import os
import torch
import gradio as gr
from PIL import Image

# Step 1: Search for best.pt in the training directory
base_path = "yolov5/runs/train/"
best_path = None

# Search through the directory structure to find best.pt
for root, dirs, files in os.walk(base_path):
    if "best.pt" in files:
        best_path = os.path.join(root, "best.pt")
        break

# Step 2: If best.pt is not found, use pre-trained weights
model = None  # Ensure model is defined

if best_path is None:
    print("Trained weights (best.pt) not found.")
    print("Using pre-trained YOLOv5 weights (yolov5s.pt) instead.")
    try:
        model = torch.hub.load('ultralytics/yolov5', 'yolov5s')  # Load pre-trained weights
    except Exception as e:
        print(f"Error loading pre-trained YOLOv5 model: {e}")
else:
    try:
        print(f"Model weights found at: {best_path}")
        model = torch.hub.load('ultralytics/yolov5', 'custom', path=best_path)
    except Exception as e:
        print(f"Error loading custom model: {e}")
        model = torch.hub.load('ultralytics/yolov5', 'yolov5s')  # Fallback to pre-trained model

# Ensure the model was loaded properly before proceeding
if model is None:
    raise RuntimeError("Failed to load YOLOv5 model. Please check the weights or model path.")

# Step 3: Define weapon classes to detect
weapon_classes = ['bomb', 'gun', 'pistol', 'Automatic', 'Rifle', 'Bazooka', 
                  'Handgun', 'Knife', 'Grenade Launcher', 'Shotgun', 'SMG', 
                  'Sniper', 'Sword']  # Adjust based on your dataset

def detect_weapons(image):
    try:
        results = model(image)
    except Exception as e:
        return f"Error during detection: {e}", None

    # Check available model class names
    model_classes = results.names
    print("Model class names:", model_classes)

    # Filter detections by confidence threshold (0.5 or higher)
    confidence_threshold = 0.5
    filtered_results = results.pandas().xyxy[0][results.pandas().xyxy[0]['confidence'] >= confidence_threshold]

    # Get the detected classes with high confidence
    detected_classes = filtered_results['name'].unique()
    print("Detected classes:", detected_classes)

    # Check if any of the detected objects are weapons
    detected_threats = [weapon for weapon in weapon_classes if weapon in detected_classes]

    # Determine threat message based on weapons detected
    if detected_threats:
        threat_message = "Threat detected: Be careful"
    else:
        threat_message = "No threat detected. But all other features are good."

    # Create a string with the detected objects' names
    detected_objects = ', '.join(detected_classes)

    # Render the image with bounding boxes
    return f"{threat_message}\nDetected objects: {detected_objects}", Image.fromarray(results.render()[0])

# Step 4: Gradio Interface
def inference(image):
    threat, detected_image = detect_weapons(image)
    return threat, detected_image

iface = gr.Interface(
    fn=inference,
    inputs=gr.Image(type="numpy", label="Upload Image"),
    outputs=[
        gr.Textbox(label="Threat Detection"),
        gr.Image(label="Detected Image"),
    ],
    title="Weapon Detection AI",
    description="Upload an image to detect weapons like bombs, guns, and pistols."
)

# Step 5: Launch Gradio App
iface.launch()
v