Spaces:
Runtime error
Runtime error
File size: 6,971 Bytes
483fb8b 801bb4d 483fb8b 9c5fb04 483fb8b 801bb4d 9c5fb04 483fb8b 7a3f21a 801bb4d e5ea4ad 9c5fb04 801bb4d 9c5fb04 483fb8b 7a3f21a 801bb4d 7a3f21a 801bb4d 7a3f21a 801bb4d 7a3f21a e5ea4ad 7a3f21a 801bb4d 7a3f21a 801bb4d 7a3f21a 801bb4d 7a3f21a 801bb4d 7a3f21a 801bb4d 7a3f21a 483fb8b 7a3f21a 483fb8b 9c5fb04 7a3f21a 483fb8b 7a3f21a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import gradio as gr
import cv2
import numpy as np
import torch
from ultralytics import SAM, YOLOWorld
import os
# Initialize models with proper error handling and auto-download
def initialize_models():
"""Initialize models with proper error handling."""
try:
sam_model = SAM("mobile_sam.pt") # This auto-downloads
print("✅ SAM model loaded successfully")
except Exception as e:
print(f"❌ Error loading SAM model: {e}")
raise
try:
# Try different YOLO-World model names that auto-download
yolo_model = YOLOWorld("yolov8s-world.pt") # Small world model (auto-downloads)
print("✅ YOLO-World model loaded successfully")
return sam_model, yolo_model
except Exception as e:
print(f"❌ Error loading YOLO-World model: {e}")
try:
# Fallback to regular YOLO if YOLO-World fails
from ultralytics import YOLO
yolo_model = YOLO("yolov8n.pt") # Regular YOLO nano model
print("⚠️ Using regular YOLO model as fallback")
return sam_model, yolo_model
except Exception as e2:
print(f"❌ Fallback YOLO model also failed: {e2}")
raise
sam_model, yolo_model = initialize_models()
def detect_motorcycles(first_frame, prompt="motorcycle"):
"""Detect motorcycles in the first frame using YOLO-World and return bounding boxes."""
try:
# Check if it's YOLO-World model
if hasattr(yolo_model, 'set_classes'):
yolo_model.set_classes([prompt])
results = yolo_model.predict(first_frame, device="cpu", max_det=2, imgsz=320, verbose=False)
else:
# Regular YOLO model - can't set custom classes, will detect all objects
results = yolo_model.predict(first_frame, device="cpu", max_det=5, imgsz=320, verbose=False)
print("⚠️ Using regular YOLO - detecting all objects, not just the specified prompt")
except Exception as e:
print(f"Error in YOLO prediction: {e}")
return np.array([])
boxes = []
for result in results:
if result.boxes is not None and len(result.boxes.xyxy) > 0:
boxes.extend(result.boxes.xyxy.cpu().numpy())
if len(boxes) > 0:
boxes = np.vstack(boxes)
print(f"Detected {len(boxes)} objects")
else:
boxes = np.array([])
print("No objects detected")
return boxes
def segment_and_highlight_video(video_path, prompt="motorcycle", highlight_color="red"):
"""Segment and highlight motorcycles in a video using SAM 2 and YOLO-World."""
# Get video properties first
cap = cv2.VideoCapture(video_path)
original_fps = cap.get(cv2.CAP_PROP_FPS)
original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Target resolution
target_width, target_height = 320, 180
# Get first frame for detection
ret, first_frame = cap.read()
if not ret:
cap.release()
raise ValueError("Could not read first frame from video.")
# Resize first frame for detection
first_frame_resized = cv2.resize(first_frame, (target_width, target_height))
cap.release()
# Detect boxes in resized first frame
boxes = detect_motorcycles(first_frame_resized, prompt)
if len(boxes) == 0:
return video_path # No motorcycles detected, return original
# Boxes are already in the target resolution coordinate system
print(f"Detected {len(boxes)} objects with boxes: {boxes}")
# Color map for highlighting
color_map = {"red": (0, 0, 255), "green": (0, 255, 0), "blue": (255, 0, 0)}
highlight_rgb = color_map.get(highlight_color.lower(), (0, 0, 255))
# Process video frame by frame instead of using SAM's video prediction
cap = cv2.VideoCapture(video_path)
output_path = "output.mp4"
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), original_fps, (target_width, target_height))
frame_count = 0
max_frames = min(total_frames, 150) # Limit to 150 frames (~5 seconds at 30fps)
print(f"Processing {max_frames} frames...")
while frame_count < max_frames:
ret, frame = cap.read()
if not ret:
break
# Resize frame to target resolution
frame_resized = cv2.resize(frame, (target_width, target_height))
try:
# Run SAM on individual frame with explicit resolution control
sam_results = sam_model.predict(
source=frame_resized,
bboxes=boxes,
device="cpu",
imgsz=320, # Force SAM resolution
conf=0.25,
verbose=False
)
highlighted_frame = frame_resized.copy()
# Process SAM results
if len(sam_results) > 0 and sam_results[0].masks is not None:
masks = sam_results[0].masks.data.cpu().numpy()
if len(masks) > 0:
# Combine all masks
combined_mask = np.any(masks, axis=0).astype(np.uint8)
# Create colored overlay
overlay = np.zeros_like(frame_resized)
overlay[combined_mask == 1] = highlight_rgb
# Blend with original frame
highlighted_frame = cv2.addWeighted(frame_resized, 0.7, overlay, 0.3, 0)
except Exception as e:
print(f"Error processing frame {frame_count}: {e}")
highlighted_frame = frame_resized
out.write(highlighted_frame)
frame_count += 1
# Progress indicator
if frame_count % 30 == 0:
print(f"Processed {frame_count}/{max_frames} frames")
cap.release()
out.release()
print(f"Video processing complete. Output saved to {output_path}")
return output_path
# Gradio interface
iface = gr.Interface(
fn=segment_and_highlight_video,
inputs=[
gr.Video(label="Upload Video"),
gr.Textbox(label="Prompt", placeholder="e.g., motorcycle", value="motorcycle"),
gr.Dropdown(choices=["red", "green", "blue"], label="Highlight Color", value="red")
],
outputs=gr.Video(label="Highlighted Video"),
title="Video Segmentation with MobileSAM and YOLO (CPU Optimized)",
description="Upload a short video (5-10 seconds), specify a text prompt (e.g., 'motorcycle'), and choose a highlight color. Uses MobileSAM + YOLO for CPU processing at 320x180 resolution.",
examples=[
[None, "motorcycle", "red"],
[None, "car", "green"],
[None, "person", "blue"]
]
)
if __name__ == "__main__":
iface.launch() |