ariG23498 HF Staff commited on
Commit
cb90111
·
1 Parent(s): 9fc3606

label id vs label name

Browse files
Files changed (1) hide show
  1. app.py +46 -19
app.py CHANGED
@@ -1,7 +1,12 @@
1
  import gradio as gr
2
  import spaces
3
  import torch
4
- from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
 
 
 
 
 
5
  from PIL import Image
6
  import time
7
 
@@ -12,27 +17,21 @@ def extract_model_short_name(model_id):
12
 
13
  model_llmdet_id = "iSEE-Laboratory/llmdet_tiny"
14
  processor_llmdet = AutoProcessor.from_pretrained(model_llmdet_id)
15
- model_llmdet = (
16
- AutoModelForZeroShotObjectDetection.from_pretrained(model_llmdet_id)
17
- )
18
 
19
  model_mm_grounding_id = "rziga/mm_grounding_dino_tiny_o365v1_goldg"
20
  processor_mm_grounding = AutoProcessor.from_pretrained(model_mm_grounding_id)
21
- model_mm_grounding = (
22
- AutoModelForZeroShotObjectDetection.from_pretrained(model_mm_grounding_id)
23
  )
24
 
25
  model_omdet_id = "omlab/omdet-turbo-swin-tiny-hf"
26
  processor_omdet = AutoProcessor.from_pretrained(model_omdet_id)
27
- model_omdet = (
28
- AutoModelForZeroShotObjectDetection.from_pretrained(model_omdet_id)
29
- )
30
 
31
  model_owlv2_id = "google/owlv2-large-patch14-ensemble"
32
  processor_owlv2 = AutoProcessor.from_pretrained(model_owlv2_id)
33
- model_owlv2 = (
34
- AutoModelForZeroShotObjectDetection.from_pretrained(model_owlv2_id)
35
- )
36
 
37
  model_llmdet_name = extract_model_short_name(model_llmdet_id)
38
  model_mm_grounding_name = extract_model_short_name(model_mm_grounding_id)
@@ -44,7 +43,7 @@ model_owlv2_name = extract_model_short_name(model_owlv2_id)
44
  def detect(model, processor, image: Image.Image, prompts: list, threshold: float):
45
  t0 = time.perf_counter()
46
  device = "cuda" if torch.cuda.is_available() else "cpu"
47
- model = model.to(device).eval()
48
  texts = [prompts]
49
  inputs = processor(images=image, text=texts, return_tensors="pt").to(device)
50
  with torch.inference_mode():
@@ -54,8 +53,23 @@ def detect(model, processor, image: Image.Image, prompts: list, threshold: float
54
  )
55
  result = results[0]
56
  annotations = []
57
- for box, score, label_name in zip(result["boxes"], result["scores"], result["text_labels"]):
 
 
 
 
 
 
 
 
 
 
58
  if score >= threshold:
 
 
 
 
 
59
  xmin, ymin, xmax, ymax = [int(x) for x in box.tolist()]
60
  annotations.append(((xmin, ymin, xmax, ymax), f"{label_name} {score:.2f}"))
61
  elapsed_ms = (time.perf_counter() - t0) * 1000
@@ -64,13 +78,26 @@ def detect(model, processor, image: Image.Image, prompts: list, threshold: float
64
 
65
 
66
  def run_detection(
67
- image: Image.Image, prompts_str: str, threshold_llm, threshold_mm, threshold_owlv2, threshold_omdet,
 
 
 
 
 
68
  ):
69
  prompts = [p.strip() for p in prompts_str.split(",")]
70
- ann_llm, time_llm = detect(model_llmdet, processor_llmdet, image, prompts, threshold_llm)
71
- ann_mm, time_mm = detect(model_mm_grounding, processor_mm_grounding, image, prompts, threshold_mm)
72
- ann_owlv2, time_owlv2 = detect(model_owlv2, processor_owlv2, image, prompts, threshold_owlv2)
73
- ann_omdet, time_omdet = detect(model_omdet, processor_omdet, image, prompts, threshold_omdet)
 
 
 
 
 
 
 
 
74
  return (
75
  (image, ann_llm),
76
  time_llm,
 
1
  import gradio as gr
2
  import spaces
3
  import torch
4
+ from transformers import (
5
+ AutoProcessor,
6
+ AutoModelForZeroShotObjectDetection,
7
+ Owlv2ForObjectDetection,
8
+ OmDetTurboForObjectDetection,
9
+ )
10
  from PIL import Image
11
  import time
12
 
 
17
 
18
  model_llmdet_id = "iSEE-Laboratory/llmdet_tiny"
19
  processor_llmdet = AutoProcessor.from_pretrained(model_llmdet_id)
20
+ model_llmdet = AutoModelForZeroShotObjectDetection.from_pretrained(model_llmdet_id)
 
 
21
 
22
  model_mm_grounding_id = "rziga/mm_grounding_dino_tiny_o365v1_goldg"
23
  processor_mm_grounding = AutoProcessor.from_pretrained(model_mm_grounding_id)
24
+ model_mm_grounding = AutoModelForZeroShotObjectDetection.from_pretrained(
25
+ model_mm_grounding_id
26
  )
27
 
28
  model_omdet_id = "omlab/omdet-turbo-swin-tiny-hf"
29
  processor_omdet = AutoProcessor.from_pretrained(model_omdet_id)
30
+ model_omdet = AutoModelForZeroShotObjectDetection.from_pretrained(model_omdet_id)
 
 
31
 
32
  model_owlv2_id = "google/owlv2-large-patch14-ensemble"
33
  processor_owlv2 = AutoProcessor.from_pretrained(model_owlv2_id)
34
+ model_owlv2 = AutoModelForZeroShotObjectDetection.from_pretrained(model_owlv2_id)
 
 
35
 
36
  model_llmdet_name = extract_model_short_name(model_llmdet_id)
37
  model_mm_grounding_name = extract_model_short_name(model_mm_grounding_id)
 
43
  def detect(model, processor, image: Image.Image, prompts: list, threshold: float):
44
  t0 = time.perf_counter()
45
  device = "cuda" if torch.cuda.is_available() else "cpu"
46
+ model.to(device).eval()
47
  texts = [prompts]
48
  inputs = processor(images=image, text=texts, return_tensors="pt").to(device)
49
  with torch.inference_mode():
 
53
  )
54
  result = results[0]
55
  annotations = []
56
+
57
+ if isinstance(model, Owlv2ForObjectDetection) or isinstance(
58
+ model, OmDetTurboForObjectDetection
59
+ ):
60
+ key = "labels"
61
+ check = True
62
+ else:
63
+ key = "text_labels"
64
+ check = False
65
+
66
+ for box, score, label in zip(result["boxes"], result["scores"], result[key]):
67
  if score >= threshold:
68
+ if check:
69
+ label_id = label
70
+ label_name = prompts[label_id]
71
+ else:
72
+ label_name = label
73
  xmin, ymin, xmax, ymax = [int(x) for x in box.tolist()]
74
  annotations.append(((xmin, ymin, xmax, ymax), f"{label_name} {score:.2f}"))
75
  elapsed_ms = (time.perf_counter() - t0) * 1000
 
78
 
79
 
80
  def run_detection(
81
+ image: Image.Image,
82
+ prompts_str: str,
83
+ threshold_llm,
84
+ threshold_mm,
85
+ threshold_owlv2,
86
+ threshold_omdet,
87
  ):
88
  prompts = [p.strip() for p in prompts_str.split(",")]
89
+ ann_llm, time_llm = detect(
90
+ model_llmdet, processor_llmdet, image, prompts, threshold_llm
91
+ )
92
+ ann_mm, time_mm = detect(
93
+ model_mm_grounding, processor_mm_grounding, image, prompts, threshold_mm
94
+ )
95
+ ann_owlv2, time_owlv2 = detect(
96
+ model_owlv2, processor_owlv2, image, prompts, threshold_owlv2
97
+ )
98
+ ann_omdet, time_omdet = detect(
99
+ model_omdet, processor_omdet, image, prompts, threshold_omdet
100
+ )
101
  return (
102
  (image, ann_llm),
103
  time_llm,