Spaces:
Runtime error
Runtime error
update
Browse files- run.py → .ipynb_checkpoints/app-checkpoint.py +3 -3
- app.py +346 -0
run.py → .ipynb_checkpoints/app-checkpoint.py
RENAMED
@@ -36,9 +36,9 @@ from huggingface_hub import hf_hub_download
|
|
36 |
|
37 |
|
38 |
# Use this command for evaluate the Grounding DINO model
|
39 |
-
config_file = "
|
40 |
-
ckpt_repo_id = "
|
41 |
-
ckpt_filenmae = "
|
42 |
|
43 |
|
44 |
def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
|
|
|
36 |
|
37 |
|
38 |
# Use this command for evaluate the Grounding DINO model
|
39 |
+
config_file = "cfg_odvg.py"
|
40 |
+
ckpt_repo_id = "Hasanmog/Peft-GroundingDINO"
|
41 |
+
ckpt_filenmae = "Best.pth"
|
42 |
|
43 |
|
44 |
def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
|
app.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
<<<<<<< HEAD
|
3 |
+
<<<<<<< HEAD
|
4 |
+
from functools import partial
|
5 |
+
import cv2
|
6 |
+
import requests
|
7 |
+
import os
|
8 |
+
from io import BytesIO
|
9 |
+
from PIL import Image
|
10 |
+
import numpy as np
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
|
14 |
+
import warnings
|
15 |
+
|
16 |
+
import torch
|
17 |
+
|
18 |
+
# prepare the environment
|
19 |
+
os.system("python setup.py build develop --user")
|
20 |
+
os.system("pip install packaging==21.3")
|
21 |
+
os.system("pip install gradio")
|
22 |
+
|
23 |
+
|
24 |
+
warnings.filterwarnings("ignore")
|
25 |
+
|
26 |
+
import gradio as gr
|
27 |
+
|
28 |
+
from groundingdino.models import build_model
|
29 |
+
from groundingdino.util.slconfig import SLConfig
|
30 |
+
from groundingdino.util.utils import clean_state_dict
|
31 |
+
from groundingdino.util.inference import annotate, load_image, predict
|
32 |
+
import groundingdino.datasets.transforms as T
|
33 |
+
|
34 |
+
from huggingface_hub import hf_hub_download
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
# Use this command for evaluate the Grounding DINO model
|
39 |
+
config_file = "cfg_odvg.py"
|
40 |
+
ckpt_repo_id = "Hasanmog/Peft-GroundingDINO"
|
41 |
+
ckpt_filenmae = "Best.pth"
|
42 |
+
|
43 |
+
|
44 |
+
def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
|
45 |
+
args = SLConfig.fromfile(model_config_path)
|
46 |
+
model = build_model(args)
|
47 |
+
args.device = device
|
48 |
+
|
49 |
+
cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
|
50 |
+
checkpoint = torch.load(cache_file, map_location='cpu')
|
51 |
+
log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
|
52 |
+
print("Model loaded from {} \n => {}".format(cache_file, log))
|
53 |
+
_ = model.eval()
|
54 |
+
return model
|
55 |
+
|
56 |
+
def image_transform_grounding(init_image):
|
57 |
+
transform = T.Compose([
|
58 |
+
T.RandomResize([800], max_size=1333),
|
59 |
+
T.ToTensor(),
|
60 |
+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
61 |
+
])
|
62 |
+
image, _ = transform(init_image, None) # 3, h, w
|
63 |
+
return init_image, image
|
64 |
+
|
65 |
+
def image_transform_grounding_for_vis(init_image):
|
66 |
+
transform = T.Compose([
|
67 |
+
T.RandomResize([800], max_size=1333),
|
68 |
+
])
|
69 |
+
image, _ = transform(init_image, None) # 3, h, w
|
70 |
+
return image
|
71 |
+
|
72 |
+
model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
|
73 |
+
|
74 |
+
def run_grounding(input_image, grounding_caption, box_threshold, text_threshold):
|
75 |
+
init_image = input_image.convert("RGB")
|
76 |
+
original_size = init_image.size
|
77 |
+
|
78 |
+
_, image_tensor = image_transform_grounding(init_image)
|
79 |
+
image_pil: Image = image_transform_grounding_for_vis(init_image)
|
80 |
+
|
81 |
+
# run grounidng
|
82 |
+
boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
|
83 |
+
annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
|
84 |
+
image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
|
85 |
+
|
86 |
+
|
87 |
+
return image_with_box
|
88 |
+
|
89 |
+
if __name__ == "__main__":
|
90 |
+
|
91 |
+
parser = argparse.ArgumentParser("Grounding DINO demo", add_help=True)
|
92 |
+
parser.add_argument("--debug", action="store_true", help="using debug mode")
|
93 |
+
parser.add_argument("--share", action="store_true", help="share the app")
|
94 |
+
args = parser.parse_args()
|
95 |
+
|
96 |
+
block = gr.Blocks().queue()
|
97 |
+
with block:
|
98 |
+
gr.Markdown("# [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO)")
|
99 |
+
gr.Markdown("### Open-World Detection with Grounding DINO")
|
100 |
+
|
101 |
+
with gr.Row():
|
102 |
+
with gr.Column():
|
103 |
+
input_image = gr.Image(source='upload', type="pil")
|
104 |
+
grounding_caption = gr.Textbox(label="Detection Prompt")
|
105 |
+
run_button = gr.Button(label="Run")
|
106 |
+
with gr.Accordion("Advanced options", open=False):
|
107 |
+
box_threshold = gr.Slider(
|
108 |
+
label="Box Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
|
109 |
+
)
|
110 |
+
text_threshold = gr.Slider(
|
111 |
+
label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
|
112 |
+
)
|
113 |
+
|
114 |
+
with gr.Column():
|
115 |
+
gallery = gr.outputs.Image(
|
116 |
+
type="pil",
|
117 |
+
# label="grounding results"
|
118 |
+
).style(full_width=True, full_height=True)
|
119 |
+
# gallery = gr.Gallery(label="Generated images", show_label=False).style(
|
120 |
+
# grid=[1], height="auto", container=True, full_width=True, full_height=True)
|
121 |
+
|
122 |
+
run_button.click(fn=run_grounding, inputs=[
|
123 |
+
input_image, grounding_caption, box_threshold, text_threshold], outputs=[gallery])
|
124 |
+
|
125 |
+
|
126 |
+
block.launch(server_name='0.0.0.0', server_port=7579, debug=args.debug, share=args.share)
|
127 |
+
|
128 |
+
=======
|
129 |
+
=======
|
130 |
+
>>>>>>> e7662d3789ee2d5b878c7399e1f04cb075927919
|
131 |
+
import os
|
132 |
+
import numpy as np
|
133 |
+
import torch
|
134 |
+
from PIL import Image, ImageDraw, ImageFont
|
135 |
+
# please make sure https://github.com/IDEA-Research/GroundingDINO is installed correctly.
|
136 |
+
import groundingdino.datasets.transforms as T
|
137 |
+
from groundingdino.models import build_model
|
138 |
+
from groundingdino.util import box_ops
|
139 |
+
from groundingdino.util.slconfig import SLConfig
|
140 |
+
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
|
141 |
+
from groundingdino.util.vl_utils import create_positive_map_from_span
|
142 |
+
|
143 |
+
|
144 |
+
def plot_boxes_to_image(image_pil, tgt):
|
145 |
+
H, W = tgt["size"]
|
146 |
+
boxes = tgt["boxes"]
|
147 |
+
labels = tgt["labels"]
|
148 |
+
assert len(boxes) == len(labels), "boxes and labels must have same length"
|
149 |
+
|
150 |
+
draw = ImageDraw.Draw(image_pil)
|
151 |
+
mask = Image.new("L", image_pil.size, 0)
|
152 |
+
mask_draw = ImageDraw.Draw(mask)
|
153 |
+
|
154 |
+
# draw boxes and masks
|
155 |
+
for box, label in zip(boxes, labels):
|
156 |
+
# from 0..1 to 0..W, 0..H
|
157 |
+
box = box * torch.Tensor([W, H, W, H])
|
158 |
+
# from xywh to xyxy
|
159 |
+
box[:2] -= box[2:] / 2
|
160 |
+
box[2:] += box[:2]
|
161 |
+
# random color
|
162 |
+
color = tuple(np.random.randint(0, 255, size=3).tolist())
|
163 |
+
# draw
|
164 |
+
x0, y0, x1, y1 = box
|
165 |
+
x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
|
166 |
+
|
167 |
+
draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
|
168 |
+
# draw.text((x0, y0), str(label), fill=color)
|
169 |
+
|
170 |
+
font = ImageFont.load_default()
|
171 |
+
if hasattr(font, "getbbox"):
|
172 |
+
bbox = draw.textbbox((x0, y0), str(label), font)
|
173 |
+
else:
|
174 |
+
w, h = draw.textsize(str(label), font)
|
175 |
+
bbox = (x0, y0, w + x0, y0 + h)
|
176 |
+
# bbox = draw.textbbox((x0, y0), str(label))
|
177 |
+
draw.rectangle(bbox, fill=color)
|
178 |
+
draw.text((x0, y0), str(label), fill="white")
|
179 |
+
|
180 |
+
mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6)
|
181 |
+
|
182 |
+
return image_pil, mask
|
183 |
+
|
184 |
+
|
185 |
+
def load_image(image_path):
|
186 |
+
# load image
|
187 |
+
image_pil = Image.open(image_path).convert("RGB") # load image
|
188 |
+
|
189 |
+
transform = T.Compose(
|
190 |
+
[
|
191 |
+
T.RandomResize([800], max_size=1333),
|
192 |
+
T.ToTensor(),
|
193 |
+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
194 |
+
]
|
195 |
+
)
|
196 |
+
image, _ = transform(image_pil, None) # 3, h, w
|
197 |
+
return image_pil, image
|
198 |
+
|
199 |
+
|
200 |
+
def load_model(model_config_path, model_checkpoint_path, cpu_only=False):
|
201 |
+
args = SLConfig.fromfile(model_config_path)
|
202 |
+
args.device = "cuda" if not cpu_only else "cpu"
|
203 |
+
model = build_model(args)
|
204 |
+
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
205 |
+
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
|
206 |
+
print(load_res)
|
207 |
+
_ = model.eval()
|
208 |
+
return model
|
209 |
+
|
210 |
+
|
211 |
+
def get_grounding_output(model, image, caption, box_threshold, text_threshold=None, with_logits=True, cpu_only=False, token_spans=None):
|
212 |
+
assert text_threshold is not None or token_spans is not None, "text_threshould and token_spans should not be None at the same time!"
|
213 |
+
caption = caption.lower()
|
214 |
+
caption = caption.strip()
|
215 |
+
if not caption.endswith("."):
|
216 |
+
caption = caption + "."
|
217 |
+
device = "cuda" if not cpu_only else "cpu"
|
218 |
+
model = model.to(device)
|
219 |
+
image = image.to(device)
|
220 |
+
with torch.no_grad():
|
221 |
+
outputs = model(image[None], captions=[caption])
|
222 |
+
logits = outputs["pred_logits"].sigmoid()[0] # (nq, 256)
|
223 |
+
boxes = outputs["pred_boxes"][0] # (nq, 4)
|
224 |
+
|
225 |
+
# filter output
|
226 |
+
if token_spans is None:
|
227 |
+
logits_filt = logits.cpu().clone()
|
228 |
+
boxes_filt = boxes.cpu().clone()
|
229 |
+
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
|
230 |
+
logits_filt = logits_filt[filt_mask] # num_filt, 256
|
231 |
+
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
|
232 |
+
|
233 |
+
# get phrase
|
234 |
+
tokenlizer = model.tokenizer
|
235 |
+
tokenized = tokenlizer(caption)
|
236 |
+
# build pred
|
237 |
+
pred_phrases = []
|
238 |
+
for logit, box in zip(logits_filt, boxes_filt):
|
239 |
+
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
|
240 |
+
if with_logits:
|
241 |
+
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
|
242 |
+
else:
|
243 |
+
pred_phrases.append(pred_phrase)
|
244 |
+
else:
|
245 |
+
# given-phrase mode
|
246 |
+
positive_maps = create_positive_map_from_span(
|
247 |
+
model.tokenizer(text_prompt),
|
248 |
+
token_span=token_spans
|
249 |
+
).to(image.device) # n_phrase, 256
|
250 |
+
|
251 |
+
logits_for_phrases = positive_maps @ logits.T # n_phrase, nq
|
252 |
+
all_logits = []
|
253 |
+
all_phrases = []
|
254 |
+
all_boxes = []
|
255 |
+
for (token_span, logit_phr) in zip(token_spans, logits_for_phrases):
|
256 |
+
# get phrase
|
257 |
+
phrase = ' '.join([caption[_s:_e] for (_s, _e) in token_span])
|
258 |
+
# get mask
|
259 |
+
filt_mask = logit_phr > box_threshold
|
260 |
+
# filt box
|
261 |
+
all_boxes.append(boxes[filt_mask])
|
262 |
+
# filt logits
|
263 |
+
all_logits.append(logit_phr[filt_mask])
|
264 |
+
if with_logits:
|
265 |
+
logit_phr_num = logit_phr[filt_mask]
|
266 |
+
all_phrases.extend([phrase + f"({str(logit.item())[:4]})" for logit in logit_phr_num])
|
267 |
+
else:
|
268 |
+
all_phrases.extend([phrase for _ in range(len(filt_mask))])
|
269 |
+
boxes_filt = torch.cat(all_boxes, dim=0).cpu()
|
270 |
+
pred_phrases = all_phrases
|
271 |
+
|
272 |
+
|
273 |
+
return boxes_filt, pred_phrases
|
274 |
+
|
275 |
+
|
276 |
+
if __name__ == "__main__":
|
277 |
+
|
278 |
+
parser = argparse.ArgumentParser("Grounding DINO example", add_help=True)
|
279 |
+
parser.add_argument("--config_file", "-c", type=str, required=True, help="path to config file")
|
280 |
+
parser.add_argument(
|
281 |
+
"--checkpoint_path", "-p", type=str, required=True, help="path to checkpoint file"
|
282 |
+
)
|
283 |
+
parser.add_argument("--image_path", "-i", type=str, required=True, help="path to image file")
|
284 |
+
parser.add_argument("--text_prompt", "-t", type=str, required=True, help="text prompt")
|
285 |
+
parser.add_argument(
|
286 |
+
"--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
|
287 |
+
)
|
288 |
+
|
289 |
+
parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
|
290 |
+
parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
|
291 |
+
parser.add_argument("--token_spans", type=str, default=None, help=
|
292 |
+
"The positions of start and end positions of phrases of interest. \
|
293 |
+
For example, a caption is 'a cat and a dog', \
|
294 |
+
if you would like to detect 'cat', the token_spans should be '[[[2, 5]], ]', since 'a cat and a dog'[2:5] is 'cat'. \
|
295 |
+
if you would like to detect 'a cat', the token_spans should be '[[[0, 1], [2, 5]], ]', since 'a cat and a dog'[0:1] is 'a', and 'a cat and a dog'[2:5] is 'cat'. \
|
296 |
+
")
|
297 |
+
|
298 |
+
parser.add_argument("--cpu-only", action="store_true", help="running on cpu only!, default=False")
|
299 |
+
args = parser.parse_args()
|
300 |
+
|
301 |
+
# cfg
|
302 |
+
config_file = args.config_file # change the path of the model config file
|
303 |
+
checkpoint_path = args.checkpoint_path # change the path of the model
|
304 |
+
image_path = args.image_path
|
305 |
+
text_prompt = args.text_prompt
|
306 |
+
output_dir = args.output_dir
|
307 |
+
box_threshold = args.box_threshold
|
308 |
+
text_threshold = args.text_threshold
|
309 |
+
token_spans = args.token_spans
|
310 |
+
|
311 |
+
# make dir
|
312 |
+
os.makedirs(output_dir, exist_ok=True)
|
313 |
+
# load image
|
314 |
+
image_pil, image = load_image(image_path)
|
315 |
+
# load model
|
316 |
+
model = load_model(config_file, checkpoint_path, cpu_only=args.cpu_only)
|
317 |
+
|
318 |
+
# visualize raw image
|
319 |
+
image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
|
320 |
+
|
321 |
+
# set the text_threshold to None if token_spans is set.
|
322 |
+
if token_spans is not None:
|
323 |
+
text_threshold = None
|
324 |
+
print("Using token_spans. Set the text_threshold to None.")
|
325 |
+
|
326 |
+
|
327 |
+
# run model
|
328 |
+
boxes_filt, pred_phrases = get_grounding_output(
|
329 |
+
model, image, text_prompt, box_threshold, text_threshold, cpu_only=args.cpu_only, token_spans=token_spans
|
330 |
+
)
|
331 |
+
|
332 |
+
# visualize pred
|
333 |
+
size = image_pil.size
|
334 |
+
pred_dict = {
|
335 |
+
"boxes": boxes_filt,
|
336 |
+
"size": [size[1], size[0]], # H,W
|
337 |
+
"labels": pred_phrases,
|
338 |
+
}
|
339 |
+
image_with_box = plot_boxes_to_image(image_pil, pred_dict)[0]
|
340 |
+
save_path = os.path.join(output_dir, "pred.jpg")
|
341 |
+
image_with_box.save(save_path)
|
342 |
+
print(f"\n======================\n{save_path} saved.\nThe program runs successfully!")
|
343 |
+
<<<<<<< HEAD
|
344 |
+
>>>>>>> e7662d3789ee2d5b878c7399e1f04cb075927919
|
345 |
+
=======
|
346 |
+
>>>>>>> e7662d3789ee2d5b878c7399e1f04cb075927919
|