File size: 10,963 Bytes
3b96cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

import torch
import torch.nn.functional as F
from mmengine import MessageHub
from mmengine.structures import InstanceData
from torch import Tensor

from mmdet.registry import MODELS
from mmdet.utils import InstanceList
from ..utils.misc import unfold_wo_center
from .condinst_head import CondInstBboxHead, CondInstMaskHead


@MODELS.register_module()
class BoxInstBboxHead(CondInstBboxHead):
    """BoxInst box head used in https://arxiv.org/abs/2012.02310."""

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)


@MODELS.register_module()
class BoxInstMaskHead(CondInstMaskHead):
    """BoxInst mask head used in https://arxiv.org/abs/2012.02310.

    This head outputs the mask for BoxInst.

    Args:
        pairwise_size (dict): The size of neighborhood for each pixel.
            Defaults to 3.
        pairwise_dilation (int): The dilation of neighborhood for each pixel.
            Defaults to 2.
        warmup_iters (int): Warmup iterations for pair-wise loss.
            Defaults to 10000.
    """

    def __init__(self,
                 *arg,
                 pairwise_size: int = 3,
                 pairwise_dilation: int = 2,
                 warmup_iters: int = 10000,
                 **kwargs) -> None:
        self.pairwise_size = pairwise_size
        self.pairwise_dilation = pairwise_dilation
        self.warmup_iters = warmup_iters
        super().__init__(*arg, **kwargs)

    def get_pairwise_affinity(self, mask_logits: Tensor) -> Tensor:
        """Compute the pairwise affinity for each pixel."""
        log_fg_prob = F.logsigmoid(mask_logits).unsqueeze(1)
        log_bg_prob = F.logsigmoid(-mask_logits).unsqueeze(1)

        log_fg_prob_unfold = unfold_wo_center(
            log_fg_prob,
            kernel_size=self.pairwise_size,
            dilation=self.pairwise_dilation)
        log_bg_prob_unfold = unfold_wo_center(
            log_bg_prob,
            kernel_size=self.pairwise_size,
            dilation=self.pairwise_dilation)

        # the probability of making the same prediction:
        # p_i * p_j + (1 - p_i) * (1 - p_j)
        # we compute the the probability in log space
        # to avoid numerical instability
        log_same_fg_prob = log_fg_prob[:, :, None] + log_fg_prob_unfold
        log_same_bg_prob = log_bg_prob[:, :, None] + log_bg_prob_unfold

        # TODO: Figure out the difference between it and directly sum
        max_ = torch.max(log_same_fg_prob, log_same_bg_prob)
        log_same_prob = torch.log(
            torch.exp(log_same_fg_prob - max_) +
            torch.exp(log_same_bg_prob - max_)) + max_

        return -log_same_prob[:, 0]

    def loss_by_feat(self, mask_preds: List[Tensor],
                     batch_gt_instances: InstanceList,
                     batch_img_metas: List[dict], positive_infos: InstanceList,
                     **kwargs) -> dict:
        """Calculate the loss based on the features extracted by the mask head.

        Args:
            mask_preds (list[Tensor]): List of predicted masks, each has
                shape (num_classes, H, W).
            batch_gt_instances (list[:obj:`InstanceData`]): Batch of
                gt_instance. It usually includes ``bboxes``, ``masks``,
                and ``labels`` attributes.
            batch_img_metas (list[dict]): Meta information of multiple images.
            positive_infos (List[:obj:``InstanceData``]): Information of
                positive samples of each image that are assigned in detection
                head.

        Returns:
            dict[str, Tensor]: A dictionary of loss components.
        """
        assert positive_infos is not None, \
            'positive_infos should not be None in `BoxInstMaskHead`'
        losses = dict()

        loss_mask_project = 0.
        loss_mask_pairwise = 0.
        num_imgs = len(mask_preds)
        total_pos = 0.
        avg_fatcor = 0.

        for idx in range(num_imgs):
            (mask_pred, pos_mask_targets, pos_pairwise_masks, num_pos) = \
                self._get_targets_single(
                mask_preds[idx], batch_gt_instances[idx],
                positive_infos[idx])
            # mask loss
            total_pos += num_pos
            if num_pos == 0 or pos_mask_targets is None:
                loss_project = mask_pred.new_zeros(1).mean()
                loss_pairwise = mask_pred.new_zeros(1).mean()
                avg_fatcor += 0.
            else:
                # compute the project term
                loss_project_x = self.loss_mask(
                    mask_pred.max(dim=1, keepdim=True)[0],
                    pos_mask_targets.max(dim=1, keepdim=True)[0],
                    reduction_override='none').sum()
                loss_project_y = self.loss_mask(
                    mask_pred.max(dim=2, keepdim=True)[0],
                    pos_mask_targets.max(dim=2, keepdim=True)[0],
                    reduction_override='none').sum()
                loss_project = loss_project_x + loss_project_y
                # compute the pairwise term
                pairwise_affinity = self.get_pairwise_affinity(mask_pred)
                avg_fatcor += pos_pairwise_masks.sum().clamp(min=1.0)
                loss_pairwise = (pairwise_affinity * pos_pairwise_masks).sum()

            loss_mask_project += loss_project
            loss_mask_pairwise += loss_pairwise

        if total_pos == 0:
            total_pos += 1  # avoid nan
        if avg_fatcor == 0:
            avg_fatcor += 1  # avoid nan
        loss_mask_project = loss_mask_project / total_pos
        loss_mask_pairwise = loss_mask_pairwise / avg_fatcor
        message_hub = MessageHub.get_current_instance()
        iter = message_hub.get_info('iter')
        warmup_factor = min(iter / float(self.warmup_iters), 1.0)
        loss_mask_pairwise *= warmup_factor

        losses.update(
            loss_mask_project=loss_mask_project,
            loss_mask_pairwise=loss_mask_pairwise)
        return losses

    def _get_targets_single(self, mask_preds: Tensor,
                            gt_instances: InstanceData,
                            positive_info: InstanceData):
        """Compute targets for predictions of single image.

        Args:
            mask_preds (Tensor): Predicted prototypes with shape
                (num_classes, H, W).
            gt_instances (:obj:`InstanceData`): Ground truth of instance
                annotations. It should includes ``bboxes``, ``labels``,
                and ``masks`` attributes.
            positive_info (:obj:`InstanceData`): Information of positive
                samples that are assigned in detection head. It usually
                contains following keys.

                    - pos_assigned_gt_inds (Tensor): Assigner GT indexes of
                      positive proposals, has shape (num_pos, )
                    - pos_inds (Tensor): Positive index of image, has
                      shape (num_pos, ).
                    - param_pred (Tensor): Positive param preditions
                      with shape (num_pos, num_params).

        Returns:
            tuple: Usually returns a tuple containing learning targets.

            - mask_preds (Tensor): Positive predicted mask with shape
              (num_pos, mask_h, mask_w).
            - pos_mask_targets (Tensor): Positive mask targets with shape
              (num_pos, mask_h, mask_w).
            - pos_pairwise_masks (Tensor): Positive pairwise masks with
              shape: (num_pos, num_neighborhood, mask_h, mask_w).
            - num_pos (int): Positive numbers.
        """
        gt_bboxes = gt_instances.bboxes
        device = gt_bboxes.device
        # Note that gt_masks are generated by full box
        # from BoxInstDataPreprocessor
        gt_masks = gt_instances.masks.to_tensor(
            dtype=torch.bool, device=device).float()
        # Note that pairwise_masks are generated by image color similarity
        # from BoxInstDataPreprocessor
        pairwise_masks = gt_instances.pairwise_masks
        pairwise_masks = pairwise_masks.to(device=device)

        # process with mask targets
        pos_assigned_gt_inds = positive_info.get('pos_assigned_gt_inds')
        scores = positive_info.get('scores')
        centernesses = positive_info.get('centernesses')
        num_pos = pos_assigned_gt_inds.size(0)

        if gt_masks.size(0) == 0 or num_pos == 0:
            return mask_preds, None, None, 0
        # Since we're producing (near) full image masks,
        # it'd take too much vram to backprop on every single mask.
        # Thus we select only a subset.
        if (self.max_masks_to_train != -1) and \
           (num_pos > self.max_masks_to_train):
            perm = torch.randperm(num_pos)
            select = perm[:self.max_masks_to_train]
            mask_preds = mask_preds[select]
            pos_assigned_gt_inds = pos_assigned_gt_inds[select]
            num_pos = self.max_masks_to_train
        elif self.topk_masks_per_img != -1:
            unique_gt_inds = pos_assigned_gt_inds.unique()
            num_inst_per_gt = max(
                int(self.topk_masks_per_img / len(unique_gt_inds)), 1)

            keep_mask_preds = []
            keep_pos_assigned_gt_inds = []
            for gt_ind in unique_gt_inds:
                per_inst_pos_inds = (pos_assigned_gt_inds == gt_ind)
                mask_preds_per_inst = mask_preds[per_inst_pos_inds]
                gt_inds_per_inst = pos_assigned_gt_inds[per_inst_pos_inds]
                if sum(per_inst_pos_inds) > num_inst_per_gt:
                    per_inst_scores = scores[per_inst_pos_inds].sigmoid().max(
                        dim=1)[0]
                    per_inst_centerness = centernesses[
                        per_inst_pos_inds].sigmoid().reshape(-1, )
                    select = (per_inst_scores * per_inst_centerness).topk(
                        k=num_inst_per_gt, dim=0)[1]
                    mask_preds_per_inst = mask_preds_per_inst[select]
                    gt_inds_per_inst = gt_inds_per_inst[select]
                keep_mask_preds.append(mask_preds_per_inst)
                keep_pos_assigned_gt_inds.append(gt_inds_per_inst)
            mask_preds = torch.cat(keep_mask_preds)
            pos_assigned_gt_inds = torch.cat(keep_pos_assigned_gt_inds)
            num_pos = pos_assigned_gt_inds.size(0)

        # Follow the origin implement
        start = int(self.mask_out_stride // 2)
        gt_masks = gt_masks[:, start::self.mask_out_stride,
                            start::self.mask_out_stride]
        gt_masks = gt_masks.gt(0.5).float()
        pos_mask_targets = gt_masks[pos_assigned_gt_inds]
        pos_pairwise_masks = pairwise_masks[pos_assigned_gt_inds]
        pos_pairwise_masks = pos_pairwise_masks * pos_mask_targets.unsqueeze(1)

        return (mask_preds, pos_mask_targets, pos_pairwise_masks, num_pos)