Spaces:
Sleeping
Sleeping
Martin Tomov
commited on
update app.py to match ZeroGPU Spaces
Browse files
app.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import random
|
2 |
from dataclasses import dataclass
|
3 |
from typing import Any, List, Dict, Optional, Union, Tuple
|
4 |
-
|
5 |
import cv2
|
6 |
import torch
|
7 |
import requests
|
@@ -10,6 +9,7 @@ from PIL import Image
|
|
10 |
import matplotlib.pyplot as plt
|
11 |
from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
|
12 |
import gradio as gr
|
|
|
13 |
|
14 |
@dataclass
|
15 |
class BoundingBox:
|
@@ -31,12 +31,16 @@ class DetectionResult:
|
|
31 |
|
32 |
@classmethod
|
33 |
def from_dict(cls, detection_dict: Dict) -> 'DetectionResult':
|
34 |
-
return cls(
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
40 |
|
41 |
def annotate(image: Union[Image.Image, np.ndarray], detection_results: List[DetectionResult]) -> np.ndarray:
|
42 |
image_cv2 = np.array(image) if isinstance(image, Image.Image) else image
|
@@ -47,16 +51,16 @@ def annotate(image: Union[Image.Image, np.ndarray], detection_results: List[Dete
|
|
47 |
score = detection.score
|
48 |
box = detection.box
|
49 |
mask = detection.mask
|
|
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
cv2.putText(image_cv2, f'{label}: {score:.2f}', (box.xmin, box.ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color.tolist(), 2)
|
55 |
|
56 |
if mask is not None:
|
57 |
mask_uint8 = (mask * 255).astype(np.uint8)
|
58 |
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
59 |
-
cv2.drawContours(image_cv2, contours, -1, color
|
60 |
|
61 |
return cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
|
62 |
|
@@ -90,29 +94,26 @@ def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> L
|
|
90 |
masks[idx] = polygon_to_mask(polygon, shape)
|
91 |
return list(masks)
|
92 |
|
|
|
93 |
def detect(image: Image.Image, labels: List[str], threshold: float = 0.3, detector_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
94 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
95 |
detector_id = detector_id if detector_id else "IDEA-Research/grounding-dino-base"
|
96 |
-
object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=
|
97 |
labels = [label if label.endswith(".") else label+"." for label in labels]
|
98 |
results = object_detector(image, candidate_labels=labels, threshold=threshold)
|
99 |
return [DetectionResult.from_dict(result) for result in results]
|
100 |
|
|
|
101 |
def segment(image: Image.Image, detection_results: List[DetectionResult], polygon_refinement: bool = False, segmenter_id: Optional[str] = None) -> List[DetectionResult]:
|
102 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
103 |
segmenter_id = segmenter_id if segmenter_id else "martintmv/InsectSAM"
|
104 |
-
segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(
|
105 |
processor = AutoProcessor.from_pretrained(segmenter_id)
|
106 |
-
|
107 |
boxes = get_boxes(detection_results)
|
108 |
-
inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(
|
109 |
outputs = segmentator(**inputs)
|
110 |
masks = processor.post_process_masks(masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes)[0]
|
111 |
masks = refine_masks(masks, polygon_refinement)
|
112 |
-
|
113 |
for detection_result, mask in zip(detection_results, masks):
|
114 |
detection_result.mask = mask
|
115 |
-
|
116 |
return detection_results
|
117 |
|
118 |
def grounded_segmentation(image: Union[Image.Image, str], labels: List[str], threshold: float = 0.3, polygon_refinement: bool = False, detector_id: Optional[str] = None, segmenter_id: Optional[str] = None) -> Tuple[np.ndarray, List[DetectionResult]]:
|
@@ -165,8 +166,10 @@ def draw_classification_boxes(image_with_insects: np.ndarray, detections: List[D
|
|
165 |
color = np.random.randint(0, 256, size=3).tolist()
|
166 |
cv2.rectangle(image_with_insects, (box.xmin, box.ymin), (box.xmax, box.ymax), color, 2)
|
167 |
(text_width, text_height), baseline = cv2.getTextSize(f"{label}: {score:.2f}", cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
|
168 |
-
cv2.rectangle(image_with_insects, (box.xmin, box.ymin - text_height - baseline), (box.xmin + text_width, box.ymin),
|
169 |
-
|
|
|
|
|
170 |
return image_with_insects
|
171 |
|
172 |
def process_image(image):
|
@@ -177,7 +180,6 @@ def process_image(image):
|
|
177 |
yellow_background_with_masks = put_masks_on_yellow_background(original_image.shape[:2], insect_masks)
|
178 |
yellow_background_with_insects = create_yellow_background_with_insects(original_image, detections)
|
179 |
yellow_background_with_boxes = draw_classification_boxes(yellow_background_with_insects, detections)
|
180 |
-
|
181 |
return masked_image, yellow_background_with_masks, yellow_background_with_boxes
|
182 |
|
183 |
gr.Interface(
|
@@ -185,4 +187,4 @@ gr.Interface(
|
|
185 |
inputs=gr.Image(type="pil"),
|
186 |
outputs=[gr.Image(type="numpy"), gr.Image(type="numpy"), gr.Image(type="numpy")],
|
187 |
title="Insect Detection and Masking"
|
188 |
-
).launch()
|
|
|
1 |
import random
|
2 |
from dataclasses import dataclass
|
3 |
from typing import Any, List, Dict, Optional, Union, Tuple
|
|
|
4 |
import cv2
|
5 |
import torch
|
6 |
import requests
|
|
|
9 |
import matplotlib.pyplot as plt
|
10 |
from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
|
11 |
import gradio as gr
|
12 |
+
import spaces
|
13 |
|
14 |
@dataclass
|
15 |
class BoundingBox:
|
|
|
31 |
|
32 |
@classmethod
|
33 |
def from_dict(cls, detection_dict: Dict) -> 'DetectionResult':
|
34 |
+
return cls(
|
35 |
+
score=detection_dict['score'],
|
36 |
+
label=detection_dict['label'],
|
37 |
+
box=BoundingBox(
|
38 |
+
xmin=detection_dict['box']['xmin'],
|
39 |
+
ymin=detection_dict['box']['ymin'],
|
40 |
+
xmax=detection_dict['box']['xmax'],
|
41 |
+
ymax=detection_dict['box']['ymax']
|
42 |
+
)
|
43 |
+
)
|
44 |
|
45 |
def annotate(image: Union[Image.Image, np.ndarray], detection_results: List[DetectionResult]) -> np.ndarray:
|
46 |
image_cv2 = np.array(image) if isinstance(image, Image.Image) else image
|
|
|
51 |
score = detection.score
|
52 |
box = detection.box
|
53 |
mask = detection.mask
|
54 |
+
color = np.random.randint(0, 256, size=3).tolist()
|
55 |
|
56 |
+
cv2.rectangle(image_cv2, (box.xmin, box.ymin), (box.xmax, box.ymax), color, 2)
|
57 |
+
cv2.putText(image_cv2, f'{label}: {score:.2f}', (box.xmin, box.ymin - 10),
|
58 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
|
|
59 |
|
60 |
if mask is not None:
|
61 |
mask_uint8 = (mask * 255).astype(np.uint8)
|
62 |
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
63 |
+
cv2.drawContours(image_cv2, contours, -1, color, 2)
|
64 |
|
65 |
return cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
|
66 |
|
|
|
94 |
masks[idx] = polygon_to_mask(polygon, shape)
|
95 |
return list(masks)
|
96 |
|
97 |
+
@spaces.GPU
|
98 |
def detect(image: Image.Image, labels: List[str], threshold: float = 0.3, detector_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
|
|
99 |
detector_id = detector_id if detector_id else "IDEA-Research/grounding-dino-base"
|
100 |
+
object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device="cuda")
|
101 |
labels = [label if label.endswith(".") else label+"." for label in labels]
|
102 |
results = object_detector(image, candidate_labels=labels, threshold=threshold)
|
103 |
return [DetectionResult.from_dict(result) for result in results]
|
104 |
|
105 |
+
@spaces.GPU
|
106 |
def segment(image: Image.Image, detection_results: List[DetectionResult], polygon_refinement: bool = False, segmenter_id: Optional[str] = None) -> List[DetectionResult]:
|
|
|
107 |
segmenter_id = segmenter_id if segmenter_id else "martintmv/InsectSAM"
|
108 |
+
segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to("cuda")
|
109 |
processor = AutoProcessor.from_pretrained(segmenter_id)
|
|
|
110 |
boxes = get_boxes(detection_results)
|
111 |
+
inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to("cuda")
|
112 |
outputs = segmentator(**inputs)
|
113 |
masks = processor.post_process_masks(masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes)[0]
|
114 |
masks = refine_masks(masks, polygon_refinement)
|
|
|
115 |
for detection_result, mask in zip(detection_results, masks):
|
116 |
detection_result.mask = mask
|
|
|
117 |
return detection_results
|
118 |
|
119 |
def grounded_segmentation(image: Union[Image.Image, str], labels: List[str], threshold: float = 0.3, polygon_refinement: bool = False, detector_id: Optional[str] = None, segmenter_id: Optional[str] = None) -> Tuple[np.ndarray, List[DetectionResult]]:
|
|
|
166 |
color = np.random.randint(0, 256, size=3).tolist()
|
167 |
cv2.rectangle(image_with_insects, (box.xmin, box.ymin), (box.xmax, box.ymax), color, 2)
|
168 |
(text_width, text_height), baseline = cv2.getTextSize(f"{label}: {score:.2f}", cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
|
169 |
+
cv2.rectangle(image_with_insects, (box.xmin, box.ymin - text_height - baseline), (box.xmin + text_width, box.ymin),
|
170 |
+
color, thickness=cv2.FILLED)
|
171 |
+
cv2.putText(image_with_insects, f"{label}: {score:.2f}", (box.xmin, box.ymin - baseline),
|
172 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
|
173 |
return image_with_insects
|
174 |
|
175 |
def process_image(image):
|
|
|
180 |
yellow_background_with_masks = put_masks_on_yellow_background(original_image.shape[:2], insect_masks)
|
181 |
yellow_background_with_insects = create_yellow_background_with_insects(original_image, detections)
|
182 |
yellow_background_with_boxes = draw_classification_boxes(yellow_background_with_insects, detections)
|
|
|
183 |
return masked_image, yellow_background_with_masks, yellow_background_with_boxes
|
184 |
|
185 |
gr.Interface(
|
|
|
187 |
inputs=gr.Image(type="pil"),
|
188 |
outputs=[gr.Image(type="numpy"), gr.Image(type="numpy"), gr.Image(type="numpy")],
|
189 |
title="Insect Detection and Masking"
|
190 |
+
).launch()
|