File size: 11,785 Bytes
3b96cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Tuple

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.ops import MultiScaleDeformableAttention, batched_nms
from torch import Tensor, nn
from torch.nn.init import normal_

from mmdet.registry import MODELS
from mmdet.structures import OptSampleList
from mmdet.structures.bbox import bbox_cxcywh_to_xyxy
from mmdet.utils import OptConfigType
from ..layers import DDQTransformerDecoder
from ..utils import align_tensor
from .deformable_detr import DeformableDETR
from .dino import DINO


@MODELS.register_module()
class DDQDETR(DINO):
    r"""Implementation of `Dense Distinct Query for
    End-to-End Object Detection <https://arxiv.org/abs/2303.12776>`_

    Code is modified from the `official github repo
    <https://github.com/jshilong/DDQ>`_.

    Args:
        dense_topk_ratio (float): Ratio of num_dense queries to num_queries.
            Defaults to 1.5.
        dqs_cfg (:obj:`ConfigDict` or dict, optional): Config of
            Distinct Queries Selection. Defaults to nms with
            `iou_threshold` = 0.8.
    """

    def __init__(self,
                 *args,
                 dense_topk_ratio: float = 1.5,
                 dqs_cfg: OptConfigType = dict(type='nms', iou_threshold=0.8),
                 **kwargs):
        self.dense_topk_ratio = dense_topk_ratio
        self.decoder_cfg = kwargs['decoder']
        self.dqs_cfg = dqs_cfg
        super().__init__(*args, **kwargs)

        # a share dict in all moduls
        # pass some intermediate results and config parameters
        cache_dict = dict()
        for m in self.modules():
            m.cache_dict = cache_dict
        # first element is the start index of matching queries
        # second element is the number of matching queries
        self.cache_dict['dis_query_info'] = [0, 0]

        # mask for distinct queries in each decoder layer
        self.cache_dict['distinct_query_mask'] = []
        # pass to decoder do the dqs
        self.cache_dict['cls_branches'] = self.bbox_head.cls_branches
        # Used to construct the attention mask after dqs
        self.cache_dict['num_heads'] = self.encoder.layers[
            0].self_attn.num_heads
        # pass to decoder to do the dqs
        self.cache_dict['dqs_cfg'] = self.dqs_cfg

    def _init_layers(self) -> None:
        """Initialize layers except for backbone, neck and bbox_head."""
        super(DDQDETR, self)._init_layers()
        self.decoder = DDQTransformerDecoder(**self.decoder_cfg)
        self.query_embedding = None
        self.query_map = nn.Linear(self.embed_dims, self.embed_dims)

    def init_weights(self) -> None:
        """Initialize weights for Transformer and other components."""
        super(DeformableDETR, self).init_weights()
        for coder in self.encoder, self.decoder:
            for p in coder.parameters():
                if p.dim() > 1:
                    nn.init.xavier_uniform_(p)
        for m in self.modules():
            if isinstance(m, MultiScaleDeformableAttention):
                m.init_weights()
        nn.init.xavier_uniform_(self.memory_trans_fc.weight)
        normal_(self.level_embed)

    def pre_decoder(
        self,
        memory: Tensor,
        memory_mask: Tensor,
        spatial_shapes: Tensor,
        batch_data_samples: OptSampleList = None,
    ) -> Tuple[Dict]:
        """Prepare intermediate variables before entering Transformer decoder,
        such as `query`, `memory`, and `reference_points`.

        Args:
            memory (Tensor): The output embeddings of the Transformer encoder,
                has shape (bs, num_feat_points, dim).
            memory_mask (Tensor): ByteTensor, the padding mask of the memory,
                has shape (bs, num_feat_points). Will only be used when
                `as_two_stage` is `True`.
            spatial_shapes (Tensor): Spatial shapes of features in all levels.
                With shape (num_levels, 2), last dimension represents (h, w).
                Will only be used when `as_two_stage` is `True`.
            batch_data_samples (list[:obj:`DetDataSample`]): The batch
                data samples. It usually includes information such
                as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
                Defaults to None.

        Returns:
            tuple[dict]: The decoder_inputs_dict and head_inputs_dict.

            - decoder_inputs_dict (dict): The keyword dictionary args of
              `self.forward_decoder()`, which includes 'query', 'memory',
              `reference_points`, and `dn_mask`. The reference points of
              decoder input here are 4D boxes, although it has `points`
              in its name.
            - head_inputs_dict (dict): The keyword dictionary args of the
              bbox_head functions, which includes `topk_score`, `topk_coords`,
              `dense_topk_score`, `dense_topk_coords`,
              and `dn_meta`, when `self.training` is `True`, else is empty.
        """
        bs, _, c = memory.shape
        output_memory, output_proposals = self.gen_encoder_output_proposals(
            memory, memory_mask, spatial_shapes)
        enc_outputs_class = self.bbox_head.cls_branches[
            self.decoder.num_layers](
                output_memory)
        enc_outputs_coord_unact = self.bbox_head.reg_branches[
            self.decoder.num_layers](output_memory) + output_proposals

        if self.training:
            # aux dense branch particularly in DDQ DETR, which doesn't exist
            #   in DINO.
            # -1 is the aux head for the encoder
            dense_enc_outputs_class = self.bbox_head.cls_branches[-1](
                output_memory)
            dense_enc_outputs_coord_unact = self.bbox_head.reg_branches[-1](
                output_memory) + output_proposals

        topk = self.num_queries
        dense_topk = int(topk * self.dense_topk_ratio)

        proposals = enc_outputs_coord_unact.sigmoid()
        proposals = bbox_cxcywh_to_xyxy(proposals)
        scores = enc_outputs_class.max(-1)[0].sigmoid()

        if self.training:
            # aux dense branch particularly in DDQ DETR, which doesn't exist
            #   in DINO.
            dense_proposals = dense_enc_outputs_coord_unact.sigmoid()
            dense_proposals = bbox_cxcywh_to_xyxy(dense_proposals)
            dense_scores = dense_enc_outputs_class.max(-1)[0].sigmoid()

        num_imgs = len(scores)
        topk_score = []
        topk_coords_unact = []
        # Distinct query.
        query = []

        dense_topk_score = []
        dense_topk_coords_unact = []
        dense_query = []

        for img_id in range(num_imgs):
            single_proposals = proposals[img_id]
            single_scores = scores[img_id]

            # `batched_nms` of class scores and bbox coordinations is used
            #   particularly by DDQ DETR for region proposal generation,
            #   instead of `topk` of class scores by DINO.
            _, keep_idxs = batched_nms(
                single_proposals, single_scores,
                torch.ones(len(single_scores), device=single_scores.device),
                self.cache_dict['dqs_cfg'])

            if self.training:
                # aux dense branch particularly in DDQ DETR, which doesn't
                #   exist in DINO.
                dense_single_proposals = dense_proposals[img_id]
                dense_single_scores = dense_scores[img_id]
                # sort according the score
                # Only sort by classification score, neither nms nor topk is
                #   required. So input parameter `nms_cfg` = None.
                _, dense_keep_idxs = batched_nms(
                    dense_single_proposals, dense_single_scores,
                    torch.ones(
                        len(dense_single_scores),
                        device=dense_single_scores.device), None)

                dense_topk_score.append(dense_enc_outputs_class[img_id]
                                        [dense_keep_idxs][:dense_topk])
                dense_topk_coords_unact.append(
                    dense_enc_outputs_coord_unact[img_id][dense_keep_idxs]
                    [:dense_topk])

            topk_score.append(enc_outputs_class[img_id][keep_idxs][:topk])

            # Instead of initializing the content part with transformed
            #   coordinates in Deformable DETR, we fuse the feature map
            #   embedding of distinct positions as the content part, which
            #   makes the initial queries more distinct.
            topk_coords_unact.append(
                enc_outputs_coord_unact[img_id][keep_idxs][:topk])

            map_memory = self.query_map(memory[img_id].detach())
            query.append(map_memory[keep_idxs][:topk])
            if self.training:
                # aux dense branch particularly in DDQ DETR, which doesn't
                # exist in DINO.
                dense_query.append(map_memory[dense_keep_idxs][:dense_topk])

        topk_score = align_tensor(topk_score, topk)
        topk_coords_unact = align_tensor(topk_coords_unact, topk)
        query = align_tensor(query, topk)
        if self.training:
            dense_topk_score = align_tensor(dense_topk_score)
            dense_topk_coords_unact = align_tensor(dense_topk_coords_unact)

            dense_query = align_tensor(dense_query)
            num_dense_queries = dense_query.size(1)
        if self.training:
            query = torch.cat([query, dense_query], dim=1)
            topk_coords_unact = torch.cat(
                [topk_coords_unact, dense_topk_coords_unact], dim=1)

        topk_coords = topk_coords_unact.sigmoid()
        if self.training:
            dense_topk_coords = topk_coords[:, -num_dense_queries:]
            topk_coords = topk_coords[:, :-num_dense_queries]

        topk_coords_unact = topk_coords_unact.detach()

        if self.training:
            dn_label_query, dn_bbox_query, dn_mask, dn_meta = \
                self.dn_query_generator(batch_data_samples)
            query = torch.cat([dn_label_query, query], dim=1)
            reference_points = torch.cat([dn_bbox_query, topk_coords_unact],
                                         dim=1)

            # Update `dn_mask` to add mask for dense queries.
            ori_size = dn_mask.size(-1)
            new_size = dn_mask.size(-1) + num_dense_queries
            new_dn_mask = dn_mask.new_ones((new_size, new_size)).bool()
            dense_mask = torch.zeros(num_dense_queries,
                                     num_dense_queries).bool()
            self.cache_dict['dis_query_info'] = [dn_label_query.size(1), topk]

            new_dn_mask[ori_size:, ori_size:] = dense_mask
            new_dn_mask[:ori_size, :ori_size] = dn_mask
            dn_meta['num_dense_queries'] = num_dense_queries
            dn_mask = new_dn_mask
            self.cache_dict['num_dense_queries'] = num_dense_queries
            self.decoder.aux_reg_branches = self.bbox_head.aux_reg_branches

        else:
            self.cache_dict['dis_query_info'] = [0, topk]
            reference_points = topk_coords_unact
            dn_mask, dn_meta = None, None

        reference_points = reference_points.sigmoid()

        decoder_inputs_dict = dict(
            query=query,
            memory=memory,
            reference_points=reference_points,
            dn_mask=dn_mask)
        head_inputs_dict = dict(
            enc_outputs_class=topk_score,
            enc_outputs_coord=topk_coords,
            aux_enc_outputs_class=dense_topk_score,
            aux_enc_outputs_coord=dense_topk_coords,
            dn_meta=dn_meta) if self.training else dict()

        return decoder_inputs_dict, head_inputs_dict