Spaces:
Runtime error
Runtime error
import gradio as gr | |
import cv2 | |
import numpy as np | |
import torch | |
from ultralytics import SAM, YOLOWorld | |
import os | |
# Initialize models with proper error handling and auto-download | |
def initialize_models(): | |
"""Initialize models with proper error handling.""" | |
try: | |
sam_model = SAM("mobile_sam.pt") # This auto-downloads | |
print("✅ SAM model loaded successfully") | |
except Exception as e: | |
print(f"❌ Error loading SAM model: {e}") | |
raise | |
try: | |
# Try different YOLO-World model names that auto-download | |
yolo_model = YOLOWorld("yolov8s-world.pt") # Small world model (auto-downloads) | |
print("✅ YOLO-World model loaded successfully") | |
return sam_model, yolo_model | |
except Exception as e: | |
print(f"❌ Error loading YOLO-World model: {e}") | |
try: | |
# Fallback to regular YOLO if YOLO-World fails | |
from ultralytics import YOLO | |
yolo_model = YOLO("yolov8n.pt") # Regular YOLO nano model | |
print("⚠️ Using regular YOLO model as fallback") | |
return sam_model, yolo_model | |
except Exception as e2: | |
print(f"❌ Fallback YOLO model also failed: {e2}") | |
raise | |
sam_model, yolo_model = initialize_models() | |
def detect_motorcycles(first_frame, prompt="motorcycle"): | |
"""Detect motorcycles in the first frame using YOLO-World and return bounding boxes.""" | |
try: | |
# Check if it's YOLO-World model | |
if hasattr(yolo_model, 'set_classes'): | |
yolo_model.set_classes([prompt]) | |
results = yolo_model.predict(first_frame, device="cpu", max_det=2, imgsz=320, verbose=False) | |
else: | |
# Regular YOLO model - can't set custom classes, will detect all objects | |
results = yolo_model.predict(first_frame, device="cpu", max_det=5, imgsz=320, verbose=False) | |
print("⚠️ Using regular YOLO - detecting all objects, not just the specified prompt") | |
except Exception as e: | |
print(f"Error in YOLO prediction: {e}") | |
return np.array([]) | |
boxes = [] | |
for result in results: | |
if result.boxes is not None and len(result.boxes.xyxy) > 0: | |
boxes.extend(result.boxes.xyxy.cpu().numpy()) | |
if len(boxes) > 0: | |
boxes = np.vstack(boxes) | |
print(f"Detected {len(boxes)} objects") | |
else: | |
boxes = np.array([]) | |
print("No objects detected") | |
return boxes | |
def segment_and_highlight_video(video_path, prompt="motorcycle", highlight_color="red"): | |
"""Segment and highlight motorcycles in a video using SAM 2 and YOLO-World.""" | |
# Get video properties first | |
cap = cv2.VideoCapture(video_path) | |
original_fps = cap.get(cv2.CAP_PROP_FPS) | |
original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
# Target resolution | |
target_width, target_height = 320, 180 | |
# Get first frame for detection | |
ret, first_frame = cap.read() | |
if not ret: | |
cap.release() | |
raise ValueError("Could not read first frame from video.") | |
# Resize first frame for detection | |
first_frame_resized = cv2.resize(first_frame, (target_width, target_height)) | |
cap.release() | |
# Detect boxes in resized first frame | |
boxes = detect_motorcycles(first_frame_resized, prompt) | |
if len(boxes) == 0: | |
return video_path # No motorcycles detected, return original | |
# Boxes are already in the target resolution coordinate system | |
print(f"Detected {len(boxes)} objects with boxes: {boxes}") | |
# Color map for highlighting | |
color_map = {"red": (0, 0, 255), "green": (0, 255, 0), "blue": (255, 0, 0)} | |
highlight_rgb = color_map.get(highlight_color.lower(), (0, 0, 255)) | |
# Process video frame by frame instead of using SAM's video prediction | |
cap = cv2.VideoCapture(video_path) | |
output_path = "output.mp4" | |
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), original_fps, (target_width, target_height)) | |
frame_count = 0 | |
max_frames = min(total_frames, 150) # Limit to 150 frames (~5 seconds at 30fps) | |
print(f"Processing {max_frames} frames...") | |
while frame_count < max_frames: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
# Resize frame to target resolution | |
frame_resized = cv2.resize(frame, (target_width, target_height)) | |
try: | |
# Run SAM on individual frame with explicit resolution control | |
sam_results = sam_model.predict( | |
source=frame_resized, | |
bboxes=boxes, | |
device="cpu", | |
imgsz=320, # Force SAM resolution | |
conf=0.25, | |
verbose=False | |
) | |
highlighted_frame = frame_resized.copy() | |
# Process SAM results | |
if len(sam_results) > 0 and sam_results[0].masks is not None: | |
masks = sam_results[0].masks.data.cpu().numpy() | |
if len(masks) > 0: | |
# Combine all masks | |
combined_mask = np.any(masks, axis=0).astype(np.uint8) | |
# Create colored overlay | |
overlay = np.zeros_like(frame_resized) | |
overlay[combined_mask == 1] = highlight_rgb | |
# Blend with original frame | |
highlighted_frame = cv2.addWeighted(frame_resized, 0.7, overlay, 0.3, 0) | |
except Exception as e: | |
print(f"Error processing frame {frame_count}: {e}") | |
highlighted_frame = frame_resized | |
out.write(highlighted_frame) | |
frame_count += 1 | |
# Progress indicator | |
if frame_count % 30 == 0: | |
print(f"Processed {frame_count}/{max_frames} frames") | |
cap.release() | |
out.release() | |
print(f"Video processing complete. Output saved to {output_path}") | |
return output_path | |
# Gradio interface | |
iface = gr.Interface( | |
fn=segment_and_highlight_video, | |
inputs=[ | |
gr.Video(label="Upload Video"), | |
gr.Textbox(label="Prompt", placeholder="e.g., motorcycle", value="motorcycle"), | |
gr.Dropdown(choices=["red", "green", "blue"], label="Highlight Color", value="red") | |
], | |
outputs=gr.Video(label="Highlighted Video"), | |
title="Video Segmentation with MobileSAM and YOLO (CPU Optimized)", | |
description="Upload a short video (5-10 seconds), specify a text prompt (e.g., 'motorcycle'), and choose a highlight color. Uses MobileSAM + YOLO for CPU processing at 320x180 resolution.", | |
examples=[ | |
[None, "motorcycle", "red"], | |
[None, "car", "green"], | |
[None, "person", "blue"] | |
] | |
) | |
if __name__ == "__main__": | |
iface.launch() |