ariG23498 HF Staff commited on
Commit
ff97ba5
·
1 Parent(s): 8983f6a

same device

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -35,7 +35,7 @@ class ZSDetBundle:
35
  # LLMDet
36
  model_llmdet_id = "iSEE-Laboratory/llmdet_tiny"
37
  processor_llmdet = AutoProcessor.from_pretrained(model_llmdet_id)
38
- model_llmdet = AutoModelForZeroShotObjectDetection.from_pretrained(model_llmdet_id).to(DEVICE).eval()
39
  bundle_llmdet = ZSDetBundle(
40
  model_id=model_llmdet_id,
41
  model_name=extract_model_short_name(model_llmdet_id),
@@ -47,7 +47,7 @@ bundle_llmdet = ZSDetBundle(
47
  # MM GroundingDINO
48
  model_mm_grounding_id = "rziga/mm_grounding_dino_tiny_o365v1_goldg"
49
  processor_mm_grounding = AutoProcessor.from_pretrained(model_mm_grounding_id)
50
- model_mm_grounding = AutoModelForZeroShotObjectDetection.from_pretrained(model_mm_grounding_id).to(DEVICE).eval()
51
  bundle_mm_grounding = ZSDetBundle(
52
  model_id=model_mm_grounding_id,
53
  model_name=extract_model_short_name(model_mm_grounding_id),
@@ -59,7 +59,7 @@ bundle_mm_grounding = ZSDetBundle(
59
  # OMDet Turbo
60
  model_omdet_id = "omlab/omdet-turbo-swin-tiny-hf"
61
  processor_omdet = AutoProcessor.from_pretrained(model_omdet_id)
62
- model_omdet = AutoModelForZeroShotObjectDetection.from_pretrained(model_omdet_id).to(DEVICE).eval()
63
  bundle_omdet = ZSDetBundle(
64
  model_id=model_omdet_id,
65
  model_name=extract_model_short_name(model_omdet_id),
@@ -71,7 +71,7 @@ bundle_omdet = ZSDetBundle(
71
  # OWLv2
72
  model_owlv2_id = "google/owlv2-large-patch14-ensemble"
73
  processor_owlv2 = AutoProcessor.from_pretrained(model_owlv2_id)
74
- model_owlv2 = AutoModelForZeroShotObjectDetection.from_pretrained(model_owlv2_id).to(DEVICE).eval()
75
  bundle_owlv2 = ZSDetBundle(
76
  model_id=model_owlv2_id,
77
  model_name=extract_model_short_name(model_owlv2_id),
@@ -95,13 +95,15 @@ def detect(
95
  Returns [(bbox, label_score_str), ...], time_str
96
  """
97
  t0 = time.perf_counter()
 
98
 
99
  # HF zero-shot OD expects list-of-list text
100
  texts = [prompts]
101
- inputs = bundle.processor(images=image, text=texts, return_tensors="pt").to(DEVICE)
 
102
 
103
  with torch.inference_mode():
104
- outputs = bundle.model(**inputs)
105
 
106
  results = bundle.processor.post_process_grounded_object_detection(
107
  outputs, threshold=threshold, target_sizes=[image.size[::-1]]
 
35
  # LLMDet
36
  model_llmdet_id = "iSEE-Laboratory/llmdet_tiny"
37
  processor_llmdet = AutoProcessor.from_pretrained(model_llmdet_id)
38
+ model_llmdet = AutoModelForZeroShotObjectDetection.from_pretrained(model_llmdet_id)
39
  bundle_llmdet = ZSDetBundle(
40
  model_id=model_llmdet_id,
41
  model_name=extract_model_short_name(model_llmdet_id),
 
47
  # MM GroundingDINO
48
  model_mm_grounding_id = "rziga/mm_grounding_dino_tiny_o365v1_goldg"
49
  processor_mm_grounding = AutoProcessor.from_pretrained(model_mm_grounding_id)
50
+ model_mm_grounding = AutoModelForZeroShotObjectDetection.from_pretrained(model_mm_grounding_id)
51
  bundle_mm_grounding = ZSDetBundle(
52
  model_id=model_mm_grounding_id,
53
  model_name=extract_model_short_name(model_mm_grounding_id),
 
59
  # OMDet Turbo
60
  model_omdet_id = "omlab/omdet-turbo-swin-tiny-hf"
61
  processor_omdet = AutoProcessor.from_pretrained(model_omdet_id)
62
+ model_omdet = AutoModelForZeroShotObjectDetection.from_pretrained(model_omdet_id)
63
  bundle_omdet = ZSDetBundle(
64
  model_id=model_omdet_id,
65
  model_name=extract_model_short_name(model_omdet_id),
 
71
  # OWLv2
72
  model_owlv2_id = "google/owlv2-large-patch14-ensemble"
73
  processor_owlv2 = AutoProcessor.from_pretrained(model_owlv2_id)
74
+ model_owlv2 = AutoModelForZeroShotObjectDetection.from_pretrained(model_owlv2_id)
75
  bundle_owlv2 = ZSDetBundle(
76
  model_id=model_owlv2_id,
77
  model_name=extract_model_short_name(model_owlv2_id),
 
95
  Returns [(bbox, label_score_str), ...], time_str
96
  """
97
  t0 = time.perf_counter()
98
+ device = "cuda" if torch.cuda.is_available() else "cpu"
99
 
100
  # HF zero-shot OD expects list-of-list text
101
  texts = [prompts]
102
+ inputs = bundle.processor(images=image, text=texts, return_tensors="pt").to(device)
103
+ model = bundle.model.to(device).eval()
104
 
105
  with torch.inference_mode():
106
+ outputs = model(**inputs)
107
 
108
  results = bundle.processor.post_process_grounded_object_detection(
109
  outputs, threshold=threshold, target_sizes=[image.size[::-1]]