yuhangzang
update
a059c46
raw
history blame
16.2 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Jeffrey Ouyang-Zhang
from typing import List
import torch
import torch.nn as nn
from util.box_ops import (box_cxcywh_to_xyxy, box_iou, box_xyxy_to_cxcywh,
generalized_box_iou)
# from https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/layers/wrappers.py#L100
def nonzero_tuple(x):
"""
A 'as_tuple=True' version of torch.nonzero to support torchscript.
because of https://github.com/pytorch/pytorch/issues/38718
"""
if torch.jit.is_scripting():
if x.dim() == 0:
return x.unsqueeze(0).nonzero().unbind(1)
return x.nonzero().unbind(1)
else:
return x.nonzero(as_tuple=True)
# from https://github.com/facebookresearch/detectron2/blob/9921a2caa585d4fa66c4b534b6fab6e74d89b582/detectron2/modeling/matcher.py#L9
class Matcher(object):
"""
This class assigns to each predicted "element" (e.g., a box) a ground-truth
element. Each predicted element will have exactly zero or one matches; each
ground-truth element may be matched to zero or more predicted elements.
The matching is determined by the MxN match_quality_matrix, that characterizes
how well each (ground-truth, prediction)-pair match each other. For example,
if the elements are boxes, this matrix may contain box intersection-over-union
overlap values.
The matcher returns (a) a vector of length N containing the index of the
ground-truth element m in [0, M) that matches to prediction n in [0, N).
(b) a vector of length N containing the labels for each prediction.
"""
def __init__(
self, thresholds: List[float], labels: List[int], allow_low_quality_matches: bool = False
):
"""
Args:
thresholds (list): a list of thresholds used to stratify predictions
into levels.
labels (list): a list of values to label predictions belonging at
each level. A label can be one of {-1, 0, 1} signifying
{ignore, negative class, positive class}, respectively.
allow_low_quality_matches (bool): if True, produce additional matches
for predictions with maximum match quality lower than high_threshold.
See set_low_quality_matches_ for more details.
For example,
thresholds = [0.3, 0.5]
labels = [0, -1, 1]
All predictions with iou < 0.3 will be marked with 0 and
thus will be considered as false positives while training.
All predictions with 0.3 <= iou < 0.5 will be marked with -1 and
thus will be ignored.
All predictions with 0.5 <= iou will be marked with 1 and
thus will be considered as true positives.
"""
# Add -inf and +inf to first and last position in thresholds
thresholds = thresholds[:]
assert thresholds[0] > 0
thresholds.insert(0, -float("inf"))
thresholds.append(float("inf"))
# Currently torchscript does not support all + generator
assert all([low <= high for (low, high) in zip(thresholds[:-1], thresholds[1:])]), thresholds
assert all([l in [-1, 0, 1] for l in labels])
assert len(labels) == len(thresholds) - 1
self.thresholds = thresholds
self.labels = labels
self.allow_low_quality_matches = allow_low_quality_matches
def __call__(self, match_quality_matrix):
"""
Args:
match_quality_matrix (Tensor[float]): an MxN tensor, containing the
pairwise quality between M ground-truth elements and N predicted
elements. All elements must be >= 0 (due to the us of `torch.nonzero`
for selecting indices in :meth:`set_low_quality_matches_`).
Returns:
matches (Tensor[int64]): a vector of length N, where matches[i] is a matched
ground-truth index in [0, M)
match_labels (Tensor[int8]): a vector of length N, where pred_labels[i] indicates
whether a prediction is a true or false positive or ignored
"""
assert match_quality_matrix.dim() == 2
if match_quality_matrix.numel() == 0:
default_matches = match_quality_matrix.new_full(
(match_quality_matrix.size(1),), 0, dtype=torch.int64
)
# When no gt boxes exist, we define IOU = 0 and therefore set labels
# to `self.labels[0]`, which usually defaults to background class 0
# To choose to ignore instead, can make labels=[-1,0,-1,1] + set appropriate thresholds
default_match_labels = match_quality_matrix.new_full(
(match_quality_matrix.size(1),), self.labels[0], dtype=torch.int8
)
return default_matches, default_match_labels
assert torch.all(match_quality_matrix >= 0)
# match_quality_matrix is M (gt) x N (predicted)
# Max over gt elements (dim 0) to find best gt candidate for each prediction
matched_vals, matches = match_quality_matrix.max(dim=0)
match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8)
for (l, low, high) in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]):
low_high = (matched_vals >= low) & (matched_vals < high)
match_labels[low_high] = l
if self.allow_low_quality_matches:
self.set_low_quality_matches_(match_labels, match_quality_matrix)
return matches, match_labels
def set_low_quality_matches_(self, match_labels, match_quality_matrix):
"""
Produce additional matches for predictions that have only low-quality matches.
Specifically, for each ground-truth G find the set of predictions that have
maximum overlap with it (including ties); for each prediction in that set, if
it is unmatched, then match it to the ground-truth G.
This function implements the RPN assignment case (i) in Sec. 3.1.2 of
:paper:`Faster R-CNN`.
"""
# For each gt, find the prediction with which it has highest quality
highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
# Find the highest quality match available, even if it is low, including ties.
# Note that the matches qualities must be positive due to the use of
# `torch.nonzero`.
_, pred_inds_with_highest_quality = nonzero_tuple(
match_quality_matrix == highest_quality_foreach_gt[:, None]
)
# If an anchor was labeled positive only due to a low-quality match
# with gt_A, but it has larger overlap with gt_B, it's matched index will still be gt_B.
# This follows the implementation in Detectron, and is found to have no significant impact.
match_labels[pred_inds_with_highest_quality] = 1
# from https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/modeling/sampling.py#L9
def subsample_labels(
labels: torch.Tensor, num_samples: int, positive_fraction: float, bg_label: int
):
"""
Return `num_samples` (or fewer, if not enough found)
random samples from `labels` which is a mixture of positives & negatives.
It will try to return as many positives as possible without
exceeding `positive_fraction * num_samples`, and then try to
fill the remaining slots with negatives.
Args:
labels (Tensor): (N, ) label vector with values:
* -1: ignore
* bg_label: background ("negative") class
* otherwise: one or more foreground ("positive") classes
num_samples (int): The total number of labels with value >= 0 to return.
Values that are not sampled will be filled with -1 (ignore).
positive_fraction (float): The number of subsampled labels with values > 0
is `min(num_positives, int(positive_fraction * num_samples))`. The number
of negatives sampled is `min(num_negatives, num_samples - num_positives_sampled)`.
In order words, if there are not enough positives, the sample is filled with
negatives. If there are also not enough negatives, then as many elements are
sampled as is possible.
bg_label (int): label index of background ("negative") class.
Returns:
pos_idx, neg_idx (Tensor):
1D vector of indices. The total length of both is `num_samples` or fewer.
"""
positive = nonzero_tuple((labels != -1) & (labels != bg_label))[0]
negative = nonzero_tuple(labels == bg_label)[0]
num_pos = int(num_samples * positive_fraction)
# protect against not enough positive examples
num_pos = min(positive.numel(), num_pos)
num_neg = num_samples - num_pos
# protect against not enough negative examples
num_neg = min(negative.numel(), num_neg)
# randomly select positive and negative examples
perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
pos_idx = positive[perm1]
neg_idx = negative[perm2]
return pos_idx, neg_idx
def sample_topk_per_gt(pr_inds, gt_inds, iou, k):
if len(gt_inds) == 0:
return pr_inds, gt_inds
# find topk matches for each gt
gt_inds2, counts = gt_inds.unique(return_counts=True)
scores, pr_inds2 = iou[gt_inds2].topk(k, dim=1)
gt_inds2 = gt_inds2[:,None].repeat(1, k)
# filter to as many matches that gt has
pr_inds3 = torch.cat([pr[:c] for c, pr in zip(counts, pr_inds2)])
gt_inds3 = torch.cat([gt[:c] for c, gt in zip(counts, gt_inds2)])
return pr_inds3, gt_inds3
# modified from https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/modeling/roi_heads/roi_heads.py#L123
class Stage2Assigner(nn.Module):
def __init__(self, num_queries, max_k=4):
super().__init__()
self.positive_fraction = 0.25
self.bg_label = 400 # number > 91 to filter out later
self.batch_size_per_image = num_queries
self.proposal_matcher = Matcher(thresholds=[0.6], labels=[0, 1], allow_low_quality_matches=True)
self.k = max_k
def _sample_proposals(
self, matched_idxs: torch.Tensor, matched_labels: torch.Tensor, gt_classes: torch.Tensor
):
"""
Based on the matching between N proposals and M groundtruth,
sample the proposals and set their classification labels.
Args:
matched_idxs (Tensor): a vector of length N, each is the best-matched
gt index in [0, M) for each proposal.
matched_labels (Tensor): a vector of length N, the matcher's label
(one of cfg.MODEL.ROI_HEADS.IOU_LABELS) for each proposal.
gt_classes (Tensor): a vector of length M.
Returns:
Tensor: a vector of indices of sampled proposals. Each is in [0, N).
Tensor: a vector of the same length, the classification label for
each sampled proposal. Each sample is labeled as either a category in
[0, num_classes) or the background (num_classes).
"""
has_gt = gt_classes.numel() > 0
# Get the corresponding GT for each proposal
if has_gt:
gt_classes = gt_classes[matched_idxs]
# Label unmatched proposals (0 label from matcher) as background (label=num_classes)
gt_classes[matched_labels == 0] = self.bg_label
# Label ignore proposals (-1 label)
gt_classes[matched_labels == -1] = -1
else:
gt_classes = torch.zeros_like(matched_idxs) + self.bg_label
sampled_fg_idxs, sampled_bg_idxs = subsample_labels(
gt_classes, self.batch_size_per_image, self.positive_fraction, self.bg_label
)
sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0)
return sampled_idxs, gt_classes[sampled_idxs]
def forward(self, outputs, targets, return_cost_matrix=False):
# COCO categories are from 1 to 90. They set num_classes=91 and apply sigmoid.
bs = len(targets)
indices = []
ious = []
for b in range(bs):
iou, _ = box_iou(
box_cxcywh_to_xyxy(targets[b]['boxes']),
box_cxcywh_to_xyxy(outputs['init_reference'][b].detach()),
)
matched_idxs, matched_labels = self.proposal_matcher(iou) # proposal_id -> highest_iou_gt_id, proposal_id -> [1 if iou > 0.6, 0 ow]
sampled_idxs, sampled_gt_classes = self._sample_proposals( # list of sampled proposal_ids, sampled_id -> [0, num_classes)+[bg_label]
matched_idxs, matched_labels, targets[b]['labels']
)
pos_pr_inds = sampled_idxs[sampled_gt_classes != self.bg_label]
pos_gt_inds = matched_idxs[pos_pr_inds]
pos_pr_inds, pos_gt_inds = self.postprocess_indices(pos_pr_inds, pos_gt_inds, iou)
indices.append((pos_pr_inds, pos_gt_inds))
ious.append(iou)
if return_cost_matrix:
return indices, ious
return indices
def postprocess_indices(self, pr_inds, gt_inds, iou):
return sample_topk_per_gt(pr_inds, gt_inds, iou, self.k)
# modified from https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/modeling/proposal_generator/rpn.py#L181
class Stage1Assigner(nn.Module):
def __init__(self, t_low=0.3, t_high=0.7, max_k=4):
super().__init__()
self.positive_fraction = 0.5
self.batch_size_per_image = 256
self.k = max_k
self.t_low = t_low
self.t_high = t_high
self.anchor_matcher = Matcher(thresholds=[t_low, t_high], labels=[0, -1, 1], allow_low_quality_matches=True)
def _subsample_labels(self, label):
"""
Randomly sample a subset of positive and negative examples, and overwrite
the label vector to the ignore value (-1) for all elements that are not
included in the sample.
Args:
labels (Tensor): a vector of -1, 0, 1. Will be modified in-place and returned.
"""
pos_idx, neg_idx = subsample_labels(
label, self.batch_size_per_image, self.positive_fraction, 0
)
# Fill with the ignore label (-1), then set positive and negative labels
label.fill_(-1)
label.scatter_(0, pos_idx, 1)
label.scatter_(0, neg_idx, 0)
return label
def forward(self, outputs, targets):
bs = len(targets)
indices = []
for b in range(bs):
anchors = outputs['anchors'][b]
if len(targets[b]['boxes']) == 0:
indices.append((torch.tensor([], dtype=torch.long, device=anchors.device),
torch.tensor([], dtype=torch.long, device=anchors.device)))
continue
iou, _ = box_iou(
box_cxcywh_to_xyxy(targets[b]['boxes']),
box_cxcywh_to_xyxy(anchors),
)
matched_idxs, matched_labels = self.anchor_matcher(iou) # proposal_id -> highest_iou_gt_id, proposal_id -> [1 if iou > 0.7, 0 if iou < 0.3, -1 ow]
matched_labels = self._subsample_labels(matched_labels)
all_pr_inds = torch.arange(len(anchors))
pos_pr_inds = all_pr_inds[matched_labels == 1]
pos_gt_inds = matched_idxs[pos_pr_inds]
pos_ious = iou[pos_gt_inds, pos_pr_inds]
pos_pr_inds, pos_gt_inds = self.postprocess_indices(pos_pr_inds, pos_gt_inds, iou)
pos_pr_inds, pos_gt_inds = pos_pr_inds.to(anchors.device), pos_gt_inds.to(anchors.device)
indices.append((pos_pr_inds, pos_gt_inds))
return indices
def postprocess_indices(self, pr_inds, gt_inds, iou):
return sample_topk_per_gt(pr_inds, gt_inds, iou, self.k)