Hasanmog commited on
Commit
4c356a6
·
1 Parent(s): 676fc6d
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. groundingdino/.ipynb_checkpoints/__init__-checkpoint.py +0 -0
  2. groundingdino/_C.cpython-310-x86_64-linux-gnu.so +3 -0
  3. groundingdino/__init__.py +0 -0
  4. groundingdino/__pycache__/__init__.cpython-310.pyc +0 -0
  5. groundingdino/config/.ipynb_checkpoints/GroundingDINO_SwinB_cfg-checkpoint.py +43 -0
  6. groundingdino/config/GroundingDINO_SwinB_cfg.py +43 -0
  7. groundingdino/config/GroundingDINO_SwinT_OGC.py +43 -0
  8. groundingdino/config/__init__.py +0 -0
  9. groundingdino/datasets/__init__.py +0 -0
  10. groundingdino/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  11. groundingdino/datasets/__pycache__/transforms.cpython-310.pyc +0 -0
  12. groundingdino/datasets/cocogrounding_eval.py +269 -0
  13. groundingdino/datasets/transforms.py +311 -0
  14. groundingdino/models/.ipynb_checkpoints/__init__-checkpoint.py +18 -0
  15. groundingdino/models/.ipynb_checkpoints/registry-checkpoint.py +66 -0
  16. groundingdino/models/GroundingDINO/.ipynb_checkpoints/fuse_modules-checkpoint.py +298 -0
  17. groundingdino/models/GroundingDINO/.ipynb_checkpoints/groundingdino-checkpoint.py +412 -0
  18. groundingdino/models/GroundingDINO/.ipynb_checkpoints/ms_deform_attn-checkpoint.py +414 -0
  19. groundingdino/models/GroundingDINO/.ipynb_checkpoints/transformer-checkpoint.py +961 -0
  20. groundingdino/models/GroundingDINO/.ipynb_checkpoints/transformer_vanilla-checkpoint.py +124 -0
  21. groundingdino/models/GroundingDINO/.ipynb_checkpoints/utils-checkpoint.py +269 -0
  22. groundingdino/models/GroundingDINO/__pycache__/__init__.cpython-310.pyc +0 -0
  23. groundingdino/models/GroundingDINO/__pycache__/bertwarper.cpython-310.pyc +0 -0
  24. groundingdino/models/GroundingDINO/__pycache__/fuse_modules.cpython-310.pyc +0 -0
  25. groundingdino/models/GroundingDINO/__pycache__/groundingdino.cpython-310.pyc +0 -0
  26. groundingdino/models/GroundingDINO/__pycache__/ms_deform_attn.cpython-310.pyc +0 -0
  27. groundingdino/models/GroundingDINO/__pycache__/transformer.cpython-310.pyc +0 -0
  28. groundingdino/models/GroundingDINO/__pycache__/transformer_vanilla.cpython-310.pyc +0 -0
  29. groundingdino/models/GroundingDINO/__pycache__/utils.cpython-310.pyc +0 -0
  30. groundingdino/models/GroundingDINO/backbone/.ipynb_checkpoints/backbone-checkpoint.py +220 -0
  31. groundingdino/models/GroundingDINO/backbone/.ipynb_checkpoints/position_encoding-checkpoint.py +186 -0
  32. groundingdino/models/GroundingDINO/backbone/.ipynb_checkpoints/swin_transformer-checkpoint.py +804 -0
  33. groundingdino/models/GroundingDINO/backbone/__pycache__/__init__.cpython-310.pyc +0 -0
  34. groundingdino/models/GroundingDINO/backbone/__pycache__/backbone.cpython-310.pyc +0 -0
  35. groundingdino/models/GroundingDINO/backbone/__pycache__/position_encoding.cpython-310.pyc +0 -0
  36. groundingdino/models/GroundingDINO/backbone/__pycache__/swin_transformer.cpython-310.pyc +0 -0
  37. groundingdino/models/GroundingDINO/backbone/backbone.py +220 -0
  38. groundingdino/models/GroundingDINO/backbone/position_encoding.py +186 -0
  39. groundingdino/models/GroundingDINO/backbone/swin_transformer.py +804 -0
  40. groundingdino/models/GroundingDINO/fuse_modules.py +298 -0
  41. groundingdino/models/GroundingDINO/groundingdino.py +412 -0
  42. groundingdino/models/GroundingDINO/ms_deform_attn.py +414 -0
  43. groundingdino/models/GroundingDINO/transformer.py +961 -0
  44. groundingdino/models/GroundingDINO/transformer_vanilla.py +124 -0
  45. groundingdino/models/GroundingDINO/utils.py +269 -0
  46. groundingdino/models/__init__.py +18 -0
  47. groundingdino/models/__pycache__/__init__.cpython-310.pyc +0 -0
  48. groundingdino/models/__pycache__/registry.cpython-310.pyc +0 -0
  49. groundingdino/models/registry.py +66 -0
  50. groundingdino/util/__pycache__/__init__.cpython-310.pyc +0 -0
groundingdino/.ipynb_checkpoints/__init__-checkpoint.py ADDED
File without changes
groundingdino/_C.cpython-310-x86_64-linux-gnu.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee590993174e3c9b9b4110d8c90aee2873edec5a61da94c577d381b320791520
3
+ size 9940696
groundingdino/__init__.py ADDED
File without changes
groundingdino/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (182 Bytes). View file
 
groundingdino/config/.ipynb_checkpoints/GroundingDINO_SwinB_cfg-checkpoint.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size = 1
2
+ modelname = "groundingdino"
3
+ backbone = "swin_B_384_22k"
4
+ position_embedding = "sine"
5
+ pe_temperatureH = 20
6
+ pe_temperatureW = 20
7
+ return_interm_indices = [1, 2, 3]
8
+ backbone_freeze_keywords = None
9
+ enc_layers = 6
10
+ dec_layers = 6
11
+ pre_norm = False
12
+ dim_feedforward = 2048
13
+ hidden_dim = 256
14
+ dropout = 0.0
15
+ nheads = 8
16
+ num_queries = 900
17
+ query_dim = 4
18
+ num_patterns = 0
19
+ num_feature_levels = 4
20
+ enc_n_points = 4
21
+ dec_n_points = 4
22
+ two_stage_type = "standard"
23
+ two_stage_bbox_embed_share = False
24
+ two_stage_class_embed_share = False
25
+ transformer_activation = "relu"
26
+ dec_pred_bbox_embed_share = True
27
+ dn_box_noise_scale = 1.0
28
+ dn_label_noise_ratio = 0.5
29
+ dn_label_coef = 1.0
30
+ dn_bbox_coef = 1.0
31
+ embed_init_tgt = True
32
+ dn_labelbook_size = 2000
33
+ max_text_len = 256
34
+ text_encoder_type = "bert-base-uncased"
35
+ use_text_enhancer = True
36
+ use_fusion_layer = True
37
+ use_checkpoint = True
38
+ use_transformer_ckpt = True
39
+ use_text_cross_attention = True
40
+ text_dropout = 0.0
41
+ fusion_dropout = 0.0
42
+ fusion_droppath = 0.1
43
+ sub_sentence_present = True
groundingdino/config/GroundingDINO_SwinB_cfg.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size = 1
2
+ modelname = "groundingdino"
3
+ backbone = "swin_B_384_22k"
4
+ position_embedding = "sine"
5
+ pe_temperatureH = 20
6
+ pe_temperatureW = 20
7
+ return_interm_indices = [1, 2, 3]
8
+ backbone_freeze_keywords = None
9
+ enc_layers = 6
10
+ dec_layers = 6
11
+ pre_norm = False
12
+ dim_feedforward = 2048
13
+ hidden_dim = 256
14
+ dropout = 0.0
15
+ nheads = 8
16
+ num_queries = 900
17
+ query_dim = 4
18
+ num_patterns = 0
19
+ num_feature_levels = 4
20
+ enc_n_points = 4
21
+ dec_n_points = 4
22
+ two_stage_type = "standard"
23
+ two_stage_bbox_embed_share = False
24
+ two_stage_class_embed_share = False
25
+ transformer_activation = "relu"
26
+ dec_pred_bbox_embed_share = True
27
+ dn_box_noise_scale = 1.0
28
+ dn_label_noise_ratio = 0.5
29
+ dn_label_coef = 1.0
30
+ dn_bbox_coef = 1.0
31
+ embed_init_tgt = True
32
+ dn_labelbook_size = 2000
33
+ max_text_len = 256
34
+ text_encoder_type = "bert-base-uncased"
35
+ use_text_enhancer = True
36
+ use_fusion_layer = True
37
+ use_checkpoint = True
38
+ use_transformer_ckpt = True
39
+ use_text_cross_attention = True
40
+ text_dropout = 0.0
41
+ fusion_dropout = 0.0
42
+ fusion_droppath = 0.1
43
+ sub_sentence_present = True
groundingdino/config/GroundingDINO_SwinT_OGC.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size = 1
2
+ modelname = "groundingdino"
3
+ backbone = "swin_T_224_1k"
4
+ position_embedding = "sine"
5
+ pe_temperatureH = 20
6
+ pe_temperatureW = 20
7
+ return_interm_indices = [1, 2, 3]
8
+ backbone_freeze_keywords = None
9
+ enc_layers = 6
10
+ dec_layers = 6
11
+ pre_norm = False
12
+ dim_feedforward = 2048
13
+ hidden_dim = 256
14
+ dropout = 0.0
15
+ nheads = 8
16
+ num_queries = 900
17
+ query_dim = 4
18
+ num_patterns = 0
19
+ num_feature_levels = 4
20
+ enc_n_points = 4
21
+ dec_n_points = 4
22
+ two_stage_type = "standard"
23
+ two_stage_bbox_embed_share = False
24
+ two_stage_class_embed_share = False
25
+ transformer_activation = "relu"
26
+ dec_pred_bbox_embed_share = True
27
+ dn_box_noise_scale = 1.0
28
+ dn_label_noise_ratio = 0.5
29
+ dn_label_coef = 1.0
30
+ dn_bbox_coef = 1.0
31
+ embed_init_tgt = True
32
+ dn_labelbook_size = 2000
33
+ max_text_len = 256
34
+ text_encoder_type = "bert-base-uncased"
35
+ use_text_enhancer = True
36
+ use_fusion_layer = True
37
+ use_checkpoint = True
38
+ use_transformer_ckpt = True
39
+ use_text_cross_attention = True
40
+ text_dropout = 0.0
41
+ fusion_dropout = 0.0
42
+ fusion_droppath = 0.1
43
+ sub_sentence_present = True
groundingdino/config/__init__.py ADDED
File without changes
groundingdino/datasets/__init__.py ADDED
File without changes
groundingdino/datasets/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (191 Bytes). View file
 
groundingdino/datasets/__pycache__/transforms.cpython-310.pyc ADDED
Binary file (10.2 kB). View file
 
groundingdino/datasets/cocogrounding_eval.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO. Midified by Shilong Liu.
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
8
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
9
+ """
10
+ COCO evaluator that works in distributed mode.
11
+
12
+ Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py
13
+ The difference is that there is less copy-pasting from pycocotools
14
+ in the end of the file, as python3 can suppress prints with contextlib
15
+ """
16
+ import contextlib
17
+ import copy
18
+ import os
19
+
20
+ import numpy as np
21
+ import pycocotools.mask as mask_util
22
+ import torch
23
+ from pycocotools.coco import COCO
24
+ from pycocotools.cocoeval import COCOeval
25
+
26
+ from groundingdino.util.misc import all_gather
27
+
28
+
29
+ class CocoGroundingEvaluator(object):
30
+ def __init__(self, coco_gt, iou_types, useCats=True):
31
+ assert isinstance(iou_types, (list, tuple))
32
+ coco_gt = copy.deepcopy(coco_gt)
33
+ self.coco_gt = coco_gt
34
+
35
+ self.iou_types = iou_types
36
+ self.coco_eval = {}
37
+ for iou_type in iou_types:
38
+ self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
39
+ self.coco_eval[iou_type].useCats = useCats
40
+
41
+ self.img_ids = []
42
+ self.eval_imgs = {k: [] for k in iou_types}
43
+ self.useCats = useCats
44
+
45
+ def update(self, predictions):
46
+ img_ids = list(np.unique(list(predictions.keys())))
47
+ self.img_ids.extend(img_ids)
48
+
49
+ for iou_type in self.iou_types:
50
+ results = self.prepare(predictions, iou_type)
51
+
52
+ # suppress pycocotools prints
53
+ with open(os.devnull, "w") as devnull:
54
+ with contextlib.redirect_stdout(devnull):
55
+ coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
56
+
57
+ coco_eval = self.coco_eval[iou_type]
58
+
59
+ coco_eval.cocoDt = coco_dt
60
+ coco_eval.params.imgIds = list(img_ids)
61
+ coco_eval.params.useCats = self.useCats
62
+ img_ids, eval_imgs = evaluate(coco_eval)
63
+
64
+ self.eval_imgs[iou_type].append(eval_imgs)
65
+
66
+ def synchronize_between_processes(self):
67
+ for iou_type in self.iou_types:
68
+ self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
69
+ create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
70
+
71
+ def accumulate(self):
72
+ for coco_eval in self.coco_eval.values():
73
+ coco_eval.accumulate()
74
+
75
+ def summarize(self):
76
+ for iou_type, coco_eval in self.coco_eval.items():
77
+ print("IoU metric: {}".format(iou_type))
78
+ coco_eval.summarize()
79
+
80
+ def prepare(self, predictions, iou_type):
81
+ if iou_type == "bbox":
82
+ return self.prepare_for_coco_detection(predictions)
83
+ elif iou_type == "segm":
84
+ return self.prepare_for_coco_segmentation(predictions)
85
+ elif iou_type == "keypoints":
86
+ return self.prepare_for_coco_keypoint(predictions)
87
+ else:
88
+ raise ValueError("Unknown iou type {}".format(iou_type))
89
+
90
+ def prepare_for_coco_detection(self, predictions):
91
+ coco_results = []
92
+ for original_id, prediction in predictions.items():
93
+ if len(prediction) == 0:
94
+ continue
95
+
96
+ boxes = prediction["boxes"]
97
+ boxes = convert_to_xywh(boxes).tolist()
98
+ scores = prediction["scores"].tolist()
99
+ labels = prediction["labels"].tolist()
100
+
101
+ coco_results.extend(
102
+ [
103
+ {
104
+ "image_id": original_id,
105
+ "category_id": labels[k],
106
+ "bbox": box,
107
+ "score": scores[k],
108
+ }
109
+ for k, box in enumerate(boxes)
110
+ ]
111
+ )
112
+ return coco_results
113
+
114
+ def prepare_for_coco_segmentation(self, predictions):
115
+ coco_results = []
116
+ for original_id, prediction in predictions.items():
117
+ if len(prediction) == 0:
118
+ continue
119
+
120
+ scores = prediction["scores"]
121
+ labels = prediction["labels"]
122
+ masks = prediction["masks"]
123
+
124
+ masks = masks > 0.5
125
+
126
+ scores = prediction["scores"].tolist()
127
+ labels = prediction["labels"].tolist()
128
+
129
+ rles = [
130
+ mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
131
+ for mask in masks
132
+ ]
133
+ for rle in rles:
134
+ rle["counts"] = rle["counts"].decode("utf-8")
135
+
136
+ coco_results.extend(
137
+ [
138
+ {
139
+ "image_id": original_id,
140
+ "category_id": labels[k],
141
+ "segmentation": rle,
142
+ "score": scores[k],
143
+ }
144
+ for k, rle in enumerate(rles)
145
+ ]
146
+ )
147
+ return coco_results
148
+
149
+ def prepare_for_coco_keypoint(self, predictions):
150
+ coco_results = []
151
+ for original_id, prediction in predictions.items():
152
+ if len(prediction) == 0:
153
+ continue
154
+
155
+ boxes = prediction["boxes"]
156
+ boxes = convert_to_xywh(boxes).tolist()
157
+ scores = prediction["scores"].tolist()
158
+ labels = prediction["labels"].tolist()
159
+ keypoints = prediction["keypoints"]
160
+ keypoints = keypoints.flatten(start_dim=1).tolist()
161
+
162
+ coco_results.extend(
163
+ [
164
+ {
165
+ "image_id": original_id,
166
+ "category_id": labels[k],
167
+ "keypoints": keypoint,
168
+ "score": scores[k],
169
+ }
170
+ for k, keypoint in enumerate(keypoints)
171
+ ]
172
+ )
173
+ return coco_results
174
+
175
+
176
+ def convert_to_xywh(boxes):
177
+ xmin, ymin, xmax, ymax = boxes.unbind(1)
178
+ return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
179
+
180
+
181
+ def merge(img_ids, eval_imgs):
182
+ all_img_ids = all_gather(img_ids)
183
+ all_eval_imgs = all_gather(eval_imgs)
184
+
185
+ merged_img_ids = []
186
+ for p in all_img_ids:
187
+ merged_img_ids.extend(p)
188
+
189
+ merged_eval_imgs = []
190
+ for p in all_eval_imgs:
191
+ merged_eval_imgs.append(p)
192
+
193
+ merged_img_ids = np.array(merged_img_ids)
194
+ merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
195
+
196
+ # keep only unique (and in sorted order) images
197
+ merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
198
+ merged_eval_imgs = merged_eval_imgs[..., idx]
199
+
200
+ return merged_img_ids, merged_eval_imgs
201
+
202
+
203
+ def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
204
+ img_ids, eval_imgs = merge(img_ids, eval_imgs)
205
+ img_ids = list(img_ids)
206
+ eval_imgs = list(eval_imgs.flatten())
207
+
208
+ coco_eval.evalImgs = eval_imgs
209
+ coco_eval.params.imgIds = img_ids
210
+ coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
211
+
212
+
213
+ #################################################################
214
+ # From pycocotools, just removed the prints and fixed
215
+ # a Python3 bug about unicode not defined
216
+ #################################################################
217
+
218
+
219
+ def evaluate(self):
220
+ """
221
+ Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
222
+ :return: None
223
+ """
224
+ # tic = time.time()
225
+ # print('Running per image evaluation...')
226
+ p = self.params
227
+ # add backward compatibility if useSegm is specified in params
228
+ if p.useSegm is not None:
229
+ p.iouType = "segm" if p.useSegm == 1 else "bbox"
230
+ print("useSegm (deprecated) is not None. Running {} evaluation".format(p.iouType))
231
+ # print('Evaluate annotation type *{}*'.format(p.iouType))
232
+ p.imgIds = list(np.unique(p.imgIds))
233
+ if p.useCats:
234
+ p.catIds = list(np.unique(p.catIds))
235
+ p.maxDets = sorted(p.maxDets)
236
+ self.params = p
237
+
238
+ self._prepare()
239
+ # loop through images, area range, max detection number
240
+ catIds = p.catIds if p.useCats else [-1]
241
+
242
+ if p.iouType == "segm" or p.iouType == "bbox":
243
+ computeIoU = self.computeIoU
244
+ elif p.iouType == "keypoints":
245
+ computeIoU = self.computeOks
246
+ self.ious = {
247
+ (imgId, catId): computeIoU(imgId, catId)
248
+ for imgId in p.imgIds
249
+ for catId in catIds}
250
+
251
+ evaluateImg = self.evaluateImg
252
+ maxDet = p.maxDets[-1]
253
+ evalImgs = [
254
+ evaluateImg(imgId, catId, areaRng, maxDet)
255
+ for catId in catIds
256
+ for areaRng in p.areaRng
257
+ for imgId in p.imgIds
258
+ ]
259
+ # this is NOT in the pycocotools code, but could be done outside
260
+ evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
261
+ self._paramsEval = copy.deepcopy(self.params)
262
+ # toc = time.time()
263
+ # print('DONE (t={:0.2f}s).'.format(toc-tic))
264
+ return p.imgIds, evalImgs
265
+
266
+
267
+ #################################################################
268
+ # end of straight copy from pycocotools, just removing the prints
269
+ #################################################################
groundingdino/datasets/transforms.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Transforms and data augmentation for both image + bbox.
4
+ """
5
+ import os
6
+ import random
7
+
8
+ import PIL
9
+ import torch
10
+ import torchvision.transforms as T
11
+ import torchvision.transforms.functional as F
12
+
13
+ from groundingdino.util.box_ops import box_xyxy_to_cxcywh
14
+ from groundingdino.util.misc import interpolate
15
+
16
+
17
+ def crop(image, target, region):
18
+ cropped_image = F.crop(image, *region)
19
+
20
+ target = target.copy()
21
+ i, j, h, w = region
22
+
23
+ # should we do something wrt the original size?
24
+ target["size"] = torch.tensor([h, w])
25
+
26
+ fields = ["labels", "area", "iscrowd", "positive_map"]
27
+
28
+ if "boxes" in target:
29
+ boxes = target["boxes"]
30
+ max_size = torch.as_tensor([w, h], dtype=torch.float32)
31
+ cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
32
+ cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
33
+ cropped_boxes = cropped_boxes.clamp(min=0)
34
+ area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
35
+ target["boxes"] = cropped_boxes.reshape(-1, 4)
36
+ target["area"] = area
37
+ fields.append("boxes")
38
+
39
+ if "masks" in target:
40
+ # FIXME should we update the area here if there are no boxes?
41
+ target["masks"] = target["masks"][:, i : i + h, j : j + w]
42
+ fields.append("masks")
43
+
44
+ # remove elements for which the boxes or masks that have zero area
45
+ if "boxes" in target or "masks" in target:
46
+ # favor boxes selection when defining which elements to keep
47
+ # this is compatible with previous implementation
48
+ if "boxes" in target:
49
+ cropped_boxes = target["boxes"].reshape(-1, 2, 2)
50
+ keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
51
+ else:
52
+ keep = target["masks"].flatten(1).any(1)
53
+
54
+ for field in fields:
55
+ if field in target:
56
+ target[field] = target[field][keep]
57
+
58
+ if os.environ.get("IPDB_SHILONG_DEBUG", None) == "INFO":
59
+ # for debug and visualization only.
60
+ if "strings_positive" in target:
61
+ target["strings_positive"] = [
62
+ _i for _i, _j in zip(target["strings_positive"], keep) if _j
63
+ ]
64
+
65
+ return cropped_image, target
66
+
67
+
68
+ def hflip(image, target):
69
+ flipped_image = F.hflip(image)
70
+
71
+ w, h = image.size
72
+
73
+ target = target.copy()
74
+ if "boxes" in target:
75
+ boxes = target["boxes"]
76
+ boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor(
77
+ [w, 0, w, 0]
78
+ )
79
+ target["boxes"] = boxes
80
+
81
+ if "masks" in target:
82
+ target["masks"] = target["masks"].flip(-1)
83
+
84
+ return flipped_image, target
85
+
86
+
87
+ def resize(image, target, size, max_size=None):
88
+ # size can be min_size (scalar) or (w, h) tuple
89
+
90
+ def get_size_with_aspect_ratio(image_size, size, max_size=None):
91
+ w, h = image_size
92
+ if max_size is not None:
93
+ min_original_size = float(min((w, h)))
94
+ max_original_size = float(max((w, h)))
95
+ if max_original_size / min_original_size * size > max_size:
96
+ size = int(round(max_size * min_original_size / max_original_size))
97
+
98
+ if (w <= h and w == size) or (h <= w and h == size):
99
+ return (h, w)
100
+
101
+ if w < h:
102
+ ow = size
103
+ oh = int(size * h / w)
104
+ else:
105
+ oh = size
106
+ ow = int(size * w / h)
107
+
108
+ return (oh, ow)
109
+
110
+ def get_size(image_size, size, max_size=None):
111
+ if isinstance(size, (list, tuple)):
112
+ return size[::-1]
113
+ else:
114
+ return get_size_with_aspect_ratio(image_size, size, max_size)
115
+
116
+ size = get_size(image.size, size, max_size)
117
+ rescaled_image = F.resize(image, size)
118
+
119
+ if target is None:
120
+ return rescaled_image, None
121
+
122
+ ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
123
+ ratio_width, ratio_height = ratios
124
+
125
+ target = target.copy()
126
+ if "boxes" in target:
127
+ boxes = target["boxes"]
128
+ scaled_boxes = boxes * torch.as_tensor(
129
+ [ratio_width, ratio_height, ratio_width, ratio_height]
130
+ )
131
+ target["boxes"] = scaled_boxes
132
+
133
+ if "area" in target:
134
+ area = target["area"]
135
+ scaled_area = area * (ratio_width * ratio_height)
136
+ target["area"] = scaled_area
137
+
138
+ h, w = size
139
+ target["size"] = torch.tensor([h, w])
140
+
141
+ if "masks" in target:
142
+ target["masks"] = (
143
+ interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5
144
+ )
145
+
146
+ return rescaled_image, target
147
+
148
+
149
+ def pad(image, target, padding):
150
+ # assumes that we only pad on the bottom right corners
151
+ padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
152
+ if target is None:
153
+ return padded_image, None
154
+ target = target.copy()
155
+ # should we do something wrt the original size?
156
+ target["size"] = torch.tensor(padded_image.size[::-1])
157
+ if "masks" in target:
158
+ target["masks"] = torch.nn.functional.pad(target["masks"], (0, padding[0], 0, padding[1]))
159
+ return padded_image, target
160
+
161
+
162
+ class ResizeDebug(object):
163
+ def __init__(self, size):
164
+ self.size = size
165
+
166
+ def __call__(self, img, target):
167
+ return resize(img, target, self.size)
168
+
169
+
170
+ class RandomCrop(object):
171
+ def __init__(self, size):
172
+ self.size = size
173
+
174
+ def __call__(self, img, target):
175
+ region = T.RandomCrop.get_params(img, self.size)
176
+ return crop(img, target, region)
177
+
178
+
179
+ class RandomSizeCrop(object):
180
+ def __init__(self, min_size: int, max_size: int, respect_boxes: bool = False):
181
+ # respect_boxes: True to keep all boxes
182
+ # False to tolerence box filter
183
+ self.min_size = min_size
184
+ self.max_size = max_size
185
+ self.respect_boxes = respect_boxes
186
+
187
+ def __call__(self, img: PIL.Image.Image, target: dict):
188
+ init_boxes = len(target["boxes"])
189
+ max_patience = 10
190
+ for i in range(max_patience):
191
+ w = random.randint(self.min_size, min(img.width, self.max_size))
192
+ h = random.randint(self.min_size, min(img.height, self.max_size))
193
+ region = T.RandomCrop.get_params(img, [h, w])
194
+ result_img, result_target = crop(img, target, region)
195
+ if (
196
+ not self.respect_boxes
197
+ or len(result_target["boxes"]) == init_boxes
198
+ or i == max_patience - 1
199
+ ):
200
+ return result_img, result_target
201
+ return result_img, result_target
202
+
203
+
204
+ class CenterCrop(object):
205
+ def __init__(self, size):
206
+ self.size = size
207
+
208
+ def __call__(self, img, target):
209
+ image_width, image_height = img.size
210
+ crop_height, crop_width = self.size
211
+ crop_top = int(round((image_height - crop_height) / 2.0))
212
+ crop_left = int(round((image_width - crop_width) / 2.0))
213
+ return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
214
+
215
+
216
+ class RandomHorizontalFlip(object):
217
+ def __init__(self, p=0.5):
218
+ self.p = p
219
+
220
+ def __call__(self, img, target):
221
+ if random.random() < self.p:
222
+ return hflip(img, target)
223
+ return img, target
224
+
225
+
226
+ class RandomResize(object):
227
+ def __init__(self, sizes, max_size=None):
228
+ assert isinstance(sizes, (list, tuple))
229
+ self.sizes = sizes
230
+ self.max_size = max_size
231
+
232
+ def __call__(self, img, target=None):
233
+ size = random.choice(self.sizes)
234
+ return resize(img, target, size, self.max_size)
235
+
236
+
237
+ class RandomPad(object):
238
+ def __init__(self, max_pad):
239
+ self.max_pad = max_pad
240
+
241
+ def __call__(self, img, target):
242
+ pad_x = random.randint(0, self.max_pad)
243
+ pad_y = random.randint(0, self.max_pad)
244
+ return pad(img, target, (pad_x, pad_y))
245
+
246
+
247
+ class RandomSelect(object):
248
+ """
249
+ Randomly selects between transforms1 and transforms2,
250
+ with probability p for transforms1 and (1 - p) for transforms2
251
+ """
252
+
253
+ def __init__(self, transforms1, transforms2, p=0.5):
254
+ self.transforms1 = transforms1
255
+ self.transforms2 = transforms2
256
+ self.p = p
257
+
258
+ def __call__(self, img, target):
259
+ if random.random() < self.p:
260
+ return self.transforms1(img, target)
261
+ return self.transforms2(img, target)
262
+
263
+
264
+ class ToTensor(object):
265
+ def __call__(self, img, target):
266
+ return F.to_tensor(img), target
267
+
268
+
269
+ class RandomErasing(object):
270
+ def __init__(self, *args, **kwargs):
271
+ self.eraser = T.RandomErasing(*args, **kwargs)
272
+
273
+ def __call__(self, img, target):
274
+ return self.eraser(img), target
275
+
276
+
277
+ class Normalize(object):
278
+ def __init__(self, mean, std):
279
+ self.mean = mean
280
+ self.std = std
281
+
282
+ def __call__(self, image, target=None):
283
+ image = F.normalize(image, mean=self.mean, std=self.std)
284
+ if target is None:
285
+ return image, None
286
+ target = target.copy()
287
+ h, w = image.shape[-2:]
288
+ if "boxes" in target:
289
+ boxes = target["boxes"]
290
+ boxes = box_xyxy_to_cxcywh(boxes)
291
+ boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
292
+ target["boxes"] = boxes
293
+ return image, target
294
+
295
+
296
+ class Compose(object):
297
+ def __init__(self, transforms):
298
+ self.transforms = transforms
299
+
300
+ def __call__(self, image, target):
301
+ for t in self.transforms:
302
+ image, target = t(image, target)
303
+ return image, target
304
+
305
+ def __repr__(self):
306
+ format_string = self.__class__.__name__ + "("
307
+ for t in self.transforms:
308
+ format_string += "\n"
309
+ format_string += " {0}".format(t)
310
+ format_string += "\n)"
311
+ return format_string
groundingdino/models/.ipynb_checkpoints/__init__-checkpoint.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
8
+ from .GroundingDINO import build_groundingdino
9
+
10
+
11
+ def build_model(args):
12
+ # we use register to maintain models from catdet6 on.
13
+ from .registry import MODULE_BUILD_FUNCS
14
+
15
+ assert args.modelname in MODULE_BUILD_FUNCS._module_dict
16
+ build_func = MODULE_BUILD_FUNCS.get(args.modelname)
17
+ model = build_func(args)
18
+ return model
groundingdino/models/.ipynb_checkpoints/registry-checkpoint.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # -*- coding: utf-8 -*-
8
+ # @Author: Yihao Chen
9
+ # @Date: 2021-08-16 16:03:17
10
+ # @Last Modified by: Shilong Liu
11
+ # @Last Modified time: 2022-01-23 15:26
12
+ # modified from mmcv
13
+
14
+ import inspect
15
+ from functools import partial
16
+
17
+
18
+ class Registry(object):
19
+ def __init__(self, name):
20
+ self._name = name
21
+ self._module_dict = dict()
22
+
23
+ def __repr__(self):
24
+ format_str = self.__class__.__name__ + "(name={}, items={})".format(
25
+ self._name, list(self._module_dict.keys())
26
+ )
27
+ return format_str
28
+
29
+ def __len__(self):
30
+ return len(self._module_dict)
31
+
32
+ @property
33
+ def name(self):
34
+ return self._name
35
+
36
+ @property
37
+ def module_dict(self):
38
+ return self._module_dict
39
+
40
+ def get(self, key):
41
+ return self._module_dict.get(key, None)
42
+
43
+ def registe_with_name(self, module_name=None, force=False):
44
+ return partial(self.register, module_name=module_name, force=force)
45
+
46
+ def register(self, module_build_function, module_name=None, force=False):
47
+ """Register a module build function.
48
+ Args:
49
+ module (:obj:`nn.Module`): Module to be registered.
50
+ """
51
+ if not inspect.isfunction(module_build_function):
52
+ raise TypeError(
53
+ "module_build_function must be a function, but got {}".format(
54
+ type(module_build_function)
55
+ )
56
+ )
57
+ if module_name is None:
58
+ module_name = module_build_function.__name__
59
+ if not force and module_name in self._module_dict:
60
+ raise KeyError("{} is already registered in {}".format(module_name, self.name))
61
+ self._module_dict[module_name] = module_build_function
62
+
63
+ return module_build_function
64
+
65
+
66
+ MODULE_BUILD_FUNCS = Registry("model build functions")
groundingdino/models/GroundingDINO/.ipynb_checkpoints/fuse_modules-checkpoint.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from timm.models.layers import DropPath
12
+ import loralib as lora
13
+
14
+ class FeatureResizer(nn.Module):
15
+ """
16
+ This class takes as input a set of embeddings of dimension C1 and outputs a set of
17
+ embedding of dimension C2, after a linear transformation, dropout and normalization (LN).
18
+ """
19
+
20
+ def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True):
21
+ super().__init__()
22
+ self.do_ln = do_ln
23
+ # Object feature encoding
24
+ r = 16
25
+ self.fc = lora.Linear(input_feat_size, output_feat_size,r=r , bias=True)
26
+ self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12)
27
+ self.dropout = nn.Dropout(dropout)
28
+
29
+ def forward(self, encoder_features):
30
+ x = self.fc(encoder_features)
31
+ if self.do_ln:
32
+ x = self.layer_norm(x)
33
+ output = self.dropout(x)
34
+ return output
35
+
36
+
37
+ def l1norm(X, dim, eps=1e-8):
38
+ """L1-normalize columns of X"""
39
+ norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps
40
+ X = torch.div(X, norm)
41
+ return X
42
+
43
+
44
+ def l2norm(X, dim, eps=1e-8):
45
+ """L2-normalize columns of X"""
46
+ norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
47
+ X = torch.div(X, norm)
48
+ return X
49
+
50
+
51
+ def func_attention(query, context, smooth=1, raw_feature_norm="softmax", eps=1e-8):
52
+ """
53
+ query: (n_context, queryL, d)
54
+ context: (n_context, sourceL, d)
55
+ """
56
+ batch_size_q, queryL = query.size(0), query.size(1)
57
+ batch_size, sourceL = context.size(0), context.size(1)
58
+
59
+ # Get attention
60
+ # --> (batch, d, queryL)
61
+ queryT = torch.transpose(query, 1, 2)
62
+
63
+ # (batch, sourceL, d)(batch, d, queryL)
64
+ # --> (batch, sourceL, queryL)
65
+ attn = torch.bmm(context, queryT)
66
+ if raw_feature_norm == "softmax":
67
+ # --> (batch*sourceL, queryL)
68
+ attn = attn.view(batch_size * sourceL, queryL)
69
+ attn = nn.Softmax()(attn)
70
+ # --> (batch, sourceL, queryL)
71
+ attn = attn.view(batch_size, sourceL, queryL)
72
+ elif raw_feature_norm == "l2norm":
73
+ attn = l2norm(attn, 2)
74
+ elif raw_feature_norm == "clipped_l2norm":
75
+ attn = nn.LeakyReLU(0.1)(attn)
76
+ attn = l2norm(attn, 2)
77
+ else:
78
+ raise ValueError("unknown first norm type:", raw_feature_norm)
79
+ # --> (batch, queryL, sourceL)
80
+ attn = torch.transpose(attn, 1, 2).contiguous()
81
+ # --> (batch*queryL, sourceL)
82
+ attn = attn.view(batch_size * queryL, sourceL)
83
+ attn = nn.Softmax()(attn * smooth)
84
+ # --> (batch, queryL, sourceL)
85
+ attn = attn.view(batch_size, queryL, sourceL)
86
+ # --> (batch, sourceL, queryL)
87
+ attnT = torch.transpose(attn, 1, 2).contiguous()
88
+
89
+ # --> (batch, d, sourceL)
90
+ contextT = torch.transpose(context, 1, 2)
91
+ # (batch x d x sourceL)(batch x sourceL x queryL)
92
+ # --> (batch, d, queryL)
93
+ weightedContext = torch.bmm(contextT, attnT)
94
+ # --> (batch, queryL, d)
95
+ weightedContext = torch.transpose(weightedContext, 1, 2)
96
+
97
+ return weightedContext, attnT
98
+
99
+
100
+ class BiMultiHeadAttention(nn.Module):
101
+ def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1, cfg=None):
102
+ super(BiMultiHeadAttention, self).__init__()
103
+
104
+ self.embed_dim = embed_dim
105
+ self.num_heads = num_heads
106
+ self.head_dim = embed_dim // num_heads
107
+ self.v_dim = v_dim
108
+ self.l_dim = l_dim
109
+
110
+ assert (
111
+ self.head_dim * self.num_heads == self.embed_dim
112
+ ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
113
+ self.scale = self.head_dim ** (-0.5)
114
+ self.dropout = dropout
115
+ r = 16
116
+ self.v_proj = lora.Linear(self.v_dim, self.embed_dim , r=r)
117
+ self.l_proj = lora.Linear(self.l_dim, self.embed_dim , r=r)
118
+ self.values_v_proj = lora.Linear(self.v_dim, self.embed_dim , r=r)
119
+ self.values_l_proj = lora.Linear(self.l_dim, self.embed_dim , r=r)
120
+
121
+ self.out_v_proj = lora.Linear(self.embed_dim, self.v_dim , r=r)
122
+ self.out_l_proj = lora.Linear(self.embed_dim, self.l_dim , r=r)
123
+
124
+ self.stable_softmax_2d = True
125
+ self.clamp_min_for_underflow = True
126
+ self.clamp_max_for_overflow = True
127
+
128
+ self._reset_parameters()
129
+
130
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
131
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
132
+
133
+ def _reset_parameters(self):
134
+ nn.init.xavier_uniform_(self.v_proj.weight)
135
+ self.v_proj.bias.data.fill_(0)
136
+ nn.init.xavier_uniform_(self.l_proj.weight)
137
+ self.l_proj.bias.data.fill_(0)
138
+ nn.init.xavier_uniform_(self.values_v_proj.weight)
139
+ self.values_v_proj.bias.data.fill_(0)
140
+ nn.init.xavier_uniform_(self.values_l_proj.weight)
141
+ self.values_l_proj.bias.data.fill_(0)
142
+ nn.init.xavier_uniform_(self.out_v_proj.weight)
143
+ self.out_v_proj.bias.data.fill_(0)
144
+ nn.init.xavier_uniform_(self.out_l_proj.weight)
145
+ self.out_l_proj.bias.data.fill_(0)
146
+
147
+ def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
148
+ """_summary_
149
+
150
+ Args:
151
+ v (_type_): bs, n_img, dim
152
+ l (_type_): bs, n_text, dim
153
+ attention_mask_v (_type_, optional): _description_. bs, n_img
154
+ attention_mask_l (_type_, optional): _description_. bs, n_text
155
+
156
+ Returns:
157
+ _type_: _description_
158
+ """
159
+ # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
160
+ # import ipdb; ipdb.set_trace()
161
+ bsz, tgt_len, _ = v.size()
162
+
163
+ query_states = self.v_proj(v) * self.scale
164
+ key_states = self._shape(self.l_proj(l), -1, bsz)
165
+ value_v_states = self._shape(self.values_v_proj(v), -1, bsz)
166
+ value_l_states = self._shape(self.values_l_proj(l), -1, bsz)
167
+
168
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
169
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
170
+ key_states = key_states.view(*proj_shape)
171
+ value_v_states = value_v_states.view(*proj_shape)
172
+ value_l_states = value_l_states.view(*proj_shape)
173
+
174
+ src_len = key_states.size(1)
175
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) # bs*nhead, nimg, ntxt
176
+
177
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
178
+ raise ValueError(
179
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
180
+ )
181
+
182
+ if self.stable_softmax_2d:
183
+ attn_weights = attn_weights - attn_weights.max()
184
+
185
+ if self.clamp_min_for_underflow:
186
+ attn_weights = torch.clamp(
187
+ attn_weights, min=-50000
188
+ ) # Do not increase -50000, data type half has quite limited range
189
+ if self.clamp_max_for_overflow:
190
+ attn_weights = torch.clamp(
191
+ attn_weights, max=50000
192
+ ) # Do not increase 50000, data type half has quite limited range
193
+
194
+ attn_weights_T = attn_weights.transpose(1, 2)
195
+ attn_weights_l = attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[0]
196
+ if self.clamp_min_for_underflow:
197
+ attn_weights_l = torch.clamp(
198
+ attn_weights_l, min=-50000
199
+ ) # Do not increase -50000, data type half has quite limited range
200
+ if self.clamp_max_for_overflow:
201
+ attn_weights_l = torch.clamp(
202
+ attn_weights_l, max=50000
203
+ ) # Do not increase 50000, data type half has quite limited range
204
+
205
+ # mask vison for language
206
+ if attention_mask_v is not None:
207
+ attention_mask_v = (
208
+ attention_mask_v[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
209
+ )
210
+ attn_weights_l.masked_fill_(attention_mask_v, float("-inf"))
211
+
212
+ attn_weights_l = attn_weights_l.softmax(dim=-1)
213
+
214
+ # mask language for vision
215
+ if attention_mask_l is not None:
216
+ attention_mask_l = (
217
+ attention_mask_l[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
218
+ )
219
+ attn_weights.masked_fill_(attention_mask_l, float("-inf"))
220
+ attn_weights_v = attn_weights.softmax(dim=-1)
221
+
222
+ attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training)
223
+ attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training)
224
+
225
+ attn_output_v = torch.bmm(attn_probs_v, value_l_states)
226
+ attn_output_l = torch.bmm(attn_probs_l, value_v_states)
227
+
228
+ if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
229
+ raise ValueError(
230
+ f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.size()}"
231
+ )
232
+
233
+ if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim):
234
+ raise ValueError(
235
+ f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, but is {attn_output_l.size()}"
236
+ )
237
+
238
+ attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim)
239
+ attn_output_v = attn_output_v.transpose(1, 2)
240
+ attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim)
241
+
242
+ attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim)
243
+ attn_output_l = attn_output_l.transpose(1, 2)
244
+ attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim)
245
+
246
+ attn_output_v = self.out_v_proj(attn_output_v)
247
+ attn_output_l = self.out_l_proj(attn_output_l)
248
+
249
+ return attn_output_v, attn_output_l
250
+
251
+
252
+ # Bi-Direction MHA (text->image, image->text)
253
+ class BiAttentionBlock(nn.Module):
254
+ def __init__(
255
+ self,
256
+ v_dim,
257
+ l_dim,
258
+ embed_dim,
259
+ num_heads,
260
+ dropout=0.1,
261
+ drop_path=0.0,
262
+ init_values=1e-4,
263
+ cfg=None,
264
+ ):
265
+ """
266
+ Inputs:
267
+ embed_dim - Dimensionality of input and attention feature vectors
268
+ hidden_dim - Dimensionality of hidden layer in feed-forward network
269
+ (usually 2-4x larger than embed_dim)
270
+ num_heads - Number of heads to use in the Multi-Head Attention block
271
+ dropout - Amount of dropout to apply in the feed-forward network
272
+ """
273
+ super(BiAttentionBlock, self).__init__()
274
+
275
+ # pre layer norm
276
+ self.layer_norm_v = nn.LayerNorm(v_dim)
277
+ self.layer_norm_l = nn.LayerNorm(l_dim)
278
+ self.attn = BiMultiHeadAttention(
279
+ v_dim=v_dim, l_dim=l_dim, embed_dim=embed_dim, num_heads=num_heads, dropout=dropout
280
+ )
281
+
282
+ # add layer scale for training stability
283
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
284
+ self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=True)
285
+ self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=True)
286
+
287
+ def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
288
+ v = self.layer_norm_v(v)
289
+ l = self.layer_norm_l(l)
290
+ delta_v, delta_l = self.attn(
291
+ v, l, attention_mask_v=attention_mask_v, attention_mask_l=attention_mask_l
292
+ )
293
+ # v, l = v + delta_v, l + delta_l
294
+ v = v + self.drop_path(self.gamma_v * delta_v)
295
+ l = l + self.drop_path(self.gamma_l * delta_l)
296
+ return v, l
297
+
298
+ # def forward(self, v:List[torch.Tensor], l, attention_mask_v=None, attention_mask_l=None)
groundingdino/models/GroundingDINO/.ipynb_checkpoints/groundingdino-checkpoint.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # Conditional DETR model and criterion classes.
8
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------
11
+ # Modified from DETR (https://github.com/facebookresearch/detr)
12
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
13
+ # ------------------------------------------------------------------------
14
+ # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
15
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
16
+ # ------------------------------------------------------------------------
17
+ import copy
18
+ from typing import List
19
+ import loralib as lora
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from torch import nn
23
+ from torchvision.ops.boxes import nms
24
+ from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast
25
+
26
+ from groundingdino.util import box_ops, get_tokenlizer
27
+ from groundingdino.util.misc import (
28
+ NestedTensor,
29
+ accuracy,
30
+ get_world_size,
31
+ interpolate,
32
+ inverse_sigmoid,
33
+ is_dist_avail_and_initialized,
34
+ nested_tensor_from_tensor_list,
35
+ )
36
+ from groundingdino.util.utils import get_phrases_from_posmap
37
+ from groundingdino.util.visualizer import COCOVisualizer
38
+ from groundingdino.util.vl_utils import create_positive_map_from_span
39
+
40
+ from ..registry import MODULE_BUILD_FUNCS
41
+ from .backbone import build_backbone
42
+ from .bertwarper import (
43
+ BertModelWarper,
44
+ generate_masks_with_special_tokens,
45
+ generate_masks_with_special_tokens_and_transfer_map,
46
+ )
47
+ from .transformer import build_transformer
48
+ from .utils import MLP, ContrastiveEmbed, sigmoid_focal_loss
49
+
50
+
51
+ class GroundingDINO(nn.Module):
52
+ """This is the Cross-Attention Detector module that performs object detection"""
53
+
54
+ def __init__(
55
+ self,
56
+ backbone,
57
+ transformer,
58
+ num_queries,
59
+ aux_loss=False,
60
+ iter_update=False,
61
+ query_dim=2,
62
+ num_feature_levels=1,
63
+ nheads=8,
64
+ # two stage
65
+ two_stage_type="no", # ['no', 'standard']
66
+ dec_pred_bbox_embed_share=True,
67
+ two_stage_class_embed_share=True,
68
+ two_stage_bbox_embed_share=True,
69
+ num_patterns=0,
70
+ dn_number=100,
71
+ dn_box_noise_scale=0.4,
72
+ dn_label_noise_ratio=0.5,
73
+ dn_labelbook_size=100,
74
+ text_encoder_type="bert-base-uncased",
75
+ sub_sentence_present=True,
76
+ max_text_len=256,
77
+ ):
78
+ """Initializes the model.
79
+ Parameters:
80
+ backbone: torch module of the backbone to be used. See backbone.py
81
+ transformer: torch module of the transformer architecture. See transformer.py
82
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
83
+ Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
84
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
85
+ """
86
+ super().__init__()
87
+ self.num_queries = num_queries
88
+ self.transformer = transformer
89
+ self.hidden_dim = hidden_dim = transformer.d_model
90
+ self.num_feature_levels = num_feature_levels
91
+ self.nheads = nheads
92
+ self.max_text_len = 256
93
+ self.sub_sentence_present = sub_sentence_present
94
+
95
+ # setting query dim
96
+ self.query_dim = query_dim
97
+ assert query_dim == 4
98
+
99
+ # for dn training
100
+ self.num_patterns = num_patterns
101
+ self.dn_number = dn_number
102
+ self.dn_box_noise_scale = dn_box_noise_scale
103
+ self.dn_label_noise_ratio = dn_label_noise_ratio
104
+ self.dn_labelbook_size = dn_labelbook_size
105
+
106
+ # bert
107
+ self.tokenizer = get_tokenlizer.get_tokenlizer(text_encoder_type)
108
+ self.bert = get_tokenlizer.get_pretrained_language_model(text_encoder_type)
109
+ self.bert.pooler.dense.weight.requires_grad_(False)
110
+ self.bert.pooler.dense.bias.requires_grad_(False)
111
+ self.bert = BertModelWarper(bert_model=self.bert)
112
+
113
+ self.feat_map = nn.Linear(self.bert.config.hidden_size, self.hidden_dim, bias=True)
114
+ nn.init.constant_(self.feat_map.bias.data, 0)
115
+ nn.init.xavier_uniform_(self.feat_map.weight.data)
116
+ # freeze
117
+
118
+ # special tokens
119
+ self.specical_tokens = self.tokenizer.convert_tokens_to_ids(["[CLS]", "[SEP]", ".", "?"])
120
+
121
+ # prepare input projection layers
122
+ if num_feature_levels > 1:
123
+ num_backbone_outs = len(backbone.num_channels)
124
+ input_proj_list = []
125
+ for _ in range(num_backbone_outs):
126
+ in_channels = backbone.num_channels[_]
127
+ input_proj_list.append(
128
+ nn.Sequential(
129
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
130
+ nn.GroupNorm(32, hidden_dim),
131
+ )
132
+ )
133
+ for _ in range(num_feature_levels - num_backbone_outs):
134
+ input_proj_list.append(
135
+ nn.Sequential(
136
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
137
+ nn.GroupNorm(32, hidden_dim),
138
+ )
139
+ )
140
+ in_channels = hidden_dim
141
+ self.input_proj = nn.ModuleList(input_proj_list)
142
+ else:
143
+ assert two_stage_type == "no", "two_stage_type should be no if num_feature_levels=1 !!!"
144
+ self.input_proj = nn.ModuleList(
145
+ [
146
+ nn.Sequential(
147
+ nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1),
148
+ nn.GroupNorm(32, hidden_dim),
149
+ )
150
+ ]
151
+ )
152
+
153
+ self.backbone = backbone
154
+ self.aux_loss = aux_loss
155
+ self.box_pred_damping = box_pred_damping = None
156
+
157
+ self.iter_update = iter_update
158
+ assert iter_update, "Why not iter_update?"
159
+
160
+ # prepare pred layers
161
+ self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share
162
+ # prepare class & box embed
163
+ _class_embed = ContrastiveEmbed()
164
+
165
+ _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
166
+ nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
167
+ nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
168
+
169
+ if dec_pred_bbox_embed_share:
170
+ box_embed_layerlist = [_bbox_embed for i in range(transformer.num_decoder_layers)]
171
+ else:
172
+ box_embed_layerlist = [
173
+ copy.deepcopy(_bbox_embed) for i in range(transformer.num_decoder_layers)
174
+ ]
175
+ class_embed_layerlist = [_class_embed for i in range(transformer.num_decoder_layers)]
176
+ self.bbox_embed = nn.ModuleList(box_embed_layerlist)
177
+ self.class_embed = nn.ModuleList(class_embed_layerlist)
178
+ self.transformer.decoder.bbox_embed = self.bbox_embed
179
+ self.transformer.decoder.class_embed = self.class_embed
180
+
181
+ # two stage
182
+ self.two_stage_type = two_stage_type
183
+ assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format(
184
+ two_stage_type
185
+ )
186
+ if two_stage_type != "no":
187
+ if two_stage_bbox_embed_share:
188
+ assert dec_pred_bbox_embed_share
189
+ self.transformer.enc_out_bbox_embed = _bbox_embed
190
+ else:
191
+ self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed)
192
+
193
+ if two_stage_class_embed_share:
194
+ assert dec_pred_bbox_embed_share
195
+ self.transformer.enc_out_class_embed = _class_embed
196
+ else:
197
+ self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed)
198
+
199
+ self.refpoint_embed = None
200
+
201
+ self._reset_parameters()
202
+
203
+ def _reset_parameters(self):
204
+ # init input_proj
205
+ for proj in self.input_proj:
206
+ nn.init.xavier_uniform_(proj[0].weight, gain=1)
207
+ nn.init.constant_(proj[0].bias, 0)
208
+
209
+ def set_image_tensor(self, samples: NestedTensor):
210
+ if isinstance(samples, (list, torch.Tensor)):
211
+ samples = nested_tensor_from_tensor_list(samples)
212
+ self.features, self.poss = self.backbone(samples)
213
+
214
+ def unset_image_tensor(self):
215
+ if hasattr(self, 'features'):
216
+ del self.features
217
+ if hasattr(self,'poss'):
218
+ del self.poss
219
+
220
+ def set_image_features(self, features , poss):
221
+ self.features = features
222
+ self.poss = poss
223
+
224
+ def init_ref_points(self, use_num_queries):
225
+ self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim)
226
+
227
+ def forward(self, samples: NestedTensor, targets: List = None, **kw):
228
+ """The forward expects a NestedTensor, which consists of:
229
+ - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
230
+ - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
231
+
232
+ It returns a dict with the following elements:
233
+ - "pred_logits": the classification logits (including no-object) for all queries.
234
+ Shape= [batch_size x num_queries x num_classes]
235
+ - "pred_boxes": The normalized boxes coordinates for all queries, represented as
236
+ (center_x, center_y, width, height). These values are normalized in [0, 1],
237
+ relative to the size of each individual image (disregarding possible padding).
238
+ See PostProcess for information on how to retrieve the unnormalized bounding box.
239
+ - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
240
+ dictionnaries containing the two above keys for each decoder layer.
241
+ """
242
+ if targets is None:
243
+ captions = kw["captions"]
244
+ else:
245
+ captions = [t["caption"] for t in targets]
246
+
247
+ # encoder texts
248
+ tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to(
249
+ samples.device
250
+ )
251
+ (
252
+ text_self_attention_masks,
253
+ position_ids,
254
+ cate_to_token_mask_list,
255
+ ) = generate_masks_with_special_tokens_and_transfer_map(
256
+ tokenized, self.specical_tokens, self.tokenizer
257
+ )
258
+
259
+ if text_self_attention_masks.shape[1] > self.max_text_len:
260
+ text_self_attention_masks = text_self_attention_masks[
261
+ :, : self.max_text_len, : self.max_text_len
262
+ ]
263
+ position_ids = position_ids[:, : self.max_text_len]
264
+ tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len]
265
+ tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len]
266
+ tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len]
267
+
268
+ # extract text embeddings
269
+ if self.sub_sentence_present:
270
+ tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"}
271
+ tokenized_for_encoder["attention_mask"] = text_self_attention_masks
272
+ tokenized_for_encoder["position_ids"] = position_ids
273
+ else:
274
+ # import ipdb; ipdb.set_trace()
275
+ tokenized_for_encoder = tokenized
276
+
277
+ bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768
278
+
279
+ encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model
280
+ text_token_mask = tokenized.attention_mask.bool() # bs, 195
281
+ # text_token_mask: True for nomask, False for mask
282
+ # text_self_attention_masks: True for nomask, False for mask
283
+
284
+ if encoded_text.shape[1] > self.max_text_len:
285
+ encoded_text = encoded_text[:, : self.max_text_len, :]
286
+ text_token_mask = text_token_mask[:, : self.max_text_len]
287
+ position_ids = position_ids[:, : self.max_text_len]
288
+ text_self_attention_masks = text_self_attention_masks[
289
+ :, : self.max_text_len, : self.max_text_len
290
+ ]
291
+
292
+ text_dict = {
293
+ "encoded_text": encoded_text, # bs, 195, d_model
294
+ "text_token_mask": text_token_mask, # bs, 195
295
+ "position_ids": position_ids, # bs, 195
296
+ "text_self_attention_masks": text_self_attention_masks, # bs, 195,195
297
+ }
298
+
299
+ # import ipdb; ipdb.set_trace()
300
+ if isinstance(samples, (list, torch.Tensor)):
301
+ samples = nested_tensor_from_tensor_list(samples)
302
+ if not hasattr(self, 'features') or not hasattr(self, 'poss'):
303
+ self.set_image_tensor(samples)
304
+
305
+ srcs = []
306
+ masks = []
307
+ for l, feat in enumerate(self.features):
308
+ src, mask = feat.decompose()
309
+ srcs.append(self.input_proj[l](src))
310
+ masks.append(mask)
311
+ assert mask is not None
312
+ if self.num_feature_levels > len(srcs):
313
+ _len_srcs = len(srcs)
314
+ for l in range(_len_srcs, self.num_feature_levels):
315
+ if l == _len_srcs:
316
+ src = self.input_proj[l](self.features[-1].tensors)
317
+ else:
318
+ src = self.input_proj[l](srcs[-1])
319
+ m = samples.mask
320
+ mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
321
+ pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
322
+ srcs.append(src)
323
+ masks.append(mask)
324
+ self.poss.append(pos_l)
325
+
326
+ input_query_bbox = input_query_label = attn_mask = dn_meta = None
327
+ hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(
328
+ srcs, masks, input_query_bbox, self.poss, input_query_label, attn_mask, text_dict
329
+ )
330
+
331
+ # deformable-detr-like anchor update
332
+ outputs_coord_list = []
333
+ for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(
334
+ zip(reference[:-1], self.bbox_embed, hs)
335
+ ):
336
+ layer_delta_unsig = layer_bbox_embed(layer_hs)
337
+ layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)
338
+ layer_outputs_unsig = layer_outputs_unsig.sigmoid()
339
+ outputs_coord_list.append(layer_outputs_unsig)
340
+ outputs_coord_list = torch.stack(outputs_coord_list)
341
+
342
+ # output
343
+ outputs_class = torch.stack(
344
+ [
345
+ layer_cls_embed(layer_hs, text_dict)
346
+ for layer_cls_embed, layer_hs in zip(self.class_embed, hs)
347
+ ]
348
+ )
349
+ out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord_list[-1]}
350
+
351
+ # # for intermediate outputs
352
+ # if self.aux_loss:
353
+ # out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord_list)
354
+
355
+ # # for encoder output
356
+ # if hs_enc is not None:
357
+ # # prepare intermediate outputs
358
+ # interm_coord = ref_enc[-1]
359
+ # interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict)
360
+ # out['interm_outputs'] = {'pred_logits': interm_class, 'pred_boxes': interm_coord}
361
+ # out['interm_outputs_for_matching_pre'] = {'pred_logits': interm_class, 'pred_boxes': init_box_proposal}
362
+ unset_image_tensor = kw.get('unset_image_tensor', True)
363
+ if unset_image_tensor:
364
+ self.unset_image_tensor() ## If necessary
365
+ return out
366
+
367
+ @torch.jit.unused
368
+ def _set_aux_loss(self, outputs_class, outputs_coord):
369
+ # this is a workaround to make torchscript happy, as torchscript
370
+ # doesn't support dictionary with non-homogeneous values, such
371
+ # as a dict having both a Tensor and a list.
372
+ return [
373
+ {"pred_logits": a, "pred_boxes": b}
374
+ for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
375
+ ]
376
+
377
+
378
+ @MODULE_BUILD_FUNCS.registe_with_name(module_name="groundingdino")
379
+ def build_groundingdino(args):
380
+
381
+ backbone = build_backbone(args)
382
+ transformer = build_transformer(args)
383
+
384
+ dn_labelbook_size = args.dn_labelbook_size
385
+ dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share
386
+ sub_sentence_present = args.sub_sentence_present
387
+
388
+ model = GroundingDINO(
389
+ backbone,
390
+ transformer,
391
+ num_queries=args.num_queries,
392
+ aux_loss=True,
393
+ iter_update=True,
394
+ query_dim=4,
395
+ num_feature_levels=args.num_feature_levels,
396
+ nheads=args.nheads,
397
+ dec_pred_bbox_embed_share=dec_pred_bbox_embed_share,
398
+ two_stage_type=args.two_stage_type,
399
+ two_stage_bbox_embed_share=args.two_stage_bbox_embed_share,
400
+ two_stage_class_embed_share=args.two_stage_class_embed_share,
401
+ num_patterns=args.num_patterns,
402
+ dn_number=0,
403
+ dn_box_noise_scale=args.dn_box_noise_scale,
404
+ dn_label_noise_ratio=args.dn_label_noise_ratio,
405
+ dn_labelbook_size=dn_labelbook_size,
406
+ text_encoder_type=args.text_encoder_type,
407
+ sub_sentence_present=sub_sentence_present,
408
+ max_text_len=args.max_text_len,
409
+ )
410
+
411
+ return model
412
+
groundingdino/models/GroundingDINO/.ipynb_checkpoints/ms_deform_attn-checkpoint.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # Deformable DETR
8
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------------------------------
11
+ # Modified from:
12
+ # https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/functions/ms_deform_attn_func.py
13
+ # https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py
14
+ # https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/multi_scale_deform_attn.py
15
+ # ------------------------------------------------------------------------------------------------
16
+
17
+ import math
18
+ import warnings
19
+ from typing import Optional
20
+ import loralib as lora
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from torch.autograd import Function
25
+ from torch.autograd.function import once_differentiable
26
+ from torch.nn.init import constant_, xavier_uniform_
27
+
28
+ try:
29
+ from groundingdino import _C
30
+ except:
31
+ warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only!")
32
+
33
+
34
+ # helpers
35
+ def _is_power_of_2(n):
36
+ if (not isinstance(n, int)) or (n < 0):
37
+ raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
38
+ return (n & (n - 1) == 0) and n != 0
39
+
40
+
41
+ class MultiScaleDeformableAttnFunction(Function):
42
+ @staticmethod
43
+ def forward(
44
+ ctx,
45
+ value,
46
+ value_spatial_shapes,
47
+ value_level_start_index,
48
+ sampling_locations,
49
+ attention_weights,
50
+ im2col_step,
51
+ ):
52
+ ctx.im2col_step = im2col_step
53
+ output = _C.ms_deform_attn_forward(
54
+ value,
55
+ value_spatial_shapes,
56
+ value_level_start_index,
57
+ sampling_locations,
58
+ attention_weights,
59
+ ctx.im2col_step,
60
+ )
61
+ ctx.save_for_backward(
62
+ value,
63
+ value_spatial_shapes,
64
+ value_level_start_index,
65
+ sampling_locations,
66
+ attention_weights,
67
+ )
68
+ return output
69
+
70
+ @staticmethod
71
+ @once_differentiable
72
+ def backward(ctx, grad_output):
73
+ (
74
+ value,
75
+ value_spatial_shapes,
76
+ value_level_start_index,
77
+ sampling_locations,
78
+ attention_weights,
79
+ ) = ctx.saved_tensors
80
+ grad_value, grad_sampling_loc, grad_attn_weight = _C.ms_deform_attn_backward(
81
+ value,
82
+ value_spatial_shapes,
83
+ value_level_start_index,
84
+ sampling_locations,
85
+ attention_weights,
86
+ grad_output,
87
+ ctx.im2col_step,
88
+ )
89
+
90
+ return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
91
+
92
+
93
+ def multi_scale_deformable_attn_pytorch(
94
+ value: torch.Tensor,
95
+ value_spatial_shapes: torch.Tensor,
96
+ sampling_locations: torch.Tensor,
97
+ attention_weights: torch.Tensor,
98
+ ) -> torch.Tensor:
99
+
100
+ bs, _, num_heads, embed_dims = value.shape
101
+ _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
102
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
103
+ sampling_grids = 2 * sampling_locations - 1
104
+ sampling_value_list = []
105
+ for level, (H_, W_) in enumerate(value_spatial_shapes):
106
+ # bs, H_*W_, num_heads, embed_dims ->
107
+ # bs, H_*W_, num_heads*embed_dims ->
108
+ # bs, num_heads*embed_dims, H_*W_ ->
109
+ # bs*num_heads, embed_dims, H_, W_
110
+ value_l_ = (
111
+ value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)
112
+ )
113
+ # bs, num_queries, num_heads, num_points, 2 ->
114
+ # bs, num_heads, num_queries, num_points, 2 ->
115
+ # bs*num_heads, num_queries, num_points, 2
116
+ sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
117
+ # bs*num_heads, embed_dims, num_queries, num_points
118
+ sampling_value_l_ = F.grid_sample(
119
+ value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
120
+ )
121
+ sampling_value_list.append(sampling_value_l_)
122
+ # (bs, num_queries, num_heads, num_levels, num_points) ->
123
+ # (bs, num_heads, num_queries, num_levels, num_points) ->
124
+ # (bs, num_heads, 1, num_queries, num_levels*num_points)
125
+ attention_weights = attention_weights.transpose(1, 2).reshape(
126
+ bs * num_heads, 1, num_queries, num_levels * num_points
127
+ )
128
+ output = (
129
+ (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
130
+ .sum(-1)
131
+ .view(bs, num_heads * embed_dims, num_queries)
132
+ )
133
+ return output.transpose(1, 2).contiguous()
134
+
135
+
136
+ class MultiScaleDeformableAttention(nn.Module):
137
+ """Multi-Scale Deformable Attention Module used in Deformable-DETR
138
+
139
+ `Deformable DETR: Deformable Transformers for End-to-End Object Detection.
140
+ <https://arxiv.org/pdf/2010.04159.pdf>`_.
141
+
142
+ Args:
143
+ embed_dim (int): The embedding dimension of Attention. Default: 256.
144
+ num_heads (int): The number of attention heads. Default: 8.
145
+ num_levels (int): The number of feature map used in Attention. Default: 4.
146
+ num_points (int): The number of sampling points for each query
147
+ in each head. Default: 4.
148
+ img2col_steps (int): The step used in image_to_column. Defualt: 64.
149
+ dropout (float): Dropout layer used in output. Default: 0.1.
150
+ batch_first (bool): if ``True``, then the input and output tensor will be
151
+ provided as `(bs, n, embed_dim)`. Default: False. `(n, bs, embed_dim)`
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ embed_dim: int = 256,
157
+ num_heads: int = 8,
158
+ num_levels: int = 4,
159
+ num_points: int = 4,
160
+ img2col_step: int = 64,
161
+ batch_first: bool = False,
162
+ ):
163
+ super().__init__()
164
+ if embed_dim % num_heads != 0:
165
+ raise ValueError(
166
+ "embed_dim must be divisible by num_heads, but got {} and {}".format(
167
+ embed_dim, num_heads
168
+ )
169
+ )
170
+ head_dim = embed_dim // num_heads
171
+
172
+ self.batch_first = batch_first
173
+
174
+ if not _is_power_of_2(head_dim):
175
+ warnings.warn(
176
+ """
177
+ You'd better set d_model in MSDeformAttn to make sure that
178
+ each dim of the attention head a power of 2, which is more efficient.
179
+ """
180
+ )
181
+
182
+ self.im2col_step = img2col_step
183
+ self.embed_dim = embed_dim
184
+ self.num_heads = num_heads
185
+ self.num_levels = num_levels
186
+ self.num_points = num_points
187
+ r = 16
188
+ self.sampling_offsets = lora.Linear(embed_dim, num_heads * num_levels * num_points * 2 , r=r)
189
+ self.attention_weights = lora.Linear(embed_dim, num_heads * num_levels * num_points , r=r)
190
+ self.value_proj = lora.Linear(embed_dim, embed_dim , r=r)
191
+ self.output_proj = lora.Linear(embed_dim, embed_dim , r=r)
192
+
193
+ self.init_weights()
194
+
195
+ def _reset_parameters(self):
196
+ return self.init_weights()
197
+
198
+ def init_weights(self):
199
+ """
200
+ Default initialization for Parameters of Module.
201
+ """
202
+ constant_(self.sampling_offsets.weight.data, 0.0)
203
+ thetas = torch.arange(self.num_heads, dtype=torch.float32) * (
204
+ 2.0 * math.pi / self.num_heads
205
+ )
206
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
207
+ grid_init = (
208
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
209
+ .view(self.num_heads, 1, 1, 2)
210
+ .repeat(1, self.num_levels, self.num_points, 1)
211
+ )
212
+ for i in range(self.num_points):
213
+ grid_init[:, :, i, :] *= i + 1
214
+ with torch.no_grad():
215
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
216
+ constant_(self.attention_weights.weight.data, 0.0)
217
+ constant_(self.attention_weights.bias.data, 0.0)
218
+ xavier_uniform_(self.value_proj.weight.data)
219
+ constant_(self.value_proj.bias.data, 0.0)
220
+ xavier_uniform_(self.output_proj.weight.data)
221
+ constant_(self.output_proj.bias.data, 0.0)
222
+
223
+ def freeze_sampling_offsets(self):
224
+ print("Freeze sampling offsets")
225
+ self.sampling_offsets.weight.requires_grad = False
226
+ self.sampling_offsets.bias.requires_grad = False
227
+
228
+ def freeze_attention_weights(self):
229
+ print("Freeze attention weights")
230
+ self.attention_weights.weight.requires_grad = False
231
+ self.attention_weights.bias.requires_grad = False
232
+
233
+ def forward(
234
+ self,
235
+ query: torch.Tensor,
236
+ key: Optional[torch.Tensor] = None,
237
+ value: Optional[torch.Tensor] = None,
238
+ query_pos: Optional[torch.Tensor] = None,
239
+ key_padding_mask: Optional[torch.Tensor] = None,
240
+ reference_points: Optional[torch.Tensor] = None,
241
+ spatial_shapes: Optional[torch.Tensor] = None,
242
+ level_start_index: Optional[torch.Tensor] = None,
243
+ **kwargs
244
+ ) -> torch.Tensor:
245
+
246
+ """Forward Function of MultiScaleDeformableAttention
247
+
248
+ Args:
249
+ query (torch.Tensor): Query embeddings with shape
250
+ `(num_query, bs, embed_dim)`
251
+ key (torch.Tensor): Key embeddings with shape
252
+ `(num_key, bs, embed_dim)`
253
+ value (torch.Tensor): Value embeddings with shape
254
+ `(num_key, bs, embed_dim)`
255
+ query_pos (torch.Tensor): The position embedding for `query`. Default: None.
256
+ key_padding_mask (torch.Tensor): ByteTensor for `query`, with shape `(bs, num_key)`,
257
+ indicating which elements within `key` to be ignored in attention.
258
+ reference_points (torch.Tensor): The normalized reference points
259
+ with shape `(bs, num_query, num_levels, 2)`,
260
+ all elements is range in [0, 1], top-left (0, 0),
261
+ bottom-right (1, 1), including padding are.
262
+ or `(N, Length_{query}, num_levels, 4)`, add additional
263
+ two dimensions `(h, w)` to form reference boxes.
264
+ spatial_shapes (torch.Tensor): Spatial shape of features in different levels.
265
+ With shape `(num_levels, 2)`, last dimension represents `(h, w)`.
266
+ level_start_index (torch.Tensor): The start index of each level. A tensor with
267
+ shape `(num_levels, )` which can be represented as
268
+ `[0, h_0 * w_0, h_0 * w_0 + h_1 * w_1, ...]`.
269
+
270
+ Returns:
271
+ torch.Tensor: forward results with shape `(num_query, bs, embed_dim)`
272
+ """
273
+
274
+ if value is None:
275
+ value = query
276
+
277
+ if query_pos is not None:
278
+ query = query + query_pos
279
+
280
+ if not self.batch_first:
281
+ # change to (bs, num_query ,embed_dims)
282
+ query = query.permute(1, 0, 2)
283
+ value = value.permute(1, 0, 2)
284
+
285
+ bs, num_query, _ = query.shape
286
+ bs, num_value, _ = value.shape
287
+
288
+ assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
289
+
290
+ value = self.value_proj(value)
291
+ if key_padding_mask is not None:
292
+ value = value.masked_fill(key_padding_mask[..., None], float(0))
293
+ value = value.view(bs, num_value, self.num_heads, -1)
294
+ sampling_offsets = self.sampling_offsets(query).view(
295
+ bs, num_query, self.num_heads, self.num_levels, self.num_points, 2
296
+ )
297
+ attention_weights = self.attention_weights(query).view(
298
+ bs, num_query, self.num_heads, self.num_levels * self.num_points
299
+ )
300
+ attention_weights = attention_weights.softmax(-1)
301
+ attention_weights = attention_weights.view(
302
+ bs,
303
+ num_query,
304
+ self.num_heads,
305
+ self.num_levels,
306
+ self.num_points,
307
+ )
308
+
309
+ # bs, num_query, num_heads, num_levels, num_points, 2
310
+ if reference_points.shape[-1] == 2:
311
+ offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
312
+ sampling_locations = (
313
+ reference_points[:, :, None, :, None, :]
314
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
315
+ )
316
+ elif reference_points.shape[-1] == 4:
317
+ sampling_locations = (
318
+ reference_points[:, :, None, :, None, :2]
319
+ + sampling_offsets
320
+ / self.num_points
321
+ * reference_points[:, :, None, :, None, 2:]
322
+ * 0.5
323
+ )
324
+ else:
325
+ raise ValueError(
326
+ "Last dim of reference_points must be 2 or 4, but get {} instead.".format(
327
+ reference_points.shape[-1]
328
+ )
329
+ )
330
+
331
+ if torch.cuda.is_available() and value.is_cuda:
332
+ halffloat = False
333
+ if value.dtype == torch.float16:
334
+ halffloat = True
335
+ value = value.float()
336
+ sampling_locations = sampling_locations.float()
337
+ attention_weights = attention_weights.float()
338
+
339
+ output = MultiScaleDeformableAttnFunction.apply(
340
+ value,
341
+ spatial_shapes,
342
+ level_start_index,
343
+ sampling_locations,
344
+ attention_weights,
345
+ self.im2col_step,
346
+ )
347
+
348
+ if halffloat:
349
+ output = output.half()
350
+ else:
351
+ output = multi_scale_deformable_attn_pytorch(
352
+ value, spatial_shapes, sampling_locations, attention_weights
353
+ )
354
+
355
+ output = self.output_proj(output)
356
+
357
+ if not self.batch_first:
358
+ output = output.permute(1, 0, 2)
359
+
360
+ return output
361
+
362
+
363
+ def create_dummy_class(klass, dependency, message=""):
364
+ """
365
+ When a dependency of a class is not available, create a dummy class which throws ImportError
366
+ when used.
367
+
368
+ Args:
369
+ klass (str): name of the class.
370
+ dependency (str): name of the dependency.
371
+ message: extra message to print
372
+ Returns:
373
+ class: a class object
374
+ """
375
+ err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, klass)
376
+ if message:
377
+ err = err + " " + message
378
+
379
+ class _DummyMetaClass(type):
380
+ # throw error on class attribute access
381
+ def __getattr__(_, __): # noqa: B902
382
+ raise ImportError(err)
383
+
384
+ class _Dummy(object, metaclass=_DummyMetaClass):
385
+ # throw error on constructor
386
+ def __init__(self, *args, **kwargs):
387
+ raise ImportError(err)
388
+
389
+ return _Dummy
390
+
391
+
392
+ def create_dummy_func(func, dependency, message=""):
393
+ """
394
+ When a dependency of a function is not available, create a dummy function which throws
395
+ ImportError when used.
396
+
397
+ Args:
398
+ func (str): name of the function.
399
+ dependency (str or list[str]): name(s) of the dependency.
400
+ message: extra message to print
401
+ Returns:
402
+ function: a function object
403
+ """
404
+ err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, func)
405
+ if message:
406
+ err = err + " " + message
407
+
408
+ if isinstance(dependency, (list, tuple)):
409
+ dependency = ",".join(dependency)
410
+
411
+ def _dummy(*args, **kwargs):
412
+ raise ImportError(err)
413
+
414
+ return _dummy
groundingdino/models/GroundingDINO/.ipynb_checkpoints/transformer-checkpoint.py ADDED
@@ -0,0 +1,961 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # DINO
8
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------
11
+ # Conditional DETR Transformer class.
12
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
13
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
14
+ # ------------------------------------------------------------------------
15
+ # Modified from DETR (https://github.com/facebookresearch/detr)
16
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
17
+ # ------------------------------------------------------------------------
18
+
19
+ from typing import Optional
20
+
21
+ import torch
22
+ import torch.utils.checkpoint as checkpoint
23
+ from torch import Tensor, nn
24
+ import loralib as lora
25
+ from groundingdino.util.misc import inverse_sigmoid
26
+
27
+ from .fuse_modules import BiAttentionBlock
28
+ from .ms_deform_attn import MultiScaleDeformableAttention as MSDeformAttn
29
+ from .transformer_vanilla import TransformerEncoderLayer
30
+ from .utils import (
31
+ MLP,
32
+ _get_activation_fn,
33
+ _get_clones,
34
+ gen_encoder_output_proposals,
35
+ gen_sineembed_for_position,
36
+ get_sine_pos_embed,
37
+ )
38
+
39
+
40
+ class Transformer(nn.Module):
41
+ def __init__(
42
+ self,
43
+ d_model=256,
44
+ nhead=8,
45
+ num_queries=300,
46
+ num_encoder_layers=6,
47
+ num_unicoder_layers=0,
48
+ num_decoder_layers=6,
49
+ dim_feedforward=2048,
50
+ dropout=0.0,
51
+ activation="relu",
52
+ normalize_before=False,
53
+ return_intermediate_dec=False,
54
+ query_dim=4,
55
+ num_patterns=0,
56
+ # for deformable encoder
57
+ num_feature_levels=1,
58
+ enc_n_points=4,
59
+ dec_n_points=4,
60
+ # init query
61
+ learnable_tgt_init=False,
62
+ # two stage
63
+ two_stage_type="no", # ['no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1']
64
+ embed_init_tgt=False,
65
+ # for text
66
+ use_text_enhancer=False,
67
+ use_fusion_layer=False,
68
+ use_checkpoint=False,
69
+ use_transformer_ckpt=False,
70
+ use_text_cross_attention=False,
71
+ text_dropout=0.1,
72
+ fusion_dropout=0.1,
73
+ fusion_droppath=0.0,
74
+ ):
75
+ super().__init__()
76
+ self.num_feature_levels = num_feature_levels
77
+ self.num_encoder_layers = num_encoder_layers
78
+ self.num_unicoder_layers = num_unicoder_layers
79
+ self.num_decoder_layers = num_decoder_layers
80
+ self.num_queries = num_queries
81
+ assert query_dim == 4
82
+
83
+ # choose encoder layer type
84
+ encoder_layer = DeformableTransformerEncoderLayer(
85
+ d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points
86
+ )
87
+
88
+ if use_text_enhancer:
89
+ text_enhance_layer = TransformerEncoderLayer(
90
+ d_model=d_model,
91
+ nhead=nhead // 2,
92
+ dim_feedforward=dim_feedforward // 2,
93
+ dropout=text_dropout,
94
+ )
95
+ else:
96
+ text_enhance_layer = None
97
+
98
+ if use_fusion_layer:
99
+ feature_fusion_layer = BiAttentionBlock(
100
+ v_dim=d_model,
101
+ l_dim=d_model,
102
+ embed_dim=dim_feedforward // 2,
103
+ num_heads=nhead // 2,
104
+ dropout=fusion_dropout,
105
+ drop_path=fusion_droppath,
106
+ )
107
+ else:
108
+ feature_fusion_layer = None
109
+
110
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
111
+ assert encoder_norm is None
112
+ self.encoder = TransformerEncoder(
113
+ encoder_layer,
114
+ num_encoder_layers,
115
+ d_model=d_model,
116
+ num_queries=num_queries,
117
+ text_enhance_layer=text_enhance_layer,
118
+ feature_fusion_layer=feature_fusion_layer,
119
+ use_checkpoint=use_checkpoint,
120
+ use_transformer_ckpt=use_transformer_ckpt,
121
+ )
122
+
123
+ # choose decoder layer type
124
+ decoder_layer = DeformableTransformerDecoderLayer(
125
+ d_model,
126
+ dim_feedforward,
127
+ dropout,
128
+ activation,
129
+ num_feature_levels,
130
+ nhead,
131
+ dec_n_points,
132
+ use_text_cross_attention=use_text_cross_attention,
133
+ )
134
+
135
+ decoder_norm = nn.LayerNorm(d_model)
136
+ self.decoder = TransformerDecoder(
137
+ decoder_layer,
138
+ num_decoder_layers,
139
+ decoder_norm,
140
+ return_intermediate=return_intermediate_dec,
141
+ d_model=d_model,
142
+ query_dim=query_dim,
143
+ num_feature_levels=num_feature_levels,
144
+ )
145
+
146
+ self.d_model = d_model
147
+ self.nhead = nhead
148
+ self.dec_layers = num_decoder_layers
149
+ self.num_queries = num_queries # useful for single stage model only
150
+ self.num_patterns = num_patterns
151
+ if not isinstance(num_patterns, int):
152
+ Warning("num_patterns should be int but {}".format(type(num_patterns)))
153
+ self.num_patterns = 0
154
+
155
+ if num_feature_levels > 1:
156
+ if self.num_encoder_layers > 0:
157
+ self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
158
+ else:
159
+ self.level_embed = None
160
+
161
+ self.learnable_tgt_init = learnable_tgt_init
162
+ assert learnable_tgt_init, "why not learnable_tgt_init"
163
+ self.embed_init_tgt = embed_init_tgt
164
+ if (two_stage_type != "no" and embed_init_tgt) or (two_stage_type == "no"):
165
+ self.tgt_embed = nn.Embedding(self.num_queries, d_model)
166
+ nn.init.normal_(self.tgt_embed.weight.data)
167
+ else:
168
+ self.tgt_embed = None
169
+
170
+ # for two stage
171
+ self.two_stage_type = two_stage_type
172
+ assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format(
173
+ two_stage_type
174
+ )
175
+ if two_stage_type == "standard":
176
+ # anchor selection at the output of encoder
177
+ self.enc_output = nn.Linear(d_model, d_model)
178
+ self.enc_output_norm = nn.LayerNorm(d_model)
179
+ self.two_stage_wh_embedding = None
180
+
181
+ if two_stage_type == "no":
182
+ self.init_ref_points(num_queries) # init self.refpoint_embed
183
+
184
+ self.enc_out_class_embed = None
185
+ self.enc_out_bbox_embed = None
186
+
187
+ self._reset_parameters()
188
+
189
+ def _reset_parameters(self):
190
+ for p in self.parameters():
191
+ if p.dim() > 1:
192
+ nn.init.xavier_uniform_(p)
193
+ for m in self.modules():
194
+ if isinstance(m, MSDeformAttn):
195
+ m._reset_parameters()
196
+ if self.num_feature_levels > 1 and self.level_embed is not None:
197
+ nn.init.normal_(self.level_embed)
198
+
199
+ def get_valid_ratio(self, mask):
200
+ _, H, W = mask.shape
201
+ valid_H = torch.sum(~mask[:, :, 0], 1)
202
+ valid_W = torch.sum(~mask[:, 0, :], 1)
203
+ valid_ratio_h = valid_H.float() / H
204
+ valid_ratio_w = valid_W.float() / W
205
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
206
+ return valid_ratio
207
+
208
+ def init_ref_points(self, use_num_queries):
209
+ self.refpoint_embed = nn.Embedding(use_num_queries, 4)
210
+
211
+ def forward(self, srcs, masks, refpoint_embed, pos_embeds, tgt, attn_mask=None, text_dict=None):
212
+ """
213
+ Input:
214
+ - srcs: List of multi features [bs, ci, hi, wi]
215
+ - masks: List of multi masks [bs, hi, wi]
216
+ - refpoint_embed: [bs, num_dn, 4]. None in infer
217
+ - pos_embeds: List of multi pos embeds [bs, ci, hi, wi]
218
+ - tgt: [bs, num_dn, d_model]. None in infer
219
+
220
+ """
221
+ # prepare input for encoder
222
+ src_flatten = []
223
+ mask_flatten = []
224
+ lvl_pos_embed_flatten = []
225
+ spatial_shapes = []
226
+ for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
227
+ bs, c, h, w = src.shape
228
+ spatial_shape = (h, w)
229
+ spatial_shapes.append(spatial_shape)
230
+
231
+ src = src.flatten(2).transpose(1, 2) # bs, hw, c
232
+ mask = mask.flatten(1) # bs, hw
233
+ pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
234
+ if self.num_feature_levels > 1 and self.level_embed is not None:
235
+ lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
236
+ else:
237
+ lvl_pos_embed = pos_embed
238
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
239
+ src_flatten.append(src)
240
+ mask_flatten.append(mask)
241
+ src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
242
+ mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
243
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c
244
+ spatial_shapes = torch.as_tensor(
245
+ spatial_shapes, dtype=torch.long, device=src_flatten.device
246
+ )
247
+ level_start_index = torch.cat(
248
+ (spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])
249
+ )
250
+ valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
251
+
252
+ # two stage
253
+ enc_topk_proposals = enc_refpoint_embed = None
254
+
255
+ #########################################################
256
+ # Begin Encoder
257
+ #########################################################
258
+ memory, memory_text = self.encoder(
259
+ src_flatten,
260
+ pos=lvl_pos_embed_flatten,
261
+ level_start_index=level_start_index,
262
+ spatial_shapes=spatial_shapes,
263
+ valid_ratios=valid_ratios,
264
+ key_padding_mask=mask_flatten,
265
+ memory_text=text_dict["encoded_text"],
266
+ text_attention_mask=~text_dict["text_token_mask"],
267
+ # we ~ the mask . False means use the token; True means pad the token
268
+ position_ids=text_dict["position_ids"],
269
+ text_self_attention_masks=text_dict["text_self_attention_masks"],
270
+ )
271
+ #########################################################
272
+ # End Encoder
273
+ # - memory: bs, \sum{hw}, c
274
+ # - mask_flatten: bs, \sum{hw}
275
+ # - lvl_pos_embed_flatten: bs, \sum{hw}, c
276
+ # - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
277
+ # - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
278
+ #########################################################
279
+ text_dict["encoded_text"] = memory_text
280
+ # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
281
+ # if memory.isnan().any() | memory.isinf().any():
282
+ # import ipdb; ipdb.set_trace()
283
+
284
+ if self.two_stage_type == "standard":
285
+ output_memory, output_proposals = gen_encoder_output_proposals(
286
+ memory, mask_flatten, spatial_shapes
287
+ )
288
+ output_memory = self.enc_output_norm(self.enc_output(output_memory))
289
+
290
+ if text_dict is not None:
291
+ enc_outputs_class_unselected = self.enc_out_class_embed(output_memory, text_dict)
292
+ else:
293
+ enc_outputs_class_unselected = self.enc_out_class_embed(output_memory)
294
+
295
+ topk_logits = enc_outputs_class_unselected.max(-1)[0]
296
+ enc_outputs_coord_unselected = (
297
+ self.enc_out_bbox_embed(output_memory) + output_proposals
298
+ ) # (bs, \sum{hw}, 4) unsigmoid
299
+ topk = self.num_queries
300
+
301
+ topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq
302
+
303
+ # gather boxes
304
+ refpoint_embed_undetach = torch.gather(
305
+ enc_outputs_coord_unselected, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
306
+ ) # unsigmoid
307
+ refpoint_embed_ = refpoint_embed_undetach.detach()
308
+ init_box_proposal = torch.gather(
309
+ output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
310
+ ).sigmoid() # sigmoid
311
+
312
+ # gather tgt
313
+ tgt_undetach = torch.gather(
314
+ output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model)
315
+ )
316
+ if self.embed_init_tgt:
317
+ tgt_ = (
318
+ self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
319
+ ) # nq, bs, d_model
320
+ else:
321
+ tgt_ = tgt_undetach.detach()
322
+
323
+ if refpoint_embed is not None:
324
+ refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1)
325
+ tgt = torch.cat([tgt, tgt_], dim=1)
326
+ else:
327
+ refpoint_embed, tgt = refpoint_embed_, tgt_
328
+
329
+ elif self.two_stage_type == "no":
330
+ tgt_ = (
331
+ self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
332
+ ) # nq, bs, d_model
333
+ refpoint_embed_ = (
334
+ self.refpoint_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
335
+ ) # nq, bs, 4
336
+
337
+ if refpoint_embed is not None:
338
+ refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1)
339
+ tgt = torch.cat([tgt, tgt_], dim=1)
340
+ else:
341
+ refpoint_embed, tgt = refpoint_embed_, tgt_
342
+
343
+ if self.num_patterns > 0:
344
+ tgt_embed = tgt.repeat(1, self.num_patterns, 1)
345
+ refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1)
346
+ tgt_pat = self.patterns.weight[None, :, :].repeat_interleave(
347
+ self.num_queries, 1
348
+ ) # 1, n_q*n_pat, d_model
349
+ tgt = tgt_embed + tgt_pat
350
+
351
+ init_box_proposal = refpoint_embed_.sigmoid()
352
+
353
+ else:
354
+ raise NotImplementedError("unknown two_stage_type {}".format(self.two_stage_type))
355
+ #########################################################
356
+ # End preparing tgt
357
+ # - tgt: bs, NQ, d_model
358
+ # - refpoint_embed(unsigmoid): bs, NQ, d_model
359
+ #########################################################
360
+
361
+ #########################################################
362
+ # Begin Decoder
363
+ #########################################################
364
+ hs, references = self.decoder(
365
+ tgt=tgt.transpose(0, 1),
366
+ memory=memory.transpose(0, 1),
367
+ memory_key_padding_mask=mask_flatten,
368
+ pos=lvl_pos_embed_flatten.transpose(0, 1),
369
+ refpoints_unsigmoid=refpoint_embed.transpose(0, 1),
370
+ level_start_index=level_start_index,
371
+ spatial_shapes=spatial_shapes,
372
+ valid_ratios=valid_ratios,
373
+ tgt_mask=attn_mask,
374
+ memory_text=text_dict["encoded_text"],
375
+ text_attention_mask=~text_dict["text_token_mask"],
376
+ # we ~ the mask . False means use the token; True means pad the token
377
+ )
378
+ #########################################################
379
+ # End Decoder
380
+ # hs: n_dec, bs, nq, d_model
381
+ # references: n_dec+1, bs, nq, query_dim
382
+ #########################################################
383
+
384
+ #########################################################
385
+ # Begin postprocess
386
+ #########################################################
387
+ if self.two_stage_type == "standard":
388
+ hs_enc = tgt_undetach.unsqueeze(0)
389
+ ref_enc = refpoint_embed_undetach.sigmoid().unsqueeze(0)
390
+ else:
391
+ hs_enc = ref_enc = None
392
+ #########################################################
393
+ # End postprocess
394
+ # hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or (n_enc, bs, nq, d_model) or None
395
+ # ref_enc: (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or (n_enc, bs, nq, d_model) or None
396
+ #########################################################
397
+
398
+ return hs, references, hs_enc, ref_enc, init_box_proposal
399
+ # hs: (n_dec, bs, nq, d_model)
400
+ # references: sigmoid coordinates. (n_dec+1, bs, bq, 4)
401
+ # hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or None
402
+ # ref_enc: sigmoid coordinates. \
403
+ # (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or None
404
+
405
+
406
+ class TransformerEncoder(nn.Module):
407
+ def __init__(
408
+ self,
409
+ encoder_layer,
410
+ num_layers,
411
+ d_model=256,
412
+ num_queries=300,
413
+ enc_layer_share=False,
414
+ text_enhance_layer=None,
415
+ feature_fusion_layer=None,
416
+ use_checkpoint=False,
417
+ use_transformer_ckpt=False,
418
+ ):
419
+ """_summary_
420
+
421
+ Args:
422
+ encoder_layer (_type_): _description_
423
+ num_layers (_type_): _description_
424
+ norm (_type_, optional): _description_. Defaults to None.
425
+ d_model (int, optional): _description_. Defaults to 256.
426
+ num_queries (int, optional): _description_. Defaults to 300.
427
+ enc_layer_share (bool, optional): _description_. Defaults to False.
428
+
429
+ """
430
+ super().__init__()
431
+ # prepare layers
432
+ self.layers = []
433
+ self.text_layers = []
434
+ self.fusion_layers = []
435
+ if num_layers > 0:
436
+ self.layers = _get_clones(encoder_layer, num_layers, layer_share=enc_layer_share)
437
+
438
+ if text_enhance_layer is not None:
439
+ self.text_layers = _get_clones(
440
+ text_enhance_layer, num_layers, layer_share=enc_layer_share
441
+ )
442
+ if feature_fusion_layer is not None:
443
+ self.fusion_layers = _get_clones(
444
+ feature_fusion_layer, num_layers, layer_share=enc_layer_share
445
+ )
446
+ else:
447
+ self.layers = []
448
+ del encoder_layer
449
+
450
+ if text_enhance_layer is not None:
451
+ self.text_layers = []
452
+ del text_enhance_layer
453
+ if feature_fusion_layer is not None:
454
+ self.fusion_layers = []
455
+ del feature_fusion_layer
456
+
457
+ self.query_scale = None
458
+ self.num_queries = num_queries
459
+ self.num_layers = num_layers
460
+ self.d_model = d_model
461
+
462
+ self.use_checkpoint = use_checkpoint
463
+ self.use_transformer_ckpt = use_transformer_ckpt
464
+
465
+ @staticmethod
466
+ def get_reference_points(spatial_shapes, valid_ratios, device):
467
+ reference_points_list = []
468
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
469
+
470
+ ref_y, ref_x = torch.meshgrid(
471
+ torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
472
+ torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
473
+ )
474
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
475
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
476
+ ref = torch.stack((ref_x, ref_y), -1)
477
+ reference_points_list.append(ref)
478
+ reference_points = torch.cat(reference_points_list, 1)
479
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
480
+ return reference_points
481
+
482
+ def forward(
483
+ self,
484
+ # for images
485
+ src: Tensor,
486
+ pos: Tensor,
487
+ spatial_shapes: Tensor,
488
+ level_start_index: Tensor,
489
+ valid_ratios: Tensor,
490
+ key_padding_mask: Tensor,
491
+ # for texts
492
+ memory_text: Tensor = None,
493
+ text_attention_mask: Tensor = None,
494
+ pos_text: Tensor = None,
495
+ text_self_attention_masks: Tensor = None,
496
+ position_ids: Tensor = None,
497
+ ):
498
+ """
499
+ Input:
500
+ - src: [bs, sum(hi*wi), 256]
501
+ - pos: pos embed for src. [bs, sum(hi*wi), 256]
502
+ - spatial_shapes: h,w of each level [num_level, 2]
503
+ - level_start_index: [num_level] start point of level in sum(hi*wi).
504
+ - valid_ratios: [bs, num_level, 2]
505
+ - key_padding_mask: [bs, sum(hi*wi)]
506
+
507
+ - memory_text: bs, n_text, 256
508
+ - text_attention_mask: bs, n_text
509
+ False for no padding; True for padding
510
+ - pos_text: bs, n_text, 256
511
+
512
+ - position_ids: bs, n_text
513
+ Intermedia:
514
+ - reference_points: [bs, sum(hi*wi), num_level, 2]
515
+ Outpus:
516
+ - output: [bs, sum(hi*wi), 256]
517
+ """
518
+
519
+ output = src
520
+
521
+ # preparation and reshape
522
+ if self.num_layers > 0:
523
+ reference_points = self.get_reference_points(
524
+ spatial_shapes, valid_ratios, device=src.device
525
+ )
526
+
527
+ if self.text_layers:
528
+ # generate pos_text
529
+ bs, n_text, text_dim = memory_text.shape
530
+ if pos_text is None and position_ids is None:
531
+ pos_text = (
532
+ torch.arange(n_text, device=memory_text.device)
533
+ .float()
534
+ .unsqueeze(0)
535
+ .unsqueeze(-1)
536
+ .repeat(bs, 1, 1)
537
+ )
538
+ pos_text = get_sine_pos_embed(pos_text, num_pos_feats=256, exchange_xy=False)
539
+ if position_ids is not None:
540
+ pos_text = get_sine_pos_embed(
541
+ position_ids[..., None], num_pos_feats=256, exchange_xy=False
542
+ )
543
+
544
+ # main process
545
+ for layer_id, layer in enumerate(self.layers):
546
+ # if output.isnan().any() or memory_text.isnan().any():
547
+ # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
548
+ # import ipdb; ipdb.set_trace()
549
+ if self.fusion_layers:
550
+ if self.use_checkpoint:
551
+ output, memory_text = checkpoint.checkpoint(
552
+ self.fusion_layers[layer_id],
553
+ output,
554
+ memory_text,
555
+ key_padding_mask,
556
+ text_attention_mask,
557
+ )
558
+ else:
559
+ output, memory_text = self.fusion_layers[layer_id](
560
+ v=output,
561
+ l=memory_text,
562
+ attention_mask_v=key_padding_mask,
563
+ attention_mask_l=text_attention_mask,
564
+ )
565
+
566
+ if self.text_layers:
567
+ memory_text = self.text_layers[layer_id](
568
+ src=memory_text.transpose(0, 1),
569
+ src_mask=~text_self_attention_masks, # note we use ~ for mask here
570
+ src_key_padding_mask=text_attention_mask,
571
+ pos=(pos_text.transpose(0, 1) if pos_text is not None else None),
572
+ ).transpose(0, 1)
573
+
574
+ # main process
575
+ if self.use_transformer_ckpt:
576
+ output = checkpoint.checkpoint(
577
+ layer,
578
+ output,
579
+ pos,
580
+ reference_points,
581
+ spatial_shapes,
582
+ level_start_index,
583
+ key_padding_mask,
584
+ )
585
+ else:
586
+ output = layer(
587
+ src=output,
588
+ pos=pos,
589
+ reference_points=reference_points,
590
+ spatial_shapes=spatial_shapes,
591
+ level_start_index=level_start_index,
592
+ key_padding_mask=key_padding_mask,
593
+ )
594
+
595
+ return output, memory_text
596
+
597
+
598
+ class TransformerDecoder(nn.Module):
599
+ def __init__(
600
+ self,
601
+ decoder_layer,
602
+ num_layers,
603
+ norm=None,
604
+ return_intermediate=False,
605
+ d_model=256,
606
+ query_dim=4,
607
+ num_feature_levels=1,
608
+ ):
609
+ super().__init__()
610
+ if num_layers > 0:
611
+ self.layers = _get_clones(decoder_layer, num_layers)
612
+ else:
613
+ self.layers = []
614
+ self.num_layers = num_layers
615
+ self.norm = norm
616
+ self.return_intermediate = return_intermediate
617
+ assert return_intermediate, "support return_intermediate only"
618
+ self.query_dim = query_dim
619
+ assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim)
620
+ self.num_feature_levels = num_feature_levels
621
+
622
+ self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2)
623
+ self.query_pos_sine_scale = None
624
+
625
+ self.query_scale = None
626
+ self.bbox_embed = None
627
+ self.class_embed = None
628
+
629
+ self.d_model = d_model
630
+
631
+ self.ref_anchor_head = None
632
+
633
+ def forward(
634
+ self,
635
+ tgt,
636
+ memory,
637
+ tgt_mask: Optional[Tensor] = None,
638
+ memory_mask: Optional[Tensor] = None,
639
+ tgt_key_padding_mask: Optional[Tensor] = None,
640
+ memory_key_padding_mask: Optional[Tensor] = None,
641
+ pos: Optional[Tensor] = None,
642
+ refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2
643
+ # for memory
644
+ level_start_index: Optional[Tensor] = None, # num_levels
645
+ spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
646
+ valid_ratios: Optional[Tensor] = None,
647
+ # for text
648
+ memory_text: Optional[Tensor] = None,
649
+ text_attention_mask: Optional[Tensor] = None,
650
+ ):
651
+ """
652
+ Input:
653
+ - tgt: nq, bs, d_model
654
+ - memory: hw, bs, d_model
655
+ - pos: hw, bs, d_model
656
+ - refpoints_unsigmoid: nq, bs, 2/4
657
+ - valid_ratios/spatial_shapes: bs, nlevel, 2
658
+ """
659
+ output = tgt
660
+
661
+ intermediate = []
662
+ reference_points = refpoints_unsigmoid.sigmoid()
663
+ ref_points = [reference_points]
664
+
665
+ for layer_id, layer in enumerate(self.layers):
666
+
667
+ if reference_points.shape[-1] == 4:
668
+ reference_points_input = (
669
+ reference_points[:, :, None]
670
+ * torch.cat([valid_ratios, valid_ratios], -1)[None, :]
671
+ ) # nq, bs, nlevel, 4
672
+ else:
673
+ assert reference_points.shape[-1] == 2
674
+ reference_points_input = reference_points[:, :, None] * valid_ratios[None, :]
675
+ query_sine_embed = gen_sineembed_for_position(
676
+ reference_points_input[:, :, 0, :]
677
+ ) # nq, bs, 256*2
678
+
679
+ # conditional query
680
+ raw_query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256
681
+ pos_scale = self.query_scale(output) if self.query_scale is not None else 1
682
+ query_pos = pos_scale * raw_query_pos
683
+ # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
684
+ # if query_pos.isnan().any() | query_pos.isinf().any():
685
+ # import ipdb; ipdb.set_trace()
686
+
687
+ # main process
688
+ output = layer(
689
+ tgt=output,
690
+ tgt_query_pos=query_pos,
691
+ tgt_query_sine_embed=query_sine_embed,
692
+ tgt_key_padding_mask=tgt_key_padding_mask,
693
+ tgt_reference_points=reference_points_input,
694
+ memory_text=memory_text,
695
+ text_attention_mask=text_attention_mask,
696
+ memory=memory,
697
+ memory_key_padding_mask=memory_key_padding_mask,
698
+ memory_level_start_index=level_start_index,
699
+ memory_spatial_shapes=spatial_shapes,
700
+ memory_pos=pos,
701
+ self_attn_mask=tgt_mask,
702
+ cross_attn_mask=memory_mask,
703
+ )
704
+ if output.isnan().any() | output.isinf().any():
705
+ print(f"output layer_id {layer_id} is nan")
706
+ try:
707
+ num_nan = output.isnan().sum().item()
708
+ num_inf = output.isinf().sum().item()
709
+ print(f"num_nan {num_nan}, num_inf {num_inf}")
710
+ except Exception as e:
711
+ print(e)
712
+ # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
713
+ # import ipdb; ipdb.set_trace()
714
+
715
+ # iter update
716
+ if self.bbox_embed is not None:
717
+ # box_holder = self.bbox_embed(output)
718
+ # box_holder[..., :self.query_dim] += inverse_sigmoid(reference_points)
719
+ # new_reference_points = box_holder[..., :self.query_dim].sigmoid()
720
+
721
+ reference_before_sigmoid = inverse_sigmoid(reference_points)
722
+ delta_unsig = self.bbox_embed[layer_id](output)
723
+ outputs_unsig = delta_unsig + reference_before_sigmoid
724
+ new_reference_points = outputs_unsig.sigmoid()
725
+
726
+ reference_points = new_reference_points.detach()
727
+ # if layer_id != self.num_layers - 1:
728
+ ref_points.append(new_reference_points)
729
+
730
+ intermediate.append(self.norm(output))
731
+
732
+ return [
733
+ [itm_out.transpose(0, 1) for itm_out in intermediate],
734
+ [itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points],
735
+ ]
736
+
737
+
738
+ class DeformableTransformerEncoderLayer(nn.Module):
739
+ def __init__(
740
+ self,
741
+ d_model=256,
742
+ d_ffn=1024,
743
+ dropout=0.1,
744
+ activation="relu",
745
+ n_levels=4,
746
+ n_heads=8,
747
+ n_points=4,
748
+ ):
749
+ super().__init__()
750
+
751
+ # self attention
752
+ self.self_attn = MSDeformAttn(
753
+ embed_dim=d_model,
754
+ num_levels=n_levels,
755
+ num_heads=n_heads,
756
+ num_points=n_points,
757
+ batch_first=True,
758
+ )
759
+ self.dropout1 = nn.Dropout(dropout)
760
+ self.norm1 = nn.LayerNorm(d_model)
761
+
762
+ # ffn
763
+ r = 16
764
+ self.linear1 = lora.Linear(d_model, d_ffn , r=r)
765
+ self.activation = _get_activation_fn(activation, d_model=d_ffn)
766
+ self.dropout2 = nn.Dropout(dropout)
767
+ self.linear2 = lora.Linear(d_ffn, d_model , r=r)
768
+ self.dropout3 = nn.Dropout(dropout)
769
+ self.norm2 = nn.LayerNorm(d_model)
770
+
771
+ @staticmethod
772
+ def with_pos_embed(tensor, pos):
773
+ return tensor if pos is None else tensor + pos
774
+
775
+ def forward_ffn(self, src):
776
+ src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
777
+ src = src + self.dropout3(src2)
778
+ src = self.norm2(src)
779
+ return src
780
+
781
+ def forward(
782
+ self, src, pos, reference_points, spatial_shapes, level_start_index, key_padding_mask=None
783
+ ):
784
+ # self attention
785
+ # import ipdb; ipdb.set_trace()
786
+ src2 = self.self_attn(
787
+ query=self.with_pos_embed(src, pos),
788
+ reference_points=reference_points,
789
+ value=src,
790
+ spatial_shapes=spatial_shapes,
791
+ level_start_index=level_start_index,
792
+ key_padding_mask=key_padding_mask,
793
+ )
794
+ src = src + self.dropout1(src2)
795
+ src = self.norm1(src)
796
+
797
+ # ffn
798
+ src = self.forward_ffn(src)
799
+
800
+ return src
801
+
802
+
803
+ class DeformableTransformerDecoderLayer(nn.Module):
804
+ def __init__(
805
+ self,
806
+ d_model=256,
807
+ d_ffn=1024,
808
+ dropout=0.1,
809
+ activation="relu",
810
+ n_levels=4,
811
+ n_heads=8,
812
+ n_points=4,
813
+ use_text_feat_guide=False,
814
+ use_text_cross_attention=False,
815
+ ):
816
+ super().__init__()
817
+
818
+ # cross attention
819
+ self.cross_attn = MSDeformAttn(
820
+ embed_dim=d_model,
821
+ num_levels=n_levels,
822
+ num_heads=n_heads,
823
+ num_points=n_points,
824
+ batch_first=True,
825
+ )
826
+ self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
827
+ self.norm1 = nn.LayerNorm(d_model)
828
+
829
+ # cross attention text
830
+ if use_text_cross_attention:
831
+ self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
832
+ self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
833
+ self.catext_norm = nn.LayerNorm(d_model)
834
+
835
+ # self attention
836
+ self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
837
+ self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
838
+ self.norm2 = nn.LayerNorm(d_model)
839
+
840
+ # ffn
841
+ r = 16
842
+ self.linear1 = lora.Linear(d_model, d_ffn , r=r)
843
+ self.activation = _get_activation_fn(activation, d_model=d_ffn, batch_dim=1)
844
+ self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
845
+ self.linear2 = lora.Linear(d_ffn, d_model , r=r)
846
+ self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
847
+ self.norm3 = nn.LayerNorm(d_model)
848
+
849
+ self.key_aware_proj = None
850
+ self.use_text_feat_guide = use_text_feat_guide
851
+ assert not use_text_feat_guide
852
+ self.use_text_cross_attention = use_text_cross_attention
853
+
854
+ def rm_self_attn_modules(self):
855
+ self.self_attn = None
856
+ self.dropout2 = None
857
+ self.norm2 = None
858
+
859
+ @staticmethod
860
+ def with_pos_embed(tensor, pos):
861
+ return tensor if pos is None else tensor + pos
862
+
863
+ def forward_ffn(self, tgt):
864
+ with torch.cuda.amp.autocast(enabled=False):
865
+ tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
866
+ tgt = tgt + self.dropout4(tgt2)
867
+ tgt = self.norm3(tgt)
868
+ return tgt
869
+
870
+ def forward(
871
+ self,
872
+ # for tgt
873
+ tgt: Optional[Tensor], # nq, bs, d_model
874
+ tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
875
+ tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
876
+ tgt_key_padding_mask: Optional[Tensor] = None,
877
+ tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
878
+ memory_text: Optional[Tensor] = None, # bs, num_token, d_model
879
+ text_attention_mask: Optional[Tensor] = None, # bs, num_token
880
+ # for memory
881
+ memory: Optional[Tensor] = None, # hw, bs, d_model
882
+ memory_key_padding_mask: Optional[Tensor] = None,
883
+ memory_level_start_index: Optional[Tensor] = None, # num_levels
884
+ memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
885
+ memory_pos: Optional[Tensor] = None, # pos for memory
886
+ # sa
887
+ self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
888
+ cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
889
+ ):
890
+ """
891
+ Input:
892
+ - tgt/tgt_query_pos: nq, bs, d_model
893
+ -
894
+ """
895
+ assert cross_attn_mask is None
896
+
897
+ # self attention
898
+ if self.self_attn is not None:
899
+ # import ipdb; ipdb.set_trace()
900
+ q = k = self.with_pos_embed(tgt, tgt_query_pos)
901
+ tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]
902
+ tgt = tgt + self.dropout2(tgt2)
903
+ tgt = self.norm2(tgt)
904
+
905
+ if self.use_text_cross_attention:
906
+ tgt2 = self.ca_text(
907
+ self.with_pos_embed(tgt, tgt_query_pos),
908
+ memory_text.transpose(0, 1),
909
+ memory_text.transpose(0, 1),
910
+ key_padding_mask=text_attention_mask,
911
+ )[0]
912
+ tgt = tgt + self.catext_dropout(tgt2)
913
+ tgt = self.catext_norm(tgt)
914
+
915
+ tgt2 = self.cross_attn(
916
+ query=self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
917
+ reference_points=tgt_reference_points.transpose(0, 1).contiguous(),
918
+ value=memory.transpose(0, 1),
919
+ spatial_shapes=memory_spatial_shapes,
920
+ level_start_index=memory_level_start_index,
921
+ key_padding_mask=memory_key_padding_mask,
922
+ ).transpose(0, 1)
923
+ tgt = tgt + self.dropout1(tgt2)
924
+ tgt = self.norm1(tgt)
925
+
926
+ # ffn
927
+ tgt = self.forward_ffn(tgt)
928
+
929
+ return tgt
930
+
931
+
932
+ def build_transformer(args):
933
+ return Transformer(
934
+ d_model=args.hidden_dim,
935
+ dropout=args.dropout,
936
+ nhead=args.nheads,
937
+ num_queries=args.num_queries,
938
+ dim_feedforward=args.dim_feedforward,
939
+ num_encoder_layers=args.enc_layers,
940
+ num_decoder_layers=args.dec_layers,
941
+ normalize_before=args.pre_norm,
942
+ return_intermediate_dec=True,
943
+ query_dim=args.query_dim,
944
+ activation=args.transformer_activation,
945
+ num_patterns=args.num_patterns,
946
+ num_feature_levels=args.num_feature_levels,
947
+ enc_n_points=args.enc_n_points,
948
+ dec_n_points=args.dec_n_points,
949
+ learnable_tgt_init=True,
950
+ # two stage
951
+ two_stage_type=args.two_stage_type, # ['no', 'standard', 'early']
952
+ embed_init_tgt=args.embed_init_tgt,
953
+ use_text_enhancer=args.use_text_enhancer,
954
+ use_fusion_layer=args.use_fusion_layer,
955
+ use_checkpoint=args.use_checkpoint,
956
+ use_transformer_ckpt=args.use_transformer_ckpt,
957
+ use_text_cross_attention=args.use_text_cross_attention,
958
+ text_dropout=args.text_dropout,
959
+ fusion_dropout=args.fusion_dropout,
960
+ fusion_droppath=args.fusion_droppath,
961
+ )
groundingdino/models/GroundingDINO/.ipynb_checkpoints/transformer_vanilla-checkpoint.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
8
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
9
+ """
10
+ DETR Transformer class.
11
+
12
+ Copy-paste from torch.nn.Transformer with modifications:
13
+ * positional encodings are passed in MHattention
14
+ * extra LN at the end of encoder is removed
15
+ * decoder returns a stack of activations from all decoding layers
16
+ """
17
+ from typing import Optional
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from torch import Tensor, nn
22
+ import loralib as lora
23
+ from .utils import (
24
+ MLP,
25
+ _get_activation_fn,
26
+ _get_clones,
27
+ gen_encoder_output_proposals,
28
+ gen_sineembed_for_position,
29
+ sigmoid_focal_loss,
30
+ )
31
+
32
+
33
+ class TextTransformer(nn.Module):
34
+ def __init__(self, num_layers, d_model=256, nheads=8, dim_feedforward=2048, dropout=0.1):
35
+ super().__init__()
36
+ self.num_layers = num_layers
37
+ self.d_model = d_model
38
+ self.nheads = nheads
39
+ self.dim_feedforward = dim_feedforward
40
+ self.norm = None
41
+
42
+ single_encoder_layer = TransformerEncoderLayer(
43
+ d_model=d_model, nhead=nheads, dim_feedforward=dim_feedforward, dropout=dropout
44
+ )
45
+ self.layers = _get_clones(single_encoder_layer, num_layers)
46
+
47
+ def forward(self, memory_text: torch.Tensor, text_attention_mask: torch.Tensor):
48
+ """
49
+
50
+ Args:
51
+ text_attention_mask: bs, num_token
52
+ memory_text: bs, num_token, d_model
53
+
54
+ Raises:
55
+ RuntimeError: _description_
56
+
57
+ Returns:
58
+ output: bs, num_token, d_model
59
+ """
60
+
61
+ output = memory_text.transpose(0, 1)
62
+
63
+ for layer in self.layers:
64
+ output = layer(output, src_key_padding_mask=text_attention_mask)
65
+
66
+ if self.norm is not None:
67
+ output = self.norm(output)
68
+
69
+ return output.transpose(0, 1)
70
+
71
+
72
+ class TransformerEncoderLayer(nn.Module):
73
+ def __init__(
74
+ self,
75
+ d_model,
76
+ nhead,
77
+ dim_feedforward=2048,
78
+ dropout=0.1,
79
+ activation="relu",
80
+ normalize_before=False,
81
+ ):
82
+ super().__init__()
83
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
84
+ # Implementation of Feedforward model
85
+ r=16
86
+ self.linear1 = lora.Linear(d_model, dim_feedforward , r=r)
87
+ self.dropout = nn.Dropout(dropout)
88
+ self.linear2 = lora.Linear(dim_feedforward, d_model , r=r)
89
+
90
+ self.norm1 = nn.LayerNorm(d_model)
91
+ self.norm2 = nn.LayerNorm(d_model)
92
+ self.dropout1 = nn.Dropout(dropout)
93
+ self.dropout2 = nn.Dropout(dropout)
94
+
95
+ self.activation = _get_activation_fn(activation)
96
+ self.normalize_before = normalize_before
97
+ self.nhead = nhead
98
+
99
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
100
+ return tensor if pos is None else tensor + pos
101
+
102
+ def forward(
103
+ self,
104
+ src,
105
+ src_mask: Optional[Tensor] = None,
106
+ src_key_padding_mask: Optional[Tensor] = None,
107
+ pos: Optional[Tensor] = None,
108
+ ):
109
+ # repeat attn mask
110
+ if src_mask.dim() == 3 and src_mask.shape[0] == src.shape[1]:
111
+ # bs, num_q, num_k
112
+ src_mask = src_mask.repeat(self.nhead, 1, 1)
113
+
114
+ q = k = self.with_pos_embed(src, pos)
115
+
116
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)[0]
117
+
118
+ # src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
119
+ src = src + self.dropout1(src2)
120
+ src = self.norm1(src)
121
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
122
+ src = src + self.dropout2(src2)
123
+ src = self.norm2(src)
124
+ return src
groundingdino/models/GroundingDINO/.ipynb_checkpoints/utils-checkpoint.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+
8
+ import copy
9
+ import math
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import Tensor, nn
14
+ import loralib as lora
15
+
16
+ def _get_clones(module, N, layer_share=False):
17
+ # import ipdb; ipdb.set_trace()
18
+ if layer_share:
19
+ return nn.ModuleList([module for i in range(N)])
20
+ else:
21
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
22
+
23
+
24
+ def get_sine_pos_embed(
25
+ pos_tensor: torch.Tensor,
26
+ num_pos_feats: int = 128,
27
+ temperature: int = 10000,
28
+ exchange_xy: bool = True,
29
+ ):
30
+ """generate sine position embedding from a position tensor
31
+ Args:
32
+ pos_tensor (torch.Tensor): shape: [..., n].
33
+ num_pos_feats (int): projected shape for each float in the tensor.
34
+ temperature (int): temperature in the sine/cosine function.
35
+ exchange_xy (bool, optional): exchange pos x and pos y. \
36
+ For example, input tensor is [x,y], the results will be [pos(y), pos(x)]. Defaults to True.
37
+ Returns:
38
+ pos_embed (torch.Tensor): shape: [..., n*num_pos_feats].
39
+ """
40
+ scale = 2 * math.pi
41
+ dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device)
42
+ dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
43
+
44
+ def sine_func(x: torch.Tensor):
45
+ sin_x = x * scale / dim_t
46
+ sin_x = torch.stack((sin_x[..., 0::2].sin(), sin_x[..., 1::2].cos()), dim=3).flatten(2)
47
+ return sin_x
48
+
49
+ pos_res = [sine_func(x) for x in pos_tensor.split([1] * pos_tensor.shape[-1], dim=-1)]
50
+ if exchange_xy:
51
+ pos_res[0], pos_res[1] = pos_res[1], pos_res[0]
52
+ pos_res = torch.cat(pos_res, dim=-1)
53
+ return pos_res
54
+
55
+
56
+ def gen_encoder_output_proposals(
57
+ memory: Tensor, memory_padding_mask: Tensor, spatial_shapes: Tensor, learnedwh=None
58
+ ):
59
+ """
60
+ Input:
61
+ - memory: bs, \sum{hw}, d_model
62
+ - memory_padding_mask: bs, \sum{hw}
63
+ - spatial_shapes: nlevel, 2
64
+ - learnedwh: 2
65
+ Output:
66
+ - output_memory: bs, \sum{hw}, d_model
67
+ - output_proposals: bs, \sum{hw}, 4
68
+ """
69
+ N_, S_, C_ = memory.shape
70
+ proposals = []
71
+ _cur = 0
72
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
73
+ mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H_ * W_)].view(N_, H_, W_, 1)
74
+ valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
75
+ valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
76
+
77
+ # import ipdb; ipdb.set_trace()
78
+
79
+ grid_y, grid_x = torch.meshgrid(
80
+ torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
81
+ torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device),
82
+ )
83
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
84
+
85
+ scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
86
+ grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
87
+
88
+ if learnedwh is not None:
89
+ # import ipdb; ipdb.set_trace()
90
+ wh = torch.ones_like(grid) * learnedwh.sigmoid() * (2.0**lvl)
91
+ else:
92
+ wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
93
+
94
+ # scale = torch.cat([W_[None].unsqueeze(-1), H_[None].unsqueeze(-1)], 1).view(1, 1, 1, 2).repeat(N_, 1, 1, 1)
95
+ # grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
96
+ # wh = torch.ones_like(grid) / scale
97
+ proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
98
+ proposals.append(proposal)
99
+ _cur += H_ * W_
100
+ # import ipdb; ipdb.set_trace()
101
+ output_proposals = torch.cat(proposals, 1)
102
+ output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(
103
+ -1, keepdim=True
104
+ )
105
+ output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid
106
+ output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float("inf"))
107
+ output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf"))
108
+
109
+ output_memory = memory
110
+ output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
111
+ output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
112
+
113
+ # output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
114
+ # output_memory = output_memory.masked_fill(~output_proposals_valid, float('inf'))
115
+
116
+ return output_memory, output_proposals
117
+
118
+
119
+ class RandomBoxPerturber:
120
+ def __init__(
121
+ self, x_noise_scale=0.2, y_noise_scale=0.2, w_noise_scale=0.2, h_noise_scale=0.2
122
+ ) -> None:
123
+ self.noise_scale = torch.Tensor(
124
+ [x_noise_scale, y_noise_scale, w_noise_scale, h_noise_scale]
125
+ )
126
+
127
+ def __call__(self, refanchors: Tensor) -> Tensor:
128
+ nq, bs, query_dim = refanchors.shape
129
+ device = refanchors.device
130
+
131
+ noise_raw = torch.rand_like(refanchors)
132
+ noise_scale = self.noise_scale.to(device)[:query_dim]
133
+
134
+ new_refanchors = refanchors * (1 + (noise_raw - 0.5) * noise_scale)
135
+ return new_refanchors.clamp_(0, 1)
136
+
137
+
138
+ def sigmoid_focal_loss(
139
+ inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2, no_reduction=False
140
+ ):
141
+ """
142
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
143
+ Args:
144
+ inputs: A float tensor of arbitrary shape.
145
+ The predictions for each example.
146
+ targets: A float tensor with the same shape as inputs. Stores the binary
147
+ classification label for each element in inputs
148
+ (0 for the negative class and 1 for the positive class).
149
+ alpha: (optional) Weighting factor in range (0,1) to balance
150
+ positive vs negative examples. Default = -1 (no weighting).
151
+ gamma: Exponent of the modulating factor (1 - p_t) to
152
+ balance easy vs hard examples.
153
+ Returns:
154
+ Loss tensor
155
+ """
156
+ prob = inputs.sigmoid()
157
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
158
+ p_t = prob * targets + (1 - prob) * (1 - targets)
159
+ loss = ce_loss * ((1 - p_t) ** gamma)
160
+
161
+ if alpha >= 0:
162
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
163
+ loss = alpha_t * loss
164
+
165
+ if no_reduction:
166
+ return loss
167
+
168
+ return loss.mean(1).sum() / num_boxes
169
+
170
+
171
+ class MLP(nn.Module):
172
+ """Very simple multi-layer perceptron (also called FFN)"""
173
+
174
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
175
+ super().__init__()
176
+ self.num_layers = num_layers
177
+ r=16
178
+ h = [hidden_dim] * (num_layers - 1)
179
+ self.layers = nn.ModuleList(
180
+ [lora.Linear(n, k, r=r) for n, k in zip([input_dim] + h, h + [output_dim])]
181
+ )
182
+
183
+ def forward(self, x):
184
+ for i, layer in enumerate(self.layers):
185
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
186
+ return x
187
+
188
+
189
+ def _get_activation_fn(activation, d_model=256, batch_dim=0):
190
+ """Return an activation function given a string"""
191
+ if activation == "relu":
192
+ return F.relu
193
+ if activation == "gelu":
194
+ return F.gelu
195
+ if activation == "glu":
196
+ return F.glu
197
+ if activation == "prelu":
198
+ return nn.PReLU()
199
+ if activation == "selu":
200
+ return F.selu
201
+
202
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
203
+
204
+
205
+ def gen_sineembed_for_position(pos_tensor):
206
+ # n_query, bs, _ = pos_tensor.size()
207
+ # sineembed_tensor = torch.zeros(n_query, bs, 256)
208
+ scale = 2 * math.pi
209
+ dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
210
+ dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode='floor')) / 128)
211
+ x_embed = pos_tensor[:, :, 0] * scale
212
+ y_embed = pos_tensor[:, :, 1] * scale
213
+ pos_x = x_embed[:, :, None] / dim_t
214
+ pos_y = y_embed[:, :, None] / dim_t
215
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
216
+ pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
217
+ if pos_tensor.size(-1) == 2:
218
+ pos = torch.cat((pos_y, pos_x), dim=2)
219
+ elif pos_tensor.size(-1) == 4:
220
+ w_embed = pos_tensor[:, :, 2] * scale
221
+ pos_w = w_embed[:, :, None] / dim_t
222
+ pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
223
+
224
+ h_embed = pos_tensor[:, :, 3] * scale
225
+ pos_h = h_embed[:, :, None] / dim_t
226
+ pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
227
+
228
+ pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
229
+ else:
230
+ raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
231
+ return pos
232
+
233
+
234
+ class ContrastiveEmbed(nn.Module):
235
+ def __init__(self, max_text_len=256):
236
+ """
237
+ Args:
238
+ max_text_len: max length of text.
239
+ """
240
+ super().__init__()
241
+ self.max_text_len = max_text_len
242
+
243
+ def forward(self, x, text_dict):
244
+ """_summary_
245
+
246
+ Args:
247
+ x (_type_): _description_
248
+ text_dict (_type_): _description_
249
+ {
250
+ 'encoded_text': encoded_text, # bs, 195, d_model
251
+ 'text_token_mask': text_token_mask, # bs, 195
252
+ # True for used tokens. False for padding tokens
253
+ }
254
+ Returns:
255
+ _type_: _description_
256
+ """
257
+ assert isinstance(text_dict, dict)
258
+
259
+ y = text_dict["encoded_text"]
260
+ text_token_mask = text_dict["text_token_mask"]
261
+
262
+ res = x @ y.transpose(-1, -2)
263
+ res.masked_fill_(~text_token_mask[:, None, :], float("-inf"))
264
+
265
+ # padding to max_text_len
266
+ new_res = torch.full((*res.shape[:-1], self.max_text_len), float("-inf"), device=res.device)
267
+ new_res[..., : res.shape[-1]] = res
268
+
269
+ return new_res
groundingdino/models/GroundingDINO/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (260 Bytes). View file
 
groundingdino/models/GroundingDINO/__pycache__/bertwarper.cpython-310.pyc ADDED
Binary file (7.23 kB). View file
 
groundingdino/models/GroundingDINO/__pycache__/fuse_modules.cpython-310.pyc ADDED
Binary file (7.88 kB). View file
 
groundingdino/models/GroundingDINO/__pycache__/groundingdino.cpython-310.pyc ADDED
Binary file (11.4 kB). View file
 
groundingdino/models/GroundingDINO/__pycache__/ms_deform_attn.cpython-310.pyc ADDED
Binary file (11.8 kB). View file
 
groundingdino/models/GroundingDINO/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (19.4 kB). View file
 
groundingdino/models/GroundingDINO/__pycache__/transformer_vanilla.cpython-310.pyc ADDED
Binary file (3.52 kB). View file
 
groundingdino/models/GroundingDINO/__pycache__/utils.cpython-310.pyc ADDED
Binary file (9.61 kB). View file
 
groundingdino/models/GroundingDINO/backbone/.ipynb_checkpoints/backbone-checkpoint.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # Conditional DETR
8
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------
11
+ # Copied from DETR (https://github.com/facebookresearch/detr)
12
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
13
+ # ------------------------------------------------------------------------
14
+
15
+
16
+ # Backbone modules.
17
+
18
+ from typing import Dict, List
19
+ import loralib as lora
20
+ import torch
21
+ import torch.nn.functional as F
22
+ import torchvision
23
+ from torch import nn
24
+ from torchvision.models._utils import IntermediateLayerGetter
25
+ import loralib as lora
26
+ from groundingdino.util.misc import NestedTensor, clean_state_dict, is_main_process
27
+
28
+ from .position_encoding import build_position_encoding
29
+ from .swin_transformer import build_swin_transformer
30
+
31
+
32
+ class FrozenBatchNorm2d(torch.nn.Module):
33
+ """
34
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
35
+
36
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
37
+ without which any other models than torchvision.models.resnet[18,34,50,101]
38
+ produce nans.
39
+ """
40
+
41
+ def __init__(self, n):
42
+ super(FrozenBatchNorm2d, self).__init__()
43
+ self.register_buffer("weight", torch.ones(n))
44
+ self.register_buffer("bias", torch.zeros(n))
45
+ self.register_buffer("running_mean", torch.zeros(n))
46
+ self.register_buffer("running_var", torch.ones(n))
47
+
48
+ def _load_from_state_dict(
49
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
50
+ ):
51
+ num_batches_tracked_key = prefix + "num_batches_tracked"
52
+ if num_batches_tracked_key in state_dict:
53
+ del state_dict[num_batches_tracked_key]
54
+
55
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
56
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
57
+ )
58
+
59
+ def forward(self, x):
60
+ # move reshapes to the beginning
61
+ # to make it fuser-friendly
62
+ w = self.weight.reshape(1, -1, 1, 1)
63
+ b = self.bias.reshape(1, -1, 1, 1)
64
+ rv = self.running_var.reshape(1, -1, 1, 1)
65
+ rm = self.running_mean.reshape(1, -1, 1, 1)
66
+ eps = 1e-5
67
+ scale = w * (rv + eps).rsqrt()
68
+ bias = b - rm * scale
69
+ return x * scale + bias
70
+
71
+
72
+ class BackboneBase(nn.Module):
73
+ def __init__(
74
+ self,
75
+ backbone: nn.Module,
76
+ train_backbone: bool,
77
+ num_channels: int,
78
+ return_interm_indices: list,
79
+ ):
80
+ super().__init__()
81
+ for name, parameter in backbone.named_parameters():
82
+ if (
83
+ not train_backbone
84
+ or "layer2" not in name
85
+ and "layer3" not in name
86
+ and "layer4" not in name
87
+ ):
88
+ parameter.requires_grad_(False)
89
+
90
+ return_layers = {}
91
+ for idx, layer_index in enumerate(return_interm_indices):
92
+ return_layers.update(
93
+ {"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)}
94
+ )
95
+
96
+ # if len:
97
+ # if use_stage1_feature:
98
+ # return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
99
+ # else:
100
+ # return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
101
+ # else:
102
+ # return_layers = {'layer4': "0"}
103
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
104
+ self.num_channels = num_channels
105
+
106
+ def forward(self, tensor_list: NestedTensor):
107
+ xs = self.body(tensor_list.tensors)
108
+ out: Dict[str, NestedTensor] = {}
109
+ for name, x in xs.items():
110
+ m = tensor_list.mask
111
+ assert m is not None
112
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
113
+ out[name] = NestedTensor(x, mask)
114
+ # import ipdb; ipdb.set_trace()
115
+ return out
116
+
117
+
118
+ class Backbone(BackboneBase):
119
+ """ResNet backbone with frozen BatchNorm."""
120
+
121
+ def __init__(
122
+ self,
123
+ name: str,
124
+ train_backbone: bool,
125
+ dilation: bool,
126
+ return_interm_indices: list,
127
+ batch_norm=FrozenBatchNorm2d,
128
+ ):
129
+ if name in ["resnet18", "resnet34", "resnet50", "resnet101"]:
130
+ backbone = getattr(torchvision.models, name)(
131
+ replace_stride_with_dilation=[False, False, dilation],
132
+ pretrained=is_main_process(),
133
+ norm_layer=batch_norm,
134
+ )
135
+ else:
136
+ raise NotImplementedError("Why you can get here with name {}".format(name))
137
+ # num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
138
+ assert name not in ("resnet18", "resnet34"), "Only resnet50 and resnet101 are available."
139
+ assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
140
+ num_channels_all = [256, 512, 1024, 2048]
141
+ num_channels = num_channels_all[4 - len(return_interm_indices) :]
142
+ super().__init__(backbone, train_backbone, num_channels, return_interm_indices)
143
+
144
+
145
+ class Joiner(nn.Sequential):
146
+ def __init__(self, backbone, position_embedding):
147
+ super().__init__(backbone, position_embedding)
148
+
149
+ def forward(self, tensor_list: NestedTensor):
150
+ xs = self[0](tensor_list)
151
+ out: List[NestedTensor] = []
152
+ pos = []
153
+ for name, x in xs.items():
154
+ out.append(x)
155
+ # position encoding
156
+ pos.append(self[1](x).to(x.tensors.dtype))
157
+
158
+ return out, pos
159
+
160
+
161
+ def build_backbone(args):
162
+ """
163
+ Useful args:
164
+ - backbone: backbone name
165
+ - lr_backbone:
166
+ - dilation
167
+ - return_interm_indices: available: [0,1,2,3], [1,2,3], [3]
168
+ - backbone_freeze_keywords:
169
+ - use_checkpoint: for swin only for now
170
+
171
+ """
172
+ position_embedding = build_position_encoding(args)
173
+ train_backbone = True
174
+ if not train_backbone:
175
+ raise ValueError("Please set lr_backbone > 0")
176
+ return_interm_indices = args.return_interm_indices
177
+ assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
178
+ args.backbone_freeze_keywords
179
+ use_checkpoint = getattr(args, "use_checkpoint", False)
180
+
181
+ if args.backbone in ["resnet50", "resnet101"]:
182
+ backbone = Backbone(
183
+ args.backbone,
184
+ train_backbone,
185
+ args.dilation,
186
+ return_interm_indices,
187
+ batch_norm=FrozenBatchNorm2d,
188
+ )
189
+ bb_num_channels = backbone.num_channels
190
+ elif args.backbone in [
191
+ "swin_T_224_1k",
192
+ "swin_B_224_22k",
193
+ "swin_B_384_22k",
194
+ "swin_L_224_22k",
195
+ "swin_L_384_22k",
196
+ ]:
197
+ pretrain_img_size = int(args.backbone.split("_")[-2])
198
+ backbone = build_swin_transformer(
199
+ args.backbone,
200
+ pretrain_img_size=pretrain_img_size,
201
+ out_indices=tuple(return_interm_indices),
202
+ dilation=False,
203
+ use_checkpoint=use_checkpoint,
204
+ )
205
+
206
+ bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :]
207
+ else:
208
+ raise NotImplementedError("Unknown backbone {}".format(args.backbone))
209
+
210
+ assert len(bb_num_channels) == len(
211
+ return_interm_indices
212
+ ), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}"
213
+
214
+ model = Joiner(backbone, position_embedding)
215
+ model.num_channels = bb_num_channels
216
+ assert isinstance(
217
+ bb_num_channels, List
218
+ ), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels))
219
+ # import ipdb; ipdb.set_trace()
220
+ return model
groundingdino/models/GroundingDINO/backbone/.ipynb_checkpoints/position_encoding-checkpoint.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # DINO
8
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------
11
+ # Conditional DETR
12
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
13
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
14
+ # ------------------------------------------------------------------------
15
+ # Copied from DETR (https://github.com/facebookresearch/detr)
16
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
17
+ # ------------------------------------------------------------------------
18
+
19
+ """
20
+ Various positional encodings for the transformer.
21
+ """
22
+ import math
23
+
24
+ import torch
25
+ from torch import nn
26
+ import loralib as lora
27
+ from groundingdino.util.misc import NestedTensor
28
+
29
+
30
+ class PositionEmbeddingSine(nn.Module):
31
+ """
32
+ This is a more standard version of the position embedding, very similar to the one
33
+ used by the Attention is all you need paper, generalized to work on images.
34
+ """
35
+
36
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
37
+ super().__init__()
38
+ self.num_pos_feats = num_pos_feats
39
+ self.temperature = temperature
40
+ self.normalize = normalize
41
+ if scale is not None and normalize is False:
42
+ raise ValueError("normalize should be True if scale is passed")
43
+ if scale is None:
44
+ scale = 2 * math.pi
45
+ self.scale = scale
46
+
47
+ def forward(self, tensor_list: NestedTensor):
48
+ x = tensor_list.tensors
49
+ mask = tensor_list.mask
50
+ assert mask is not None
51
+ not_mask = ~mask
52
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
53
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
54
+ if self.normalize:
55
+ eps = 1e-6
56
+ # if os.environ.get("SHILONG_AMP", None) == '1':
57
+ # eps = 1e-4
58
+ # else:
59
+ # eps = 1e-6
60
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
61
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
62
+
63
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
64
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
65
+
66
+ pos_x = x_embed[:, :, :, None] / dim_t
67
+ pos_y = y_embed[:, :, :, None] / dim_t
68
+ pos_x = torch.stack(
69
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
70
+ ).flatten(3)
71
+ pos_y = torch.stack(
72
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
73
+ ).flatten(3)
74
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
75
+ return pos
76
+
77
+
78
+ class PositionEmbeddingSineHW(nn.Module):
79
+ """
80
+ This is a more standard version of the position embedding, very similar to the one
81
+ used by the Attention is all you need paper, generalized to work on images.
82
+ """
83
+
84
+ def __init__(
85
+ self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None
86
+ ):
87
+ super().__init__()
88
+ self.num_pos_feats = num_pos_feats
89
+ self.temperatureH = temperatureH
90
+ self.temperatureW = temperatureW
91
+ self.normalize = normalize
92
+ if scale is not None and normalize is False:
93
+ raise ValueError("normalize should be True if scale is passed")
94
+ if scale is None:
95
+ scale = 2 * math.pi
96
+ self.scale = scale
97
+
98
+ def forward(self, tensor_list: NestedTensor):
99
+ x = tensor_list.tensors
100
+ mask = tensor_list.mask
101
+ assert mask is not None
102
+ not_mask = ~mask
103
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
104
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
105
+
106
+ # import ipdb; ipdb.set_trace()
107
+
108
+ if self.normalize:
109
+ eps = 1e-6
110
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
111
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
112
+
113
+ dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
114
+ dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode='floor')) / self.num_pos_feats)
115
+ pos_x = x_embed[:, :, :, None] / dim_tx
116
+
117
+ dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
118
+ dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode='floor')) / self.num_pos_feats)
119
+ pos_y = y_embed[:, :, :, None] / dim_ty
120
+
121
+ pos_x = torch.stack(
122
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
123
+ ).flatten(3)
124
+ pos_y = torch.stack(
125
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
126
+ ).flatten(3)
127
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
128
+
129
+ # import ipdb; ipdb.set_trace()
130
+
131
+ return pos
132
+
133
+
134
+ class PositionEmbeddingLearned(nn.Module):
135
+ """
136
+ Absolute pos embedding, learned.
137
+ """
138
+
139
+ def __init__(self, num_pos_feats=256):
140
+ super().__init__()
141
+ self.row_embed = nn.Embedding(50, num_pos_feats)
142
+ self.col_embed = nn.Embedding(50, num_pos_feats)
143
+ self.reset_parameters()
144
+
145
+ def reset_parameters(self):
146
+ nn.init.uniform_(self.row_embed.weight)
147
+ nn.init.uniform_(self.col_embed.weight)
148
+
149
+ def forward(self, tensor_list: NestedTensor):
150
+ x = tensor_list.tensors
151
+ h, w = x.shape[-2:]
152
+ i = torch.arange(w, device=x.device)
153
+ j = torch.arange(h, device=x.device)
154
+ x_emb = self.col_embed(i)
155
+ y_emb = self.row_embed(j)
156
+ pos = (
157
+ torch.cat(
158
+ [
159
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
160
+ y_emb.unsqueeze(1).repeat(1, w, 1),
161
+ ],
162
+ dim=-1,
163
+ )
164
+ .permute(2, 0, 1)
165
+ .unsqueeze(0)
166
+ .repeat(x.shape[0], 1, 1, 1)
167
+ )
168
+ return pos
169
+
170
+
171
+ def build_position_encoding(args):
172
+ N_steps = args.hidden_dim // 2
173
+ if args.position_embedding in ("v2", "sine"):
174
+ # TODO find a better way of exposing other arguments
175
+ position_embedding = PositionEmbeddingSineHW(
176
+ N_steps,
177
+ temperatureH=args.pe_temperatureH,
178
+ temperatureW=args.pe_temperatureW,
179
+ normalize=True,
180
+ )
181
+ elif args.position_embedding in ("v3", "learned"):
182
+ position_embedding = PositionEmbeddingLearned(N_steps)
183
+ else:
184
+ raise ValueError(f"not supported {args.position_embedding}")
185
+
186
+ return position_embedding
groundingdino/models/GroundingDINO/backbone/.ipynb_checkpoints/swin_transformer-checkpoint.py ADDED
@@ -0,0 +1,804 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # DINO
8
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # --------------------------------------------------------
11
+ # modified from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py
12
+ # --------------------------------------------------------
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.utils.checkpoint as checkpoint
19
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
20
+ import loralib as lora
21
+ from groundingdino.util.misc import NestedTensor
22
+
23
+
24
+ class Mlp(nn.Module):
25
+ """Multilayer perceptron."""
26
+
27
+ def __init__(
28
+ self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
29
+ ):
30
+ super().__init__()
31
+ out_features = out_features or in_features
32
+ hidden_features = hidden_features or in_features
33
+ r = 16
34
+ self.fc1 = lora.Linear(in_features, hidden_features , r=r)
35
+ self.act = act_layer()
36
+ self.fc2 = lora.Linear(hidden_features, out_features , r=r)
37
+ self.drop = nn.Dropout(drop)
38
+
39
+ def forward(self, x):
40
+ x = self.fc1(x)
41
+ x = self.act(x)
42
+ x = self.drop(x)
43
+ x = self.fc2(x)
44
+ x = self.drop(x)
45
+ return x
46
+
47
+
48
+ def window_partition(x, window_size):
49
+ """
50
+ Args:
51
+ x: (B, H, W, C)
52
+ window_size (int): window size
53
+ Returns:
54
+ windows: (num_windows*B, window_size, window_size, C)
55
+ """
56
+ B, H, W, C = x.shape
57
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
58
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
59
+ return windows
60
+
61
+
62
+ def window_reverse(windows, window_size, H, W):
63
+ """
64
+ Args:
65
+ windows: (num_windows*B, window_size, window_size, C)
66
+ window_size (int): Window size
67
+ H (int): Height of image
68
+ W (int): Width of image
69
+ Returns:
70
+ x: (B, H, W, C)
71
+ """
72
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
73
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
74
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
75
+ return x
76
+
77
+
78
+ class WindowAttention(nn.Module):
79
+ """Window based multi-head self attention (W-MSA) module with relative position bias.
80
+ It supports both of shifted and non-shifted window.
81
+ Args:
82
+ dim (int): Number of input channels.
83
+ window_size (tuple[int]): The height and width of the window.
84
+ num_heads (int): Number of attention heads.
85
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
86
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
87
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
88
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ dim,
94
+ window_size,
95
+ num_heads,
96
+ qkv_bias=True,
97
+ qk_scale=None,
98
+ attn_drop=0.0,
99
+ proj_drop=0.0,
100
+ ):
101
+
102
+ super().__init__()
103
+ self.dim = dim
104
+ self.window_size = window_size # Wh, Ww
105
+ self.num_heads = num_heads
106
+ head_dim = dim // num_heads
107
+ self.scale = qk_scale or head_dim**-0.5
108
+ r =16
109
+ # define a parameter table of relative position bias
110
+ self.relative_position_bias_table = nn.Parameter(
111
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
112
+ ) # 2*Wh-1 * 2*Ww-1, nH
113
+
114
+ # get pair-wise relative position index for each token inside the window
115
+ coords_h = torch.arange(self.window_size[0])
116
+ coords_w = torch.arange(self.window_size[1])
117
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
118
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
119
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
120
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
121
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
122
+ relative_coords[:, :, 1] += self.window_size[1] - 1
123
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
124
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
125
+ self.register_buffer("relative_position_index", relative_position_index)
126
+
127
+ self.qkv = lora.Linear(dim, dim * 3,r=r , bias=qkv_bias)
128
+ self.attn_drop = nn.Dropout(attn_drop)
129
+ self.proj = lora.Linear(dim, dim , r=r)
130
+ self.proj_drop = nn.Dropout(proj_drop)
131
+
132
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
133
+ self.softmax = nn.Softmax(dim=-1)
134
+
135
+ def forward(self, x, mask=None):
136
+ """Forward function.
137
+ Args:
138
+ x: input features with shape of (num_windows*B, N, C)
139
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
140
+ """
141
+ B_, N, C = x.shape
142
+ qkv = (
143
+ self.qkv(x)
144
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
145
+ .permute(2, 0, 3, 1, 4)
146
+ )
147
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
148
+
149
+ q = q * self.scale
150
+ attn = q @ k.transpose(-2, -1)
151
+
152
+ relative_position_bias = self.relative_position_bias_table[
153
+ self.relative_position_index.view(-1)
154
+ ].view(
155
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
156
+ ) # Wh*Ww,Wh*Ww,nH
157
+ relative_position_bias = relative_position_bias.permute(
158
+ 2, 0, 1
159
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
160
+ attn = attn + relative_position_bias.unsqueeze(0)
161
+
162
+ if mask is not None:
163
+ nW = mask.shape[0]
164
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
165
+ attn = attn.view(-1, self.num_heads, N, N)
166
+ attn = self.softmax(attn)
167
+ else:
168
+ attn = self.softmax(attn)
169
+
170
+ attn = self.attn_drop(attn)
171
+
172
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
173
+ x = self.proj(x)
174
+ x = self.proj_drop(x)
175
+ return x
176
+
177
+
178
+ class SwinTransformerBlock(nn.Module):
179
+ """Swin Transformer Block.
180
+ Args:
181
+ dim (int): Number of input channels.
182
+ num_heads (int): Number of attention heads.
183
+ window_size (int): Window size.
184
+ shift_size (int): Shift size for SW-MSA.
185
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
186
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
187
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
188
+ drop (float, optional): Dropout rate. Default: 0.0
189
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
190
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
191
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
192
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
193
+ """
194
+
195
+ def __init__(
196
+ self,
197
+ dim,
198
+ num_heads,
199
+ window_size=7,
200
+ shift_size=0,
201
+ mlp_ratio=4.0,
202
+ qkv_bias=True,
203
+ qk_scale=None,
204
+ drop=0.0,
205
+ attn_drop=0.0,
206
+ drop_path=0.0,
207
+ act_layer=nn.GELU,
208
+ norm_layer=nn.LayerNorm,
209
+ ):
210
+ super().__init__()
211
+ self.dim = dim
212
+ self.num_heads = num_heads
213
+ self.window_size = window_size
214
+ self.shift_size = shift_size
215
+ self.mlp_ratio = mlp_ratio
216
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
217
+
218
+ self.norm1 = norm_layer(dim)
219
+ self.attn = WindowAttention(
220
+ dim,
221
+ window_size=to_2tuple(self.window_size),
222
+ num_heads=num_heads,
223
+ qkv_bias=qkv_bias,
224
+ qk_scale=qk_scale,
225
+ attn_drop=attn_drop,
226
+ proj_drop=drop,
227
+ )
228
+
229
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
230
+ self.norm2 = norm_layer(dim)
231
+ mlp_hidden_dim = int(dim * mlp_ratio)
232
+ self.mlp = Mlp(
233
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
234
+ )
235
+
236
+ self.H = None
237
+ self.W = None
238
+
239
+ def forward(self, x, mask_matrix):
240
+ """Forward function.
241
+ Args:
242
+ x: Input feature, tensor size (B, H*W, C).
243
+ H, W: Spatial resolution of the input feature.
244
+ mask_matrix: Attention mask for cyclic shift.
245
+ """
246
+ B, L, C = x.shape
247
+ H, W = self.H, self.W
248
+ assert L == H * W, "input feature has wrong size"
249
+
250
+ shortcut = x
251
+ x = self.norm1(x)
252
+ x = x.view(B, H, W, C)
253
+
254
+ # pad feature maps to multiples of window size
255
+ pad_l = pad_t = 0
256
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
257
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
258
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
259
+ _, Hp, Wp, _ = x.shape
260
+
261
+ # cyclic shift
262
+ if self.shift_size > 0:
263
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
264
+ attn_mask = mask_matrix
265
+ else:
266
+ shifted_x = x
267
+ attn_mask = None
268
+
269
+ # partition windows
270
+ x_windows = window_partition(
271
+ shifted_x, self.window_size
272
+ ) # nW*B, window_size, window_size, C
273
+ x_windows = x_windows.view(
274
+ -1, self.window_size * self.window_size, C
275
+ ) # nW*B, window_size*window_size, C
276
+
277
+ # W-MSA/SW-MSA
278
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
279
+
280
+ # merge windows
281
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
282
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
283
+
284
+ # reverse cyclic shift
285
+ if self.shift_size > 0:
286
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
287
+ else:
288
+ x = shifted_x
289
+
290
+ if pad_r > 0 or pad_b > 0:
291
+ x = x[:, :H, :W, :].contiguous()
292
+
293
+ x = x.view(B, H * W, C)
294
+
295
+ # FFN
296
+ x = shortcut + self.drop_path(x)
297
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
298
+
299
+ return x
300
+
301
+
302
+ class PatchMerging(nn.Module):
303
+ """Patch Merging Layer
304
+ Args:
305
+ dim (int): Number of input channels.
306
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
307
+ """
308
+
309
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
310
+ super().__init__()
311
+ self.dim = dim
312
+ r = 16
313
+ self.reduction = lora.Linear(4 * dim, 2 * dim, r=r, bias=False)
314
+ self.norm = norm_layer(4 * dim)
315
+
316
+ def forward(self, x, H, W):
317
+ """Forward function.
318
+ Args:
319
+ x: Input feature, tensor size (B, H*W, C).
320
+ H, W: Spatial resolution of the input feature.
321
+ """
322
+ B, L, C = x.shape
323
+ assert L == H * W, "input feature has wrong size"
324
+
325
+ x = x.view(B, H, W, C)
326
+
327
+ # padding
328
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
329
+ if pad_input:
330
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
331
+
332
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
333
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
334
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
335
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
336
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
337
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
338
+
339
+ x = self.norm(x)
340
+ x = self.reduction(x)
341
+
342
+ return x
343
+
344
+
345
+ class BasicLayer(nn.Module):
346
+ """A basic Swin Transformer layer for one stage.
347
+ Args:
348
+ dim (int): Number of feature channels
349
+ depth (int): Depths of this stage.
350
+ num_heads (int): Number of attention head.
351
+ window_size (int): Local window size. Default: 7.
352
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
353
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
354
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
355
+ drop (float, optional): Dropout rate. Default: 0.0
356
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
357
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
358
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
359
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
360
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
361
+ """
362
+
363
+ def __init__(
364
+ self,
365
+ dim,
366
+ depth,
367
+ num_heads,
368
+ window_size=7,
369
+ mlp_ratio=4.0,
370
+ qkv_bias=True,
371
+ qk_scale=None,
372
+ drop=0.0,
373
+ attn_drop=0.0,
374
+ drop_path=0.0,
375
+ norm_layer=nn.LayerNorm,
376
+ downsample=None,
377
+ use_checkpoint=False,
378
+ ):
379
+ super().__init__()
380
+ self.window_size = window_size
381
+ self.shift_size = window_size // 2
382
+ self.depth = depth
383
+ self.use_checkpoint = use_checkpoint
384
+
385
+ # build blocks
386
+ self.blocks = nn.ModuleList(
387
+ [
388
+ SwinTransformerBlock(
389
+ dim=dim,
390
+ num_heads=num_heads,
391
+ window_size=window_size,
392
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
393
+ mlp_ratio=mlp_ratio,
394
+ qkv_bias=qkv_bias,
395
+ qk_scale=qk_scale,
396
+ drop=drop,
397
+ attn_drop=attn_drop,
398
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
399
+ norm_layer=norm_layer,
400
+ )
401
+ for i in range(depth)
402
+ ]
403
+ )
404
+
405
+ # patch merging layer
406
+ if downsample is not None:
407
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
408
+ else:
409
+ self.downsample = None
410
+
411
+ def forward(self, x, H, W):
412
+ """Forward function.
413
+ Args:
414
+ x: Input feature, tensor size (B, H*W, C).
415
+ H, W: Spatial resolution of the input feature.
416
+ """
417
+
418
+ # calculate attention mask for SW-MSA
419
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
420
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
421
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
422
+ h_slices = (
423
+ slice(0, -self.window_size),
424
+ slice(-self.window_size, -self.shift_size),
425
+ slice(-self.shift_size, None),
426
+ )
427
+ w_slices = (
428
+ slice(0, -self.window_size),
429
+ slice(-self.window_size, -self.shift_size),
430
+ slice(-self.shift_size, None),
431
+ )
432
+ cnt = 0
433
+ for h in h_slices:
434
+ for w in w_slices:
435
+ img_mask[:, h, w, :] = cnt
436
+ cnt += 1
437
+
438
+ mask_windows = window_partition(
439
+ img_mask, self.window_size
440
+ ) # nW, window_size, window_size, 1
441
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
442
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
443
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
444
+ attn_mask == 0, float(0.0)
445
+ )
446
+
447
+ for blk in self.blocks:
448
+ blk.H, blk.W = H, W
449
+ if self.use_checkpoint:
450
+ x = checkpoint.checkpoint(blk, x, attn_mask)
451
+ else:
452
+ x = blk(x, attn_mask)
453
+ if self.downsample is not None:
454
+ x_down = self.downsample(x, H, W)
455
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
456
+ return x, H, W, x_down, Wh, Ww
457
+ else:
458
+ return x, H, W, x, H, W
459
+
460
+
461
+ class PatchEmbed(nn.Module):
462
+ """Image to Patch Embedding
463
+ Args:
464
+ patch_size (int): Patch token size. Default: 4.
465
+ in_chans (int): Number of input image channels. Default: 3.
466
+ embed_dim (int): Number of linear projection output channels. Default: 96.
467
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
468
+ """
469
+
470
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
471
+ super().__init__()
472
+ patch_size = to_2tuple(patch_size)
473
+ self.patch_size = patch_size
474
+
475
+ self.in_chans = in_chans
476
+ self.embed_dim = embed_dim
477
+
478
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
479
+ if norm_layer is not None:
480
+ self.norm = norm_layer(embed_dim)
481
+ else:
482
+ self.norm = None
483
+
484
+ def forward(self, x):
485
+ """Forward function."""
486
+ # padding
487
+ _, _, H, W = x.size()
488
+ if W % self.patch_size[1] != 0:
489
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
490
+ if H % self.patch_size[0] != 0:
491
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
492
+
493
+ x = self.proj(x) # B C Wh Ww
494
+ if self.norm is not None:
495
+ Wh, Ww = x.size(2), x.size(3)
496
+ x = x.flatten(2).transpose(1, 2)
497
+ x = self.norm(x)
498
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
499
+
500
+ return x
501
+
502
+
503
+ class SwinTransformer(nn.Module):
504
+ """Swin Transformer backbone.
505
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
506
+ https://arxiv.org/pdf/2103.14030
507
+ Args:
508
+ pretrain_img_size (int): Input image size for training the pretrained model,
509
+ used in absolute postion embedding. Default 224.
510
+ patch_size (int | tuple(int)): Patch size. Default: 4.
511
+ in_chans (int): Number of input image channels. Default: 3.
512
+ embed_dim (int): Number of linear projection output channels. Default: 96.
513
+ depths (tuple[int]): Depths of each Swin Transformer stage.
514
+ num_heads (tuple[int]): Number of attention head of each stage.
515
+ window_size (int): Window size. Default: 7.
516
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
517
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
518
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
519
+ drop_rate (float): Dropout rate.
520
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
521
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
522
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
523
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
524
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
525
+ out_indices (Sequence[int]): Output from which stages.
526
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
527
+ -1 means not freezing any parameters.
528
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
529
+ dilation (bool): if True, the output size if 16x downsample, ow 32x downsample.
530
+ """
531
+
532
+ def __init__(
533
+ self,
534
+ pretrain_img_size=224,
535
+ patch_size=4,
536
+ in_chans=3,
537
+ embed_dim=96,
538
+ depths=[2, 2, 6, 2],
539
+ num_heads=[3, 6, 12, 24],
540
+ window_size=7,
541
+ mlp_ratio=4.0,
542
+ qkv_bias=True,
543
+ qk_scale=None,
544
+ drop_rate=0.0,
545
+ attn_drop_rate=0.0,
546
+ drop_path_rate=0.2,
547
+ norm_layer=nn.LayerNorm,
548
+ ape=False,
549
+ patch_norm=True,
550
+ out_indices=(0, 1, 2, 3),
551
+ frozen_stages=-1,
552
+ dilation=False,
553
+ use_checkpoint=False,
554
+ ):
555
+ super().__init__()
556
+
557
+ self.pretrain_img_size = pretrain_img_size
558
+ self.num_layers = len(depths)
559
+ self.embed_dim = embed_dim
560
+ self.ape = ape
561
+ self.patch_norm = patch_norm
562
+ self.out_indices = out_indices
563
+ self.frozen_stages = frozen_stages
564
+ self.dilation = dilation
565
+
566
+ # if use_checkpoint:
567
+ # print("use_checkpoint!!!!!!!!!!!!!!!!!!!!!!!!")
568
+
569
+ # split image into non-overlapping patches
570
+ self.patch_embed = PatchEmbed(
571
+ patch_size=patch_size,
572
+ in_chans=in_chans,
573
+ embed_dim=embed_dim,
574
+ norm_layer=norm_layer if self.patch_norm else None,
575
+ )
576
+
577
+ # absolute position embedding
578
+ if self.ape:
579
+ pretrain_img_size = to_2tuple(pretrain_img_size)
580
+ patch_size = to_2tuple(patch_size)
581
+ patches_resolution = [
582
+ pretrain_img_size[0] // patch_size[0],
583
+ pretrain_img_size[1] // patch_size[1],
584
+ ]
585
+
586
+ self.absolute_pos_embed = nn.Parameter(
587
+ torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
588
+ )
589
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
590
+
591
+ self.pos_drop = nn.Dropout(p=drop_rate)
592
+
593
+ # stochastic depth
594
+ dpr = [
595
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
596
+ ] # stochastic depth decay rule
597
+
598
+ # build layers
599
+ self.layers = nn.ModuleList()
600
+ # prepare downsample list
601
+ downsamplelist = [PatchMerging for i in range(self.num_layers)]
602
+ downsamplelist[-1] = None
603
+ num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
604
+ if self.dilation:
605
+ downsamplelist[-2] = None
606
+ num_features[-1] = int(embed_dim * 2 ** (self.num_layers - 1)) // 2
607
+ for i_layer in range(self.num_layers):
608
+ layer = BasicLayer(
609
+ # dim=int(embed_dim * 2 ** i_layer),
610
+ dim=num_features[i_layer],
611
+ depth=depths[i_layer],
612
+ num_heads=num_heads[i_layer],
613
+ window_size=window_size,
614
+ mlp_ratio=mlp_ratio,
615
+ qkv_bias=qkv_bias,
616
+ qk_scale=qk_scale,
617
+ drop=drop_rate,
618
+ attn_drop=attn_drop_rate,
619
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
620
+ norm_layer=norm_layer,
621
+ # downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
622
+ downsample=downsamplelist[i_layer],
623
+ use_checkpoint=use_checkpoint,
624
+ )
625
+ self.layers.append(layer)
626
+
627
+ # num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
628
+ self.num_features = num_features
629
+
630
+ # add a norm layer for each output
631
+ for i_layer in out_indices:
632
+ layer = norm_layer(num_features[i_layer])
633
+ layer_name = f"norm{i_layer}"
634
+ self.add_module(layer_name, layer)
635
+
636
+ self._freeze_stages()
637
+
638
+ def _freeze_stages(self):
639
+ if self.frozen_stages >= 0:
640
+ self.patch_embed.eval()
641
+ for param in self.patch_embed.parameters():
642
+ param.requires_grad = False
643
+
644
+ if self.frozen_stages >= 1 and self.ape:
645
+ self.absolute_pos_embed.requires_grad = False
646
+
647
+ if self.frozen_stages >= 2:
648
+ self.pos_drop.eval()
649
+ for i in range(0, self.frozen_stages - 1):
650
+ m = self.layers[i]
651
+ m.eval()
652
+ for param in m.parameters():
653
+ param.requires_grad = False
654
+
655
+ # def init_weights(self, pretrained=None):
656
+ # """Initialize the weights in backbone.
657
+ # Args:
658
+ # pretrained (str, optional): Path to pre-trained weights.
659
+ # Defaults to None.
660
+ # """
661
+
662
+ # def _init_weights(m):
663
+ # if isinstance(m, nn.Linear):
664
+ # trunc_normal_(m.weight, std=.02)
665
+ # if isinstance(m, nn.Linear) and m.bias is not None:
666
+ # nn.init.constant_(m.bias, 0)
667
+ # elif isinstance(m, nn.LayerNorm):
668
+ # nn.init.constant_(m.bias, 0)
669
+ # nn.init.constant_(m.weight, 1.0)
670
+
671
+ # if isinstance(pretrained, str):
672
+ # self.apply(_init_weights)
673
+ # logger = get_root_logger()
674
+ # load_checkpoint(self, pretrained, strict=False, logger=logger)
675
+ # elif pretrained is None:
676
+ # self.apply(_init_weights)
677
+ # else:
678
+ # raise TypeError('pretrained must be a str or None')
679
+
680
+ def forward_raw(self, x):
681
+ """Forward function."""
682
+ x = self.patch_embed(x)
683
+
684
+ Wh, Ww = x.size(2), x.size(3)
685
+ if self.ape:
686
+ # interpolate the position embedding to the corresponding size
687
+ absolute_pos_embed = F.interpolate(
688
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
689
+ )
690
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
691
+ else:
692
+ x = x.flatten(2).transpose(1, 2)
693
+ x = self.pos_drop(x)
694
+
695
+ outs = []
696
+ for i in range(self.num_layers):
697
+ layer = self.layers[i]
698
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
699
+ # import ipdb; ipdb.set_trace()
700
+
701
+ if i in self.out_indices:
702
+ norm_layer = getattr(self, f"norm{i}")
703
+ x_out = norm_layer(x_out)
704
+
705
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
706
+ outs.append(out)
707
+ # in:
708
+ # torch.Size([2, 3, 1024, 1024])
709
+ # outs:
710
+ # [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
711
+ # torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
712
+ return tuple(outs)
713
+
714
+ def forward(self, tensor_list: NestedTensor):
715
+ x = tensor_list.tensors
716
+
717
+ """Forward function."""
718
+ x = self.patch_embed(x)
719
+
720
+ Wh, Ww = x.size(2), x.size(3)
721
+ if self.ape:
722
+ # interpolate the position embedding to the corresponding size
723
+ absolute_pos_embed = F.interpolate(
724
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
725
+ )
726
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
727
+ else:
728
+ x = x.flatten(2).transpose(1, 2)
729
+ x = self.pos_drop(x)
730
+
731
+ outs = []
732
+ for i in range(self.num_layers):
733
+ layer = self.layers[i]
734
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
735
+
736
+ if i in self.out_indices:
737
+ norm_layer = getattr(self, f"norm{i}")
738
+ x_out = norm_layer(x_out)
739
+
740
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
741
+ outs.append(out)
742
+ # in:
743
+ # torch.Size([2, 3, 1024, 1024])
744
+ # out:
745
+ # [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
746
+ # torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
747
+
748
+ # collect for nesttensors
749
+ outs_dict = {}
750
+ for idx, out_i in enumerate(outs):
751
+ m = tensor_list.mask
752
+ assert m is not None
753
+ mask = F.interpolate(m[None].float(), size=out_i.shape[-2:]).to(torch.bool)[0]
754
+ outs_dict[idx] = NestedTensor(out_i, mask)
755
+
756
+ return outs_dict
757
+
758
+ def train(self, mode=True):
759
+ """Convert the model into training mode while keep layers freezed."""
760
+ super(SwinTransformer, self).train(mode)
761
+ self._freeze_stages()
762
+
763
+
764
+ def build_swin_transformer(modelname, pretrain_img_size, **kw):
765
+ assert modelname in [
766
+ "swin_T_224_1k",
767
+ "swin_B_224_22k",
768
+ "swin_B_384_22k",
769
+ "swin_L_224_22k",
770
+ "swin_L_384_22k",
771
+ ]
772
+
773
+ model_para_dict = {
774
+ "swin_T_224_1k": dict(
775
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7
776
+ ),
777
+ "swin_B_224_22k": dict(
778
+ embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7
779
+ ),
780
+ "swin_B_384_22k": dict(
781
+ embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12
782
+ ),
783
+ "swin_L_224_22k": dict(
784
+ embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=7
785
+ ),
786
+ "swin_L_384_22k": dict(
787
+ embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12
788
+ ),
789
+ }
790
+ kw_cgf = model_para_dict[modelname]
791
+ kw_cgf.update(kw)
792
+ model = SwinTransformer(pretrain_img_size=pretrain_img_size, **kw_cgf)
793
+ return model
794
+
795
+
796
+ if __name__ == "__main__":
797
+ model = build_swin_transformer("swin_L_384_22k", 384, dilation=True)
798
+ x = torch.rand(2, 3, 1024, 1024)
799
+ y = model.forward_raw(x)
800
+ import ipdb
801
+
802
+ ipdb.set_trace()
803
+ x = torch.rand(2, 3, 384, 384)
804
+ y = model.forward_raw(x)
groundingdino/models/GroundingDINO/backbone/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (259 Bytes). View file
 
groundingdino/models/GroundingDINO/backbone/__pycache__/backbone.cpython-310.pyc ADDED
Binary file (6.26 kB). View file
 
groundingdino/models/GroundingDINO/backbone/__pycache__/position_encoding.cpython-310.pyc ADDED
Binary file (5.18 kB). View file
 
groundingdino/models/GroundingDINO/backbone/__pycache__/swin_transformer.cpython-310.pyc ADDED
Binary file (20.7 kB). View file
 
groundingdino/models/GroundingDINO/backbone/backbone.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # Conditional DETR
8
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------
11
+ # Copied from DETR (https://github.com/facebookresearch/detr)
12
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
13
+ # ------------------------------------------------------------------------
14
+
15
+
16
+ # Backbone modules.
17
+
18
+ from typing import Dict, List
19
+ import loralib as lora
20
+ import torch
21
+ import torch.nn.functional as F
22
+ import torchvision
23
+ from torch import nn
24
+ from torchvision.models._utils import IntermediateLayerGetter
25
+ import loralib as lora
26
+ from groundingdino.util.misc import NestedTensor, clean_state_dict, is_main_process
27
+
28
+ from .position_encoding import build_position_encoding
29
+ from .swin_transformer import build_swin_transformer
30
+
31
+
32
+ class FrozenBatchNorm2d(torch.nn.Module):
33
+ """
34
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
35
+
36
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
37
+ without which any other models than torchvision.models.resnet[18,34,50,101]
38
+ produce nans.
39
+ """
40
+
41
+ def __init__(self, n):
42
+ super(FrozenBatchNorm2d, self).__init__()
43
+ self.register_buffer("weight", torch.ones(n))
44
+ self.register_buffer("bias", torch.zeros(n))
45
+ self.register_buffer("running_mean", torch.zeros(n))
46
+ self.register_buffer("running_var", torch.ones(n))
47
+
48
+ def _load_from_state_dict(
49
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
50
+ ):
51
+ num_batches_tracked_key = prefix + "num_batches_tracked"
52
+ if num_batches_tracked_key in state_dict:
53
+ del state_dict[num_batches_tracked_key]
54
+
55
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
56
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
57
+ )
58
+
59
+ def forward(self, x):
60
+ # move reshapes to the beginning
61
+ # to make it fuser-friendly
62
+ w = self.weight.reshape(1, -1, 1, 1)
63
+ b = self.bias.reshape(1, -1, 1, 1)
64
+ rv = self.running_var.reshape(1, -1, 1, 1)
65
+ rm = self.running_mean.reshape(1, -1, 1, 1)
66
+ eps = 1e-5
67
+ scale = w * (rv + eps).rsqrt()
68
+ bias = b - rm * scale
69
+ return x * scale + bias
70
+
71
+
72
+ class BackboneBase(nn.Module):
73
+ def __init__(
74
+ self,
75
+ backbone: nn.Module,
76
+ train_backbone: bool,
77
+ num_channels: int,
78
+ return_interm_indices: list,
79
+ ):
80
+ super().__init__()
81
+ for name, parameter in backbone.named_parameters():
82
+ if (
83
+ not train_backbone
84
+ or "layer2" not in name
85
+ and "layer3" not in name
86
+ and "layer4" not in name
87
+ ):
88
+ parameter.requires_grad_(False)
89
+
90
+ return_layers = {}
91
+ for idx, layer_index in enumerate(return_interm_indices):
92
+ return_layers.update(
93
+ {"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)}
94
+ )
95
+
96
+ # if len:
97
+ # if use_stage1_feature:
98
+ # return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
99
+ # else:
100
+ # return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
101
+ # else:
102
+ # return_layers = {'layer4': "0"}
103
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
104
+ self.num_channels = num_channels
105
+
106
+ def forward(self, tensor_list: NestedTensor):
107
+ xs = self.body(tensor_list.tensors)
108
+ out: Dict[str, NestedTensor] = {}
109
+ for name, x in xs.items():
110
+ m = tensor_list.mask
111
+ assert m is not None
112
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
113
+ out[name] = NestedTensor(x, mask)
114
+ # import ipdb; ipdb.set_trace()
115
+ return out
116
+
117
+
118
+ class Backbone(BackboneBase):
119
+ """ResNet backbone with frozen BatchNorm."""
120
+
121
+ def __init__(
122
+ self,
123
+ name: str,
124
+ train_backbone: bool,
125
+ dilation: bool,
126
+ return_interm_indices: list,
127
+ batch_norm=FrozenBatchNorm2d,
128
+ ):
129
+ if name in ["resnet18", "resnet34", "resnet50", "resnet101"]:
130
+ backbone = getattr(torchvision.models, name)(
131
+ replace_stride_with_dilation=[False, False, dilation],
132
+ pretrained=is_main_process(),
133
+ norm_layer=batch_norm,
134
+ )
135
+ else:
136
+ raise NotImplementedError("Why you can get here with name {}".format(name))
137
+ # num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
138
+ assert name not in ("resnet18", "resnet34"), "Only resnet50 and resnet101 are available."
139
+ assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
140
+ num_channels_all = [256, 512, 1024, 2048]
141
+ num_channels = num_channels_all[4 - len(return_interm_indices) :]
142
+ super().__init__(backbone, train_backbone, num_channels, return_interm_indices)
143
+
144
+
145
+ class Joiner(nn.Sequential):
146
+ def __init__(self, backbone, position_embedding):
147
+ super().__init__(backbone, position_embedding)
148
+
149
+ def forward(self, tensor_list: NestedTensor):
150
+ xs = self[0](tensor_list)
151
+ out: List[NestedTensor] = []
152
+ pos = []
153
+ for name, x in xs.items():
154
+ out.append(x)
155
+ # position encoding
156
+ pos.append(self[1](x).to(x.tensors.dtype))
157
+
158
+ return out, pos
159
+
160
+
161
+ def build_backbone(args):
162
+ """
163
+ Useful args:
164
+ - backbone: backbone name
165
+ - lr_backbone:
166
+ - dilation
167
+ - return_interm_indices: available: [0,1,2,3], [1,2,3], [3]
168
+ - backbone_freeze_keywords:
169
+ - use_checkpoint: for swin only for now
170
+
171
+ """
172
+ position_embedding = build_position_encoding(args)
173
+ train_backbone = True
174
+ if not train_backbone:
175
+ raise ValueError("Please set lr_backbone > 0")
176
+ return_interm_indices = args.return_interm_indices
177
+ assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
178
+ args.backbone_freeze_keywords
179
+ use_checkpoint = getattr(args, "use_checkpoint", False)
180
+
181
+ if args.backbone in ["resnet50", "resnet101"]:
182
+ backbone = Backbone(
183
+ args.backbone,
184
+ train_backbone,
185
+ args.dilation,
186
+ return_interm_indices,
187
+ batch_norm=FrozenBatchNorm2d,
188
+ )
189
+ bb_num_channels = backbone.num_channels
190
+ elif args.backbone in [
191
+ "swin_T_224_1k",
192
+ "swin_B_224_22k",
193
+ "swin_B_384_22k",
194
+ "swin_L_224_22k",
195
+ "swin_L_384_22k",
196
+ ]:
197
+ pretrain_img_size = int(args.backbone.split("_")[-2])
198
+ backbone = build_swin_transformer(
199
+ args.backbone,
200
+ pretrain_img_size=pretrain_img_size,
201
+ out_indices=tuple(return_interm_indices),
202
+ dilation=False,
203
+ use_checkpoint=use_checkpoint,
204
+ )
205
+
206
+ bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :]
207
+ else:
208
+ raise NotImplementedError("Unknown backbone {}".format(args.backbone))
209
+
210
+ assert len(bb_num_channels) == len(
211
+ return_interm_indices
212
+ ), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}"
213
+
214
+ model = Joiner(backbone, position_embedding)
215
+ model.num_channels = bb_num_channels
216
+ assert isinstance(
217
+ bb_num_channels, List
218
+ ), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels))
219
+ # import ipdb; ipdb.set_trace()
220
+ return model
groundingdino/models/GroundingDINO/backbone/position_encoding.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # DINO
8
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------
11
+ # Conditional DETR
12
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
13
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
14
+ # ------------------------------------------------------------------------
15
+ # Copied from DETR (https://github.com/facebookresearch/detr)
16
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
17
+ # ------------------------------------------------------------------------
18
+
19
+ """
20
+ Various positional encodings for the transformer.
21
+ """
22
+ import math
23
+
24
+ import torch
25
+ from torch import nn
26
+ import loralib as lora
27
+ from groundingdino.util.misc import NestedTensor
28
+
29
+
30
+ class PositionEmbeddingSine(nn.Module):
31
+ """
32
+ This is a more standard version of the position embedding, very similar to the one
33
+ used by the Attention is all you need paper, generalized to work on images.
34
+ """
35
+
36
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
37
+ super().__init__()
38
+ self.num_pos_feats = num_pos_feats
39
+ self.temperature = temperature
40
+ self.normalize = normalize
41
+ if scale is not None and normalize is False:
42
+ raise ValueError("normalize should be True if scale is passed")
43
+ if scale is None:
44
+ scale = 2 * math.pi
45
+ self.scale = scale
46
+
47
+ def forward(self, tensor_list: NestedTensor):
48
+ x = tensor_list.tensors
49
+ mask = tensor_list.mask
50
+ assert mask is not None
51
+ not_mask = ~mask
52
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
53
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
54
+ if self.normalize:
55
+ eps = 1e-6
56
+ # if os.environ.get("SHILONG_AMP", None) == '1':
57
+ # eps = 1e-4
58
+ # else:
59
+ # eps = 1e-6
60
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
61
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
62
+
63
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
64
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
65
+
66
+ pos_x = x_embed[:, :, :, None] / dim_t
67
+ pos_y = y_embed[:, :, :, None] / dim_t
68
+ pos_x = torch.stack(
69
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
70
+ ).flatten(3)
71
+ pos_y = torch.stack(
72
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
73
+ ).flatten(3)
74
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
75
+ return pos
76
+
77
+
78
+ class PositionEmbeddingSineHW(nn.Module):
79
+ """
80
+ This is a more standard version of the position embedding, very similar to the one
81
+ used by the Attention is all you need paper, generalized to work on images.
82
+ """
83
+
84
+ def __init__(
85
+ self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None
86
+ ):
87
+ super().__init__()
88
+ self.num_pos_feats = num_pos_feats
89
+ self.temperatureH = temperatureH
90
+ self.temperatureW = temperatureW
91
+ self.normalize = normalize
92
+ if scale is not None and normalize is False:
93
+ raise ValueError("normalize should be True if scale is passed")
94
+ if scale is None:
95
+ scale = 2 * math.pi
96
+ self.scale = scale
97
+
98
+ def forward(self, tensor_list: NestedTensor):
99
+ x = tensor_list.tensors
100
+ mask = tensor_list.mask
101
+ assert mask is not None
102
+ not_mask = ~mask
103
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
104
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
105
+
106
+ # import ipdb; ipdb.set_trace()
107
+
108
+ if self.normalize:
109
+ eps = 1e-6
110
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
111
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
112
+
113
+ dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
114
+ dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode='floor')) / self.num_pos_feats)
115
+ pos_x = x_embed[:, :, :, None] / dim_tx
116
+
117
+ dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
118
+ dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode='floor')) / self.num_pos_feats)
119
+ pos_y = y_embed[:, :, :, None] / dim_ty
120
+
121
+ pos_x = torch.stack(
122
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
123
+ ).flatten(3)
124
+ pos_y = torch.stack(
125
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
126
+ ).flatten(3)
127
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
128
+
129
+ # import ipdb; ipdb.set_trace()
130
+
131
+ return pos
132
+
133
+
134
+ class PositionEmbeddingLearned(nn.Module):
135
+ """
136
+ Absolute pos embedding, learned.
137
+ """
138
+
139
+ def __init__(self, num_pos_feats=256):
140
+ super().__init__()
141
+ self.row_embed = nn.Embedding(50, num_pos_feats)
142
+ self.col_embed = nn.Embedding(50, num_pos_feats)
143
+ self.reset_parameters()
144
+
145
+ def reset_parameters(self):
146
+ nn.init.uniform_(self.row_embed.weight)
147
+ nn.init.uniform_(self.col_embed.weight)
148
+
149
+ def forward(self, tensor_list: NestedTensor):
150
+ x = tensor_list.tensors
151
+ h, w = x.shape[-2:]
152
+ i = torch.arange(w, device=x.device)
153
+ j = torch.arange(h, device=x.device)
154
+ x_emb = self.col_embed(i)
155
+ y_emb = self.row_embed(j)
156
+ pos = (
157
+ torch.cat(
158
+ [
159
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
160
+ y_emb.unsqueeze(1).repeat(1, w, 1),
161
+ ],
162
+ dim=-1,
163
+ )
164
+ .permute(2, 0, 1)
165
+ .unsqueeze(0)
166
+ .repeat(x.shape[0], 1, 1, 1)
167
+ )
168
+ return pos
169
+
170
+
171
+ def build_position_encoding(args):
172
+ N_steps = args.hidden_dim // 2
173
+ if args.position_embedding in ("v2", "sine"):
174
+ # TODO find a better way of exposing other arguments
175
+ position_embedding = PositionEmbeddingSineHW(
176
+ N_steps,
177
+ temperatureH=args.pe_temperatureH,
178
+ temperatureW=args.pe_temperatureW,
179
+ normalize=True,
180
+ )
181
+ elif args.position_embedding in ("v3", "learned"):
182
+ position_embedding = PositionEmbeddingLearned(N_steps)
183
+ else:
184
+ raise ValueError(f"not supported {args.position_embedding}")
185
+
186
+ return position_embedding
groundingdino/models/GroundingDINO/backbone/swin_transformer.py ADDED
@@ -0,0 +1,804 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # DINO
8
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # --------------------------------------------------------
11
+ # modified from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py
12
+ # --------------------------------------------------------
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.utils.checkpoint as checkpoint
19
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
20
+ import loralib as lora
21
+ from groundingdino.util.misc import NestedTensor
22
+
23
+
24
+ class Mlp(nn.Module):
25
+ """Multilayer perceptron."""
26
+
27
+ def __init__(
28
+ self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
29
+ ):
30
+ super().__init__()
31
+ out_features = out_features or in_features
32
+ hidden_features = hidden_features or in_features
33
+ r = 16
34
+ self.fc1 = lora.Linear(in_features, hidden_features , r=r)
35
+ self.act = act_layer()
36
+ self.fc2 = lora.Linear(hidden_features, out_features , r=r)
37
+ self.drop = nn.Dropout(drop)
38
+
39
+ def forward(self, x):
40
+ x = self.fc1(x)
41
+ x = self.act(x)
42
+ x = self.drop(x)
43
+ x = self.fc2(x)
44
+ x = self.drop(x)
45
+ return x
46
+
47
+
48
+ def window_partition(x, window_size):
49
+ """
50
+ Args:
51
+ x: (B, H, W, C)
52
+ window_size (int): window size
53
+ Returns:
54
+ windows: (num_windows*B, window_size, window_size, C)
55
+ """
56
+ B, H, W, C = x.shape
57
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
58
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
59
+ return windows
60
+
61
+
62
+ def window_reverse(windows, window_size, H, W):
63
+ """
64
+ Args:
65
+ windows: (num_windows*B, window_size, window_size, C)
66
+ window_size (int): Window size
67
+ H (int): Height of image
68
+ W (int): Width of image
69
+ Returns:
70
+ x: (B, H, W, C)
71
+ """
72
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
73
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
74
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
75
+ return x
76
+
77
+
78
+ class WindowAttention(nn.Module):
79
+ """Window based multi-head self attention (W-MSA) module with relative position bias.
80
+ It supports both of shifted and non-shifted window.
81
+ Args:
82
+ dim (int): Number of input channels.
83
+ window_size (tuple[int]): The height and width of the window.
84
+ num_heads (int): Number of attention heads.
85
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
86
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
87
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
88
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ dim,
94
+ window_size,
95
+ num_heads,
96
+ qkv_bias=True,
97
+ qk_scale=None,
98
+ attn_drop=0.0,
99
+ proj_drop=0.0,
100
+ ):
101
+
102
+ super().__init__()
103
+ self.dim = dim
104
+ self.window_size = window_size # Wh, Ww
105
+ self.num_heads = num_heads
106
+ head_dim = dim // num_heads
107
+ self.scale = qk_scale or head_dim**-0.5
108
+ r =16
109
+ # define a parameter table of relative position bias
110
+ self.relative_position_bias_table = nn.Parameter(
111
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
112
+ ) # 2*Wh-1 * 2*Ww-1, nH
113
+
114
+ # get pair-wise relative position index for each token inside the window
115
+ coords_h = torch.arange(self.window_size[0])
116
+ coords_w = torch.arange(self.window_size[1])
117
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
118
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
119
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
120
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
121
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
122
+ relative_coords[:, :, 1] += self.window_size[1] - 1
123
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
124
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
125
+ self.register_buffer("relative_position_index", relative_position_index)
126
+
127
+ self.qkv = lora.Linear(dim, dim * 3,r=r , bias=qkv_bias)
128
+ self.attn_drop = nn.Dropout(attn_drop)
129
+ self.proj = lora.Linear(dim, dim , r=r)
130
+ self.proj_drop = nn.Dropout(proj_drop)
131
+
132
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
133
+ self.softmax = nn.Softmax(dim=-1)
134
+
135
+ def forward(self, x, mask=None):
136
+ """Forward function.
137
+ Args:
138
+ x: input features with shape of (num_windows*B, N, C)
139
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
140
+ """
141
+ B_, N, C = x.shape
142
+ qkv = (
143
+ self.qkv(x)
144
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
145
+ .permute(2, 0, 3, 1, 4)
146
+ )
147
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
148
+
149
+ q = q * self.scale
150
+ attn = q @ k.transpose(-2, -1)
151
+
152
+ relative_position_bias = self.relative_position_bias_table[
153
+ self.relative_position_index.view(-1)
154
+ ].view(
155
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
156
+ ) # Wh*Ww,Wh*Ww,nH
157
+ relative_position_bias = relative_position_bias.permute(
158
+ 2, 0, 1
159
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
160
+ attn = attn + relative_position_bias.unsqueeze(0)
161
+
162
+ if mask is not None:
163
+ nW = mask.shape[0]
164
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
165
+ attn = attn.view(-1, self.num_heads, N, N)
166
+ attn = self.softmax(attn)
167
+ else:
168
+ attn = self.softmax(attn)
169
+
170
+ attn = self.attn_drop(attn)
171
+
172
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
173
+ x = self.proj(x)
174
+ x = self.proj_drop(x)
175
+ return x
176
+
177
+
178
+ class SwinTransformerBlock(nn.Module):
179
+ """Swin Transformer Block.
180
+ Args:
181
+ dim (int): Number of input channels.
182
+ num_heads (int): Number of attention heads.
183
+ window_size (int): Window size.
184
+ shift_size (int): Shift size for SW-MSA.
185
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
186
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
187
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
188
+ drop (float, optional): Dropout rate. Default: 0.0
189
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
190
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
191
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
192
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
193
+ """
194
+
195
+ def __init__(
196
+ self,
197
+ dim,
198
+ num_heads,
199
+ window_size=7,
200
+ shift_size=0,
201
+ mlp_ratio=4.0,
202
+ qkv_bias=True,
203
+ qk_scale=None,
204
+ drop=0.0,
205
+ attn_drop=0.0,
206
+ drop_path=0.0,
207
+ act_layer=nn.GELU,
208
+ norm_layer=nn.LayerNorm,
209
+ ):
210
+ super().__init__()
211
+ self.dim = dim
212
+ self.num_heads = num_heads
213
+ self.window_size = window_size
214
+ self.shift_size = shift_size
215
+ self.mlp_ratio = mlp_ratio
216
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
217
+
218
+ self.norm1 = norm_layer(dim)
219
+ self.attn = WindowAttention(
220
+ dim,
221
+ window_size=to_2tuple(self.window_size),
222
+ num_heads=num_heads,
223
+ qkv_bias=qkv_bias,
224
+ qk_scale=qk_scale,
225
+ attn_drop=attn_drop,
226
+ proj_drop=drop,
227
+ )
228
+
229
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
230
+ self.norm2 = norm_layer(dim)
231
+ mlp_hidden_dim = int(dim * mlp_ratio)
232
+ self.mlp = Mlp(
233
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
234
+ )
235
+
236
+ self.H = None
237
+ self.W = None
238
+
239
+ def forward(self, x, mask_matrix):
240
+ """Forward function.
241
+ Args:
242
+ x: Input feature, tensor size (B, H*W, C).
243
+ H, W: Spatial resolution of the input feature.
244
+ mask_matrix: Attention mask for cyclic shift.
245
+ """
246
+ B, L, C = x.shape
247
+ H, W = self.H, self.W
248
+ assert L == H * W, "input feature has wrong size"
249
+
250
+ shortcut = x
251
+ x = self.norm1(x)
252
+ x = x.view(B, H, W, C)
253
+
254
+ # pad feature maps to multiples of window size
255
+ pad_l = pad_t = 0
256
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
257
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
258
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
259
+ _, Hp, Wp, _ = x.shape
260
+
261
+ # cyclic shift
262
+ if self.shift_size > 0:
263
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
264
+ attn_mask = mask_matrix
265
+ else:
266
+ shifted_x = x
267
+ attn_mask = None
268
+
269
+ # partition windows
270
+ x_windows = window_partition(
271
+ shifted_x, self.window_size
272
+ ) # nW*B, window_size, window_size, C
273
+ x_windows = x_windows.view(
274
+ -1, self.window_size * self.window_size, C
275
+ ) # nW*B, window_size*window_size, C
276
+
277
+ # W-MSA/SW-MSA
278
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
279
+
280
+ # merge windows
281
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
282
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
283
+
284
+ # reverse cyclic shift
285
+ if self.shift_size > 0:
286
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
287
+ else:
288
+ x = shifted_x
289
+
290
+ if pad_r > 0 or pad_b > 0:
291
+ x = x[:, :H, :W, :].contiguous()
292
+
293
+ x = x.view(B, H * W, C)
294
+
295
+ # FFN
296
+ x = shortcut + self.drop_path(x)
297
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
298
+
299
+ return x
300
+
301
+
302
+ class PatchMerging(nn.Module):
303
+ """Patch Merging Layer
304
+ Args:
305
+ dim (int): Number of input channels.
306
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
307
+ """
308
+
309
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
310
+ super().__init__()
311
+ self.dim = dim
312
+ r = 16
313
+ self.reduction = lora.Linear(4 * dim, 2 * dim, r=r, bias=False)
314
+ self.norm = norm_layer(4 * dim)
315
+
316
+ def forward(self, x, H, W):
317
+ """Forward function.
318
+ Args:
319
+ x: Input feature, tensor size (B, H*W, C).
320
+ H, W: Spatial resolution of the input feature.
321
+ """
322
+ B, L, C = x.shape
323
+ assert L == H * W, "input feature has wrong size"
324
+
325
+ x = x.view(B, H, W, C)
326
+
327
+ # padding
328
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
329
+ if pad_input:
330
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
331
+
332
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
333
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
334
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
335
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
336
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
337
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
338
+
339
+ x = self.norm(x)
340
+ x = self.reduction(x)
341
+
342
+ return x
343
+
344
+
345
+ class BasicLayer(nn.Module):
346
+ """A basic Swin Transformer layer for one stage.
347
+ Args:
348
+ dim (int): Number of feature channels
349
+ depth (int): Depths of this stage.
350
+ num_heads (int): Number of attention head.
351
+ window_size (int): Local window size. Default: 7.
352
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
353
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
354
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
355
+ drop (float, optional): Dropout rate. Default: 0.0
356
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
357
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
358
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
359
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
360
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
361
+ """
362
+
363
+ def __init__(
364
+ self,
365
+ dim,
366
+ depth,
367
+ num_heads,
368
+ window_size=7,
369
+ mlp_ratio=4.0,
370
+ qkv_bias=True,
371
+ qk_scale=None,
372
+ drop=0.0,
373
+ attn_drop=0.0,
374
+ drop_path=0.0,
375
+ norm_layer=nn.LayerNorm,
376
+ downsample=None,
377
+ use_checkpoint=False,
378
+ ):
379
+ super().__init__()
380
+ self.window_size = window_size
381
+ self.shift_size = window_size // 2
382
+ self.depth = depth
383
+ self.use_checkpoint = use_checkpoint
384
+
385
+ # build blocks
386
+ self.blocks = nn.ModuleList(
387
+ [
388
+ SwinTransformerBlock(
389
+ dim=dim,
390
+ num_heads=num_heads,
391
+ window_size=window_size,
392
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
393
+ mlp_ratio=mlp_ratio,
394
+ qkv_bias=qkv_bias,
395
+ qk_scale=qk_scale,
396
+ drop=drop,
397
+ attn_drop=attn_drop,
398
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
399
+ norm_layer=norm_layer,
400
+ )
401
+ for i in range(depth)
402
+ ]
403
+ )
404
+
405
+ # patch merging layer
406
+ if downsample is not None:
407
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
408
+ else:
409
+ self.downsample = None
410
+
411
+ def forward(self, x, H, W):
412
+ """Forward function.
413
+ Args:
414
+ x: Input feature, tensor size (B, H*W, C).
415
+ H, W: Spatial resolution of the input feature.
416
+ """
417
+
418
+ # calculate attention mask for SW-MSA
419
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
420
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
421
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
422
+ h_slices = (
423
+ slice(0, -self.window_size),
424
+ slice(-self.window_size, -self.shift_size),
425
+ slice(-self.shift_size, None),
426
+ )
427
+ w_slices = (
428
+ slice(0, -self.window_size),
429
+ slice(-self.window_size, -self.shift_size),
430
+ slice(-self.shift_size, None),
431
+ )
432
+ cnt = 0
433
+ for h in h_slices:
434
+ for w in w_slices:
435
+ img_mask[:, h, w, :] = cnt
436
+ cnt += 1
437
+
438
+ mask_windows = window_partition(
439
+ img_mask, self.window_size
440
+ ) # nW, window_size, window_size, 1
441
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
442
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
443
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
444
+ attn_mask == 0, float(0.0)
445
+ )
446
+
447
+ for blk in self.blocks:
448
+ blk.H, blk.W = H, W
449
+ if self.use_checkpoint:
450
+ x = checkpoint.checkpoint(blk, x, attn_mask)
451
+ else:
452
+ x = blk(x, attn_mask)
453
+ if self.downsample is not None:
454
+ x_down = self.downsample(x, H, W)
455
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
456
+ return x, H, W, x_down, Wh, Ww
457
+ else:
458
+ return x, H, W, x, H, W
459
+
460
+
461
+ class PatchEmbed(nn.Module):
462
+ """Image to Patch Embedding
463
+ Args:
464
+ patch_size (int): Patch token size. Default: 4.
465
+ in_chans (int): Number of input image channels. Default: 3.
466
+ embed_dim (int): Number of linear projection output channels. Default: 96.
467
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
468
+ """
469
+
470
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
471
+ super().__init__()
472
+ patch_size = to_2tuple(patch_size)
473
+ self.patch_size = patch_size
474
+
475
+ self.in_chans = in_chans
476
+ self.embed_dim = embed_dim
477
+
478
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
479
+ if norm_layer is not None:
480
+ self.norm = norm_layer(embed_dim)
481
+ else:
482
+ self.norm = None
483
+
484
+ def forward(self, x):
485
+ """Forward function."""
486
+ # padding
487
+ _, _, H, W = x.size()
488
+ if W % self.patch_size[1] != 0:
489
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
490
+ if H % self.patch_size[0] != 0:
491
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
492
+
493
+ x = self.proj(x) # B C Wh Ww
494
+ if self.norm is not None:
495
+ Wh, Ww = x.size(2), x.size(3)
496
+ x = x.flatten(2).transpose(1, 2)
497
+ x = self.norm(x)
498
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
499
+
500
+ return x
501
+
502
+
503
+ class SwinTransformer(nn.Module):
504
+ """Swin Transformer backbone.
505
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
506
+ https://arxiv.org/pdf/2103.14030
507
+ Args:
508
+ pretrain_img_size (int): Input image size for training the pretrained model,
509
+ used in absolute postion embedding. Default 224.
510
+ patch_size (int | tuple(int)): Patch size. Default: 4.
511
+ in_chans (int): Number of input image channels. Default: 3.
512
+ embed_dim (int): Number of linear projection output channels. Default: 96.
513
+ depths (tuple[int]): Depths of each Swin Transformer stage.
514
+ num_heads (tuple[int]): Number of attention head of each stage.
515
+ window_size (int): Window size. Default: 7.
516
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
517
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
518
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
519
+ drop_rate (float): Dropout rate.
520
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
521
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
522
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
523
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
524
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
525
+ out_indices (Sequence[int]): Output from which stages.
526
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
527
+ -1 means not freezing any parameters.
528
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
529
+ dilation (bool): if True, the output size if 16x downsample, ow 32x downsample.
530
+ """
531
+
532
+ def __init__(
533
+ self,
534
+ pretrain_img_size=224,
535
+ patch_size=4,
536
+ in_chans=3,
537
+ embed_dim=96,
538
+ depths=[2, 2, 6, 2],
539
+ num_heads=[3, 6, 12, 24],
540
+ window_size=7,
541
+ mlp_ratio=4.0,
542
+ qkv_bias=True,
543
+ qk_scale=None,
544
+ drop_rate=0.0,
545
+ attn_drop_rate=0.0,
546
+ drop_path_rate=0.2,
547
+ norm_layer=nn.LayerNorm,
548
+ ape=False,
549
+ patch_norm=True,
550
+ out_indices=(0, 1, 2, 3),
551
+ frozen_stages=-1,
552
+ dilation=False,
553
+ use_checkpoint=False,
554
+ ):
555
+ super().__init__()
556
+
557
+ self.pretrain_img_size = pretrain_img_size
558
+ self.num_layers = len(depths)
559
+ self.embed_dim = embed_dim
560
+ self.ape = ape
561
+ self.patch_norm = patch_norm
562
+ self.out_indices = out_indices
563
+ self.frozen_stages = frozen_stages
564
+ self.dilation = dilation
565
+
566
+ # if use_checkpoint:
567
+ # print("use_checkpoint!!!!!!!!!!!!!!!!!!!!!!!!")
568
+
569
+ # split image into non-overlapping patches
570
+ self.patch_embed = PatchEmbed(
571
+ patch_size=patch_size,
572
+ in_chans=in_chans,
573
+ embed_dim=embed_dim,
574
+ norm_layer=norm_layer if self.patch_norm else None,
575
+ )
576
+
577
+ # absolute position embedding
578
+ if self.ape:
579
+ pretrain_img_size = to_2tuple(pretrain_img_size)
580
+ patch_size = to_2tuple(patch_size)
581
+ patches_resolution = [
582
+ pretrain_img_size[0] // patch_size[0],
583
+ pretrain_img_size[1] // patch_size[1],
584
+ ]
585
+
586
+ self.absolute_pos_embed = nn.Parameter(
587
+ torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
588
+ )
589
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
590
+
591
+ self.pos_drop = nn.Dropout(p=drop_rate)
592
+
593
+ # stochastic depth
594
+ dpr = [
595
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
596
+ ] # stochastic depth decay rule
597
+
598
+ # build layers
599
+ self.layers = nn.ModuleList()
600
+ # prepare downsample list
601
+ downsamplelist = [PatchMerging for i in range(self.num_layers)]
602
+ downsamplelist[-1] = None
603
+ num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
604
+ if self.dilation:
605
+ downsamplelist[-2] = None
606
+ num_features[-1] = int(embed_dim * 2 ** (self.num_layers - 1)) // 2
607
+ for i_layer in range(self.num_layers):
608
+ layer = BasicLayer(
609
+ # dim=int(embed_dim * 2 ** i_layer),
610
+ dim=num_features[i_layer],
611
+ depth=depths[i_layer],
612
+ num_heads=num_heads[i_layer],
613
+ window_size=window_size,
614
+ mlp_ratio=mlp_ratio,
615
+ qkv_bias=qkv_bias,
616
+ qk_scale=qk_scale,
617
+ drop=drop_rate,
618
+ attn_drop=attn_drop_rate,
619
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
620
+ norm_layer=norm_layer,
621
+ # downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
622
+ downsample=downsamplelist[i_layer],
623
+ use_checkpoint=use_checkpoint,
624
+ )
625
+ self.layers.append(layer)
626
+
627
+ # num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
628
+ self.num_features = num_features
629
+
630
+ # add a norm layer for each output
631
+ for i_layer in out_indices:
632
+ layer = norm_layer(num_features[i_layer])
633
+ layer_name = f"norm{i_layer}"
634
+ self.add_module(layer_name, layer)
635
+
636
+ self._freeze_stages()
637
+
638
+ def _freeze_stages(self):
639
+ if self.frozen_stages >= 0:
640
+ self.patch_embed.eval()
641
+ for param in self.patch_embed.parameters():
642
+ param.requires_grad = False
643
+
644
+ if self.frozen_stages >= 1 and self.ape:
645
+ self.absolute_pos_embed.requires_grad = False
646
+
647
+ if self.frozen_stages >= 2:
648
+ self.pos_drop.eval()
649
+ for i in range(0, self.frozen_stages - 1):
650
+ m = self.layers[i]
651
+ m.eval()
652
+ for param in m.parameters():
653
+ param.requires_grad = False
654
+
655
+ # def init_weights(self, pretrained=None):
656
+ # """Initialize the weights in backbone.
657
+ # Args:
658
+ # pretrained (str, optional): Path to pre-trained weights.
659
+ # Defaults to None.
660
+ # """
661
+
662
+ # def _init_weights(m):
663
+ # if isinstance(m, nn.Linear):
664
+ # trunc_normal_(m.weight, std=.02)
665
+ # if isinstance(m, nn.Linear) and m.bias is not None:
666
+ # nn.init.constant_(m.bias, 0)
667
+ # elif isinstance(m, nn.LayerNorm):
668
+ # nn.init.constant_(m.bias, 0)
669
+ # nn.init.constant_(m.weight, 1.0)
670
+
671
+ # if isinstance(pretrained, str):
672
+ # self.apply(_init_weights)
673
+ # logger = get_root_logger()
674
+ # load_checkpoint(self, pretrained, strict=False, logger=logger)
675
+ # elif pretrained is None:
676
+ # self.apply(_init_weights)
677
+ # else:
678
+ # raise TypeError('pretrained must be a str or None')
679
+
680
+ def forward_raw(self, x):
681
+ """Forward function."""
682
+ x = self.patch_embed(x)
683
+
684
+ Wh, Ww = x.size(2), x.size(3)
685
+ if self.ape:
686
+ # interpolate the position embedding to the corresponding size
687
+ absolute_pos_embed = F.interpolate(
688
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
689
+ )
690
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
691
+ else:
692
+ x = x.flatten(2).transpose(1, 2)
693
+ x = self.pos_drop(x)
694
+
695
+ outs = []
696
+ for i in range(self.num_layers):
697
+ layer = self.layers[i]
698
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
699
+ # import ipdb; ipdb.set_trace()
700
+
701
+ if i in self.out_indices:
702
+ norm_layer = getattr(self, f"norm{i}")
703
+ x_out = norm_layer(x_out)
704
+
705
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
706
+ outs.append(out)
707
+ # in:
708
+ # torch.Size([2, 3, 1024, 1024])
709
+ # outs:
710
+ # [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
711
+ # torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
712
+ return tuple(outs)
713
+
714
+ def forward(self, tensor_list: NestedTensor):
715
+ x = tensor_list.tensors
716
+
717
+ """Forward function."""
718
+ x = self.patch_embed(x)
719
+
720
+ Wh, Ww = x.size(2), x.size(3)
721
+ if self.ape:
722
+ # interpolate the position embedding to the corresponding size
723
+ absolute_pos_embed = F.interpolate(
724
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
725
+ )
726
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
727
+ else:
728
+ x = x.flatten(2).transpose(1, 2)
729
+ x = self.pos_drop(x)
730
+
731
+ outs = []
732
+ for i in range(self.num_layers):
733
+ layer = self.layers[i]
734
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
735
+
736
+ if i in self.out_indices:
737
+ norm_layer = getattr(self, f"norm{i}")
738
+ x_out = norm_layer(x_out)
739
+
740
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
741
+ outs.append(out)
742
+ # in:
743
+ # torch.Size([2, 3, 1024, 1024])
744
+ # out:
745
+ # [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
746
+ # torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
747
+
748
+ # collect for nesttensors
749
+ outs_dict = {}
750
+ for idx, out_i in enumerate(outs):
751
+ m = tensor_list.mask
752
+ assert m is not None
753
+ mask = F.interpolate(m[None].float(), size=out_i.shape[-2:]).to(torch.bool)[0]
754
+ outs_dict[idx] = NestedTensor(out_i, mask)
755
+
756
+ return outs_dict
757
+
758
+ def train(self, mode=True):
759
+ """Convert the model into training mode while keep layers freezed."""
760
+ super(SwinTransformer, self).train(mode)
761
+ self._freeze_stages()
762
+
763
+
764
+ def build_swin_transformer(modelname, pretrain_img_size, **kw):
765
+ assert modelname in [
766
+ "swin_T_224_1k",
767
+ "swin_B_224_22k",
768
+ "swin_B_384_22k",
769
+ "swin_L_224_22k",
770
+ "swin_L_384_22k",
771
+ ]
772
+
773
+ model_para_dict = {
774
+ "swin_T_224_1k": dict(
775
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7
776
+ ),
777
+ "swin_B_224_22k": dict(
778
+ embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7
779
+ ),
780
+ "swin_B_384_22k": dict(
781
+ embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12
782
+ ),
783
+ "swin_L_224_22k": dict(
784
+ embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=7
785
+ ),
786
+ "swin_L_384_22k": dict(
787
+ embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12
788
+ ),
789
+ }
790
+ kw_cgf = model_para_dict[modelname]
791
+ kw_cgf.update(kw)
792
+ model = SwinTransformer(pretrain_img_size=pretrain_img_size, **kw_cgf)
793
+ return model
794
+
795
+
796
+ if __name__ == "__main__":
797
+ model = build_swin_transformer("swin_L_384_22k", 384, dilation=True)
798
+ x = torch.rand(2, 3, 1024, 1024)
799
+ y = model.forward_raw(x)
800
+ import ipdb
801
+
802
+ ipdb.set_trace()
803
+ x = torch.rand(2, 3, 384, 384)
804
+ y = model.forward_raw(x)
groundingdino/models/GroundingDINO/fuse_modules.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from timm.models.layers import DropPath
12
+ import loralib as lora
13
+
14
+ class FeatureResizer(nn.Module):
15
+ """
16
+ This class takes as input a set of embeddings of dimension C1 and outputs a set of
17
+ embedding of dimension C2, after a linear transformation, dropout and normalization (LN).
18
+ """
19
+
20
+ def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True):
21
+ super().__init__()
22
+ self.do_ln = do_ln
23
+ # Object feature encoding
24
+ r = 16
25
+ self.fc = lora.Linear(input_feat_size, output_feat_size,r=r , bias=True)
26
+ self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12)
27
+ self.dropout = nn.Dropout(dropout)
28
+
29
+ def forward(self, encoder_features):
30
+ x = self.fc(encoder_features)
31
+ if self.do_ln:
32
+ x = self.layer_norm(x)
33
+ output = self.dropout(x)
34
+ return output
35
+
36
+
37
+ def l1norm(X, dim, eps=1e-8):
38
+ """L1-normalize columns of X"""
39
+ norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps
40
+ X = torch.div(X, norm)
41
+ return X
42
+
43
+
44
+ def l2norm(X, dim, eps=1e-8):
45
+ """L2-normalize columns of X"""
46
+ norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
47
+ X = torch.div(X, norm)
48
+ return X
49
+
50
+
51
+ def func_attention(query, context, smooth=1, raw_feature_norm="softmax", eps=1e-8):
52
+ """
53
+ query: (n_context, queryL, d)
54
+ context: (n_context, sourceL, d)
55
+ """
56
+ batch_size_q, queryL = query.size(0), query.size(1)
57
+ batch_size, sourceL = context.size(0), context.size(1)
58
+
59
+ # Get attention
60
+ # --> (batch, d, queryL)
61
+ queryT = torch.transpose(query, 1, 2)
62
+
63
+ # (batch, sourceL, d)(batch, d, queryL)
64
+ # --> (batch, sourceL, queryL)
65
+ attn = torch.bmm(context, queryT)
66
+ if raw_feature_norm == "softmax":
67
+ # --> (batch*sourceL, queryL)
68
+ attn = attn.view(batch_size * sourceL, queryL)
69
+ attn = nn.Softmax()(attn)
70
+ # --> (batch, sourceL, queryL)
71
+ attn = attn.view(batch_size, sourceL, queryL)
72
+ elif raw_feature_norm == "l2norm":
73
+ attn = l2norm(attn, 2)
74
+ elif raw_feature_norm == "clipped_l2norm":
75
+ attn = nn.LeakyReLU(0.1)(attn)
76
+ attn = l2norm(attn, 2)
77
+ else:
78
+ raise ValueError("unknown first norm type:", raw_feature_norm)
79
+ # --> (batch, queryL, sourceL)
80
+ attn = torch.transpose(attn, 1, 2).contiguous()
81
+ # --> (batch*queryL, sourceL)
82
+ attn = attn.view(batch_size * queryL, sourceL)
83
+ attn = nn.Softmax()(attn * smooth)
84
+ # --> (batch, queryL, sourceL)
85
+ attn = attn.view(batch_size, queryL, sourceL)
86
+ # --> (batch, sourceL, queryL)
87
+ attnT = torch.transpose(attn, 1, 2).contiguous()
88
+
89
+ # --> (batch, d, sourceL)
90
+ contextT = torch.transpose(context, 1, 2)
91
+ # (batch x d x sourceL)(batch x sourceL x queryL)
92
+ # --> (batch, d, queryL)
93
+ weightedContext = torch.bmm(contextT, attnT)
94
+ # --> (batch, queryL, d)
95
+ weightedContext = torch.transpose(weightedContext, 1, 2)
96
+
97
+ return weightedContext, attnT
98
+
99
+
100
+ class BiMultiHeadAttention(nn.Module):
101
+ def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1, cfg=None):
102
+ super(BiMultiHeadAttention, self).__init__()
103
+
104
+ self.embed_dim = embed_dim
105
+ self.num_heads = num_heads
106
+ self.head_dim = embed_dim // num_heads
107
+ self.v_dim = v_dim
108
+ self.l_dim = l_dim
109
+
110
+ assert (
111
+ self.head_dim * self.num_heads == self.embed_dim
112
+ ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
113
+ self.scale = self.head_dim ** (-0.5)
114
+ self.dropout = dropout
115
+ r = 16
116
+ self.v_proj = lora.Linear(self.v_dim, self.embed_dim , r=r)
117
+ self.l_proj = lora.Linear(self.l_dim, self.embed_dim , r=r)
118
+ self.values_v_proj = lora.Linear(self.v_dim, self.embed_dim , r=r)
119
+ self.values_l_proj = lora.Linear(self.l_dim, self.embed_dim , r=r)
120
+
121
+ self.out_v_proj = lora.Linear(self.embed_dim, self.v_dim , r=r)
122
+ self.out_l_proj = lora.Linear(self.embed_dim, self.l_dim , r=r)
123
+
124
+ self.stable_softmax_2d = True
125
+ self.clamp_min_for_underflow = True
126
+ self.clamp_max_for_overflow = True
127
+
128
+ self._reset_parameters()
129
+
130
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
131
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
132
+
133
+ def _reset_parameters(self):
134
+ nn.init.xavier_uniform_(self.v_proj.weight)
135
+ self.v_proj.bias.data.fill_(0)
136
+ nn.init.xavier_uniform_(self.l_proj.weight)
137
+ self.l_proj.bias.data.fill_(0)
138
+ nn.init.xavier_uniform_(self.values_v_proj.weight)
139
+ self.values_v_proj.bias.data.fill_(0)
140
+ nn.init.xavier_uniform_(self.values_l_proj.weight)
141
+ self.values_l_proj.bias.data.fill_(0)
142
+ nn.init.xavier_uniform_(self.out_v_proj.weight)
143
+ self.out_v_proj.bias.data.fill_(0)
144
+ nn.init.xavier_uniform_(self.out_l_proj.weight)
145
+ self.out_l_proj.bias.data.fill_(0)
146
+
147
+ def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
148
+ """_summary_
149
+
150
+ Args:
151
+ v (_type_): bs, n_img, dim
152
+ l (_type_): bs, n_text, dim
153
+ attention_mask_v (_type_, optional): _description_. bs, n_img
154
+ attention_mask_l (_type_, optional): _description_. bs, n_text
155
+
156
+ Returns:
157
+ _type_: _description_
158
+ """
159
+ # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
160
+ # import ipdb; ipdb.set_trace()
161
+ bsz, tgt_len, _ = v.size()
162
+
163
+ query_states = self.v_proj(v) * self.scale
164
+ key_states = self._shape(self.l_proj(l), -1, bsz)
165
+ value_v_states = self._shape(self.values_v_proj(v), -1, bsz)
166
+ value_l_states = self._shape(self.values_l_proj(l), -1, bsz)
167
+
168
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
169
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
170
+ key_states = key_states.view(*proj_shape)
171
+ value_v_states = value_v_states.view(*proj_shape)
172
+ value_l_states = value_l_states.view(*proj_shape)
173
+
174
+ src_len = key_states.size(1)
175
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) # bs*nhead, nimg, ntxt
176
+
177
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
178
+ raise ValueError(
179
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
180
+ )
181
+
182
+ if self.stable_softmax_2d:
183
+ attn_weights = attn_weights - attn_weights.max()
184
+
185
+ if self.clamp_min_for_underflow:
186
+ attn_weights = torch.clamp(
187
+ attn_weights, min=-50000
188
+ ) # Do not increase -50000, data type half has quite limited range
189
+ if self.clamp_max_for_overflow:
190
+ attn_weights = torch.clamp(
191
+ attn_weights, max=50000
192
+ ) # Do not increase 50000, data type half has quite limited range
193
+
194
+ attn_weights_T = attn_weights.transpose(1, 2)
195
+ attn_weights_l = attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[0]
196
+ if self.clamp_min_for_underflow:
197
+ attn_weights_l = torch.clamp(
198
+ attn_weights_l, min=-50000
199
+ ) # Do not increase -50000, data type half has quite limited range
200
+ if self.clamp_max_for_overflow:
201
+ attn_weights_l = torch.clamp(
202
+ attn_weights_l, max=50000
203
+ ) # Do not increase 50000, data type half has quite limited range
204
+
205
+ # mask vison for language
206
+ if attention_mask_v is not None:
207
+ attention_mask_v = (
208
+ attention_mask_v[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
209
+ )
210
+ attn_weights_l.masked_fill_(attention_mask_v, float("-inf"))
211
+
212
+ attn_weights_l = attn_weights_l.softmax(dim=-1)
213
+
214
+ # mask language for vision
215
+ if attention_mask_l is not None:
216
+ attention_mask_l = (
217
+ attention_mask_l[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
218
+ )
219
+ attn_weights.masked_fill_(attention_mask_l, float("-inf"))
220
+ attn_weights_v = attn_weights.softmax(dim=-1)
221
+
222
+ attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training)
223
+ attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training)
224
+
225
+ attn_output_v = torch.bmm(attn_probs_v, value_l_states)
226
+ attn_output_l = torch.bmm(attn_probs_l, value_v_states)
227
+
228
+ if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
229
+ raise ValueError(
230
+ f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.size()}"
231
+ )
232
+
233
+ if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim):
234
+ raise ValueError(
235
+ f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, but is {attn_output_l.size()}"
236
+ )
237
+
238
+ attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim)
239
+ attn_output_v = attn_output_v.transpose(1, 2)
240
+ attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim)
241
+
242
+ attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim)
243
+ attn_output_l = attn_output_l.transpose(1, 2)
244
+ attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim)
245
+
246
+ attn_output_v = self.out_v_proj(attn_output_v)
247
+ attn_output_l = self.out_l_proj(attn_output_l)
248
+
249
+ return attn_output_v, attn_output_l
250
+
251
+
252
+ # Bi-Direction MHA (text->image, image->text)
253
+ class BiAttentionBlock(nn.Module):
254
+ def __init__(
255
+ self,
256
+ v_dim,
257
+ l_dim,
258
+ embed_dim,
259
+ num_heads,
260
+ dropout=0.1,
261
+ drop_path=0.0,
262
+ init_values=1e-4,
263
+ cfg=None,
264
+ ):
265
+ """
266
+ Inputs:
267
+ embed_dim - Dimensionality of input and attention feature vectors
268
+ hidden_dim - Dimensionality of hidden layer in feed-forward network
269
+ (usually 2-4x larger than embed_dim)
270
+ num_heads - Number of heads to use in the Multi-Head Attention block
271
+ dropout - Amount of dropout to apply in the feed-forward network
272
+ """
273
+ super(BiAttentionBlock, self).__init__()
274
+
275
+ # pre layer norm
276
+ self.layer_norm_v = nn.LayerNorm(v_dim)
277
+ self.layer_norm_l = nn.LayerNorm(l_dim)
278
+ self.attn = BiMultiHeadAttention(
279
+ v_dim=v_dim, l_dim=l_dim, embed_dim=embed_dim, num_heads=num_heads, dropout=dropout
280
+ )
281
+
282
+ # add layer scale for training stability
283
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
284
+ self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=True)
285
+ self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=True)
286
+
287
+ def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
288
+ v = self.layer_norm_v(v)
289
+ l = self.layer_norm_l(l)
290
+ delta_v, delta_l = self.attn(
291
+ v, l, attention_mask_v=attention_mask_v, attention_mask_l=attention_mask_l
292
+ )
293
+ # v, l = v + delta_v, l + delta_l
294
+ v = v + self.drop_path(self.gamma_v * delta_v)
295
+ l = l + self.drop_path(self.gamma_l * delta_l)
296
+ return v, l
297
+
298
+ # def forward(self, v:List[torch.Tensor], l, attention_mask_v=None, attention_mask_l=None)
groundingdino/models/GroundingDINO/groundingdino.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # Conditional DETR model and criterion classes.
8
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------
11
+ # Modified from DETR (https://github.com/facebookresearch/detr)
12
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
13
+ # ------------------------------------------------------------------------
14
+ # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
15
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
16
+ # ------------------------------------------------------------------------
17
+ import copy
18
+ from typing import List
19
+ import loralib as lora
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from torch import nn
23
+ from torchvision.ops.boxes import nms
24
+ from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast
25
+
26
+ from groundingdino.util import box_ops, get_tokenlizer
27
+ from groundingdino.util.misc import (
28
+ NestedTensor,
29
+ accuracy,
30
+ get_world_size,
31
+ interpolate,
32
+ inverse_sigmoid,
33
+ is_dist_avail_and_initialized,
34
+ nested_tensor_from_tensor_list,
35
+ )
36
+ from groundingdino.util.utils import get_phrases_from_posmap
37
+ from groundingdino.util.visualizer import COCOVisualizer
38
+ from groundingdino.util.vl_utils import create_positive_map_from_span
39
+
40
+ from ..registry import MODULE_BUILD_FUNCS
41
+ from .backbone import build_backbone
42
+ from .bertwarper import (
43
+ BertModelWarper,
44
+ generate_masks_with_special_tokens,
45
+ generate_masks_with_special_tokens_and_transfer_map,
46
+ )
47
+ from .transformer import build_transformer
48
+ from .utils import MLP, ContrastiveEmbed, sigmoid_focal_loss
49
+
50
+
51
+ class GroundingDINO(nn.Module):
52
+ """This is the Cross-Attention Detector module that performs object detection"""
53
+
54
+ def __init__(
55
+ self,
56
+ backbone,
57
+ transformer,
58
+ num_queries,
59
+ aux_loss=False,
60
+ iter_update=False,
61
+ query_dim=2,
62
+ num_feature_levels=1,
63
+ nheads=8,
64
+ # two stage
65
+ two_stage_type="no", # ['no', 'standard']
66
+ dec_pred_bbox_embed_share=True,
67
+ two_stage_class_embed_share=True,
68
+ two_stage_bbox_embed_share=True,
69
+ num_patterns=0,
70
+ dn_number=100,
71
+ dn_box_noise_scale=0.4,
72
+ dn_label_noise_ratio=0.5,
73
+ dn_labelbook_size=100,
74
+ text_encoder_type="bert-base-uncased",
75
+ sub_sentence_present=True,
76
+ max_text_len=256,
77
+ ):
78
+ """Initializes the model.
79
+ Parameters:
80
+ backbone: torch module of the backbone to be used. See backbone.py
81
+ transformer: torch module of the transformer architecture. See transformer.py
82
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
83
+ Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
84
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
85
+ """
86
+ super().__init__()
87
+ self.num_queries = num_queries
88
+ self.transformer = transformer
89
+ self.hidden_dim = hidden_dim = transformer.d_model
90
+ self.num_feature_levels = num_feature_levels
91
+ self.nheads = nheads
92
+ self.max_text_len = 256
93
+ self.sub_sentence_present = sub_sentence_present
94
+
95
+ # setting query dim
96
+ self.query_dim = query_dim
97
+ assert query_dim == 4
98
+
99
+ # for dn training
100
+ self.num_patterns = num_patterns
101
+ self.dn_number = dn_number
102
+ self.dn_box_noise_scale = dn_box_noise_scale
103
+ self.dn_label_noise_ratio = dn_label_noise_ratio
104
+ self.dn_labelbook_size = dn_labelbook_size
105
+
106
+ # bert
107
+ self.tokenizer = get_tokenlizer.get_tokenlizer(text_encoder_type)
108
+ self.bert = get_tokenlizer.get_pretrained_language_model(text_encoder_type)
109
+ self.bert.pooler.dense.weight.requires_grad_(False)
110
+ self.bert.pooler.dense.bias.requires_grad_(False)
111
+ self.bert = BertModelWarper(bert_model=self.bert)
112
+
113
+ self.feat_map = nn.Linear(self.bert.config.hidden_size, self.hidden_dim, bias=True)
114
+ nn.init.constant_(self.feat_map.bias.data, 0)
115
+ nn.init.xavier_uniform_(self.feat_map.weight.data)
116
+ # freeze
117
+
118
+ # special tokens
119
+ self.specical_tokens = self.tokenizer.convert_tokens_to_ids(["[CLS]", "[SEP]", ".", "?"])
120
+
121
+ # prepare input projection layers
122
+ if num_feature_levels > 1:
123
+ num_backbone_outs = len(backbone.num_channels)
124
+ input_proj_list = []
125
+ for _ in range(num_backbone_outs):
126
+ in_channels = backbone.num_channels[_]
127
+ input_proj_list.append(
128
+ nn.Sequential(
129
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
130
+ nn.GroupNorm(32, hidden_dim),
131
+ )
132
+ )
133
+ for _ in range(num_feature_levels - num_backbone_outs):
134
+ input_proj_list.append(
135
+ nn.Sequential(
136
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
137
+ nn.GroupNorm(32, hidden_dim),
138
+ )
139
+ )
140
+ in_channels = hidden_dim
141
+ self.input_proj = nn.ModuleList(input_proj_list)
142
+ else:
143
+ assert two_stage_type == "no", "two_stage_type should be no if num_feature_levels=1 !!!"
144
+ self.input_proj = nn.ModuleList(
145
+ [
146
+ nn.Sequential(
147
+ nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1),
148
+ nn.GroupNorm(32, hidden_dim),
149
+ )
150
+ ]
151
+ )
152
+
153
+ self.backbone = backbone
154
+ self.aux_loss = aux_loss
155
+ self.box_pred_damping = box_pred_damping = None
156
+
157
+ self.iter_update = iter_update
158
+ assert iter_update, "Why not iter_update?"
159
+
160
+ # prepare pred layers
161
+ self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share
162
+ # prepare class & box embed
163
+ _class_embed = ContrastiveEmbed()
164
+
165
+ _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
166
+ nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
167
+ nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
168
+
169
+ if dec_pred_bbox_embed_share:
170
+ box_embed_layerlist = [_bbox_embed for i in range(transformer.num_decoder_layers)]
171
+ else:
172
+ box_embed_layerlist = [
173
+ copy.deepcopy(_bbox_embed) for i in range(transformer.num_decoder_layers)
174
+ ]
175
+ class_embed_layerlist = [_class_embed for i in range(transformer.num_decoder_layers)]
176
+ self.bbox_embed = nn.ModuleList(box_embed_layerlist)
177
+ self.class_embed = nn.ModuleList(class_embed_layerlist)
178
+ self.transformer.decoder.bbox_embed = self.bbox_embed
179
+ self.transformer.decoder.class_embed = self.class_embed
180
+
181
+ # two stage
182
+ self.two_stage_type = two_stage_type
183
+ assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format(
184
+ two_stage_type
185
+ )
186
+ if two_stage_type != "no":
187
+ if two_stage_bbox_embed_share:
188
+ assert dec_pred_bbox_embed_share
189
+ self.transformer.enc_out_bbox_embed = _bbox_embed
190
+ else:
191
+ self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed)
192
+
193
+ if two_stage_class_embed_share:
194
+ assert dec_pred_bbox_embed_share
195
+ self.transformer.enc_out_class_embed = _class_embed
196
+ else:
197
+ self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed)
198
+
199
+ self.refpoint_embed = None
200
+
201
+ self._reset_parameters()
202
+
203
+ def _reset_parameters(self):
204
+ # init input_proj
205
+ for proj in self.input_proj:
206
+ nn.init.xavier_uniform_(proj[0].weight, gain=1)
207
+ nn.init.constant_(proj[0].bias, 0)
208
+
209
+ def set_image_tensor(self, samples: NestedTensor):
210
+ if isinstance(samples, (list, torch.Tensor)):
211
+ samples = nested_tensor_from_tensor_list(samples)
212
+ self.features, self.poss = self.backbone(samples)
213
+
214
+ def unset_image_tensor(self):
215
+ if hasattr(self, 'features'):
216
+ del self.features
217
+ if hasattr(self,'poss'):
218
+ del self.poss
219
+
220
+ def set_image_features(self, features , poss):
221
+ self.features = features
222
+ self.poss = poss
223
+
224
+ def init_ref_points(self, use_num_queries):
225
+ self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim)
226
+
227
+ def forward(self, samples: NestedTensor, targets: List = None, **kw):
228
+ """The forward expects a NestedTensor, which consists of:
229
+ - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
230
+ - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
231
+
232
+ It returns a dict with the following elements:
233
+ - "pred_logits": the classification logits (including no-object) for all queries.
234
+ Shape= [batch_size x num_queries x num_classes]
235
+ - "pred_boxes": The normalized boxes coordinates for all queries, represented as
236
+ (center_x, center_y, width, height). These values are normalized in [0, 1],
237
+ relative to the size of each individual image (disregarding possible padding).
238
+ See PostProcess for information on how to retrieve the unnormalized bounding box.
239
+ - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
240
+ dictionnaries containing the two above keys for each decoder layer.
241
+ """
242
+ if targets is None:
243
+ captions = kw["captions"]
244
+ else:
245
+ captions = [t["caption"] for t in targets]
246
+
247
+ # encoder texts
248
+ tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to(
249
+ samples.device
250
+ )
251
+ (
252
+ text_self_attention_masks,
253
+ position_ids,
254
+ cate_to_token_mask_list,
255
+ ) = generate_masks_with_special_tokens_and_transfer_map(
256
+ tokenized, self.specical_tokens, self.tokenizer
257
+ )
258
+
259
+ if text_self_attention_masks.shape[1] > self.max_text_len:
260
+ text_self_attention_masks = text_self_attention_masks[
261
+ :, : self.max_text_len, : self.max_text_len
262
+ ]
263
+ position_ids = position_ids[:, : self.max_text_len]
264
+ tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len]
265
+ tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len]
266
+ tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len]
267
+
268
+ # extract text embeddings
269
+ if self.sub_sentence_present:
270
+ tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"}
271
+ tokenized_for_encoder["attention_mask"] = text_self_attention_masks
272
+ tokenized_for_encoder["position_ids"] = position_ids
273
+ else:
274
+ # import ipdb; ipdb.set_trace()
275
+ tokenized_for_encoder = tokenized
276
+
277
+ bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768
278
+
279
+ encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model
280
+ text_token_mask = tokenized.attention_mask.bool() # bs, 195
281
+ # text_token_mask: True for nomask, False for mask
282
+ # text_self_attention_masks: True for nomask, False for mask
283
+
284
+ if encoded_text.shape[1] > self.max_text_len:
285
+ encoded_text = encoded_text[:, : self.max_text_len, :]
286
+ text_token_mask = text_token_mask[:, : self.max_text_len]
287
+ position_ids = position_ids[:, : self.max_text_len]
288
+ text_self_attention_masks = text_self_attention_masks[
289
+ :, : self.max_text_len, : self.max_text_len
290
+ ]
291
+
292
+ text_dict = {
293
+ "encoded_text": encoded_text, # bs, 195, d_model
294
+ "text_token_mask": text_token_mask, # bs, 195
295
+ "position_ids": position_ids, # bs, 195
296
+ "text_self_attention_masks": text_self_attention_masks, # bs, 195,195
297
+ }
298
+
299
+ # import ipdb; ipdb.set_trace()
300
+ if isinstance(samples, (list, torch.Tensor)):
301
+ samples = nested_tensor_from_tensor_list(samples)
302
+ if not hasattr(self, 'features') or not hasattr(self, 'poss'):
303
+ self.set_image_tensor(samples)
304
+
305
+ srcs = []
306
+ masks = []
307
+ for l, feat in enumerate(self.features):
308
+ src, mask = feat.decompose()
309
+ srcs.append(self.input_proj[l](src))
310
+ masks.append(mask)
311
+ assert mask is not None
312
+ if self.num_feature_levels > len(srcs):
313
+ _len_srcs = len(srcs)
314
+ for l in range(_len_srcs, self.num_feature_levels):
315
+ if l == _len_srcs:
316
+ src = self.input_proj[l](self.features[-1].tensors)
317
+ else:
318
+ src = self.input_proj[l](srcs[-1])
319
+ m = samples.mask
320
+ mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
321
+ pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
322
+ srcs.append(src)
323
+ masks.append(mask)
324
+ self.poss.append(pos_l)
325
+
326
+ input_query_bbox = input_query_label = attn_mask = dn_meta = None
327
+ hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(
328
+ srcs, masks, input_query_bbox, self.poss, input_query_label, attn_mask, text_dict
329
+ )
330
+
331
+ # deformable-detr-like anchor update
332
+ outputs_coord_list = []
333
+ for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(
334
+ zip(reference[:-1], self.bbox_embed, hs)
335
+ ):
336
+ layer_delta_unsig = layer_bbox_embed(layer_hs)
337
+ layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)
338
+ layer_outputs_unsig = layer_outputs_unsig.sigmoid()
339
+ outputs_coord_list.append(layer_outputs_unsig)
340
+ outputs_coord_list = torch.stack(outputs_coord_list)
341
+
342
+ # output
343
+ outputs_class = torch.stack(
344
+ [
345
+ layer_cls_embed(layer_hs, text_dict)
346
+ for layer_cls_embed, layer_hs in zip(self.class_embed, hs)
347
+ ]
348
+ )
349
+ out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord_list[-1]}
350
+
351
+ # # for intermediate outputs
352
+ # if self.aux_loss:
353
+ # out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord_list)
354
+
355
+ # # for encoder output
356
+ # if hs_enc is not None:
357
+ # # prepare intermediate outputs
358
+ # interm_coord = ref_enc[-1]
359
+ # interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict)
360
+ # out['interm_outputs'] = {'pred_logits': interm_class, 'pred_boxes': interm_coord}
361
+ # out['interm_outputs_for_matching_pre'] = {'pred_logits': interm_class, 'pred_boxes': init_box_proposal}
362
+ unset_image_tensor = kw.get('unset_image_tensor', True)
363
+ if unset_image_tensor:
364
+ self.unset_image_tensor() ## If necessary
365
+ return out
366
+
367
+ @torch.jit.unused
368
+ def _set_aux_loss(self, outputs_class, outputs_coord):
369
+ # this is a workaround to make torchscript happy, as torchscript
370
+ # doesn't support dictionary with non-homogeneous values, such
371
+ # as a dict having both a Tensor and a list.
372
+ return [
373
+ {"pred_logits": a, "pred_boxes": b}
374
+ for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
375
+ ]
376
+
377
+
378
+ @MODULE_BUILD_FUNCS.registe_with_name(module_name="groundingdino")
379
+ def build_groundingdino(args):
380
+
381
+ backbone = build_backbone(args)
382
+ transformer = build_transformer(args)
383
+
384
+ dn_labelbook_size = args.dn_labelbook_size
385
+ dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share
386
+ sub_sentence_present = args.sub_sentence_present
387
+
388
+ model = GroundingDINO(
389
+ backbone,
390
+ transformer,
391
+ num_queries=args.num_queries,
392
+ aux_loss=True,
393
+ iter_update=True,
394
+ query_dim=4,
395
+ num_feature_levels=args.num_feature_levels,
396
+ nheads=args.nheads,
397
+ dec_pred_bbox_embed_share=dec_pred_bbox_embed_share,
398
+ two_stage_type=args.two_stage_type,
399
+ two_stage_bbox_embed_share=args.two_stage_bbox_embed_share,
400
+ two_stage_class_embed_share=args.two_stage_class_embed_share,
401
+ num_patterns=args.num_patterns,
402
+ dn_number=0,
403
+ dn_box_noise_scale=args.dn_box_noise_scale,
404
+ dn_label_noise_ratio=args.dn_label_noise_ratio,
405
+ dn_labelbook_size=dn_labelbook_size,
406
+ text_encoder_type=args.text_encoder_type,
407
+ sub_sentence_present=sub_sentence_present,
408
+ max_text_len=args.max_text_len,
409
+ )
410
+
411
+ return model
412
+
groundingdino/models/GroundingDINO/ms_deform_attn.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # Deformable DETR
8
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------------------------------
11
+ # Modified from:
12
+ # https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/functions/ms_deform_attn_func.py
13
+ # https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py
14
+ # https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/multi_scale_deform_attn.py
15
+ # ------------------------------------------------------------------------------------------------
16
+
17
+ import math
18
+ import warnings
19
+ from typing import Optional
20
+ import loralib as lora
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from torch.autograd import Function
25
+ from torch.autograd.function import once_differentiable
26
+ from torch.nn.init import constant_, xavier_uniform_
27
+
28
+ try:
29
+ from groundingdino import _C
30
+ except:
31
+ warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only!")
32
+
33
+
34
+ # helpers
35
+ def _is_power_of_2(n):
36
+ if (not isinstance(n, int)) or (n < 0):
37
+ raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
38
+ return (n & (n - 1) == 0) and n != 0
39
+
40
+
41
+ class MultiScaleDeformableAttnFunction(Function):
42
+ @staticmethod
43
+ def forward(
44
+ ctx,
45
+ value,
46
+ value_spatial_shapes,
47
+ value_level_start_index,
48
+ sampling_locations,
49
+ attention_weights,
50
+ im2col_step,
51
+ ):
52
+ ctx.im2col_step = im2col_step
53
+ output = _C.ms_deform_attn_forward(
54
+ value,
55
+ value_spatial_shapes,
56
+ value_level_start_index,
57
+ sampling_locations,
58
+ attention_weights,
59
+ ctx.im2col_step,
60
+ )
61
+ ctx.save_for_backward(
62
+ value,
63
+ value_spatial_shapes,
64
+ value_level_start_index,
65
+ sampling_locations,
66
+ attention_weights,
67
+ )
68
+ return output
69
+
70
+ @staticmethod
71
+ @once_differentiable
72
+ def backward(ctx, grad_output):
73
+ (
74
+ value,
75
+ value_spatial_shapes,
76
+ value_level_start_index,
77
+ sampling_locations,
78
+ attention_weights,
79
+ ) = ctx.saved_tensors
80
+ grad_value, grad_sampling_loc, grad_attn_weight = _C.ms_deform_attn_backward(
81
+ value,
82
+ value_spatial_shapes,
83
+ value_level_start_index,
84
+ sampling_locations,
85
+ attention_weights,
86
+ grad_output,
87
+ ctx.im2col_step,
88
+ )
89
+
90
+ return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
91
+
92
+
93
+ def multi_scale_deformable_attn_pytorch(
94
+ value: torch.Tensor,
95
+ value_spatial_shapes: torch.Tensor,
96
+ sampling_locations: torch.Tensor,
97
+ attention_weights: torch.Tensor,
98
+ ) -> torch.Tensor:
99
+
100
+ bs, _, num_heads, embed_dims = value.shape
101
+ _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
102
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
103
+ sampling_grids = 2 * sampling_locations - 1
104
+ sampling_value_list = []
105
+ for level, (H_, W_) in enumerate(value_spatial_shapes):
106
+ # bs, H_*W_, num_heads, embed_dims ->
107
+ # bs, H_*W_, num_heads*embed_dims ->
108
+ # bs, num_heads*embed_dims, H_*W_ ->
109
+ # bs*num_heads, embed_dims, H_, W_
110
+ value_l_ = (
111
+ value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)
112
+ )
113
+ # bs, num_queries, num_heads, num_points, 2 ->
114
+ # bs, num_heads, num_queries, num_points, 2 ->
115
+ # bs*num_heads, num_queries, num_points, 2
116
+ sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
117
+ # bs*num_heads, embed_dims, num_queries, num_points
118
+ sampling_value_l_ = F.grid_sample(
119
+ value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
120
+ )
121
+ sampling_value_list.append(sampling_value_l_)
122
+ # (bs, num_queries, num_heads, num_levels, num_points) ->
123
+ # (bs, num_heads, num_queries, num_levels, num_points) ->
124
+ # (bs, num_heads, 1, num_queries, num_levels*num_points)
125
+ attention_weights = attention_weights.transpose(1, 2).reshape(
126
+ bs * num_heads, 1, num_queries, num_levels * num_points
127
+ )
128
+ output = (
129
+ (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
130
+ .sum(-1)
131
+ .view(bs, num_heads * embed_dims, num_queries)
132
+ )
133
+ return output.transpose(1, 2).contiguous()
134
+
135
+
136
+ class MultiScaleDeformableAttention(nn.Module):
137
+ """Multi-Scale Deformable Attention Module used in Deformable-DETR
138
+
139
+ `Deformable DETR: Deformable Transformers for End-to-End Object Detection.
140
+ <https://arxiv.org/pdf/2010.04159.pdf>`_.
141
+
142
+ Args:
143
+ embed_dim (int): The embedding dimension of Attention. Default: 256.
144
+ num_heads (int): The number of attention heads. Default: 8.
145
+ num_levels (int): The number of feature map used in Attention. Default: 4.
146
+ num_points (int): The number of sampling points for each query
147
+ in each head. Default: 4.
148
+ img2col_steps (int): The step used in image_to_column. Defualt: 64.
149
+ dropout (float): Dropout layer used in output. Default: 0.1.
150
+ batch_first (bool): if ``True``, then the input and output tensor will be
151
+ provided as `(bs, n, embed_dim)`. Default: False. `(n, bs, embed_dim)`
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ embed_dim: int = 256,
157
+ num_heads: int = 8,
158
+ num_levels: int = 4,
159
+ num_points: int = 4,
160
+ img2col_step: int = 64,
161
+ batch_first: bool = False,
162
+ ):
163
+ super().__init__()
164
+ if embed_dim % num_heads != 0:
165
+ raise ValueError(
166
+ "embed_dim must be divisible by num_heads, but got {} and {}".format(
167
+ embed_dim, num_heads
168
+ )
169
+ )
170
+ head_dim = embed_dim // num_heads
171
+
172
+ self.batch_first = batch_first
173
+
174
+ if not _is_power_of_2(head_dim):
175
+ warnings.warn(
176
+ """
177
+ You'd better set d_model in MSDeformAttn to make sure that
178
+ each dim of the attention head a power of 2, which is more efficient.
179
+ """
180
+ )
181
+
182
+ self.im2col_step = img2col_step
183
+ self.embed_dim = embed_dim
184
+ self.num_heads = num_heads
185
+ self.num_levels = num_levels
186
+ self.num_points = num_points
187
+ r = 16
188
+ self.sampling_offsets = lora.Linear(embed_dim, num_heads * num_levels * num_points * 2 , r=r)
189
+ self.attention_weights = lora.Linear(embed_dim, num_heads * num_levels * num_points , r=r)
190
+ self.value_proj = lora.Linear(embed_dim, embed_dim , r=r)
191
+ self.output_proj = lora.Linear(embed_dim, embed_dim , r=r)
192
+
193
+ self.init_weights()
194
+
195
+ def _reset_parameters(self):
196
+ return self.init_weights()
197
+
198
+ def init_weights(self):
199
+ """
200
+ Default initialization for Parameters of Module.
201
+ """
202
+ constant_(self.sampling_offsets.weight.data, 0.0)
203
+ thetas = torch.arange(self.num_heads, dtype=torch.float32) * (
204
+ 2.0 * math.pi / self.num_heads
205
+ )
206
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
207
+ grid_init = (
208
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
209
+ .view(self.num_heads, 1, 1, 2)
210
+ .repeat(1, self.num_levels, self.num_points, 1)
211
+ )
212
+ for i in range(self.num_points):
213
+ grid_init[:, :, i, :] *= i + 1
214
+ with torch.no_grad():
215
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
216
+ constant_(self.attention_weights.weight.data, 0.0)
217
+ constant_(self.attention_weights.bias.data, 0.0)
218
+ xavier_uniform_(self.value_proj.weight.data)
219
+ constant_(self.value_proj.bias.data, 0.0)
220
+ xavier_uniform_(self.output_proj.weight.data)
221
+ constant_(self.output_proj.bias.data, 0.0)
222
+
223
+ def freeze_sampling_offsets(self):
224
+ print("Freeze sampling offsets")
225
+ self.sampling_offsets.weight.requires_grad = False
226
+ self.sampling_offsets.bias.requires_grad = False
227
+
228
+ def freeze_attention_weights(self):
229
+ print("Freeze attention weights")
230
+ self.attention_weights.weight.requires_grad = False
231
+ self.attention_weights.bias.requires_grad = False
232
+
233
+ def forward(
234
+ self,
235
+ query: torch.Tensor,
236
+ key: Optional[torch.Tensor] = None,
237
+ value: Optional[torch.Tensor] = None,
238
+ query_pos: Optional[torch.Tensor] = None,
239
+ key_padding_mask: Optional[torch.Tensor] = None,
240
+ reference_points: Optional[torch.Tensor] = None,
241
+ spatial_shapes: Optional[torch.Tensor] = None,
242
+ level_start_index: Optional[torch.Tensor] = None,
243
+ **kwargs
244
+ ) -> torch.Tensor:
245
+
246
+ """Forward Function of MultiScaleDeformableAttention
247
+
248
+ Args:
249
+ query (torch.Tensor): Query embeddings with shape
250
+ `(num_query, bs, embed_dim)`
251
+ key (torch.Tensor): Key embeddings with shape
252
+ `(num_key, bs, embed_dim)`
253
+ value (torch.Tensor): Value embeddings with shape
254
+ `(num_key, bs, embed_dim)`
255
+ query_pos (torch.Tensor): The position embedding for `query`. Default: None.
256
+ key_padding_mask (torch.Tensor): ByteTensor for `query`, with shape `(bs, num_key)`,
257
+ indicating which elements within `key` to be ignored in attention.
258
+ reference_points (torch.Tensor): The normalized reference points
259
+ with shape `(bs, num_query, num_levels, 2)`,
260
+ all elements is range in [0, 1], top-left (0, 0),
261
+ bottom-right (1, 1), including padding are.
262
+ or `(N, Length_{query}, num_levels, 4)`, add additional
263
+ two dimensions `(h, w)` to form reference boxes.
264
+ spatial_shapes (torch.Tensor): Spatial shape of features in different levels.
265
+ With shape `(num_levels, 2)`, last dimension represents `(h, w)`.
266
+ level_start_index (torch.Tensor): The start index of each level. A tensor with
267
+ shape `(num_levels, )` which can be represented as
268
+ `[0, h_0 * w_0, h_0 * w_0 + h_1 * w_1, ...]`.
269
+
270
+ Returns:
271
+ torch.Tensor: forward results with shape `(num_query, bs, embed_dim)`
272
+ """
273
+
274
+ if value is None:
275
+ value = query
276
+
277
+ if query_pos is not None:
278
+ query = query + query_pos
279
+
280
+ if not self.batch_first:
281
+ # change to (bs, num_query ,embed_dims)
282
+ query = query.permute(1, 0, 2)
283
+ value = value.permute(1, 0, 2)
284
+
285
+ bs, num_query, _ = query.shape
286
+ bs, num_value, _ = value.shape
287
+
288
+ assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
289
+
290
+ value = self.value_proj(value)
291
+ if key_padding_mask is not None:
292
+ value = value.masked_fill(key_padding_mask[..., None], float(0))
293
+ value = value.view(bs, num_value, self.num_heads, -1)
294
+ sampling_offsets = self.sampling_offsets(query).view(
295
+ bs, num_query, self.num_heads, self.num_levels, self.num_points, 2
296
+ )
297
+ attention_weights = self.attention_weights(query).view(
298
+ bs, num_query, self.num_heads, self.num_levels * self.num_points
299
+ )
300
+ attention_weights = attention_weights.softmax(-1)
301
+ attention_weights = attention_weights.view(
302
+ bs,
303
+ num_query,
304
+ self.num_heads,
305
+ self.num_levels,
306
+ self.num_points,
307
+ )
308
+
309
+ # bs, num_query, num_heads, num_levels, num_points, 2
310
+ if reference_points.shape[-1] == 2:
311
+ offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
312
+ sampling_locations = (
313
+ reference_points[:, :, None, :, None, :]
314
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
315
+ )
316
+ elif reference_points.shape[-1] == 4:
317
+ sampling_locations = (
318
+ reference_points[:, :, None, :, None, :2]
319
+ + sampling_offsets
320
+ / self.num_points
321
+ * reference_points[:, :, None, :, None, 2:]
322
+ * 0.5
323
+ )
324
+ else:
325
+ raise ValueError(
326
+ "Last dim of reference_points must be 2 or 4, but get {} instead.".format(
327
+ reference_points.shape[-1]
328
+ )
329
+ )
330
+
331
+ if torch.cuda.is_available() and value.is_cuda:
332
+ halffloat = False
333
+ if value.dtype == torch.float16:
334
+ halffloat = True
335
+ value = value.float()
336
+ sampling_locations = sampling_locations.float()
337
+ attention_weights = attention_weights.float()
338
+
339
+ output = MultiScaleDeformableAttnFunction.apply(
340
+ value,
341
+ spatial_shapes,
342
+ level_start_index,
343
+ sampling_locations,
344
+ attention_weights,
345
+ self.im2col_step,
346
+ )
347
+
348
+ if halffloat:
349
+ output = output.half()
350
+ else:
351
+ output = multi_scale_deformable_attn_pytorch(
352
+ value, spatial_shapes, sampling_locations, attention_weights
353
+ )
354
+
355
+ output = self.output_proj(output)
356
+
357
+ if not self.batch_first:
358
+ output = output.permute(1, 0, 2)
359
+
360
+ return output
361
+
362
+
363
+ def create_dummy_class(klass, dependency, message=""):
364
+ """
365
+ When a dependency of a class is not available, create a dummy class which throws ImportError
366
+ when used.
367
+
368
+ Args:
369
+ klass (str): name of the class.
370
+ dependency (str): name of the dependency.
371
+ message: extra message to print
372
+ Returns:
373
+ class: a class object
374
+ """
375
+ err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, klass)
376
+ if message:
377
+ err = err + " " + message
378
+
379
+ class _DummyMetaClass(type):
380
+ # throw error on class attribute access
381
+ def __getattr__(_, __): # noqa: B902
382
+ raise ImportError(err)
383
+
384
+ class _Dummy(object, metaclass=_DummyMetaClass):
385
+ # throw error on constructor
386
+ def __init__(self, *args, **kwargs):
387
+ raise ImportError(err)
388
+
389
+ return _Dummy
390
+
391
+
392
+ def create_dummy_func(func, dependency, message=""):
393
+ """
394
+ When a dependency of a function is not available, create a dummy function which throws
395
+ ImportError when used.
396
+
397
+ Args:
398
+ func (str): name of the function.
399
+ dependency (str or list[str]): name(s) of the dependency.
400
+ message: extra message to print
401
+ Returns:
402
+ function: a function object
403
+ """
404
+ err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, func)
405
+ if message:
406
+ err = err + " " + message
407
+
408
+ if isinstance(dependency, (list, tuple)):
409
+ dependency = ",".join(dependency)
410
+
411
+ def _dummy(*args, **kwargs):
412
+ raise ImportError(err)
413
+
414
+ return _dummy
groundingdino/models/GroundingDINO/transformer.py ADDED
@@ -0,0 +1,961 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # DINO
8
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------
11
+ # Conditional DETR Transformer class.
12
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
13
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
14
+ # ------------------------------------------------------------------------
15
+ # Modified from DETR (https://github.com/facebookresearch/detr)
16
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
17
+ # ------------------------------------------------------------------------
18
+
19
+ from typing import Optional
20
+
21
+ import torch
22
+ import torch.utils.checkpoint as checkpoint
23
+ from torch import Tensor, nn
24
+ import loralib as lora
25
+ from groundingdino.util.misc import inverse_sigmoid
26
+
27
+ from .fuse_modules import BiAttentionBlock
28
+ from .ms_deform_attn import MultiScaleDeformableAttention as MSDeformAttn
29
+ from .transformer_vanilla import TransformerEncoderLayer
30
+ from .utils import (
31
+ MLP,
32
+ _get_activation_fn,
33
+ _get_clones,
34
+ gen_encoder_output_proposals,
35
+ gen_sineembed_for_position,
36
+ get_sine_pos_embed,
37
+ )
38
+
39
+
40
+ class Transformer(nn.Module):
41
+ def __init__(
42
+ self,
43
+ d_model=256,
44
+ nhead=8,
45
+ num_queries=300,
46
+ num_encoder_layers=6,
47
+ num_unicoder_layers=0,
48
+ num_decoder_layers=6,
49
+ dim_feedforward=2048,
50
+ dropout=0.0,
51
+ activation="relu",
52
+ normalize_before=False,
53
+ return_intermediate_dec=False,
54
+ query_dim=4,
55
+ num_patterns=0,
56
+ # for deformable encoder
57
+ num_feature_levels=1,
58
+ enc_n_points=4,
59
+ dec_n_points=4,
60
+ # init query
61
+ learnable_tgt_init=False,
62
+ # two stage
63
+ two_stage_type="no", # ['no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1']
64
+ embed_init_tgt=False,
65
+ # for text
66
+ use_text_enhancer=False,
67
+ use_fusion_layer=False,
68
+ use_checkpoint=False,
69
+ use_transformer_ckpt=False,
70
+ use_text_cross_attention=False,
71
+ text_dropout=0.1,
72
+ fusion_dropout=0.1,
73
+ fusion_droppath=0.0,
74
+ ):
75
+ super().__init__()
76
+ self.num_feature_levels = num_feature_levels
77
+ self.num_encoder_layers = num_encoder_layers
78
+ self.num_unicoder_layers = num_unicoder_layers
79
+ self.num_decoder_layers = num_decoder_layers
80
+ self.num_queries = num_queries
81
+ assert query_dim == 4
82
+
83
+ # choose encoder layer type
84
+ encoder_layer = DeformableTransformerEncoderLayer(
85
+ d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points
86
+ )
87
+
88
+ if use_text_enhancer:
89
+ text_enhance_layer = TransformerEncoderLayer(
90
+ d_model=d_model,
91
+ nhead=nhead // 2,
92
+ dim_feedforward=dim_feedforward // 2,
93
+ dropout=text_dropout,
94
+ )
95
+ else:
96
+ text_enhance_layer = None
97
+
98
+ if use_fusion_layer:
99
+ feature_fusion_layer = BiAttentionBlock(
100
+ v_dim=d_model,
101
+ l_dim=d_model,
102
+ embed_dim=dim_feedforward // 2,
103
+ num_heads=nhead // 2,
104
+ dropout=fusion_dropout,
105
+ drop_path=fusion_droppath,
106
+ )
107
+ else:
108
+ feature_fusion_layer = None
109
+
110
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
111
+ assert encoder_norm is None
112
+ self.encoder = TransformerEncoder(
113
+ encoder_layer,
114
+ num_encoder_layers,
115
+ d_model=d_model,
116
+ num_queries=num_queries,
117
+ text_enhance_layer=text_enhance_layer,
118
+ feature_fusion_layer=feature_fusion_layer,
119
+ use_checkpoint=use_checkpoint,
120
+ use_transformer_ckpt=use_transformer_ckpt,
121
+ )
122
+
123
+ # choose decoder layer type
124
+ decoder_layer = DeformableTransformerDecoderLayer(
125
+ d_model,
126
+ dim_feedforward,
127
+ dropout,
128
+ activation,
129
+ num_feature_levels,
130
+ nhead,
131
+ dec_n_points,
132
+ use_text_cross_attention=use_text_cross_attention,
133
+ )
134
+
135
+ decoder_norm = nn.LayerNorm(d_model)
136
+ self.decoder = TransformerDecoder(
137
+ decoder_layer,
138
+ num_decoder_layers,
139
+ decoder_norm,
140
+ return_intermediate=return_intermediate_dec,
141
+ d_model=d_model,
142
+ query_dim=query_dim,
143
+ num_feature_levels=num_feature_levels,
144
+ )
145
+
146
+ self.d_model = d_model
147
+ self.nhead = nhead
148
+ self.dec_layers = num_decoder_layers
149
+ self.num_queries = num_queries # useful for single stage model only
150
+ self.num_patterns = num_patterns
151
+ if not isinstance(num_patterns, int):
152
+ Warning("num_patterns should be int but {}".format(type(num_patterns)))
153
+ self.num_patterns = 0
154
+
155
+ if num_feature_levels > 1:
156
+ if self.num_encoder_layers > 0:
157
+ self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
158
+ else:
159
+ self.level_embed = None
160
+
161
+ self.learnable_tgt_init = learnable_tgt_init
162
+ assert learnable_tgt_init, "why not learnable_tgt_init"
163
+ self.embed_init_tgt = embed_init_tgt
164
+ if (two_stage_type != "no" and embed_init_tgt) or (two_stage_type == "no"):
165
+ self.tgt_embed = nn.Embedding(self.num_queries, d_model)
166
+ nn.init.normal_(self.tgt_embed.weight.data)
167
+ else:
168
+ self.tgt_embed = None
169
+
170
+ # for two stage
171
+ self.two_stage_type = two_stage_type
172
+ assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format(
173
+ two_stage_type
174
+ )
175
+ if two_stage_type == "standard":
176
+ # anchor selection at the output of encoder
177
+ self.enc_output = nn.Linear(d_model, d_model)
178
+ self.enc_output_norm = nn.LayerNorm(d_model)
179
+ self.two_stage_wh_embedding = None
180
+
181
+ if two_stage_type == "no":
182
+ self.init_ref_points(num_queries) # init self.refpoint_embed
183
+
184
+ self.enc_out_class_embed = None
185
+ self.enc_out_bbox_embed = None
186
+
187
+ self._reset_parameters()
188
+
189
+ def _reset_parameters(self):
190
+ for p in self.parameters():
191
+ if p.dim() > 1:
192
+ nn.init.xavier_uniform_(p)
193
+ for m in self.modules():
194
+ if isinstance(m, MSDeformAttn):
195
+ m._reset_parameters()
196
+ if self.num_feature_levels > 1 and self.level_embed is not None:
197
+ nn.init.normal_(self.level_embed)
198
+
199
+ def get_valid_ratio(self, mask):
200
+ _, H, W = mask.shape
201
+ valid_H = torch.sum(~mask[:, :, 0], 1)
202
+ valid_W = torch.sum(~mask[:, 0, :], 1)
203
+ valid_ratio_h = valid_H.float() / H
204
+ valid_ratio_w = valid_W.float() / W
205
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
206
+ return valid_ratio
207
+
208
+ def init_ref_points(self, use_num_queries):
209
+ self.refpoint_embed = nn.Embedding(use_num_queries, 4)
210
+
211
+ def forward(self, srcs, masks, refpoint_embed, pos_embeds, tgt, attn_mask=None, text_dict=None):
212
+ """
213
+ Input:
214
+ - srcs: List of multi features [bs, ci, hi, wi]
215
+ - masks: List of multi masks [bs, hi, wi]
216
+ - refpoint_embed: [bs, num_dn, 4]. None in infer
217
+ - pos_embeds: List of multi pos embeds [bs, ci, hi, wi]
218
+ - tgt: [bs, num_dn, d_model]. None in infer
219
+
220
+ """
221
+ # prepare input for encoder
222
+ src_flatten = []
223
+ mask_flatten = []
224
+ lvl_pos_embed_flatten = []
225
+ spatial_shapes = []
226
+ for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
227
+ bs, c, h, w = src.shape
228
+ spatial_shape = (h, w)
229
+ spatial_shapes.append(spatial_shape)
230
+
231
+ src = src.flatten(2).transpose(1, 2) # bs, hw, c
232
+ mask = mask.flatten(1) # bs, hw
233
+ pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
234
+ if self.num_feature_levels > 1 and self.level_embed is not None:
235
+ lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
236
+ else:
237
+ lvl_pos_embed = pos_embed
238
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
239
+ src_flatten.append(src)
240
+ mask_flatten.append(mask)
241
+ src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
242
+ mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
243
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c
244
+ spatial_shapes = torch.as_tensor(
245
+ spatial_shapes, dtype=torch.long, device=src_flatten.device
246
+ )
247
+ level_start_index = torch.cat(
248
+ (spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])
249
+ )
250
+ valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
251
+
252
+ # two stage
253
+ enc_topk_proposals = enc_refpoint_embed = None
254
+
255
+ #########################################################
256
+ # Begin Encoder
257
+ #########################################################
258
+ memory, memory_text = self.encoder(
259
+ src_flatten,
260
+ pos=lvl_pos_embed_flatten,
261
+ level_start_index=level_start_index,
262
+ spatial_shapes=spatial_shapes,
263
+ valid_ratios=valid_ratios,
264
+ key_padding_mask=mask_flatten,
265
+ memory_text=text_dict["encoded_text"],
266
+ text_attention_mask=~text_dict["text_token_mask"],
267
+ # we ~ the mask . False means use the token; True means pad the token
268
+ position_ids=text_dict["position_ids"],
269
+ text_self_attention_masks=text_dict["text_self_attention_masks"],
270
+ )
271
+ #########################################################
272
+ # End Encoder
273
+ # - memory: bs, \sum{hw}, c
274
+ # - mask_flatten: bs, \sum{hw}
275
+ # - lvl_pos_embed_flatten: bs, \sum{hw}, c
276
+ # - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
277
+ # - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
278
+ #########################################################
279
+ text_dict["encoded_text"] = memory_text
280
+ # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
281
+ # if memory.isnan().any() | memory.isinf().any():
282
+ # import ipdb; ipdb.set_trace()
283
+
284
+ if self.two_stage_type == "standard":
285
+ output_memory, output_proposals = gen_encoder_output_proposals(
286
+ memory, mask_flatten, spatial_shapes
287
+ )
288
+ output_memory = self.enc_output_norm(self.enc_output(output_memory))
289
+
290
+ if text_dict is not None:
291
+ enc_outputs_class_unselected = self.enc_out_class_embed(output_memory, text_dict)
292
+ else:
293
+ enc_outputs_class_unselected = self.enc_out_class_embed(output_memory)
294
+
295
+ topk_logits = enc_outputs_class_unselected.max(-1)[0]
296
+ enc_outputs_coord_unselected = (
297
+ self.enc_out_bbox_embed(output_memory) + output_proposals
298
+ ) # (bs, \sum{hw}, 4) unsigmoid
299
+ topk = self.num_queries
300
+
301
+ topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq
302
+
303
+ # gather boxes
304
+ refpoint_embed_undetach = torch.gather(
305
+ enc_outputs_coord_unselected, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
306
+ ) # unsigmoid
307
+ refpoint_embed_ = refpoint_embed_undetach.detach()
308
+ init_box_proposal = torch.gather(
309
+ output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
310
+ ).sigmoid() # sigmoid
311
+
312
+ # gather tgt
313
+ tgt_undetach = torch.gather(
314
+ output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model)
315
+ )
316
+ if self.embed_init_tgt:
317
+ tgt_ = (
318
+ self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
319
+ ) # nq, bs, d_model
320
+ else:
321
+ tgt_ = tgt_undetach.detach()
322
+
323
+ if refpoint_embed is not None:
324
+ refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1)
325
+ tgt = torch.cat([tgt, tgt_], dim=1)
326
+ else:
327
+ refpoint_embed, tgt = refpoint_embed_, tgt_
328
+
329
+ elif self.two_stage_type == "no":
330
+ tgt_ = (
331
+ self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
332
+ ) # nq, bs, d_model
333
+ refpoint_embed_ = (
334
+ self.refpoint_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
335
+ ) # nq, bs, 4
336
+
337
+ if refpoint_embed is not None:
338
+ refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1)
339
+ tgt = torch.cat([tgt, tgt_], dim=1)
340
+ else:
341
+ refpoint_embed, tgt = refpoint_embed_, tgt_
342
+
343
+ if self.num_patterns > 0:
344
+ tgt_embed = tgt.repeat(1, self.num_patterns, 1)
345
+ refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1)
346
+ tgt_pat = self.patterns.weight[None, :, :].repeat_interleave(
347
+ self.num_queries, 1
348
+ ) # 1, n_q*n_pat, d_model
349
+ tgt = tgt_embed + tgt_pat
350
+
351
+ init_box_proposal = refpoint_embed_.sigmoid()
352
+
353
+ else:
354
+ raise NotImplementedError("unknown two_stage_type {}".format(self.two_stage_type))
355
+ #########################################################
356
+ # End preparing tgt
357
+ # - tgt: bs, NQ, d_model
358
+ # - refpoint_embed(unsigmoid): bs, NQ, d_model
359
+ #########################################################
360
+
361
+ #########################################################
362
+ # Begin Decoder
363
+ #########################################################
364
+ hs, references = self.decoder(
365
+ tgt=tgt.transpose(0, 1),
366
+ memory=memory.transpose(0, 1),
367
+ memory_key_padding_mask=mask_flatten,
368
+ pos=lvl_pos_embed_flatten.transpose(0, 1),
369
+ refpoints_unsigmoid=refpoint_embed.transpose(0, 1),
370
+ level_start_index=level_start_index,
371
+ spatial_shapes=spatial_shapes,
372
+ valid_ratios=valid_ratios,
373
+ tgt_mask=attn_mask,
374
+ memory_text=text_dict["encoded_text"],
375
+ text_attention_mask=~text_dict["text_token_mask"],
376
+ # we ~ the mask . False means use the token; True means pad the token
377
+ )
378
+ #########################################################
379
+ # End Decoder
380
+ # hs: n_dec, bs, nq, d_model
381
+ # references: n_dec+1, bs, nq, query_dim
382
+ #########################################################
383
+
384
+ #########################################################
385
+ # Begin postprocess
386
+ #########################################################
387
+ if self.two_stage_type == "standard":
388
+ hs_enc = tgt_undetach.unsqueeze(0)
389
+ ref_enc = refpoint_embed_undetach.sigmoid().unsqueeze(0)
390
+ else:
391
+ hs_enc = ref_enc = None
392
+ #########################################################
393
+ # End postprocess
394
+ # hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or (n_enc, bs, nq, d_model) or None
395
+ # ref_enc: (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or (n_enc, bs, nq, d_model) or None
396
+ #########################################################
397
+
398
+ return hs, references, hs_enc, ref_enc, init_box_proposal
399
+ # hs: (n_dec, bs, nq, d_model)
400
+ # references: sigmoid coordinates. (n_dec+1, bs, bq, 4)
401
+ # hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or None
402
+ # ref_enc: sigmoid coordinates. \
403
+ # (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or None
404
+
405
+
406
+ class TransformerEncoder(nn.Module):
407
+ def __init__(
408
+ self,
409
+ encoder_layer,
410
+ num_layers,
411
+ d_model=256,
412
+ num_queries=300,
413
+ enc_layer_share=False,
414
+ text_enhance_layer=None,
415
+ feature_fusion_layer=None,
416
+ use_checkpoint=False,
417
+ use_transformer_ckpt=False,
418
+ ):
419
+ """_summary_
420
+
421
+ Args:
422
+ encoder_layer (_type_): _description_
423
+ num_layers (_type_): _description_
424
+ norm (_type_, optional): _description_. Defaults to None.
425
+ d_model (int, optional): _description_. Defaults to 256.
426
+ num_queries (int, optional): _description_. Defaults to 300.
427
+ enc_layer_share (bool, optional): _description_. Defaults to False.
428
+
429
+ """
430
+ super().__init__()
431
+ # prepare layers
432
+ self.layers = []
433
+ self.text_layers = []
434
+ self.fusion_layers = []
435
+ if num_layers > 0:
436
+ self.layers = _get_clones(encoder_layer, num_layers, layer_share=enc_layer_share)
437
+
438
+ if text_enhance_layer is not None:
439
+ self.text_layers = _get_clones(
440
+ text_enhance_layer, num_layers, layer_share=enc_layer_share
441
+ )
442
+ if feature_fusion_layer is not None:
443
+ self.fusion_layers = _get_clones(
444
+ feature_fusion_layer, num_layers, layer_share=enc_layer_share
445
+ )
446
+ else:
447
+ self.layers = []
448
+ del encoder_layer
449
+
450
+ if text_enhance_layer is not None:
451
+ self.text_layers = []
452
+ del text_enhance_layer
453
+ if feature_fusion_layer is not None:
454
+ self.fusion_layers = []
455
+ del feature_fusion_layer
456
+
457
+ self.query_scale = None
458
+ self.num_queries = num_queries
459
+ self.num_layers = num_layers
460
+ self.d_model = d_model
461
+
462
+ self.use_checkpoint = use_checkpoint
463
+ self.use_transformer_ckpt = use_transformer_ckpt
464
+
465
+ @staticmethod
466
+ def get_reference_points(spatial_shapes, valid_ratios, device):
467
+ reference_points_list = []
468
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
469
+
470
+ ref_y, ref_x = torch.meshgrid(
471
+ torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
472
+ torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
473
+ )
474
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
475
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
476
+ ref = torch.stack((ref_x, ref_y), -1)
477
+ reference_points_list.append(ref)
478
+ reference_points = torch.cat(reference_points_list, 1)
479
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
480
+ return reference_points
481
+
482
+ def forward(
483
+ self,
484
+ # for images
485
+ src: Tensor,
486
+ pos: Tensor,
487
+ spatial_shapes: Tensor,
488
+ level_start_index: Tensor,
489
+ valid_ratios: Tensor,
490
+ key_padding_mask: Tensor,
491
+ # for texts
492
+ memory_text: Tensor = None,
493
+ text_attention_mask: Tensor = None,
494
+ pos_text: Tensor = None,
495
+ text_self_attention_masks: Tensor = None,
496
+ position_ids: Tensor = None,
497
+ ):
498
+ """
499
+ Input:
500
+ - src: [bs, sum(hi*wi), 256]
501
+ - pos: pos embed for src. [bs, sum(hi*wi), 256]
502
+ - spatial_shapes: h,w of each level [num_level, 2]
503
+ - level_start_index: [num_level] start point of level in sum(hi*wi).
504
+ - valid_ratios: [bs, num_level, 2]
505
+ - key_padding_mask: [bs, sum(hi*wi)]
506
+
507
+ - memory_text: bs, n_text, 256
508
+ - text_attention_mask: bs, n_text
509
+ False for no padding; True for padding
510
+ - pos_text: bs, n_text, 256
511
+
512
+ - position_ids: bs, n_text
513
+ Intermedia:
514
+ - reference_points: [bs, sum(hi*wi), num_level, 2]
515
+ Outpus:
516
+ - output: [bs, sum(hi*wi), 256]
517
+ """
518
+
519
+ output = src
520
+
521
+ # preparation and reshape
522
+ if self.num_layers > 0:
523
+ reference_points = self.get_reference_points(
524
+ spatial_shapes, valid_ratios, device=src.device
525
+ )
526
+
527
+ if self.text_layers:
528
+ # generate pos_text
529
+ bs, n_text, text_dim = memory_text.shape
530
+ if pos_text is None and position_ids is None:
531
+ pos_text = (
532
+ torch.arange(n_text, device=memory_text.device)
533
+ .float()
534
+ .unsqueeze(0)
535
+ .unsqueeze(-1)
536
+ .repeat(bs, 1, 1)
537
+ )
538
+ pos_text = get_sine_pos_embed(pos_text, num_pos_feats=256, exchange_xy=False)
539
+ if position_ids is not None:
540
+ pos_text = get_sine_pos_embed(
541
+ position_ids[..., None], num_pos_feats=256, exchange_xy=False
542
+ )
543
+
544
+ # main process
545
+ for layer_id, layer in enumerate(self.layers):
546
+ # if output.isnan().any() or memory_text.isnan().any():
547
+ # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
548
+ # import ipdb; ipdb.set_trace()
549
+ if self.fusion_layers:
550
+ if self.use_checkpoint:
551
+ output, memory_text = checkpoint.checkpoint(
552
+ self.fusion_layers[layer_id],
553
+ output,
554
+ memory_text,
555
+ key_padding_mask,
556
+ text_attention_mask,
557
+ )
558
+ else:
559
+ output, memory_text = self.fusion_layers[layer_id](
560
+ v=output,
561
+ l=memory_text,
562
+ attention_mask_v=key_padding_mask,
563
+ attention_mask_l=text_attention_mask,
564
+ )
565
+
566
+ if self.text_layers:
567
+ memory_text = self.text_layers[layer_id](
568
+ src=memory_text.transpose(0, 1),
569
+ src_mask=~text_self_attention_masks, # note we use ~ for mask here
570
+ src_key_padding_mask=text_attention_mask,
571
+ pos=(pos_text.transpose(0, 1) if pos_text is not None else None),
572
+ ).transpose(0, 1)
573
+
574
+ # main process
575
+ if self.use_transformer_ckpt:
576
+ output = checkpoint.checkpoint(
577
+ layer,
578
+ output,
579
+ pos,
580
+ reference_points,
581
+ spatial_shapes,
582
+ level_start_index,
583
+ key_padding_mask,
584
+ )
585
+ else:
586
+ output = layer(
587
+ src=output,
588
+ pos=pos,
589
+ reference_points=reference_points,
590
+ spatial_shapes=spatial_shapes,
591
+ level_start_index=level_start_index,
592
+ key_padding_mask=key_padding_mask,
593
+ )
594
+
595
+ return output, memory_text
596
+
597
+
598
+ class TransformerDecoder(nn.Module):
599
+ def __init__(
600
+ self,
601
+ decoder_layer,
602
+ num_layers,
603
+ norm=None,
604
+ return_intermediate=False,
605
+ d_model=256,
606
+ query_dim=4,
607
+ num_feature_levels=1,
608
+ ):
609
+ super().__init__()
610
+ if num_layers > 0:
611
+ self.layers = _get_clones(decoder_layer, num_layers)
612
+ else:
613
+ self.layers = []
614
+ self.num_layers = num_layers
615
+ self.norm = norm
616
+ self.return_intermediate = return_intermediate
617
+ assert return_intermediate, "support return_intermediate only"
618
+ self.query_dim = query_dim
619
+ assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim)
620
+ self.num_feature_levels = num_feature_levels
621
+
622
+ self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2)
623
+ self.query_pos_sine_scale = None
624
+
625
+ self.query_scale = None
626
+ self.bbox_embed = None
627
+ self.class_embed = None
628
+
629
+ self.d_model = d_model
630
+
631
+ self.ref_anchor_head = None
632
+
633
+ def forward(
634
+ self,
635
+ tgt,
636
+ memory,
637
+ tgt_mask: Optional[Tensor] = None,
638
+ memory_mask: Optional[Tensor] = None,
639
+ tgt_key_padding_mask: Optional[Tensor] = None,
640
+ memory_key_padding_mask: Optional[Tensor] = None,
641
+ pos: Optional[Tensor] = None,
642
+ refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2
643
+ # for memory
644
+ level_start_index: Optional[Tensor] = None, # num_levels
645
+ spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
646
+ valid_ratios: Optional[Tensor] = None,
647
+ # for text
648
+ memory_text: Optional[Tensor] = None,
649
+ text_attention_mask: Optional[Tensor] = None,
650
+ ):
651
+ """
652
+ Input:
653
+ - tgt: nq, bs, d_model
654
+ - memory: hw, bs, d_model
655
+ - pos: hw, bs, d_model
656
+ - refpoints_unsigmoid: nq, bs, 2/4
657
+ - valid_ratios/spatial_shapes: bs, nlevel, 2
658
+ """
659
+ output = tgt
660
+
661
+ intermediate = []
662
+ reference_points = refpoints_unsigmoid.sigmoid()
663
+ ref_points = [reference_points]
664
+
665
+ for layer_id, layer in enumerate(self.layers):
666
+
667
+ if reference_points.shape[-1] == 4:
668
+ reference_points_input = (
669
+ reference_points[:, :, None]
670
+ * torch.cat([valid_ratios, valid_ratios], -1)[None, :]
671
+ ) # nq, bs, nlevel, 4
672
+ else:
673
+ assert reference_points.shape[-1] == 2
674
+ reference_points_input = reference_points[:, :, None] * valid_ratios[None, :]
675
+ query_sine_embed = gen_sineembed_for_position(
676
+ reference_points_input[:, :, 0, :]
677
+ ) # nq, bs, 256*2
678
+
679
+ # conditional query
680
+ raw_query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256
681
+ pos_scale = self.query_scale(output) if self.query_scale is not None else 1
682
+ query_pos = pos_scale * raw_query_pos
683
+ # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
684
+ # if query_pos.isnan().any() | query_pos.isinf().any():
685
+ # import ipdb; ipdb.set_trace()
686
+
687
+ # main process
688
+ output = layer(
689
+ tgt=output,
690
+ tgt_query_pos=query_pos,
691
+ tgt_query_sine_embed=query_sine_embed,
692
+ tgt_key_padding_mask=tgt_key_padding_mask,
693
+ tgt_reference_points=reference_points_input,
694
+ memory_text=memory_text,
695
+ text_attention_mask=text_attention_mask,
696
+ memory=memory,
697
+ memory_key_padding_mask=memory_key_padding_mask,
698
+ memory_level_start_index=level_start_index,
699
+ memory_spatial_shapes=spatial_shapes,
700
+ memory_pos=pos,
701
+ self_attn_mask=tgt_mask,
702
+ cross_attn_mask=memory_mask,
703
+ )
704
+ if output.isnan().any() | output.isinf().any():
705
+ print(f"output layer_id {layer_id} is nan")
706
+ try:
707
+ num_nan = output.isnan().sum().item()
708
+ num_inf = output.isinf().sum().item()
709
+ print(f"num_nan {num_nan}, num_inf {num_inf}")
710
+ except Exception as e:
711
+ print(e)
712
+ # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
713
+ # import ipdb; ipdb.set_trace()
714
+
715
+ # iter update
716
+ if self.bbox_embed is not None:
717
+ # box_holder = self.bbox_embed(output)
718
+ # box_holder[..., :self.query_dim] += inverse_sigmoid(reference_points)
719
+ # new_reference_points = box_holder[..., :self.query_dim].sigmoid()
720
+
721
+ reference_before_sigmoid = inverse_sigmoid(reference_points)
722
+ delta_unsig = self.bbox_embed[layer_id](output)
723
+ outputs_unsig = delta_unsig + reference_before_sigmoid
724
+ new_reference_points = outputs_unsig.sigmoid()
725
+
726
+ reference_points = new_reference_points.detach()
727
+ # if layer_id != self.num_layers - 1:
728
+ ref_points.append(new_reference_points)
729
+
730
+ intermediate.append(self.norm(output))
731
+
732
+ return [
733
+ [itm_out.transpose(0, 1) for itm_out in intermediate],
734
+ [itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points],
735
+ ]
736
+
737
+
738
+ class DeformableTransformerEncoderLayer(nn.Module):
739
+ def __init__(
740
+ self,
741
+ d_model=256,
742
+ d_ffn=1024,
743
+ dropout=0.1,
744
+ activation="relu",
745
+ n_levels=4,
746
+ n_heads=8,
747
+ n_points=4,
748
+ ):
749
+ super().__init__()
750
+
751
+ # self attention
752
+ self.self_attn = MSDeformAttn(
753
+ embed_dim=d_model,
754
+ num_levels=n_levels,
755
+ num_heads=n_heads,
756
+ num_points=n_points,
757
+ batch_first=True,
758
+ )
759
+ self.dropout1 = nn.Dropout(dropout)
760
+ self.norm1 = nn.LayerNorm(d_model)
761
+
762
+ # ffn
763
+ r = 16
764
+ self.linear1 = lora.Linear(d_model, d_ffn , r=r)
765
+ self.activation = _get_activation_fn(activation, d_model=d_ffn)
766
+ self.dropout2 = nn.Dropout(dropout)
767
+ self.linear2 = lora.Linear(d_ffn, d_model , r=r)
768
+ self.dropout3 = nn.Dropout(dropout)
769
+ self.norm2 = nn.LayerNorm(d_model)
770
+
771
+ @staticmethod
772
+ def with_pos_embed(tensor, pos):
773
+ return tensor if pos is None else tensor + pos
774
+
775
+ def forward_ffn(self, src):
776
+ src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
777
+ src = src + self.dropout3(src2)
778
+ src = self.norm2(src)
779
+ return src
780
+
781
+ def forward(
782
+ self, src, pos, reference_points, spatial_shapes, level_start_index, key_padding_mask=None
783
+ ):
784
+ # self attention
785
+ # import ipdb; ipdb.set_trace()
786
+ src2 = self.self_attn(
787
+ query=self.with_pos_embed(src, pos),
788
+ reference_points=reference_points,
789
+ value=src,
790
+ spatial_shapes=spatial_shapes,
791
+ level_start_index=level_start_index,
792
+ key_padding_mask=key_padding_mask,
793
+ )
794
+ src = src + self.dropout1(src2)
795
+ src = self.norm1(src)
796
+
797
+ # ffn
798
+ src = self.forward_ffn(src)
799
+
800
+ return src
801
+
802
+
803
+ class DeformableTransformerDecoderLayer(nn.Module):
804
+ def __init__(
805
+ self,
806
+ d_model=256,
807
+ d_ffn=1024,
808
+ dropout=0.1,
809
+ activation="relu",
810
+ n_levels=4,
811
+ n_heads=8,
812
+ n_points=4,
813
+ use_text_feat_guide=False,
814
+ use_text_cross_attention=False,
815
+ ):
816
+ super().__init__()
817
+
818
+ # cross attention
819
+ self.cross_attn = MSDeformAttn(
820
+ embed_dim=d_model,
821
+ num_levels=n_levels,
822
+ num_heads=n_heads,
823
+ num_points=n_points,
824
+ batch_first=True,
825
+ )
826
+ self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
827
+ self.norm1 = nn.LayerNorm(d_model)
828
+
829
+ # cross attention text
830
+ if use_text_cross_attention:
831
+ self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
832
+ self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
833
+ self.catext_norm = nn.LayerNorm(d_model)
834
+
835
+ # self attention
836
+ self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
837
+ self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
838
+ self.norm2 = nn.LayerNorm(d_model)
839
+
840
+ # ffn
841
+ r = 16
842
+ self.linear1 = lora.Linear(d_model, d_ffn , r=r)
843
+ self.activation = _get_activation_fn(activation, d_model=d_ffn, batch_dim=1)
844
+ self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
845
+ self.linear2 = lora.Linear(d_ffn, d_model , r=r)
846
+ self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
847
+ self.norm3 = nn.LayerNorm(d_model)
848
+
849
+ self.key_aware_proj = None
850
+ self.use_text_feat_guide = use_text_feat_guide
851
+ assert not use_text_feat_guide
852
+ self.use_text_cross_attention = use_text_cross_attention
853
+
854
+ def rm_self_attn_modules(self):
855
+ self.self_attn = None
856
+ self.dropout2 = None
857
+ self.norm2 = None
858
+
859
+ @staticmethod
860
+ def with_pos_embed(tensor, pos):
861
+ return tensor if pos is None else tensor + pos
862
+
863
+ def forward_ffn(self, tgt):
864
+ with torch.cuda.amp.autocast(enabled=False):
865
+ tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
866
+ tgt = tgt + self.dropout4(tgt2)
867
+ tgt = self.norm3(tgt)
868
+ return tgt
869
+
870
+ def forward(
871
+ self,
872
+ # for tgt
873
+ tgt: Optional[Tensor], # nq, bs, d_model
874
+ tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
875
+ tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
876
+ tgt_key_padding_mask: Optional[Tensor] = None,
877
+ tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
878
+ memory_text: Optional[Tensor] = None, # bs, num_token, d_model
879
+ text_attention_mask: Optional[Tensor] = None, # bs, num_token
880
+ # for memory
881
+ memory: Optional[Tensor] = None, # hw, bs, d_model
882
+ memory_key_padding_mask: Optional[Tensor] = None,
883
+ memory_level_start_index: Optional[Tensor] = None, # num_levels
884
+ memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
885
+ memory_pos: Optional[Tensor] = None, # pos for memory
886
+ # sa
887
+ self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
888
+ cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
889
+ ):
890
+ """
891
+ Input:
892
+ - tgt/tgt_query_pos: nq, bs, d_model
893
+ -
894
+ """
895
+ assert cross_attn_mask is None
896
+
897
+ # self attention
898
+ if self.self_attn is not None:
899
+ # import ipdb; ipdb.set_trace()
900
+ q = k = self.with_pos_embed(tgt, tgt_query_pos)
901
+ tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]
902
+ tgt = tgt + self.dropout2(tgt2)
903
+ tgt = self.norm2(tgt)
904
+
905
+ if self.use_text_cross_attention:
906
+ tgt2 = self.ca_text(
907
+ self.with_pos_embed(tgt, tgt_query_pos),
908
+ memory_text.transpose(0, 1),
909
+ memory_text.transpose(0, 1),
910
+ key_padding_mask=text_attention_mask,
911
+ )[0]
912
+ tgt = tgt + self.catext_dropout(tgt2)
913
+ tgt = self.catext_norm(tgt)
914
+
915
+ tgt2 = self.cross_attn(
916
+ query=self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
917
+ reference_points=tgt_reference_points.transpose(0, 1).contiguous(),
918
+ value=memory.transpose(0, 1),
919
+ spatial_shapes=memory_spatial_shapes,
920
+ level_start_index=memory_level_start_index,
921
+ key_padding_mask=memory_key_padding_mask,
922
+ ).transpose(0, 1)
923
+ tgt = tgt + self.dropout1(tgt2)
924
+ tgt = self.norm1(tgt)
925
+
926
+ # ffn
927
+ tgt = self.forward_ffn(tgt)
928
+
929
+ return tgt
930
+
931
+
932
+ def build_transformer(args):
933
+ return Transformer(
934
+ d_model=args.hidden_dim,
935
+ dropout=args.dropout,
936
+ nhead=args.nheads,
937
+ num_queries=args.num_queries,
938
+ dim_feedforward=args.dim_feedforward,
939
+ num_encoder_layers=args.enc_layers,
940
+ num_decoder_layers=args.dec_layers,
941
+ normalize_before=args.pre_norm,
942
+ return_intermediate_dec=True,
943
+ query_dim=args.query_dim,
944
+ activation=args.transformer_activation,
945
+ num_patterns=args.num_patterns,
946
+ num_feature_levels=args.num_feature_levels,
947
+ enc_n_points=args.enc_n_points,
948
+ dec_n_points=args.dec_n_points,
949
+ learnable_tgt_init=True,
950
+ # two stage
951
+ two_stage_type=args.two_stage_type, # ['no', 'standard', 'early']
952
+ embed_init_tgt=args.embed_init_tgt,
953
+ use_text_enhancer=args.use_text_enhancer,
954
+ use_fusion_layer=args.use_fusion_layer,
955
+ use_checkpoint=args.use_checkpoint,
956
+ use_transformer_ckpt=args.use_transformer_ckpt,
957
+ use_text_cross_attention=args.use_text_cross_attention,
958
+ text_dropout=args.text_dropout,
959
+ fusion_dropout=args.fusion_dropout,
960
+ fusion_droppath=args.fusion_droppath,
961
+ )
groundingdino/models/GroundingDINO/transformer_vanilla.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
8
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
9
+ """
10
+ DETR Transformer class.
11
+
12
+ Copy-paste from torch.nn.Transformer with modifications:
13
+ * positional encodings are passed in MHattention
14
+ * extra LN at the end of encoder is removed
15
+ * decoder returns a stack of activations from all decoding layers
16
+ """
17
+ from typing import Optional
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from torch import Tensor, nn
22
+ import loralib as lora
23
+ from .utils import (
24
+ MLP,
25
+ _get_activation_fn,
26
+ _get_clones,
27
+ gen_encoder_output_proposals,
28
+ gen_sineembed_for_position,
29
+ sigmoid_focal_loss,
30
+ )
31
+
32
+
33
+ class TextTransformer(nn.Module):
34
+ def __init__(self, num_layers, d_model=256, nheads=8, dim_feedforward=2048, dropout=0.1):
35
+ super().__init__()
36
+ self.num_layers = num_layers
37
+ self.d_model = d_model
38
+ self.nheads = nheads
39
+ self.dim_feedforward = dim_feedforward
40
+ self.norm = None
41
+
42
+ single_encoder_layer = TransformerEncoderLayer(
43
+ d_model=d_model, nhead=nheads, dim_feedforward=dim_feedforward, dropout=dropout
44
+ )
45
+ self.layers = _get_clones(single_encoder_layer, num_layers)
46
+
47
+ def forward(self, memory_text: torch.Tensor, text_attention_mask: torch.Tensor):
48
+ """
49
+
50
+ Args:
51
+ text_attention_mask: bs, num_token
52
+ memory_text: bs, num_token, d_model
53
+
54
+ Raises:
55
+ RuntimeError: _description_
56
+
57
+ Returns:
58
+ output: bs, num_token, d_model
59
+ """
60
+
61
+ output = memory_text.transpose(0, 1)
62
+
63
+ for layer in self.layers:
64
+ output = layer(output, src_key_padding_mask=text_attention_mask)
65
+
66
+ if self.norm is not None:
67
+ output = self.norm(output)
68
+
69
+ return output.transpose(0, 1)
70
+
71
+
72
+ class TransformerEncoderLayer(nn.Module):
73
+ def __init__(
74
+ self,
75
+ d_model,
76
+ nhead,
77
+ dim_feedforward=2048,
78
+ dropout=0.1,
79
+ activation="relu",
80
+ normalize_before=False,
81
+ ):
82
+ super().__init__()
83
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
84
+ # Implementation of Feedforward model
85
+ r=16
86
+ self.linear1 = lora.Linear(d_model, dim_feedforward , r=r)
87
+ self.dropout = nn.Dropout(dropout)
88
+ self.linear2 = lora.Linear(dim_feedforward, d_model , r=r)
89
+
90
+ self.norm1 = nn.LayerNorm(d_model)
91
+ self.norm2 = nn.LayerNorm(d_model)
92
+ self.dropout1 = nn.Dropout(dropout)
93
+ self.dropout2 = nn.Dropout(dropout)
94
+
95
+ self.activation = _get_activation_fn(activation)
96
+ self.normalize_before = normalize_before
97
+ self.nhead = nhead
98
+
99
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
100
+ return tensor if pos is None else tensor + pos
101
+
102
+ def forward(
103
+ self,
104
+ src,
105
+ src_mask: Optional[Tensor] = None,
106
+ src_key_padding_mask: Optional[Tensor] = None,
107
+ pos: Optional[Tensor] = None,
108
+ ):
109
+ # repeat attn mask
110
+ if src_mask.dim() == 3 and src_mask.shape[0] == src.shape[1]:
111
+ # bs, num_q, num_k
112
+ src_mask = src_mask.repeat(self.nhead, 1, 1)
113
+
114
+ q = k = self.with_pos_embed(src, pos)
115
+
116
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)[0]
117
+
118
+ # src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
119
+ src = src + self.dropout1(src2)
120
+ src = self.norm1(src)
121
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
122
+ src = src + self.dropout2(src2)
123
+ src = self.norm2(src)
124
+ return src
groundingdino/models/GroundingDINO/utils.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+
8
+ import copy
9
+ import math
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import Tensor, nn
14
+ import loralib as lora
15
+
16
+ def _get_clones(module, N, layer_share=False):
17
+ # import ipdb; ipdb.set_trace()
18
+ if layer_share:
19
+ return nn.ModuleList([module for i in range(N)])
20
+ else:
21
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
22
+
23
+
24
+ def get_sine_pos_embed(
25
+ pos_tensor: torch.Tensor,
26
+ num_pos_feats: int = 128,
27
+ temperature: int = 10000,
28
+ exchange_xy: bool = True,
29
+ ):
30
+ """generate sine position embedding from a position tensor
31
+ Args:
32
+ pos_tensor (torch.Tensor): shape: [..., n].
33
+ num_pos_feats (int): projected shape for each float in the tensor.
34
+ temperature (int): temperature in the sine/cosine function.
35
+ exchange_xy (bool, optional): exchange pos x and pos y. \
36
+ For example, input tensor is [x,y], the results will be [pos(y), pos(x)]. Defaults to True.
37
+ Returns:
38
+ pos_embed (torch.Tensor): shape: [..., n*num_pos_feats].
39
+ """
40
+ scale = 2 * math.pi
41
+ dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device)
42
+ dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
43
+
44
+ def sine_func(x: torch.Tensor):
45
+ sin_x = x * scale / dim_t
46
+ sin_x = torch.stack((sin_x[..., 0::2].sin(), sin_x[..., 1::2].cos()), dim=3).flatten(2)
47
+ return sin_x
48
+
49
+ pos_res = [sine_func(x) for x in pos_tensor.split([1] * pos_tensor.shape[-1], dim=-1)]
50
+ if exchange_xy:
51
+ pos_res[0], pos_res[1] = pos_res[1], pos_res[0]
52
+ pos_res = torch.cat(pos_res, dim=-1)
53
+ return pos_res
54
+
55
+
56
+ def gen_encoder_output_proposals(
57
+ memory: Tensor, memory_padding_mask: Tensor, spatial_shapes: Tensor, learnedwh=None
58
+ ):
59
+ """
60
+ Input:
61
+ - memory: bs, \sum{hw}, d_model
62
+ - memory_padding_mask: bs, \sum{hw}
63
+ - spatial_shapes: nlevel, 2
64
+ - learnedwh: 2
65
+ Output:
66
+ - output_memory: bs, \sum{hw}, d_model
67
+ - output_proposals: bs, \sum{hw}, 4
68
+ """
69
+ N_, S_, C_ = memory.shape
70
+ proposals = []
71
+ _cur = 0
72
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
73
+ mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H_ * W_)].view(N_, H_, W_, 1)
74
+ valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
75
+ valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
76
+
77
+ # import ipdb; ipdb.set_trace()
78
+
79
+ grid_y, grid_x = torch.meshgrid(
80
+ torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
81
+ torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device),
82
+ )
83
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
84
+
85
+ scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
86
+ grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
87
+
88
+ if learnedwh is not None:
89
+ # import ipdb; ipdb.set_trace()
90
+ wh = torch.ones_like(grid) * learnedwh.sigmoid() * (2.0**lvl)
91
+ else:
92
+ wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
93
+
94
+ # scale = torch.cat([W_[None].unsqueeze(-1), H_[None].unsqueeze(-1)], 1).view(1, 1, 1, 2).repeat(N_, 1, 1, 1)
95
+ # grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
96
+ # wh = torch.ones_like(grid) / scale
97
+ proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
98
+ proposals.append(proposal)
99
+ _cur += H_ * W_
100
+ # import ipdb; ipdb.set_trace()
101
+ output_proposals = torch.cat(proposals, 1)
102
+ output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(
103
+ -1, keepdim=True
104
+ )
105
+ output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid
106
+ output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float("inf"))
107
+ output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf"))
108
+
109
+ output_memory = memory
110
+ output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
111
+ output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
112
+
113
+ # output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
114
+ # output_memory = output_memory.masked_fill(~output_proposals_valid, float('inf'))
115
+
116
+ return output_memory, output_proposals
117
+
118
+
119
+ class RandomBoxPerturber:
120
+ def __init__(
121
+ self, x_noise_scale=0.2, y_noise_scale=0.2, w_noise_scale=0.2, h_noise_scale=0.2
122
+ ) -> None:
123
+ self.noise_scale = torch.Tensor(
124
+ [x_noise_scale, y_noise_scale, w_noise_scale, h_noise_scale]
125
+ )
126
+
127
+ def __call__(self, refanchors: Tensor) -> Tensor:
128
+ nq, bs, query_dim = refanchors.shape
129
+ device = refanchors.device
130
+
131
+ noise_raw = torch.rand_like(refanchors)
132
+ noise_scale = self.noise_scale.to(device)[:query_dim]
133
+
134
+ new_refanchors = refanchors * (1 + (noise_raw - 0.5) * noise_scale)
135
+ return new_refanchors.clamp_(0, 1)
136
+
137
+
138
+ def sigmoid_focal_loss(
139
+ inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2, no_reduction=False
140
+ ):
141
+ """
142
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
143
+ Args:
144
+ inputs: A float tensor of arbitrary shape.
145
+ The predictions for each example.
146
+ targets: A float tensor with the same shape as inputs. Stores the binary
147
+ classification label for each element in inputs
148
+ (0 for the negative class and 1 for the positive class).
149
+ alpha: (optional) Weighting factor in range (0,1) to balance
150
+ positive vs negative examples. Default = -1 (no weighting).
151
+ gamma: Exponent of the modulating factor (1 - p_t) to
152
+ balance easy vs hard examples.
153
+ Returns:
154
+ Loss tensor
155
+ """
156
+ prob = inputs.sigmoid()
157
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
158
+ p_t = prob * targets + (1 - prob) * (1 - targets)
159
+ loss = ce_loss * ((1 - p_t) ** gamma)
160
+
161
+ if alpha >= 0:
162
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
163
+ loss = alpha_t * loss
164
+
165
+ if no_reduction:
166
+ return loss
167
+
168
+ return loss.mean(1).sum() / num_boxes
169
+
170
+
171
+ class MLP(nn.Module):
172
+ """Very simple multi-layer perceptron (also called FFN)"""
173
+
174
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
175
+ super().__init__()
176
+ self.num_layers = num_layers
177
+ r=16
178
+ h = [hidden_dim] * (num_layers - 1)
179
+ self.layers = nn.ModuleList(
180
+ [lora.Linear(n, k, r=r) for n, k in zip([input_dim] + h, h + [output_dim])]
181
+ )
182
+
183
+ def forward(self, x):
184
+ for i, layer in enumerate(self.layers):
185
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
186
+ return x
187
+
188
+
189
+ def _get_activation_fn(activation, d_model=256, batch_dim=0):
190
+ """Return an activation function given a string"""
191
+ if activation == "relu":
192
+ return F.relu
193
+ if activation == "gelu":
194
+ return F.gelu
195
+ if activation == "glu":
196
+ return F.glu
197
+ if activation == "prelu":
198
+ return nn.PReLU()
199
+ if activation == "selu":
200
+ return F.selu
201
+
202
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
203
+
204
+
205
+ def gen_sineembed_for_position(pos_tensor):
206
+ # n_query, bs, _ = pos_tensor.size()
207
+ # sineembed_tensor = torch.zeros(n_query, bs, 256)
208
+ scale = 2 * math.pi
209
+ dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
210
+ dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode='floor')) / 128)
211
+ x_embed = pos_tensor[:, :, 0] * scale
212
+ y_embed = pos_tensor[:, :, 1] * scale
213
+ pos_x = x_embed[:, :, None] / dim_t
214
+ pos_y = y_embed[:, :, None] / dim_t
215
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
216
+ pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
217
+ if pos_tensor.size(-1) == 2:
218
+ pos = torch.cat((pos_y, pos_x), dim=2)
219
+ elif pos_tensor.size(-1) == 4:
220
+ w_embed = pos_tensor[:, :, 2] * scale
221
+ pos_w = w_embed[:, :, None] / dim_t
222
+ pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
223
+
224
+ h_embed = pos_tensor[:, :, 3] * scale
225
+ pos_h = h_embed[:, :, None] / dim_t
226
+ pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
227
+
228
+ pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
229
+ else:
230
+ raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
231
+ return pos
232
+
233
+
234
+ class ContrastiveEmbed(nn.Module):
235
+ def __init__(self, max_text_len=256):
236
+ """
237
+ Args:
238
+ max_text_len: max length of text.
239
+ """
240
+ super().__init__()
241
+ self.max_text_len = max_text_len
242
+
243
+ def forward(self, x, text_dict):
244
+ """_summary_
245
+
246
+ Args:
247
+ x (_type_): _description_
248
+ text_dict (_type_): _description_
249
+ {
250
+ 'encoded_text': encoded_text, # bs, 195, d_model
251
+ 'text_token_mask': text_token_mask, # bs, 195
252
+ # True for used tokens. False for padding tokens
253
+ }
254
+ Returns:
255
+ _type_: _description_
256
+ """
257
+ assert isinstance(text_dict, dict)
258
+
259
+ y = text_dict["encoded_text"]
260
+ text_token_mask = text_dict["text_token_mask"]
261
+
262
+ res = x @ y.transpose(-1, -2)
263
+ res.masked_fill_(~text_token_mask[:, None, :], float("-inf"))
264
+
265
+ # padding to max_text_len
266
+ new_res = torch.full((*res.shape[:-1], self.max_text_len), float("-inf"), device=res.device)
267
+ new_res[..., : res.shape[-1]] = res
268
+
269
+ return new_res
groundingdino/models/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
8
+ from .GroundingDINO import build_groundingdino
9
+
10
+
11
+ def build_model(args):
12
+ # we use register to maintain models from catdet6 on.
13
+ from .registry import MODULE_BUILD_FUNCS
14
+
15
+ assert args.modelname in MODULE_BUILD_FUNCS._module_dict
16
+ build_func = MODULE_BUILD_FUNCS.get(args.modelname)
17
+ model = build_func(args)
18
+ return model
groundingdino/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (504 Bytes). View file
 
groundingdino/models/__pycache__/registry.cpython-310.pyc ADDED
Binary file (2.11 kB). View file
 
groundingdino/models/registry.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # -*- coding: utf-8 -*-
8
+ # @Author: Yihao Chen
9
+ # @Date: 2021-08-16 16:03:17
10
+ # @Last Modified by: Shilong Liu
11
+ # @Last Modified time: 2022-01-23 15:26
12
+ # modified from mmcv
13
+
14
+ import inspect
15
+ from functools import partial
16
+
17
+
18
+ class Registry(object):
19
+ def __init__(self, name):
20
+ self._name = name
21
+ self._module_dict = dict()
22
+
23
+ def __repr__(self):
24
+ format_str = self.__class__.__name__ + "(name={}, items={})".format(
25
+ self._name, list(self._module_dict.keys())
26
+ )
27
+ return format_str
28
+
29
+ def __len__(self):
30
+ return len(self._module_dict)
31
+
32
+ @property
33
+ def name(self):
34
+ return self._name
35
+
36
+ @property
37
+ def module_dict(self):
38
+ return self._module_dict
39
+
40
+ def get(self, key):
41
+ return self._module_dict.get(key, None)
42
+
43
+ def registe_with_name(self, module_name=None, force=False):
44
+ return partial(self.register, module_name=module_name, force=force)
45
+
46
+ def register(self, module_build_function, module_name=None, force=False):
47
+ """Register a module build function.
48
+ Args:
49
+ module (:obj:`nn.Module`): Module to be registered.
50
+ """
51
+ if not inspect.isfunction(module_build_function):
52
+ raise TypeError(
53
+ "module_build_function must be a function, but got {}".format(
54
+ type(module_build_function)
55
+ )
56
+ )
57
+ if module_name is None:
58
+ module_name = module_build_function.__name__
59
+ if not force and module_name in self._module_dict:
60
+ raise KeyError("{} is already registered in {}".format(module_name, self.name))
61
+ self._module_dict[module_name] = module_build_function
62
+
63
+ return module_build_function
64
+
65
+
66
+ MODULE_BUILD_FUNCS = Registry("model build functions")
groundingdino/util/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (187 Bytes). View file