Hasanmog commited on
Commit
a844d95
·
1 Parent(s): 58e8422

Add application file

Browse files
Files changed (1) hide show
  1. gradio_app.py +125 -0
gradio_app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from functools import partial
3
+ import cv2
4
+ import requests
5
+ import os
6
+ from io import BytesIO
7
+ from PIL import Image
8
+ import numpy as np
9
+ from pathlib import Path
10
+
11
+
12
+ import warnings
13
+
14
+ import torch
15
+
16
+ # prepare the environment
17
+ os.system("python setup.py build develop --user")
18
+ os.system("pip install packaging==21.3")
19
+ os.system("pip install gradio==3.50.2")
20
+
21
+
22
+ warnings.filterwarnings("ignore")
23
+
24
+ import gradio as gr
25
+
26
+ from groundingdino.models import build_model
27
+ from groundingdino.util.slconfig import SLConfig
28
+ from groundingdino.util.utils import clean_state_dict
29
+ from groundingdino.util.inference import annotate, load_image, predict
30
+ import groundingdino.datasets.transforms as T
31
+
32
+ from huggingface_hub import hf_hub_download
33
+
34
+
35
+
36
+ # Use this command for evaluate the Grounding DINO model
37
+ config_file = "../config/cfg_odvg.py"
38
+ ckpt_repo_id = "Hasanmog/Peft-GroundingDINO"
39
+ ckpt_filenmae = "../Best.pth"
40
+
41
+
42
+ def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
43
+ args = SLConfig.fromfile(model_config_path)
44
+ model = build_model(args)
45
+ args.device = device
46
+
47
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
48
+ checkpoint = torch.load(cache_file, map_location='cpu')
49
+ log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
50
+ print("Model loaded from {} \n => {}".format(cache_file, log))
51
+ _ = model.eval()
52
+ return model
53
+
54
+ def image_transform_grounding(init_image):
55
+ transform = T.Compose([
56
+ T.RandomResize([800], max_size=1333),
57
+ T.ToTensor(),
58
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
59
+ ])
60
+ image, _ = transform(init_image, None) # 3, h, w
61
+ return init_image, image
62
+
63
+ def image_transform_grounding_for_vis(init_image):
64
+ transform = T.Compose([
65
+ T.RandomResize([800], max_size=1333),
66
+ ])
67
+ image, _ = transform(init_image, None) # 3, h, w
68
+ return image
69
+
70
+ model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
71
+
72
+ def run_grounding(input_image, grounding_caption, box_threshold, text_threshold):
73
+ init_image = input_image.convert("RGB")
74
+ original_size = init_image.size
75
+
76
+ _, image_tensor = image_transform_grounding(init_image)
77
+ image_pil: Image = image_transform_grounding_for_vis(init_image)
78
+
79
+ # run grounidng
80
+ boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
81
+ annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
82
+ image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
83
+
84
+
85
+ return image_with_box
86
+
87
+ if __name__ == "__main__":
88
+
89
+ parser = argparse.ArgumentParser("Grounding DINO demo", add_help=True)
90
+ parser.add_argument("--debug", action="store_true", help="using debug mode")
91
+ parser.add_argument("--share", action="store_true", help="share the app")
92
+ args = parser.parse_args()
93
+
94
+ block = gr.Blocks().queue()
95
+ with block:
96
+ gr.Markdown("# [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO)")
97
+ gr.Markdown("### Open-World Detection with Grounding DINO")
98
+
99
+ with gr.Row():
100
+ with gr.Column():
101
+ input_image = gr.Image(source='upload', type="pil")
102
+ grounding_caption = gr.Textbox(label="Detection Prompt")
103
+ run_button = gr.Button(label="Run")
104
+ with gr.Accordion("Advanced options", open=False):
105
+ box_threshold = gr.Slider(
106
+ label="Box Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
107
+ )
108
+ text_threshold = gr.Slider(
109
+ label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
110
+ )
111
+
112
+ with gr.Column():
113
+ gallery = gr.outputs.Image(
114
+ type="pil",
115
+ # label="grounding results"
116
+ ).style(full_width=True, full_height=True)
117
+ # gallery = gr.Gallery(label="Generated images", show_label=False).style(
118
+ # grid=[1], height="auto", container=True, full_width=True, full_height=True)
119
+
120
+ run_button.click(fn=run_grounding, inputs=[
121
+ input_image, grounding_caption, box_threshold, text_threshold], outputs=[gallery])
122
+
123
+
124
+ block.launch(server_name='0.0.0.0', server_port=7579, debug=args.debug, share=args.share)
125
+