rdjarbeng's picture
Changed from yolov8n-world.pt to yolov8s-world.pt, add error handling
9c5fb04 verified
import gradio as gr
import cv2
import numpy as np
import torch
from ultralytics import SAM, YOLOWorld
import os
# Initialize models with proper error handling and auto-download
def initialize_models():
"""Initialize models with proper error handling."""
try:
sam_model = SAM("mobile_sam.pt") # This auto-downloads
print("✅ SAM model loaded successfully")
except Exception as e:
print(f"❌ Error loading SAM model: {e}")
raise
try:
# Try different YOLO-World model names that auto-download
yolo_model = YOLOWorld("yolov8s-world.pt") # Small world model (auto-downloads)
print("✅ YOLO-World model loaded successfully")
return sam_model, yolo_model
except Exception as e:
print(f"❌ Error loading YOLO-World model: {e}")
try:
# Fallback to regular YOLO if YOLO-World fails
from ultralytics import YOLO
yolo_model = YOLO("yolov8n.pt") # Regular YOLO nano model
print("⚠️ Using regular YOLO model as fallback")
return sam_model, yolo_model
except Exception as e2:
print(f"❌ Fallback YOLO model also failed: {e2}")
raise
sam_model, yolo_model = initialize_models()
def detect_motorcycles(first_frame, prompt="motorcycle"):
"""Detect motorcycles in the first frame using YOLO-World and return bounding boxes."""
try:
# Check if it's YOLO-World model
if hasattr(yolo_model, 'set_classes'):
yolo_model.set_classes([prompt])
results = yolo_model.predict(first_frame, device="cpu", max_det=2, imgsz=320, verbose=False)
else:
# Regular YOLO model - can't set custom classes, will detect all objects
results = yolo_model.predict(first_frame, device="cpu", max_det=5, imgsz=320, verbose=False)
print("⚠️ Using regular YOLO - detecting all objects, not just the specified prompt")
except Exception as e:
print(f"Error in YOLO prediction: {e}")
return np.array([])
boxes = []
for result in results:
if result.boxes is not None and len(result.boxes.xyxy) > 0:
boxes.extend(result.boxes.xyxy.cpu().numpy())
if len(boxes) > 0:
boxes = np.vstack(boxes)
print(f"Detected {len(boxes)} objects")
else:
boxes = np.array([])
print("No objects detected")
return boxes
def segment_and_highlight_video(video_path, prompt="motorcycle", highlight_color="red"):
"""Segment and highlight motorcycles in a video using SAM 2 and YOLO-World."""
# Get video properties first
cap = cv2.VideoCapture(video_path)
original_fps = cap.get(cv2.CAP_PROP_FPS)
original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Target resolution
target_width, target_height = 320, 180
# Get first frame for detection
ret, first_frame = cap.read()
if not ret:
cap.release()
raise ValueError("Could not read first frame from video.")
# Resize first frame for detection
first_frame_resized = cv2.resize(first_frame, (target_width, target_height))
cap.release()
# Detect boxes in resized first frame
boxes = detect_motorcycles(first_frame_resized, prompt)
if len(boxes) == 0:
return video_path # No motorcycles detected, return original
# Boxes are already in the target resolution coordinate system
print(f"Detected {len(boxes)} objects with boxes: {boxes}")
# Color map for highlighting
color_map = {"red": (0, 0, 255), "green": (0, 255, 0), "blue": (255, 0, 0)}
highlight_rgb = color_map.get(highlight_color.lower(), (0, 0, 255))
# Process video frame by frame instead of using SAM's video prediction
cap = cv2.VideoCapture(video_path)
output_path = "output.mp4"
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), original_fps, (target_width, target_height))
frame_count = 0
max_frames = min(total_frames, 150) # Limit to 150 frames (~5 seconds at 30fps)
print(f"Processing {max_frames} frames...")
while frame_count < max_frames:
ret, frame = cap.read()
if not ret:
break
# Resize frame to target resolution
frame_resized = cv2.resize(frame, (target_width, target_height))
try:
# Run SAM on individual frame with explicit resolution control
sam_results = sam_model.predict(
source=frame_resized,
bboxes=boxes,
device="cpu",
imgsz=320, # Force SAM resolution
conf=0.25,
verbose=False
)
highlighted_frame = frame_resized.copy()
# Process SAM results
if len(sam_results) > 0 and sam_results[0].masks is not None:
masks = sam_results[0].masks.data.cpu().numpy()
if len(masks) > 0:
# Combine all masks
combined_mask = np.any(masks, axis=0).astype(np.uint8)
# Create colored overlay
overlay = np.zeros_like(frame_resized)
overlay[combined_mask == 1] = highlight_rgb
# Blend with original frame
highlighted_frame = cv2.addWeighted(frame_resized, 0.7, overlay, 0.3, 0)
except Exception as e:
print(f"Error processing frame {frame_count}: {e}")
highlighted_frame = frame_resized
out.write(highlighted_frame)
frame_count += 1
# Progress indicator
if frame_count % 30 == 0:
print(f"Processed {frame_count}/{max_frames} frames")
cap.release()
out.release()
print(f"Video processing complete. Output saved to {output_path}")
return output_path
# Gradio interface
iface = gr.Interface(
fn=segment_and_highlight_video,
inputs=[
gr.Video(label="Upload Video"),
gr.Textbox(label="Prompt", placeholder="e.g., motorcycle", value="motorcycle"),
gr.Dropdown(choices=["red", "green", "blue"], label="Highlight Color", value="red")
],
outputs=gr.Video(label="Highlighted Video"),
title="Video Segmentation with MobileSAM and YOLO (CPU Optimized)",
description="Upload a short video (5-10 seconds), specify a text prompt (e.g., 'motorcycle'), and choose a highlight color. Uses MobileSAM + YOLO for CPU processing at 320x180 resolution.",
examples=[
[None, "motorcycle", "red"],
[None, "car", "green"],
[None, "person", "blue"]
]
)
if __name__ == "__main__":
iface.launch()