Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -2,6 +2,7 @@ import torch
|
|
2 |
import torchvision
|
3 |
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
|
4 |
from transformers import DetrImageProcessor, DetrForObjectDetection
|
|
|
5 |
from PIL import Image
|
6 |
import numpy as np
|
7 |
import matplotlib.pyplot as plt
|
@@ -25,6 +26,11 @@ detr_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
|
|
25 |
maskrcnn_model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
|
26 |
maskrcnn_model.eval()
|
27 |
|
|
|
|
|
|
|
|
|
|
|
28 |
# COCO class names for Faster R-CNN and Mask R-CNN
|
29 |
COCO_INSTANCE_CATEGORY_NAMES = [
|
30 |
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
@@ -41,6 +47,9 @@ COCO_INSTANCE_CATEGORY_NAMES = [
|
|
41 |
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
|
42 |
]
|
43 |
|
|
|
|
|
|
|
44 |
def detect_objects_frcnn(image, threshold=0.5):
|
45 |
"""Run Faster R-CNN detection."""
|
46 |
if image is None:
|
@@ -186,7 +195,7 @@ def detect_objects_maskrcnn(image, threshold=0.5):
|
|
186 |
for i in range(len(masks)):
|
187 |
if scores[i] >= threshold:
|
188 |
mask = masks[i, 0].cpu().numpy()
|
189 |
-
mask = mask > 0.5
|
190 |
color = np.random.rand(3)
|
191 |
colored_mask = np.zeros_like(image_np, dtype=np.uint8)
|
192 |
for c in range(3):
|
@@ -216,17 +225,88 @@ def detect_objects_maskrcnn(image, threshold=0.5):
|
|
216 |
plt.close()
|
217 |
return error_path, 0
|
218 |
|
219 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
"""Analyze and compare model performance."""
|
221 |
if image is None:
|
222 |
-
return "Please upload an image first.", None, None, None, "No analysis available."
|
223 |
|
224 |
frcnn_result = None
|
225 |
detr_result = None
|
226 |
maskrcnn_result = None
|
|
|
227 |
frcnn_count = 0
|
228 |
detr_count = 0
|
229 |
maskrcnn_count = 0
|
|
|
230 |
|
231 |
if model_choice in ["Faster R-CNN", "All"]:
|
232 |
frcnn_result, frcnn_count = detect_objects_frcnn(image, frcnn_threshold)
|
@@ -237,14 +317,17 @@ def analyze_performance(image, model_choice, frcnn_threshold=0.5, detr_threshold
|
|
237 |
if model_choice in ["Mask R-CNN", "All"]:
|
238 |
maskrcnn_result, maskrcnn_count = detect_objects_maskrcnn(image, maskrcnn_threshold)
|
239 |
|
|
|
|
|
|
|
240 |
# Compare and analyze performance
|
241 |
analysis = ""
|
242 |
if model_choice == "All":
|
243 |
-
# Compare the models
|
244 |
counts = {
|
245 |
"Faster R-CNN": frcnn_count,
|
246 |
"DETR": detr_count,
|
247 |
-
"Mask R-CNN": maskrcnn_count
|
|
|
248 |
}
|
249 |
max_count = max(counts.values())
|
250 |
max_models = [model for model, count in counts.items() if count == max_count]
|
@@ -254,7 +337,7 @@ def analyze_performance(image, model_choice, frcnn_threshold=0.5, detr_threshold
|
|
254 |
else:
|
255 |
analysis = f"{', '.join(max_models)} detected the same number of objects ({max_count}). "
|
256 |
|
257 |
-
analysis += "Faster R-CNN is typically faster and good for general detection. DETR excels in complex scenes with better context understanding. Mask R-CNN
|
258 |
|
259 |
# Add image-specific recommendation
|
260 |
img_array = np.array(image)
|
@@ -262,27 +345,29 @@ def analyze_performance(image, model_choice, frcnn_threshold=0.5, detr_threshold
|
|
262 |
pixel_variance = np.var(img_array)
|
263 |
|
264 |
if height * width > 1000 * 1000:
|
265 |
-
analysis += "\n\nThis is a high-resolution image. DETR and
|
266 |
if pixel_variance > 1000:
|
267 |
-
analysis += "\n\nThis image has high contrast/complexity. DETR and
|
268 |
if height * width < 500 * 500:
|
269 |
analysis += "\n\nFor smaller images, Faster R-CNN often provides good results at lower computational cost."
|
270 |
if max_count > 0:
|
271 |
-
analysis += "\n\nSince Mask R-CNN
|
272 |
|
273 |
elif model_choice == "Faster R-CNN":
|
274 |
analysis = f"Faster R-CNN detected {frcnn_count} objects with a confidence threshold of {frcnn_threshold}."
|
275 |
elif model_choice == "DETR":
|
276 |
analysis = f"DETR detected {detr_count} objects with a confidence threshold of {detr_threshold}."
|
277 |
-
|
278 |
analysis = f"Mask R-CNN detected {maskrcnn_count} objects with a confidence threshold of {maskrcnn_threshold}. It also provides instance segmentation for precise object boundaries."
|
|
|
|
|
279 |
|
280 |
-
return "Analysis complete!", frcnn_result, detr_result, maskrcnn_result, analysis
|
281 |
|
282 |
# Create multi-step Gradio interface with a workflow
|
283 |
with gr.Blocks(title="Object Detection Comparison") as app:
|
284 |
-
gr.Markdown("# Object Detection: Faster R-CNN vs DETR vs Mask R-CNN")
|
285 |
-
gr.Markdown("### Upload an image and compare
|
286 |
|
287 |
# State variables
|
288 |
image_state = gr.State(None)
|
@@ -299,7 +384,7 @@ with gr.Blocks(title="Object Detection Comparison") as app:
|
|
299 |
gr.Markdown("## Step 2: Question")
|
300 |
gr.Markdown("Which model do you think will work better?")
|
301 |
model_choice = gr.Radio(
|
302 |
-
choices=["Faster R-CNN", "DETR", "Mask R-CNN", "All"],
|
303 |
label="Select Object Detection Model(s)",
|
304 |
value="All"
|
305 |
)
|
@@ -315,6 +400,10 @@ with gr.Blocks(title="Object Detection Comparison") as app:
|
|
315 |
minimum=0.0, maximum=1.0, value=0.5, step=0.05,
|
316 |
label="Mask R-CNN Confidence Threshold"
|
317 |
)
|
|
|
|
|
|
|
|
|
318 |
detect_button = gr.Button("Run", variant="primary")
|
319 |
|
320 |
# Step 3: Results display
|
@@ -327,11 +416,15 @@ with gr.Blocks(title="Object Detection Comparison") as app:
|
|
327 |
with gr.Column():
|
328 |
gr.Markdown("### DETR Result")
|
329 |
detr_result = gr.Image(type="filepath", label="DETR")
|
|
|
330 |
with gr.Column():
|
331 |
gr.Markdown("### Mask R-CNN Result")
|
332 |
maskrcnn_result = gr.Image(type="filepath", label="Mask R-CNN")
|
|
|
|
|
|
|
333 |
|
334 |
-
analysis_output = gr.Textbox(label="Performance Analysis", lines=
|
335 |
restart_button = gr.Button("Try Another Image", variant="secondary")
|
336 |
|
337 |
# Upload button click event
|
@@ -349,8 +442,8 @@ with gr.Blocks(title="Object Detection Comparison") as app:
|
|
349 |
# Detect button click event
|
350 |
detect_button.click(
|
351 |
fn=analyze_performance,
|
352 |
-
inputs=[image_state, model_choice, frcnn_threshold, detr_threshold, maskrcnn_threshold],
|
353 |
-
outputs=[gr.Textbox(visible=False), frcnn_result, detr_result, maskrcnn_result, analysis_output]
|
354 |
).then(
|
355 |
fn=lambda: (gr.update(visible=True)),
|
356 |
outputs=[results_panel]
|
|
|
2 |
import torchvision
|
3 |
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
|
4 |
from transformers import DetrImageProcessor, DetrForObjectDetection
|
5 |
+
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
|
6 |
from PIL import Image
|
7 |
import numpy as np
|
8 |
import matplotlib.pyplot as plt
|
|
|
26 |
maskrcnn_model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
|
27 |
maskrcnn_model.eval()
|
28 |
|
29 |
+
# Load Mask2Former model and processor
|
30 |
+
mask2former_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-coco-instance")
|
31 |
+
mask2former_model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-small-coco-instance")
|
32 |
+
mask2former_model.eval()
|
33 |
+
|
34 |
# COCO class names for Faster R-CNN and Mask R-CNN
|
35 |
COCO_INSTANCE_CATEGORY_NAMES = [
|
36 |
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
|
|
47 |
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
|
48 |
]
|
49 |
|
50 |
+
# Mask2Former label map
|
51 |
+
MASK2FORMER_COCO_NAMES = mask2former_model.config.id2label if hasattr(mask2former_model.config, "id2label") else {str(i): str(i) for i in range(133)}
|
52 |
+
|
53 |
def detect_objects_frcnn(image, threshold=0.5):
|
54 |
"""Run Faster R-CNN detection."""
|
55 |
if image is None:
|
|
|
195 |
for i in range(len(masks)):
|
196 |
if scores[i] >= threshold:
|
197 |
mask = masks[i, 0].cpu().numpy()
|
198 |
+
mask = mask > 0.5
|
199 |
color = np.random.rand(3)
|
200 |
colored_mask = np.zeros_like(image_np, dtype=np.uint8)
|
201 |
for c in range(3):
|
|
|
225 |
plt.close()
|
226 |
return error_path, 0
|
227 |
|
228 |
+
def detect_objects_mask2former(image, threshold=0.5):
|
229 |
+
"""Run Mask2Former detection and segmentation."""
|
230 |
+
if image is None:
|
231 |
+
blank_img = Image.new('RGB', (400, 400), color='white')
|
232 |
+
plt.figure(figsize=(10, 10))
|
233 |
+
plt.imshow(blank_img)
|
234 |
+
plt.text(0.5, 0.5, "No image provided", horizontalalignment='center', verticalalignment='center',
|
235 |
+
transform=plt.gca().transAxes, fontsize=20)
|
236 |
+
plt.axis('off')
|
237 |
+
output_path = "mask2former_blank_output.png"
|
238 |
+
plt.savefig(output_path)
|
239 |
+
plt.close()
|
240 |
+
return output_path, 0
|
241 |
+
|
242 |
+
try:
|
243 |
+
image = image.convert('RGB')
|
244 |
+
inputs = mask2former_processor(images=image, return_tensors="pt")
|
245 |
+
with torch.no_grad():
|
246 |
+
outputs = mask2former_model(**inputs)
|
247 |
+
|
248 |
+
results = mask2former_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|
249 |
+
segmentation_map = results["segmentation"].cpu().numpy()
|
250 |
+
segments_info = results["segments_info"]
|
251 |
+
|
252 |
+
valid_detections = sum(1 for segment in segments_info if segment.get("score", 1.0) >= threshold)
|
253 |
+
|
254 |
+
image_np = np.array(image).copy()
|
255 |
+
overlay = image_np.copy()
|
256 |
+
fig, ax = plt.subplots(1, figsize=(10, 10))
|
257 |
+
ax.imshow(image_np)
|
258 |
+
|
259 |
+
for segment in segments_info:
|
260 |
+
score = segment.get("score", 1.0)
|
261 |
+
if score < threshold:
|
262 |
+
continue
|
263 |
+
segment_id = segment["id"]
|
264 |
+
label_id = segment["label_id"]
|
265 |
+
mask = segmentation_map == segment_id
|
266 |
+
color = np.random.rand(3)
|
267 |
+
overlay[mask] = (overlay[mask] * 0.5 + np.array(color) * 255 * 0.5).astype(np.uint8)
|
268 |
+
|
269 |
+
y_indices, x_indices = np.where(mask)
|
270 |
+
if len(x_indices) == 0 or len(y_indices) == 0:
|
271 |
+
continue
|
272 |
+
x1, x2 = x_indices.min(), x_indices.max()
|
273 |
+
y1, y2 = y_indices.min(), y_indices.max()
|
274 |
+
|
275 |
+
label_name = MASK2FORMER_COCO_NAMES.get(str(label_id), str(label_id))
|
276 |
+
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, color=color, linewidth=2))
|
277 |
+
ax.text(x1, y1, f"{label_name}: {score:.2f}", bbox=dict(facecolor='yellow', alpha=0.5), fontsize=10)
|
278 |
+
|
279 |
+
ax.imshow(overlay)
|
280 |
+
ax.axis('off')
|
281 |
+
output_path = "mask2former_output.png"
|
282 |
+
plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
|
283 |
+
plt.close()
|
284 |
+
return output_path, valid_detections
|
285 |
+
except Exception as e:
|
286 |
+
error_img = Image.new('RGB', (400, 400), color='white')
|
287 |
+
plt.figure(figsize=(10, 10))
|
288 |
+
plt.imshow(error_img)
|
289 |
+
plt.text(0.5, 0.5, f"Error: {str(e)}", horizontalalignment='center', verticalalignment='center',
|
290 |
+
transform=plt.gca().transAxes, fontsize=12, wrap=True)
|
291 |
+
plt.axis('off')
|
292 |
+
error_path = "mask2former_error_output.png"
|
293 |
+
plt.savefig(error_path)
|
294 |
+
plt.close()
|
295 |
+
return error_path, 0
|
296 |
+
|
297 |
+
def analyze_performance(image, model_choice, frcnn_threshold=0.5, detr_threshold=0.9, maskrcnn_threshold=0.5, mask2former_threshold=0.5):
|
298 |
"""Analyze and compare model performance."""
|
299 |
if image is None:
|
300 |
+
return "Please upload an image first.", None, None, None, None, "No analysis available."
|
301 |
|
302 |
frcnn_result = None
|
303 |
detr_result = None
|
304 |
maskrcnn_result = None
|
305 |
+
mask2former_result = None
|
306 |
frcnn_count = 0
|
307 |
detr_count = 0
|
308 |
maskrcnn_count = 0
|
309 |
+
mask2former_count = 0
|
310 |
|
311 |
if model_choice in ["Faster R-CNN", "All"]:
|
312 |
frcnn_result, frcnn_count = detect_objects_frcnn(image, frcnn_threshold)
|
|
|
317 |
if model_choice in ["Mask R-CNN", "All"]:
|
318 |
maskrcnn_result, maskrcnn_count = detect_objects_maskrcnn(image, maskrcnn_threshold)
|
319 |
|
320 |
+
if model_choice in ["Mask2Former", "All"]:
|
321 |
+
mask2former_result, mask2former_count = detect_objects_mask2former(image, mask2former_threshold)
|
322 |
+
|
323 |
# Compare and analyze performance
|
324 |
analysis = ""
|
325 |
if model_choice == "All":
|
|
|
326 |
counts = {
|
327 |
"Faster R-CNN": frcnn_count,
|
328 |
"DETR": detr_count,
|
329 |
+
"Mask R-CNN": maskrcnn_count,
|
330 |
+
"Mask2Former": mask2former_count
|
331 |
}
|
332 |
max_count = max(counts.values())
|
333 |
max_models = [model for model, count in counts.items() if count == max_count]
|
|
|
337 |
else:
|
338 |
analysis = f"{', '.join(max_models)} detected the same number of objects ({max_count}). "
|
339 |
|
340 |
+
analysis += "Faster R-CNN is typically faster and good for general detection. DETR excels in complex scenes with better context understanding. Mask R-CNN and Mask2Former provide instance segmentation for precise object boundaries, with Mask2Former leveraging a transformer-based architecture for potentially superior performance in complex scenes."
|
341 |
|
342 |
# Add image-specific recommendation
|
343 |
img_array = np.array(image)
|
|
|
345 |
pixel_variance = np.var(img_array)
|
346 |
|
347 |
if height * width > 1000 * 1000:
|
348 |
+
analysis += "\n\nThis is a high-resolution image. DETR and Mask2Former typically perform better on high-resolution images with complex scenes."
|
349 |
if pixel_variance > 1000:
|
350 |
+
analysis += "\n\nThis image has high contrast/complexity. DETR and Mask2Former may provide better context-aware detections."
|
351 |
if height * width < 500 * 500:
|
352 |
analysis += "\n\nFor smaller images, Faster R-CNN often provides good results at lower computational cost."
|
353 |
if max_count > 0:
|
354 |
+
analysis += "\n\nSince Mask R-CNN and Mask2Former provide segmentation, they may be preferable if precise object boundaries are needed, with Mask2Former potentially offering better performance due to its transformer-based design."
|
355 |
|
356 |
elif model_choice == "Faster R-CNN":
|
357 |
analysis = f"Faster R-CNN detected {frcnn_count} objects with a confidence threshold of {frcnn_threshold}."
|
358 |
elif model_choice == "DETR":
|
359 |
analysis = f"DETR detected {detr_count} objects with a confidence threshold of {detr_threshold}."
|
360 |
+
elif model_choice == "Mask R-CNN":
|
361 |
analysis = f"Mask R-CNN detected {maskrcnn_count} objects with a confidence threshold of {maskrcnn_threshold}. It also provides instance segmentation for precise object boundaries."
|
362 |
+
else: # Mask2Former
|
363 |
+
analysis = f"Mask2Former detected {mask2former_count} objects with a confidence threshold of {mask2former_threshold}. It provides instance segmentation with a transformer-based architecture, potentially offering superior performance in complex scenes."
|
364 |
|
365 |
+
return "Analysis complete!", frcnn_result, detr_result, maskrcnn_result, mask2former_result, analysis
|
366 |
|
367 |
# Create multi-step Gradio interface with a workflow
|
368 |
with gr.Blocks(title="Object Detection Comparison") as app:
|
369 |
+
gr.Markdown("# Object Detection: Faster R-CNN vs DETR vs Mask R-CNN vs Mask2Former")
|
370 |
+
gr.Markdown("### Upload an image and compare four state-of-the-art object detection models")
|
371 |
|
372 |
# State variables
|
373 |
image_state = gr.State(None)
|
|
|
384 |
gr.Markdown("## Step 2: Question")
|
385 |
gr.Markdown("Which model do you think will work better?")
|
386 |
model_choice = gr.Radio(
|
387 |
+
choices=["Faster R-CNN", "DETR", "Mask R-CNN", "Mask2Former", "All"],
|
388 |
label="Select Object Detection Model(s)",
|
389 |
value="All"
|
390 |
)
|
|
|
400 |
minimum=0.0, maximum=1.0, value=0.5, step=0.05,
|
401 |
label="Mask R-CNN Confidence Threshold"
|
402 |
)
|
403 |
+
mask2former_threshold = gr.Slider(
|
404 |
+
minimum=0.0, maximum=1.0, value=0.5, step=0.05,
|
405 |
+
label="Mask2Former Confidence Threshold"
|
406 |
+
)
|
407 |
detect_button = gr.Button("Run", variant="primary")
|
408 |
|
409 |
# Step 3: Results display
|
|
|
416 |
with gr.Column():
|
417 |
gr.Markdown("### DETR Result")
|
418 |
detr_result = gr.Image(type="filepath", label="DETR")
|
419 |
+
with gr.Row():
|
420 |
with gr.Column():
|
421 |
gr.Markdown("### Mask R-CNN Result")
|
422 |
maskrcnn_result = gr.Image(type="filepath", label="Mask R-CNN")
|
423 |
+
with gr.Column():
|
424 |
+
gr.Markdown("### Mask2Former Result")
|
425 |
+
mask2former_result = gr.Image(type="filepath", label="Mask2Former")
|
426 |
|
427 |
+
analysis_output = gr.Textbox(label="Performance Analysis", lines=10)
|
428 |
restart_button = gr.Button("Try Another Image", variant="secondary")
|
429 |
|
430 |
# Upload button click event
|
|
|
442 |
# Detect button click event
|
443 |
detect_button.click(
|
444 |
fn=analyze_performance,
|
445 |
+
inputs=[image_state, model_choice, frcnn_threshold, detr_threshold, maskrcnn_threshold, mask2former_threshold],
|
446 |
+
outputs=[gr.Textbox(visible=False), frcnn_result, detr_result, maskrcnn_result, mask2former_result, analysis_output]
|
447 |
).then(
|
448 |
fn=lambda: (gr.update(visible=True)),
|
449 |
outputs=[results_panel]
|