Spaces:
Runtime error
Runtime error
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection | |
from PIL import Image | |
import time | |
def extract_model_short_name(model_id): | |
return model_id.split("/")[-1].replace("-", " ").replace("_", " ") | |
model_llmdet_id = "iSEE-Laboratory/llmdet_tiny" | |
model_mm_grounding_id = "rziga/mm_grounding_dino_tiny_o365v1_goldg" | |
model_omdet_id = "omlab/omdet-turbo-swin-tiny-hf" | |
model_owlv2_id = "google/owlv2-large-patch14-ensemble" | |
model_llmdet_name = extract_model_short_name(model_llmdet_id) | |
model_mm_grounding_name = extract_model_short_name(model_mm_grounding_id) | |
model_omdet_name = extract_model_short_name(model_omdet_id) | |
model_owlv2_name = extract_model_short_name(model_owlv2_id) | |
def detect_omdet(image: Image.Image, prompts: list, threshold: float): | |
t0 = time.perf_counter() | |
model_id = model_omdet_id | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
processor = AutoProcessor.from_pretrained(model_id) | |
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device).eval() | |
texts = [prompts] | |
inputs = processor(images=image, text=texts, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
results = processor.post_process_grounded_object_detection( | |
outputs, | |
threshold=threshold, | |
target_sizes=[image.size[::-1]] | |
) | |
result = results[0] | |
annotations = [] | |
raw_results = [] | |
for box, score, label in zip(result["boxes"], result["scores"], result["labels"]): | |
if score >= threshold: | |
label_name = prompts[label] | |
xmin, ymin, xmax, ymax = [int(x) for x in box.tolist()] | |
annotations.append(((xmin, ymin, xmax, ymax), f"{label_name} {score:.2f}")) | |
raw_results.append(f"Detected {label_name} with confidence {score:.2f} at location [{xmin}, {ymin}, {xmax}, {ymax}]") | |
elapsed_ms = (time.perf_counter() - t0) * 1000 | |
time_taken = f"**Inference time ({model_omdet_name}):** {elapsed_ms:.0f} ms" | |
raw_text = "\n".join(raw_results) if raw_results else "No detections" | |
return annotations, raw_text, time_taken | |
def detect_llmdet(image: Image.Image, prompts: list, threshold: float): | |
t0 = time.perf_counter() | |
model_id = model_llmdet_id | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
processor = AutoProcessor.from_pretrained(model_id) | |
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device).eval() | |
texts = [prompts] | |
inputs = processor(images=image, text=texts, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
results = processor.post_process_grounded_object_detection( | |
outputs, | |
threshold=threshold, | |
target_sizes=[image.size[::-1]] | |
) | |
result = results[0] | |
annotations = [] | |
raw_results = [] | |
for box, score, label in zip(result["boxes"], result["scores"], result["labels"]): | |
if score >= threshold: | |
xmin, ymin, xmax, ymax = [int(x) for x in box.tolist()] | |
annotations.append(((xmin, ymin, xmax, ymax), f"{label} {score:.2f}")) | |
raw_results.append(f"Detected {label} with confidence {score:.2f} at location [{xmin}, {ymin}, {xmax}, {ymax}]") | |
elapsed_ms = (time.perf_counter() - t0) * 1000 | |
time_taken = f"**Inference time ({model_llmdet_name}):** {elapsed_ms:.0f} ms" | |
raw_text = "\n".join(raw_results) if raw_results else "No detections" | |
return annotations, raw_text, time_taken | |
def detect_mm_grounding(image: Image.Image, prompts: list, threshold: float): | |
t0 = time.perf_counter() | |
model_id = model_mm_grounding_id | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
processor = AutoProcessor.from_pretrained(model_id) | |
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device).eval() | |
texts = [prompts] | |
inputs = processor(images=image, text=texts, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
results = processor.post_process_grounded_object_detection( | |
outputs, | |
threshold=threshold, | |
target_sizes=[image.size[::-1]] | |
) | |
result = results[0] | |
annotations = [] | |
raw_results = [] | |
for box, score, label in zip(result["boxes"], result["scores"], result["labels"]): | |
if score >= threshold: | |
xmin, ymin, xmax, ymax = [int(x) for x in box.tolist()] | |
annotations.append(((xmin, ymin, xmax, ymax), f"{label} {score:.2f}")) | |
raw_results.append(f"Detected {label} with confidence {score:.2f} at location [{xmin}, {ymin}, {xmax}, {ymax}]") | |
elapsed_ms = (time.perf_counter() - t0) * 1000 | |
time_taken = f"**Inference time ({model_mm_grounding_name}):** {elapsed_ms:.0f} ms" | |
raw_text = "\n".join(raw_results) if raw_results else "No detections" | |
return annotations, raw_text, time_taken | |
def detect_owlv2(image: Image.Image, prompts: list, threshold: float): | |
t0 = time.perf_counter() | |
model_id = model_owlv2_id | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
processor = AutoProcessor.from_pretrained(model_id) | |
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device).eval() | |
texts = [prompts] | |
inputs = processor(images=image, text=texts, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
results = processor.post_process_grounded_object_detection( | |
outputs, | |
threshold=threshold, | |
target_sizes=[image.size[::-1]] | |
) | |
result = results[0] | |
annotations = [] | |
raw_results = [] | |
for box, score, label in zip(result["boxes"], result["scores"], result["labels"]): | |
if score >= threshold: | |
label_name = prompts[label] | |
xmin, ymin, xmax, ymax = [int(x) for x in box.tolist()] | |
annotations.append(((xmin, ymin, xmax, ymax), f"{label_name} {score:.2f}")) | |
raw_results.append(f"Detected {label_name} with confidence {score:.2f} at location [{xmin}, {ymin}, {xmax}, {ymax}]") | |
elapsed_ms = (time.perf_counter() - t0) * 1000 | |
time_taken = f"**Inference time ({model_owlv2_name}):** {elapsed_ms:.0f} ms" | |
raw_text = "\n".join(raw_results) if raw_results else "No detections" | |
return annotations, raw_text, time_taken | |
def run_detection(image, prompts_str, threshold_llm, threshold_mm, threshold_owlv2, threshold_omdet): | |
if image is None: | |
return (None, []), "No detections", "", (None, []), "No detections", "" | |
prompts = [p.strip() for p in prompts_str.split(",")] | |
ann_llm, raw_llm, time_llm = detect_llmdet(image, prompts, threshold_llm) | |
ann_mm, raw_mm, time_mm = detect_mm_grounding(image, prompts, threshold_mm) | |
ann_owlv2, raw_owlv2, time_owlv2 = detect_owlv2(image, prompts, threshold_owlv2) | |
ann_omdet, raw_omdet, time_omdet = detect_omdet(image, prompts, threshold_omdet) | |
return (image, ann_llm), raw_llm, time_llm, (image, ann_mm), raw_mm, time_mm, (image, ann_owlv2), raw_owlv2, time_owlv2, (image, ann_omdet), raw_omdet, time_omdet | |
with gr.Blocks() as app: | |
gr.Markdown("# Zero-Shot Object Detection Arena") | |
gr.Markdown("### Compare different zero-shot object detection models on the same image and prompts.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image = gr.Image(type="pil", label="Upload an image", height=400) | |
prompts = gr.Textbox(label="Prompts (comma-separated)", value="a cat, a remote control") | |
with gr.Accordion("Per-model confidence thresholds", open=True): | |
threshold_llm = gr.Slider(label="Threshold for LLMDet", minimum=0.0, maximum=1.0, value=0.3) | |
threshold_mm = gr.Slider(label="Threshold for MM GroundingDINO Tiny", minimum=0.0, maximum=1.0, value=0.3) | |
threshold_owlv2 = gr.Slider(label="Threshold for OwlV2 Large", minimum=0.0, maximum=1.0, value=0.1) | |
threshold_omdet = gr.Slider(label="Threshold for OMDet Turbo Swin Tiny", minimum=0.0, maximum=1.0, value=0.2) | |
generate_btn = gr.Button(value="Detect") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
output_image_llm = gr.AnnotatedImage(label=f"Annotated image for {model_llmdet_name}", height=400) | |
output_text_llm = gr.Textbox(label=f"Model detections for {model_llmdet_name}", lines=5) | |
output_time_llm = gr.Markdown() | |
with gr.Column(scale=2): | |
output_image_mm = gr.AnnotatedImage(label=f"Annotated image for {model_mm_grounding_name}", height=400) | |
output_text_mm = gr.Textbox(label=f"Model detections for {model_mm_grounding_name}", lines=5) | |
output_time_mm = gr.Markdown() | |
with gr.Row(): | |
with gr.Column(scale=2): | |
output_image_owlv2 = gr.AnnotatedImage(label=f"Annotated image for {model_owlv2_name}", height=400) | |
output_text_owlv2 = gr.Textbox(label=f"Model detections for {model_owlv2_name}", lines=5) | |
output_time_owlv2 = gr.Markdown() | |
with gr.Column(scale=2): | |
output_image_omdet = gr.AnnotatedImage(label=f"Annotated image for {model_omdet_name}", height=400) | |
output_text_omdet = gr.Textbox(label=f"Model detections for {model_omdet_name}", lines=5) | |
output_time_omdet = gr.Markdown() | |
gr.Markdown("### Examples") | |
example_data = [ | |
["http://images.cocodataset.org/val2017/000000039769.jpg", "a cat, a remote control", 0.30, 0.30, 0.10, 0.30], | |
["http://images.cocodataset.org/val2017/000000000139.jpg", "a person, a tv, a remote", 0.35, 0.30, 0.12, 0.30], | |
] | |
gr.Examples( | |
examples=example_data, | |
inputs=[image, prompts, threshold_llm, threshold_mm, threshold_owlv2, threshold_omdet], | |
label="Click an example to populate the inputs", | |
) | |
inputs = [image, prompts, threshold_llm, threshold_mm, threshold_owlv2, threshold_omdet] | |
outputs = [output_image_llm, output_text_llm, output_time_llm, output_image_mm, output_text_mm, output_time_mm, output_image_owlv2, output_text_owlv2, output_time_owlv2, output_image_omdet, output_text_omdet, output_time_omdet] | |
generate_btn.click( | |
fn=run_detection, | |
inputs=inputs, | |
outputs=outputs, | |
) | |
image.upload( | |
fn=run_detection, | |
inputs=inputs, | |
outputs=outputs, | |
) | |
app.launch() |