Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
from mmengine.structures import BaseDataElement | |
from scipy.optimize import linear_sum_assignment | |
from mmdet.registry import TASK_UTILS | |
from .assign_result import AssignResult | |
from .task_aligned_assigner import TaskAlignedAssigner | |
class TopkHungarianAssigner(TaskAlignedAssigner): | |
"""Computes 1-to-k matching between ground truth and predictions. | |
This class computes an assignment between the targets and the predictions | |
based on the costs. The costs are weighted sum of some components. | |
For DETR the costs are weighted sum of classification cost, regression L1 | |
cost and regression iou cost. The targets don't include the no_object, so | |
generally there are more predictions than targets. After the 1-to-k | |
gt-pred matching, the un-matched are treated as backgrounds. Thus each | |
query prediction will be assigned with `0` or a positive integer | |
indicating the ground truth index: | |
- 0: negative sample, no assigned gt | |
- positive integer: positive sample, index (1-based) of assigned gt | |
Args: | |
cls_cost (dict): Classification cost configuration. | |
reg_cost (dict): Regression L1 cost configuration. | |
iou_cost (dict): Regression iou cost configuration. | |
""" | |
def __init__(self, | |
*args, | |
cls_cost=dict(type='FocalLossCost', weight=2.0), | |
reg_cost=dict(type='BBoxL1Cost', weight=5.0), | |
iou_cost=dict(type='IoUCost', iou_mode='giou', weight=2.0), | |
**kwargs): | |
super(TopkHungarianAssigner, self).__init__(*args, **kwargs) | |
self.cls_cost = TASK_UTILS.build(cls_cost) | |
self.reg_cost = TASK_UTILS.build(reg_cost) | |
self.iou_cost = TASK_UTILS.build(iou_cost) | |
def assign(self, | |
pred_scores, | |
decode_bboxes, | |
gt_bboxes, | |
gt_labels, | |
img_meta, | |
alpha=1, | |
beta=6, | |
**kwargs): | |
"""Computes 1-to-k gt-pred matching based on the weighted costs. | |
This method assign each query prediction to a ground truth or | |
background. The `assigned_gt_inds` with -1 means don't care, | |
0 means negative sample, and positive number is the index (1-based) | |
of assigned gt. | |
The assignment is done in the following steps, the order matters. | |
1. Assign every prediction to -1. | |
2. Compute the weighted costs, each cost has shape (num_pred, num_gt). | |
3. Update topk to be min(topk, int(num_pred / num_gt)), then repeat | |
costs topk times to shape: (num_pred, num_gt * topk), so that each | |
gt will match topk predictions. | |
3. Do Hungarian matching on CPU based on the costs. | |
4. Assign all to 0 (background) first, then for each matched pair | |
between predictions and gts, treat this prediction as foreground | |
and assign the corresponding gt index (plus 1) to it. | |
5. Calculate alignment metrics and overlaps of each matched pred-gt | |
pair. | |
Args: | |
pred_scores (Tensor): Predicted normalized classification | |
scores for one image, has shape (num_dense_queries, | |
cls_out_channels). | |
decode_bboxes (Tensor): Predicted unnormalized bbox coordinates | |
for one image, has shape (num_dense_queries, 4) with the | |
last dimension arranged as (x1, y1, x2, y2). | |
gt_bboxes (Tensor): Unnormalized ground truth | |
bboxes for one image, has shape (num_gt, 4) with the | |
last dimension arranged as (x1, y1, x2, y2). | |
NOTE: num_gt is dynamic for each image. | |
gt_labels (Tensor): Ground truth classification | |
index for the image, has shape (num_gt,). | |
NOTE: num_gt is dynamic for each image. | |
img_meta (dict): Meta information for one image. | |
alpha (int): Hyper-parameters related to alignment_metrics. | |
Defaults to 1. | |
beta (int): Hyper-parameters related to alignment_metrics. | |
Defaults to 6. | |
Returns: | |
:obj:`AssignResult`: The assigned result. | |
""" | |
pred_scores = pred_scores.detach() | |
decode_bboxes = decode_bboxes.detach() | |
temp_overlaps = self.iou_calculator(decode_bboxes, gt_bboxes).detach() | |
bbox_scores = pred_scores[:, gt_labels].detach() | |
alignment_metrics = bbox_scores**alpha * temp_overlaps**beta | |
pred_instances = BaseDataElement() | |
gt_instances = BaseDataElement() | |
pred_instances.bboxes = decode_bboxes | |
gt_instances.bboxes = gt_bboxes | |
pred_instances.scores = pred_scores | |
gt_instances.labels = gt_labels | |
reg_cost = self.reg_cost(pred_instances, gt_instances, img_meta) | |
iou_cost = self.iou_cost(pred_instances, gt_instances, img_meta) | |
cls_cost = self.cls_cost(pred_instances, gt_instances, img_meta) | |
all_cost = cls_cost + reg_cost + iou_cost | |
num_gt, num_bboxes = gt_bboxes.size(0), pred_scores.size(0) | |
if num_gt > 0: | |
# assign 0 by default | |
assigned_gt_inds = pred_scores.new_full((num_bboxes, ), | |
0, | |
dtype=torch.long) | |
select_cost = all_cost | |
topk = min(self.topk, int(len(select_cost) / num_gt)) | |
# Repeat the ground truth `topk` times to perform 1-to-k gt-pred | |
# matching. For example, if `num_pred` = 900, `num_gt` = 3, then | |
# there are only 3 gt-pred pairs in sum for 1-1 matching. | |
# However, for 1-k gt-pred matching, if `topk` = 4, then each | |
# gt is assigned 4 unique predictions, so there would be 12 | |
# gt-pred pairs in sum. | |
repeat_select_cost = select_cost[..., | |
None].repeat(1, 1, topk).view( | |
select_cost.size(0), -1) | |
# anchor index and gt index | |
matched_row_inds, matched_col_inds = linear_sum_assignment( | |
repeat_select_cost.detach().cpu().numpy()) | |
matched_row_inds = torch.from_numpy(matched_row_inds).to( | |
pred_scores.device) | |
matched_col_inds = torch.from_numpy(matched_col_inds).to( | |
pred_scores.device) | |
match_gt_ids = matched_col_inds // topk | |
candidate_idxs = matched_row_inds | |
assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1) | |
if candidate_idxs.numel() > 0: | |
assigned_labels[candidate_idxs] = gt_labels[match_gt_ids] | |
else: | |
assigned_labels = None | |
assigned_gt_inds[candidate_idxs] = match_gt_ids + 1 | |
overlaps = self.iou_calculator( | |
decode_bboxes[candidate_idxs], | |
gt_bboxes[match_gt_ids], | |
is_aligned=True).detach() | |
temp_pos_alignment_metrics = alignment_metrics[candidate_idxs] | |
pos_alignment_metrics = torch.gather(temp_pos_alignment_metrics, 1, | |
match_gt_ids[:, | |
None]).view(-1) | |
assign_result = AssignResult( | |
num_gt, assigned_gt_inds, overlaps, labels=assigned_labels) | |
assign_result.assign_metrics = pos_alignment_metrics | |
return assign_result | |
else: | |
assigned_gt_inds = pred_scores.new_full((num_bboxes, ), | |
-1, | |
dtype=torch.long) | |
assigned_labels = pred_scores.new_full((num_bboxes, ), | |
-1, | |
dtype=torch.long) | |
assigned_gt_inds[:] = 0 | |
return AssignResult( | |
0, assigned_gt_inds, None, labels=assigned_labels) | |