SamiKhokhar's picture
Update app.py
751f5cc verified
import os
import torch
import gradio as gr
from PIL import Image
# Step 1: Search for best.pt in the training directory
base_path = "yolov5/runs/train/"
best_path = None
# Search through the directory structure to find best.pt
for root, dirs, files in os.walk(base_path):
if "best.pt" in files:
best_path = os.path.join(root, "best.pt")
break
# Step 2: If best.pt is not found, use pre-trained weights
model = None # Ensure model is defined
if best_path is None:
print("Trained weights (best.pt) not found.")
print("Using pre-trained YOLOv5 weights (yolov5s.pt) instead.")
try:
model = torch.hub.load('ultralytics/yolov5', 'yolov5s') # Load pre-trained weights
except Exception as e:
print(f"Error loading pre-trained YOLOv5 model: {e}")
else:
try:
print(f"Model weights found at: {best_path}")
model = torch.hub.load('ultralytics/yolov5', 'custom', path=best_path)
except Exception as e:
print(f"Error loading custom model: {e}")
model = torch.hub.load('ultralytics/yolov5', 'yolov5s') # Fallback to pre-trained model
# Ensure the model was loaded properly before proceeding
if model is None:
raise RuntimeError("Failed to load YOLOv5 model. Please check the weights or model path.")
# Step 3: Define weapon classes to detect
weapon_classes = ['bomb', 'gun', 'pistol', 'Automatic', 'Rifle', 'Bazooka',
'Handgun', 'Knife', 'Grenade Launcher', 'Shotgun', 'SMG',
'Sniper', 'Sword'] # Adjust based on your dataset
def detect_weapons(image):
try:
results = model(image)
except Exception as e:
return f"Error during detection: {e}", None
# Check available model class names
model_classes = results.names
print("Model class names:", model_classes)
# Filter detections by confidence threshold (0.5 or higher)
confidence_threshold = 0.5
filtered_results = results.pandas().xyxy[0][results.pandas().xyxy[0]['confidence'] >= confidence_threshold]
# Get the detected classes with high confidence
detected_classes = filtered_results['name'].unique()
print("Detected classes:", detected_classes)
# Check if any of the detected objects are weapons
detected_threats = [weapon for weapon in weapon_classes if weapon in detected_classes]
# Determine threat message based on weapons detected
if detected_threats:
threat_message = "Threat detected: Be careful"
else:
threat_message = "No threat detected. But all other features are good."
# Create a string with the detected objects' names
detected_objects = ', '.join(detected_classes)
# Render the image with bounding boxes
return f"{threat_message}\nDetected objects: {detected_objects}", Image.fromarray(results.render()[0])
# Step 4: Gradio Interface
def inference(image):
threat, detected_image = detect_weapons(image)
return threat, detected_image
iface = gr.Interface(
fn=inference,
inputs=gr.Image(type="numpy", label="Upload Image"),
outputs=[
gr.Textbox(label="Threat Detection"),
gr.Image(label="Detected Image"),
],
title="Weapon Detection AI",
description="Upload an image to detect weapons like bombs, guns, and pistols."
)
# Step 5: Launch Gradio App
iface.launch()
v