Spaces:
Runtime error
Runtime error
File size: 11,785 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Tuple
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.ops import MultiScaleDeformableAttention, batched_nms
from torch import Tensor, nn
from torch.nn.init import normal_
from mmdet.registry import MODELS
from mmdet.structures import OptSampleList
from mmdet.structures.bbox import bbox_cxcywh_to_xyxy
from mmdet.utils import OptConfigType
from ..layers import DDQTransformerDecoder
from ..utils import align_tensor
from .deformable_detr import DeformableDETR
from .dino import DINO
@MODELS.register_module()
class DDQDETR(DINO):
r"""Implementation of `Dense Distinct Query for
End-to-End Object Detection <https://arxiv.org/abs/2303.12776>`_
Code is modified from the `official github repo
<https://github.com/jshilong/DDQ>`_.
Args:
dense_topk_ratio (float): Ratio of num_dense queries to num_queries.
Defaults to 1.5.
dqs_cfg (:obj:`ConfigDict` or dict, optional): Config of
Distinct Queries Selection. Defaults to nms with
`iou_threshold` = 0.8.
"""
def __init__(self,
*args,
dense_topk_ratio: float = 1.5,
dqs_cfg: OptConfigType = dict(type='nms', iou_threshold=0.8),
**kwargs):
self.dense_topk_ratio = dense_topk_ratio
self.decoder_cfg = kwargs['decoder']
self.dqs_cfg = dqs_cfg
super().__init__(*args, **kwargs)
# a share dict in all moduls
# pass some intermediate results and config parameters
cache_dict = dict()
for m in self.modules():
m.cache_dict = cache_dict
# first element is the start index of matching queries
# second element is the number of matching queries
self.cache_dict['dis_query_info'] = [0, 0]
# mask for distinct queries in each decoder layer
self.cache_dict['distinct_query_mask'] = []
# pass to decoder do the dqs
self.cache_dict['cls_branches'] = self.bbox_head.cls_branches
# Used to construct the attention mask after dqs
self.cache_dict['num_heads'] = self.encoder.layers[
0].self_attn.num_heads
# pass to decoder to do the dqs
self.cache_dict['dqs_cfg'] = self.dqs_cfg
def _init_layers(self) -> None:
"""Initialize layers except for backbone, neck and bbox_head."""
super(DDQDETR, self)._init_layers()
self.decoder = DDQTransformerDecoder(**self.decoder_cfg)
self.query_embedding = None
self.query_map = nn.Linear(self.embed_dims, self.embed_dims)
def init_weights(self) -> None:
"""Initialize weights for Transformer and other components."""
super(DeformableDETR, self).init_weights()
for coder in self.encoder, self.decoder:
for p in coder.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MultiScaleDeformableAttention):
m.init_weights()
nn.init.xavier_uniform_(self.memory_trans_fc.weight)
normal_(self.level_embed)
def pre_decoder(
self,
memory: Tensor,
memory_mask: Tensor,
spatial_shapes: Tensor,
batch_data_samples: OptSampleList = None,
) -> Tuple[Dict]:
"""Prepare intermediate variables before entering Transformer decoder,
such as `query`, `memory`, and `reference_points`.
Args:
memory (Tensor): The output embeddings of the Transformer encoder,
has shape (bs, num_feat_points, dim).
memory_mask (Tensor): ByteTensor, the padding mask of the memory,
has shape (bs, num_feat_points). Will only be used when
`as_two_stage` is `True`.
spatial_shapes (Tensor): Spatial shapes of features in all levels.
With shape (num_levels, 2), last dimension represents (h, w).
Will only be used when `as_two_stage` is `True`.
batch_data_samples (list[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Defaults to None.
Returns:
tuple[dict]: The decoder_inputs_dict and head_inputs_dict.
- decoder_inputs_dict (dict): The keyword dictionary args of
`self.forward_decoder()`, which includes 'query', 'memory',
`reference_points`, and `dn_mask`. The reference points of
decoder input here are 4D boxes, although it has `points`
in its name.
- head_inputs_dict (dict): The keyword dictionary args of the
bbox_head functions, which includes `topk_score`, `topk_coords`,
`dense_topk_score`, `dense_topk_coords`,
and `dn_meta`, when `self.training` is `True`, else is empty.
"""
bs, _, c = memory.shape
output_memory, output_proposals = self.gen_encoder_output_proposals(
memory, memory_mask, spatial_shapes)
enc_outputs_class = self.bbox_head.cls_branches[
self.decoder.num_layers](
output_memory)
enc_outputs_coord_unact = self.bbox_head.reg_branches[
self.decoder.num_layers](output_memory) + output_proposals
if self.training:
# aux dense branch particularly in DDQ DETR, which doesn't exist
# in DINO.
# -1 is the aux head for the encoder
dense_enc_outputs_class = self.bbox_head.cls_branches[-1](
output_memory)
dense_enc_outputs_coord_unact = self.bbox_head.reg_branches[-1](
output_memory) + output_proposals
topk = self.num_queries
dense_topk = int(topk * self.dense_topk_ratio)
proposals = enc_outputs_coord_unact.sigmoid()
proposals = bbox_cxcywh_to_xyxy(proposals)
scores = enc_outputs_class.max(-1)[0].sigmoid()
if self.training:
# aux dense branch particularly in DDQ DETR, which doesn't exist
# in DINO.
dense_proposals = dense_enc_outputs_coord_unact.sigmoid()
dense_proposals = bbox_cxcywh_to_xyxy(dense_proposals)
dense_scores = dense_enc_outputs_class.max(-1)[0].sigmoid()
num_imgs = len(scores)
topk_score = []
topk_coords_unact = []
# Distinct query.
query = []
dense_topk_score = []
dense_topk_coords_unact = []
dense_query = []
for img_id in range(num_imgs):
single_proposals = proposals[img_id]
single_scores = scores[img_id]
# `batched_nms` of class scores and bbox coordinations is used
# particularly by DDQ DETR for region proposal generation,
# instead of `topk` of class scores by DINO.
_, keep_idxs = batched_nms(
single_proposals, single_scores,
torch.ones(len(single_scores), device=single_scores.device),
self.cache_dict['dqs_cfg'])
if self.training:
# aux dense branch particularly in DDQ DETR, which doesn't
# exist in DINO.
dense_single_proposals = dense_proposals[img_id]
dense_single_scores = dense_scores[img_id]
# sort according the score
# Only sort by classification score, neither nms nor topk is
# required. So input parameter `nms_cfg` = None.
_, dense_keep_idxs = batched_nms(
dense_single_proposals, dense_single_scores,
torch.ones(
len(dense_single_scores),
device=dense_single_scores.device), None)
dense_topk_score.append(dense_enc_outputs_class[img_id]
[dense_keep_idxs][:dense_topk])
dense_topk_coords_unact.append(
dense_enc_outputs_coord_unact[img_id][dense_keep_idxs]
[:dense_topk])
topk_score.append(enc_outputs_class[img_id][keep_idxs][:topk])
# Instead of initializing the content part with transformed
# coordinates in Deformable DETR, we fuse the feature map
# embedding of distinct positions as the content part, which
# makes the initial queries more distinct.
topk_coords_unact.append(
enc_outputs_coord_unact[img_id][keep_idxs][:topk])
map_memory = self.query_map(memory[img_id].detach())
query.append(map_memory[keep_idxs][:topk])
if self.training:
# aux dense branch particularly in DDQ DETR, which doesn't
# exist in DINO.
dense_query.append(map_memory[dense_keep_idxs][:dense_topk])
topk_score = align_tensor(topk_score, topk)
topk_coords_unact = align_tensor(topk_coords_unact, topk)
query = align_tensor(query, topk)
if self.training:
dense_topk_score = align_tensor(dense_topk_score)
dense_topk_coords_unact = align_tensor(dense_topk_coords_unact)
dense_query = align_tensor(dense_query)
num_dense_queries = dense_query.size(1)
if self.training:
query = torch.cat([query, dense_query], dim=1)
topk_coords_unact = torch.cat(
[topk_coords_unact, dense_topk_coords_unact], dim=1)
topk_coords = topk_coords_unact.sigmoid()
if self.training:
dense_topk_coords = topk_coords[:, -num_dense_queries:]
topk_coords = topk_coords[:, :-num_dense_queries]
topk_coords_unact = topk_coords_unact.detach()
if self.training:
dn_label_query, dn_bbox_query, dn_mask, dn_meta = \
self.dn_query_generator(batch_data_samples)
query = torch.cat([dn_label_query, query], dim=1)
reference_points = torch.cat([dn_bbox_query, topk_coords_unact],
dim=1)
# Update `dn_mask` to add mask for dense queries.
ori_size = dn_mask.size(-1)
new_size = dn_mask.size(-1) + num_dense_queries
new_dn_mask = dn_mask.new_ones((new_size, new_size)).bool()
dense_mask = torch.zeros(num_dense_queries,
num_dense_queries).bool()
self.cache_dict['dis_query_info'] = [dn_label_query.size(1), topk]
new_dn_mask[ori_size:, ori_size:] = dense_mask
new_dn_mask[:ori_size, :ori_size] = dn_mask
dn_meta['num_dense_queries'] = num_dense_queries
dn_mask = new_dn_mask
self.cache_dict['num_dense_queries'] = num_dense_queries
self.decoder.aux_reg_branches = self.bbox_head.aux_reg_branches
else:
self.cache_dict['dis_query_info'] = [0, topk]
reference_points = topk_coords_unact
dn_mask, dn_meta = None, None
reference_points = reference_points.sigmoid()
decoder_inputs_dict = dict(
query=query,
memory=memory,
reference_points=reference_points,
dn_mask=dn_mask)
head_inputs_dict = dict(
enc_outputs_class=topk_score,
enc_outputs_coord=topk_coords,
aux_enc_outputs_class=dense_topk_score,
aux_enc_outputs_coord=dense_topk_coords,
dn_meta=dn_meta) if self.training else dict()
return decoder_inputs_dict, head_inputs_dict
|