Spaces:
Running
Running
import torch | |
import torchvision | |
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights | |
from PIL import Image | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import gradio as gr | |
# Load the pre-trained model once | |
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT) | |
model.eval() | |
# COCO class names | |
COCO_INSTANCE_CATEGORY_NAMES = [ | |
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', | |
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', | |
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', | |
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', | |
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', | |
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', | |
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', | |
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', | |
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', | |
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', | |
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' | |
] | |
# Gradio-compatible detection function | |
def detect_objects(image, threshold=0.5): | |
transform = FasterRCNN_ResNet50_FPN_Weights.DEFAULT.transforms() | |
image_tensor = transform(image).unsqueeze(0) | |
with torch.no_grad(): | |
prediction = model(image_tensor)[0] | |
boxes = prediction['boxes'].cpu().numpy() | |
labels = prediction['labels'].cpu().numpy() | |
scores = prediction['scores'].cpu().numpy() | |
image_np = np.array(image) | |
plt.figure(figsize=(10, 10)) | |
plt.imshow(image_np) | |
ax = plt.gca() | |
for box, label, score in zip(boxes, labels, scores): | |
if score >= threshold: | |
x1, y1, x2, y2 = box | |
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, | |
fill=False, color='red', linewidth=2)) | |
class_name = COCO_INSTANCE_CATEGORY_NAMES[label] | |
ax.text(x1, y1, f'{class_name}: {score:.2f}', bbox=dict(facecolor='yellow', alpha=0.5), | |
fontsize=12, color='black') | |
plt.axis('off') | |
plt.tight_layout() | |
# Save the figure to return | |
plt.savefig("output.png") | |
plt.close() | |
return "output.png" | |
# Create Gradio interface | |
gr.Interface( | |
fn=detect_objects, | |
inputs=[ | |
gr.Image(type="pil"), | |
gr.Slider(0, 1, value=0.5, label="Confidence Threshold") | |
], | |
outputs=gr.Image(type="filepath"), | |
title="Faster R-CNN Object Detection", | |
description="Upload an image to detect objects using a pretrained Faster R-CNN model." | |
).launch() | |