Spaces:
Sleeping
Sleeping
import torch | |
import torchvision | |
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights | |
from PIL import Image | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import gradio as gr | |
import os | |
import sys | |
# Load the pre-trained model once | |
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT) | |
model.eval() | |
# COCO class names | |
COCO_INSTANCE_CATEGORY_NAMES = [ | |
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', | |
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', | |
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', | |
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', | |
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', | |
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', | |
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', | |
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', | |
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', | |
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', | |
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' | |
] | |
# Gradio-compatible detection function | |
def detect_objects(image, threshold=0.5): | |
if image is None: | |
return None | |
try: | |
transform = FasterRCNN_ResNet50_FPN_Weights.DEFAULT.transforms() | |
image_tensor = transform(image).unsqueeze(0) | |
with torch.no_grad(): | |
prediction = model(image_tensor)[0] | |
boxes = prediction['boxes'].cpu().numpy() | |
labels = prediction['labels'].cpu().numpy() | |
scores = prediction['scores'].cpu().numpy() | |
image_np = np.array(image) | |
plt.figure(figsize=(10, 10)) | |
plt.imshow(image_np) | |
ax = plt.gca() | |
for box, label, score in zip(boxes, labels, scores): | |
if score >= threshold: | |
x1, y1, x2, y2 = box | |
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, | |
fill=False, color='red', linewidth=2)) | |
class_name = COCO_INSTANCE_CATEGORY_NAMES[label] | |
ax.text(x1, y1, f'{class_name}: {score:.2f}', bbox=dict(facecolor='yellow', alpha=0.5), | |
fontsize=12, color='black') | |
plt.axis('off') | |
plt.tight_layout() | |
# Save the figure to return | |
output_path = "output.png" | |
plt.savefig(output_path) | |
plt.close() | |
return output_path | |
except Exception as e: | |
print(f"Error in detect_objects: {e}", file=sys.stderr) | |
return None | |
# Function to check if a file exists | |
def file_exists(filepath): | |
return os.path.isfile(filepath) | |
# Find base directory for examples | |
# For Hugging Face Spaces, this is typically the root directory of the repository | |
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
# Check all possible locations for the example images | |
possible_dirs = [ | |
BASE_DIR, # Root directory | |
os.path.join(BASE_DIR, "Object-Detection"), # Subdirectory | |
os.path.join(BASE_DIR, "images"), # Common image directory name | |
os.path.join(os.path.dirname(BASE_DIR), "Object-Detection") # Parent/sibling directory | |
] | |
# Test image filenames with different case combinations | |
test_image_variations = [ | |
["TEST_IMG_1.jpg"], | |
["TEST_IMG_1.JPG"], | |
["test_img_1.jpg"], | |
["Test_Img_1.jpg"] | |
] | |
# Find working examples by testing different combinations | |
working_examples = [] | |
# Check all possible combinations of directories and filenames | |
for directory in possible_dirs: | |
print(f"Checking directory: {directory}", file=sys.stderr) | |
if os.path.isdir(directory): | |
for variation in test_image_variations: | |
filepath = os.path.join(directory, variation[0]) | |
if file_exists(filepath): | |
print(f"Found example image: {filepath}", file=sys.stderr) | |
working_examples.append([filepath]) | |
# If we found the first image, try the others with the same pattern | |
base_pattern = variation[0].split("1")[0] | |
ext = variation[0].split(".")[-1] | |
for i in range(2, 5): # Test images 2-4 | |
test_path = os.path.join(directory, f"{base_pattern}{i}.{ext}") | |
if file_exists(test_path): | |
print(f"Found additional example: {test_path}", file=sys.stderr) | |
working_examples.append([test_path]) | |
# If we found all 4 examples, break the loop | |
if len(working_examples) >= 4: | |
break | |
# If we found examples in this directory, no need to check others | |
if working_examples: | |
break | |
# If no working examples found, try hard-coded paths | |
if not working_examples: | |
print("No examples found automatically. Using hard-coded paths.", file=sys.stderr) | |
example_images = [ | |
["TEST_IMG_1.jpg"], | |
["TEST_IMG_2.JPG"], | |
["TEST_IMG_3.jpg"], | |
["TEST_IMG_4.jpg"] | |
] | |
else: | |
example_images = working_examples[:4] # Use first 4 found examples | |
print(f"Final example images: {example_images}", file=sys.stderr) | |
# Create Gradio interface | |
interface = gr.Interface( | |
fn=detect_objects, | |
inputs=[ | |
gr.Image(type="pil"), | |
gr.Slider(0, 1, value=0.5, label="Confidence Threshold") | |
], | |
outputs=gr.Image(type="filepath"), | |
examples=example_images, | |
title="Faster R-CNN Object Detection", | |
description="Upload an image to detect objects using a pretrained Faster R-CNN model.", | |
allow_flagging="never" # Disable flagging to avoid potential issues | |
) | |
# Launch with specific configuration for Hugging Face | |
if __name__ == "__main__": | |
interface.launch(debug=True) |