JohnJoelMota's picture
Update app.py
453991c verified
raw
history blame
21.5 kB
import torch
import torchvision
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
from transformers import DetrImageProcessor, DetrForObjectDetection
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import gradio as gr
import os
import io
# Load Faster R-CNN model with proper weight assignment
frcnn_weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
frcnn_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None, progress=True)
state_dict = torch.hub.load_state_dict_from_url(frcnn_weights.url, progress=True, map_location=torch.device('cpu'))
frcnn_model.load_state_dict(state_dict, strict=False)
frcnn_model.eval()
# Load DETR model and processor
detr_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
detr_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
# Load Mask R-CNN model
maskrcnn_model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
maskrcnn_model.eval()
# Load Mask2Former model and processor
mask2former_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-coco-instance")
mask2former_model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-small-coco-instance")
mask2former_model.eval()
# COCO class names for Faster R-CNN and Mask R-CNN
COCO_INSTANCE_CATEGORY_NAMES = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
# Mask2Former label map
MASK2FORMER_COCO_NAMES = mask2former_model.config.id2label if hasattr(mask2former_model.config, "id2label") else {str(i): str(i) for i in range(133)}
def detect_objects_frcnn(image, threshold=0.5):
"""Run Faster R-CNN detection."""
if image is None:
blank_img = Image.new('RGB', (400, 400), color='white')
plt.figure(figsize=(10, 10))
plt.imshow(blank_img)
plt.text(0.5, 0.5, "No image provided", horizontalalignment='center', verticalalignment='center',
transform=plt.gca().transAxes, fontsize=20)
plt.axis('off')
output_path = "frcnn_blank_output.png"
plt.savefig(output_path)
plt.close()
return output_path, 0
try:
threshold = float(threshold) if threshold is not None else 0.5
image = image.convert('RGB')
img_array = np.array(image).astype(np.float32) / 255.0
transform = frcnn_weights.transforms()
image_tensor = transform(Image.fromarray((img_array * 255).astype(np.uint8))).unsqueeze(0)
with torch.no_grad():
prediction = frcnn_model(image_tensor)[0]
boxes = prediction['boxes'].cpu().numpy()
labels = prediction['labels'].cpu().numpy()
scores = prediction['scores'].cpu().numpy()
valid_detections = sum(1 for score in scores if score >= threshold)
image_np = np.array(image)
plt.figure(figsize=(10, 10))
plt.imshow(image_np)
ax = plt.gca()
for box, label, score in zip(boxes, labels, scores):
if score >= threshold:
x1, y1, x2, y2 = box
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, color='red', linewidth=2))
class_name = COCO_INSTANCE_CATEGORY_NAMES[label]
ax.text(x1, y1, f'{class_name}: {score:.2f}', bbox=dict(facecolor='yellow', alpha=0.5), fontsize=12, color='black')
plt.axis('off')
plt.tight_layout()
output_path = "frcnn_output.png"
plt.savefig(output_path)
plt.close()
return output_path, valid_detections
except Exception as e:
error_img = Image.new('RGB', (400, 400), color='white')
plt.figure(figsize=(10, 10))
plt.imshow(error_img)
plt.text(0.5, 0.5, f"Error: {str(e)}", horizontalalignment='center', verticalalignment='center',
transform=plt.gca().transAxes, fontsize=12, wrap=True)
plt.axis('off')
error_path = "frcnn_error_output.png"
plt.savefig(error_path)
plt.close()
return error_path, 0
def detect_objects_detr(image, threshold=0.9):
"""Run DETR detection."""
if image is None:
blank_img = Image.new('RGB', (400, 400), color='white')
fig, ax = plt.subplots(1, figsize=(10, 10))
ax.imshow(blank_img)
ax.text(0.5, 0.5, "No image provided", horizontalalignment='center', verticalalignment='center',
transform=ax.transAxes, fontsize=20)
plt.axis('off')
output_path = "detr_blank_output.png"
plt.savefig(output_path)
plt.close(fig)
return output_path, 0
try:
image = image.convert('RGB')
inputs = detr_processor(images=image, return_tensors="pt")
outputs = detr_model(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = detr_processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=threshold)[0]
valid_detections = len(results["scores"])
fig, ax = plt.subplots(1, figsize=(10, 10))
ax.imshow(image)
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
xmin, ymin, xmax, ymax = box.tolist()
ax.add_patch(patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=2, edgecolor='red', facecolor='none'))
ax.text(xmin, ymin, f"{detr_model.config.id2label[label.item()]}: {round(score.item(), 2)}",
bbox=dict(facecolor='yellow', alpha=0.5), fontsize=8)
plt.axis('off')
output_path = "detr_output.png"
plt.savefig(output_path)
plt.close(fig)
return output_path, valid_detections
except Exception as e:
error_img = Image.new('RGB', (400, 400), color='white')
fig, ax = plt.subplots(1, figsize=(10, 10))
ax.imshow(error_img)
ax.text(0.5, 0.5, f"Error: {str(e)}", horizontalalignment='center', verticalalignment='center',
transform=ax.transAxes, fontsize=12, wrap=True)
plt.axis('off')
error_path = "detr_error_output.png"
plt.savefig(error_path)
plt.close(fig)
return error_path, 0
def detect_objects_maskrcnn(image, threshold=0.5):
"""Run Mask R-CNN detection and segmentation."""
if image is None:
blank_img = Image.new('RGB', (400, 400), color='white')
plt.figure(figsize=(10, 10))
plt.imshow(blank_img)
plt.text(0.5, 0.5, "No image provided", horizontalalignment='center', verticalalignment='center',
transform=plt.gca().transAxes, fontsize=20)
plt.axis('off')
output_path = "maskrcnn_blank_output.png"
plt.savefig(output_path)
plt.close()
return output_path, 0
try:
image = image.convert('RGB')
transform = torchvision.transforms.ToTensor()
img_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
output = maskrcnn_model(img_tensor)[0]
masks = output['masks']
boxes = output['boxes'].cpu().numpy()
labels = output['labels'].cpu().numpy()
scores = output['scores'].cpu().numpy()
valid_detections = sum(1 for score in scores if score >= threshold)
image_np = np.array(image).copy()
fig, ax = plt.subplots(1, figsize=(10, 10))
ax.imshow(image_np)
for i in range(len(masks)):
if scores[i] >= threshold:
mask = masks[i, 0].cpu().numpy()
mask = mask > 0.5
color = np.random.rand(3)
colored_mask = np.zeros_like(image_np, dtype=np.uint8)
for c in range(3):
colored_mask[:, :, c] = mask * int(color[c] * 255)
image_np = np.where(mask[:, :, None], 0.5 * image_np + 0.5 * colored_mask, image_np).astype(np.uint8)
x1, y1, x2, y2 = boxes[i]
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, color=color, linewidth=2))
label = COCO_INSTANCE_CATEGORY_NAMES[labels[i]]
ax.text(x1, y1, f"{label}: {scores[i]:.2f}", bbox=dict(facecolor='yellow', alpha=0.5), fontsize=10)
ax.imshow(image_np)
ax.axis('off')
output_path = "maskrcnn_output.png"
plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
plt.close()
return output_path, valid_detections
except Exception as e:
error_img = Image.new('RGB', (400, 400), color='white')
plt.figure(figsize=(10, 10))
plt.imshow(error_img)
plt.text(0.5, 0.5, f"Error: {str(e)}", horizontalalignment='center', verticalalignment='center',
transform=plt.gca().transAxes, fontsize=12, wrap=True)
plt.axis('off')
error_path = "maskrcnn_error_output.png"
plt.savefig(error_path)
plt.close()
return error_path, 0
def detect_objects_mask2former(image, threshold=0.5):
"""Run Mask2Former detection and segmentation."""
if image is None:
blank_img = Image.new('RGB', (400, 400), color='white')
plt.figure(figsize=(10, 10))
plt.imshow(blank_img)
plt.text(0.5, 0.5, "No image provided", horizontalalignment='center', verticalalignment='center',
transform=plt.gca().transAxes, fontsize=20)
plt.axis('off')
output_path = "mask2former_blank_output.png"
plt.savefig(output_path)
plt.close()
return output_path, 0
try:
image = image.convert('RGB')
inputs = mask2former_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = mask2former_model(**inputs)
results = mask2former_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
segmentation_map = results["segmentation"].cpu().numpy()
segments_info = results["segments_info"]
valid_detections = sum(1 for segment in segments_info if segment.get("score", 1.0) >= threshold)
image_np = np.array(image).copy()
overlay = image_np.copy()
fig, ax = plt.subplots(1, figsize=(10, 10))
ax.imshow(image_np)
for segment in segments_info:
score = segment.get("score", 1.0)
if score < threshold:
continue
segment_id = segment["id"]
label_id = segment["label_id"]
mask = segmentation_map == segment_id
color = np.random.rand(3)
overlay[mask] = (overlay[mask] * 0.5 + np.array(color) * 255 * 0.5).astype(np.uint8)
y_indices, x_indices = np.where(mask)
if len(x_indices) == 0 or len(y_indices) == 0:
continue
x1, x2 = x_indices.min(), x_indices.max()
y1, y2 = y_indices.min(), y_indices.max()
label_name = MASK2FORMER_COCO_NAMES.get(str(label_id), str(label_id))
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, color=color, linewidth=2))
ax.text(x1, y1, f"{label_name}: {score:.2f}", bbox=dict(facecolor='yellow', alpha=0.5), fontsize=10)
ax.imshow(overlay)
ax.axis('off')
output_path = "mask2former_output.png"
plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
plt.close()
return output_path, valid_detections
except Exception as e:
error_img = Image.new('RGB', (400, 400), color='white')
plt.figure(figsize=(10, 10))
plt.imshow(error_img)
plt.text(0.5, 0.5, f"Error: {str(e)}", horizontalalignment='center', verticalalignment='center',
transform=plt.gca().transAxes, fontsize=12, wrap=True)
plt.axis('off')
error_path = "mask2former_error_output.png"
plt.savefig(error_path)
plt.close()
return error_path, 0
def analyze_performance(image, model_choice, frcnn_threshold=0.5, detr_threshold=0.9, maskrcnn_threshold=0.5, mask2former_threshold=0.5):
"""Analyze and compare model performance."""
if image is None:
return "Please upload an image first.", None, None, None, None, "No analysis available."
frcnn_result = None
detr_result = None
maskrcnn_result = None
mask2former_result = None
frcnn_count = 0
detr_count = 0
maskrcnn_count = 0
mask2former_count = 0
if model_choice in ["Faster R-CNN", "All"]:
frcnn_result, frcnn_count = detect_objects_frcnn(image, frcnn_threshold)
if model_choice in ["DETR", "All"]:
detr_result, detr_count = detect_objects_detr(image, detr_threshold)
if model_choice in ["Mask R-CNN", "All"]:
maskrcnn_result, maskrcnn_count = detect_objects_maskrcnn(image, maskrcnn_threshold)
if model_choice in ["Mask2Former", "All"]:
mask2former_result, mask2former_count = detect_objects_mask2former(image, mask2former_threshold)
# Compare and analyze performance
analysis = ""
if model_choice == "All":
counts = {
"Faster R-CNN": frcnn_count,
"DETR": detr_count,
"Mask R-CNN": maskrcnn_count,
"Mask2Former": mask2former_count
}
max_count = max(counts.values())
max_models = [model for model, count in counts.items() if count == max_count]
if len(max_models) == 1:
analysis = f"{max_models[0]} detected the most objects ({max_count}). "
else:
analysis = f"{', '.join(max_models)} detected the same number of objects ({max_count}). "
analysis += "Faster R-CNN is typically faster and good for general detection. DETR excels in complex scenes with better context understanding. Mask R-CNN and Mask2Former provide instance segmentation for precise object boundaries, with Mask2Former leveraging a transformer-based architecture for potentially superior performance in complex scenes."
# Add image-specific recommendation
img_array = np.array(image)
height, width = img_array.shape[:2]
pixel_variance = np.var(img_array)
if height * width > 1000 * 1000:
analysis += "\n\nThis is a high-resolution image. DETR and Mask2Former typically perform better on high-resolution images with complex scenes."
if pixel_variance > 1000:
analysis += "\n\nThis image has high contrast/complexity. DETR and Mask2Former may provide better context-aware detections."
if height * width < 500 * 500:
analysis += "\n\nFor smaller images, Faster R-CNN often provides good results at lower computational cost."
if max_count > 0:
analysis += "\n\nSince Mask R-CNN and Mask2Former provide segmentation, they may be preferable if precise object boundaries are needed, with Mask2Former potentially offering better performance due to its transformer-based design."
elif model_choice == "Faster R-CNN":
analysis = f"Faster R-CNN detected {frcnn_count} objects with a confidence threshold of {frcnn_threshold}."
elif model_choice == "DETR":
analysis = f"DETR detected {detr_count} objects with a confidence threshold of {detr_threshold}."
elif model_choice == "Mask R-CNN":
analysis = f"Mask R-CNN detected {maskrcnn_count} objects with a confidence threshold of {maskrcnn_threshold}. It also provides instance segmentation for precise object boundaries."
else: # Mask2Former
analysis = f"Mask2Former detected {mask2former_count} objects with a confidence threshold of {mask2former_threshold}. It provides instance segmentation with a transformer-based architecture, potentially offering superior performance in complex scenes."
return "Analysis complete!", frcnn_result, detr_result, maskrcnn_result, mask2former_result, analysis
# Create multi-step Gradio interface with a workflow
with gr.Blocks(title="Object Detection Comparison") as app:
gr.Markdown("# Object Detection: Faster R-CNN vs DETR vs Mask R-CNN vs Mask2Former")
gr.Markdown("### Upload an image and compare four state-of-the-art object detection models")
# State variables
image_state = gr.State(None)
with gr.Row():
with gr.Column():
# Step 1: Image upload
gr.Markdown("## Step 1: Upload an image")
image_input = gr.Image(type="pil", label="Upload Image")
upload_button = gr.Button("Upload Image", variant="primary")
# Step 2: Detection configuration
with gr.Group(visible=False) as detection_config:
gr.Markdown("## Step 2: Question")
gr.Markdown("Which model do you think will work better?")
model_choice = gr.Radio(
choices=["Faster R-CNN", "DETR", "Mask R-CNN", "Mask2Former", "All"],
label="Select Object Detection Model(s)",
value="All"
)
frcnn_threshold = gr.Slider(
minimum=0.0, maximum=1.0, value=0.5, step=0.05,
label="Faster R-CNN Confidence Threshold"
)
detr_threshold = gr.Slider(
minimum=0.0, maximum=1.0, value=0.9, step=0.05,
label="DETR Confidence Threshold"
)
maskrcnn_threshold = gr.Slider(
minimum=0.0, maximum=1.0, value=0.5, step=0.05,
label="Mask R-CNN Confidence Threshold"
)
mask2former_threshold = gr.Slider(
minimum=0.0, maximum=1.0, value=0.5, step=0.05,
label="Mask2Former Confidence Threshold"
)
detect_button = gr.Button("Run", variant="primary")
# Step 3: Results display
with gr.Column(visible=False) as results_panel:
gr.Markdown("## Step 3: Results")
with gr.Row():
with gr.Column():
gr.Markdown("### Faster R-CNN Result")
frcnn_result = gr.Image(type="filepath", label="Faster R-CNN")
with gr.Column():
gr.Markdown("### DETR Result")
detr_result = gr.Image(type="filepath", label="DETR")
with gr.Row():
with gr.Column():
gr.Markdown("### Mask R-CNN Result")
maskrcnn_result = gr.Image(type="filepath", label="Mask R-CNN")
with gr.Column():
gr.Markdown("### Mask2Former Result")
mask2former_result = gr.Image(type="filepath", label="Mask2Former")
analysis_output = gr.Textbox(label="Performance Analysis", lines=10)
restart_button = gr.Button("Try Another Image", variant="secondary")
# Upload button click event
def upload_image(img):
if img is None:
return None, gr.update(visible=False), gr.update(visible=False)
return img, gr.update(visible=True), gr.update(visible=False)
upload_button.click(
fn=upload_image,
inputs=[image_input],
outputs=[image_state, detection_config, results_panel]
)
# Detect button click event
detect_button.click(
fn=analyze_performance,
inputs=[image_state, model_choice, frcnn_threshold, detr_threshold, maskrcnn_threshold, mask2former_threshold],
outputs=[gr.Textbox(visible=False), frcnn_result, detr_result, maskrcnn_result, mask2former_result, analysis_output]
).then(
fn=lambda: (gr.update(visible=True)),
outputs=[results_panel]
)
# Restart button click event
restart_button.click(
fn=lambda: (None, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)),
outputs=[image_state, results_panel, detection_config, image_input]
)
# Example images
example_images = [
os.path.join(os.getcwd(), "TEST_IMG_1.jpg"),
os.path.join(os.getcwd(), "TEST_IMG_2.JPG"),
os.path.join(os.getcwd(), "TEST_IMG_3.jpg"),
os.path.join(os.getcwd(), "TEST_IMG_4.jpg")
]
valid_examples = [img for img in example_images if os.path.exists(img)]
if valid_examples:
gr.Examples(
examples=valid_examples,
inputs=image_input,
)
if __name__ == "__main__":
app.launch(debug=True)