Spaces:
Sleeping
Sleeping
Commit
·
2396fdf
1
Parent(s):
53c5524
Enable smoothing mask and expanding mask
Browse files
app.py
CHANGED
@@ -1,9 +1,10 @@
|
|
|
|
1 |
import json
|
2 |
import os
|
3 |
-
import subprocess
|
4 |
import sys
|
5 |
import tempfile
|
6 |
|
|
|
7 |
import gradio as gr
|
8 |
import numpy as np
|
9 |
import supervision as sv
|
@@ -86,7 +87,16 @@ grounding_dino_model = DinoModel(
|
|
86 |
)
|
87 |
|
88 |
|
89 |
-
def process(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
global tag2text_model, sam_predictor, sam_automask_generator, grounding_dino_model, device
|
91 |
output_gallery = []
|
92 |
detections = None
|
@@ -97,6 +107,7 @@ def process(image_path, task, prompt, box_threshold, text_threshold, iou_thresho
|
|
97 |
image = Image.open(image_path)
|
98 |
image_pil = image.convert("RGB")
|
99 |
image = np.array(image_pil)
|
|
|
100 |
|
101 |
# Extract image metadata
|
102 |
filename = os.path.basename(image_path)
|
@@ -106,7 +117,7 @@ def process(image_path, task, prompt, box_threshold, text_threshold, iou_thresho
|
|
106 |
metadata["image"]["height"] = h
|
107 |
|
108 |
# Generate tags
|
109 |
-
if task in ["auto", "
|
110 |
tags, caption = generate_tags(tag2text_model, image_pil, "None", device)
|
111 |
prompt = " . ".join(tags)
|
112 |
print(f"Caption: {caption}")
|
@@ -146,20 +157,38 @@ def process(image_path, task, prompt, box_threshold, text_threshold, iou_thresho
|
|
146 |
|
147 |
# Segmentation
|
148 |
if task in ["auto", "segment"]:
|
|
|
|
|
|
|
149 |
if detections:
|
150 |
masks, scores = segment(
|
151 |
-
sam_predictor, image=
|
152 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
detections.mask = masks
|
|
|
|
|
|
|
154 |
else:
|
155 |
-
masks = sam_automask_generator.generate(
|
156 |
sorted_generated_masks = sorted(
|
157 |
masks, key=lambda x: x["area"], reverse=True
|
158 |
)
|
159 |
|
160 |
xywh = np.array([mask["bbox"] for mask in sorted_generated_masks])
|
161 |
mask = np.array(
|
162 |
-
[
|
|
|
|
|
|
|
163 |
)
|
164 |
scores = np.array(
|
165 |
[mask["predicted_iou"] for mask in sorted_generated_masks]
|
@@ -167,9 +196,7 @@ def process(image_path, task, prompt, box_threshold, text_threshold, iou_thresho
|
|
167 |
detections = sv.Detections(
|
168 |
xyxy=xywh_to_xyxy(boxes_xywh=xywh), mask=mask
|
169 |
)
|
170 |
-
|
171 |
-
# mask_image, _ = show_anns_sam(masks)
|
172 |
-
# annotated_image = np.uint8(mask_image * opacity + image * (1 - opacity))
|
173 |
|
174 |
mask_annotator = sv.MaskAnnotator()
|
175 |
mask_image = np.zeros_like(image, dtype=np.uint8)
|
@@ -177,7 +204,13 @@ def process(image_path, task, prompt, box_threshold, text_threshold, iou_thresho
|
|
177 |
mask_image, detections=detections, opacity=1
|
178 |
)
|
179 |
annotated_image = mask_annotator.annotate(image, detections=detections)
|
|
|
180 |
output_gallery.append(mask_image)
|
|
|
|
|
|
|
|
|
|
|
181 |
output_gallery.append(annotated_image)
|
182 |
|
183 |
# ToDo: Extract metadata
|
@@ -203,7 +236,7 @@ def process(image_path, task, prompt, box_threshold, text_threshold, iou_thresho
|
|
203 |
|
204 |
meta_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json")
|
205 |
meta_file_path = meta_file.name
|
206 |
-
with open(meta_file_path, "w") as fp:
|
207 |
json.dump(metadata, fp)
|
208 |
|
209 |
return output_gallery, meta_file_path
|
@@ -231,7 +264,6 @@ with gr.Blocks(css="style.css", title=title) as demo:
|
|
231 |
value=0.3,
|
232 |
step=0.05,
|
233 |
label="Box threshold",
|
234 |
-
info="Hash size to use for image hashing",
|
235 |
)
|
236 |
text_threshold = gr.Slider(
|
237 |
minimum=0,
|
@@ -239,7 +271,6 @@ with gr.Blocks(css="style.css", title=title) as demo:
|
|
239 |
value=0.25,
|
240 |
step=0.05,
|
241 |
label="Text threshold",
|
242 |
-
info="Number of history images used to find out duplicate image",
|
243 |
)
|
244 |
iou_threshold = gr.Slider(
|
245 |
minimum=0,
|
@@ -247,7 +278,18 @@ with gr.Blocks(css="style.css", title=title) as demo:
|
|
247 |
value=0.5,
|
248 |
step=0.05,
|
249 |
label="IOU threshold",
|
250 |
-
info="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
)
|
252 |
run_button = gr.Button(label="Run")
|
253 |
|
@@ -256,12 +298,11 @@ with gr.Blocks(css="style.css", title=title) as demo:
|
|
256 |
label="Generated images", show_label=False, elem_id="gallery"
|
257 |
).style(preview=True, grid=2, object_fit="scale-down")
|
258 |
meta_file = gr.File(label="Metadata file")
|
259 |
-
|
260 |
with gr.Row(elem_classes=["container"]):
|
261 |
gr.Examples(
|
262 |
[
|
263 |
["examples/dog.png", "auto", ""],
|
264 |
-
["examples/eiffel.
|
265 |
["examples/eiffel.png", "segment", ""],
|
266 |
["examples/girl.png", "auto", "girl . face"],
|
267 |
["examples/horse.png", "detect", "horse"],
|
@@ -279,6 +320,8 @@ with gr.Blocks(css="style.css", title=title) as demo:
|
|
279 |
box_threshold,
|
280 |
text_threshold,
|
281 |
iou_threshold,
|
|
|
|
|
282 |
],
|
283 |
outputs=[gallery, meta_file],
|
284 |
)
|
|
|
1 |
+
import functools
|
2 |
import json
|
3 |
import os
|
|
|
4 |
import sys
|
5 |
import tempfile
|
6 |
|
7 |
+
import cv2
|
8 |
import gradio as gr
|
9 |
import numpy as np
|
10 |
import supervision as sv
|
|
|
87 |
)
|
88 |
|
89 |
|
90 |
+
def process(
|
91 |
+
image_path,
|
92 |
+
task,
|
93 |
+
prompt,
|
94 |
+
box_threshold,
|
95 |
+
text_threshold,
|
96 |
+
iou_threshold,
|
97 |
+
kernel_size,
|
98 |
+
expand_mask,
|
99 |
+
):
|
100 |
global tag2text_model, sam_predictor, sam_automask_generator, grounding_dino_model, device
|
101 |
output_gallery = []
|
102 |
detections = None
|
|
|
107 |
image = Image.open(image_path)
|
108 |
image_pil = image.convert("RGB")
|
109 |
image = np.array(image_pil)
|
110 |
+
orig_image = image.copy()
|
111 |
|
112 |
# Extract image metadata
|
113 |
filename = os.path.basename(image_path)
|
|
|
117 |
metadata["image"]["height"] = h
|
118 |
|
119 |
# Generate tags
|
120 |
+
if task in ["auto", "detection"] and prompt == "":
|
121 |
tags, caption = generate_tags(tag2text_model, image_pil, "None", device)
|
122 |
prompt = " . ".join(tags)
|
123 |
print(f"Caption: {caption}")
|
|
|
157 |
|
158 |
# Segmentation
|
159 |
if task in ["auto", "segment"]:
|
160 |
+
kernel = cv2.getStructuringElement(
|
161 |
+
cv2.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1)
|
162 |
+
)
|
163 |
if detections:
|
164 |
masks, scores = segment(
|
165 |
+
sam_predictor, image=orig_image, boxes=detections.xyxy
|
166 |
)
|
167 |
+
if expand_mask:
|
168 |
+
masks = [
|
169 |
+
cv2.dilate(mask.astype(np.uint8), kernel) for mask in masks
|
170 |
+
]
|
171 |
+
else:
|
172 |
+
masks = [
|
173 |
+
cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
|
174 |
+
for mask in masks
|
175 |
+
]
|
176 |
detections.mask = masks
|
177 |
+
binary_mask = functools.reduce(
|
178 |
+
lambda x, y: x + y, detections.mask
|
179 |
+
).astype(np.bool)
|
180 |
else:
|
181 |
+
masks = sam_automask_generator.generate(orig_image)
|
182 |
sorted_generated_masks = sorted(
|
183 |
masks, key=lambda x: x["area"], reverse=True
|
184 |
)
|
185 |
|
186 |
xywh = np.array([mask["bbox"] for mask in sorted_generated_masks])
|
187 |
mask = np.array(
|
188 |
+
[
|
189 |
+
cv2.dilate(mask["segmentation"].astype(np.uint8), kernel)
|
190 |
+
for mask in sorted_generated_masks
|
191 |
+
]
|
192 |
)
|
193 |
scores = np.array(
|
194 |
[mask["predicted_iou"] for mask in sorted_generated_masks]
|
|
|
196 |
detections = sv.Detections(
|
197 |
xyxy=xywh_to_xyxy(boxes_xywh=xywh), mask=mask
|
198 |
)
|
199 |
+
binary_mask = None
|
|
|
|
|
200 |
|
201 |
mask_annotator = sv.MaskAnnotator()
|
202 |
mask_image = np.zeros_like(image, dtype=np.uint8)
|
|
|
204 |
mask_image, detections=detections, opacity=1
|
205 |
)
|
206 |
annotated_image = mask_annotator.annotate(image, detections=detections)
|
207 |
+
|
208 |
output_gallery.append(mask_image)
|
209 |
+
if binary_mask is not None:
|
210 |
+
binary_mask_image = binary_mask * 255
|
211 |
+
cutout_image = np.expand_dims(binary_mask, axis=-1) * orig_image
|
212 |
+
output_gallery.append(binary_mask_image)
|
213 |
+
output_gallery.append(cutout_image)
|
214 |
output_gallery.append(annotated_image)
|
215 |
|
216 |
# ToDo: Extract metadata
|
|
|
236 |
|
237 |
meta_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json")
|
238 |
meta_file_path = meta_file.name
|
239 |
+
with open(meta_file_path, "w", encoding="utf-8") as fp:
|
240 |
json.dump(metadata, fp)
|
241 |
|
242 |
return output_gallery, meta_file_path
|
|
|
264 |
value=0.3,
|
265 |
step=0.05,
|
266 |
label="Box threshold",
|
|
|
267 |
)
|
268 |
text_threshold = gr.Slider(
|
269 |
minimum=0,
|
|
|
271 |
value=0.25,
|
272 |
step=0.05,
|
273 |
label="Text threshold",
|
|
|
274 |
)
|
275 |
iou_threshold = gr.Slider(
|
276 |
minimum=0,
|
|
|
278 |
value=0.5,
|
279 |
step=0.05,
|
280 |
label="IOU threshold",
|
281 |
+
info="Intersection over Union threshold",
|
282 |
+
)
|
283 |
+
kernel_size = gr.Slider(
|
284 |
+
minimum=1,
|
285 |
+
maximum=5,
|
286 |
+
value=2,
|
287 |
+
step=1,
|
288 |
+
label="Kernel size",
|
289 |
+
info="Use to smooth segment masks",
|
290 |
+
)
|
291 |
+
expand_mask = gr.Checkbox(
|
292 |
+
label="Expand mask",
|
293 |
)
|
294 |
run_button = gr.Button(label="Run")
|
295 |
|
|
|
298 |
label="Generated images", show_label=False, elem_id="gallery"
|
299 |
).style(preview=True, grid=2, object_fit="scale-down")
|
300 |
meta_file = gr.File(label="Metadata file")
|
|
|
301 |
with gr.Row(elem_classes=["container"]):
|
302 |
gr.Examples(
|
303 |
[
|
304 |
["examples/dog.png", "auto", ""],
|
305 |
+
["examples/eiffel.jpg", "auto", "tower . lake . grass . sky"],
|
306 |
["examples/eiffel.png", "segment", ""],
|
307 |
["examples/girl.png", "auto", "girl . face"],
|
308 |
["examples/horse.png", "detect", "horse"],
|
|
|
320 |
box_threshold,
|
321 |
text_threshold,
|
322 |
iou_threshold,
|
323 |
+
kernel_size,
|
324 |
+
expand_mask,
|
325 |
],
|
326 |
outputs=[gallery, meta_file],
|
327 |
)
|