import os import torch import gradio as gr from PIL import Image # Step 1: Search for best.pt in the training directory base_path = "yolov5/runs/train/" best_path = None # Search through the directory structure to find best.pt for root, dirs, files in os.walk(base_path): if "best.pt" in files: best_path = os.path.join(root, "best.pt") break # Step 2: If best.pt is not found, use pre-trained weights model = None # Ensure model is defined if best_path is None: print("Trained weights (best.pt) not found.") print("Using pre-trained YOLOv5 weights (yolov5s.pt) instead.") try: model = torch.hub.load('ultralytics/yolov5', 'yolov5s') # Load pre-trained weights except Exception as e: print(f"Error loading pre-trained YOLOv5 model: {e}") else: try: print(f"Model weights found at: {best_path}") model = torch.hub.load('ultralytics/yolov5', 'custom', path=best_path) except Exception as e: print(f"Error loading custom model: {e}") model = torch.hub.load('ultralytics/yolov5', 'yolov5s') # Fallback to pre-trained model # Ensure the model was loaded properly before proceeding if model is None: raise RuntimeError("Failed to load YOLOv5 model. Please check the weights or model path.") # Step 3: Define weapon classes to detect weapon_classes = ['bomb', 'gun', 'pistol', 'Automatic', 'Rifle', 'Bazooka', 'Handgun', 'Knife', 'Grenade Launcher', 'Shotgun', 'SMG', 'Sniper', 'Sword'] # Adjust based on your dataset def detect_weapons(image): try: results = model(image) except Exception as e: return f"Error during detection: {e}", None # Check available model class names model_classes = results.names print("Model class names:", model_classes) # Filter detections by confidence threshold (0.5 or higher) confidence_threshold = 0.5 filtered_results = results.pandas().xyxy[0][results.pandas().xyxy[0]['confidence'] >= confidence_threshold] # Get the detected classes with high confidence detected_classes = filtered_results['name'].unique() print("Detected classes:", detected_classes) # Check if any of the detected objects are weapons detected_threats = [weapon for weapon in weapon_classes if weapon in detected_classes] # Determine threat message based on weapons detected if detected_threats: threat_message = "Threat detected: Be careful" else: threat_message = "No threat detected. But all other features are good." # Create a string with the detected objects' names detected_objects = ', '.join(detected_classes) # Render the image with bounding boxes return f"{threat_message}\nDetected objects: {detected_objects}", Image.fromarray(results.render()[0]) # Step 4: Gradio Interface def inference(image): threat, detected_image = detect_weapons(image) return threat, detected_image iface = gr.Interface( fn=inference, inputs=gr.Image(type="numpy", label="Upload Image"), outputs=[ gr.Textbox(label="Threat Detection"), gr.Image(label="Detected Image"), ], title="Weapon Detection AI", description="Upload an image to detect weapons like bombs, guns, and pistols." ) # Step 5: Launch Gradio App iface.launch() v