ariG23498 HF Staff commited on
Commit
4fa3f07
·
verified ·
1 Parent(s): ff97ba5

update text labels usage

Browse files
Files changed (1) hide show
  1. app.py +2 -22
app.py CHANGED
@@ -20,9 +20,6 @@ def extract_model_short_name(model_id: str) -> str:
20
 
21
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
- # (Optional) modest speed-ups
24
- torch.set_grad_enabled(False)
25
-
26
  # Model bundles for cleaner wiring
27
  @dataclass
28
  class ZSDetBundle:
@@ -30,7 +27,6 @@ class ZSDetBundle:
30
  model_name: str
31
  processor: AutoProcessor
32
  model: AutoModelForZeroShotObjectDetection
33
- use_label_ids: bool # True for OWLv2/OMDet (labels are indices), False for others
34
 
35
  # LLMDet
36
  model_llmdet_id = "iSEE-Laboratory/llmdet_tiny"
@@ -41,7 +37,6 @@ bundle_llmdet = ZSDetBundle(
41
  model_name=extract_model_short_name(model_llmdet_id),
42
  processor=processor_llmdet,
43
  model=model_llmdet,
44
- use_label_ids=False,
45
  )
46
 
47
  # MM GroundingDINO
@@ -53,7 +48,6 @@ bundle_mm_grounding = ZSDetBundle(
53
  model_name=extract_model_short_name(model_mm_grounding_id),
54
  processor=processor_mm_grounding,
55
  model=model_mm_grounding,
56
- use_label_ids=False,
57
  )
58
 
59
  # OMDet Turbo
@@ -65,7 +59,6 @@ bundle_omdet = ZSDetBundle(
65
  model_name=extract_model_short_name(model_omdet_id),
66
  processor=processor_omdet,
67
  model=model_omdet,
68
- use_label_ids=True, # returns label indices
69
  )
70
 
71
  # OWLv2
@@ -77,7 +70,6 @@ bundle_owlv2 = ZSDetBundle(
77
  model_name=extract_model_short_name(model_owlv2_id),
78
  processor=processor_owlv2,
79
  model=model_owlv2,
80
- use_label_ids=True, # returns label indices
81
  )
82
 
83
  # ---------------------------
@@ -106,27 +98,15 @@ def detect(
106
  outputs = model(**inputs)
107
 
108
  results = bundle.processor.post_process_grounded_object_detection(
109
- outputs, threshold=threshold, target_sizes=[image.size[::-1]]
110
  )[0]
111
 
112
  annotations = []
113
- key = "labels" if bundle.use_label_ids else "text_labels"
114
 
115
- for box, score, label in zip(results["boxes"], results["scores"], results[key]):
116
  if float(score) < threshold:
117
  continue
118
 
119
- if bundle.use_label_ids:
120
- # Map label index -> prompt string
121
- label_idx = int(label) if isinstance(label, torch.Tensor) else int(label)
122
- if 0 <= label_idx < len(prompts):
123
- label_name = prompts[label_idx]
124
- else:
125
- label_name = str(label_idx)
126
- else:
127
- # Direct text label
128
- label_name = label if isinstance(label, str) else str(label)
129
-
130
  xmin, ymin, xmax, ymax = map(lambda v: int(v), box.tolist())
131
  annotations.append(((xmin, ymin, xmax, ymax), f"{label_name} {float(score):.2f}"))
132
 
 
20
 
21
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
 
 
 
 
23
  # Model bundles for cleaner wiring
24
  @dataclass
25
  class ZSDetBundle:
 
27
  model_name: str
28
  processor: AutoProcessor
29
  model: AutoModelForZeroShotObjectDetection
 
30
 
31
  # LLMDet
32
  model_llmdet_id = "iSEE-Laboratory/llmdet_tiny"
 
37
  model_name=extract_model_short_name(model_llmdet_id),
38
  processor=processor_llmdet,
39
  model=model_llmdet,
 
40
  )
41
 
42
  # MM GroundingDINO
 
48
  model_name=extract_model_short_name(model_mm_grounding_id),
49
  processor=processor_mm_grounding,
50
  model=model_mm_grounding,
 
51
  )
52
 
53
  # OMDet Turbo
 
59
  model_name=extract_model_short_name(model_omdet_id),
60
  processor=processor_omdet,
61
  model=model_omdet,
 
62
  )
63
 
64
  # OWLv2
 
70
  model_name=extract_model_short_name(model_owlv2_id),
71
  processor=processor_owlv2,
72
  model=model_owlv2,
 
73
  )
74
 
75
  # ---------------------------
 
98
  outputs = model(**inputs)
99
 
100
  results = bundle.processor.post_process_grounded_object_detection(
101
+ outputs, threshold=threshold, target_sizes=[image.size[::-1]], text_labels=texts,
102
  )[0]
103
 
104
  annotations = []
 
105
 
106
+ for box, score, label_name in zip(results["boxes"], results["scores"], results["text_labels"]):
107
  if float(score) < threshold:
108
  continue
109
 
 
 
 
 
 
 
 
 
 
 
 
110
  xmin, ymin, xmax, ymax = map(lambda v: int(v), box.tolist())
111
  annotations.append(((xmin, ymin, xmax, ymax), f"{label_name} {float(score):.2f}"))
112