import time from dataclasses import dataclass from typing import List, Tuple import gradio as gr import spaces import torch from PIL import Image from transformers import ( AutoProcessor, AutoModelForZeroShotObjectDetection, ) # --------------------------- # Setup # --------------------------- def extract_model_short_name(model_id: str) -> str: return model_id.split("/")[-1].replace("-", " ").replace("_", " ") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Model bundles for cleaner wiring @dataclass class ZSDetBundle: model_id: str model_name: str processor: AutoProcessor model: AutoModelForZeroShotObjectDetection # LLMDet model_llmdet_id = "iSEE-Laboratory/llmdet_tiny" processor_llmdet = AutoProcessor.from_pretrained(model_llmdet_id) model_llmdet = AutoModelForZeroShotObjectDetection.from_pretrained(model_llmdet_id) bundle_llmdet = ZSDetBundle( model_id=model_llmdet_id, model_name=extract_model_short_name(model_llmdet_id), processor=processor_llmdet, model=model_llmdet, ) # MM GroundingDINO model_mm_grounding_id = "rziga/mm_grounding_dino_tiny_o365v1_goldg" processor_mm_grounding = AutoProcessor.from_pretrained(model_mm_grounding_id) model_mm_grounding = AutoModelForZeroShotObjectDetection.from_pretrained(model_mm_grounding_id) bundle_mm_grounding = ZSDetBundle( model_id=model_mm_grounding_id, model_name=extract_model_short_name(model_mm_grounding_id), processor=processor_mm_grounding, model=model_mm_grounding, ) # OMDet Turbo model_omdet_id = "omlab/omdet-turbo-swin-tiny-hf" processor_omdet = AutoProcessor.from_pretrained(model_omdet_id) model_omdet = AutoModelForZeroShotObjectDetection.from_pretrained(model_omdet_id) bundle_omdet = ZSDetBundle( model_id=model_omdet_id, model_name=extract_model_short_name(model_omdet_id), processor=processor_omdet, model=model_omdet, ) # OWLv2 model_owlv2_id = "google/owlv2-large-patch14-ensemble" processor_owlv2 = AutoProcessor.from_pretrained(model_owlv2_id) model_owlv2 = AutoModelForZeroShotObjectDetection.from_pretrained(model_owlv2_id) bundle_owlv2 = ZSDetBundle( model_id=model_owlv2_id, model_name=extract_model_short_name(model_owlv2_id), processor=processor_owlv2, model=model_owlv2, ) # --------------------------- # Inference # --------------------------- @spaces.GPU def detect( bundle: ZSDetBundle, image: Image.Image, prompts: List[str], threshold: float, ) -> Tuple[List[Tuple[Tuple[int, int, int, int], str]], str]: """ Returns [(bbox, label_score_str), ...], time_str """ t0 = time.perf_counter() device = "cuda" if torch.cuda.is_available() else "cpu" # HF zero-shot OD expects list-of-list text texts = [prompts] inputs = bundle.processor(images=image, text=texts, return_tensors="pt").to(device) model = bundle.model.to(device).eval() with torch.inference_mode(): outputs = model(**inputs) results = bundle.processor.post_process_grounded_object_detection( outputs, threshold=threshold, target_sizes=[image.size[::-1]], text_labels=texts, )[0] annotations = [] for box, score, label_name in zip(results["boxes"], results["scores"], results["text_labels"]): if float(score) < threshold: continue xmin, ymin, xmax, ymax = map(lambda v: int(v), box.tolist()) annotations.append(((xmin, ymin, xmax, ymax), f"{label_name} {float(score):.2f}")) elapsed_ms = (time.perf_counter() - t0) * 1000 time_taken = f"**Inference time ({bundle.model_name}):** {elapsed_ms:.0f} ms" return annotations, time_taken def parse_prompts(prompts_str: str) -> List[str]: return [p.strip() for p in prompts_str.split(",") if p.strip()] def run_detection( image: Image.Image, prompts_str: str, threshold_llm: float, threshold_mm: float, threshold_owlv2: float, threshold_omdet: float, ): prompts = parse_prompts(prompts_str) ann_llm, time_llm = detect(bundle_llmdet, image, prompts, threshold_llm) ann_mm, time_mm = detect(bundle_mm_grounding, image, prompts, threshold_mm) ann_owlv2, time_owlv2 = detect(bundle_owlv2, image, prompts, threshold_owlv2) ann_omdet, time_omdet = detect(bundle_omdet, image, prompts, threshold_omdet) return ( (image, ann_llm), time_llm, (image, ann_mm), time_mm, (image, ann_owlv2), time_owlv2, (image, ann_omdet), time_omdet, ) # --------------------------- # Compact Description # --------------------------- description_md = """ # Zero-Shot Object Detection Arena Compare **four zero-shot object detectors** on the same image + prompts. Upload an image (or pick an example), add **comma-separated prompts**, tweak per-model **thresholds**, and hit **Detect**. You'll see bounding boxes, scores, and **per-model inference time**. **Models** - LLMDet Tiny — [`iSEE-Laboratory/llmdet_tiny`](https://huggingface.co/iSEE-Laboratory/llmdet_tiny) - MM GroundingDINO Tiny O365v1 GoldG — [`rziga/mm_grounding_dino_tiny_o365v1_goldg`](https://huggingface.co/rziga/mm_grounding_dino_tiny_o365v1_goldg) - OMDet Turbo Swin Tiny — [`omlab/omdet-turbo-swin-tiny-hf`](https://huggingface.co/omlab/omdet-turbo-swin-tiny-hf) - OWL-V2 Large Patch14 Ensemble — [`google/owlv2-large-patch14-ensemble`](https://huggingface.co/google/owlv2-large-patch14-ensemble) **Tip:** Lower thresholds ↑ recall but may ↑ false positives. """ # --------------------------- # UI # --------------------------- with gr.Blocks() as app: gr.Markdown(description_md) with gr.Row(): with gr.Column(scale=1): image = gr.Image(type="pil", label="Upload an image", height=400) prompts = gr.Textbox( label="Prompts (comma-separated)", value="a cat, a remote control", placeholder="e.g., a cat, a remote control", ) with gr.Accordion("Per-model confidence thresholds", open=True): threshold_llm = gr.Slider(label=f"Threshold — {bundle_llmdet.model_name}", minimum=0.0, maximum=1.0, value=0.3) threshold_mm = gr.Slider(label=f"Threshold — {bundle_mm_grounding.model_name}", minimum=0.0, maximum=1.0, value=0.3) threshold_owlv2 = gr.Slider(label=f"Threshold — {bundle_owlv2.model_name}", minimum=0.0, maximum=1.0, value=0.1) threshold_omdet = gr.Slider(label=f"Threshold — {bundle_omdet.model_name}", minimum=0.0, maximum=1.0, value=0.2) generate_btn = gr.Button(value="Detect") with gr.Row(): with gr.Column(scale=2): output_image_llm = gr.AnnotatedImage(label=f"Annotated — {bundle_llmdet.model_name}", height=400) output_time_llm = gr.Markdown() with gr.Column(scale=2): output_image_mm = gr.AnnotatedImage(label=f"Annotated — {bundle_mm_grounding.model_name}", height=400) output_time_mm = gr.Markdown() with gr.Row(): with gr.Column(scale=2): output_image_owlv2 = gr.AnnotatedImage(label=f"Annotated — {bundle_owlv2.model_name}", height=400) output_time_owlv2 = gr.Markdown() with gr.Column(scale=2): output_image_omdet = gr.AnnotatedImage(label=f"Annotated — {bundle_omdet.model_name}", height=400) output_time_omdet = gr.Markdown() gr.Markdown("### Examples") example_data = [ ["http://images.cocodataset.org/val2017/000000039769.jpg", "a cat, a remote control", 0.30, 0.30, 0.10, 0.30], ["http://images.cocodataset.org/val2017/000000000139.jpg", "a person, a tv, a remote", 0.35, 0.30, 0.12, 0.30], ] gr.Examples( examples=example_data, inputs=[image, prompts, threshold_llm, threshold_mm, threshold_owlv2, threshold_omdet], label="Click an example to populate the inputs", ) inputs = [image, prompts, threshold_llm, threshold_mm, threshold_owlv2, threshold_omdet] outputs = [ output_image_llm, output_time_llm, output_image_mm, output_time_mm, output_image_owlv2, output_time_owlv2, output_image_omdet, output_time_omdet, ] generate_btn.click(fn=run_detection, inputs=inputs, outputs=outputs) image.upload(fn=run_detection, inputs=inputs, outputs=outputs) # Optional: queue to handle multiple users gracefully (tune as needed) app.launch()