ariG23498 HF Staff commited on
Commit
9052bb3
·
verified ·
1 Parent(s): 2eeb110

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
5
+ from PIL import Image
6
+ import time
7
+
8
+ def extract_model_short_name(model_id):
9
+ return model_id.split("/")[-1].replace("-", " ").replace("_", " ")
10
+
11
+ model_llmdet_id = "iSEE-Laboratory/llmdet_tiny"
12
+ model_mm_grounding_id = "rziga/mm_grounding_dino_tiny_o365v1_goldg"
13
+
14
+ model_llmdet_name = extract_model_short_name(model_llmdet_id)
15
+ model_mm_grounding_name = extract_model_short_name(model_mm_grounding_id)
16
+
17
+ def detect_llmdet(image: Image.Image, prompts: list, threshold: float):
18
+ t0 = time.perf_counter()
19
+ model_id = model_llmdet_id
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ processor = AutoProcessor.from_pretrained(model_id)
22
+ model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device).eval()
23
+ texts = [prompts]
24
+ inputs = processor(images=image, text=texts, return_tensors="pt").to(device)
25
+ with torch.no_grad():
26
+ outputs = model(**inputs)
27
+ results = processor.post_process_grounded_object_detection(
28
+ outputs,
29
+ threshold=threshold,
30
+ target_sizes=[image.size[::-1]]
31
+ )
32
+ result = results[0]
33
+ annotations = []
34
+ raw_results = []
35
+ for box, score, label in zip(result["boxes"], result["scores"], result["labels"]):
36
+ if score >= threshold:
37
+ xmin, ymin, xmax, ymax = [int(x) for x in box.tolist()]
38
+ annotations.append(((xmin, ymin, xmax, ymax), f"{label} {score:.2f}"))
39
+ raw_results.append(f"Detected {label} with confidence {score:.2f} at location [{xmin}, {ymin}, {xmax}, {ymax}]")
40
+ elapsed_ms = (time.perf_counter() - t0) * 1000
41
+ time_taken = f"**Inference time ({model_llmdet_name}):** {elapsed_ms:.0f} ms"
42
+ raw_text = "\n".join(raw_results) if raw_results else "No detections"
43
+ return annotations, raw_text, time_taken
44
+
45
+ def detect_mm_grounding(image: Image.Image, prompts: list, threshold: float):
46
+ t0 = time.perf_counter()
47
+ model_id = model_mm_grounding_id
48
+ device = "cuda" if torch.cuda.is_available() else "cpu"
49
+ processor = AutoProcessor.from_pretrained(model_id)
50
+ model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device).eval()
51
+ texts = [prompts]
52
+ inputs = processor(images=image, text=texts, return_tensors="pt").to(device)
53
+ with torch.no_grad():
54
+ outputs = model(**inputs)
55
+ results = processor.post_process_grounded_object_detection(
56
+ outputs,
57
+ threshold=threshold,
58
+ target_sizes=[image.size[::-1]]
59
+ )
60
+ result = results[0]
61
+ annotations = []
62
+ raw_results = []
63
+ for box, score, label in zip(result["boxes"], result["scores"], result["labels"]):
64
+ if score >= threshold:
65
+ xmin, ymin, xmax, ymax = [int(x) for x in box.tolist()]
66
+ annotations.append(((xmin, ymin, xmax, ymax), f"{label} {score:.2f}"))
67
+ raw_results.append(f"Detected {label} with confidence {score:.2f} at location [{xmin}, {ymin}, {xmax}, {ymax}]")
68
+ elapsed_ms = (time.perf_counter() - t0) * 1000
69
+ time_taken = f"**Inference time ({model_mm_grounding_name}):** {elapsed_ms:.0f} ms"
70
+ raw_text = "\n".join(raw_results) if raw_results else "No detections"
71
+ return annotations, raw_text, time_taken
72
+
73
+ @spaces.GPU
74
+ def run_detection(image, prompts_str, threshold):
75
+ if image is None:
76
+ return (None, []), "No detections", "", (None, []), "No detections", ""
77
+ prompts = [p.strip() for p in prompts_str.split(",")]
78
+ ann_llm, raw_llm, time_llm = detect_llmdet(image, prompts, threshold)
79
+ ann_mm, raw_mm, time_mm = detect_mm_grounding(image, prompts, threshold)
80
+ return (image, ann_llm), raw_llm, time_llm, (image, ann_mm), raw_mm, time_mm
81
+
82
+ with gr.Blocks() as app:
83
+ gr.Markdown("# Zero-Shot Object Detection Arena")
84
+ gr.Markdown("### Compare different zero-shot object detection models on the same image and prompts.")
85
+ with gr.Row():
86
+ with gr.Column(scale=1):
87
+ image = gr.Image(type="pil", label="Upload an image", height=400)
88
+ prompts = gr.Textbox(label="Prompts (comma-separated)", value="a cat, a remote control")
89
+ threshold = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, step=0.05, value=0.30)
90
+ generate_btn = gr.Button(value="Detect")
91
+ with gr.Column(scale=2):
92
+ output_image_llm = gr.AnnotatedImage(label=f"Annotated image for {model_llmdet_name}", height=400)
93
+ output_text_llm = gr.Textbox(label=f"Model detections for {model_llmdet_name}", lines=10)
94
+ output_time_llm = gr.Markdown()
95
+ with gr.Column(scale=2):
96
+ output_image_mm = gr.AnnotatedImage(label=f"Annotated image for {model_mm_grounding_name}", height=400)
97
+ output_text_mm = gr.Textbox(label=f"Model detections for {model_mm_grounding_name}", lines=10)
98
+ output_time_mm = gr.Markdown()
99
+ gr.Markdown("### Examples")
100
+ example_data = [
101
+ ["http://images.cocodataset.org/val2017/000000039769.jpg", "a cat, a remote control", 0.4],
102
+ ["http://images.cocodataset.org/val2017/000000000139.jpg", "a person, a tv, a remote", 0.3],
103
+ ]
104
+ gr.Examples(
105
+ examples=example_data,
106
+ inputs=[image, prompts, threshold],
107
+ label="Click an example to populate the input",
108
+ )
109
+ generate_btn.click(
110
+ fn=run_detection,
111
+ inputs=[image, prompts, threshold],
112
+ outputs=[output_image_llm, output_text_llm, output_time_llm, output_image_mm, output_text_mm, output_time_mm],
113
+ )
114
+ image.upload(
115
+ fn=run_detection,
116
+ inputs=[image, prompts, threshold],
117
+ outputs=[output_image_llm, output_text_llm, output_time_llm, output_image_mm, output_text_mm, output_time_mm],
118
+ )
119
+
120
+ app.launch()