Spaces:
Runtime error
Runtime error
File size: 3,329 Bytes
9423eeb b0c391e 69ebf82 b0c391e 751f5cc b0c391e fd21413 b0c391e cdef2a7 b0c391e 751f5cc b0c391e 69ebf82 9423eeb cdef2a7 b0c391e fd21413 b0c391e 69ebf82 b0c391e 69ebf82 b0c391e 69ebf82 b0c391e 69ebf82 b0c391e 69ebf82 9423eeb 69ebf82 9423eeb bbe4eba 9423eeb b0c391e 9423eeb b0c391e 69ebf82 751f5cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import os
import torch
import gradio as gr
from PIL import Image
# Step 1: Search for best.pt in the training directory
base_path = "yolov5/runs/train/"
best_path = None
# Search through the directory structure to find best.pt
for root, dirs, files in os.walk(base_path):
if "best.pt" in files:
best_path = os.path.join(root, "best.pt")
break
# Step 2: If best.pt is not found, use pre-trained weights
model = None # Ensure model is defined
if best_path is None:
print("Trained weights (best.pt) not found.")
print("Using pre-trained YOLOv5 weights (yolov5s.pt) instead.")
try:
model = torch.hub.load('ultralytics/yolov5', 'yolov5s') # Load pre-trained weights
except Exception as e:
print(f"Error loading pre-trained YOLOv5 model: {e}")
else:
try:
print(f"Model weights found at: {best_path}")
model = torch.hub.load('ultralytics/yolov5', 'custom', path=best_path)
except Exception as e:
print(f"Error loading custom model: {e}")
model = torch.hub.load('ultralytics/yolov5', 'yolov5s') # Fallback to pre-trained model
# Ensure the model was loaded properly before proceeding
if model is None:
raise RuntimeError("Failed to load YOLOv5 model. Please check the weights or model path.")
# Step 3: Define weapon classes to detect
weapon_classes = ['bomb', 'gun', 'pistol', 'Automatic', 'Rifle', 'Bazooka',
'Handgun', 'Knife', 'Grenade Launcher', 'Shotgun', 'SMG',
'Sniper', 'Sword'] # Adjust based on your dataset
def detect_weapons(image):
try:
results = model(image)
except Exception as e:
return f"Error during detection: {e}", None
# Check available model class names
model_classes = results.names
print("Model class names:", model_classes)
# Filter detections by confidence threshold (0.5 or higher)
confidence_threshold = 0.5
filtered_results = results.pandas().xyxy[0][results.pandas().xyxy[0]['confidence'] >= confidence_threshold]
# Get the detected classes with high confidence
detected_classes = filtered_results['name'].unique()
print("Detected classes:", detected_classes)
# Check if any of the detected objects are weapons
detected_threats = [weapon for weapon in weapon_classes if weapon in detected_classes]
# Determine threat message based on weapons detected
if detected_threats:
threat_message = "Threat detected: Be careful"
else:
threat_message = "No threat detected. But all other features are good."
# Create a string with the detected objects' names
detected_objects = ', '.join(detected_classes)
# Render the image with bounding boxes
return f"{threat_message}\nDetected objects: {detected_objects}", Image.fromarray(results.render()[0])
# Step 4: Gradio Interface
def inference(image):
threat, detected_image = detect_weapons(image)
return threat, detected_image
iface = gr.Interface(
fn=inference,
inputs=gr.Image(type="numpy", label="Upload Image"),
outputs=[
gr.Textbox(label="Threat Detection"),
gr.Image(label="Detected Image"),
],
title="Weapon Detection AI",
description="Upload an image to detect weapons like bombs, guns, and pistols."
)
# Step 5: Launch Gradio App
iface.launch()
v |