Hasanmog commited on
Commit
616dc83
·
1 Parent(s): 1f5a089
cfg_odvg.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_aug_scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
2
+ data_aug_max_size = 1333
3
+ data_aug_scales2_resize = [400, 500, 600]
4
+ data_aug_scales2_crop = [384, 600]
5
+ data_aug_scale_overlap = None
6
+ batch_size = 2
7
+ modelname = 'groundingdino'
8
+ backbone = 'swin_B_384_22k'
9
+ position_embedding = 'sine'
10
+ pe_temperatureH = 20
11
+ pe_temperatureW = 20
12
+ return_interm_indices = [1, 2, 3]
13
+ enc_layers = 6
14
+ dec_layers = 6 # originally 6
15
+ pre_norm = False
16
+ dim_feedforward = 2048
17
+ hidden_dim = 256
18
+ dropout = 0.0
19
+ nheads = 8 # originally 8
20
+ num_queries = 900
21
+ query_dim = 4
22
+ num_patterns = 0
23
+ num_feature_levels = 4
24
+ enc_n_points = 4
25
+ dec_n_points = 4
26
+ two_stage_type = 'standard'
27
+ two_stage_bbox_embed_share = False
28
+ two_stage_class_embed_share = False
29
+ transformer_activation = 'relu'
30
+ dec_pred_bbox_embed_share = True
31
+ dn_box_noise_scale = 1.0
32
+ dn_label_noise_ratio = 0.5
33
+ dn_label_coef = 1.0
34
+ dn_bbox_coef = 1.0
35
+ embed_init_tgt = True
36
+ dn_labelbook_size = 91
37
+ max_text_len = 256
38
+ text_encoder_type = "bert-base-uncased"
39
+ use_text_enhancer = True
40
+ use_fusion_layer = True
41
+ use_checkpoint = True
42
+ use_transformer_ckpt = True
43
+ use_text_cross_attention = True
44
+ text_dropout = 0.0
45
+ fusion_dropout = 0.0
46
+ fusion_droppath = 0.1
47
+ sub_sentence_present = True
48
+ max_labels = 50 # pos + neg
49
+ lr = 0.001 #0.001 # base learning rate
50
+ backbone_freeze_keywords = None # only for gdino backbone
51
+ lora = True
52
+ trainable_keywords = ['transformer' , 'input_proj' , 'feat_map' , 'backbone.0' ] # for whole model, e.g. ['backbone.0', 'bert'] for freeze visual encoder and text encoder
53
+ lr_backbone = 1e-05 # specific learning rate
54
+ lr_backbone_names = ['backbone.0', 'bert']
55
+ lr_linear_proj_mult = 1e-05
56
+ lr_linear_proj_names = ['ref_point_head', 'sampling_offsets']
57
+ weight_decay = 0.001 #0.001
58
+ param_dict_type = 'ddetr_in_mmdet'
59
+ ddetr_lr_param = False
60
+ epochs = 50
61
+ lr_drop = 4
62
+ save_checkpoint_interval = 1
63
+ clip_max_norm = 0.1
64
+ onecyclelr = False
65
+ multi_step_lr = False
66
+ cosine_anneal = False
67
+ ReduceLROnPlateau = True
68
+ step_lr = False
69
+ gamma = 0.95
70
+ lr_drop_list = [2 , 5, 10 , 15 , 20 ]
71
+ frozen_weights = None
72
+ dilation = False
73
+ pdetr3_bbox_embed_diff_each_layer = False
74
+ pdetr3_refHW = -1
75
+ random_refpoints_xy = False
76
+ fix_refpoints_hw = -1
77
+ dabdetr_yolo_like_anchor_update = False
78
+ dabdetr_deformable_encoder = False
79
+ dabdetr_deformable_decoder = False
80
+ use_deformable_box_attn = False
81
+ box_attn_type = 'roi_align'
82
+ dec_layer_number = None
83
+ decoder_layer_noise = False
84
+ dln_xy_noise = 0.2
85
+ dln_hw_noise = 0.2
86
+ add_channel_attention = False
87
+ add_pos_value = False
88
+ two_stage_pat_embed = 0
89
+ two_stage_add_query_num = 0
90
+ two_stage_learn_wh = False
91
+ two_stage_default_hw = 0.05
92
+ two_stage_keep_all_tokens = False
93
+ num_select = 10
94
+ batch_norm_type = 'FrozenBatchNorm2d'
95
+ masks = False
96
+ aux_loss = True
97
+ set_cost_class = 1.0
98
+ set_cost_bbox = 5.0
99
+ set_cost_giou = 2.0
100
+ cls_loss_coef = 2.0 # originally 2.0
101
+ bbox_loss_coef = 5.0
102
+ giou_loss_coef = 2.0
103
+ enc_loss_coef = 1.0
104
+ interm_loss_coef = 1.0
105
+ no_interm_box_loss = False
106
+ mask_loss_coef = 1.0
107
+ dice_loss_coef = 1.0
108
+ focal_alpha = 0.25
109
+ focal_gamma = 2.0
110
+ decoder_sa_type = 'sa'
111
+ matcher_type = 'HungarianMatcher'
112
+ decoder_module_seq = ['sa', 'ca', 'ffn']
113
+ nms_iou_threshold = -1
114
+ dec_pred_class_embed_share = True
115
+ # label_list = [
116
+ # "airplane", "airport", "baseballfield", "basketballcourt", "bridge","chimney", "dam",
117
+ # "Expressway-Service-area", "Expressway-toll-station", "golffield",
118
+ # "groundtrackfield","harbor" , "overpass", "ship", "stadium", "storagetank",
119
+ # "tenniscourt", "trainstation", "vehicle" , "windmill"
120
+ # ] RSVGD
121
+
122
+
123
+ label_list = ["airplane","baseball diamond","basketball court","bridge","crossroad","ground track field","harbor","parking lot","ship","storage tank","swimming pool","tennis court","T junction","vehicle"]
124
+
125
+ match_unstable_error = True
126
+ use_ema = False
127
+ ema_decay = 0.9997
128
+ ema_epoch = 0
129
+ use_detached_boxes_dec_out = False
130
+ use_coco_eval = False
131
+ dn_scalar = 100
groundingdino/models/GroundingDINO/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # Conditional DETR
8
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------
11
+ # Copied from DETR (https://github.com/facebookresearch/detr)
12
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
13
+ # ------------------------------------------------------------------------
14
+
15
+ from .groundingdino import build_groundingdino
groundingdino/models/GroundingDINO/backbone/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .backbone import build_backbone
groundingdino/models/GroundingDINO/bertwarper.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint as checkpoint
11
+ from torch import Tensor, nn
12
+ from torchvision.ops.boxes import nms
13
+ from transformers import BertConfig, BertModel, BertPreTrainedModel
14
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
15
+
16
+
17
+ class BertModelWarper(nn.Module):
18
+ def __init__(self, bert_model):
19
+ super().__init__()
20
+ # self.bert = bert_modelc
21
+
22
+ self.config = bert_model.config
23
+ self.embeddings = bert_model.embeddings
24
+ self.encoder = bert_model.encoder
25
+ self.pooler = bert_model.pooler
26
+
27
+ self.get_extended_attention_mask = bert_model.get_extended_attention_mask
28
+ self.invert_attention_mask = bert_model.invert_attention_mask
29
+ self.get_head_mask = bert_model.get_head_mask
30
+
31
+ def forward(
32
+ self,
33
+ input_ids=None,
34
+ attention_mask=None,
35
+ token_type_ids=None,
36
+ position_ids=None,
37
+ head_mask=None,
38
+ inputs_embeds=None,
39
+ encoder_hidden_states=None,
40
+ encoder_attention_mask=None,
41
+ past_key_values=None,
42
+ use_cache=None,
43
+ output_attentions=None,
44
+ output_hidden_states=None,
45
+ return_dict=None,
46
+ ):
47
+ r"""
48
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
49
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
50
+ the model is configured as a decoder.
51
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
52
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
53
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
54
+
55
+ - 1 for tokens that are **not masked**,
56
+ - 0 for tokens that are **masked**.
57
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
58
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
59
+
60
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
61
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
62
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
63
+ use_cache (:obj:`bool`, `optional`):
64
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
65
+ decoding (see :obj:`past_key_values`).
66
+ """
67
+ output_attentions = (
68
+ output_attentions if output_attentions is not None else self.config.output_attentions
69
+ )
70
+ output_hidden_states = (
71
+ output_hidden_states
72
+ if output_hidden_states is not None
73
+ else self.config.output_hidden_states
74
+ )
75
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
76
+
77
+ if self.config.is_decoder:
78
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
79
+ else:
80
+ use_cache = False
81
+
82
+ if input_ids is not None and inputs_embeds is not None:
83
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
84
+ elif input_ids is not None:
85
+ input_shape = input_ids.size()
86
+ batch_size, seq_length = input_shape
87
+ elif inputs_embeds is not None:
88
+ input_shape = inputs_embeds.size()[:-1]
89
+ batch_size, seq_length = input_shape
90
+ else:
91
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
92
+
93
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
94
+
95
+ # past_key_values_length
96
+ past_key_values_length = (
97
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
98
+ )
99
+
100
+ if attention_mask is None:
101
+ attention_mask = torch.ones(
102
+ ((batch_size, seq_length + past_key_values_length)), device=device
103
+ )
104
+ if token_type_ids is None:
105
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
106
+
107
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
108
+ # ourselves in which case we just need to make it broadcastable to all heads.
109
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
110
+ attention_mask, input_shape, device
111
+ )
112
+
113
+ # If a 2D or 3D attention mask is provided for the cross-attention
114
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
115
+ if self.config.is_decoder and encoder_hidden_states is not None:
116
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
117
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
118
+ if encoder_attention_mask is None:
119
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
120
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
121
+ else:
122
+ encoder_extended_attention_mask = None
123
+ # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
124
+ # import ipdb; ipdb.set_trace()
125
+
126
+ # Prepare head mask if needed
127
+ # 1.0 in head_mask indicate we keep the head
128
+ # attention_probs has shape bsz x n_heads x N x N
129
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
130
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
131
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
132
+
133
+ embedding_output = self.embeddings(
134
+ input_ids=input_ids,
135
+ position_ids=position_ids,
136
+ token_type_ids=token_type_ids,
137
+ inputs_embeds=inputs_embeds,
138
+ past_key_values_length=past_key_values_length,
139
+ )
140
+
141
+ encoder_outputs = self.encoder(
142
+ embedding_output,
143
+ attention_mask=extended_attention_mask,
144
+ head_mask=head_mask,
145
+ encoder_hidden_states=encoder_hidden_states,
146
+ encoder_attention_mask=encoder_extended_attention_mask,
147
+ past_key_values=past_key_values,
148
+ use_cache=use_cache,
149
+ output_attentions=output_attentions,
150
+ output_hidden_states=output_hidden_states,
151
+ return_dict=return_dict,
152
+ )
153
+ sequence_output = encoder_outputs[0]
154
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
155
+
156
+ if not return_dict:
157
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
158
+
159
+ return BaseModelOutputWithPoolingAndCrossAttentions(
160
+ last_hidden_state=sequence_output,
161
+ pooler_output=pooled_output,
162
+ past_key_values=encoder_outputs.past_key_values,
163
+ hidden_states=encoder_outputs.hidden_states,
164
+ attentions=encoder_outputs.attentions,
165
+ cross_attentions=encoder_outputs.cross_attentions,
166
+ )
167
+
168
+
169
+ class TextEncoderShell(nn.Module):
170
+ def __init__(self, text_encoder):
171
+ super().__init__()
172
+ self.text_encoder = text_encoder
173
+ self.config = self.text_encoder.config
174
+
175
+ def forward(self, **kw):
176
+ # feed into text encoder
177
+ return self.text_encoder(**kw)
178
+
179
+
180
+ def generate_masks_with_special_tokens(tokenized, special_tokens_list, tokenizer):
181
+ """Generate attention mask between each pair of special tokens
182
+ Args:
183
+ input_ids (torch.Tensor): input ids. Shape: [bs, num_token]
184
+ special_tokens_mask (list): special tokens mask.
185
+ Returns:
186
+ torch.Tensor: attention mask between each special tokens.
187
+ """
188
+ input_ids = tokenized["input_ids"]
189
+ bs, num_token = input_ids.shape
190
+ # special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
191
+ special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
192
+ for special_token in special_tokens_list:
193
+ special_tokens_mask |= input_ids == special_token
194
+
195
+ # idxs: each row is a list of indices of special tokens
196
+ idxs = torch.nonzero(special_tokens_mask)
197
+
198
+ # generate attention mask and positional ids
199
+ attention_mask = (
200
+ torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(bs, 1, 1)
201
+ )
202
+ position_ids = torch.zeros((bs, num_token), device=input_ids.device)
203
+ previous_col = 0
204
+ for i in range(idxs.shape[0]):
205
+ row, col = idxs[i]
206
+ if (col == 0) or (col == num_token - 1):
207
+ attention_mask[row, col, col] = True
208
+ position_ids[row, col] = 0
209
+ else:
210
+ attention_mask[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True
211
+ position_ids[row, previous_col + 1 : col + 1] = torch.arange(
212
+ 0, col - previous_col, device=input_ids.device
213
+ )
214
+
215
+ previous_col = col
216
+
217
+ # # padding mask
218
+ # padding_mask = tokenized['attention_mask']
219
+ # attention_mask = attention_mask & padding_mask.unsqueeze(1).bool() & padding_mask.unsqueeze(2).bool()
220
+
221
+ return attention_mask, position_ids.to(torch.long)
222
+
223
+
224
+ def generate_masks_with_special_tokens_and_transfer_map(tokenized, special_tokens_list, tokenizer):
225
+ """Generate attention mask between each pair of special tokens
226
+ Args:
227
+ input_ids (torch.Tensor): input ids. Shape: [bs, num_token]
228
+ special_tokens_mask (list): special tokens mask.
229
+ Returns:
230
+ torch.Tensor: attention mask between each special tokens.
231
+ """
232
+ input_ids = tokenized["input_ids"]
233
+ bs, num_token = input_ids.shape
234
+ # special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
235
+ special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
236
+ for special_token in special_tokens_list:
237
+ special_tokens_mask |= input_ids == special_token
238
+
239
+ # idxs: each row is a list of indices of special tokens
240
+ idxs = torch.nonzero(special_tokens_mask)
241
+
242
+ # generate attention mask and positional ids
243
+ attention_mask = (
244
+ torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(bs, 1, 1)
245
+ )
246
+ position_ids = torch.zeros((bs, num_token), device=input_ids.device)
247
+ cate_to_token_mask_list = [[] for _ in range(bs)]
248
+ previous_col = 0
249
+ for i in range(idxs.shape[0]):
250
+ row, col = idxs[i]
251
+ if (col == 0) or (col == num_token - 1):
252
+ attention_mask[row, col, col] = True
253
+ position_ids[row, col] = 0
254
+ else:
255
+ attention_mask[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True
256
+ position_ids[row, previous_col + 1 : col + 1] = torch.arange(
257
+ 0, col - previous_col, device=input_ids.device
258
+ )
259
+ c2t_maski = torch.zeros((num_token), device=input_ids.device).bool()
260
+ c2t_maski[previous_col + 1 : col] = True
261
+ cate_to_token_mask_list[row].append(c2t_maski)
262
+ previous_col = col
263
+
264
+ cate_to_token_mask_list = [
265
+ torch.stack(cate_to_token_mask_listi, dim=0)
266
+ for cate_to_token_mask_listi in cate_to_token_mask_list
267
+ ]
268
+
269
+ # # padding mask
270
+ # padding_mask = tokenized['attention_mask']
271
+ # attention_mask = attention_mask & padding_mask.unsqueeze(1).bool() & padding_mask.unsqueeze(2).bool()
272
+
273
+ return attention_mask, position_ids.to(torch.long), cate_to_token_mask_list
groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn.h ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+
13
+ #include "ms_deform_attn_cpu.h"
14
+
15
+ #ifdef WITH_CUDA
16
+ #include "ms_deform_attn_cuda.h"
17
+ #endif
18
+
19
+ namespace groundingdino {
20
+
21
+ at::Tensor
22
+ ms_deform_attn_forward(
23
+ const at::Tensor &value,
24
+ const at::Tensor &spatial_shapes,
25
+ const at::Tensor &level_start_index,
26
+ const at::Tensor &sampling_loc,
27
+ const at::Tensor &attn_weight,
28
+ const int im2col_step)
29
+ {
30
+ if (value.type().is_cuda())
31
+ {
32
+ #ifdef WITH_CUDA
33
+ return ms_deform_attn_cuda_forward(
34
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
35
+ #else
36
+ AT_ERROR("Not compiled with GPU support");
37
+ #endif
38
+ }
39
+ AT_ERROR("Not implemented on the CPU");
40
+ }
41
+
42
+ std::vector<at::Tensor>
43
+ ms_deform_attn_backward(
44
+ const at::Tensor &value,
45
+ const at::Tensor &spatial_shapes,
46
+ const at::Tensor &level_start_index,
47
+ const at::Tensor &sampling_loc,
48
+ const at::Tensor &attn_weight,
49
+ const at::Tensor &grad_output,
50
+ const int im2col_step)
51
+ {
52
+ if (value.type().is_cuda())
53
+ {
54
+ #ifdef WITH_CUDA
55
+ return ms_deform_attn_cuda_backward(
56
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
57
+ #else
58
+ AT_ERROR("Not compiled with GPU support");
59
+ #endif
60
+ }
61
+ AT_ERROR("Not implemented on the CPU");
62
+ }
63
+
64
+ } // namespace groundingdino
groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.cpp ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include <vector>
12
+
13
+ #include <ATen/ATen.h>
14
+ #include <ATen/cuda/CUDAContext.h>
15
+
16
+ namespace groundingdino {
17
+
18
+ at::Tensor
19
+ ms_deform_attn_cpu_forward(
20
+ const at::Tensor &value,
21
+ const at::Tensor &spatial_shapes,
22
+ const at::Tensor &level_start_index,
23
+ const at::Tensor &sampling_loc,
24
+ const at::Tensor &attn_weight,
25
+ const int im2col_step)
26
+ {
27
+ AT_ERROR("Not implement on cpu");
28
+ }
29
+
30
+ std::vector<at::Tensor>
31
+ ms_deform_attn_cpu_backward(
32
+ const at::Tensor &value,
33
+ const at::Tensor &spatial_shapes,
34
+ const at::Tensor &level_start_index,
35
+ const at::Tensor &sampling_loc,
36
+ const at::Tensor &attn_weight,
37
+ const at::Tensor &grad_output,
38
+ const int im2col_step)
39
+ {
40
+ AT_ERROR("Not implement on cpu");
41
+ }
42
+
43
+ } // namespace groundingdino
groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.h ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+ #include <torch/extension.h>
13
+
14
+ namespace groundingdino {
15
+
16
+ at::Tensor
17
+ ms_deform_attn_cpu_forward(
18
+ const at::Tensor &value,
19
+ const at::Tensor &spatial_shapes,
20
+ const at::Tensor &level_start_index,
21
+ const at::Tensor &sampling_loc,
22
+ const at::Tensor &attn_weight,
23
+ const int im2col_step);
24
+
25
+ std::vector<at::Tensor>
26
+ ms_deform_attn_cpu_backward(
27
+ const at::Tensor &value,
28
+ const at::Tensor &spatial_shapes,
29
+ const at::Tensor &level_start_index,
30
+ const at::Tensor &sampling_loc,
31
+ const at::Tensor &attn_weight,
32
+ const at::Tensor &grad_output,
33
+ const int im2col_step);
34
+
35
+ } // namespace groundingdino
groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include <vector>
12
+ #include "ms_deform_im2col_cuda.cuh"
13
+
14
+ #include <ATen/ATen.h>
15
+ #include <ATen/cuda/CUDAContext.h>
16
+ #include <cuda.h>
17
+ #include <cuda_runtime.h>
18
+
19
+ namespace groundingdino {
20
+
21
+ at::Tensor ms_deform_attn_cuda_forward(
22
+ const at::Tensor &value,
23
+ const at::Tensor &spatial_shapes,
24
+ const at::Tensor &level_start_index,
25
+ const at::Tensor &sampling_loc,
26
+ const at::Tensor &attn_weight,
27
+ const int im2col_step)
28
+ {
29
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
30
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
31
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
32
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
33
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
34
+
35
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
36
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
37
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
38
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
39
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
40
+
41
+ const int batch = value.size(0);
42
+ const int spatial_size = value.size(1);
43
+ const int num_heads = value.size(2);
44
+ const int channels = value.size(3);
45
+
46
+ const int num_levels = spatial_shapes.size(0);
47
+
48
+ const int num_query = sampling_loc.size(1);
49
+ const int num_point = sampling_loc.size(4);
50
+
51
+ const int im2col_step_ = std::min(batch, im2col_step);
52
+
53
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
54
+
55
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
56
+
57
+ const int batch_n = im2col_step_;
58
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
59
+ auto per_value_size = spatial_size * num_heads * channels;
60
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
61
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
62
+ for (int n = 0; n < batch/im2col_step_; ++n)
63
+ {
64
+ auto columns = output_n.select(0, n);
65
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
66
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
67
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
68
+ spatial_shapes.data<int64_t>(),
69
+ level_start_index.data<int64_t>(),
70
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
71
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
72
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
73
+ columns.data<scalar_t>());
74
+
75
+ }));
76
+ }
77
+
78
+ output = output.view({batch, num_query, num_heads*channels});
79
+
80
+ return output;
81
+ }
82
+
83
+
84
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
85
+ const at::Tensor &value,
86
+ const at::Tensor &spatial_shapes,
87
+ const at::Tensor &level_start_index,
88
+ const at::Tensor &sampling_loc,
89
+ const at::Tensor &attn_weight,
90
+ const at::Tensor &grad_output,
91
+ const int im2col_step)
92
+ {
93
+
94
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
95
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
96
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
97
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
98
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
99
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
100
+
101
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
102
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
103
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
104
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
105
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
106
+ AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
107
+
108
+ const int batch = value.size(0);
109
+ const int spatial_size = value.size(1);
110
+ const int num_heads = value.size(2);
111
+ const int channels = value.size(3);
112
+
113
+ const int num_levels = spatial_shapes.size(0);
114
+
115
+ const int num_query = sampling_loc.size(1);
116
+ const int num_point = sampling_loc.size(4);
117
+
118
+ const int im2col_step_ = std::min(batch, im2col_step);
119
+
120
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
121
+
122
+ auto grad_value = at::zeros_like(value);
123
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
124
+ auto grad_attn_weight = at::zeros_like(attn_weight);
125
+
126
+ const int batch_n = im2col_step_;
127
+ auto per_value_size = spatial_size * num_heads * channels;
128
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
129
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
130
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
131
+
132
+ for (int n = 0; n < batch/im2col_step_; ++n)
133
+ {
134
+ auto grad_output_g = grad_output_n.select(0, n);
135
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
136
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
137
+ grad_output_g.data<scalar_t>(),
138
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
139
+ spatial_shapes.data<int64_t>(),
140
+ level_start_index.data<int64_t>(),
141
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
142
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
143
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
144
+ grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
145
+ grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
146
+ grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
147
+
148
+ }));
149
+ }
150
+
151
+ return {
152
+ grad_value, grad_sampling_loc, grad_attn_weight
153
+ };
154
+ }
155
+
156
+ } // namespace groundingdino
groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.h ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+ #include <torch/extension.h>
13
+
14
+ namespace groundingdino {
15
+
16
+ at::Tensor ms_deform_attn_cuda_forward(
17
+ const at::Tensor &value,
18
+ const at::Tensor &spatial_shapes,
19
+ const at::Tensor &level_start_index,
20
+ const at::Tensor &sampling_loc,
21
+ const at::Tensor &attn_weight,
22
+ const int im2col_step);
23
+
24
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
25
+ const at::Tensor &value,
26
+ const at::Tensor &spatial_shapes,
27
+ const at::Tensor &level_start_index,
28
+ const at::Tensor &sampling_loc,
29
+ const at::Tensor &attn_weight,
30
+ const at::Tensor &grad_output,
31
+ const int im2col_step);
32
+
33
+ } // namespace groundingdino
groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_im2col_cuda.cuh ADDED
@@ -0,0 +1,1327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************
7
+ * Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
8
+ * Copyright (c) 2018 Microsoft
9
+ **************************************************************************
10
+ */
11
+
12
+ #include <cstdio>
13
+ #include <algorithm>
14
+ #include <cstring>
15
+
16
+ #include <ATen/ATen.h>
17
+ #include <ATen/cuda/CUDAContext.h>
18
+
19
+ #include <THC/THCAtomics.cuh>
20
+
21
+ #define CUDA_KERNEL_LOOP(i, n) \
22
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
23
+ i < (n); \
24
+ i += blockDim.x * gridDim.x)
25
+
26
+ const int CUDA_NUM_THREADS = 1024;
27
+ inline int GET_BLOCKS(const int N, const int num_threads)
28
+ {
29
+ return (N + num_threads - 1) / num_threads;
30
+ }
31
+
32
+
33
+ template <typename scalar_t>
34
+ __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
35
+ const int &height, const int &width, const int &nheads, const int &channels,
36
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
37
+ {
38
+ const int h_low = floor(h);
39
+ const int w_low = floor(w);
40
+ const int h_high = h_low + 1;
41
+ const int w_high = w_low + 1;
42
+
43
+ const scalar_t lh = h - h_low;
44
+ const scalar_t lw = w - w_low;
45
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
46
+
47
+ const int w_stride = nheads * channels;
48
+ const int h_stride = width * w_stride;
49
+ const int h_low_ptr_offset = h_low * h_stride;
50
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
51
+ const int w_low_ptr_offset = w_low * w_stride;
52
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
53
+ const int base_ptr = m * channels + c;
54
+
55
+ scalar_t v1 = 0;
56
+ if (h_low >= 0 && w_low >= 0)
57
+ {
58
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
59
+ v1 = bottom_data[ptr1];
60
+ }
61
+ scalar_t v2 = 0;
62
+ if (h_low >= 0 && w_high <= width - 1)
63
+ {
64
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
65
+ v2 = bottom_data[ptr2];
66
+ }
67
+ scalar_t v3 = 0;
68
+ if (h_high <= height - 1 && w_low >= 0)
69
+ {
70
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
71
+ v3 = bottom_data[ptr3];
72
+ }
73
+ scalar_t v4 = 0;
74
+ if (h_high <= height - 1 && w_high <= width - 1)
75
+ {
76
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
77
+ v4 = bottom_data[ptr4];
78
+ }
79
+
80
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
81
+
82
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
83
+ return val;
84
+ }
85
+
86
+
87
+ template <typename scalar_t>
88
+ __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
89
+ const int &height, const int &width, const int &nheads, const int &channels,
90
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
91
+ const scalar_t &top_grad,
92
+ const scalar_t &attn_weight,
93
+ scalar_t* &grad_value,
94
+ scalar_t* grad_sampling_loc,
95
+ scalar_t* grad_attn_weight)
96
+ {
97
+ const int h_low = floor(h);
98
+ const int w_low = floor(w);
99
+ const int h_high = h_low + 1;
100
+ const int w_high = w_low + 1;
101
+
102
+ const scalar_t lh = h - h_low;
103
+ const scalar_t lw = w - w_low;
104
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
105
+
106
+ const int w_stride = nheads * channels;
107
+ const int h_stride = width * w_stride;
108
+ const int h_low_ptr_offset = h_low * h_stride;
109
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
110
+ const int w_low_ptr_offset = w_low * w_stride;
111
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
112
+ const int base_ptr = m * channels + c;
113
+
114
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
115
+ const scalar_t top_grad_value = top_grad * attn_weight;
116
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
117
+
118
+ scalar_t v1 = 0;
119
+ if (h_low >= 0 && w_low >= 0)
120
+ {
121
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
122
+ v1 = bottom_data[ptr1];
123
+ grad_h_weight -= hw * v1;
124
+ grad_w_weight -= hh * v1;
125
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
126
+ }
127
+ scalar_t v2 = 0;
128
+ if (h_low >= 0 && w_high <= width - 1)
129
+ {
130
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
131
+ v2 = bottom_data[ptr2];
132
+ grad_h_weight -= lw * v2;
133
+ grad_w_weight += hh * v2;
134
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
135
+ }
136
+ scalar_t v3 = 0;
137
+ if (h_high <= height - 1 && w_low >= 0)
138
+ {
139
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
140
+ v3 = bottom_data[ptr3];
141
+ grad_h_weight += hw * v3;
142
+ grad_w_weight -= lh * v3;
143
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
144
+ }
145
+ scalar_t v4 = 0;
146
+ if (h_high <= height - 1 && w_high <= width - 1)
147
+ {
148
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
149
+ v4 = bottom_data[ptr4];
150
+ grad_h_weight += lw * v4;
151
+ grad_w_weight += lh * v4;
152
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
153
+ }
154
+
155
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
156
+ *grad_attn_weight = top_grad * val;
157
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
158
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
159
+ }
160
+
161
+
162
+ template <typename scalar_t>
163
+ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
164
+ const int &height, const int &width, const int &nheads, const int &channels,
165
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
166
+ const scalar_t &top_grad,
167
+ const scalar_t &attn_weight,
168
+ scalar_t* &grad_value,
169
+ scalar_t* grad_sampling_loc,
170
+ scalar_t* grad_attn_weight)
171
+ {
172
+ const int h_low = floor(h);
173
+ const int w_low = floor(w);
174
+ const int h_high = h_low + 1;
175
+ const int w_high = w_low + 1;
176
+
177
+ const scalar_t lh = h - h_low;
178
+ const scalar_t lw = w - w_low;
179
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
180
+
181
+ const int w_stride = nheads * channels;
182
+ const int h_stride = width * w_stride;
183
+ const int h_low_ptr_offset = h_low * h_stride;
184
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
185
+ const int w_low_ptr_offset = w_low * w_stride;
186
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
187
+ const int base_ptr = m * channels + c;
188
+
189
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
190
+ const scalar_t top_grad_value = top_grad * attn_weight;
191
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
192
+
193
+ scalar_t v1 = 0;
194
+ if (h_low >= 0 && w_low >= 0)
195
+ {
196
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
197
+ v1 = bottom_data[ptr1];
198
+ grad_h_weight -= hw * v1;
199
+ grad_w_weight -= hh * v1;
200
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
201
+ }
202
+ scalar_t v2 = 0;
203
+ if (h_low >= 0 && w_high <= width - 1)
204
+ {
205
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
206
+ v2 = bottom_data[ptr2];
207
+ grad_h_weight -= lw * v2;
208
+ grad_w_weight += hh * v2;
209
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
210
+ }
211
+ scalar_t v3 = 0;
212
+ if (h_high <= height - 1 && w_low >= 0)
213
+ {
214
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
215
+ v3 = bottom_data[ptr3];
216
+ grad_h_weight += hw * v3;
217
+ grad_w_weight -= lh * v3;
218
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
219
+ }
220
+ scalar_t v4 = 0;
221
+ if (h_high <= height - 1 && w_high <= width - 1)
222
+ {
223
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
224
+ v4 = bottom_data[ptr4];
225
+ grad_h_weight += lw * v4;
226
+ grad_w_weight += lh * v4;
227
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
228
+ }
229
+
230
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
231
+ atomicAdd(grad_attn_weight, top_grad * val);
232
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
233
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
234
+ }
235
+
236
+
237
+ template <typename scalar_t>
238
+ __global__ void ms_deformable_im2col_gpu_kernel(const int n,
239
+ const scalar_t *data_value,
240
+ const int64_t *data_spatial_shapes,
241
+ const int64_t *data_level_start_index,
242
+ const scalar_t *data_sampling_loc,
243
+ const scalar_t *data_attn_weight,
244
+ const int batch_size,
245
+ const int spatial_size,
246
+ const int num_heads,
247
+ const int channels,
248
+ const int num_levels,
249
+ const int num_query,
250
+ const int num_point,
251
+ scalar_t *data_col)
252
+ {
253
+ CUDA_KERNEL_LOOP(index, n)
254
+ {
255
+ int _temp = index;
256
+ const int c_col = _temp % channels;
257
+ _temp /= channels;
258
+ const int sampling_index = _temp;
259
+ const int m_col = _temp % num_heads;
260
+ _temp /= num_heads;
261
+ const int q_col = _temp % num_query;
262
+ _temp /= num_query;
263
+ const int b_col = _temp;
264
+
265
+ scalar_t *data_col_ptr = data_col + index;
266
+ int data_weight_ptr = sampling_index * num_levels * num_point;
267
+ int data_loc_w_ptr = data_weight_ptr << 1;
268
+ const int qid_stride = num_heads * channels;
269
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
270
+ scalar_t col = 0;
271
+
272
+ for (int l_col=0; l_col < num_levels; ++l_col)
273
+ {
274
+ const int level_start_id = data_level_start_index[l_col];
275
+ const int spatial_h_ptr = l_col << 1;
276
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
277
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
278
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
279
+ for (int p_col=0; p_col < num_point; ++p_col)
280
+ {
281
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
282
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
283
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
284
+
285
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
286
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
287
+
288
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
289
+ {
290
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
291
+ }
292
+
293
+ data_weight_ptr += 1;
294
+ data_loc_w_ptr += 2;
295
+ }
296
+ }
297
+ *data_col_ptr = col;
298
+ }
299
+ }
300
+
301
+ template <typename scalar_t, unsigned int blockSize>
302
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
303
+ const scalar_t *grad_col,
304
+ const scalar_t *data_value,
305
+ const int64_t *data_spatial_shapes,
306
+ const int64_t *data_level_start_index,
307
+ const scalar_t *data_sampling_loc,
308
+ const scalar_t *data_attn_weight,
309
+ const int batch_size,
310
+ const int spatial_size,
311
+ const int num_heads,
312
+ const int channels,
313
+ const int num_levels,
314
+ const int num_query,
315
+ const int num_point,
316
+ scalar_t *grad_value,
317
+ scalar_t *grad_sampling_loc,
318
+ scalar_t *grad_attn_weight)
319
+ {
320
+ CUDA_KERNEL_LOOP(index, n)
321
+ {
322
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
323
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
324
+ unsigned int tid = threadIdx.x;
325
+ int _temp = index;
326
+ const int c_col = _temp % channels;
327
+ _temp /= channels;
328
+ const int sampling_index = _temp;
329
+ const int m_col = _temp % num_heads;
330
+ _temp /= num_heads;
331
+ const int q_col = _temp % num_query;
332
+ _temp /= num_query;
333
+ const int b_col = _temp;
334
+
335
+ const scalar_t top_grad = grad_col[index];
336
+
337
+ int data_weight_ptr = sampling_index * num_levels * num_point;
338
+ int data_loc_w_ptr = data_weight_ptr << 1;
339
+ const int grad_sampling_ptr = data_weight_ptr;
340
+ grad_sampling_loc += grad_sampling_ptr << 1;
341
+ grad_attn_weight += grad_sampling_ptr;
342
+ const int grad_weight_stride = 1;
343
+ const int grad_loc_stride = 2;
344
+ const int qid_stride = num_heads * channels;
345
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
346
+
347
+ for (int l_col=0; l_col < num_levels; ++l_col)
348
+ {
349
+ const int level_start_id = data_level_start_index[l_col];
350
+ const int spatial_h_ptr = l_col << 1;
351
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
352
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
353
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
354
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
355
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
356
+
357
+ for (int p_col=0; p_col < num_point; ++p_col)
358
+ {
359
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
360
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
361
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
362
+
363
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
364
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
365
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
366
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
367
+ *(cache_grad_attn_weight+threadIdx.x)=0;
368
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
369
+ {
370
+ ms_deform_attn_col2im_bilinear(
371
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
372
+ top_grad, weight, grad_value_ptr,
373
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
374
+ }
375
+
376
+ __syncthreads();
377
+ if (tid == 0)
378
+ {
379
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
380
+ int sid=2;
381
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
382
+ {
383
+ _grad_w += cache_grad_sampling_loc[sid];
384
+ _grad_h += cache_grad_sampling_loc[sid + 1];
385
+ _grad_a += cache_grad_attn_weight[tid];
386
+ sid += 2;
387
+ }
388
+
389
+
390
+ *grad_sampling_loc = _grad_w;
391
+ *(grad_sampling_loc + 1) = _grad_h;
392
+ *grad_attn_weight = _grad_a;
393
+ }
394
+ __syncthreads();
395
+
396
+ data_weight_ptr += 1;
397
+ data_loc_w_ptr += 2;
398
+ grad_attn_weight += grad_weight_stride;
399
+ grad_sampling_loc += grad_loc_stride;
400
+ }
401
+ }
402
+ }
403
+ }
404
+
405
+
406
+ template <typename scalar_t, unsigned int blockSize>
407
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
408
+ const scalar_t *grad_col,
409
+ const scalar_t *data_value,
410
+ const int64_t *data_spatial_shapes,
411
+ const int64_t *data_level_start_index,
412
+ const scalar_t *data_sampling_loc,
413
+ const scalar_t *data_attn_weight,
414
+ const int batch_size,
415
+ const int spatial_size,
416
+ const int num_heads,
417
+ const int channels,
418
+ const int num_levels,
419
+ const int num_query,
420
+ const int num_point,
421
+ scalar_t *grad_value,
422
+ scalar_t *grad_sampling_loc,
423
+ scalar_t *grad_attn_weight)
424
+ {
425
+ CUDA_KERNEL_LOOP(index, n)
426
+ {
427
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
428
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
429
+ unsigned int tid = threadIdx.x;
430
+ int _temp = index;
431
+ const int c_col = _temp % channels;
432
+ _temp /= channels;
433
+ const int sampling_index = _temp;
434
+ const int m_col = _temp % num_heads;
435
+ _temp /= num_heads;
436
+ const int q_col = _temp % num_query;
437
+ _temp /= num_query;
438
+ const int b_col = _temp;
439
+
440
+ const scalar_t top_grad = grad_col[index];
441
+
442
+ int data_weight_ptr = sampling_index * num_levels * num_point;
443
+ int data_loc_w_ptr = data_weight_ptr << 1;
444
+ const int grad_sampling_ptr = data_weight_ptr;
445
+ grad_sampling_loc += grad_sampling_ptr << 1;
446
+ grad_attn_weight += grad_sampling_ptr;
447
+ const int grad_weight_stride = 1;
448
+ const int grad_loc_stride = 2;
449
+ const int qid_stride = num_heads * channels;
450
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
451
+
452
+ for (int l_col=0; l_col < num_levels; ++l_col)
453
+ {
454
+ const int level_start_id = data_level_start_index[l_col];
455
+ const int spatial_h_ptr = l_col << 1;
456
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
457
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
458
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
459
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
460
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
461
+
462
+ for (int p_col=0; p_col < num_point; ++p_col)
463
+ {
464
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
465
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
466
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
467
+
468
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
469
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
470
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
471
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
472
+ *(cache_grad_attn_weight+threadIdx.x)=0;
473
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
474
+ {
475
+ ms_deform_attn_col2im_bilinear(
476
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
477
+ top_grad, weight, grad_value_ptr,
478
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
479
+ }
480
+
481
+ __syncthreads();
482
+
483
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
484
+ {
485
+ if (tid < s) {
486
+ const unsigned int xid1 = tid << 1;
487
+ const unsigned int xid2 = (tid + s) << 1;
488
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
489
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
490
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
491
+ }
492
+ __syncthreads();
493
+ }
494
+
495
+ if (tid == 0)
496
+ {
497
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
498
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
499
+ *grad_attn_weight = cache_grad_attn_weight[0];
500
+ }
501
+ __syncthreads();
502
+
503
+ data_weight_ptr += 1;
504
+ data_loc_w_ptr += 2;
505
+ grad_attn_weight += grad_weight_stride;
506
+ grad_sampling_loc += grad_loc_stride;
507
+ }
508
+ }
509
+ }
510
+ }
511
+
512
+
513
+ template <typename scalar_t>
514
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
515
+ const scalar_t *grad_col,
516
+ const scalar_t *data_value,
517
+ const int64_t *data_spatial_shapes,
518
+ const int64_t *data_level_start_index,
519
+ const scalar_t *data_sampling_loc,
520
+ const scalar_t *data_attn_weight,
521
+ const int batch_size,
522
+ const int spatial_size,
523
+ const int num_heads,
524
+ const int channels,
525
+ const int num_levels,
526
+ const int num_query,
527
+ const int num_point,
528
+ scalar_t *grad_value,
529
+ scalar_t *grad_sampling_loc,
530
+ scalar_t *grad_attn_weight)
531
+ {
532
+ CUDA_KERNEL_LOOP(index, n)
533
+ {
534
+ extern __shared__ int _s[];
535
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
536
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
537
+ unsigned int tid = threadIdx.x;
538
+ int _temp = index;
539
+ const int c_col = _temp % channels;
540
+ _temp /= channels;
541
+ const int sampling_index = _temp;
542
+ const int m_col = _temp % num_heads;
543
+ _temp /= num_heads;
544
+ const int q_col = _temp % num_query;
545
+ _temp /= num_query;
546
+ const int b_col = _temp;
547
+
548
+ const scalar_t top_grad = grad_col[index];
549
+
550
+ int data_weight_ptr = sampling_index * num_levels * num_point;
551
+ int data_loc_w_ptr = data_weight_ptr << 1;
552
+ const int grad_sampling_ptr = data_weight_ptr;
553
+ grad_sampling_loc += grad_sampling_ptr << 1;
554
+ grad_attn_weight += grad_sampling_ptr;
555
+ const int grad_weight_stride = 1;
556
+ const int grad_loc_stride = 2;
557
+ const int qid_stride = num_heads * channels;
558
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
559
+
560
+ for (int l_col=0; l_col < num_levels; ++l_col)
561
+ {
562
+ const int level_start_id = data_level_start_index[l_col];
563
+ const int spatial_h_ptr = l_col << 1;
564
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
565
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
566
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
567
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
568
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
569
+
570
+ for (int p_col=0; p_col < num_point; ++p_col)
571
+ {
572
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
573
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
574
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
575
+
576
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
577
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
578
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
579
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
580
+ *(cache_grad_attn_weight+threadIdx.x)=0;
581
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
582
+ {
583
+ ms_deform_attn_col2im_bilinear(
584
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
585
+ top_grad, weight, grad_value_ptr,
586
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
587
+ }
588
+
589
+ __syncthreads();
590
+ if (tid == 0)
591
+ {
592
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
593
+ int sid=2;
594
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
595
+ {
596
+ _grad_w += cache_grad_sampling_loc[sid];
597
+ _grad_h += cache_grad_sampling_loc[sid + 1];
598
+ _grad_a += cache_grad_attn_weight[tid];
599
+ sid += 2;
600
+ }
601
+
602
+
603
+ *grad_sampling_loc = _grad_w;
604
+ *(grad_sampling_loc + 1) = _grad_h;
605
+ *grad_attn_weight = _grad_a;
606
+ }
607
+ __syncthreads();
608
+
609
+ data_weight_ptr += 1;
610
+ data_loc_w_ptr += 2;
611
+ grad_attn_weight += grad_weight_stride;
612
+ grad_sampling_loc += grad_loc_stride;
613
+ }
614
+ }
615
+ }
616
+ }
617
+
618
+ template <typename scalar_t>
619
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
620
+ const scalar_t *grad_col,
621
+ const scalar_t *data_value,
622
+ const int64_t *data_spatial_shapes,
623
+ const int64_t *data_level_start_index,
624
+ const scalar_t *data_sampling_loc,
625
+ const scalar_t *data_attn_weight,
626
+ const int batch_size,
627
+ const int spatial_size,
628
+ const int num_heads,
629
+ const int channels,
630
+ const int num_levels,
631
+ const int num_query,
632
+ const int num_point,
633
+ scalar_t *grad_value,
634
+ scalar_t *grad_sampling_loc,
635
+ scalar_t *grad_attn_weight)
636
+ {
637
+ CUDA_KERNEL_LOOP(index, n)
638
+ {
639
+ extern __shared__ int _s[];
640
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
641
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
642
+ unsigned int tid = threadIdx.x;
643
+ int _temp = index;
644
+ const int c_col = _temp % channels;
645
+ _temp /= channels;
646
+ const int sampling_index = _temp;
647
+ const int m_col = _temp % num_heads;
648
+ _temp /= num_heads;
649
+ const int q_col = _temp % num_query;
650
+ _temp /= num_query;
651
+ const int b_col = _temp;
652
+
653
+ const scalar_t top_grad = grad_col[index];
654
+
655
+ int data_weight_ptr = sampling_index * num_levels * num_point;
656
+ int data_loc_w_ptr = data_weight_ptr << 1;
657
+ const int grad_sampling_ptr = data_weight_ptr;
658
+ grad_sampling_loc += grad_sampling_ptr << 1;
659
+ grad_attn_weight += grad_sampling_ptr;
660
+ const int grad_weight_stride = 1;
661
+ const int grad_loc_stride = 2;
662
+ const int qid_stride = num_heads * channels;
663
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
664
+
665
+ for (int l_col=0; l_col < num_levels; ++l_col)
666
+ {
667
+ const int level_start_id = data_level_start_index[l_col];
668
+ const int spatial_h_ptr = l_col << 1;
669
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
670
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
671
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
672
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
673
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
674
+
675
+ for (int p_col=0; p_col < num_point; ++p_col)
676
+ {
677
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
678
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
679
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
680
+
681
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
682
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
683
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
684
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
685
+ *(cache_grad_attn_weight+threadIdx.x)=0;
686
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
687
+ {
688
+ ms_deform_attn_col2im_bilinear(
689
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
690
+ top_grad, weight, grad_value_ptr,
691
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
692
+ }
693
+
694
+ __syncthreads();
695
+
696
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
697
+ {
698
+ if (tid < s) {
699
+ const unsigned int xid1 = tid << 1;
700
+ const unsigned int xid2 = (tid + s) << 1;
701
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
702
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
703
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
704
+ if (tid + (s << 1) < spre)
705
+ {
706
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
707
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
708
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
709
+ }
710
+ }
711
+ __syncthreads();
712
+ }
713
+
714
+ if (tid == 0)
715
+ {
716
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
717
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
718
+ *grad_attn_weight = cache_grad_attn_weight[0];
719
+ }
720
+ __syncthreads();
721
+
722
+ data_weight_ptr += 1;
723
+ data_loc_w_ptr += 2;
724
+ grad_attn_weight += grad_weight_stride;
725
+ grad_sampling_loc += grad_loc_stride;
726
+ }
727
+ }
728
+ }
729
+ }
730
+
731
+ template <typename scalar_t>
732
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
733
+ const scalar_t *grad_col,
734
+ const scalar_t *data_value,
735
+ const int64_t *data_spatial_shapes,
736
+ const int64_t *data_level_start_index,
737
+ const scalar_t *data_sampling_loc,
738
+ const scalar_t *data_attn_weight,
739
+ const int batch_size,
740
+ const int spatial_size,
741
+ const int num_heads,
742
+ const int channels,
743
+ const int num_levels,
744
+ const int num_query,
745
+ const int num_point,
746
+ scalar_t *grad_value,
747
+ scalar_t *grad_sampling_loc,
748
+ scalar_t *grad_attn_weight)
749
+ {
750
+ CUDA_KERNEL_LOOP(index, n)
751
+ {
752
+ extern __shared__ int _s[];
753
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
754
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
755
+ unsigned int tid = threadIdx.x;
756
+ int _temp = index;
757
+ const int c_col = _temp % channels;
758
+ _temp /= channels;
759
+ const int sampling_index = _temp;
760
+ const int m_col = _temp % num_heads;
761
+ _temp /= num_heads;
762
+ const int q_col = _temp % num_query;
763
+ _temp /= num_query;
764
+ const int b_col = _temp;
765
+
766
+ const scalar_t top_grad = grad_col[index];
767
+
768
+ int data_weight_ptr = sampling_index * num_levels * num_point;
769
+ int data_loc_w_ptr = data_weight_ptr << 1;
770
+ const int grad_sampling_ptr = data_weight_ptr;
771
+ grad_sampling_loc += grad_sampling_ptr << 1;
772
+ grad_attn_weight += grad_sampling_ptr;
773
+ const int grad_weight_stride = 1;
774
+ const int grad_loc_stride = 2;
775
+ const int qid_stride = num_heads * channels;
776
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
777
+
778
+ for (int l_col=0; l_col < num_levels; ++l_col)
779
+ {
780
+ const int level_start_id = data_level_start_index[l_col];
781
+ const int spatial_h_ptr = l_col << 1;
782
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
783
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
784
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
785
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
786
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
787
+
788
+ for (int p_col=0; p_col < num_point; ++p_col)
789
+ {
790
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
791
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
792
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
793
+
794
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
795
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
796
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
797
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
798
+ *(cache_grad_attn_weight+threadIdx.x)=0;
799
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
800
+ {
801
+ ms_deform_attn_col2im_bilinear(
802
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
803
+ top_grad, weight, grad_value_ptr,
804
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
805
+ }
806
+
807
+ __syncthreads();
808
+
809
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
810
+ {
811
+ if (tid < s) {
812
+ const unsigned int xid1 = tid << 1;
813
+ const unsigned int xid2 = (tid + s) << 1;
814
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
815
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
816
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
817
+ if (tid + (s << 1) < spre)
818
+ {
819
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
820
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
821
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
822
+ }
823
+ }
824
+ __syncthreads();
825
+ }
826
+
827
+ if (tid == 0)
828
+ {
829
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
830
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
831
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
832
+ }
833
+ __syncthreads();
834
+
835
+ data_weight_ptr += 1;
836
+ data_loc_w_ptr += 2;
837
+ grad_attn_weight += grad_weight_stride;
838
+ grad_sampling_loc += grad_loc_stride;
839
+ }
840
+ }
841
+ }
842
+ }
843
+
844
+
845
+ template <typename scalar_t>
846
+ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
847
+ const scalar_t *grad_col,
848
+ const scalar_t *data_value,
849
+ const int64_t *data_spatial_shapes,
850
+ const int64_t *data_level_start_index,
851
+ const scalar_t *data_sampling_loc,
852
+ const scalar_t *data_attn_weight,
853
+ const int batch_size,
854
+ const int spatial_size,
855
+ const int num_heads,
856
+ const int channels,
857
+ const int num_levels,
858
+ const int num_query,
859
+ const int num_point,
860
+ scalar_t *grad_value,
861
+ scalar_t *grad_sampling_loc,
862
+ scalar_t *grad_attn_weight)
863
+ {
864
+ CUDA_KERNEL_LOOP(index, n)
865
+ {
866
+ int _temp = index;
867
+ const int c_col = _temp % channels;
868
+ _temp /= channels;
869
+ const int sampling_index = _temp;
870
+ const int m_col = _temp % num_heads;
871
+ _temp /= num_heads;
872
+ const int q_col = _temp % num_query;
873
+ _temp /= num_query;
874
+ const int b_col = _temp;
875
+
876
+ const scalar_t top_grad = grad_col[index];
877
+
878
+ int data_weight_ptr = sampling_index * num_levels * num_point;
879
+ int data_loc_w_ptr = data_weight_ptr << 1;
880
+ const int grad_sampling_ptr = data_weight_ptr;
881
+ grad_sampling_loc += grad_sampling_ptr << 1;
882
+ grad_attn_weight += grad_sampling_ptr;
883
+ const int grad_weight_stride = 1;
884
+ const int grad_loc_stride = 2;
885
+ const int qid_stride = num_heads * channels;
886
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
887
+
888
+ for (int l_col=0; l_col < num_levels; ++l_col)
889
+ {
890
+ const int level_start_id = data_level_start_index[l_col];
891
+ const int spatial_h_ptr = l_col << 1;
892
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
893
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
894
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
895
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
896
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
897
+
898
+ for (int p_col=0; p_col < num_point; ++p_col)
899
+ {
900
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
901
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
902
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
903
+
904
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
905
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
906
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
907
+ {
908
+ ms_deform_attn_col2im_bilinear_gm(
909
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
910
+ top_grad, weight, grad_value_ptr,
911
+ grad_sampling_loc, grad_attn_weight);
912
+ }
913
+ data_weight_ptr += 1;
914
+ data_loc_w_ptr += 2;
915
+ grad_attn_weight += grad_weight_stride;
916
+ grad_sampling_loc += grad_loc_stride;
917
+ }
918
+ }
919
+ }
920
+ }
921
+
922
+
923
+ template <typename scalar_t>
924
+ void ms_deformable_im2col_cuda(cudaStream_t stream,
925
+ const scalar_t* data_value,
926
+ const int64_t* data_spatial_shapes,
927
+ const int64_t* data_level_start_index,
928
+ const scalar_t* data_sampling_loc,
929
+ const scalar_t* data_attn_weight,
930
+ const int batch_size,
931
+ const int spatial_size,
932
+ const int num_heads,
933
+ const int channels,
934
+ const int num_levels,
935
+ const int num_query,
936
+ const int num_point,
937
+ scalar_t* data_col)
938
+ {
939
+ const int num_kernels = batch_size * num_query * num_heads * channels;
940
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
941
+ const int num_threads = CUDA_NUM_THREADS;
942
+ ms_deformable_im2col_gpu_kernel<scalar_t>
943
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
944
+ 0, stream>>>(
945
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
946
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
947
+
948
+ cudaError_t err = cudaGetLastError();
949
+ if (err != cudaSuccess)
950
+ {
951
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
952
+ }
953
+
954
+ }
955
+
956
+ template <typename scalar_t>
957
+ void ms_deformable_col2im_cuda(cudaStream_t stream,
958
+ const scalar_t* grad_col,
959
+ const scalar_t* data_value,
960
+ const int64_t * data_spatial_shapes,
961
+ const int64_t * data_level_start_index,
962
+ const scalar_t * data_sampling_loc,
963
+ const scalar_t * data_attn_weight,
964
+ const int batch_size,
965
+ const int spatial_size,
966
+ const int num_heads,
967
+ const int channels,
968
+ const int num_levels,
969
+ const int num_query,
970
+ const int num_point,
971
+ scalar_t* grad_value,
972
+ scalar_t* grad_sampling_loc,
973
+ scalar_t* grad_attn_weight)
974
+ {
975
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
976
+ const int num_kernels = batch_size * num_query * num_heads * channels;
977
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
978
+ if (channels > 1024)
979
+ {
980
+ if ((channels & 1023) == 0)
981
+ {
982
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
983
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
984
+ num_threads*3*sizeof(scalar_t), stream>>>(
985
+ num_kernels,
986
+ grad_col,
987
+ data_value,
988
+ data_spatial_shapes,
989
+ data_level_start_index,
990
+ data_sampling_loc,
991
+ data_attn_weight,
992
+ batch_size,
993
+ spatial_size,
994
+ num_heads,
995
+ channels,
996
+ num_levels,
997
+ num_query,
998
+ num_point,
999
+ grad_value,
1000
+ grad_sampling_loc,
1001
+ grad_attn_weight);
1002
+ }
1003
+ else
1004
+ {
1005
+ ms_deformable_col2im_gpu_kernel_gm<scalar_t>
1006
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1007
+ 0, stream>>>(
1008
+ num_kernels,
1009
+ grad_col,
1010
+ data_value,
1011
+ data_spatial_shapes,
1012
+ data_level_start_index,
1013
+ data_sampling_loc,
1014
+ data_attn_weight,
1015
+ batch_size,
1016
+ spatial_size,
1017
+ num_heads,
1018
+ channels,
1019
+ num_levels,
1020
+ num_query,
1021
+ num_point,
1022
+ grad_value,
1023
+ grad_sampling_loc,
1024
+ grad_attn_weight);
1025
+ }
1026
+ }
1027
+ else{
1028
+ switch(channels)
1029
+ {
1030
+ case 1:
1031
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
1032
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1033
+ 0, stream>>>(
1034
+ num_kernels,
1035
+ grad_col,
1036
+ data_value,
1037
+ data_spatial_shapes,
1038
+ data_level_start_index,
1039
+ data_sampling_loc,
1040
+ data_attn_weight,
1041
+ batch_size,
1042
+ spatial_size,
1043
+ num_heads,
1044
+ channels,
1045
+ num_levels,
1046
+ num_query,
1047
+ num_point,
1048
+ grad_value,
1049
+ grad_sampling_loc,
1050
+ grad_attn_weight);
1051
+ break;
1052
+ case 2:
1053
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
1054
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1055
+ 0, stream>>>(
1056
+ num_kernels,
1057
+ grad_col,
1058
+ data_value,
1059
+ data_spatial_shapes,
1060
+ data_level_start_index,
1061
+ data_sampling_loc,
1062
+ data_attn_weight,
1063
+ batch_size,
1064
+ spatial_size,
1065
+ num_heads,
1066
+ channels,
1067
+ num_levels,
1068
+ num_query,
1069
+ num_point,
1070
+ grad_value,
1071
+ grad_sampling_loc,
1072
+ grad_attn_weight);
1073
+ break;
1074
+ case 4:
1075
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
1076
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1077
+ 0, stream>>>(
1078
+ num_kernels,
1079
+ grad_col,
1080
+ data_value,
1081
+ data_spatial_shapes,
1082
+ data_level_start_index,
1083
+ data_sampling_loc,
1084
+ data_attn_weight,
1085
+ batch_size,
1086
+ spatial_size,
1087
+ num_heads,
1088
+ channels,
1089
+ num_levels,
1090
+ num_query,
1091
+ num_point,
1092
+ grad_value,
1093
+ grad_sampling_loc,
1094
+ grad_attn_weight);
1095
+ break;
1096
+ case 8:
1097
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
1098
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1099
+ 0, stream>>>(
1100
+ num_kernels,
1101
+ grad_col,
1102
+ data_value,
1103
+ data_spatial_shapes,
1104
+ data_level_start_index,
1105
+ data_sampling_loc,
1106
+ data_attn_weight,
1107
+ batch_size,
1108
+ spatial_size,
1109
+ num_heads,
1110
+ channels,
1111
+ num_levels,
1112
+ num_query,
1113
+ num_point,
1114
+ grad_value,
1115
+ grad_sampling_loc,
1116
+ grad_attn_weight);
1117
+ break;
1118
+ case 16:
1119
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
1120
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1121
+ 0, stream>>>(
1122
+ num_kernels,
1123
+ grad_col,
1124
+ data_value,
1125
+ data_spatial_shapes,
1126
+ data_level_start_index,
1127
+ data_sampling_loc,
1128
+ data_attn_weight,
1129
+ batch_size,
1130
+ spatial_size,
1131
+ num_heads,
1132
+ channels,
1133
+ num_levels,
1134
+ num_query,
1135
+ num_point,
1136
+ grad_value,
1137
+ grad_sampling_loc,
1138
+ grad_attn_weight);
1139
+ break;
1140
+ case 32:
1141
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
1142
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1143
+ 0, stream>>>(
1144
+ num_kernels,
1145
+ grad_col,
1146
+ data_value,
1147
+ data_spatial_shapes,
1148
+ data_level_start_index,
1149
+ data_sampling_loc,
1150
+ data_attn_weight,
1151
+ batch_size,
1152
+ spatial_size,
1153
+ num_heads,
1154
+ channels,
1155
+ num_levels,
1156
+ num_query,
1157
+ num_point,
1158
+ grad_value,
1159
+ grad_sampling_loc,
1160
+ grad_attn_weight);
1161
+ break;
1162
+ case 64:
1163
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
1164
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1165
+ 0, stream>>>(
1166
+ num_kernels,
1167
+ grad_col,
1168
+ data_value,
1169
+ data_spatial_shapes,
1170
+ data_level_start_index,
1171
+ data_sampling_loc,
1172
+ data_attn_weight,
1173
+ batch_size,
1174
+ spatial_size,
1175
+ num_heads,
1176
+ channels,
1177
+ num_levels,
1178
+ num_query,
1179
+ num_point,
1180
+ grad_value,
1181
+ grad_sampling_loc,
1182
+ grad_attn_weight);
1183
+ break;
1184
+ case 128:
1185
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
1186
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1187
+ 0, stream>>>(
1188
+ num_kernels,
1189
+ grad_col,
1190
+ data_value,
1191
+ data_spatial_shapes,
1192
+ data_level_start_index,
1193
+ data_sampling_loc,
1194
+ data_attn_weight,
1195
+ batch_size,
1196
+ spatial_size,
1197
+ num_heads,
1198
+ channels,
1199
+ num_levels,
1200
+ num_query,
1201
+ num_point,
1202
+ grad_value,
1203
+ grad_sampling_loc,
1204
+ grad_attn_weight);
1205
+ break;
1206
+ case 256:
1207
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
1208
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1209
+ 0, stream>>>(
1210
+ num_kernels,
1211
+ grad_col,
1212
+ data_value,
1213
+ data_spatial_shapes,
1214
+ data_level_start_index,
1215
+ data_sampling_loc,
1216
+ data_attn_weight,
1217
+ batch_size,
1218
+ spatial_size,
1219
+ num_heads,
1220
+ channels,
1221
+ num_levels,
1222
+ num_query,
1223
+ num_point,
1224
+ grad_value,
1225
+ grad_sampling_loc,
1226
+ grad_attn_weight);
1227
+ break;
1228
+ case 512:
1229
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
1230
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1231
+ 0, stream>>>(
1232
+ num_kernels,
1233
+ grad_col,
1234
+ data_value,
1235
+ data_spatial_shapes,
1236
+ data_level_start_index,
1237
+ data_sampling_loc,
1238
+ data_attn_weight,
1239
+ batch_size,
1240
+ spatial_size,
1241
+ num_heads,
1242
+ channels,
1243
+ num_levels,
1244
+ num_query,
1245
+ num_point,
1246
+ grad_value,
1247
+ grad_sampling_loc,
1248
+ grad_attn_weight);
1249
+ break;
1250
+ case 1024:
1251
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
1252
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1253
+ 0, stream>>>(
1254
+ num_kernels,
1255
+ grad_col,
1256
+ data_value,
1257
+ data_spatial_shapes,
1258
+ data_level_start_index,
1259
+ data_sampling_loc,
1260
+ data_attn_weight,
1261
+ batch_size,
1262
+ spatial_size,
1263
+ num_heads,
1264
+ channels,
1265
+ num_levels,
1266
+ num_query,
1267
+ num_point,
1268
+ grad_value,
1269
+ grad_sampling_loc,
1270
+ grad_attn_weight);
1271
+ break;
1272
+ default:
1273
+ if (channels < 64)
1274
+ {
1275
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
1276
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1277
+ num_threads*3*sizeof(scalar_t), stream>>>(
1278
+ num_kernels,
1279
+ grad_col,
1280
+ data_value,
1281
+ data_spatial_shapes,
1282
+ data_level_start_index,
1283
+ data_sampling_loc,
1284
+ data_attn_weight,
1285
+ batch_size,
1286
+ spatial_size,
1287
+ num_heads,
1288
+ channels,
1289
+ num_levels,
1290
+ num_query,
1291
+ num_point,
1292
+ grad_value,
1293
+ grad_sampling_loc,
1294
+ grad_attn_weight);
1295
+ }
1296
+ else
1297
+ {
1298
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
1299
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1300
+ num_threads*3*sizeof(scalar_t), stream>>>(
1301
+ num_kernels,
1302
+ grad_col,
1303
+ data_value,
1304
+ data_spatial_shapes,
1305
+ data_level_start_index,
1306
+ data_sampling_loc,
1307
+ data_attn_weight,
1308
+ batch_size,
1309
+ spatial_size,
1310
+ num_heads,
1311
+ channels,
1312
+ num_levels,
1313
+ num_query,
1314
+ num_point,
1315
+ grad_value,
1316
+ grad_sampling_loc,
1317
+ grad_attn_weight);
1318
+ }
1319
+ }
1320
+ }
1321
+ cudaError_t err = cudaGetLastError();
1322
+ if (err != cudaSuccess)
1323
+ {
1324
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
1325
+ }
1326
+
1327
+ }
groundingdino/models/GroundingDINO/csrc/cuda_version.cu ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #include <cuda_runtime_api.h>
2
+
3
+ namespace groundingdino {
4
+ int get_cudart_version() {
5
+ return CUDART_VERSION;
6
+ }
7
+ } // namespace groundingdino
groundingdino/models/GroundingDINO/csrc/vision.cpp ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ #include "MsDeformAttn/ms_deform_attn.h"
4
+
5
+ namespace groundingdino {
6
+
7
+ #ifdef WITH_CUDA
8
+ extern int get_cudart_version();
9
+ #endif
10
+
11
+ std::string get_cuda_version() {
12
+ #ifdef WITH_CUDA
13
+ std::ostringstream oss;
14
+
15
+ // copied from
16
+ // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231
17
+ auto printCudaStyleVersion = [&](int v) {
18
+ oss << (v / 1000) << "." << (v / 10 % 100);
19
+ if (v % 10 != 0) {
20
+ oss << "." << (v % 10);
21
+ }
22
+ };
23
+ printCudaStyleVersion(get_cudart_version());
24
+ return oss.str();
25
+ #else
26
+ return std::string("not available");
27
+ #endif
28
+ }
29
+
30
+ // similar to
31
+ // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Version.cpp
32
+ std::string get_compiler_version() {
33
+ std::ostringstream ss;
34
+ #if defined(__GNUC__)
35
+ #ifndef __clang__
36
+ { ss << "GCC " << __GNUC__ << "." << __GNUC_MINOR__; }
37
+ #endif
38
+ #endif
39
+
40
+ #if defined(__clang_major__)
41
+ {
42
+ ss << "clang " << __clang_major__ << "." << __clang_minor__ << "."
43
+ << __clang_patchlevel__;
44
+ }
45
+ #endif
46
+
47
+ #if defined(_MSC_VER)
48
+ { ss << "MSVC " << _MSC_FULL_VER; }
49
+ #endif
50
+ return ss.str();
51
+ }
52
+
53
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
54
+ m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
55
+ m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
56
+ }
57
+
58
+ } // namespace groundingdino
groundingdino/util/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
groundingdino/util/box_ops.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Utilities for bounding box manipulation and GIoU.
4
+ """
5
+ import torch
6
+ from torchvision.ops.boxes import box_area
7
+
8
+
9
+ def box_cxcywh_to_xyxy(x):
10
+ x_c, y_c, w, h = x.unbind(-1)
11
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
12
+ return torch.stack(b, dim=-1)
13
+
14
+
15
+ def box_xyxy_to_cxcywh(x):
16
+ x0, y0, x1, y1 = x.unbind(-1)
17
+ b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
18
+ return torch.stack(b, dim=-1)
19
+
20
+
21
+ # modified from torchvision to also return the union
22
+ def box_iou(boxes1, boxes2):
23
+ area1 = box_area(boxes1)
24
+ area2 = box_area(boxes2)
25
+
26
+ # import ipdb; ipdb.set_trace()
27
+ lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
28
+ rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
29
+
30
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
31
+ inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
32
+
33
+ union = area1[:, None] + area2 - inter
34
+
35
+ iou = inter / (union + 1e-6)
36
+ return iou, union
37
+
38
+
39
+ def generalized_box_iou(boxes1, boxes2):
40
+ """
41
+ Generalized IoU from https://giou.stanford.edu/
42
+
43
+ The boxes should be in [x0, y0, x1, y1] format
44
+
45
+ Returns a [N, M] pairwise matrix, where N = len(boxes1)
46
+ and M = len(boxes2)
47
+ """
48
+ # degenerate boxes gives inf / nan results
49
+ # so do an early check
50
+ assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
51
+ assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
52
+ # except:
53
+ # import ipdb; ipdb.set_trace()
54
+ iou, union = box_iou(boxes1, boxes2)
55
+
56
+ lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
57
+ rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
58
+
59
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
60
+ area = wh[:, :, 0] * wh[:, :, 1]
61
+
62
+ return iou - (area - union) / (area + 1e-6)
63
+
64
+
65
+ # modified from torchvision to also return the union
66
+ def box_iou_pairwise(boxes1, boxes2):
67
+ area1 = box_area(boxes1)
68
+ area2 = box_area(boxes2)
69
+
70
+ lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N,2]
71
+ rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N,2]
72
+
73
+ wh = (rb - lt).clamp(min=0) # [N,2]
74
+ inter = wh[:, 0] * wh[:, 1] # [N]
75
+
76
+ union = area1 + area2 - inter
77
+
78
+ iou = inter / union
79
+ return iou, union
80
+
81
+
82
+ def generalized_box_iou_pairwise(boxes1, boxes2):
83
+ """
84
+ Generalized IoU from https://giou.stanford.edu/
85
+
86
+ Input:
87
+ - boxes1, boxes2: N,4
88
+ Output:
89
+ - giou: N, 4
90
+ """
91
+ # degenerate boxes gives inf / nan results
92
+ # so do an early check
93
+ assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
94
+ assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
95
+ assert boxes1.shape == boxes2.shape
96
+ iou, union = box_iou_pairwise(boxes1, boxes2) # N, 4
97
+
98
+ lt = torch.min(boxes1[:, :2], boxes2[:, :2])
99
+ rb = torch.max(boxes1[:, 2:], boxes2[:, 2:])
100
+
101
+ wh = (rb - lt).clamp(min=0) # [N,2]
102
+ area = wh[:, 0] * wh[:, 1]
103
+
104
+ return iou - (area - union) / area
105
+
106
+
107
+ def masks_to_boxes(masks):
108
+ """Compute the bounding boxes around the provided masks
109
+
110
+ The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
111
+
112
+ Returns a [N, 4] tensors, with the boxes in xyxy format
113
+ """
114
+ if masks.numel() == 0:
115
+ return torch.zeros((0, 4), device=masks.device)
116
+
117
+ h, w = masks.shape[-2:]
118
+
119
+ y = torch.arange(0, h, dtype=torch.float)
120
+ x = torch.arange(0, w, dtype=torch.float)
121
+ y, x = torch.meshgrid(y, x)
122
+
123
+ x_mask = masks * x.unsqueeze(0)
124
+ x_max = x_mask.flatten(1).max(-1)[0]
125
+ x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
126
+
127
+ y_mask = masks * y.unsqueeze(0)
128
+ y_max = y_mask.flatten(1).max(-1)[0]
129
+ y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
130
+
131
+ return torch.stack([x_min, y_min, x_max, y_max], 1)
132
+
133
+
134
+ if __name__ == "__main__":
135
+ x = torch.rand(5, 4)
136
+ y = torch.rand(3, 4)
137
+ iou, union = box_iou(x, y)
138
+ import ipdb
139
+
140
+ ipdb.set_trace()
groundingdino/util/inference.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import supervision as sv
6
+ import torch
7
+ from PIL import Image
8
+ from torchvision.ops import box_convert
9
+ import bisect
10
+
11
+ import groundingdino.datasets.transforms as T
12
+ from groundingdino.models import build_model
13
+ from groundingdino.util.misc import clean_state_dict
14
+ from groundingdino.util.slconfig import SLConfig
15
+ from groundingdino.util.utils import get_phrases_from_posmap
16
+
17
+ # ----------------------------------------------------------------------------------------------------------------------
18
+ # OLD API
19
+ # ----------------------------------------------------------------------------------------------------------------------
20
+
21
+
22
+ def preprocess_caption(caption: str) -> str:
23
+ result = caption.lower().strip()
24
+ if result.endswith("."):
25
+ return result
26
+ return result + "."
27
+
28
+
29
+ def load_model(model_config_path: str, model_checkpoint_path: str, device: str = "cuda"):
30
+ args = SLConfig.fromfile(model_config_path)
31
+ args.device = device
32
+ model = build_model(args)
33
+ checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
34
+ model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
35
+ model.eval()
36
+ return model
37
+
38
+
39
+ def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
40
+ transform = T.Compose(
41
+ [
42
+ T.RandomResize([800], max_size=1333),
43
+ T.ToTensor(),
44
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
45
+ ]
46
+ )
47
+ image_source = Image.open(image_path).convert("RGB")
48
+ image = np.asarray(image_source)
49
+ image_transformed, _ = transform(image_source, None)
50
+ return image, image_transformed
51
+
52
+
53
+ def predict(
54
+ model,
55
+ image: torch.Tensor,
56
+ caption: str,
57
+ box_threshold: float,
58
+ text_threshold: float,
59
+ device: str = "cuda",
60
+ remove_combined: bool = False
61
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
62
+ caption = preprocess_caption(caption=caption)
63
+
64
+ model = model.to(device)
65
+ image = image.to(device)
66
+
67
+ with torch.no_grad():
68
+ outputs = model(image[None], captions=[caption])
69
+
70
+ prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256)
71
+ prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4)
72
+
73
+ mask = prediction_logits.max(dim=1)[0] > box_threshold
74
+ logits = prediction_logits[mask] # logits.shape = (n, 256)
75
+ boxes = prediction_boxes[mask] # boxes.shape = (n, 4)
76
+
77
+ tokenizer = model.tokenizer
78
+ tokenized = tokenizer(caption)
79
+
80
+ if remove_combined:
81
+ sep_idx = [i for i in range(len(tokenized['input_ids'])) if tokenized['input_ids'][i] in [101, 102, 1012]]
82
+
83
+ phrases = []
84
+ for logit in logits:
85
+ max_idx = logit.argmax()
86
+ insert_idx = bisect.bisect_left(sep_idx, max_idx)
87
+ right_idx = sep_idx[insert_idx]
88
+ left_idx = sep_idx[insert_idx - 1]
89
+ phrases.append(get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer, left_idx, right_idx).replace('.', ''))
90
+ else:
91
+ phrases = [
92
+ get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
93
+ for logit
94
+ in logits
95
+ ]
96
+
97
+ return boxes, logits.max(dim=1)[0], phrases
98
+
99
+
100
+ def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str]) -> np.ndarray:
101
+ h, w, _ = image_source.shape
102
+ boxes = boxes * torch.Tensor([w, h, w, h])
103
+ xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
104
+ detections = sv.Detections(xyxy=xyxy)
105
+
106
+ labels = [
107
+ f"{phrase} {logit:.2f}"
108
+ for phrase, logit
109
+ in zip(phrases, logits)
110
+ ]
111
+
112
+ box_annotator = sv.BoxAnnotator()
113
+ annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)
114
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
115
+ return annotated_frame
116
+
117
+
118
+ # ----------------------------------------------------------------------------------------------------------------------
119
+ # NEW API
120
+ # ----------------------------------------------------------------------------------------------------------------------
121
+
122
+
123
+ class Model:
124
+
125
+ def __init__(
126
+ self,
127
+ model_config_path: str,
128
+ model_checkpoint_path: str,
129
+ device: str = "cuda"
130
+ ):
131
+ self.model = load_model(
132
+ model_config_path=model_config_path,
133
+ model_checkpoint_path=model_checkpoint_path,
134
+ device=device
135
+ ).to(device)
136
+ self.device = device
137
+
138
+ def predict_with_caption(
139
+ self,
140
+ image: np.ndarray,
141
+ caption: str,
142
+ box_threshold: float = 0.35,
143
+ text_threshold: float = 0.25
144
+ ) -> Tuple[sv.Detections, List[str]]:
145
+ """
146
+ import cv2
147
+
148
+ image = cv2.imread(IMAGE_PATH)
149
+
150
+ model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
151
+ detections, labels = model.predict_with_caption(
152
+ image=image,
153
+ caption=caption,
154
+ box_threshold=BOX_THRESHOLD,
155
+ text_threshold=TEXT_THRESHOLD
156
+ )
157
+
158
+ import supervision as sv
159
+
160
+ box_annotator = sv.BoxAnnotator()
161
+ annotated_image = box_annotator.annotate(scene=image, detections=detections, labels=labels)
162
+ """
163
+ processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
164
+ boxes, logits, phrases = predict(
165
+ model=self.model,
166
+ image=processed_image,
167
+ caption=caption,
168
+ box_threshold=box_threshold,
169
+ text_threshold=text_threshold,
170
+ device=self.device)
171
+ source_h, source_w, _ = image.shape
172
+ detections = Model.post_process_result(
173
+ source_h=source_h,
174
+ source_w=source_w,
175
+ boxes=boxes,
176
+ logits=logits)
177
+ return detections, phrases
178
+
179
+ def predict_with_classes(
180
+ self,
181
+ image: np.ndarray,
182
+ classes: List[str],
183
+ box_threshold: float,
184
+ text_threshold: float
185
+ ) -> sv.Detections:
186
+ """
187
+ import cv2
188
+
189
+ image = cv2.imread(IMAGE_PATH)
190
+
191
+ model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
192
+ detections = model.predict_with_classes(
193
+ image=image,
194
+ classes=CLASSES,
195
+ box_threshold=BOX_THRESHOLD,
196
+ text_threshold=TEXT_THRESHOLD
197
+ )
198
+
199
+
200
+ import supervision as sv
201
+
202
+ box_annotator = sv.BoxAnnotator()
203
+ annotated_image = box_annotator.annotate(scene=image, detections=detections)
204
+ """
205
+ caption = ". ".join(classes)
206
+ processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
207
+ boxes, logits, phrases = predict(
208
+ model=self.model,
209
+ image=processed_image,
210
+ caption=caption,
211
+ box_threshold=box_threshold,
212
+ text_threshold=text_threshold,
213
+ device=self.device)
214
+ source_h, source_w, _ = image.shape
215
+ detections = Model.post_process_result(
216
+ source_h=source_h,
217
+ source_w=source_w,
218
+ boxes=boxes,
219
+ logits=logits)
220
+ class_id = Model.phrases2classes(phrases=phrases, classes=classes)
221
+ detections.class_id = class_id
222
+ return detections
223
+
224
+ @staticmethod
225
+ def preprocess_image(image_bgr: np.ndarray) -> torch.Tensor:
226
+ transform = T.Compose(
227
+ [
228
+ T.RandomResize([800], max_size=1333),
229
+ T.ToTensor(),
230
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
231
+ ]
232
+ )
233
+ image_pillow = Image.fromarray(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))
234
+ image_transformed, _ = transform(image_pillow, None)
235
+ return image_transformed
236
+
237
+ @staticmethod
238
+ def post_process_result(
239
+ source_h: int,
240
+ source_w: int,
241
+ boxes: torch.Tensor,
242
+ logits: torch.Tensor
243
+ ) -> sv.Detections:
244
+ boxes = boxes * torch.Tensor([source_w, source_h, source_w, source_h])
245
+ xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
246
+ confidence = logits.numpy()
247
+ return sv.Detections(xyxy=xyxy, confidence=confidence)
248
+
249
+ @staticmethod
250
+ def phrases2classes(phrases: List[str], classes: List[str]) -> np.ndarray:
251
+ class_ids = []
252
+ for phrase in phrases:
253
+ for class_ in classes:
254
+ if class_ in phrase:
255
+ class_ids.append(classes.index(class_))
256
+ break
257
+ else:
258
+ class_ids.append(None)
259
+ return np.array(class_ids)
groundingdino/util/logger.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ import functools
3
+ import logging
4
+ import os
5
+ import sys
6
+
7
+ from termcolor import colored
8
+
9
+
10
+ class _ColorfulFormatter(logging.Formatter):
11
+ def __init__(self, *args, **kwargs):
12
+ self._root_name = kwargs.pop("root_name") + "."
13
+ self._abbrev_name = kwargs.pop("abbrev_name", "")
14
+ if len(self._abbrev_name):
15
+ self._abbrev_name = self._abbrev_name + "."
16
+ super(_ColorfulFormatter, self).__init__(*args, **kwargs)
17
+
18
+ def formatMessage(self, record):
19
+ record.name = record.name.replace(self._root_name, self._abbrev_name)
20
+ log = super(_ColorfulFormatter, self).formatMessage(record)
21
+ if record.levelno == logging.WARNING:
22
+ prefix = colored("WARNING", "red", attrs=["blink"])
23
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
24
+ prefix = colored("ERROR", "red", attrs=["blink", "underline"])
25
+ else:
26
+ return log
27
+ return prefix + " " + log
28
+
29
+
30
+ # so that calling setup_logger multiple times won't add many handlers
31
+ @functools.lru_cache()
32
+ def setup_logger(output=None, distributed_rank=0, *, color=True, name="imagenet", abbrev_name=None):
33
+ """
34
+ Initialize the detectron2 logger and set its verbosity level to "INFO".
35
+
36
+ Args:
37
+ output (str): a file name or a directory to save log. If None, will not save log file.
38
+ If ends with ".txt" or ".log", assumed to be a file name.
39
+ Otherwise, logs will be saved to `output/log.txt`.
40
+ name (str): the root module name of this logger
41
+
42
+ Returns:
43
+ logging.Logger: a logger
44
+ """
45
+ logger = logging.getLogger(name)
46
+ logger.setLevel(logging.DEBUG)
47
+ logger.propagate = False
48
+
49
+ if abbrev_name is None:
50
+ abbrev_name = name
51
+
52
+ plain_formatter = logging.Formatter(
53
+ "[%(asctime)s.%(msecs)03d]: %(message)s", datefmt="%m/%d %H:%M:%S"
54
+ )
55
+ # stdout logging: master only
56
+ if distributed_rank == 0:
57
+ ch = logging.StreamHandler(stream=sys.stdout)
58
+ ch.setLevel(logging.DEBUG)
59
+ if color:
60
+ formatter = _ColorfulFormatter(
61
+ colored("[%(asctime)s.%(msecs)03d]: ", "green") + "%(message)s",
62
+ datefmt="%m/%d %H:%M:%S",
63
+ root_name=name,
64
+ abbrev_name=str(abbrev_name),
65
+ )
66
+ else:
67
+ formatter = plain_formatter
68
+ ch.setFormatter(formatter)
69
+ logger.addHandler(ch)
70
+
71
+ # file logging: all workers
72
+ if output is not None:
73
+ if output.endswith(".txt") or output.endswith(".log"):
74
+ filename = output
75
+ else:
76
+ filename = os.path.join(output, "log.txt")
77
+ if distributed_rank > 0:
78
+ filename = filename + f".rank{distributed_rank}"
79
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
80
+
81
+ fh = logging.StreamHandler(_cached_log_stream(filename))
82
+ fh.setLevel(logging.DEBUG)
83
+ fh.setFormatter(plain_formatter)
84
+ logger.addHandler(fh)
85
+
86
+ return logger
87
+
88
+
89
+ # cache the opened file object, so that different calls to `setup_logger`
90
+ # with the same file name can safely write to the same file.
91
+ @functools.lru_cache(maxsize=None)
92
+ def _cached_log_stream(filename):
93
+ return open(filename, "a")
groundingdino/util/misc.py ADDED
@@ -0,0 +1,717 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Misc functions, including distributed helpers.
4
+
5
+ Mostly copy-paste from torchvision references.
6
+ """
7
+ import colorsys
8
+ import datetime
9
+ import functools
10
+ import io
11
+ import json
12
+ import os
13
+ import pickle
14
+ import subprocess
15
+ import time
16
+ from collections import OrderedDict, defaultdict, deque
17
+ from typing import List, Optional
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.distributed as dist
22
+
23
+ # needed due to empty tensor bug in pytorch and torchvision 0.5
24
+ import torchvision
25
+ from torch import Tensor
26
+
27
+ __torchvision_need_compat_flag = float(torchvision.__version__.split(".")[1]) < 7
28
+ if __torchvision_need_compat_flag:
29
+ from torchvision.ops import _new_empty_tensor
30
+ from torchvision.ops.misc import _output_size
31
+
32
+
33
+ class SmoothedValue(object):
34
+ """Track a series of values and provide access to smoothed values over a
35
+ window or the global series average.
36
+ """
37
+
38
+ def __init__(self, window_size=20, fmt=None):
39
+ if fmt is None:
40
+ fmt = "{median:.4f} ({global_avg:.4f})"
41
+ self.deque = deque(maxlen=window_size)
42
+ self.total = 0.0
43
+ self.count = 0
44
+ self.fmt = fmt
45
+
46
+ def update(self, value, n=1):
47
+ self.deque.append(value)
48
+ self.count += n
49
+ self.total += value * n
50
+
51
+ def synchronize_between_processes(self):
52
+ """
53
+ Warning: does not synchronize the deque!
54
+ """
55
+ if not is_dist_avail_and_initialized():
56
+ return
57
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
58
+ dist.barrier()
59
+ dist.all_reduce(t)
60
+ t = t.tolist()
61
+ self.count = int(t[0])
62
+ self.total = t[1]
63
+
64
+ @property
65
+ def median(self):
66
+ d = torch.tensor(list(self.deque))
67
+ if d.shape[0] == 0:
68
+ return 0
69
+ return d.median().item()
70
+
71
+ @property
72
+ def avg(self):
73
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
74
+ return d.mean().item()
75
+
76
+ @property
77
+ def global_avg(self):
78
+ if os.environ.get("SHILONG_AMP", None) == "1":
79
+ eps = 1e-4
80
+ else:
81
+ eps = 1e-6
82
+ return self.total / (self.count + eps)
83
+
84
+ @property
85
+ def max(self):
86
+ return max(self.deque)
87
+
88
+ @property
89
+ def value(self):
90
+ return self.deque[-1]
91
+
92
+ def __str__(self):
93
+ return self.fmt.format(
94
+ median=self.median,
95
+ avg=self.avg,
96
+ global_avg=self.global_avg,
97
+ max=self.max,
98
+ value=self.value,
99
+ )
100
+
101
+
102
+ @functools.lru_cache()
103
+ def _get_global_gloo_group():
104
+ """
105
+ Return a process group based on gloo backend, containing all the ranks
106
+ The result is cached.
107
+ """
108
+
109
+ if dist.get_backend() == "nccl":
110
+ return dist.new_group(backend="gloo")
111
+
112
+ return dist.group.WORLD
113
+
114
+
115
+ def all_gather_cpu(data):
116
+ """
117
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
118
+ Args:
119
+ data: any picklable object
120
+ Returns:
121
+ list[data]: list of data gathered from each rank
122
+ """
123
+
124
+ world_size = get_world_size()
125
+ if world_size == 1:
126
+ return [data]
127
+
128
+ cpu_group = _get_global_gloo_group()
129
+
130
+ buffer = io.BytesIO()
131
+ torch.save(data, buffer)
132
+ data_view = buffer.getbuffer()
133
+ device = "cuda" if cpu_group is None else "cpu"
134
+ tensor = torch.ByteTensor(data_view).to(device)
135
+
136
+ # obtain Tensor size of each rank
137
+ local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long)
138
+ size_list = [torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)]
139
+ if cpu_group is None:
140
+ dist.all_gather(size_list, local_size)
141
+ else:
142
+ print("gathering on cpu")
143
+ dist.all_gather(size_list, local_size, group=cpu_group)
144
+ size_list = [int(size.item()) for size in size_list]
145
+ max_size = max(size_list)
146
+ assert isinstance(local_size.item(), int)
147
+ local_size = int(local_size.item())
148
+
149
+ # receiving Tensor from all ranks
150
+ # we pad the tensor because torch all_gather does not support
151
+ # gathering tensors of different shapes
152
+ tensor_list = []
153
+ for _ in size_list:
154
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device))
155
+ if local_size != max_size:
156
+ padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device=device)
157
+ tensor = torch.cat((tensor, padding), dim=0)
158
+ if cpu_group is None:
159
+ dist.all_gather(tensor_list, tensor)
160
+ else:
161
+ dist.all_gather(tensor_list, tensor, group=cpu_group)
162
+
163
+ data_list = []
164
+ for size, tensor in zip(size_list, tensor_list):
165
+ tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
166
+ buffer = io.BytesIO(tensor.cpu().numpy())
167
+ obj = torch.load(buffer)
168
+ data_list.append(obj)
169
+
170
+ return data_list
171
+
172
+
173
+ def all_gather(data):
174
+ """
175
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
176
+ Args:
177
+ data: any picklable object
178
+ Returns:
179
+ list[data]: list of data gathered from each rank
180
+ """
181
+
182
+ if os.getenv("CPU_REDUCE") == "1":
183
+ return all_gather_cpu(data)
184
+
185
+ world_size = get_world_size()
186
+ if world_size == 1:
187
+ return [data]
188
+
189
+ # serialized to a Tensor
190
+ buffer = pickle.dumps(data)
191
+ storage = torch.ByteStorage.from_buffer(buffer)
192
+ tensor = torch.ByteTensor(storage).to("cuda")
193
+
194
+ # obtain Tensor size of each rank
195
+ local_size = torch.tensor([tensor.numel()], device="cuda")
196
+ size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
197
+ dist.all_gather(size_list, local_size)
198
+ size_list = [int(size.item()) for size in size_list]
199
+ max_size = max(size_list)
200
+
201
+ # receiving Tensor from all ranks
202
+ # we pad the tensor because torch all_gather does not support
203
+ # gathering tensors of different shapes
204
+ tensor_list = []
205
+ for _ in size_list:
206
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
207
+ if local_size != max_size:
208
+ padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
209
+ tensor = torch.cat((tensor, padding), dim=0)
210
+ dist.all_gather(tensor_list, tensor)
211
+
212
+ data_list = []
213
+ for size, tensor in zip(size_list, tensor_list):
214
+ buffer = tensor.cpu().numpy().tobytes()[:size]
215
+ data_list.append(pickle.loads(buffer))
216
+
217
+ return data_list
218
+
219
+
220
+ def reduce_dict(input_dict, average=True):
221
+ """
222
+ Args:
223
+ input_dict (dict): all the values will be reduced
224
+ average (bool): whether to do average or sum
225
+ Reduce the values in the dictionary from all processes so that all processes
226
+ have the averaged results. Returns a dict with the same fields as
227
+ input_dict, after reduction.
228
+ """
229
+ world_size = get_world_size()
230
+ if world_size < 2:
231
+ return input_dict
232
+ with torch.no_grad():
233
+ names = []
234
+ values = []
235
+ # sort the keys so that they are consistent across processes
236
+ for k in sorted(input_dict.keys()):
237
+ names.append(k)
238
+ values.append(input_dict[k])
239
+ values = torch.stack(values, dim=0)
240
+ dist.all_reduce(values)
241
+ if average:
242
+ values /= world_size
243
+ reduced_dict = {k: v for k, v in zip(names, values)}
244
+ return reduced_dict
245
+
246
+
247
+ class MetricLogger(object):
248
+ def __init__(self, delimiter="\t"):
249
+ self.meters = defaultdict(SmoothedValue)
250
+ self.delimiter = delimiter
251
+
252
+ def update(self, **kwargs):
253
+ for k, v in kwargs.items():
254
+ if isinstance(v, torch.Tensor):
255
+ v = v.item()
256
+ assert isinstance(v, (float, int))
257
+ self.meters[k].update(v)
258
+
259
+ def __getattr__(self, attr):
260
+ if attr in self.meters:
261
+ return self.meters[attr]
262
+ if attr in self.__dict__:
263
+ return self.__dict__[attr]
264
+ raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
265
+
266
+ def __str__(self):
267
+ loss_str = []
268
+ for name, meter in self.meters.items():
269
+ # print(name, str(meter))
270
+ # import ipdb;ipdb.set_trace()
271
+ if meter.count > 0:
272
+ loss_str.append("{}: {}".format(name, str(meter)))
273
+ return self.delimiter.join(loss_str)
274
+
275
+ def synchronize_between_processes(self):
276
+ for meter in self.meters.values():
277
+ meter.synchronize_between_processes()
278
+
279
+ def add_meter(self, name, meter):
280
+ self.meters[name] = meter
281
+
282
+ def log_every(self, iterable, print_freq, header=None, logger=None):
283
+ if logger is None:
284
+ print_func = print
285
+ else:
286
+ print_func = logger.info
287
+
288
+ i = 0
289
+ if not header:
290
+ header = ""
291
+ start_time = time.time()
292
+ end = time.time()
293
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
294
+ data_time = SmoothedValue(fmt="{avg:.4f}")
295
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
296
+ if torch.cuda.is_available():
297
+ log_msg = self.delimiter.join(
298
+ [
299
+ header,
300
+ "[{0" + space_fmt + "}/{1}]",
301
+ "eta: {eta}",
302
+ "{meters}",
303
+ "time: {time}",
304
+ "data: {data}",
305
+ "max mem: {memory:.0f}",
306
+ ]
307
+ )
308
+ else:
309
+ log_msg = self.delimiter.join(
310
+ [
311
+ header,
312
+ "[{0" + space_fmt + "}/{1}]",
313
+ "eta: {eta}",
314
+ "{meters}",
315
+ "time: {time}",
316
+ "data: {data}",
317
+ ]
318
+ )
319
+ MB = 1024.0 * 1024.0
320
+ for obj in iterable:
321
+ data_time.update(time.time() - end)
322
+ yield obj
323
+ # import ipdb; ipdb.set_trace()
324
+ iter_time.update(time.time() - end)
325
+ if i % print_freq == 0 or i == len(iterable) - 1:
326
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
327
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
328
+ if torch.cuda.is_available():
329
+ print_func(
330
+ log_msg.format(
331
+ i,
332
+ len(iterable),
333
+ eta=eta_string,
334
+ meters=str(self),
335
+ time=str(iter_time),
336
+ data=str(data_time),
337
+ memory=torch.cuda.max_memory_allocated() / MB,
338
+ )
339
+ )
340
+ else:
341
+ print_func(
342
+ log_msg.format(
343
+ i,
344
+ len(iterable),
345
+ eta=eta_string,
346
+ meters=str(self),
347
+ time=str(iter_time),
348
+ data=str(data_time),
349
+ )
350
+ )
351
+ i += 1
352
+ end = time.time()
353
+ total_time = time.time() - start_time
354
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
355
+ print_func(
356
+ "{} Total time: {} ({:.4f} s / it)".format(
357
+ header, total_time_str, total_time / len(iterable)
358
+ )
359
+ )
360
+
361
+
362
+ def get_sha():
363
+ cwd = os.path.dirname(os.path.abspath(__file__))
364
+
365
+ def _run(command):
366
+ return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
367
+
368
+ sha = "N/A"
369
+ diff = "clean"
370
+ branch = "N/A"
371
+ try:
372
+ sha = _run(["git", "rev-parse", "HEAD"])
373
+ subprocess.check_output(["git", "diff"], cwd=cwd)
374
+ diff = _run(["git", "diff-index", "HEAD"])
375
+ diff = "has uncommited changes" if diff else "clean"
376
+ branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
377
+ except Exception:
378
+ pass
379
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
380
+ return message
381
+
382
+
383
+ def collate_fn(batch):
384
+ # import ipdb; ipdb.set_trace()
385
+ batch = list(zip(*batch))
386
+ batch[0] = nested_tensor_from_tensor_list(batch[0])
387
+ return tuple(batch)
388
+
389
+
390
+ def _max_by_axis(the_list):
391
+ # type: (List[List[int]]) -> List[int]
392
+ maxes = the_list[0]
393
+ for sublist in the_list[1:]:
394
+ for index, item in enumerate(sublist):
395
+ maxes[index] = max(maxes[index], item)
396
+ return maxes
397
+
398
+
399
+ class NestedTensor(object):
400
+ def __init__(self, tensors, mask: Optional[Tensor]):
401
+ self.tensors = tensors
402
+ self.mask = mask
403
+ if mask == "auto":
404
+ self.mask = torch.zeros_like(tensors).to(tensors.device)
405
+ if self.mask.dim() == 3:
406
+ self.mask = self.mask.sum(0).to(bool)
407
+ elif self.mask.dim() == 4:
408
+ self.mask = self.mask.sum(1).to(bool)
409
+ else:
410
+ raise ValueError(
411
+ "tensors dim must be 3 or 4 but {}({})".format(
412
+ self.tensors.dim(), self.tensors.shape
413
+ )
414
+ )
415
+
416
+ def imgsize(self):
417
+ res = []
418
+ for i in range(self.tensors.shape[0]):
419
+ mask = self.mask[i]
420
+ maxH = (~mask).sum(0).max()
421
+ maxW = (~mask).sum(1).max()
422
+ res.append(torch.Tensor([maxH, maxW]))
423
+ return res
424
+
425
+ def to(self, device):
426
+ # type: (Device) -> NestedTensor # noqa
427
+ cast_tensor = self.tensors.to(device)
428
+ mask = self.mask
429
+ if mask is not None:
430
+ assert mask is not None
431
+ cast_mask = mask.to(device)
432
+ else:
433
+ cast_mask = None
434
+ return NestedTensor(cast_tensor, cast_mask)
435
+
436
+ def to_img_list_single(self, tensor, mask):
437
+ assert tensor.dim() == 3, "dim of tensor should be 3 but {}".format(tensor.dim())
438
+ maxH = (~mask).sum(0).max()
439
+ maxW = (~mask).sum(1).max()
440
+ img = tensor[:, :maxH, :maxW]
441
+ return img
442
+
443
+ def to_img_list(self):
444
+ """remove the padding and convert to img list
445
+
446
+ Returns:
447
+ [type]: [description]
448
+ """
449
+ if self.tensors.dim() == 3:
450
+ return self.to_img_list_single(self.tensors, self.mask)
451
+ else:
452
+ res = []
453
+ for i in range(self.tensors.shape[0]):
454
+ tensor_i = self.tensors[i]
455
+ mask_i = self.mask[i]
456
+ res.append(self.to_img_list_single(tensor_i, mask_i))
457
+ return res
458
+
459
+ @property
460
+ def device(self):
461
+ return self.tensors.device
462
+
463
+ def decompose(self):
464
+ return self.tensors, self.mask
465
+
466
+ def __repr__(self):
467
+ return str(self.tensors)
468
+
469
+ @property
470
+ def shape(self):
471
+ return {"tensors.shape": self.tensors.shape, "mask.shape": self.mask.shape}
472
+
473
+
474
+ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
475
+ # TODO make this more general
476
+ if tensor_list[0].ndim == 3:
477
+ if torchvision._is_tracing():
478
+ # nested_tensor_from_tensor_list() does not export well to ONNX
479
+ # call _onnx_nested_tensor_from_tensor_list() instead
480
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
481
+
482
+ # TODO make it support different-sized images
483
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
484
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
485
+ batch_shape = [len(tensor_list)] + max_size
486
+ b, c, h, w = batch_shape
487
+ dtype = tensor_list[0].dtype
488
+ device = tensor_list[0].device
489
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
490
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
491
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
492
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
493
+ m[: img.shape[1], : img.shape[2]] = False
494
+ else:
495
+ raise ValueError("not supported")
496
+ return NestedTensor(tensor, mask)
497
+
498
+
499
+ # _onnx_nested_tensor_from_tensor_list() is an implementation of
500
+ # nested_tensor_from_tensor_list() that is supported by ONNX tracing.
501
+ @torch.jit.unused
502
+ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
503
+ max_size = []
504
+ for i in range(tensor_list[0].dim()):
505
+ max_size_i = torch.max(
506
+ torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
507
+ ).to(torch.int64)
508
+ max_size.append(max_size_i)
509
+ max_size = tuple(max_size)
510
+
511
+ # work around for
512
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
513
+ # m[: img.shape[1], :img.shape[2]] = False
514
+ # which is not yet supported in onnx
515
+ padded_imgs = []
516
+ padded_masks = []
517
+ for img in tensor_list:
518
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
519
+ padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
520
+ padded_imgs.append(padded_img)
521
+
522
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
523
+ padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
524
+ padded_masks.append(padded_mask.to(torch.bool))
525
+
526
+ tensor = torch.stack(padded_imgs)
527
+ mask = torch.stack(padded_masks)
528
+
529
+ return NestedTensor(tensor, mask=mask)
530
+
531
+
532
+ def setup_for_distributed(is_master):
533
+ """
534
+ This function disables printing when not in master process
535
+ """
536
+ import builtins as __builtin__
537
+
538
+ builtin_print = __builtin__.print
539
+
540
+ def print(*args, **kwargs):
541
+ force = kwargs.pop("force", False)
542
+ if is_master or force:
543
+ builtin_print(*args, **kwargs)
544
+
545
+ __builtin__.print = print
546
+
547
+
548
+ def is_dist_avail_and_initialized():
549
+ if not dist.is_available():
550
+ return False
551
+ if not dist.is_initialized():
552
+ return False
553
+ return True
554
+
555
+
556
+ def get_world_size():
557
+ if not is_dist_avail_and_initialized():
558
+ return 1
559
+ return dist.get_world_size()
560
+
561
+
562
+ def get_rank():
563
+ if not is_dist_avail_and_initialized():
564
+ return 0
565
+ return dist.get_rank()
566
+
567
+
568
+ def is_main_process():
569
+ return get_rank() == 0
570
+
571
+
572
+ def save_on_master(*args, **kwargs):
573
+ if is_main_process():
574
+ torch.save(*args, **kwargs)
575
+
576
+
577
+ def init_distributed_mode(args):
578
+ if "WORLD_SIZE" in os.environ and os.environ["WORLD_SIZE"] != "": # 'RANK' in os.environ and
579
+ args.rank = int(os.environ["RANK"])
580
+ args.world_size = int(os.environ["WORLD_SIZE"])
581
+ args.gpu = args.local_rank = int(os.environ["LOCAL_RANK"])
582
+
583
+ # launch by torch.distributed.launch
584
+ # Single node
585
+ # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 1 --rank 0 ...
586
+ # Multi nodes
587
+ # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 0 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ...
588
+ # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 1 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ...
589
+ # args.rank = int(os.environ.get('OMPI_COMM_WORLD_RANK'))
590
+ # local_world_size = int(os.environ['GPU_PER_NODE_COUNT'])
591
+ # args.world_size = args.world_size * local_world_size
592
+ # args.gpu = args.local_rank = int(os.environ['LOCAL_RANK'])
593
+ # args.rank = args.rank * local_world_size + args.local_rank
594
+ print(
595
+ "world size: {}, rank: {}, local rank: {}".format(
596
+ args.world_size, args.rank, args.local_rank
597
+ )
598
+ )
599
+ print(json.dumps(dict(os.environ), indent=2))
600
+ elif "SLURM_PROCID" in os.environ:
601
+ args.rank = int(os.environ["SLURM_PROCID"])
602
+ args.gpu = args.local_rank = int(os.environ["SLURM_LOCALID"])
603
+ args.world_size = int(os.environ["SLURM_NPROCS"])
604
+
605
+ print(
606
+ "world size: {}, world rank: {}, local rank: {}, device_count: {}".format(
607
+ args.world_size, args.rank, args.local_rank, torch.cuda.device_count()
608
+ )
609
+ )
610
+ else:
611
+ print("Not using distributed mode")
612
+ args.distributed = False
613
+ args.world_size = 1
614
+ args.rank = 0
615
+ args.local_rank = 0
616
+ return
617
+
618
+ print("world_size:{} rank:{} local_rank:{}".format(args.world_size, args.rank, args.local_rank))
619
+ args.distributed = True
620
+ torch.cuda.set_device(args.local_rank)
621
+ args.dist_backend = "nccl"
622
+ print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
623
+
624
+ torch.distributed.init_process_group(
625
+ backend=args.dist_backend,
626
+ world_size=args.world_size,
627
+ rank=args.rank,
628
+ init_method=args.dist_url,
629
+ )
630
+
631
+ print("Before torch.distributed.barrier()")
632
+ torch.distributed.barrier()
633
+ print("End torch.distributed.barrier()")
634
+ setup_for_distributed(args.rank == 0)
635
+
636
+
637
+ @torch.no_grad()
638
+ def accuracy(output, target, topk=(1,)):
639
+ """Computes the precision@k for the specified values of k"""
640
+ if target.numel() == 0:
641
+ return [torch.zeros([], device=output.device)]
642
+ maxk = max(topk)
643
+ batch_size = target.size(0)
644
+
645
+ _, pred = output.topk(maxk, 1, True, True)
646
+ pred = pred.t()
647
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
648
+
649
+ res = []
650
+ for k in topk:
651
+ correct_k = correct[:k].view(-1).float().sum(0)
652
+ res.append(correct_k.mul_(100.0 / batch_size))
653
+ return res
654
+
655
+
656
+ @torch.no_grad()
657
+ def accuracy_onehot(pred, gt):
658
+ """_summary_
659
+
660
+ Args:
661
+ pred (_type_): n, c
662
+ gt (_type_): n, c
663
+ """
664
+ tp = ((pred - gt).abs().sum(-1) < 1e-4).float().sum()
665
+ acc = tp / gt.shape[0] * 100
666
+ return acc
667
+
668
+
669
+ def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
670
+ # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
671
+ """
672
+ Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
673
+ This will eventually be supported natively by PyTorch, and this
674
+ class can go away.
675
+ """
676
+ if __torchvision_need_compat_flag < 0.7:
677
+ if input.numel() > 0:
678
+ return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners)
679
+
680
+ output_shape = _output_size(2, input, size, scale_factor)
681
+ output_shape = list(input.shape[:-2]) + list(output_shape)
682
+ return _new_empty_tensor(input, output_shape)
683
+ else:
684
+ return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
685
+
686
+
687
+ class color_sys:
688
+ def __init__(self, num_colors) -> None:
689
+ self.num_colors = num_colors
690
+ colors = []
691
+ for i in np.arange(0.0, 360.0, 360.0 / num_colors):
692
+ hue = i / 360.0
693
+ lightness = (50 + np.random.rand() * 10) / 100.0
694
+ saturation = (90 + np.random.rand() * 10) / 100.0
695
+ colors.append(
696
+ tuple([int(j * 255) for j in colorsys.hls_to_rgb(hue, lightness, saturation)])
697
+ )
698
+ self.colors = colors
699
+
700
+ def __call__(self, idx):
701
+ return self.colors[idx]
702
+
703
+
704
+ def inverse_sigmoid(x, eps=1e-3):
705
+ x = x.clamp(min=0, max=1)
706
+ x1 = x.clamp(min=eps)
707
+ x2 = (1 - x).clamp(min=eps)
708
+ return torch.log(x1 / x2)
709
+
710
+
711
+ def clean_state_dict(state_dict):
712
+ new_state_dict = OrderedDict()
713
+ for k, v in state_dict.items():
714
+ if k[:7] == "module.":
715
+ k = k[7:] # remove `module.`
716
+ new_state_dict[k] = v
717
+ return new_state_dict
groundingdino/util/slconfig.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==========================================================
2
+ # Modified from mmcv
3
+ # ==========================================================
4
+ import ast
5
+ import os
6
+ import os.path as osp
7
+ import shutil
8
+ import sys
9
+ import tempfile
10
+ from argparse import Action
11
+ from importlib import import_module
12
+
13
+ from addict import Dict
14
+ from yapf.yapflib.yapf_api import FormatCode
15
+
16
+ BASE_KEY = "_base_"
17
+ DELETE_KEY = "_delete_"
18
+ RESERVED_KEYS = ["filename", "text", "pretty_text", "get", "dump", "merge_from_dict"]
19
+
20
+
21
+ def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
22
+ if not osp.isfile(filename):
23
+ raise FileNotFoundError(msg_tmpl.format(filename))
24
+
25
+
26
+ class ConfigDict(Dict):
27
+ def __missing__(self, name):
28
+ raise KeyError(name)
29
+
30
+ def __getattr__(self, name):
31
+ try:
32
+ value = super(ConfigDict, self).__getattr__(name)
33
+ except KeyError:
34
+ ex = AttributeError(f"'{self.__class__.__name__}' object has no " f"attribute '{name}'")
35
+ except Exception as e:
36
+ ex = e
37
+ else:
38
+ return value
39
+ raise ex
40
+
41
+
42
+ class SLConfig(object):
43
+ """
44
+ config files.
45
+ only support .py file as config now.
46
+
47
+ ref: mmcv.utils.config
48
+
49
+ Example:
50
+ >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
51
+ >>> cfg.a
52
+ 1
53
+ >>> cfg.b
54
+ {'b1': [0, 1]}
55
+ >>> cfg.b.b1
56
+ [0, 1]
57
+ >>> cfg = Config.fromfile('tests/data/config/a.py')
58
+ >>> cfg.filename
59
+ "/home/kchen/projects/mmcv/tests/data/config/a.py"
60
+ >>> cfg.item4
61
+ 'test'
62
+ >>> cfg
63
+ "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
64
+ "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
65
+ """
66
+
67
+ @staticmethod
68
+ def _validate_py_syntax(filename):
69
+ with open(filename) as f:
70
+ content = f.read()
71
+ try:
72
+ ast.parse(content)
73
+ except SyntaxError:
74
+ raise SyntaxError("There are syntax errors in config " f"file {filename}")
75
+
76
+ @staticmethod
77
+ def _file2dict(filename):
78
+ filename = osp.abspath(osp.expanduser(filename))
79
+ check_file_exist(filename)
80
+ if filename.lower().endswith(".py"):
81
+ with tempfile.TemporaryDirectory() as temp_config_dir:
82
+ temp_config_file = tempfile.NamedTemporaryFile(dir=temp_config_dir, suffix=".py")
83
+ temp_config_name = osp.basename(temp_config_file.name)
84
+ if os.name == 'nt':
85
+ temp_config_file.close()
86
+ shutil.copyfile(filename, osp.join(temp_config_dir, temp_config_name))
87
+ temp_module_name = osp.splitext(temp_config_name)[0]
88
+ sys.path.insert(0, temp_config_dir)
89
+ SLConfig._validate_py_syntax(filename)
90
+ mod = import_module(temp_module_name)
91
+ sys.path.pop(0)
92
+ cfg_dict = {
93
+ name: value for name, value in mod.__dict__.items() if not name.startswith("__")
94
+ }
95
+ # delete imported module
96
+ del sys.modules[temp_module_name]
97
+ # close temp file
98
+ temp_config_file.close()
99
+ elif filename.lower().endswith((".yml", ".yaml", ".json")):
100
+ from .slio import slload
101
+
102
+ cfg_dict = slload(filename)
103
+ else:
104
+ raise IOError("Only py/yml/yaml/json type are supported now!")
105
+
106
+ cfg_text = filename + "\n"
107
+ with open(filename, "r") as f:
108
+ cfg_text += f.read()
109
+
110
+ # parse the base file
111
+ if BASE_KEY in cfg_dict:
112
+ cfg_dir = osp.dirname(filename)
113
+ base_filename = cfg_dict.pop(BASE_KEY)
114
+ base_filename = base_filename if isinstance(base_filename, list) else [base_filename]
115
+
116
+ cfg_dict_list = list()
117
+ cfg_text_list = list()
118
+ for f in base_filename:
119
+ _cfg_dict, _cfg_text = SLConfig._file2dict(osp.join(cfg_dir, f))
120
+ cfg_dict_list.append(_cfg_dict)
121
+ cfg_text_list.append(_cfg_text)
122
+
123
+ base_cfg_dict = dict()
124
+ for c in cfg_dict_list:
125
+ if len(base_cfg_dict.keys() & c.keys()) > 0:
126
+ raise KeyError("Duplicate key is not allowed among bases")
127
+ # TODO Allow the duplicate key while warnning user
128
+ base_cfg_dict.update(c)
129
+
130
+ base_cfg_dict = SLConfig._merge_a_into_b(cfg_dict, base_cfg_dict)
131
+ cfg_dict = base_cfg_dict
132
+
133
+ # merge cfg_text
134
+ cfg_text_list.append(cfg_text)
135
+ cfg_text = "\n".join(cfg_text_list)
136
+
137
+ return cfg_dict, cfg_text
138
+
139
+ @staticmethod
140
+ def _merge_a_into_b(a, b):
141
+ """merge dict `a` into dict `b` (non-inplace).
142
+ values in `a` will overwrite `b`.
143
+ copy first to avoid inplace modification
144
+
145
+ Args:
146
+ a ([type]): [description]
147
+ b ([type]): [description]
148
+
149
+ Returns:
150
+ [dict]: [description]
151
+ """
152
+ # import ipdb; ipdb.set_trace()
153
+ if not isinstance(a, dict):
154
+ return a
155
+
156
+ b = b.copy()
157
+ for k, v in a.items():
158
+ if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
159
+
160
+ if not isinstance(b[k], dict) and not isinstance(b[k], list):
161
+ # if :
162
+ # import ipdb; ipdb.set_trace()
163
+ raise TypeError(
164
+ f"{k}={v} in child config cannot inherit from base "
165
+ f"because {k} is a dict in the child config but is of "
166
+ f"type {type(b[k])} in base config. You may set "
167
+ f"`{DELETE_KEY}=True` to ignore the base config"
168
+ )
169
+ b[k] = SLConfig._merge_a_into_b(v, b[k])
170
+ elif isinstance(b, list):
171
+ try:
172
+ _ = int(k)
173
+ except:
174
+ raise TypeError(
175
+ f"b is a list, " f"index {k} should be an int when input but {type(k)}"
176
+ )
177
+ b[int(k)] = SLConfig._merge_a_into_b(v, b[int(k)])
178
+ else:
179
+ b[k] = v
180
+
181
+ return b
182
+
183
+ @staticmethod
184
+ def fromfile(filename):
185
+ cfg_dict, cfg_text = SLConfig._file2dict(filename)
186
+ return SLConfig(cfg_dict, cfg_text=cfg_text, filename=filename)
187
+
188
+ def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
189
+ if cfg_dict is None:
190
+ cfg_dict = dict()
191
+ elif not isinstance(cfg_dict, dict):
192
+ raise TypeError("cfg_dict must be a dict, but " f"got {type(cfg_dict)}")
193
+ for key in cfg_dict:
194
+ if key in RESERVED_KEYS:
195
+ raise KeyError(f"{key} is reserved for config file")
196
+
197
+ super(SLConfig, self).__setattr__("_cfg_dict", ConfigDict(cfg_dict))
198
+ super(SLConfig, self).__setattr__("_filename", filename)
199
+ if cfg_text:
200
+ text = cfg_text
201
+ elif filename:
202
+ with open(filename, "r") as f:
203
+ text = f.read()
204
+ else:
205
+ text = ""
206
+ super(SLConfig, self).__setattr__("_text", text)
207
+
208
+ @property
209
+ def filename(self):
210
+ return self._filename
211
+
212
+ @property
213
+ def text(self):
214
+ return self._text
215
+
216
+ @property
217
+ def pretty_text(self):
218
+
219
+ indent = 4
220
+
221
+ def _indent(s_, num_spaces):
222
+ s = s_.split("\n")
223
+ if len(s) == 1:
224
+ return s_
225
+ first = s.pop(0)
226
+ s = [(num_spaces * " ") + line for line in s]
227
+ s = "\n".join(s)
228
+ s = first + "\n" + s
229
+ return s
230
+
231
+ def _format_basic_types(k, v, use_mapping=False):
232
+ if isinstance(v, str):
233
+ v_str = f"'{v}'"
234
+ else:
235
+ v_str = str(v)
236
+
237
+ if use_mapping:
238
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
239
+ attr_str = f"{k_str}: {v_str}"
240
+ else:
241
+ attr_str = f"{str(k)}={v_str}"
242
+ attr_str = _indent(attr_str, indent)
243
+
244
+ return attr_str
245
+
246
+ def _format_list(k, v, use_mapping=False):
247
+ # check if all items in the list are dict
248
+ if all(isinstance(_, dict) for _ in v):
249
+ v_str = "[\n"
250
+ v_str += "\n".join(
251
+ f"dict({_indent(_format_dict(v_), indent)})," for v_ in v
252
+ ).rstrip(",")
253
+ if use_mapping:
254
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
255
+ attr_str = f"{k_str}: {v_str}"
256
+ else:
257
+ attr_str = f"{str(k)}={v_str}"
258
+ attr_str = _indent(attr_str, indent) + "]"
259
+ else:
260
+ attr_str = _format_basic_types(k, v, use_mapping)
261
+ return attr_str
262
+
263
+ def _contain_invalid_identifier(dict_str):
264
+ contain_invalid_identifier = False
265
+ for key_name in dict_str:
266
+ contain_invalid_identifier |= not str(key_name).isidentifier()
267
+ return contain_invalid_identifier
268
+
269
+ def _format_dict(input_dict, outest_level=False):
270
+ r = ""
271
+ s = []
272
+
273
+ use_mapping = _contain_invalid_identifier(input_dict)
274
+ if use_mapping:
275
+ r += "{"
276
+ for idx, (k, v) in enumerate(input_dict.items()):
277
+ is_last = idx >= len(input_dict) - 1
278
+ end = "" if outest_level or is_last else ","
279
+ if isinstance(v, dict):
280
+ v_str = "\n" + _format_dict(v)
281
+ if use_mapping:
282
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
283
+ attr_str = f"{k_str}: dict({v_str}"
284
+ else:
285
+ attr_str = f"{str(k)}=dict({v_str}"
286
+ attr_str = _indent(attr_str, indent) + ")" + end
287
+ elif isinstance(v, list):
288
+ attr_str = _format_list(k, v, use_mapping) + end
289
+ else:
290
+ attr_str = _format_basic_types(k, v, use_mapping) + end
291
+
292
+ s.append(attr_str)
293
+ r += "\n".join(s)
294
+ if use_mapping:
295
+ r += "}"
296
+ return r
297
+
298
+ cfg_dict = self._cfg_dict.to_dict()
299
+ text = _format_dict(cfg_dict, outest_level=True)
300
+ # copied from setup.cfg
301
+ yapf_style = dict(
302
+ based_on_style="pep8",
303
+ blank_line_before_nested_class_or_def=True,
304
+ split_before_expression_after_opening_paren=True,
305
+ )
306
+ text, _ = FormatCode(text, style_config=yapf_style, verify=True)
307
+
308
+ return text
309
+
310
+ def __repr__(self):
311
+ return f"Config (path: {self.filename}): {self._cfg_dict.__repr__()}"
312
+
313
+ def __len__(self):
314
+ return len(self._cfg_dict)
315
+
316
+ def __getattr__(self, name):
317
+ # # debug
318
+ # print('+'*15)
319
+ # print('name=%s' % name)
320
+ # print("addr:", id(self))
321
+ # # print('type(self):', type(self))
322
+ # print(self.__dict__)
323
+ # print('+'*15)
324
+ # if self.__dict__ == {}:
325
+ # raise ValueError
326
+
327
+ return getattr(self._cfg_dict, name)
328
+
329
+ def __getitem__(self, name):
330
+ return self._cfg_dict.__getitem__(name)
331
+
332
+ def __setattr__(self, name, value):
333
+ if isinstance(value, dict):
334
+ value = ConfigDict(value)
335
+ self._cfg_dict.__setattr__(name, value)
336
+
337
+ def __setitem__(self, name, value):
338
+ if isinstance(value, dict):
339
+ value = ConfigDict(value)
340
+ self._cfg_dict.__setitem__(name, value)
341
+
342
+ def __iter__(self):
343
+ return iter(self._cfg_dict)
344
+
345
+ def dump(self, file=None):
346
+ # import ipdb; ipdb.set_trace()
347
+ if file is None:
348
+ return self.pretty_text
349
+ else:
350
+ with open(file, "w") as f:
351
+ f.write(self.pretty_text)
352
+
353
+ def merge_from_dict(self, options):
354
+ """Merge list into cfg_dict
355
+
356
+ Merge the dict parsed by MultipleKVAction into this cfg.
357
+
358
+ Examples:
359
+ >>> options = {'model.backbone.depth': 50,
360
+ ... 'model.backbone.with_cp':True}
361
+ >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
362
+ >>> cfg.merge_from_dict(options)
363
+ >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
364
+ >>> assert cfg_dict == dict(
365
+ ... model=dict(backbone=dict(depth=50, with_cp=True)))
366
+
367
+ Args:
368
+ options (dict): dict of configs to merge from.
369
+ """
370
+ option_cfg_dict = {}
371
+ for full_key, v in options.items():
372
+ d = option_cfg_dict
373
+ key_list = full_key.split(".")
374
+ for subkey in key_list[:-1]:
375
+ d.setdefault(subkey, ConfigDict())
376
+ d = d[subkey]
377
+ subkey = key_list[-1]
378
+ d[subkey] = v
379
+
380
+ cfg_dict = super(SLConfig, self).__getattribute__("_cfg_dict")
381
+ super(SLConfig, self).__setattr__(
382
+ "_cfg_dict", SLConfig._merge_a_into_b(option_cfg_dict, cfg_dict)
383
+ )
384
+
385
+ # for multiprocess
386
+ def __setstate__(self, state):
387
+ self.__init__(state)
388
+
389
+ def copy(self):
390
+ return SLConfig(self._cfg_dict.copy())
391
+
392
+ def deepcopy(self):
393
+ return SLConfig(self._cfg_dict.deepcopy())
394
+
395
+
396
+ class DictAction(Action):
397
+ """
398
+ argparse action to split an argument into KEY=VALUE form
399
+ on the first = and append to a dictionary. List options should
400
+ be passed as comma separated values, i.e KEY=V1,V2,V3
401
+ """
402
+
403
+ @staticmethod
404
+ def _parse_int_float_bool(val):
405
+ try:
406
+ return int(val)
407
+ except ValueError:
408
+ pass
409
+ try:
410
+ return float(val)
411
+ except ValueError:
412
+ pass
413
+ if val.lower() in ["true", "false"]:
414
+ return True if val.lower() == "true" else False
415
+ if val.lower() in ["none", "null"]:
416
+ return None
417
+ return val
418
+
419
+ def __call__(self, parser, namespace, values, option_string=None):
420
+ options = {}
421
+ for kv in values:
422
+ key, val = kv.split("=", maxsplit=1)
423
+ val = [self._parse_int_float_bool(v) for v in val.split(",")]
424
+ if len(val) == 1:
425
+ val = val[0]
426
+ options[key] = val
427
+ setattr(namespace, self.dest, options)
groundingdino/util/slio.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==========================================================
2
+ # Modified from mmcv
3
+ # ==========================================================
4
+
5
+ import json
6
+ import pickle
7
+ from abc import ABCMeta, abstractmethod
8
+ from pathlib import Path
9
+
10
+ import yaml
11
+
12
+ try:
13
+ from yaml import CLoader as Loader, CDumper as Dumper
14
+ except ImportError:
15
+ from yaml import Loader, Dumper
16
+
17
+
18
+ # ===========================
19
+ # Rigister handler
20
+ # ===========================
21
+
22
+
23
+ class BaseFileHandler(metaclass=ABCMeta):
24
+ @abstractmethod
25
+ def load_from_fileobj(self, file, **kwargs):
26
+ pass
27
+
28
+ @abstractmethod
29
+ def dump_to_fileobj(self, obj, file, **kwargs):
30
+ pass
31
+
32
+ @abstractmethod
33
+ def dump_to_str(self, obj, **kwargs):
34
+ pass
35
+
36
+ def load_from_path(self, filepath, mode="r", **kwargs):
37
+ with open(filepath, mode) as f:
38
+ return self.load_from_fileobj(f, **kwargs)
39
+
40
+ def dump_to_path(self, obj, filepath, mode="w", **kwargs):
41
+ with open(filepath, mode) as f:
42
+ self.dump_to_fileobj(obj, f, **kwargs)
43
+
44
+
45
+ class JsonHandler(BaseFileHandler):
46
+ def load_from_fileobj(self, file):
47
+ return json.load(file)
48
+
49
+ def dump_to_fileobj(self, obj, file, **kwargs):
50
+ json.dump(obj, file, **kwargs)
51
+
52
+ def dump_to_str(self, obj, **kwargs):
53
+ return json.dumps(obj, **kwargs)
54
+
55
+
56
+ class PickleHandler(BaseFileHandler):
57
+ def load_from_fileobj(self, file, **kwargs):
58
+ return pickle.load(file, **kwargs)
59
+
60
+ def load_from_path(self, filepath, **kwargs):
61
+ return super(PickleHandler, self).load_from_path(filepath, mode="rb", **kwargs)
62
+
63
+ def dump_to_str(self, obj, **kwargs):
64
+ kwargs.setdefault("protocol", 2)
65
+ return pickle.dumps(obj, **kwargs)
66
+
67
+ def dump_to_fileobj(self, obj, file, **kwargs):
68
+ kwargs.setdefault("protocol", 2)
69
+ pickle.dump(obj, file, **kwargs)
70
+
71
+ def dump_to_path(self, obj, filepath, **kwargs):
72
+ super(PickleHandler, self).dump_to_path(obj, filepath, mode="wb", **kwargs)
73
+
74
+
75
+ class YamlHandler(BaseFileHandler):
76
+ def load_from_fileobj(self, file, **kwargs):
77
+ kwargs.setdefault("Loader", Loader)
78
+ return yaml.load(file, **kwargs)
79
+
80
+ def dump_to_fileobj(self, obj, file, **kwargs):
81
+ kwargs.setdefault("Dumper", Dumper)
82
+ yaml.dump(obj, file, **kwargs)
83
+
84
+ def dump_to_str(self, obj, **kwargs):
85
+ kwargs.setdefault("Dumper", Dumper)
86
+ return yaml.dump(obj, **kwargs)
87
+
88
+
89
+ file_handlers = {
90
+ "json": JsonHandler(),
91
+ "yaml": YamlHandler(),
92
+ "yml": YamlHandler(),
93
+ "pickle": PickleHandler(),
94
+ "pkl": PickleHandler(),
95
+ }
96
+
97
+ # ===========================
98
+ # load and dump
99
+ # ===========================
100
+
101
+
102
+ def is_str(x):
103
+ """Whether the input is an string instance.
104
+
105
+ Note: This method is deprecated since python 2 is no longer supported.
106
+ """
107
+ return isinstance(x, str)
108
+
109
+
110
+ def slload(file, file_format=None, **kwargs):
111
+ """Load data from json/yaml/pickle files.
112
+
113
+ This method provides a unified api for loading data from serialized files.
114
+
115
+ Args:
116
+ file (str or :obj:`Path` or file-like object): Filename or a file-like
117
+ object.
118
+ file_format (str, optional): If not specified, the file format will be
119
+ inferred from the file extension, otherwise use the specified one.
120
+ Currently supported formats include "json", "yaml/yml" and
121
+ "pickle/pkl".
122
+
123
+ Returns:
124
+ The content from the file.
125
+ """
126
+ if isinstance(file, Path):
127
+ file = str(file)
128
+ if file_format is None and is_str(file):
129
+ file_format = file.split(".")[-1]
130
+ if file_format not in file_handlers:
131
+ raise TypeError(f"Unsupported format: {file_format}")
132
+
133
+ handler = file_handlers[file_format]
134
+ if is_str(file):
135
+ obj = handler.load_from_path(file, **kwargs)
136
+ elif hasattr(file, "read"):
137
+ obj = handler.load_from_fileobj(file, **kwargs)
138
+ else:
139
+ raise TypeError('"file" must be a filepath str or a file-object')
140
+ return obj
141
+
142
+
143
+ def sldump(obj, file=None, file_format=None, **kwargs):
144
+ """Dump data to json/yaml/pickle strings or files.
145
+
146
+ This method provides a unified api for dumping data as strings or to files,
147
+ and also supports custom arguments for each file format.
148
+
149
+ Args:
150
+ obj (any): The python object to be dumped.
151
+ file (str or :obj:`Path` or file-like object, optional): If not
152
+ specified, then the object is dump to a str, otherwise to a file
153
+ specified by the filename or file-like object.
154
+ file_format (str, optional): Same as :func:`load`.
155
+
156
+ Returns:
157
+ bool: True for success, False otherwise.
158
+ """
159
+ if isinstance(file, Path):
160
+ file = str(file)
161
+ if file_format is None:
162
+ if is_str(file):
163
+ file_format = file.split(".")[-1]
164
+ elif file is None:
165
+ raise ValueError("file_format must be specified since file is None")
166
+ if file_format not in file_handlers:
167
+ raise TypeError(f"Unsupported format: {file_format}")
168
+
169
+ handler = file_handlers[file_format]
170
+ if file is None:
171
+ return handler.dump_to_str(obj, **kwargs)
172
+ elif is_str(file):
173
+ handler.dump_to_path(obj, file, **kwargs)
174
+ elif hasattr(file, "write"):
175
+ handler.dump_to_fileobj(obj, file, **kwargs)
176
+ else:
177
+ raise TypeError('"file" must be a filename str or a file-object')
groundingdino/util/time_counter.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+
4
+
5
+ class TimeCounter:
6
+ def __init__(self) -> None:
7
+ pass
8
+
9
+ def clear(self):
10
+ self.timedict = {}
11
+ self.basetime = time.perf_counter()
12
+
13
+ def timeit(self, name):
14
+ nowtime = time.perf_counter() - self.basetime
15
+ self.timedict[name] = nowtime
16
+ self.basetime = time.perf_counter()
17
+
18
+
19
+ class TimeHolder:
20
+ def __init__(self) -> None:
21
+ self.timedict = {}
22
+
23
+ def update(self, _timedict: dict):
24
+ for k, v in _timedict.items():
25
+ if k not in self.timedict:
26
+ self.timedict[k] = AverageMeter(name=k, val_only=True)
27
+ self.timedict[k].update(val=v)
28
+
29
+ def final_res(self):
30
+ return {k: v.avg for k, v in self.timedict.items()}
31
+
32
+ def __str__(self):
33
+ return json.dumps(self.final_res(), indent=2)
34
+
35
+
36
+ class AverageMeter(object):
37
+ """Computes and stores the average and current value"""
38
+
39
+ def __init__(self, name, fmt=":f", val_only=False):
40
+ self.name = name
41
+ self.fmt = fmt
42
+ self.val_only = val_only
43
+ self.reset()
44
+
45
+ def reset(self):
46
+ self.val = 0
47
+ self.avg = 0
48
+ self.sum = 0
49
+ self.count = 0
50
+
51
+ def update(self, val, n=1):
52
+ self.val = val
53
+ self.sum += val * n
54
+ self.count += n
55
+ self.avg = self.sum / self.count
56
+
57
+ def __str__(self):
58
+ if self.val_only:
59
+ fmtstr = "{name} {val" + self.fmt + "}"
60
+ else:
61
+ fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
62
+ return fmtstr.format(**self.__dict__)
groundingdino/util/utils.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import warnings
4
+ from collections import OrderedDict
5
+ from copy import deepcopy
6
+ from typing import Any, Dict, List
7
+
8
+ import numpy as np
9
+ import torch
10
+ from transformers import AutoTokenizer
11
+
12
+ from groundingdino.util.slconfig import SLConfig
13
+
14
+
15
+ def slprint(x, name="x"):
16
+ if isinstance(x, (torch.Tensor, np.ndarray)):
17
+ print(f"{name}.shape:", x.shape)
18
+ elif isinstance(x, (tuple, list)):
19
+ print("type x:", type(x))
20
+ for i in range(min(10, len(x))):
21
+ slprint(x[i], f"{name}[{i}]")
22
+ elif isinstance(x, dict):
23
+ for k, v in x.items():
24
+ slprint(v, f"{name}[{k}]")
25
+ else:
26
+ print(f"{name}.type:", type(x))
27
+
28
+
29
+ def clean_state_dict(state_dict):
30
+ new_state_dict = OrderedDict()
31
+ for k, v in state_dict.items():
32
+ if k[:7] == "module.":
33
+ k = k[7:] # remove `module.`
34
+ new_state_dict[k] = v
35
+ return new_state_dict
36
+
37
+
38
+ def renorm(
39
+ img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
40
+ ) -> torch.FloatTensor:
41
+ # img: tensor(3,H,W) or tensor(B,3,H,W)
42
+ # return: same as img
43
+ assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim()
44
+ if img.dim() == 3:
45
+ assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % (
46
+ img.size(0),
47
+ str(img.size()),
48
+ )
49
+ img_perm = img.permute(1, 2, 0)
50
+ mean = torch.Tensor(mean)
51
+ std = torch.Tensor(std)
52
+ img_res = img_perm * std + mean
53
+ return img_res.permute(2, 0, 1)
54
+ else: # img.dim() == 4
55
+ assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % (
56
+ img.size(1),
57
+ str(img.size()),
58
+ )
59
+ img_perm = img.permute(0, 2, 3, 1)
60
+ mean = torch.Tensor(mean)
61
+ std = torch.Tensor(std)
62
+ img_res = img_perm * std + mean
63
+ return img_res.permute(0, 3, 1, 2)
64
+
65
+
66
+ class CocoClassMapper:
67
+ def __init__(self) -> None:
68
+ self.category_map_str = {
69
+ "1": 1,
70
+ "2": 2,
71
+ "3": 3,
72
+ "4": 4,
73
+ "5": 5,
74
+ "6": 6,
75
+ "7": 7,
76
+ "8": 8,
77
+ "9": 9,
78
+ "10": 10,
79
+ "11": 11,
80
+ "13": 12,
81
+ "14": 13,
82
+ "15": 14,
83
+ "16": 15,
84
+ "17": 16,
85
+ "18": 17,
86
+ "19": 18,
87
+ "20": 19,
88
+ "21": 20,
89
+ "22": 21,
90
+ "23": 22,
91
+ "24": 23,
92
+ "25": 24,
93
+ "27": 25,
94
+ "28": 26,
95
+ "31": 27,
96
+ "32": 28,
97
+ "33": 29,
98
+ "34": 30,
99
+ "35": 31,
100
+ "36": 32,
101
+ "37": 33,
102
+ "38": 34,
103
+ "39": 35,
104
+ "40": 36,
105
+ "41": 37,
106
+ "42": 38,
107
+ "43": 39,
108
+ "44": 40,
109
+ "46": 41,
110
+ "47": 42,
111
+ "48": 43,
112
+ "49": 44,
113
+ "50": 45,
114
+ "51": 46,
115
+ "52": 47,
116
+ "53": 48,
117
+ "54": 49,
118
+ "55": 50,
119
+ "56": 51,
120
+ "57": 52,
121
+ "58": 53,
122
+ "59": 54,
123
+ "60": 55,
124
+ "61": 56,
125
+ "62": 57,
126
+ "63": 58,
127
+ "64": 59,
128
+ "65": 60,
129
+ "67": 61,
130
+ "70": 62,
131
+ "72": 63,
132
+ "73": 64,
133
+ "74": 65,
134
+ "75": 66,
135
+ "76": 67,
136
+ "77": 68,
137
+ "78": 69,
138
+ "79": 70,
139
+ "80": 71,
140
+ "81": 72,
141
+ "82": 73,
142
+ "84": 74,
143
+ "85": 75,
144
+ "86": 76,
145
+ "87": 77,
146
+ "88": 78,
147
+ "89": 79,
148
+ "90": 80,
149
+ }
150
+ self.origin2compact_mapper = {int(k): v - 1 for k, v in self.category_map_str.items()}
151
+ self.compact2origin_mapper = {int(v - 1): int(k) for k, v in self.category_map_str.items()}
152
+
153
+ def origin2compact(self, idx):
154
+ return self.origin2compact_mapper[int(idx)]
155
+
156
+ def compact2origin(self, idx):
157
+ return self.compact2origin_mapper[int(idx)]
158
+
159
+
160
+ def to_device(item, device):
161
+ if isinstance(item, torch.Tensor):
162
+ return item.to(device)
163
+ elif isinstance(item, list):
164
+ return [to_device(i, device) for i in item]
165
+ elif isinstance(item, dict):
166
+ return {k: to_device(v, device) for k, v in item.items()}
167
+ else:
168
+ raise NotImplementedError(
169
+ "Call Shilong if you use other containers! type: {}".format(type(item))
170
+ )
171
+
172
+
173
+ #
174
+ def get_gaussian_mean(x, axis, other_axis, softmax=True):
175
+ """
176
+
177
+ Args:
178
+ x (float): Input images(BxCxHxW)
179
+ axis (int): The index for weighted mean
180
+ other_axis (int): The other index
181
+
182
+ Returns: weighted index for axis, BxC
183
+
184
+ """
185
+ mat2line = torch.sum(x, axis=other_axis)
186
+ # mat2line = mat2line / mat2line.mean() * 10
187
+ if softmax:
188
+ u = torch.softmax(mat2line, axis=2)
189
+ else:
190
+ u = mat2line / (mat2line.sum(2, keepdim=True) + 1e-6)
191
+ size = x.shape[axis]
192
+ ind = torch.linspace(0, 1, size).to(x.device)
193
+ batch = x.shape[0]
194
+ channel = x.shape[1]
195
+ index = ind.repeat([batch, channel, 1])
196
+ mean_position = torch.sum(index * u, dim=2)
197
+ return mean_position
198
+
199
+
200
+ def get_expected_points_from_map(hm, softmax=True):
201
+ """get_gaussian_map_from_points
202
+ B,C,H,W -> B,N,2 float(0, 1) float(0, 1)
203
+ softargmax function
204
+
205
+ Args:
206
+ hm (float): Input images(BxCxHxW)
207
+
208
+ Returns:
209
+ weighted index for axis, BxCx2. float between 0 and 1.
210
+
211
+ """
212
+ # hm = 10*hm
213
+ B, C, H, W = hm.shape
214
+ y_mean = get_gaussian_mean(hm, 2, 3, softmax=softmax) # B,C
215
+ x_mean = get_gaussian_mean(hm, 3, 2, softmax=softmax) # B,C
216
+ # return torch.cat((x_mean.unsqueeze(-1), y_mean.unsqueeze(-1)), 2)
217
+ return torch.stack([x_mean, y_mean], dim=2)
218
+
219
+
220
+ # Positional encoding (section 5.1)
221
+ # borrow from nerf
222
+ class Embedder:
223
+ def __init__(self, **kwargs):
224
+ self.kwargs = kwargs
225
+ self.create_embedding_fn()
226
+
227
+ def create_embedding_fn(self):
228
+ embed_fns = []
229
+ d = self.kwargs["input_dims"]
230
+ out_dim = 0
231
+ if self.kwargs["include_input"]:
232
+ embed_fns.append(lambda x: x)
233
+ out_dim += d
234
+
235
+ max_freq = self.kwargs["max_freq_log2"]
236
+ N_freqs = self.kwargs["num_freqs"]
237
+
238
+ if self.kwargs["log_sampling"]:
239
+ freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs)
240
+ else:
241
+ freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs)
242
+
243
+ for freq in freq_bands:
244
+ for p_fn in self.kwargs["periodic_fns"]:
245
+ embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
246
+ out_dim += d
247
+
248
+ self.embed_fns = embed_fns
249
+ self.out_dim = out_dim
250
+
251
+ def embed(self, inputs):
252
+ return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
253
+
254
+
255
+ def get_embedder(multires, i=0):
256
+ import torch.nn as nn
257
+
258
+ if i == -1:
259
+ return nn.Identity(), 3
260
+
261
+ embed_kwargs = {
262
+ "include_input": True,
263
+ "input_dims": 3,
264
+ "max_freq_log2": multires - 1,
265
+ "num_freqs": multires,
266
+ "log_sampling": True,
267
+ "periodic_fns": [torch.sin, torch.cos],
268
+ }
269
+
270
+ embedder_obj = Embedder(**embed_kwargs)
271
+ embed = lambda x, eo=embedder_obj: eo.embed(x)
272
+ return embed, embedder_obj.out_dim
273
+
274
+
275
+ class APOPMeter:
276
+ def __init__(self) -> None:
277
+ self.tp = 0
278
+ self.fp = 0
279
+ self.tn = 0
280
+ self.fn = 0
281
+
282
+ def update(self, pred, gt):
283
+ """
284
+ Input:
285
+ pred, gt: Tensor()
286
+ """
287
+ assert pred.shape == gt.shape
288
+ self.tp += torch.logical_and(pred == 1, gt == 1).sum().item()
289
+ self.fp += torch.logical_and(pred == 1, gt == 0).sum().item()
290
+ self.tn += torch.logical_and(pred == 0, gt == 0).sum().item()
291
+ self.tn += torch.logical_and(pred == 1, gt == 0).sum().item()
292
+
293
+ def update_cm(self, tp, fp, tn, fn):
294
+ self.tp += tp
295
+ self.fp += fp
296
+ self.tn += tn
297
+ self.tn += fn
298
+
299
+
300
+ def inverse_sigmoid(x, eps=1e-5):
301
+ x = x.clamp(min=0, max=1)
302
+ x1 = x.clamp(min=eps)
303
+ x2 = (1 - x).clamp(min=eps)
304
+ return torch.log(x1 / x2)
305
+
306
+
307
+ def get_raw_dict(args):
308
+ """
309
+ return the dicf contained in args.
310
+
311
+ e.g:
312
+ >>> with open(path, 'w') as f:
313
+ json.dump(get_raw_dict(args), f, indent=2)
314
+ """
315
+ if isinstance(args, argparse.Namespace):
316
+ return vars(args)
317
+ elif isinstance(args, dict):
318
+ return args
319
+ elif isinstance(args, SLConfig):
320
+ return args._cfg_dict
321
+ else:
322
+ raise NotImplementedError("Unknown type {}".format(type(args)))
323
+
324
+
325
+ def stat_tensors(tensor):
326
+ assert tensor.dim() == 1
327
+ tensor_sm = tensor.softmax(0)
328
+ entropy = (tensor_sm * torch.log(tensor_sm + 1e-9)).sum()
329
+
330
+ return {
331
+ "max": tensor.max(),
332
+ "min": tensor.min(),
333
+ "mean": tensor.mean(),
334
+ "var": tensor.var(),
335
+ "std": tensor.var() ** 0.5,
336
+ "entropy": entropy,
337
+ }
338
+
339
+
340
+ class NiceRepr:
341
+ """Inherit from this class and define ``__nice__`` to "nicely" print your
342
+ objects.
343
+
344
+ Defines ``__str__`` and ``__repr__`` in terms of ``__nice__`` function
345
+ Classes that inherit from :class:`NiceRepr` should redefine ``__nice__``.
346
+ If the inheriting class has a ``__len__``, method then the default
347
+ ``__nice__`` method will return its length.
348
+
349
+ Example:
350
+ >>> class Foo(NiceRepr):
351
+ ... def __nice__(self):
352
+ ... return 'info'
353
+ >>> foo = Foo()
354
+ >>> assert str(foo) == '<Foo(info)>'
355
+ >>> assert repr(foo).startswith('<Foo(info) at ')
356
+
357
+ Example:
358
+ >>> class Bar(NiceRepr):
359
+ ... pass
360
+ >>> bar = Bar()
361
+ >>> import pytest
362
+ >>> with pytest.warns(None) as record:
363
+ >>> assert 'object at' in str(bar)
364
+ >>> assert 'object at' in repr(bar)
365
+
366
+ Example:
367
+ >>> class Baz(NiceRepr):
368
+ ... def __len__(self):
369
+ ... return 5
370
+ >>> baz = Baz()
371
+ >>> assert str(baz) == '<Baz(5)>'
372
+ """
373
+
374
+ def __nice__(self):
375
+ """str: a "nice" summary string describing this module"""
376
+ if hasattr(self, "__len__"):
377
+ # It is a common pattern for objects to use __len__ in __nice__
378
+ # As a convenience we define a default __nice__ for these objects
379
+ return str(len(self))
380
+ else:
381
+ # In all other cases force the subclass to overload __nice__
382
+ raise NotImplementedError(f"Define the __nice__ method for {self.__class__!r}")
383
+
384
+ def __repr__(self):
385
+ """str: the string of the module"""
386
+ try:
387
+ nice = self.__nice__()
388
+ classname = self.__class__.__name__
389
+ return f"<{classname}({nice}) at {hex(id(self))}>"
390
+ except NotImplementedError as ex:
391
+ warnings.warn(str(ex), category=RuntimeWarning)
392
+ return object.__repr__(self)
393
+
394
+ def __str__(self):
395
+ """str: the string of the module"""
396
+ try:
397
+ classname = self.__class__.__name__
398
+ nice = self.__nice__()
399
+ return f"<{classname}({nice})>"
400
+ except NotImplementedError as ex:
401
+ warnings.warn(str(ex), category=RuntimeWarning)
402
+ return object.__repr__(self)
403
+
404
+
405
+ def ensure_rng(rng=None):
406
+ """Coerces input into a random number generator.
407
+
408
+ If the input is None, then a global random state is returned.
409
+
410
+ If the input is a numeric value, then that is used as a seed to construct a
411
+ random state. Otherwise the input is returned as-is.
412
+
413
+ Adapted from [1]_.
414
+
415
+ Args:
416
+ rng (int | numpy.random.RandomState | None):
417
+ if None, then defaults to the global rng. Otherwise this can be an
418
+ integer or a RandomState class
419
+ Returns:
420
+ (numpy.random.RandomState) : rng -
421
+ a numpy random number generator
422
+
423
+ References:
424
+ .. [1] https://gitlab.kitware.com/computer-vision/kwarray/blob/master/kwarray/util_random.py#L270 # noqa: E501
425
+ """
426
+
427
+ if rng is None:
428
+ rng = np.random.mtrand._rand
429
+ elif isinstance(rng, int):
430
+ rng = np.random.RandomState(rng)
431
+ else:
432
+ rng = rng
433
+ return rng
434
+
435
+
436
+ def random_boxes(num=1, scale=1, rng=None):
437
+ """Simple version of ``kwimage.Boxes.random``
438
+
439
+ Returns:
440
+ Tensor: shape (n, 4) in x1, y1, x2, y2 format.
441
+
442
+ References:
443
+ https://gitlab.kitware.com/computer-vision/kwimage/blob/master/kwimage/structs/boxes.py#L1390
444
+
445
+ Example:
446
+ >>> num = 3
447
+ >>> scale = 512
448
+ >>> rng = 0
449
+ >>> boxes = random_boxes(num, scale, rng)
450
+ >>> print(boxes)
451
+ tensor([[280.9925, 278.9802, 308.6148, 366.1769],
452
+ [216.9113, 330.6978, 224.0446, 456.5878],
453
+ [405.3632, 196.3221, 493.3953, 270.7942]])
454
+ """
455
+ rng = ensure_rng(rng)
456
+
457
+ tlbr = rng.rand(num, 4).astype(np.float32)
458
+
459
+ tl_x = np.minimum(tlbr[:, 0], tlbr[:, 2])
460
+ tl_y = np.minimum(tlbr[:, 1], tlbr[:, 3])
461
+ br_x = np.maximum(tlbr[:, 0], tlbr[:, 2])
462
+ br_y = np.maximum(tlbr[:, 1], tlbr[:, 3])
463
+
464
+ tlbr[:, 0] = tl_x * scale
465
+ tlbr[:, 1] = tl_y * scale
466
+ tlbr[:, 2] = br_x * scale
467
+ tlbr[:, 3] = br_y * scale
468
+
469
+ boxes = torch.from_numpy(tlbr)
470
+ return boxes
471
+
472
+
473
+ class ModelEma(torch.nn.Module):
474
+ def __init__(self, model, decay=0.9997, device=None):
475
+ super(ModelEma, self).__init__()
476
+ # make a copy of the model for accumulating moving average of weights
477
+ self.module = deepcopy(model)
478
+ self.module.eval()
479
+
480
+ # import ipdb; ipdb.set_trace()
481
+
482
+ self.decay = decay
483
+ self.device = device # perform ema on different device from model if set
484
+ if self.device is not None:
485
+ self.module.to(device=device)
486
+
487
+ def _update(self, model, update_fn):
488
+ with torch.no_grad():
489
+ for ema_v, model_v in zip(
490
+ self.module.state_dict().values(), model.state_dict().values()
491
+ ):
492
+ if self.device is not None:
493
+ model_v = model_v.to(device=self.device)
494
+ ema_v.copy_(update_fn(ema_v, model_v))
495
+
496
+ def update(self, model):
497
+ self._update(model, update_fn=lambda e, m: self.decay * e + (1.0 - self.decay) * m)
498
+
499
+ def set(self, model):
500
+ self._update(model, update_fn=lambda e, m: m)
501
+
502
+
503
+ class BestMetricSingle:
504
+ def __init__(self, init_res=0.0, better="large") -> None:
505
+ self.init_res = init_res
506
+ self.best_res = init_res
507
+ self.best_ep = -1
508
+
509
+ self.better = better
510
+ assert better in ["large", "small"]
511
+
512
+ def isbetter(self, new_res, old_res):
513
+ if self.better == "large":
514
+ return new_res > old_res
515
+ if self.better == "small":
516
+ return new_res < old_res
517
+
518
+ def update(self, new_res, ep):
519
+ if self.isbetter(new_res, self.best_res):
520
+ self.best_res = new_res
521
+ self.best_ep = ep
522
+ return True
523
+ return False
524
+
525
+ def __str__(self) -> str:
526
+ return "best_res: {}\t best_ep: {}".format(self.best_res, self.best_ep)
527
+
528
+ def __repr__(self) -> str:
529
+ return self.__str__()
530
+
531
+ def summary(self) -> dict:
532
+ return {
533
+ "best_res": self.best_res,
534
+ "best_ep": self.best_ep,
535
+ }
536
+
537
+
538
+ class BestMetricHolder:
539
+ def __init__(self, init_res=0.0, better="large", use_ema=False) -> None:
540
+ self.best_all = BestMetricSingle(init_res, better)
541
+ self.use_ema = use_ema
542
+ if use_ema:
543
+ self.best_ema = BestMetricSingle(init_res, better)
544
+ self.best_regular = BestMetricSingle(init_res, better)
545
+
546
+ def update(self, new_res, epoch, is_ema=False):
547
+ """
548
+ return if the results is the best.
549
+ """
550
+ if not self.use_ema:
551
+ return self.best_all.update(new_res, epoch)
552
+ else:
553
+ if is_ema:
554
+ self.best_ema.update(new_res, epoch)
555
+ return self.best_all.update(new_res, epoch)
556
+ else:
557
+ self.best_regular.update(new_res, epoch)
558
+ return self.best_all.update(new_res, epoch)
559
+
560
+ def summary(self):
561
+ if not self.use_ema:
562
+ return self.best_all.summary()
563
+
564
+ res = {}
565
+ res.update({f"all_{k}": v for k, v in self.best_all.summary().items()})
566
+ res.update({f"regular_{k}": v for k, v in self.best_regular.summary().items()})
567
+ res.update({f"ema_{k}": v for k, v in self.best_ema.summary().items()})
568
+ return res
569
+
570
+ def __repr__(self) -> str:
571
+ return json.dumps(self.summary(), indent=2)
572
+
573
+ def __str__(self) -> str:
574
+ return self.__repr__()
575
+
576
+
577
+ def targets_to(targets: List[Dict[str, Any]], device):
578
+ """Moves the target dicts to the given device."""
579
+ excluded_keys = [
580
+ "questionId",
581
+ "tokens_positive",
582
+ "strings_positive",
583
+ "tokens",
584
+ "dataset_name",
585
+ "sentence_id",
586
+ "original_img_id",
587
+ "nb_eval",
588
+ "task_id",
589
+ "original_id",
590
+ "token_span",
591
+ "caption",
592
+ "dataset_type",
593
+ ]
594
+ return [
595
+ {k: v.to(device) if k not in excluded_keys else v for k, v in t.items()} for t in targets
596
+ ]
597
+
598
+
599
+ def get_phrases_from_posmap(
600
+ posmap: torch.BoolTensor, tokenized: Dict, tokenizer: AutoTokenizer, left_idx: int = 0, right_idx: int = 255
601
+ ):
602
+ assert isinstance(posmap, torch.Tensor), "posmap must be torch.Tensor"
603
+ if posmap.dim() == 1:
604
+ posmap[0: left_idx + 1] = False
605
+ posmap[right_idx:] = False
606
+ non_zero_idx = posmap.nonzero(as_tuple=True)[0].tolist()
607
+ token_ids = [tokenized["input_ids"][i] for i in non_zero_idx]
608
+ return tokenizer.decode(token_ids)
609
+ else:
610
+ raise NotImplementedError("posmap must be 1-dim")
groundingdino/util/visualizer.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @File : visualizer.py
4
+ @Time : 2022/04/05 11:39:33
5
+ @Author : Shilong Liu
6
+ @Contact : [email protected]
7
+ """
8
+
9
+ import datetime
10
+ import os
11
+
12
+ import cv2
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
15
+ import torch
16
+ from matplotlib import transforms
17
+ from matplotlib.collections import PatchCollection
18
+ from matplotlib.patches import Polygon
19
+ from pycocotools import mask as maskUtils
20
+
21
+
22
+ def renorm(
23
+ img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
24
+ ) -> torch.FloatTensor:
25
+ # img: tensor(3,H,W) or tensor(B,3,H,W)
26
+ # return: same as img
27
+ assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim()
28
+ if img.dim() == 3:
29
+ assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % (
30
+ img.size(0),
31
+ str(img.size()),
32
+ )
33
+ img_perm = img.permute(1, 2, 0)
34
+ mean = torch.Tensor(mean)
35
+ std = torch.Tensor(std)
36
+ img_res = img_perm * std + mean
37
+ return img_res.permute(2, 0, 1)
38
+ else: # img.dim() == 4
39
+ assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % (
40
+ img.size(1),
41
+ str(img.size()),
42
+ )
43
+ img_perm = img.permute(0, 2, 3, 1)
44
+ mean = torch.Tensor(mean)
45
+ std = torch.Tensor(std)
46
+ img_res = img_perm * std + mean
47
+ return img_res.permute(0, 3, 1, 2)
48
+
49
+
50
+ class ColorMap:
51
+ def __init__(self, basergb=[255, 255, 0]):
52
+ self.basergb = np.array(basergb)
53
+
54
+ def __call__(self, attnmap):
55
+ # attnmap: h, w. np.uint8.
56
+ # return: h, w, 4. np.uint8.
57
+ assert attnmap.dtype == np.uint8
58
+ h, w = attnmap.shape
59
+ res = self.basergb.copy()
60
+ res = res[None][None].repeat(h, 0).repeat(w, 1) # h, w, 3
61
+ attn1 = attnmap.copy()[..., None] # h, w, 1
62
+ res = np.concatenate((res, attn1), axis=-1).astype(np.uint8)
63
+ return res
64
+
65
+
66
+ def rainbow_text(x, y, ls, lc, **kw):
67
+ """
68
+ Take a list of strings ``ls`` and colors ``lc`` and place them next to each
69
+ other, with text ls[i] being shown in color lc[i].
70
+
71
+ This example shows how to do both vertical and horizontal text, and will
72
+ pass all keyword arguments to plt.text, so you can set the font size,
73
+ family, etc.
74
+ """
75
+ t = plt.gca().transData
76
+ fig = plt.gcf()
77
+ plt.show()
78
+
79
+ # horizontal version
80
+ for s, c in zip(ls, lc):
81
+ text = plt.text(x, y, " " + s + " ", color=c, transform=t, **kw)
82
+ text.draw(fig.canvas.get_renderer())
83
+ ex = text.get_window_extent()
84
+ t = transforms.offset_copy(text._transform, x=ex.width, units="dots")
85
+
86
+ # #vertical version
87
+ # for s,c in zip(ls,lc):
88
+ # text = plt.text(x,y," "+s+" ",color=c, transform=t,
89
+ # rotation=90,va='bottom',ha='center',**kw)
90
+ # text.draw(fig.canvas.get_renderer())
91
+ # ex = text.get_window_extent()
92
+ # t = transforms.offset_copy(text._transform, y=ex.height, units='dots')
93
+
94
+
95
+ class COCOVisualizer:
96
+ def __init__(self, coco=None, tokenlizer=None) -> None:
97
+ self.coco = coco
98
+
99
+ def visualize(self, img, tgt, caption=None, dpi=180, savedir="vis"):
100
+ """
101
+ img: tensor(3, H, W)
102
+ tgt: make sure they are all on cpu.
103
+ must have items: 'image_id', 'boxes', 'size'
104
+ """
105
+ plt.figure(dpi=dpi)
106
+ plt.rcParams["font.size"] = "5"
107
+ ax = plt.gca()
108
+ img = renorm(img).permute(1, 2, 0)
109
+ # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
110
+ # import ipdb; ipdb.set_trace()
111
+ ax.imshow(img)
112
+
113
+ self.addtgt(tgt)
114
+
115
+ if tgt is None:
116
+ image_id = 0
117
+ elif "image_id" not in tgt:
118
+ image_id = 0
119
+ else:
120
+ image_id = tgt["image_id"]
121
+
122
+ if caption is None:
123
+ savename = "{}/{}-{}.png".format(
124
+ savedir, int(image_id), str(datetime.datetime.now()).replace(" ", "-")
125
+ )
126
+ else:
127
+ savename = "{}/{}-{}-{}.png".format(
128
+ savedir, caption, int(image_id), str(datetime.datetime.now()).replace(" ", "-")
129
+ )
130
+ print("savename: {}".format(savename))
131
+ os.makedirs(os.path.dirname(savename), exist_ok=True)
132
+ plt.savefig(savename)
133
+ plt.close()
134
+
135
+ def addtgt(self, tgt):
136
+ """ """
137
+ if tgt is None or not "boxes" in tgt:
138
+ ax = plt.gca()
139
+
140
+ if "caption" in tgt:
141
+ ax.set_title(tgt["caption"], wrap=True)
142
+
143
+ ax.set_axis_off()
144
+ return
145
+
146
+ ax = plt.gca()
147
+ H, W = tgt["size"]
148
+ numbox = tgt["boxes"].shape[0]
149
+
150
+ color = []
151
+ polygons = []
152
+ boxes = []
153
+ for box in tgt["boxes"].cpu():
154
+ unnormbbox = box * torch.Tensor([W, H, W, H])
155
+ unnormbbox[:2] -= unnormbbox[2:] / 2
156
+ [bbox_x, bbox_y, bbox_w, bbox_h] = unnormbbox.tolist()
157
+ boxes.append([bbox_x, bbox_y, bbox_w, bbox_h])
158
+ poly = [
159
+ [bbox_x, bbox_y],
160
+ [bbox_x, bbox_y + bbox_h],
161
+ [bbox_x + bbox_w, bbox_y + bbox_h],
162
+ [bbox_x + bbox_w, bbox_y],
163
+ ]
164
+ np_poly = np.array(poly).reshape((4, 2))
165
+ polygons.append(Polygon(np_poly))
166
+ c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0]
167
+ color.append(c)
168
+
169
+ p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.1)
170
+ ax.add_collection(p)
171
+ p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2)
172
+ ax.add_collection(p)
173
+
174
+ if "strings_positive" in tgt and len(tgt["strings_positive"]) > 0:
175
+ assert (
176
+ len(tgt["strings_positive"]) == numbox
177
+ ), f"{len(tgt['strings_positive'])} = {numbox}, "
178
+ for idx, strlist in enumerate(tgt["strings_positive"]):
179
+ cate_id = int(tgt["labels"][idx])
180
+ _string = str(cate_id) + ":" + " ".join(strlist)
181
+ bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx]
182
+ # ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': 'yellow', 'alpha': 1.0, 'pad': 1})
183
+ ax.text(
184
+ bbox_x,
185
+ bbox_y,
186
+ _string,
187
+ color="black",
188
+ bbox={"facecolor": color[idx], "alpha": 0.6, "pad": 1},
189
+ )
190
+
191
+ if "box_label" in tgt:
192
+ assert len(tgt["box_label"]) == numbox, f"{len(tgt['box_label'])} = {numbox}, "
193
+ for idx, bl in enumerate(tgt["box_label"]):
194
+ _string = str(bl)
195
+ bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx]
196
+ # ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': 'yellow', 'alpha': 1.0, 'pad': 1})
197
+ ax.text(
198
+ bbox_x,
199
+ bbox_y,
200
+ _string,
201
+ color="black",
202
+ bbox={"facecolor": color[idx], "alpha": 0.6, "pad": 1},
203
+ )
204
+
205
+ if "caption" in tgt:
206
+ ax.set_title(tgt["caption"], wrap=True)
207
+ # plt.figure()
208
+ # rainbow_text(0.0,0.0,"all unicorns poop rainbows ! ! !".split(),
209
+ # ['red', 'orange', 'brown', 'green', 'blue', 'purple', 'black'])
210
+
211
+ if "attn" in tgt:
212
+ # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
213
+ # import ipdb; ipdb.set_trace()
214
+ if isinstance(tgt["attn"], tuple):
215
+ tgt["attn"] = [tgt["attn"]]
216
+ for item in tgt["attn"]:
217
+ attn_map, basergb = item
218
+ attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-3)
219
+ attn_map = (attn_map * 255).astype(np.uint8)
220
+ cm = ColorMap(basergb)
221
+ heatmap = cm(attn_map)
222
+ ax.imshow(heatmap)
223
+ ax.set_axis_off()
224
+
225
+ def showAnns(self, anns, draw_bbox=False):
226
+ """
227
+ Display the specified annotations.
228
+ :param anns (array of object): annotations to display
229
+ :return: None
230
+ """
231
+ if len(anns) == 0:
232
+ return 0
233
+ if "segmentation" in anns[0] or "keypoints" in anns[0]:
234
+ datasetType = "instances"
235
+ elif "caption" in anns[0]:
236
+ datasetType = "captions"
237
+ else:
238
+ raise Exception("datasetType not supported")
239
+ if datasetType == "instances":
240
+ ax = plt.gca()
241
+ ax.set_autoscale_on(False)
242
+ polygons = []
243
+ color = []
244
+ for ann in anns:
245
+ c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0]
246
+ if "segmentation" in ann:
247
+ if type(ann["segmentation"]) == list:
248
+ # polygon
249
+ for seg in ann["segmentation"]:
250
+ poly = np.array(seg).reshape((int(len(seg) / 2), 2))
251
+ polygons.append(Polygon(poly))
252
+ color.append(c)
253
+ else:
254
+ # mask
255
+ t = self.imgs[ann["image_id"]]
256
+ if type(ann["segmentation"]["counts"]) == list:
257
+ rle = maskUtils.frPyObjects(
258
+ [ann["segmentation"]], t["height"], t["width"]
259
+ )
260
+ else:
261
+ rle = [ann["segmentation"]]
262
+ m = maskUtils.decode(rle)
263
+ img = np.ones((m.shape[0], m.shape[1], 3))
264
+ if ann["iscrowd"] == 1:
265
+ color_mask = np.array([2.0, 166.0, 101.0]) / 255
266
+ if ann["iscrowd"] == 0:
267
+ color_mask = np.random.random((1, 3)).tolist()[0]
268
+ for i in range(3):
269
+ img[:, :, i] = color_mask[i]
270
+ ax.imshow(np.dstack((img, m * 0.5)))
271
+ if "keypoints" in ann and type(ann["keypoints"]) == list:
272
+ # turn skeleton into zero-based index
273
+ sks = np.array(self.loadCats(ann["category_id"])[0]["skeleton"]) - 1
274
+ kp = np.array(ann["keypoints"])
275
+ x = kp[0::3]
276
+ y = kp[1::3]
277
+ v = kp[2::3]
278
+ for sk in sks:
279
+ if np.all(v[sk] > 0):
280
+ plt.plot(x[sk], y[sk], linewidth=3, color=c)
281
+ plt.plot(
282
+ x[v > 0],
283
+ y[v > 0],
284
+ "o",
285
+ markersize=8,
286
+ markerfacecolor=c,
287
+ markeredgecolor="k",
288
+ markeredgewidth=2,
289
+ )
290
+ plt.plot(
291
+ x[v > 1],
292
+ y[v > 1],
293
+ "o",
294
+ markersize=8,
295
+ markerfacecolor=c,
296
+ markeredgecolor=c,
297
+ markeredgewidth=2,
298
+ )
299
+
300
+ if draw_bbox:
301
+ [bbox_x, bbox_y, bbox_w, bbox_h] = ann["bbox"]
302
+ poly = [
303
+ [bbox_x, bbox_y],
304
+ [bbox_x, bbox_y + bbox_h],
305
+ [bbox_x + bbox_w, bbox_y + bbox_h],
306
+ [bbox_x + bbox_w, bbox_y],
307
+ ]
308
+ np_poly = np.array(poly).reshape((4, 2))
309
+ polygons.append(Polygon(np_poly))
310
+ color.append(c)
311
+
312
+ # p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4)
313
+ # ax.add_collection(p)
314
+ p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2)
315
+ ax.add_collection(p)
316
+ elif datasetType == "captions":
317
+ for ann in anns:
318
+ print(ann["caption"])
groundingdino/util/vl_utils.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from typing import List
4
+
5
+ import torch
6
+
7
+
8
+ def create_positive_map_from_span(tokenized, token_span, max_text_len=256):
9
+ """construct a map such that positive_map[i,j] = True iff box i is associated to token j
10
+ Input:
11
+ - tokenized:
12
+ - input_ids: Tensor[1, ntokens]
13
+ - attention_mask: Tensor[1, ntokens]
14
+ - token_span: list with length num_boxes.
15
+ - each item: [start_idx, end_idx]
16
+ """
17
+ positive_map = torch.zeros((len(token_span), max_text_len), dtype=torch.float)
18
+ for j, tok_list in enumerate(token_span):
19
+ for (beg, end) in tok_list:
20
+ beg_pos = tokenized.char_to_token(beg)
21
+ end_pos = tokenized.char_to_token(end - 1)
22
+ if beg_pos is None:
23
+ try:
24
+ beg_pos = tokenized.char_to_token(beg + 1)
25
+ if beg_pos is None:
26
+ beg_pos = tokenized.char_to_token(beg + 2)
27
+ except:
28
+ beg_pos = None
29
+ if end_pos is None:
30
+ try:
31
+ end_pos = tokenized.char_to_token(end - 2)
32
+ if end_pos is None:
33
+ end_pos = tokenized.char_to_token(end - 3)
34
+ except:
35
+ end_pos = None
36
+ if beg_pos is None or end_pos is None:
37
+ continue
38
+
39
+ assert beg_pos is not None and end_pos is not None
40
+ if os.environ.get("SHILONG_DEBUG_ONLY_ONE_POS", None) == "TRUE":
41
+ positive_map[j, beg_pos] = 1
42
+ break
43
+ else:
44
+ positive_map[j, beg_pos : end_pos + 1].fill_(1)
45
+
46
+ return positive_map / (positive_map.sum(-1)[:, None] + 1e-6)
47
+
48
+
49
+ def build_captions_and_token_span(cat_list, force_lowercase):
50
+ """
51
+ Return:
52
+ captions: str
53
+ cat2tokenspan: dict
54
+ {
55
+ 'dog': [[0, 2]],
56
+ ...
57
+ }
58
+ """
59
+
60
+ cat2tokenspan = {}
61
+ captions = ""
62
+ for catname in cat_list:
63
+ class_name = catname
64
+ if force_lowercase:
65
+ class_name = class_name.lower()
66
+ if "/" in class_name:
67
+ class_name_list: List = class_name.strip().split("/")
68
+ class_name_list.append(class_name)
69
+ class_name: str = random.choice(class_name_list)
70
+
71
+ tokens_positive_i = []
72
+ subnamelist = [i.strip() for i in class_name.strip().split(" ")]
73
+ for subname in subnamelist:
74
+ if len(subname) == 0:
75
+ continue
76
+ if len(captions) > 0:
77
+ captions = captions + " "
78
+ strat_idx = len(captions)
79
+ end_idx = strat_idx + len(subname)
80
+ tokens_positive_i.append([strat_idx, end_idx])
81
+ captions = captions + subname
82
+
83
+ if len(tokens_positive_i) > 0:
84
+ captions = captions + " ."
85
+ cat2tokenspan[class_name] = tokens_positive_i
86
+
87
+ return captions, cat2tokenspan
88
+
89
+
90
+ def build_id2posspan_and_caption(category_dict: dict):
91
+ """Build id2pos_span and caption from category_dict
92
+
93
+ Args:
94
+ category_dict (dict): category_dict
95
+ """
96
+ cat_list = [item["name"].lower() for item in category_dict]
97
+ id2catname = {item["id"]: item["name"].lower() for item in category_dict}
98
+ caption, cat2posspan = build_captions_and_token_span(cat_list, force_lowercase=True)
99
+ id2posspan = {catid: cat2posspan[catname] for catid, catname in id2catname.items()}
100
+ return id2posspan, caption
run.py CHANGED
@@ -1,5 +1,6 @@
1
  import argparse
2
  <<<<<<< HEAD
 
3
  from functools import partial
4
  import cv2
5
  import requests
@@ -125,6 +126,8 @@ if __name__ == "__main__":
125
  block.launch(server_name='0.0.0.0', server_port=7579, debug=args.debug, share=args.share)
126
 
127
  =======
 
 
128
  import os
129
  import numpy as np
130
  import torch
@@ -337,4 +340,7 @@ if __name__ == "__main__":
337
  save_path = os.path.join(output_dir, "pred.jpg")
338
  image_with_box.save(save_path)
339
  print(f"\n======================\n{save_path} saved.\nThe program runs successfully!")
 
 
 
340
  >>>>>>> e7662d3789ee2d5b878c7399e1f04cb075927919
 
1
  import argparse
2
  <<<<<<< HEAD
3
+ <<<<<<< HEAD
4
  from functools import partial
5
  import cv2
6
  import requests
 
126
  block.launch(server_name='0.0.0.0', server_port=7579, debug=args.debug, share=args.share)
127
 
128
  =======
129
+ =======
130
+ >>>>>>> e7662d3789ee2d5b878c7399e1f04cb075927919
131
  import os
132
  import numpy as np
133
  import torch
 
340
  save_path = os.path.join(output_dir, "pred.jpg")
341
  image_with_box.save(save_path)
342
  print(f"\n======================\n{save_path} saved.\nThe program runs successfully!")
343
+ <<<<<<< HEAD
344
+ >>>>>>> e7662d3789ee2d5b878c7399e1f04cb075927919
345
+ =======
346
  >>>>>>> e7662d3789ee2d5b878c7399e1f04cb075927919