Martin Tomov commited on
Commit
59d81d4
·
verified ·
1 Parent(s): 3e420b8

update app.py to match ZeroGPU Spaces

Browse files
Files changed (1) hide show
  1. app.py +26 -24
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import random
2
  from dataclasses import dataclass
3
  from typing import Any, List, Dict, Optional, Union, Tuple
4
-
5
  import cv2
6
  import torch
7
  import requests
@@ -10,6 +9,7 @@ from PIL import Image
10
  import matplotlib.pyplot as plt
11
  from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
12
  import gradio as gr
 
13
 
14
  @dataclass
15
  class BoundingBox:
@@ -31,12 +31,16 @@ class DetectionResult:
31
 
32
  @classmethod
33
  def from_dict(cls, detection_dict: Dict) -> 'DetectionResult':
34
- return cls(score=detection_dict['score'],
35
- label=detection_dict['label'],
36
- box=BoundingBox(xmin=detection_dict['box']['xmin'],
37
- ymin=detection_dict['box']['ymin'],
38
- xmax=detection_dict['box']['xmax'],
39
- ymax=detection_dict['box']['ymax']))
 
 
 
 
40
 
41
  def annotate(image: Union[Image.Image, np.ndarray], detection_results: List[DetectionResult]) -> np.ndarray:
42
  image_cv2 = np.array(image) if isinstance(image, Image.Image) else image
@@ -47,16 +51,16 @@ def annotate(image: Union[Image.Image, np.ndarray], detection_results: List[Dete
47
  score = detection.score
48
  box = detection.box
49
  mask = detection.mask
 
50
 
51
- color = np.random.randint(0, 256, size=3)
52
-
53
- cv2.rectangle(image_cv2, (box.xmin, box.ymin), (box.xmax, box.ymax), color.tolist(), 2)
54
- cv2.putText(image_cv2, f'{label}: {score:.2f}', (box.xmin, box.ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color.tolist(), 2)
55
 
56
  if mask is not None:
57
  mask_uint8 = (mask * 255).astype(np.uint8)
58
  contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
59
- cv2.drawContours(image_cv2, contours, -1, color.tolist(), 2)
60
 
61
  return cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
62
 
@@ -90,29 +94,26 @@ def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> L
90
  masks[idx] = polygon_to_mask(polygon, shape)
91
  return list(masks)
92
 
 
93
  def detect(image: Image.Image, labels: List[str], threshold: float = 0.3, detector_id: Optional[str] = None) -> List[Dict[str, Any]]:
94
- device = "cuda" if torch.cuda.is_available() else "cpu"
95
  detector_id = detector_id if detector_id else "IDEA-Research/grounding-dino-base"
96
- object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=device)
97
  labels = [label if label.endswith(".") else label+"." for label in labels]
98
  results = object_detector(image, candidate_labels=labels, threshold=threshold)
99
  return [DetectionResult.from_dict(result) for result in results]
100
 
 
101
  def segment(image: Image.Image, detection_results: List[DetectionResult], polygon_refinement: bool = False, segmenter_id: Optional[str] = None) -> List[DetectionResult]:
102
- device = "cuda" if torch.cuda.is_available() else "cpu"
103
  segmenter_id = segmenter_id if segmenter_id else "martintmv/InsectSAM"
104
- segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device)
105
  processor = AutoProcessor.from_pretrained(segmenter_id)
106
-
107
  boxes = get_boxes(detection_results)
108
- inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(device)
109
  outputs = segmentator(**inputs)
110
  masks = processor.post_process_masks(masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes)[0]
111
  masks = refine_masks(masks, polygon_refinement)
112
-
113
  for detection_result, mask in zip(detection_results, masks):
114
  detection_result.mask = mask
115
-
116
  return detection_results
117
 
118
  def grounded_segmentation(image: Union[Image.Image, str], labels: List[str], threshold: float = 0.3, polygon_refinement: bool = False, detector_id: Optional[str] = None, segmenter_id: Optional[str] = None) -> Tuple[np.ndarray, List[DetectionResult]]:
@@ -165,8 +166,10 @@ def draw_classification_boxes(image_with_insects: np.ndarray, detections: List[D
165
  color = np.random.randint(0, 256, size=3).tolist()
166
  cv2.rectangle(image_with_insects, (box.xmin, box.ymin), (box.xmax, box.ymax), color, 2)
167
  (text_width, text_height), baseline = cv2.getTextSize(f"{label}: {score:.2f}", cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
168
- cv2.rectangle(image_with_insects, (box.xmin, box.ymin - text_height - baseline), (box.xmin + text_width, box.ymin), color, thickness=cv2.FILLED)
169
- cv2.putText(image_with_insects, f"{label}: {score:.2f}", (box.xmin, box.ymin - baseline), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
 
 
170
  return image_with_insects
171
 
172
  def process_image(image):
@@ -177,7 +180,6 @@ def process_image(image):
177
  yellow_background_with_masks = put_masks_on_yellow_background(original_image.shape[:2], insect_masks)
178
  yellow_background_with_insects = create_yellow_background_with_insects(original_image, detections)
179
  yellow_background_with_boxes = draw_classification_boxes(yellow_background_with_insects, detections)
180
-
181
  return masked_image, yellow_background_with_masks, yellow_background_with_boxes
182
 
183
  gr.Interface(
@@ -185,4 +187,4 @@ gr.Interface(
185
  inputs=gr.Image(type="pil"),
186
  outputs=[gr.Image(type="numpy"), gr.Image(type="numpy"), gr.Image(type="numpy")],
187
  title="Insect Detection and Masking"
188
- ).launch()
 
1
  import random
2
  from dataclasses import dataclass
3
  from typing import Any, List, Dict, Optional, Union, Tuple
 
4
  import cv2
5
  import torch
6
  import requests
 
9
  import matplotlib.pyplot as plt
10
  from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
11
  import gradio as gr
12
+ import spaces
13
 
14
  @dataclass
15
  class BoundingBox:
 
31
 
32
  @classmethod
33
  def from_dict(cls, detection_dict: Dict) -> 'DetectionResult':
34
+ return cls(
35
+ score=detection_dict['score'],
36
+ label=detection_dict['label'],
37
+ box=BoundingBox(
38
+ xmin=detection_dict['box']['xmin'],
39
+ ymin=detection_dict['box']['ymin'],
40
+ xmax=detection_dict['box']['xmax'],
41
+ ymax=detection_dict['box']['ymax']
42
+ )
43
+ )
44
 
45
  def annotate(image: Union[Image.Image, np.ndarray], detection_results: List[DetectionResult]) -> np.ndarray:
46
  image_cv2 = np.array(image) if isinstance(image, Image.Image) else image
 
51
  score = detection.score
52
  box = detection.box
53
  mask = detection.mask
54
+ color = np.random.randint(0, 256, size=3).tolist()
55
 
56
+ cv2.rectangle(image_cv2, (box.xmin, box.ymin), (box.xmax, box.ymax), color, 2)
57
+ cv2.putText(image_cv2, f'{label}: {score:.2f}', (box.xmin, box.ymin - 10),
58
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
 
59
 
60
  if mask is not None:
61
  mask_uint8 = (mask * 255).astype(np.uint8)
62
  contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
63
+ cv2.drawContours(image_cv2, contours, -1, color, 2)
64
 
65
  return cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
66
 
 
94
  masks[idx] = polygon_to_mask(polygon, shape)
95
  return list(masks)
96
 
97
+ @spaces.GPU
98
  def detect(image: Image.Image, labels: List[str], threshold: float = 0.3, detector_id: Optional[str] = None) -> List[Dict[str, Any]]:
 
99
  detector_id = detector_id if detector_id else "IDEA-Research/grounding-dino-base"
100
+ object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device="cuda")
101
  labels = [label if label.endswith(".") else label+"." for label in labels]
102
  results = object_detector(image, candidate_labels=labels, threshold=threshold)
103
  return [DetectionResult.from_dict(result) for result in results]
104
 
105
+ @spaces.GPU
106
  def segment(image: Image.Image, detection_results: List[DetectionResult], polygon_refinement: bool = False, segmenter_id: Optional[str] = None) -> List[DetectionResult]:
 
107
  segmenter_id = segmenter_id if segmenter_id else "martintmv/InsectSAM"
108
+ segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to("cuda")
109
  processor = AutoProcessor.from_pretrained(segmenter_id)
 
110
  boxes = get_boxes(detection_results)
111
+ inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to("cuda")
112
  outputs = segmentator(**inputs)
113
  masks = processor.post_process_masks(masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes)[0]
114
  masks = refine_masks(masks, polygon_refinement)
 
115
  for detection_result, mask in zip(detection_results, masks):
116
  detection_result.mask = mask
 
117
  return detection_results
118
 
119
  def grounded_segmentation(image: Union[Image.Image, str], labels: List[str], threshold: float = 0.3, polygon_refinement: bool = False, detector_id: Optional[str] = None, segmenter_id: Optional[str] = None) -> Tuple[np.ndarray, List[DetectionResult]]:
 
166
  color = np.random.randint(0, 256, size=3).tolist()
167
  cv2.rectangle(image_with_insects, (box.xmin, box.ymin), (box.xmax, box.ymax), color, 2)
168
  (text_width, text_height), baseline = cv2.getTextSize(f"{label}: {score:.2f}", cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
169
+ cv2.rectangle(image_with_insects, (box.xmin, box.ymin - text_height - baseline), (box.xmin + text_width, box.ymin),
170
+ color, thickness=cv2.FILLED)
171
+ cv2.putText(image_with_insects, f"{label}: {score:.2f}", (box.xmin, box.ymin - baseline),
172
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
173
  return image_with_insects
174
 
175
  def process_image(image):
 
180
  yellow_background_with_masks = put_masks_on_yellow_background(original_image.shape[:2], insect_masks)
181
  yellow_background_with_insects = create_yellow_background_with_insects(original_image, detections)
182
  yellow_background_with_boxes = draw_classification_boxes(yellow_background_with_insects, detections)
 
183
  return masked_image, yellow_background_with_masks, yellow_background_with_boxes
184
 
185
  gr.Interface(
 
187
  inputs=gr.Image(type="pil"),
188
  outputs=[gr.Image(type="numpy"), gr.Image(type="numpy"), gr.Image(type="numpy")],
189
  title="Insect Detection and Masking"
190
+ ).launch()