Hasanmog commited on
Commit
7b8b8a8
·
1 Parent(s): 23a2072
run.py → .ipynb_checkpoints/app-checkpoint.py RENAMED
@@ -36,9 +36,9 @@ from huggingface_hub import hf_hub_download
36
 
37
 
38
  # Use this command for evaluate the Grounding DINO model
39
- config_file = "groundingdino/config/GroundingDINO_SwinT_OGC.py"
40
- ckpt_repo_id = "ShilongLiu/GroundingDINO"
41
- ckpt_filenmae = "groundingdino_swint_ogc.pth"
42
 
43
 
44
  def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
 
36
 
37
 
38
  # Use this command for evaluate the Grounding DINO model
39
+ config_file = "cfg_odvg.py"
40
+ ckpt_repo_id = "Hasanmog/Peft-GroundingDINO"
41
+ ckpt_filenmae = "Best.pth"
42
 
43
 
44
  def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
app.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ <<<<<<< HEAD
3
+ <<<<<<< HEAD
4
+ from functools import partial
5
+ import cv2
6
+ import requests
7
+ import os
8
+ from io import BytesIO
9
+ from PIL import Image
10
+ import numpy as np
11
+ from pathlib import Path
12
+
13
+
14
+ import warnings
15
+
16
+ import torch
17
+
18
+ # prepare the environment
19
+ os.system("python setup.py build develop --user")
20
+ os.system("pip install packaging==21.3")
21
+ os.system("pip install gradio")
22
+
23
+
24
+ warnings.filterwarnings("ignore")
25
+
26
+ import gradio as gr
27
+
28
+ from groundingdino.models import build_model
29
+ from groundingdino.util.slconfig import SLConfig
30
+ from groundingdino.util.utils import clean_state_dict
31
+ from groundingdino.util.inference import annotate, load_image, predict
32
+ import groundingdino.datasets.transforms as T
33
+
34
+ from huggingface_hub import hf_hub_download
35
+
36
+
37
+
38
+ # Use this command for evaluate the Grounding DINO model
39
+ config_file = "cfg_odvg.py"
40
+ ckpt_repo_id = "Hasanmog/Peft-GroundingDINO"
41
+ ckpt_filenmae = "Best.pth"
42
+
43
+
44
+ def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
45
+ args = SLConfig.fromfile(model_config_path)
46
+ model = build_model(args)
47
+ args.device = device
48
+
49
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
50
+ checkpoint = torch.load(cache_file, map_location='cpu')
51
+ log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
52
+ print("Model loaded from {} \n => {}".format(cache_file, log))
53
+ _ = model.eval()
54
+ return model
55
+
56
+ def image_transform_grounding(init_image):
57
+ transform = T.Compose([
58
+ T.RandomResize([800], max_size=1333),
59
+ T.ToTensor(),
60
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
61
+ ])
62
+ image, _ = transform(init_image, None) # 3, h, w
63
+ return init_image, image
64
+
65
+ def image_transform_grounding_for_vis(init_image):
66
+ transform = T.Compose([
67
+ T.RandomResize([800], max_size=1333),
68
+ ])
69
+ image, _ = transform(init_image, None) # 3, h, w
70
+ return image
71
+
72
+ model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
73
+
74
+ def run_grounding(input_image, grounding_caption, box_threshold, text_threshold):
75
+ init_image = input_image.convert("RGB")
76
+ original_size = init_image.size
77
+
78
+ _, image_tensor = image_transform_grounding(init_image)
79
+ image_pil: Image = image_transform_grounding_for_vis(init_image)
80
+
81
+ # run grounidng
82
+ boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
83
+ annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
84
+ image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
85
+
86
+
87
+ return image_with_box
88
+
89
+ if __name__ == "__main__":
90
+
91
+ parser = argparse.ArgumentParser("Grounding DINO demo", add_help=True)
92
+ parser.add_argument("--debug", action="store_true", help="using debug mode")
93
+ parser.add_argument("--share", action="store_true", help="share the app")
94
+ args = parser.parse_args()
95
+
96
+ block = gr.Blocks().queue()
97
+ with block:
98
+ gr.Markdown("# [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO)")
99
+ gr.Markdown("### Open-World Detection with Grounding DINO")
100
+
101
+ with gr.Row():
102
+ with gr.Column():
103
+ input_image = gr.Image(source='upload', type="pil")
104
+ grounding_caption = gr.Textbox(label="Detection Prompt")
105
+ run_button = gr.Button(label="Run")
106
+ with gr.Accordion("Advanced options", open=False):
107
+ box_threshold = gr.Slider(
108
+ label="Box Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
109
+ )
110
+ text_threshold = gr.Slider(
111
+ label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
112
+ )
113
+
114
+ with gr.Column():
115
+ gallery = gr.outputs.Image(
116
+ type="pil",
117
+ # label="grounding results"
118
+ ).style(full_width=True, full_height=True)
119
+ # gallery = gr.Gallery(label="Generated images", show_label=False).style(
120
+ # grid=[1], height="auto", container=True, full_width=True, full_height=True)
121
+
122
+ run_button.click(fn=run_grounding, inputs=[
123
+ input_image, grounding_caption, box_threshold, text_threshold], outputs=[gallery])
124
+
125
+
126
+ block.launch(server_name='0.0.0.0', server_port=7579, debug=args.debug, share=args.share)
127
+
128
+ =======
129
+ =======
130
+ >>>>>>> e7662d3789ee2d5b878c7399e1f04cb075927919
131
+ import os
132
+ import numpy as np
133
+ import torch
134
+ from PIL import Image, ImageDraw, ImageFont
135
+ # please make sure https://github.com/IDEA-Research/GroundingDINO is installed correctly.
136
+ import groundingdino.datasets.transforms as T
137
+ from groundingdino.models import build_model
138
+ from groundingdino.util import box_ops
139
+ from groundingdino.util.slconfig import SLConfig
140
+ from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
141
+ from groundingdino.util.vl_utils import create_positive_map_from_span
142
+
143
+
144
+ def plot_boxes_to_image(image_pil, tgt):
145
+ H, W = tgt["size"]
146
+ boxes = tgt["boxes"]
147
+ labels = tgt["labels"]
148
+ assert len(boxes) == len(labels), "boxes and labels must have same length"
149
+
150
+ draw = ImageDraw.Draw(image_pil)
151
+ mask = Image.new("L", image_pil.size, 0)
152
+ mask_draw = ImageDraw.Draw(mask)
153
+
154
+ # draw boxes and masks
155
+ for box, label in zip(boxes, labels):
156
+ # from 0..1 to 0..W, 0..H
157
+ box = box * torch.Tensor([W, H, W, H])
158
+ # from xywh to xyxy
159
+ box[:2] -= box[2:] / 2
160
+ box[2:] += box[:2]
161
+ # random color
162
+ color = tuple(np.random.randint(0, 255, size=3).tolist())
163
+ # draw
164
+ x0, y0, x1, y1 = box
165
+ x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
166
+
167
+ draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
168
+ # draw.text((x0, y0), str(label), fill=color)
169
+
170
+ font = ImageFont.load_default()
171
+ if hasattr(font, "getbbox"):
172
+ bbox = draw.textbbox((x0, y0), str(label), font)
173
+ else:
174
+ w, h = draw.textsize(str(label), font)
175
+ bbox = (x0, y0, w + x0, y0 + h)
176
+ # bbox = draw.textbbox((x0, y0), str(label))
177
+ draw.rectangle(bbox, fill=color)
178
+ draw.text((x0, y0), str(label), fill="white")
179
+
180
+ mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6)
181
+
182
+ return image_pil, mask
183
+
184
+
185
+ def load_image(image_path):
186
+ # load image
187
+ image_pil = Image.open(image_path).convert("RGB") # load image
188
+
189
+ transform = T.Compose(
190
+ [
191
+ T.RandomResize([800], max_size=1333),
192
+ T.ToTensor(),
193
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
194
+ ]
195
+ )
196
+ image, _ = transform(image_pil, None) # 3, h, w
197
+ return image_pil, image
198
+
199
+
200
+ def load_model(model_config_path, model_checkpoint_path, cpu_only=False):
201
+ args = SLConfig.fromfile(model_config_path)
202
+ args.device = "cuda" if not cpu_only else "cpu"
203
+ model = build_model(args)
204
+ checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
205
+ load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
206
+ print(load_res)
207
+ _ = model.eval()
208
+ return model
209
+
210
+
211
+ def get_grounding_output(model, image, caption, box_threshold, text_threshold=None, with_logits=True, cpu_only=False, token_spans=None):
212
+ assert text_threshold is not None or token_spans is not None, "text_threshould and token_spans should not be None at the same time!"
213
+ caption = caption.lower()
214
+ caption = caption.strip()
215
+ if not caption.endswith("."):
216
+ caption = caption + "."
217
+ device = "cuda" if not cpu_only else "cpu"
218
+ model = model.to(device)
219
+ image = image.to(device)
220
+ with torch.no_grad():
221
+ outputs = model(image[None], captions=[caption])
222
+ logits = outputs["pred_logits"].sigmoid()[0] # (nq, 256)
223
+ boxes = outputs["pred_boxes"][0] # (nq, 4)
224
+
225
+ # filter output
226
+ if token_spans is None:
227
+ logits_filt = logits.cpu().clone()
228
+ boxes_filt = boxes.cpu().clone()
229
+ filt_mask = logits_filt.max(dim=1)[0] > box_threshold
230
+ logits_filt = logits_filt[filt_mask] # num_filt, 256
231
+ boxes_filt = boxes_filt[filt_mask] # num_filt, 4
232
+
233
+ # get phrase
234
+ tokenlizer = model.tokenizer
235
+ tokenized = tokenlizer(caption)
236
+ # build pred
237
+ pred_phrases = []
238
+ for logit, box in zip(logits_filt, boxes_filt):
239
+ pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
240
+ if with_logits:
241
+ pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
242
+ else:
243
+ pred_phrases.append(pred_phrase)
244
+ else:
245
+ # given-phrase mode
246
+ positive_maps = create_positive_map_from_span(
247
+ model.tokenizer(text_prompt),
248
+ token_span=token_spans
249
+ ).to(image.device) # n_phrase, 256
250
+
251
+ logits_for_phrases = positive_maps @ logits.T # n_phrase, nq
252
+ all_logits = []
253
+ all_phrases = []
254
+ all_boxes = []
255
+ for (token_span, logit_phr) in zip(token_spans, logits_for_phrases):
256
+ # get phrase
257
+ phrase = ' '.join([caption[_s:_e] for (_s, _e) in token_span])
258
+ # get mask
259
+ filt_mask = logit_phr > box_threshold
260
+ # filt box
261
+ all_boxes.append(boxes[filt_mask])
262
+ # filt logits
263
+ all_logits.append(logit_phr[filt_mask])
264
+ if with_logits:
265
+ logit_phr_num = logit_phr[filt_mask]
266
+ all_phrases.extend([phrase + f"({str(logit.item())[:4]})" for logit in logit_phr_num])
267
+ else:
268
+ all_phrases.extend([phrase for _ in range(len(filt_mask))])
269
+ boxes_filt = torch.cat(all_boxes, dim=0).cpu()
270
+ pred_phrases = all_phrases
271
+
272
+
273
+ return boxes_filt, pred_phrases
274
+
275
+
276
+ if __name__ == "__main__":
277
+
278
+ parser = argparse.ArgumentParser("Grounding DINO example", add_help=True)
279
+ parser.add_argument("--config_file", "-c", type=str, required=True, help="path to config file")
280
+ parser.add_argument(
281
+ "--checkpoint_path", "-p", type=str, required=True, help="path to checkpoint file"
282
+ )
283
+ parser.add_argument("--image_path", "-i", type=str, required=True, help="path to image file")
284
+ parser.add_argument("--text_prompt", "-t", type=str, required=True, help="text prompt")
285
+ parser.add_argument(
286
+ "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
287
+ )
288
+
289
+ parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
290
+ parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
291
+ parser.add_argument("--token_spans", type=str, default=None, help=
292
+ "The positions of start and end positions of phrases of interest. \
293
+ For example, a caption is 'a cat and a dog', \
294
+ if you would like to detect 'cat', the token_spans should be '[[[2, 5]], ]', since 'a cat and a dog'[2:5] is 'cat'. \
295
+ if you would like to detect 'a cat', the token_spans should be '[[[0, 1], [2, 5]], ]', since 'a cat and a dog'[0:1] is 'a', and 'a cat and a dog'[2:5] is 'cat'. \
296
+ ")
297
+
298
+ parser.add_argument("--cpu-only", action="store_true", help="running on cpu only!, default=False")
299
+ args = parser.parse_args()
300
+
301
+ # cfg
302
+ config_file = args.config_file # change the path of the model config file
303
+ checkpoint_path = args.checkpoint_path # change the path of the model
304
+ image_path = args.image_path
305
+ text_prompt = args.text_prompt
306
+ output_dir = args.output_dir
307
+ box_threshold = args.box_threshold
308
+ text_threshold = args.text_threshold
309
+ token_spans = args.token_spans
310
+
311
+ # make dir
312
+ os.makedirs(output_dir, exist_ok=True)
313
+ # load image
314
+ image_pil, image = load_image(image_path)
315
+ # load model
316
+ model = load_model(config_file, checkpoint_path, cpu_only=args.cpu_only)
317
+
318
+ # visualize raw image
319
+ image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
320
+
321
+ # set the text_threshold to None if token_spans is set.
322
+ if token_spans is not None:
323
+ text_threshold = None
324
+ print("Using token_spans. Set the text_threshold to None.")
325
+
326
+
327
+ # run model
328
+ boxes_filt, pred_phrases = get_grounding_output(
329
+ model, image, text_prompt, box_threshold, text_threshold, cpu_only=args.cpu_only, token_spans=token_spans
330
+ )
331
+
332
+ # visualize pred
333
+ size = image_pil.size
334
+ pred_dict = {
335
+ "boxes": boxes_filt,
336
+ "size": [size[1], size[0]], # H,W
337
+ "labels": pred_phrases,
338
+ }
339
+ image_with_box = plot_boxes_to_image(image_pil, pred_dict)[0]
340
+ save_path = os.path.join(output_dir, "pred.jpg")
341
+ image_with_box.save(save_path)
342
+ print(f"\n======================\n{save_path} saved.\nThe program runs successfully!")
343
+ <<<<<<< HEAD
344
+ >>>>>>> e7662d3789ee2d5b878c7399e1f04cb075927919
345
+ =======
346
+ >>>>>>> e7662d3789ee2d5b878c7399e1f04cb075927919