Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import gradio as gr | |
from PIL import Image | |
# Step 1: Search for best.pt in the training directory | |
base_path = "yolov5/runs/train/" | |
best_path = None | |
# Search through the directory structure to find best.pt | |
for root, dirs, files in os.walk(base_path): | |
if "best.pt" in files: | |
best_path = os.path.join(root, "best.pt") | |
break | |
# Step 2: If best.pt is not found, use pre-trained weights | |
model = None # Ensure model is defined | |
if best_path is None: | |
print("Trained weights (best.pt) not found.") | |
print("Using pre-trained YOLOv5 weights (yolov5s.pt) instead.") | |
try: | |
model = torch.hub.load('ultralytics/yolov5', 'yolov5s') # Load pre-trained weights | |
except Exception as e: | |
print(f"Error loading pre-trained YOLOv5 model: {e}") | |
else: | |
try: | |
print(f"Model weights found at: {best_path}") | |
model = torch.hub.load('ultralytics/yolov5', 'custom', path=best_path) | |
except Exception as e: | |
print(f"Error loading custom model: {e}") | |
model = torch.hub.load('ultralytics/yolov5', 'yolov5s') # Fallback to pre-trained model | |
# Ensure the model was loaded properly before proceeding | |
if model is None: | |
raise RuntimeError("Failed to load YOLOv5 model. Please check the weights or model path.") | |
# Step 3: Define weapon classes to detect | |
weapon_classes = ['bomb', 'gun', 'pistol', 'Automatic', 'Rifle', 'Bazooka', | |
'Handgun', 'Knife', 'Grenade Launcher', 'Shotgun', 'SMG', | |
'Sniper', 'Sword'] # Adjust based on your dataset | |
def detect_weapons(image): | |
try: | |
results = model(image) | |
except Exception as e: | |
return f"Error during detection: {e}", None | |
# Check available model class names | |
model_classes = results.names | |
print("Model class names:", model_classes) | |
# Filter detections by confidence threshold (0.5 or higher) | |
confidence_threshold = 0.5 | |
filtered_results = results.pandas().xyxy[0][results.pandas().xyxy[0]['confidence'] >= confidence_threshold] | |
# Get the detected classes with high confidence | |
detected_classes = filtered_results['name'].unique() | |
print("Detected classes:", detected_classes) | |
# Check if any of the detected objects are weapons | |
detected_threats = [weapon for weapon in weapon_classes if weapon in detected_classes] | |
# Determine threat message based on weapons detected | |
if detected_threats: | |
threat_message = "Threat detected: Be careful" | |
else: | |
threat_message = "No threat detected. But all other features are good." | |
# Create a string with the detected objects' names | |
detected_objects = ', '.join(detected_classes) | |
# Render the image with bounding boxes | |
return f"{threat_message}\nDetected objects: {detected_objects}", Image.fromarray(results.render()[0]) | |
# Step 4: Gradio Interface | |
def inference(image): | |
threat, detected_image = detect_weapons(image) | |
return threat, detected_image | |
iface = gr.Interface( | |
fn=inference, | |
inputs=gr.Image(type="numpy", label="Upload Image"), | |
outputs=[ | |
gr.Textbox(label="Threat Detection"), | |
gr.Image(label="Detected Image"), | |
], | |
title="Weapon Detection AI", | |
description="Upload an image to detect weapons like bombs, guns, and pistols." | |
) | |
# Step 5: Launch Gradio App | |
iface.launch() | |
v |