TTP / mmdet /models /task_modules /assigners /topk_hungarian_assigner.py
KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.structures import BaseDataElement
from scipy.optimize import linear_sum_assignment
from mmdet.registry import TASK_UTILS
from .assign_result import AssignResult
from .task_aligned_assigner import TaskAlignedAssigner
@TASK_UTILS.register_module()
class TopkHungarianAssigner(TaskAlignedAssigner):
"""Computes 1-to-k matching between ground truth and predictions.
This class computes an assignment between the targets and the predictions
based on the costs. The costs are weighted sum of some components.
For DETR the costs are weighted sum of classification cost, regression L1
cost and regression iou cost. The targets don't include the no_object, so
generally there are more predictions than targets. After the 1-to-k
gt-pred matching, the un-matched are treated as backgrounds. Thus each
query prediction will be assigned with `0` or a positive integer
indicating the ground truth index:
- 0: negative sample, no assigned gt
- positive integer: positive sample, index (1-based) of assigned gt
Args:
cls_cost (dict): Classification cost configuration.
reg_cost (dict): Regression L1 cost configuration.
iou_cost (dict): Regression iou cost configuration.
"""
def __init__(self,
*args,
cls_cost=dict(type='FocalLossCost', weight=2.0),
reg_cost=dict(type='BBoxL1Cost', weight=5.0),
iou_cost=dict(type='IoUCost', iou_mode='giou', weight=2.0),
**kwargs):
super(TopkHungarianAssigner, self).__init__(*args, **kwargs)
self.cls_cost = TASK_UTILS.build(cls_cost)
self.reg_cost = TASK_UTILS.build(reg_cost)
self.iou_cost = TASK_UTILS.build(iou_cost)
def assign(self,
pred_scores,
decode_bboxes,
gt_bboxes,
gt_labels,
img_meta,
alpha=1,
beta=6,
**kwargs):
"""Computes 1-to-k gt-pred matching based on the weighted costs.
This method assign each query prediction to a ground truth or
background. The `assigned_gt_inds` with -1 means don't care,
0 means negative sample, and positive number is the index (1-based)
of assigned gt.
The assignment is done in the following steps, the order matters.
1. Assign every prediction to -1.
2. Compute the weighted costs, each cost has shape (num_pred, num_gt).
3. Update topk to be min(topk, int(num_pred / num_gt)), then repeat
costs topk times to shape: (num_pred, num_gt * topk), so that each
gt will match topk predictions.
3. Do Hungarian matching on CPU based on the costs.
4. Assign all to 0 (background) first, then for each matched pair
between predictions and gts, treat this prediction as foreground
and assign the corresponding gt index (plus 1) to it.
5. Calculate alignment metrics and overlaps of each matched pred-gt
pair.
Args:
pred_scores (Tensor): Predicted normalized classification
scores for one image, has shape (num_dense_queries,
cls_out_channels).
decode_bboxes (Tensor): Predicted unnormalized bbox coordinates
for one image, has shape (num_dense_queries, 4) with the
last dimension arranged as (x1, y1, x2, y2).
gt_bboxes (Tensor): Unnormalized ground truth
bboxes for one image, has shape (num_gt, 4) with the
last dimension arranged as (x1, y1, x2, y2).
NOTE: num_gt is dynamic for each image.
gt_labels (Tensor): Ground truth classification
index for the image, has shape (num_gt,).
NOTE: num_gt is dynamic for each image.
img_meta (dict): Meta information for one image.
alpha (int): Hyper-parameters related to alignment_metrics.
Defaults to 1.
beta (int): Hyper-parameters related to alignment_metrics.
Defaults to 6.
Returns:
:obj:`AssignResult`: The assigned result.
"""
pred_scores = pred_scores.detach()
decode_bboxes = decode_bboxes.detach()
temp_overlaps = self.iou_calculator(decode_bboxes, gt_bboxes).detach()
bbox_scores = pred_scores[:, gt_labels].detach()
alignment_metrics = bbox_scores**alpha * temp_overlaps**beta
pred_instances = BaseDataElement()
gt_instances = BaseDataElement()
pred_instances.bboxes = decode_bboxes
gt_instances.bboxes = gt_bboxes
pred_instances.scores = pred_scores
gt_instances.labels = gt_labels
reg_cost = self.reg_cost(pred_instances, gt_instances, img_meta)
iou_cost = self.iou_cost(pred_instances, gt_instances, img_meta)
cls_cost = self.cls_cost(pred_instances, gt_instances, img_meta)
all_cost = cls_cost + reg_cost + iou_cost
num_gt, num_bboxes = gt_bboxes.size(0), pred_scores.size(0)
if num_gt > 0:
# assign 0 by default
assigned_gt_inds = pred_scores.new_full((num_bboxes, ),
0,
dtype=torch.long)
select_cost = all_cost
topk = min(self.topk, int(len(select_cost) / num_gt))
# Repeat the ground truth `topk` times to perform 1-to-k gt-pred
# matching. For example, if `num_pred` = 900, `num_gt` = 3, then
# there are only 3 gt-pred pairs in sum for 1-1 matching.
# However, for 1-k gt-pred matching, if `topk` = 4, then each
# gt is assigned 4 unique predictions, so there would be 12
# gt-pred pairs in sum.
repeat_select_cost = select_cost[...,
None].repeat(1, 1, topk).view(
select_cost.size(0), -1)
# anchor index and gt index
matched_row_inds, matched_col_inds = linear_sum_assignment(
repeat_select_cost.detach().cpu().numpy())
matched_row_inds = torch.from_numpy(matched_row_inds).to(
pred_scores.device)
matched_col_inds = torch.from_numpy(matched_col_inds).to(
pred_scores.device)
match_gt_ids = matched_col_inds // topk
candidate_idxs = matched_row_inds
assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
if candidate_idxs.numel() > 0:
assigned_labels[candidate_idxs] = gt_labels[match_gt_ids]
else:
assigned_labels = None
assigned_gt_inds[candidate_idxs] = match_gt_ids + 1
overlaps = self.iou_calculator(
decode_bboxes[candidate_idxs],
gt_bboxes[match_gt_ids],
is_aligned=True).detach()
temp_pos_alignment_metrics = alignment_metrics[candidate_idxs]
pos_alignment_metrics = torch.gather(temp_pos_alignment_metrics, 1,
match_gt_ids[:,
None]).view(-1)
assign_result = AssignResult(
num_gt, assigned_gt_inds, overlaps, labels=assigned_labels)
assign_result.assign_metrics = pos_alignment_metrics
return assign_result
else:
assigned_gt_inds = pred_scores.new_full((num_bboxes, ),
-1,
dtype=torch.long)
assigned_labels = pred_scores.new_full((num_bboxes, ),
-1,
dtype=torch.long)
assigned_gt_inds[:] = 0
return AssignResult(
0, assigned_gt_inds, None, labels=assigned_labels)