import gradio as gr import cv2 import numpy as np import torch from ultralytics import SAM, YOLOWorld import os # Initialize models with proper error handling and auto-download def initialize_models(): """Initialize models with proper error handling.""" try: sam_model = SAM("mobile_sam.pt") # This auto-downloads print("✅ SAM model loaded successfully") except Exception as e: print(f"❌ Error loading SAM model: {e}") raise try: # Try different YOLO-World model names that auto-download yolo_model = YOLOWorld("yolov8s-world.pt") # Small world model (auto-downloads) print("✅ YOLO-World model loaded successfully") return sam_model, yolo_model except Exception as e: print(f"❌ Error loading YOLO-World model: {e}") try: # Fallback to regular YOLO if YOLO-World fails from ultralytics import YOLO yolo_model = YOLO("yolov8n.pt") # Regular YOLO nano model print("⚠️ Using regular YOLO model as fallback") return sam_model, yolo_model except Exception as e2: print(f"❌ Fallback YOLO model also failed: {e2}") raise sam_model, yolo_model = initialize_models() def detect_motorcycles(first_frame, prompt="motorcycle"): """Detect motorcycles in the first frame using YOLO-World and return bounding boxes.""" try: # Check if it's YOLO-World model if hasattr(yolo_model, 'set_classes'): yolo_model.set_classes([prompt]) results = yolo_model.predict(first_frame, device="cpu", max_det=2, imgsz=320, verbose=False) else: # Regular YOLO model - can't set custom classes, will detect all objects results = yolo_model.predict(first_frame, device="cpu", max_det=5, imgsz=320, verbose=False) print("⚠️ Using regular YOLO - detecting all objects, not just the specified prompt") except Exception as e: print(f"Error in YOLO prediction: {e}") return np.array([]) boxes = [] for result in results: if result.boxes is not None and len(result.boxes.xyxy) > 0: boxes.extend(result.boxes.xyxy.cpu().numpy()) if len(boxes) > 0: boxes = np.vstack(boxes) print(f"Detected {len(boxes)} objects") else: boxes = np.array([]) print("No objects detected") return boxes def segment_and_highlight_video(video_path, prompt="motorcycle", highlight_color="red"): """Segment and highlight motorcycles in a video using SAM 2 and YOLO-World.""" # Get video properties first cap = cv2.VideoCapture(video_path) original_fps = cap.get(cv2.CAP_PROP_FPS) original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Target resolution target_width, target_height = 320, 180 # Get first frame for detection ret, first_frame = cap.read() if not ret: cap.release() raise ValueError("Could not read first frame from video.") # Resize first frame for detection first_frame_resized = cv2.resize(first_frame, (target_width, target_height)) cap.release() # Detect boxes in resized first frame boxes = detect_motorcycles(first_frame_resized, prompt) if len(boxes) == 0: return video_path # No motorcycles detected, return original # Boxes are already in the target resolution coordinate system print(f"Detected {len(boxes)} objects with boxes: {boxes}") # Color map for highlighting color_map = {"red": (0, 0, 255), "green": (0, 255, 0), "blue": (255, 0, 0)} highlight_rgb = color_map.get(highlight_color.lower(), (0, 0, 255)) # Process video frame by frame instead of using SAM's video prediction cap = cv2.VideoCapture(video_path) output_path = "output.mp4" out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), original_fps, (target_width, target_height)) frame_count = 0 max_frames = min(total_frames, 150) # Limit to 150 frames (~5 seconds at 30fps) print(f"Processing {max_frames} frames...") while frame_count < max_frames: ret, frame = cap.read() if not ret: break # Resize frame to target resolution frame_resized = cv2.resize(frame, (target_width, target_height)) try: # Run SAM on individual frame with explicit resolution control sam_results = sam_model.predict( source=frame_resized, bboxes=boxes, device="cpu", imgsz=320, # Force SAM resolution conf=0.25, verbose=False ) highlighted_frame = frame_resized.copy() # Process SAM results if len(sam_results) > 0 and sam_results[0].masks is not None: masks = sam_results[0].masks.data.cpu().numpy() if len(masks) > 0: # Combine all masks combined_mask = np.any(masks, axis=0).astype(np.uint8) # Create colored overlay overlay = np.zeros_like(frame_resized) overlay[combined_mask == 1] = highlight_rgb # Blend with original frame highlighted_frame = cv2.addWeighted(frame_resized, 0.7, overlay, 0.3, 0) except Exception as e: print(f"Error processing frame {frame_count}: {e}") highlighted_frame = frame_resized out.write(highlighted_frame) frame_count += 1 # Progress indicator if frame_count % 30 == 0: print(f"Processed {frame_count}/{max_frames} frames") cap.release() out.release() print(f"Video processing complete. Output saved to {output_path}") return output_path # Gradio interface iface = gr.Interface( fn=segment_and_highlight_video, inputs=[ gr.Video(label="Upload Video"), gr.Textbox(label="Prompt", placeholder="e.g., motorcycle", value="motorcycle"), gr.Dropdown(choices=["red", "green", "blue"], label="Highlight Color", value="red") ], outputs=gr.Video(label="Highlighted Video"), title="Video Segmentation with MobileSAM and YOLO (CPU Optimized)", description="Upload a short video (5-10 seconds), specify a text prompt (e.g., 'motorcycle'), and choose a highlight color. Uses MobileSAM + YOLO for CPU processing at 320x180 resolution.", examples=[ [None, "motorcycle", "red"], [None, "car", "green"], [None, "person", "blue"] ] ) if __name__ == "__main__": iface.launch()