Spaces:
Runtime error
Runtime error
File size: 13,567 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Tuple
import torch
from torch import Tensor, nn
from torch.nn.init import normal_
from mmdet.registry import MODELS
from mmdet.structures import OptSampleList
from mmdet.utils import OptConfigType
from ..layers import (CdnQueryGenerator, DeformableDetrTransformerEncoder,
DinoTransformerDecoder, SinePositionalEncoding)
from .deformable_detr import DeformableDETR, MultiScaleDeformableAttention
@MODELS.register_module()
class DINO(DeformableDETR):
r"""Implementation of `DINO: DETR with Improved DeNoising Anchor Boxes
for End-to-End Object Detection <https://arxiv.org/abs/2203.03605>`_
Code is modified from the `official github repo
<https://github.com/IDEA-Research/DINO>`_.
Args:
dn_cfg (:obj:`ConfigDict` or dict, optional): Config of denoising
query generator. Defaults to `None`.
"""
def __init__(self, *args, dn_cfg: OptConfigType = None, **kwargs) -> None:
super().__init__(*args, **kwargs)
assert self.as_two_stage, 'as_two_stage must be True for DINO'
assert self.with_box_refine, 'with_box_refine must be True for DINO'
if dn_cfg is not None:
assert 'num_classes' not in dn_cfg and \
'num_queries' not in dn_cfg and \
'hidden_dim' not in dn_cfg, \
'The three keyword args `num_classes`, `embed_dims`, and ' \
'`num_matching_queries` are set in `detector.__init__()`, ' \
'users should not set them in `dn_cfg` config.'
dn_cfg['num_classes'] = self.bbox_head.num_classes
dn_cfg['embed_dims'] = self.embed_dims
dn_cfg['num_matching_queries'] = self.num_queries
self.dn_query_generator = CdnQueryGenerator(**dn_cfg)
def _init_layers(self) -> None:
"""Initialize layers except for backbone, neck and bbox_head."""
self.positional_encoding = SinePositionalEncoding(
**self.positional_encoding)
self.encoder = DeformableDetrTransformerEncoder(**self.encoder)
self.decoder = DinoTransformerDecoder(**self.decoder)
self.embed_dims = self.encoder.embed_dims
self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims)
# NOTE In DINO, the query_embedding only contains content
# queries, while in Deformable DETR, the query_embedding
# contains both content and spatial queries, and in DETR,
# it only contains spatial queries.
num_feats = self.positional_encoding.num_feats
assert num_feats * 2 == self.embed_dims, \
f'embed_dims should be exactly 2 times of num_feats. ' \
f'Found {self.embed_dims} and {num_feats}.'
self.level_embed = nn.Parameter(
torch.Tensor(self.num_feature_levels, self.embed_dims))
self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims)
self.memory_trans_norm = nn.LayerNorm(self.embed_dims)
def init_weights(self) -> None:
"""Initialize weights for Transformer and other components."""
super(DeformableDETR, self).init_weights()
for coder in self.encoder, self.decoder:
for p in coder.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MultiScaleDeformableAttention):
m.init_weights()
nn.init.xavier_uniform_(self.memory_trans_fc.weight)
nn.init.xavier_uniform_(self.query_embedding.weight)
normal_(self.level_embed)
def forward_transformer(
self,
img_feats: Tuple[Tensor],
batch_data_samples: OptSampleList = None,
) -> Dict:
"""Forward process of Transformer.
The forward procedure of the transformer is defined as:
'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
More details can be found at `TransformerDetector.forward_transformer`
in `mmdet/detector/base_detr.py`.
The difference is that the ground truth in `batch_data_samples` is
required for the `pre_decoder` to prepare the query of DINO.
Additionally, DINO inherits the `pre_transformer` method and the
`forward_encoder` method of DeformableDETR. More details about the
two methods can be found in `mmdet/detector/deformable_detr.py`.
Args:
img_feats (tuple[Tensor]): Tuple of feature maps from neck. Each
feature map has shape (bs, dim, H, W).
batch_data_samples (list[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Defaults to None.
Returns:
dict: The dictionary of bbox_head function inputs, which always
includes the `hidden_states` of the decoder output and may contain
`references` including the initial and intermediate references.
"""
encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer(
img_feats, batch_data_samples)
encoder_outputs_dict = self.forward_encoder(**encoder_inputs_dict)
tmp_dec_in, head_inputs_dict = self.pre_decoder(
**encoder_outputs_dict, batch_data_samples=batch_data_samples)
decoder_inputs_dict.update(tmp_dec_in)
decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict)
head_inputs_dict.update(decoder_outputs_dict)
return head_inputs_dict
def pre_decoder(
self,
memory: Tensor,
memory_mask: Tensor,
spatial_shapes: Tensor,
batch_data_samples: OptSampleList = None,
) -> Tuple[Dict]:
"""Prepare intermediate variables before entering Transformer decoder,
such as `query`, `query_pos`, and `reference_points`.
Args:
memory (Tensor): The output embeddings of the Transformer encoder,
has shape (bs, num_feat_points, dim).
memory_mask (Tensor): ByteTensor, the padding mask of the memory,
has shape (bs, num_feat_points). Will only be used when
`as_two_stage` is `True`.
spatial_shapes (Tensor): Spatial shapes of features in all levels.
With shape (num_levels, 2), last dimension represents (h, w).
Will only be used when `as_two_stage` is `True`.
batch_data_samples (list[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Defaults to None.
Returns:
tuple[dict]: The decoder_inputs_dict and head_inputs_dict.
- decoder_inputs_dict (dict): The keyword dictionary args of
`self.forward_decoder()`, which includes 'query', 'memory',
`reference_points`, and `dn_mask`. The reference points of
decoder input here are 4D boxes, although it has `points`
in its name.
- head_inputs_dict (dict): The keyword dictionary args of the
bbox_head functions, which includes `topk_score`, `topk_coords`,
and `dn_meta` when `self.training` is `True`, else is empty.
"""
bs, _, c = memory.shape
cls_out_features = self.bbox_head.cls_branches[
self.decoder.num_layers].out_features
output_memory, output_proposals = self.gen_encoder_output_proposals(
memory, memory_mask, spatial_shapes)
enc_outputs_class = self.bbox_head.cls_branches[
self.decoder.num_layers](
output_memory)
enc_outputs_coord_unact = self.bbox_head.reg_branches[
self.decoder.num_layers](output_memory) + output_proposals
# NOTE The DINO selects top-k proposals according to scores of
# multi-class classification, while DeformDETR, where the input
# is `enc_outputs_class[..., 0]` selects according to scores of
# binary classification.
topk_indices = torch.topk(
enc_outputs_class.max(-1)[0], k=self.num_queries, dim=1)[1]
topk_score = torch.gather(
enc_outputs_class, 1,
topk_indices.unsqueeze(-1).repeat(1, 1, cls_out_features))
topk_coords_unact = torch.gather(
enc_outputs_coord_unact, 1,
topk_indices.unsqueeze(-1).repeat(1, 1, 4))
topk_coords = topk_coords_unact.sigmoid()
topk_coords_unact = topk_coords_unact.detach()
query = self.query_embedding.weight[:, None, :]
query = query.repeat(1, bs, 1).transpose(0, 1)
if self.training:
dn_label_query, dn_bbox_query, dn_mask, dn_meta = \
self.dn_query_generator(batch_data_samples)
query = torch.cat([dn_label_query, query], dim=1)
reference_points = torch.cat([dn_bbox_query, topk_coords_unact],
dim=1)
else:
reference_points = topk_coords_unact
dn_mask, dn_meta = None, None
reference_points = reference_points.sigmoid()
decoder_inputs_dict = dict(
query=query,
memory=memory,
reference_points=reference_points,
dn_mask=dn_mask)
# NOTE DINO calculates encoder losses on scores and coordinates
# of selected top-k encoder queries, while DeformDETR is of all
# encoder queries.
head_inputs_dict = dict(
enc_outputs_class=topk_score,
enc_outputs_coord=topk_coords,
dn_meta=dn_meta) if self.training else dict()
return decoder_inputs_dict, head_inputs_dict
def forward_decoder(self,
query: Tensor,
memory: Tensor,
memory_mask: Tensor,
reference_points: Tensor,
spatial_shapes: Tensor,
level_start_index: Tensor,
valid_ratios: Tensor,
dn_mask: Optional[Tensor] = None,
**kwargs) -> Dict:
"""Forward with Transformer decoder.
The forward procedure of the transformer is defined as:
'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
More details can be found at `TransformerDetector.forward_transformer`
in `mmdet/detector/base_detr.py`.
Args:
query (Tensor): The queries of decoder inputs, has shape
(bs, num_queries_total, dim), where `num_queries_total` is the
sum of `num_denoising_queries` and `num_matching_queries` when
`self.training` is `True`, else `num_matching_queries`.
memory (Tensor): The output embeddings of the Transformer encoder,
has shape (bs, num_feat_points, dim).
memory_mask (Tensor): ByteTensor, the padding mask of the memory,
has shape (bs, num_feat_points).
reference_points (Tensor): The initial reference, has shape
(bs, num_queries_total, 4) with the last dimension arranged as
(cx, cy, w, h).
spatial_shapes (Tensor): Spatial shapes of features in all levels,
has shape (num_levels, 2), last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape (num_levels, ) and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
valid_ratios (Tensor): The ratios of the valid width and the valid
height relative to the width and the height of features in all
levels, has shape (bs, num_levels, 2).
dn_mask (Tensor, optional): The attention mask to prevent
information leakage from different denoising groups and
matching parts, will be used as `self_attn_mask` of the
`self.decoder`, has shape (num_queries_total,
num_queries_total).
It is `None` when `self.training` is `False`.
Returns:
dict: The dictionary of decoder outputs, which includes the
`hidden_states` of the decoder output and `references` including
the initial and intermediate reference_points.
"""
inter_states, references = self.decoder(
query=query,
value=memory,
key_padding_mask=memory_mask,
self_attn_mask=dn_mask,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
reg_branches=self.bbox_head.reg_branches,
**kwargs)
if len(query) == self.num_queries:
# NOTE: This is to make sure label_embeding can be involved to
# produce loss even if there is no denoising query (no ground truth
# target in this GPU), otherwise, this will raise runtime error in
# distributed training.
inter_states[0] += \
self.dn_query_generator.label_embedding.weight[0, 0] * 0.0
decoder_outputs_dict = dict(
hidden_states=inter_states, references=list(references))
return decoder_outputs_dict
|