JohnJoelMota commited on
Commit
453991c
·
verified ·
1 Parent(s): 22fb9d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -17
app.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  import torchvision
3
  from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
4
  from transformers import DetrImageProcessor, DetrForObjectDetection
 
5
  from PIL import Image
6
  import numpy as np
7
  import matplotlib.pyplot as plt
@@ -25,6 +26,11 @@ detr_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
25
  maskrcnn_model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
26
  maskrcnn_model.eval()
27
 
 
 
 
 
 
28
  # COCO class names for Faster R-CNN and Mask R-CNN
29
  COCO_INSTANCE_CATEGORY_NAMES = [
30
  '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
@@ -41,6 +47,9 @@ COCO_INSTANCE_CATEGORY_NAMES = [
41
  'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
42
  ]
43
 
 
 
 
44
  def detect_objects_frcnn(image, threshold=0.5):
45
  """Run Faster R-CNN detection."""
46
  if image is None:
@@ -186,7 +195,7 @@ def detect_objects_maskrcnn(image, threshold=0.5):
186
  for i in range(len(masks)):
187
  if scores[i] >= threshold:
188
  mask = masks[i, 0].cpu().numpy()
189
- mask = mask > 0.5 # Convert to binary mask
190
  color = np.random.rand(3)
191
  colored_mask = np.zeros_like(image_np, dtype=np.uint8)
192
  for c in range(3):
@@ -216,17 +225,88 @@ def detect_objects_maskrcnn(image, threshold=0.5):
216
  plt.close()
217
  return error_path, 0
218
 
219
- def analyze_performance(image, model_choice, frcnn_threshold=0.5, detr_threshold=0.9, maskrcnn_threshold=0.5):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  """Analyze and compare model performance."""
221
  if image is None:
222
- return "Please upload an image first.", None, None, None, "No analysis available."
223
 
224
  frcnn_result = None
225
  detr_result = None
226
  maskrcnn_result = None
 
227
  frcnn_count = 0
228
  detr_count = 0
229
  maskrcnn_count = 0
 
230
 
231
  if model_choice in ["Faster R-CNN", "All"]:
232
  frcnn_result, frcnn_count = detect_objects_frcnn(image, frcnn_threshold)
@@ -237,14 +317,17 @@ def analyze_performance(image, model_choice, frcnn_threshold=0.5, detr_threshold
237
  if model_choice in ["Mask R-CNN", "All"]:
238
  maskrcnn_result, maskrcnn_count = detect_objects_maskrcnn(image, maskrcnn_threshold)
239
 
 
 
 
240
  # Compare and analyze performance
241
  analysis = ""
242
  if model_choice == "All":
243
- # Compare the models
244
  counts = {
245
  "Faster R-CNN": frcnn_count,
246
  "DETR": detr_count,
247
- "Mask R-CNN": maskrcnn_count
 
248
  }
249
  max_count = max(counts.values())
250
  max_models = [model for model, count in counts.items() if count == max_count]
@@ -254,7 +337,7 @@ def analyze_performance(image, model_choice, frcnn_threshold=0.5, detr_threshold
254
  else:
255
  analysis = f"{', '.join(max_models)} detected the same number of objects ({max_count}). "
256
 
257
- analysis += "Faster R-CNN is typically faster and good for general detection. DETR excels in complex scenes with better context understanding. Mask R-CNN provides instance segmentation, which is useful for precise object boundaries but may be slower."
258
 
259
  # Add image-specific recommendation
260
  img_array = np.array(image)
@@ -262,27 +345,29 @@ def analyze_performance(image, model_choice, frcnn_threshold=0.5, detr_threshold
262
  pixel_variance = np.var(img_array)
263
 
264
  if height * width > 1000 * 1000:
265
- analysis += "\n\nThis is a high-resolution image. DETR and Mask R-CNN typically perform better on high-resolution images with complex scenes."
266
  if pixel_variance > 1000:
267
- analysis += "\n\nThis image has high contrast/complexity. DETR and Mask R-CNN may provide better context-aware detections."
268
  if height * width < 500 * 500:
269
  analysis += "\n\nFor smaller images, Faster R-CNN often provides good results at lower computational cost."
270
  if max_count > 0:
271
- analysis += "\n\nSince Mask R-CNN provides segmentation, it may be preferable if precise object boundaries are needed."
272
 
273
  elif model_choice == "Faster R-CNN":
274
  analysis = f"Faster R-CNN detected {frcnn_count} objects with a confidence threshold of {frcnn_threshold}."
275
  elif model_choice == "DETR":
276
  analysis = f"DETR detected {detr_count} objects with a confidence threshold of {detr_threshold}."
277
- else: # Mask R-CNN
278
  analysis = f"Mask R-CNN detected {maskrcnn_count} objects with a confidence threshold of {maskrcnn_threshold}. It also provides instance segmentation for precise object boundaries."
 
 
279
 
280
- return "Analysis complete!", frcnn_result, detr_result, maskrcnn_result, analysis
281
 
282
  # Create multi-step Gradio interface with a workflow
283
  with gr.Blocks(title="Object Detection Comparison") as app:
284
- gr.Markdown("# Object Detection: Faster R-CNN vs DETR vs Mask R-CNN")
285
- gr.Markdown("### Upload an image and compare three state-of-the-art object detection models")
286
 
287
  # State variables
288
  image_state = gr.State(None)
@@ -299,7 +384,7 @@ with gr.Blocks(title="Object Detection Comparison") as app:
299
  gr.Markdown("## Step 2: Question")
300
  gr.Markdown("Which model do you think will work better?")
301
  model_choice = gr.Radio(
302
- choices=["Faster R-CNN", "DETR", "Mask R-CNN", "All"],
303
  label="Select Object Detection Model(s)",
304
  value="All"
305
  )
@@ -315,6 +400,10 @@ with gr.Blocks(title="Object Detection Comparison") as app:
315
  minimum=0.0, maximum=1.0, value=0.5, step=0.05,
316
  label="Mask R-CNN Confidence Threshold"
317
  )
 
 
 
 
318
  detect_button = gr.Button("Run", variant="primary")
319
 
320
  # Step 3: Results display
@@ -327,11 +416,15 @@ with gr.Blocks(title="Object Detection Comparison") as app:
327
  with gr.Column():
328
  gr.Markdown("### DETR Result")
329
  detr_result = gr.Image(type="filepath", label="DETR")
 
330
  with gr.Column():
331
  gr.Markdown("### Mask R-CNN Result")
332
  maskrcnn_result = gr.Image(type="filepath", label="Mask R-CNN")
 
 
 
333
 
334
- analysis_output = gr.Textbox(label="Performance Analysis", lines=8)
335
  restart_button = gr.Button("Try Another Image", variant="secondary")
336
 
337
  # Upload button click event
@@ -349,8 +442,8 @@ with gr.Blocks(title="Object Detection Comparison") as app:
349
  # Detect button click event
350
  detect_button.click(
351
  fn=analyze_performance,
352
- inputs=[image_state, model_choice, frcnn_threshold, detr_threshold, maskrcnn_threshold],
353
- outputs=[gr.Textbox(visible=False), frcnn_result, detr_result, maskrcnn_result, analysis_output]
354
  ).then(
355
  fn=lambda: (gr.update(visible=True)),
356
  outputs=[results_panel]
 
2
  import torchvision
3
  from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
4
  from transformers import DetrImageProcessor, DetrForObjectDetection
5
+ from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
6
  from PIL import Image
7
  import numpy as np
8
  import matplotlib.pyplot as plt
 
26
  maskrcnn_model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
27
  maskrcnn_model.eval()
28
 
29
+ # Load Mask2Former model and processor
30
+ mask2former_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-coco-instance")
31
+ mask2former_model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-small-coco-instance")
32
+ mask2former_model.eval()
33
+
34
  # COCO class names for Faster R-CNN and Mask R-CNN
35
  COCO_INSTANCE_CATEGORY_NAMES = [
36
  '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
 
47
  'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
48
  ]
49
 
50
+ # Mask2Former label map
51
+ MASK2FORMER_COCO_NAMES = mask2former_model.config.id2label if hasattr(mask2former_model.config, "id2label") else {str(i): str(i) for i in range(133)}
52
+
53
  def detect_objects_frcnn(image, threshold=0.5):
54
  """Run Faster R-CNN detection."""
55
  if image is None:
 
195
  for i in range(len(masks)):
196
  if scores[i] >= threshold:
197
  mask = masks[i, 0].cpu().numpy()
198
+ mask = mask > 0.5
199
  color = np.random.rand(3)
200
  colored_mask = np.zeros_like(image_np, dtype=np.uint8)
201
  for c in range(3):
 
225
  plt.close()
226
  return error_path, 0
227
 
228
+ def detect_objects_mask2former(image, threshold=0.5):
229
+ """Run Mask2Former detection and segmentation."""
230
+ if image is None:
231
+ blank_img = Image.new('RGB', (400, 400), color='white')
232
+ plt.figure(figsize=(10, 10))
233
+ plt.imshow(blank_img)
234
+ plt.text(0.5, 0.5, "No image provided", horizontalalignment='center', verticalalignment='center',
235
+ transform=plt.gca().transAxes, fontsize=20)
236
+ plt.axis('off')
237
+ output_path = "mask2former_blank_output.png"
238
+ plt.savefig(output_path)
239
+ plt.close()
240
+ return output_path, 0
241
+
242
+ try:
243
+ image = image.convert('RGB')
244
+ inputs = mask2former_processor(images=image, return_tensors="pt")
245
+ with torch.no_grad():
246
+ outputs = mask2former_model(**inputs)
247
+
248
+ results = mask2former_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
249
+ segmentation_map = results["segmentation"].cpu().numpy()
250
+ segments_info = results["segments_info"]
251
+
252
+ valid_detections = sum(1 for segment in segments_info if segment.get("score", 1.0) >= threshold)
253
+
254
+ image_np = np.array(image).copy()
255
+ overlay = image_np.copy()
256
+ fig, ax = plt.subplots(1, figsize=(10, 10))
257
+ ax.imshow(image_np)
258
+
259
+ for segment in segments_info:
260
+ score = segment.get("score", 1.0)
261
+ if score < threshold:
262
+ continue
263
+ segment_id = segment["id"]
264
+ label_id = segment["label_id"]
265
+ mask = segmentation_map == segment_id
266
+ color = np.random.rand(3)
267
+ overlay[mask] = (overlay[mask] * 0.5 + np.array(color) * 255 * 0.5).astype(np.uint8)
268
+
269
+ y_indices, x_indices = np.where(mask)
270
+ if len(x_indices) == 0 or len(y_indices) == 0:
271
+ continue
272
+ x1, x2 = x_indices.min(), x_indices.max()
273
+ y1, y2 = y_indices.min(), y_indices.max()
274
+
275
+ label_name = MASK2FORMER_COCO_NAMES.get(str(label_id), str(label_id))
276
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, color=color, linewidth=2))
277
+ ax.text(x1, y1, f"{label_name}: {score:.2f}", bbox=dict(facecolor='yellow', alpha=0.5), fontsize=10)
278
+
279
+ ax.imshow(overlay)
280
+ ax.axis('off')
281
+ output_path = "mask2former_output.png"
282
+ plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
283
+ plt.close()
284
+ return output_path, valid_detections
285
+ except Exception as e:
286
+ error_img = Image.new('RGB', (400, 400), color='white')
287
+ plt.figure(figsize=(10, 10))
288
+ plt.imshow(error_img)
289
+ plt.text(0.5, 0.5, f"Error: {str(e)}", horizontalalignment='center', verticalalignment='center',
290
+ transform=plt.gca().transAxes, fontsize=12, wrap=True)
291
+ plt.axis('off')
292
+ error_path = "mask2former_error_output.png"
293
+ plt.savefig(error_path)
294
+ plt.close()
295
+ return error_path, 0
296
+
297
+ def analyze_performance(image, model_choice, frcnn_threshold=0.5, detr_threshold=0.9, maskrcnn_threshold=0.5, mask2former_threshold=0.5):
298
  """Analyze and compare model performance."""
299
  if image is None:
300
+ return "Please upload an image first.", None, None, None, None, "No analysis available."
301
 
302
  frcnn_result = None
303
  detr_result = None
304
  maskrcnn_result = None
305
+ mask2former_result = None
306
  frcnn_count = 0
307
  detr_count = 0
308
  maskrcnn_count = 0
309
+ mask2former_count = 0
310
 
311
  if model_choice in ["Faster R-CNN", "All"]:
312
  frcnn_result, frcnn_count = detect_objects_frcnn(image, frcnn_threshold)
 
317
  if model_choice in ["Mask R-CNN", "All"]:
318
  maskrcnn_result, maskrcnn_count = detect_objects_maskrcnn(image, maskrcnn_threshold)
319
 
320
+ if model_choice in ["Mask2Former", "All"]:
321
+ mask2former_result, mask2former_count = detect_objects_mask2former(image, mask2former_threshold)
322
+
323
  # Compare and analyze performance
324
  analysis = ""
325
  if model_choice == "All":
 
326
  counts = {
327
  "Faster R-CNN": frcnn_count,
328
  "DETR": detr_count,
329
+ "Mask R-CNN": maskrcnn_count,
330
+ "Mask2Former": mask2former_count
331
  }
332
  max_count = max(counts.values())
333
  max_models = [model for model, count in counts.items() if count == max_count]
 
337
  else:
338
  analysis = f"{', '.join(max_models)} detected the same number of objects ({max_count}). "
339
 
340
+ analysis += "Faster R-CNN is typically faster and good for general detection. DETR excels in complex scenes with better context understanding. Mask R-CNN and Mask2Former provide instance segmentation for precise object boundaries, with Mask2Former leveraging a transformer-based architecture for potentially superior performance in complex scenes."
341
 
342
  # Add image-specific recommendation
343
  img_array = np.array(image)
 
345
  pixel_variance = np.var(img_array)
346
 
347
  if height * width > 1000 * 1000:
348
+ analysis += "\n\nThis is a high-resolution image. DETR and Mask2Former typically perform better on high-resolution images with complex scenes."
349
  if pixel_variance > 1000:
350
+ analysis += "\n\nThis image has high contrast/complexity. DETR and Mask2Former may provide better context-aware detections."
351
  if height * width < 500 * 500:
352
  analysis += "\n\nFor smaller images, Faster R-CNN often provides good results at lower computational cost."
353
  if max_count > 0:
354
+ analysis += "\n\nSince Mask R-CNN and Mask2Former provide segmentation, they may be preferable if precise object boundaries are needed, with Mask2Former potentially offering better performance due to its transformer-based design."
355
 
356
  elif model_choice == "Faster R-CNN":
357
  analysis = f"Faster R-CNN detected {frcnn_count} objects with a confidence threshold of {frcnn_threshold}."
358
  elif model_choice == "DETR":
359
  analysis = f"DETR detected {detr_count} objects with a confidence threshold of {detr_threshold}."
360
+ elif model_choice == "Mask R-CNN":
361
  analysis = f"Mask R-CNN detected {maskrcnn_count} objects with a confidence threshold of {maskrcnn_threshold}. It also provides instance segmentation for precise object boundaries."
362
+ else: # Mask2Former
363
+ analysis = f"Mask2Former detected {mask2former_count} objects with a confidence threshold of {mask2former_threshold}. It provides instance segmentation with a transformer-based architecture, potentially offering superior performance in complex scenes."
364
 
365
+ return "Analysis complete!", frcnn_result, detr_result, maskrcnn_result, mask2former_result, analysis
366
 
367
  # Create multi-step Gradio interface with a workflow
368
  with gr.Blocks(title="Object Detection Comparison") as app:
369
+ gr.Markdown("# Object Detection: Faster R-CNN vs DETR vs Mask R-CNN vs Mask2Former")
370
+ gr.Markdown("### Upload an image and compare four state-of-the-art object detection models")
371
 
372
  # State variables
373
  image_state = gr.State(None)
 
384
  gr.Markdown("## Step 2: Question")
385
  gr.Markdown("Which model do you think will work better?")
386
  model_choice = gr.Radio(
387
+ choices=["Faster R-CNN", "DETR", "Mask R-CNN", "Mask2Former", "All"],
388
  label="Select Object Detection Model(s)",
389
  value="All"
390
  )
 
400
  minimum=0.0, maximum=1.0, value=0.5, step=0.05,
401
  label="Mask R-CNN Confidence Threshold"
402
  )
403
+ mask2former_threshold = gr.Slider(
404
+ minimum=0.0, maximum=1.0, value=0.5, step=0.05,
405
+ label="Mask2Former Confidence Threshold"
406
+ )
407
  detect_button = gr.Button("Run", variant="primary")
408
 
409
  # Step 3: Results display
 
416
  with gr.Column():
417
  gr.Markdown("### DETR Result")
418
  detr_result = gr.Image(type="filepath", label="DETR")
419
+ with gr.Row():
420
  with gr.Column():
421
  gr.Markdown("### Mask R-CNN Result")
422
  maskrcnn_result = gr.Image(type="filepath", label="Mask R-CNN")
423
+ with gr.Column():
424
+ gr.Markdown("### Mask2Former Result")
425
+ mask2former_result = gr.Image(type="filepath", label="Mask2Former")
426
 
427
+ analysis_output = gr.Textbox(label="Performance Analysis", lines=10)
428
  restart_button = gr.Button("Try Another Image", variant="secondary")
429
 
430
  # Upload button click event
 
442
  # Detect button click event
443
  detect_button.click(
444
  fn=analyze_performance,
445
+ inputs=[image_state, model_choice, frcnn_threshold, detr_threshold, maskrcnn_threshold, mask2former_threshold],
446
+ outputs=[gr.Textbox(visible=False), frcnn_result, detr_result, maskrcnn_result, mask2former_result, analysis_output]
447
  ).then(
448
  fn=lambda: (gr.update(visible=True)),
449
  outputs=[results_panel]