# Copyright (c) Facebook, Inc. and its affiliates. # Modified by Jeffrey Ouyang-Zhang from typing import List import torch import torch.nn as nn from util.box_ops import (box_cxcywh_to_xyxy, box_iou, box_xyxy_to_cxcywh, generalized_box_iou) # from https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/layers/wrappers.py#L100 def nonzero_tuple(x): """ A 'as_tuple=True' version of torch.nonzero to support torchscript. because of https://github.com/pytorch/pytorch/issues/38718 """ if torch.jit.is_scripting(): if x.dim() == 0: return x.unsqueeze(0).nonzero().unbind(1) return x.nonzero().unbind(1) else: return x.nonzero(as_tuple=True) # from https://github.com/facebookresearch/detectron2/blob/9921a2caa585d4fa66c4b534b6fab6e74d89b582/detectron2/modeling/matcher.py#L9 class Matcher(object): """ This class assigns to each predicted "element" (e.g., a box) a ground-truth element. Each predicted element will have exactly zero or one matches; each ground-truth element may be matched to zero or more predicted elements. The matching is determined by the MxN match_quality_matrix, that characterizes how well each (ground-truth, prediction)-pair match each other. For example, if the elements are boxes, this matrix may contain box intersection-over-union overlap values. The matcher returns (a) a vector of length N containing the index of the ground-truth element m in [0, M) that matches to prediction n in [0, N). (b) a vector of length N containing the labels for each prediction. """ def __init__( self, thresholds: List[float], labels: List[int], allow_low_quality_matches: bool = False ): """ Args: thresholds (list): a list of thresholds used to stratify predictions into levels. labels (list): a list of values to label predictions belonging at each level. A label can be one of {-1, 0, 1} signifying {ignore, negative class, positive class}, respectively. allow_low_quality_matches (bool): if True, produce additional matches for predictions with maximum match quality lower than high_threshold. See set_low_quality_matches_ for more details. For example, thresholds = [0.3, 0.5] labels = [0, -1, 1] All predictions with iou < 0.3 will be marked with 0 and thus will be considered as false positives while training. All predictions with 0.3 <= iou < 0.5 will be marked with -1 and thus will be ignored. All predictions with 0.5 <= iou will be marked with 1 and thus will be considered as true positives. """ # Add -inf and +inf to first and last position in thresholds thresholds = thresholds[:] assert thresholds[0] > 0 thresholds.insert(0, -float("inf")) thresholds.append(float("inf")) # Currently torchscript does not support all + generator assert all([low <= high for (low, high) in zip(thresholds[:-1], thresholds[1:])]), thresholds assert all([l in [-1, 0, 1] for l in labels]) assert len(labels) == len(thresholds) - 1 self.thresholds = thresholds self.labels = labels self.allow_low_quality_matches = allow_low_quality_matches def __call__(self, match_quality_matrix): """ Args: match_quality_matrix (Tensor[float]): an MxN tensor, containing the pairwise quality between M ground-truth elements and N predicted elements. All elements must be >= 0 (due to the us of `torch.nonzero` for selecting indices in :meth:`set_low_quality_matches_`). Returns: matches (Tensor[int64]): a vector of length N, where matches[i] is a matched ground-truth index in [0, M) match_labels (Tensor[int8]): a vector of length N, where pred_labels[i] indicates whether a prediction is a true or false positive or ignored """ assert match_quality_matrix.dim() == 2 if match_quality_matrix.numel() == 0: default_matches = match_quality_matrix.new_full( (match_quality_matrix.size(1),), 0, dtype=torch.int64 ) # When no gt boxes exist, we define IOU = 0 and therefore set labels # to `self.labels[0]`, which usually defaults to background class 0 # To choose to ignore instead, can make labels=[-1,0,-1,1] + set appropriate thresholds default_match_labels = match_quality_matrix.new_full( (match_quality_matrix.size(1),), self.labels[0], dtype=torch.int8 ) return default_matches, default_match_labels assert torch.all(match_quality_matrix >= 0) # match_quality_matrix is M (gt) x N (predicted) # Max over gt elements (dim 0) to find best gt candidate for each prediction matched_vals, matches = match_quality_matrix.max(dim=0) match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8) for (l, low, high) in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]): low_high = (matched_vals >= low) & (matched_vals < high) match_labels[low_high] = l if self.allow_low_quality_matches: self.set_low_quality_matches_(match_labels, match_quality_matrix) return matches, match_labels def set_low_quality_matches_(self, match_labels, match_quality_matrix): """ Produce additional matches for predictions that have only low-quality matches. Specifically, for each ground-truth G find the set of predictions that have maximum overlap with it (including ties); for each prediction in that set, if it is unmatched, then match it to the ground-truth G. This function implements the RPN assignment case (i) in Sec. 3.1.2 of :paper:`Faster R-CNN`. """ # For each gt, find the prediction with which it has highest quality highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1) # Find the highest quality match available, even if it is low, including ties. # Note that the matches qualities must be positive due to the use of # `torch.nonzero`. _, pred_inds_with_highest_quality = nonzero_tuple( match_quality_matrix == highest_quality_foreach_gt[:, None] ) # If an anchor was labeled positive only due to a low-quality match # with gt_A, but it has larger overlap with gt_B, it's matched index will still be gt_B. # This follows the implementation in Detectron, and is found to have no significant impact. match_labels[pred_inds_with_highest_quality] = 1 # from https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/modeling/sampling.py#L9 def subsample_labels( labels: torch.Tensor, num_samples: int, positive_fraction: float, bg_label: int ): """ Return `num_samples` (or fewer, if not enough found) random samples from `labels` which is a mixture of positives & negatives. It will try to return as many positives as possible without exceeding `positive_fraction * num_samples`, and then try to fill the remaining slots with negatives. Args: labels (Tensor): (N, ) label vector with values: * -1: ignore * bg_label: background ("negative") class * otherwise: one or more foreground ("positive") classes num_samples (int): The total number of labels with value >= 0 to return. Values that are not sampled will be filled with -1 (ignore). positive_fraction (float): The number of subsampled labels with values > 0 is `min(num_positives, int(positive_fraction * num_samples))`. The number of negatives sampled is `min(num_negatives, num_samples - num_positives_sampled)`. In order words, if there are not enough positives, the sample is filled with negatives. If there are also not enough negatives, then as many elements are sampled as is possible. bg_label (int): label index of background ("negative") class. Returns: pos_idx, neg_idx (Tensor): 1D vector of indices. The total length of both is `num_samples` or fewer. """ positive = nonzero_tuple((labels != -1) & (labels != bg_label))[0] negative = nonzero_tuple(labels == bg_label)[0] num_pos = int(num_samples * positive_fraction) # protect against not enough positive examples num_pos = min(positive.numel(), num_pos) num_neg = num_samples - num_pos # protect against not enough negative examples num_neg = min(negative.numel(), num_neg) # randomly select positive and negative examples perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos] perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg] pos_idx = positive[perm1] neg_idx = negative[perm2] return pos_idx, neg_idx def sample_topk_per_gt(pr_inds, gt_inds, iou, k): if len(gt_inds) == 0: return pr_inds, gt_inds # find topk matches for each gt gt_inds2, counts = gt_inds.unique(return_counts=True) scores, pr_inds2 = iou[gt_inds2].topk(k, dim=1) gt_inds2 = gt_inds2[:,None].repeat(1, k) # filter to as many matches that gt has pr_inds3 = torch.cat([pr[:c] for c, pr in zip(counts, pr_inds2)]) gt_inds3 = torch.cat([gt[:c] for c, gt in zip(counts, gt_inds2)]) return pr_inds3, gt_inds3 # modified from https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/modeling/roi_heads/roi_heads.py#L123 class Stage2Assigner(nn.Module): def __init__(self, num_queries, max_k=4): super().__init__() self.positive_fraction = 0.25 self.bg_label = 400 # number > 91 to filter out later self.batch_size_per_image = num_queries self.proposal_matcher = Matcher(thresholds=[0.6], labels=[0, 1], allow_low_quality_matches=True) self.k = max_k def _sample_proposals( self, matched_idxs: torch.Tensor, matched_labels: torch.Tensor, gt_classes: torch.Tensor ): """ Based on the matching between N proposals and M groundtruth, sample the proposals and set their classification labels. Args: matched_idxs (Tensor): a vector of length N, each is the best-matched gt index in [0, M) for each proposal. matched_labels (Tensor): a vector of length N, the matcher's label (one of cfg.MODEL.ROI_HEADS.IOU_LABELS) for each proposal. gt_classes (Tensor): a vector of length M. Returns: Tensor: a vector of indices of sampled proposals. Each is in [0, N). Tensor: a vector of the same length, the classification label for each sampled proposal. Each sample is labeled as either a category in [0, num_classes) or the background (num_classes). """ has_gt = gt_classes.numel() > 0 # Get the corresponding GT for each proposal if has_gt: gt_classes = gt_classes[matched_idxs] # Label unmatched proposals (0 label from matcher) as background (label=num_classes) gt_classes[matched_labels == 0] = self.bg_label # Label ignore proposals (-1 label) gt_classes[matched_labels == -1] = -1 else: gt_classes = torch.zeros_like(matched_idxs) + self.bg_label sampled_fg_idxs, sampled_bg_idxs = subsample_labels( gt_classes, self.batch_size_per_image, self.positive_fraction, self.bg_label ) sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0) return sampled_idxs, gt_classes[sampled_idxs] def forward(self, outputs, targets, return_cost_matrix=False): # COCO categories are from 1 to 90. They set num_classes=91 and apply sigmoid. bs = len(targets) indices = [] ious = [] for b in range(bs): iou, _ = box_iou( box_cxcywh_to_xyxy(targets[b]['boxes']), box_cxcywh_to_xyxy(outputs['init_reference'][b].detach()), ) matched_idxs, matched_labels = self.proposal_matcher(iou) # proposal_id -> highest_iou_gt_id, proposal_id -> [1 if iou > 0.6, 0 ow] sampled_idxs, sampled_gt_classes = self._sample_proposals( # list of sampled proposal_ids, sampled_id -> [0, num_classes)+[bg_label] matched_idxs, matched_labels, targets[b]['labels'] ) pos_pr_inds = sampled_idxs[sampled_gt_classes != self.bg_label] pos_gt_inds = matched_idxs[pos_pr_inds] pos_pr_inds, pos_gt_inds = self.postprocess_indices(pos_pr_inds, pos_gt_inds, iou) indices.append((pos_pr_inds, pos_gt_inds)) ious.append(iou) if return_cost_matrix: return indices, ious return indices def postprocess_indices(self, pr_inds, gt_inds, iou): return sample_topk_per_gt(pr_inds, gt_inds, iou, self.k) # modified from https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/modeling/proposal_generator/rpn.py#L181 class Stage1Assigner(nn.Module): def __init__(self, t_low=0.3, t_high=0.7, max_k=4): super().__init__() self.positive_fraction = 0.5 self.batch_size_per_image = 256 self.k = max_k self.t_low = t_low self.t_high = t_high self.anchor_matcher = Matcher(thresholds=[t_low, t_high], labels=[0, -1, 1], allow_low_quality_matches=True) def _subsample_labels(self, label): """ Randomly sample a subset of positive and negative examples, and overwrite the label vector to the ignore value (-1) for all elements that are not included in the sample. Args: labels (Tensor): a vector of -1, 0, 1. Will be modified in-place and returned. """ pos_idx, neg_idx = subsample_labels( label, self.batch_size_per_image, self.positive_fraction, 0 ) # Fill with the ignore label (-1), then set positive and negative labels label.fill_(-1) label.scatter_(0, pos_idx, 1) label.scatter_(0, neg_idx, 0) return label def forward(self, outputs, targets): bs = len(targets) indices = [] for b in range(bs): anchors = outputs['anchors'][b] if len(targets[b]['boxes']) == 0: indices.append((torch.tensor([], dtype=torch.long, device=anchors.device), torch.tensor([], dtype=torch.long, device=anchors.device))) continue iou, _ = box_iou( box_cxcywh_to_xyxy(targets[b]['boxes']), box_cxcywh_to_xyxy(anchors), ) matched_idxs, matched_labels = self.anchor_matcher(iou) # proposal_id -> highest_iou_gt_id, proposal_id -> [1 if iou > 0.7, 0 if iou < 0.3, -1 ow] matched_labels = self._subsample_labels(matched_labels) all_pr_inds = torch.arange(len(anchors)) pos_pr_inds = all_pr_inds[matched_labels == 1] pos_gt_inds = matched_idxs[pos_pr_inds] pos_ious = iou[pos_gt_inds, pos_pr_inds] pos_pr_inds, pos_gt_inds = self.postprocess_indices(pos_pr_inds, pos_gt_inds, iou) pos_pr_inds, pos_gt_inds = pos_pr_inds.to(anchors.device), pos_gt_inds.to(anchors.device) indices.append((pos_pr_inds, pos_gt_inds)) return indices def postprocess_indices(self, pr_inds, gt_inds, iou): return sample_topk_per_gt(pr_inds, gt_inds, iou, self.k)