Spaces:
Sleeping
Sleeping
| import torch | |
| import torchvision | |
| from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights | |
| from transformers import DetrImageProcessor, DetrForObjectDetection | |
| from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation | |
| from PIL import Image | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as patches | |
| import gradio as gr | |
| import os | |
| import io | |
| # Load Faster R-CNN model with proper weight assignment | |
| frcnn_weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT | |
| frcnn_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None, progress=True) | |
| state_dict = torch.hub.load_state_dict_from_url(frcnn_weights.url, progress=True, map_location=torch.device('cpu')) | |
| frcnn_model.load_state_dict(state_dict, strict=False) | |
| frcnn_model.eval() | |
| # Load DETR model and processor | |
| detr_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
| detr_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") | |
| # Load Mask R-CNN model | |
| maskrcnn_model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) | |
| maskrcnn_model.eval() | |
| # Load Mask2Former model and processor | |
| mask2former_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-coco-instance") | |
| mask2former_model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-small-coco-instance") | |
| mask2former_model.eval() | |
| # COCO class names for Faster R-CNN and Mask R-CNN | |
| COCO_INSTANCE_CATEGORY_NAMES = [ | |
| '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |
| 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', | |
| 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', | |
| 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', | |
| 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', | |
| 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', | |
| 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', | |
| 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', | |
| 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', | |
| 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', | |
| 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', | |
| 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' | |
| ] | |
| # Mask2Former label map | |
| MASK2FORMER_COCO_NAMES = mask2former_model.config.id2label if hasattr(mask2former_model.config, "id2label") else {str(i): str(i) for i in range(133)} | |
| def detect_objects_frcnn(image, threshold=0.5): | |
| """Run Faster R-CNN detection.""" | |
| if image is None: | |
| blank_img = Image.new('RGB', (400, 400), color='white') | |
| plt.figure(figsize=(10, 10)) | |
| plt.imshow(blank_img) | |
| plt.text(0.5, 0.5, "No image provided", horizontalalignment='center', verticalalignment='center', | |
| transform=plt.gca().transAxes, fontsize=20) | |
| plt.axis('off') | |
| output_path = "frcnn_blank_output.png" | |
| plt.savefig(output_path) | |
| plt.close() | |
| return output_path, 0 | |
| try: | |
| threshold = float(threshold) if threshold is not None else 0.5 | |
| image = image.convert('RGB') | |
| img_array = np.array(image).astype(np.float32) / 255.0 | |
| transform = frcnn_weights.transforms() | |
| image_tensor = transform(Image.fromarray((img_array * 255).astype(np.uint8))).unsqueeze(0) | |
| with torch.no_grad(): | |
| prediction = frcnn_model(image_tensor)[0] | |
| boxes = prediction['boxes'].cpu().numpy() | |
| labels = prediction['labels'].cpu().numpy() | |
| scores = prediction['scores'].cpu().numpy() | |
| valid_detections = sum(1 for score in scores if score >= threshold) | |
| image_np = np.array(image) | |
| plt.figure(figsize=(10, 10)) | |
| plt.imshow(image_np) | |
| ax = plt.gca() | |
| for box, label, score in zip(boxes, labels, scores): | |
| if score >= threshold: | |
| x1, y1, x2, y2 = box | |
| ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, color='red', linewidth=2)) | |
| class_name = COCO_INSTANCE_CATEGORY_NAMES[label] | |
| ax.text(x1, y1, f'{class_name}: {score:.2f}', bbox=dict(facecolor='yellow', alpha=0.5), fontsize=12, color='black') | |
| plt.axis('off') | |
| plt.tight_layout() | |
| output_path = "frcnn_output.png" | |
| plt.savefig(output_path) | |
| plt.close() | |
| return output_path, valid_detections | |
| except Exception as e: | |
| error_img = Image.new('RGB', (400, 400), color='white') | |
| plt.figure(figsize=(10, 10)) | |
| plt.imshow(error_img) | |
| plt.text(0.5, 0.5, f"Error: {str(e)}", horizontalalignment='center', verticalalignment='center', | |
| transform=plt.gca().transAxes, fontsize=12, wrap=True) | |
| plt.axis('off') | |
| error_path = "frcnn_error_output.png" | |
| plt.savefig(error_path) | |
| plt.close() | |
| return error_path, 0 | |
| def detect_objects_detr(image, threshold=0.9): | |
| """Run DETR detection.""" | |
| if image is None: | |
| blank_img = Image.new('RGB', (400, 400), color='white') | |
| fig, ax = plt.subplots(1, figsize=(10, 10)) | |
| ax.imshow(blank_img) | |
| ax.text(0.5, 0.5, "No image provided", horizontalalignment='center', verticalalignment='center', | |
| transform=ax.transAxes, fontsize=20) | |
| plt.axis('off') | |
| output_path = "detr_blank_output.png" | |
| plt.savefig(output_path) | |
| plt.close(fig) | |
| return output_path, 0 | |
| try: | |
| image = image.convert('RGB') | |
| inputs = detr_processor(images=image, return_tensors="pt") | |
| outputs = detr_model(**inputs) | |
| target_sizes = torch.tensor([image.size[::-1]]) | |
| results = detr_processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=threshold)[0] | |
| valid_detections = len(results["scores"]) | |
| fig, ax = plt.subplots(1, figsize=(10, 10)) | |
| ax.imshow(image) | |
| for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | |
| xmin, ymin, xmax, ymax = box.tolist() | |
| ax.add_patch(patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=2, edgecolor='red', facecolor='none')) | |
| ax.text(xmin, ymin, f"{detr_model.config.id2label[label.item()]}: {round(score.item(), 2)}", | |
| bbox=dict(facecolor='yellow', alpha=0.5), fontsize=8) | |
| plt.axis('off') | |
| output_path = "detr_output.png" | |
| plt.savefig(output_path) | |
| plt.close(fig) | |
| return output_path, valid_detections | |
| except Exception as e: | |
| error_img = Image.new('RGB', (400, 400), color='white') | |
| fig, ax = plt.subplots(1, figsize=(10, 10)) | |
| ax.imshow(error_img) | |
| ax.text(0.5, 0.5, f"Error: {str(e)}", horizontalalignment='center', verticalalignment='center', | |
| transform=ax.transAxes, fontsize=12, wrap=True) | |
| plt.axis('off') | |
| error_path = "detr_error_output.png" | |
| plt.savefig(error_path) | |
| plt.close(fig) | |
| return error_path, 0 | |
| def detect_objects_maskrcnn(image, threshold=0.5): | |
| """Run Mask R-CNN detection and segmentation.""" | |
| if image is None: | |
| blank_img = Image.new('RGB', (400, 400), color='white') | |
| plt.figure(figsize=(10, 10)) | |
| plt.imshow(blank_img) | |
| plt.text(0.5, 0.5, "No image provided", horizontalalignment='center', verticalalignment='center', | |
| transform=plt.gca().transAxes, fontsize=20) | |
| plt.axis('off') | |
| output_path = "maskrcnn_blank_output.png" | |
| plt.savefig(output_path) | |
| plt.close() | |
| return output_path, 0 | |
| try: | |
| image = image.convert('RGB') | |
| transform = torchvision.transforms.ToTensor() | |
| img_tensor = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| output = maskrcnn_model(img_tensor)[0] | |
| masks = output['masks'] | |
| boxes = output['boxes'].cpu().numpy() | |
| labels = output['labels'].cpu().numpy() | |
| scores = output['scores'].cpu().numpy() | |
| valid_detections = sum(1 for score in scores if score >= threshold) | |
| image_np = np.array(image).copy() | |
| fig, ax = plt.subplots(1, figsize=(10, 10)) | |
| ax.imshow(image_np) | |
| for i in range(len(masks)): | |
| if scores[i] >= threshold: | |
| mask = masks[i, 0].cpu().numpy() | |
| mask = mask > 0.5 | |
| color = np.random.rand(3) | |
| colored_mask = np.zeros_like(image_np, dtype=np.uint8) | |
| for c in range(3): | |
| colored_mask[:, :, c] = mask * int(color[c] * 255) | |
| image_np = np.where(mask[:, :, None], 0.5 * image_np + 0.5 * colored_mask, image_np).astype(np.uint8) | |
| x1, y1, x2, y2 = boxes[i] | |
| ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, color=color, linewidth=2)) | |
| label = COCO_INSTANCE_CATEGORY_NAMES[labels[i]] | |
| ax.text(x1, y1, f"{label}: {scores[i]:.2f}", bbox=dict(facecolor='yellow', alpha=0.5), fontsize=10) | |
| ax.imshow(image_np) | |
| ax.axis('off') | |
| output_path = "maskrcnn_output.png" | |
| plt.savefig(output_path, bbox_inches='tight', pad_inches=0) | |
| plt.close() | |
| return output_path, valid_detections | |
| except Exception as e: | |
| error_img = Image.new('RGB', (400, 400), color='white') | |
| plt.figure(figsize=(10, 10)) | |
| plt.imshow(error_img) | |
| plt.text(0.5, 0.5, f"Error: {str(e)}", horizontalalignment='center', verticalalignment='center', | |
| transform=plt.gca().transAxes, fontsize=12, wrap=True) | |
| plt.axis('off') | |
| error_path = "maskrcnn_error_output.png" | |
| plt.savefig(error_path) | |
| plt.close() | |
| return error_path, 0 | |
| def detect_objects_mask2former(image, threshold=0.5): | |
| """Run Mask2Former detection and segmentation.""" | |
| if image is None: | |
| blank_img = Image.new('RGB', (400, 400), color='white') | |
| plt.figure(figsize=(10, 10)) | |
| plt.imshow(blank_img) | |
| plt.text(0.5, 0.5, "No image provided", horizontalalignment='center', verticalalignment='center', | |
| transform=plt.gca().transAxes, fontsize=20) | |
| plt.axis('off') | |
| output_path = "mask2former_blank_output.png" | |
| plt.savefig(output_path) | |
| plt.close() | |
| return output_path, 0 | |
| try: | |
| image = image.convert('RGB') | |
| inputs = mask2former_processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = mask2former_model(**inputs) | |
| results = mask2former_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0] | |
| segmentation_map = results["segmentation"].cpu().numpy() | |
| segments_info = results["segments_info"] | |
| valid_detections = sum(1 for segment in segments_info if segment.get("score", 1.0) >= threshold) | |
| image_np = np.array(image).copy() | |
| overlay = image_np.copy() | |
| fig, ax = plt.subplots(1, figsize=(10, 10)) | |
| ax.imshow(image_np) | |
| for segment in segments_info: | |
| score = segment.get("score", 1.0) | |
| if score < threshold: | |
| continue | |
| segment_id = segment["id"] | |
| label_id = segment["label_id"] | |
| mask = segmentation_map == segment_id | |
| color = np.random.rand(3) | |
| overlay[mask] = (overlay[mask] * 0.5 + np.array(color) * 255 * 0.5).astype(np.uint8) | |
| y_indices, x_indices = np.where(mask) | |
| if len(x_indices) == 0 or len(y_indices) == 0: | |
| continue | |
| x1, x2 = x_indices.min(), x_indices.max() | |
| y1, y2 = y_indices.min(), y_indices.max() | |
| label_name = MASK2FORMER_COCO_NAMES.get(str(label_id), str(label_id)) | |
| ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, color=color, linewidth=2)) | |
| ax.text(x1, y1, f"{label_name}: {score:.2f}", bbox=dict(facecolor='yellow', alpha=0.5), fontsize=10) | |
| ax.imshow(overlay) | |
| ax.axis('off') | |
| output_path = "mask2former_output.png" | |
| plt.savefig(output_path, bbox_inches='tight', pad_inches=0) | |
| plt.close() | |
| return output_path, valid_detections | |
| except Exception as e: | |
| error_img = Image.new('RGB', (400, 400), color='white') | |
| plt.figure(figsize=(10, 10)) | |
| plt.imshow(error_img) | |
| plt.text(0.5, 0.5, f"Error: {str(e)}", horizontalalignment='center', verticalalignment='center', | |
| transform=plt.gca().transAxes, fontsize=12, wrap=True) | |
| plt.axis('off') | |
| error_path = "mask2former_error_output.png" | |
| plt.savefig(error_path) | |
| plt.close() | |
| return error_path, 0 | |
| def analyze_performance(image, model_choice, frcnn_threshold=0.5, detr_threshold=0.9, maskrcnn_threshold=0.5, mask2former_threshold=0.5): | |
| """Analyze and compare model performance.""" | |
| if image is None: | |
| return "Please upload an image first.", None, None, None, None, "No analysis available." | |
| frcnn_result = None | |
| detr_result = None | |
| maskrcnn_result = None | |
| mask2former_result = None | |
| frcnn_count = 0 | |
| detr_count = 0 | |
| maskrcnn_count = 0 | |
| mask2former_count = 0 | |
| if model_choice in ["Faster R-CNN", "All"]: | |
| frcnn_result, frcnn_count = detect_objects_frcnn(image, frcnn_threshold) | |
| if model_choice in ["DETR", "All"]: | |
| detr_result, detr_count = detect_objects_detr(image, detr_threshold) | |
| if model_choice in ["Mask R-CNN", "All"]: | |
| maskrcnn_result, maskrcnn_count = detect_objects_maskrcnn(image, maskrcnn_threshold) | |
| if model_choice in ["Mask2Former", "All"]: | |
| mask2former_result, mask2former_count = detect_objects_mask2former(image, mask2former_threshold) | |
| # Compare and analyze performance | |
| analysis = "" | |
| if model_choice == "All": | |
| counts = { | |
| "Faster R-CNN": frcnn_count, | |
| "DETR": detr_count, | |
| "Mask R-CNN": maskrcnn_count, | |
| "Mask2Former": mask2former_count | |
| } | |
| max_count = max(counts.values()) | |
| max_models = [model for model, count in counts.items() if count == max_count] | |
| if len(max_models) == 1: | |
| analysis = f"{max_models[0]} detected the most objects ({max_count}). " | |
| else: | |
| analysis = f"{', '.join(max_models)} detected the same number of objects ({max_count}). " | |
| analysis += "Faster R-CNN is typically faster and good for general detection. DETR excels in complex scenes with better context understanding. Mask R-CNN and Mask2Former provide instance segmentation for precise object boundaries, with Mask2Former leveraging a transformer-based architecture for potentially superior performance in complex scenes." | |
| # Add image-specific recommendation | |
| img_array = np.array(image) | |
| height, width = img_array.shape[:2] | |
| pixel_variance = np.var(img_array) | |
| if height * width > 1000 * 1000: | |
| analysis += "\n\nThis is a high-resolution image. DETR and Mask2Former typically perform better on high-resolution images with complex scenes." | |
| if pixel_variance > 1000: | |
| analysis += "\n\nThis image has high contrast/complexity. DETR and Mask2Former may provide better context-aware detections." | |
| if height * width < 500 * 500: | |
| analysis += "\n\nFor smaller images, Faster R-CNN often provides good results at lower computational cost." | |
| if max_count > 0: | |
| analysis += "\n\nSince Mask R-CNN and Mask2Former provide segmentation, they may be preferable if precise object boundaries are needed, with Mask2Former potentially offering better performance due to its transformer-based design." | |
| elif model_choice == "Faster R-CNN": | |
| analysis = f"Faster R-CNN detected {frcnn_count} objects with a confidence threshold of {frcnn_threshold}." | |
| elif model_choice == "DETR": | |
| analysis = f"DETR detected {detr_count} objects with a confidence threshold of {detr_threshold}." | |
| elif model_choice == "Mask R-CNN": | |
| analysis = f"Mask R-CNN detected {maskrcnn_count} objects with a confidence threshold of {maskrcnn_threshold}. It also provides instance segmentation for precise object boundaries." | |
| else: # Mask2Former | |
| analysis = f"Mask2Former detected {mask2former_count} objects with a confidence threshold of {mask2former_threshold}. It provides instance segmentation with a transformer-based architecture, potentially offering superior performance in complex scenes." | |
| return "Analysis complete!", frcnn_result, detr_result, maskrcnn_result, mask2former_result, analysis | |
| # Create multi-step Gradio interface with a workflow | |
| with gr.Blocks(title="Object Detection Comparison") as app: | |
| gr.Markdown("# Object Detection: Faster R-CNN vs DETR vs Mask R-CNN vs Mask2Former") | |
| gr.Markdown("### Upload an image and compare four state-of-the-art object detection models") | |
| # State variables | |
| image_state = gr.State(None) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Step 1: Image upload | |
| gr.Markdown("## Step 1: Upload an image") | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| upload_button = gr.Button("Upload Image", variant="primary") | |
| # Step 2: Detection configuration | |
| with gr.Group(visible=False) as detection_config: | |
| gr.Markdown("## Step 2: Question") | |
| gr.Markdown("Which model do you think will work better?") | |
| model_choice = gr.Radio( | |
| choices=["Faster R-CNN", "DETR", "Mask R-CNN", "Mask2Former", "All"], | |
| label="Select Object Detection Model(s)", | |
| value="All" | |
| ) | |
| frcnn_threshold = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.5, step=0.05, | |
| label="Faster R-CNN Confidence Threshold" | |
| ) | |
| detr_threshold = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.9, step=0.05, | |
| label="DETR Confidence Threshold" | |
| ) | |
| maskrcnn_threshold = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.5, step=0.05, | |
| label="Mask R-CNN Confidence Threshold" | |
| ) | |
| mask2former_threshold = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.5, step=0.05, | |
| label="Mask2Former Confidence Threshold" | |
| ) | |
| detect_button = gr.Button("Run", variant="primary") | |
| # Step 3: Results display | |
| with gr.Column(visible=False) as results_panel: | |
| gr.Markdown("## Step 3: Results") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Faster R-CNN Result") | |
| frcnn_result = gr.Image(type="filepath", label="Faster R-CNN") | |
| with gr.Column(): | |
| gr.Markdown("### DETR Result") | |
| detr_result = gr.Image(type="filepath", label="DETR") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Mask R-CNN Result") | |
| maskrcnn_result = gr.Image(type="filepath", label="Mask R-CNN") | |
| with gr.Column(): | |
| gr.Markdown("### Mask2Former Result") | |
| mask2former_result = gr.Image(type="filepath", label="Mask2Former") | |
| analysis_output = gr.Textbox(label="Performance Analysis", lines=10) | |
| restart_button = gr.Button("Try Another Image", variant="secondary") | |
| # Upload button click event | |
| def upload_image(img): | |
| if img is None: | |
| return None, gr.update(visible=False), gr.update(visible=False) | |
| return img, gr.update(visible=True), gr.update(visible=False) | |
| upload_button.click( | |
| fn=upload_image, | |
| inputs=[image_input], | |
| outputs=[image_state, detection_config, results_panel] | |
| ) | |
| # Detect button click event | |
| detect_button.click( | |
| fn=analyze_performance, | |
| inputs=[image_state, model_choice, frcnn_threshold, detr_threshold, maskrcnn_threshold, mask2former_threshold], | |
| outputs=[gr.Textbox(visible=False), frcnn_result, detr_result, maskrcnn_result, mask2former_result, analysis_output] | |
| ).then( | |
| fn=lambda: (gr.update(visible=True)), | |
| outputs=[results_panel] | |
| ) | |
| # Restart button click event | |
| restart_button.click( | |
| fn=lambda: (None, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)), | |
| outputs=[image_state, results_panel, detection_config, image_input] | |
| ) | |
| # Example images | |
| example_images = [ | |
| os.path.join(os.getcwd(), "TEST_IMG_1.jpg"), | |
| os.path.join(os.getcwd(), "TEST_IMG_2.JPG"), | |
| os.path.join(os.getcwd(), "TEST_IMG_3.jpg"), | |
| os.path.join(os.getcwd(), "TEST_IMG_4.jpg") | |
| ] | |
| valid_examples = [img for img in example_images if os.path.exists(img)] | |
| if valid_examples: | |
| gr.Examples( | |
| examples=valid_examples, | |
| inputs=image_input, | |
| ) | |
| if __name__ == "__main__": | |
| app.launch(debug=True) |