Spaces:
Runtime error
Runtime error
File size: 3,696 Bytes
483fb8b 801bb4d 483fb8b 93c5c43 e5ea4ad 483fb8b 801bb4d 93c5c43 e5ea4ad 483fb8b 2974d2b 801bb4d e5ea4ad 801bb4d 483fb8b 801bb4d e5ea4ad 801bb4d 483fb8b e5ea4ad 801bb4d 483fb8b e5ea4ad 801bb4d 483fb8b 801bb4d 483fb8b 801bb4d e5ea4ad 801bb4d 483fb8b 801bb4d e5ea4ad 483fb8b 801bb4d 483fb8b 801bb4d 483fb8b 801bb4d 483fb8b 801bb4d 483fb8b e5ea4ad 2974d2b 483fb8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import gradio as gr
import cv2
import numpy as np
import torch
from ultralytics import SAM, YOLOWorld
import os
# Initialize models
sam_model = SAM("mobile_sam.pt") # Switch to MobileSAM for faster CPU inference
yolo_model = YOLOWorld("yolov8n-world.pt") # Nano model for faster detection
def detect_motorcycles(first_frame, prompt="motorcycle"):
"""Detect motorcycles in the first frame using YOLO-World and return bounding boxes."""
yolo_model.set_classes([prompt])
results = yolo_model.predict(first_frame, device="cpu", max_det=2) # Limit to 2 detections
boxes = []
for result in results:
boxes.extend(result.boxes.xyxy.cpu().numpy())
if len(boxes) > 0:
boxes = np.vstack(boxes)
else:
boxes = np.array([])
return boxes
def segment_and_highlight_video(video_path, prompt="motorcycle", highlight_color="red"):
"""Segment and highlight motorcycles in a video using SAM 2 and YOLO-World."""
# Get first frame for detection
cap = cv2.VideoCapture(video_path)
ret, first_frame = cap.read()
if not ret:
raise ValueError("Could not read first frame from video.")
# Resize first frame for detection
first_frame = cv2.resize(first_frame, (320, 180))
cap.release()
# Detect boxes in first frame
boxes = detect_motorcycles(first_frame, prompt)
if len(boxes) == 0:
return video_path # No motorcycles detected, return original
# Resize boxes to match SAM input resolution (320x180)
scale_x = 320 / first_frame.shape[1]
scale_y = 180 / first_frame.shape[0]
boxes = boxes * [scale_x, scale_y, scale_x, scale_y]
# Run SAM on video with boxes prompt
results = sam_model.predict(source=video_path, bboxes=boxes, stream=True, imgsz=320) # Stream and low resolution
# Prepare output video
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
width = 320
height = 180
output_path = "output.mp4"
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))
# Color map for highlighting
color_map = {"red": (0, 0, 255), "green": (0, 255, 0), "blue": (255, 0, 0)}
highlight_rgb = color_map.get(highlight_color.lower(), (0, 0, 255))
frame_idx = 0
for result in results:
frame = cv2.VideoCapture(video_path).read()[1]
frame = cv2.resize(frame, (width, height))
# Get masks for this frame
if result.masks is not None:
masks = result.masks.data.cpu().numpy() # (num_masks, h, w)
combined_mask = np.any(masks, axis=0).astype(np.uint8) * 255
mask_colored = np.zeros_like(frame)
mask_colored[:, :, 0] = combined_mask * highlight_rgb[0]
mask_colored[:, :, 1] = combined_mask * highlight_rgb[1]
mask_colored[:, :, 2] = combined_mask * highlight_rgb[2]
highlighted_frame = cv2.addWeighted(frame, 0.7, mask_colored, 0.3, 0)
else:
highlighted_frame = frame
out.write(highlighted_frame)
frame_idx += 1
cap.release()
out.release()
return output_path
# Gradio interface
iface = gr.Interface(
fn=segment_and_highlight_video,
inputs=[
gr.Video(label="Upload Video"),
gr.Textbox(label="Prompt", placeholder="e.g., motorcycle"),
gr.Dropdown(choices=["red", "green", "blue"], label="Highlight Color")
],
outputs=gr.Video(label="Highlighted Video"),
title="Video Segmentation with MobileSAM and YOLO-World (CPU)",
description="Upload a short video (5-10 seconds), specify a text prompt (e.g., 'motorcycle'), and choose a highlight color. Optimized for CPU."
)
iface.launch() |