Spaces:
Runtime error
Runtime error
File size: 10,963 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch
import torch.nn.functional as F
from mmengine import MessageHub
from mmengine.structures import InstanceData
from torch import Tensor
from mmdet.registry import MODELS
from mmdet.utils import InstanceList
from ..utils.misc import unfold_wo_center
from .condinst_head import CondInstBboxHead, CondInstMaskHead
@MODELS.register_module()
class BoxInstBboxHead(CondInstBboxHead):
"""BoxInst box head used in https://arxiv.org/abs/2012.02310."""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@MODELS.register_module()
class BoxInstMaskHead(CondInstMaskHead):
"""BoxInst mask head used in https://arxiv.org/abs/2012.02310.
This head outputs the mask for BoxInst.
Args:
pairwise_size (dict): The size of neighborhood for each pixel.
Defaults to 3.
pairwise_dilation (int): The dilation of neighborhood for each pixel.
Defaults to 2.
warmup_iters (int): Warmup iterations for pair-wise loss.
Defaults to 10000.
"""
def __init__(self,
*arg,
pairwise_size: int = 3,
pairwise_dilation: int = 2,
warmup_iters: int = 10000,
**kwargs) -> None:
self.pairwise_size = pairwise_size
self.pairwise_dilation = pairwise_dilation
self.warmup_iters = warmup_iters
super().__init__(*arg, **kwargs)
def get_pairwise_affinity(self, mask_logits: Tensor) -> Tensor:
"""Compute the pairwise affinity for each pixel."""
log_fg_prob = F.logsigmoid(mask_logits).unsqueeze(1)
log_bg_prob = F.logsigmoid(-mask_logits).unsqueeze(1)
log_fg_prob_unfold = unfold_wo_center(
log_fg_prob,
kernel_size=self.pairwise_size,
dilation=self.pairwise_dilation)
log_bg_prob_unfold = unfold_wo_center(
log_bg_prob,
kernel_size=self.pairwise_size,
dilation=self.pairwise_dilation)
# the probability of making the same prediction:
# p_i * p_j + (1 - p_i) * (1 - p_j)
# we compute the the probability in log space
# to avoid numerical instability
log_same_fg_prob = log_fg_prob[:, :, None] + log_fg_prob_unfold
log_same_bg_prob = log_bg_prob[:, :, None] + log_bg_prob_unfold
# TODO: Figure out the difference between it and directly sum
max_ = torch.max(log_same_fg_prob, log_same_bg_prob)
log_same_prob = torch.log(
torch.exp(log_same_fg_prob - max_) +
torch.exp(log_same_bg_prob - max_)) + max_
return -log_same_prob[:, 0]
def loss_by_feat(self, mask_preds: List[Tensor],
batch_gt_instances: InstanceList,
batch_img_metas: List[dict], positive_infos: InstanceList,
**kwargs) -> dict:
"""Calculate the loss based on the features extracted by the mask head.
Args:
mask_preds (list[Tensor]): List of predicted masks, each has
shape (num_classes, H, W).
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes``, ``masks``,
and ``labels`` attributes.
batch_img_metas (list[dict]): Meta information of multiple images.
positive_infos (List[:obj:``InstanceData``]): Information of
positive samples of each image that are assigned in detection
head.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert positive_infos is not None, \
'positive_infos should not be None in `BoxInstMaskHead`'
losses = dict()
loss_mask_project = 0.
loss_mask_pairwise = 0.
num_imgs = len(mask_preds)
total_pos = 0.
avg_fatcor = 0.
for idx in range(num_imgs):
(mask_pred, pos_mask_targets, pos_pairwise_masks, num_pos) = \
self._get_targets_single(
mask_preds[idx], batch_gt_instances[idx],
positive_infos[idx])
# mask loss
total_pos += num_pos
if num_pos == 0 or pos_mask_targets is None:
loss_project = mask_pred.new_zeros(1).mean()
loss_pairwise = mask_pred.new_zeros(1).mean()
avg_fatcor += 0.
else:
# compute the project term
loss_project_x = self.loss_mask(
mask_pred.max(dim=1, keepdim=True)[0],
pos_mask_targets.max(dim=1, keepdim=True)[0],
reduction_override='none').sum()
loss_project_y = self.loss_mask(
mask_pred.max(dim=2, keepdim=True)[0],
pos_mask_targets.max(dim=2, keepdim=True)[0],
reduction_override='none').sum()
loss_project = loss_project_x + loss_project_y
# compute the pairwise term
pairwise_affinity = self.get_pairwise_affinity(mask_pred)
avg_fatcor += pos_pairwise_masks.sum().clamp(min=1.0)
loss_pairwise = (pairwise_affinity * pos_pairwise_masks).sum()
loss_mask_project += loss_project
loss_mask_pairwise += loss_pairwise
if total_pos == 0:
total_pos += 1 # avoid nan
if avg_fatcor == 0:
avg_fatcor += 1 # avoid nan
loss_mask_project = loss_mask_project / total_pos
loss_mask_pairwise = loss_mask_pairwise / avg_fatcor
message_hub = MessageHub.get_current_instance()
iter = message_hub.get_info('iter')
warmup_factor = min(iter / float(self.warmup_iters), 1.0)
loss_mask_pairwise *= warmup_factor
losses.update(
loss_mask_project=loss_mask_project,
loss_mask_pairwise=loss_mask_pairwise)
return losses
def _get_targets_single(self, mask_preds: Tensor,
gt_instances: InstanceData,
positive_info: InstanceData):
"""Compute targets for predictions of single image.
Args:
mask_preds (Tensor): Predicted prototypes with shape
(num_classes, H, W).
gt_instances (:obj:`InstanceData`): Ground truth of instance
annotations. It should includes ``bboxes``, ``labels``,
and ``masks`` attributes.
positive_info (:obj:`InstanceData`): Information of positive
samples that are assigned in detection head. It usually
contains following keys.
- pos_assigned_gt_inds (Tensor): Assigner GT indexes of
positive proposals, has shape (num_pos, )
- pos_inds (Tensor): Positive index of image, has
shape (num_pos, ).
- param_pred (Tensor): Positive param preditions
with shape (num_pos, num_params).
Returns:
tuple: Usually returns a tuple containing learning targets.
- mask_preds (Tensor): Positive predicted mask with shape
(num_pos, mask_h, mask_w).
- pos_mask_targets (Tensor): Positive mask targets with shape
(num_pos, mask_h, mask_w).
- pos_pairwise_masks (Tensor): Positive pairwise masks with
shape: (num_pos, num_neighborhood, mask_h, mask_w).
- num_pos (int): Positive numbers.
"""
gt_bboxes = gt_instances.bboxes
device = gt_bboxes.device
# Note that gt_masks are generated by full box
# from BoxInstDataPreprocessor
gt_masks = gt_instances.masks.to_tensor(
dtype=torch.bool, device=device).float()
# Note that pairwise_masks are generated by image color similarity
# from BoxInstDataPreprocessor
pairwise_masks = gt_instances.pairwise_masks
pairwise_masks = pairwise_masks.to(device=device)
# process with mask targets
pos_assigned_gt_inds = positive_info.get('pos_assigned_gt_inds')
scores = positive_info.get('scores')
centernesses = positive_info.get('centernesses')
num_pos = pos_assigned_gt_inds.size(0)
if gt_masks.size(0) == 0 or num_pos == 0:
return mask_preds, None, None, 0
# Since we're producing (near) full image masks,
# it'd take too much vram to backprop on every single mask.
# Thus we select only a subset.
if (self.max_masks_to_train != -1) and \
(num_pos > self.max_masks_to_train):
perm = torch.randperm(num_pos)
select = perm[:self.max_masks_to_train]
mask_preds = mask_preds[select]
pos_assigned_gt_inds = pos_assigned_gt_inds[select]
num_pos = self.max_masks_to_train
elif self.topk_masks_per_img != -1:
unique_gt_inds = pos_assigned_gt_inds.unique()
num_inst_per_gt = max(
int(self.topk_masks_per_img / len(unique_gt_inds)), 1)
keep_mask_preds = []
keep_pos_assigned_gt_inds = []
for gt_ind in unique_gt_inds:
per_inst_pos_inds = (pos_assigned_gt_inds == gt_ind)
mask_preds_per_inst = mask_preds[per_inst_pos_inds]
gt_inds_per_inst = pos_assigned_gt_inds[per_inst_pos_inds]
if sum(per_inst_pos_inds) > num_inst_per_gt:
per_inst_scores = scores[per_inst_pos_inds].sigmoid().max(
dim=1)[0]
per_inst_centerness = centernesses[
per_inst_pos_inds].sigmoid().reshape(-1, )
select = (per_inst_scores * per_inst_centerness).topk(
k=num_inst_per_gt, dim=0)[1]
mask_preds_per_inst = mask_preds_per_inst[select]
gt_inds_per_inst = gt_inds_per_inst[select]
keep_mask_preds.append(mask_preds_per_inst)
keep_pos_assigned_gt_inds.append(gt_inds_per_inst)
mask_preds = torch.cat(keep_mask_preds)
pos_assigned_gt_inds = torch.cat(keep_pos_assigned_gt_inds)
num_pos = pos_assigned_gt_inds.size(0)
# Follow the origin implement
start = int(self.mask_out_stride // 2)
gt_masks = gt_masks[:, start::self.mask_out_stride,
start::self.mask_out_stride]
gt_masks = gt_masks.gt(0.5).float()
pos_mask_targets = gt_masks[pos_assigned_gt_inds]
pos_pairwise_masks = pairwise_masks[pos_assigned_gt_inds]
pos_pairwise_masks = pos_pairwise_masks * pos_mask_targets.unsqueeze(1)
return (mask_preds, pos_mask_targets, pos_pairwise_masks, num_pos)
|