# -----------------------------------------------------
# ConsNet
# Licensed under the GNU General Public License v3.0
# Written by Ye Liu (ye-liu at whu.edu.cn)
# -----------------------------------------------------
import nncore
import torch
from nncore.ops import bbox_iou
[docs]def pair_iou(bboxes1, bboxes2):
"""
Compute the intersection-over-unions (IoUs) among human-object pairs.
Args:
bboxes1 (:obj:`Tensor[N, 8]`): Human-object pairs to be computed. They
are expected to be in ``(x1, y1, x2, y2, ...)`` format.
bboxes2 (:obj:`Tensor[M, 8]`): Human-object pairs to be computed. They
are expected to be in ``(x1, y1, x2, y2, ...)`` format.
Returns:
:obj:`Tensor[N, M]`: The computed pairwise IoU values
"""
assert bboxes1.size(1) == bboxes2.size(1) == 8
h_iou = bbox_iou(bboxes1[:, :4], bboxes2[:, :4])
o_iou = bbox_iou(bboxes1[:, 4:], bboxes2[:, 4:])
return torch.min(h_iou, o_iou)
[docs]def pair_nms(bboxes,
scores,
method='fast',
hard_thr=0.5,
soft_thr=0.3,
sigma=0.5,
score_thr=1e-6):
"""
Perform non-maximum suppression (NMS) on human-object pairs. This method
supports multiple NMS types including Fast NMS [1], Cluster NMS [2], Normal
NMS [3] and Soft NMS [4] with linear or gaussian suppression terms.
Args:
bboxes (:obj:`Tensor[N, 9]`): Batches of human-object pairs to be
suppressed. The values are expected to be in
``(batch_id, x1, y1, x2, y2, ...)`` format.
scores (:obj:`Tensor[N]`): Human-object interaction detection scores
to be considered.
method (str, optional): Type of NMS. Expected values include
``'fast'``, ``'cluster'``, ``'normal'``, ``'linear'`` and
``'gaussian'``, indicating Fast NMS, Cluster NMS, Normal NMS and
Soft NMS with linear or gaussian suppression terms.
hard_thr (float, optional): Hard threshold of NMS. This attribute
is applied to all NMS methods. Human-object pairs with IoUs higher
than this value will be discarded.
soft_thr (float, optional): Soft threshold of NMS. This attribute
is only applied to ``linear`` and ``gaussian`` methods.
Human-object pairs with IoUs lower than ``hard_thr`` but higher
than this value will be suppressed in a soft manner.
sigma (float, optional): Hyperparameter for ``gaussian`` method.
score_thr (float, optional): Score threshold. This attribute is
applied to ``normal``, ``linear`` and ``gaussian`` methods.
Human-object pairs with suppressed scores lower than this value
will be discarded.
Returns:
:obj:`Tensor[N, 10]`: Human-object pairs and their updated scores \
after NMS. The values are expected to be in \
``(batch_id, x1, y1, x2, y2, ..., score)`` format.
References:
1. Bolya et al. (https://arxiv.org/abs/1904.02689)
2. Zheng et al. (https://arxiv.org/abs/2005.03572)
3. Neubeck er al. (https://doi.org/10.1109/icpr.2006.479)
4. Bodla et al. (https://arxiv.org/abs/1704.04503)
"""
assert bboxes.size(1) == 9
assert bboxes.size(0) == scores.size(0)
assert method in ('fast', 'cluster', 'normal', 'linear', 'gaussian')
if (num_bboxes := bboxes.size(0)) == 0:
return torch.cat((bboxes, scores[:, None]), dim=1)
if method in ('fast', 'cluster'):
batch_ids = bboxes[:, None, 0]
coors = (c := bboxes[:, 1:]) + batch_ids * (c.max() + 1)
scores, inds = scores.sort(descending=True)
bboxes, coors = bboxes[inds], coors[inds]
iou = pair_iou(coors, coors).triu(diagonal=1)
if method == 'fast':
keep = iou.amax(dim=0) <= hard_thr
else:
c = iou
for _ in range(num_bboxes):
max_iou = (a := c).amax(dim=0)
c = iou * (max_iou < hard_thr)[:, None].float().expand_as(a)
if torch.equal(a, c):
break
keep = max_iou < hard_thr
blob = torch.cat((bboxes[keep], scores[keep, None]), dim=1)
else:
batch_ids, collected = bboxes[:, 0].unique(), []
for batch_id in batch_ids:
keep = bboxes[:, 0] == batch_id
blob = torch.cat((bboxes[keep], scores[keep, None]), dim=1)
num_bboxes = blob.size(0)
for i in range(num_bboxes - 1):
max_score, max_idx = blob[i:, -1].max(dim=0)
if max_score < score_thr:
blob = blob[:i]
break
blob = nncore.swap_element(blob, i, max_idx + i)
iou = pair_iou(blob[i, None, 1:9], blob[i + 1:, 1:9])[0]
blob[i + 1:, -1][iou >= hard_thr] = 0
if method == 'normal':
continue
keep = iou >= soft_thr
if method == 'linear':
blob[i + 1:, -1][keep] *= 1 - iou[keep]
else:
blob[i + 1:, -1][keep] *= (-iou[keep].pow(2) / sigma).exp()
collected.append(blob)
blob = torch.cat(collected)
return blob