rdjarbeng's picture
Use mobile sam to reduce processing time
e5ea4ad verified
raw
history blame
3.7 kB
import gradio as gr
import cv2
import numpy as np
import torch
from ultralytics import SAM, YOLOWorld
import os
# Initialize models
sam_model = SAM("mobile_sam.pt") # Switch to MobileSAM for faster CPU inference
yolo_model = YOLOWorld("yolov8n-world.pt") # Nano model for faster detection
def detect_motorcycles(first_frame, prompt="motorcycle"):
"""Detect motorcycles in the first frame using YOLO-World and return bounding boxes."""
yolo_model.set_classes([prompt])
results = yolo_model.predict(first_frame, device="cpu", max_det=2) # Limit to 2 detections
boxes = []
for result in results:
boxes.extend(result.boxes.xyxy.cpu().numpy())
if len(boxes) > 0:
boxes = np.vstack(boxes)
else:
boxes = np.array([])
return boxes
def segment_and_highlight_video(video_path, prompt="motorcycle", highlight_color="red"):
"""Segment and highlight motorcycles in a video using SAM 2 and YOLO-World."""
# Get first frame for detection
cap = cv2.VideoCapture(video_path)
ret, first_frame = cap.read()
if not ret:
raise ValueError("Could not read first frame from video.")
# Resize first frame for detection
first_frame = cv2.resize(first_frame, (320, 180))
cap.release()
# Detect boxes in first frame
boxes = detect_motorcycles(first_frame, prompt)
if len(boxes) == 0:
return video_path # No motorcycles detected, return original
# Resize boxes to match SAM input resolution (320x180)
scale_x = 320 / first_frame.shape[1]
scale_y = 180 / first_frame.shape[0]
boxes = boxes * [scale_x, scale_y, scale_x, scale_y]
# Run SAM on video with boxes prompt
results = sam_model.predict(source=video_path, bboxes=boxes, stream=True, imgsz=320) # Stream and low resolution
# Prepare output video
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
width = 320
height = 180
output_path = "output.mp4"
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))
# Color map for highlighting
color_map = {"red": (0, 0, 255), "green": (0, 255, 0), "blue": (255, 0, 0)}
highlight_rgb = color_map.get(highlight_color.lower(), (0, 0, 255))
frame_idx = 0
for result in results:
frame = cv2.VideoCapture(video_path).read()[1]
frame = cv2.resize(frame, (width, height))
# Get masks for this frame
if result.masks is not None:
masks = result.masks.data.cpu().numpy() # (num_masks, h, w)
combined_mask = np.any(masks, axis=0).astype(np.uint8) * 255
mask_colored = np.zeros_like(frame)
mask_colored[:, :, 0] = combined_mask * highlight_rgb[0]
mask_colored[:, :, 1] = combined_mask * highlight_rgb[1]
mask_colored[:, :, 2] = combined_mask * highlight_rgb[2]
highlighted_frame = cv2.addWeighted(frame, 0.7, mask_colored, 0.3, 0)
else:
highlighted_frame = frame
out.write(highlighted_frame)
frame_idx += 1
cap.release()
out.release()
return output_path
# Gradio interface
iface = gr.Interface(
fn=segment_and_highlight_video,
inputs=[
gr.Video(label="Upload Video"),
gr.Textbox(label="Prompt", placeholder="e.g., motorcycle"),
gr.Dropdown(choices=["red", "green", "blue"], label="Highlight Color")
],
outputs=gr.Video(label="Highlighted Video"),
title="Video Segmentation with MobileSAM and YOLO-World (CPU)",
description="Upload a short video (5-10 seconds), specify a text prompt (e.g., 'motorcycle'), and choose a highlight color. Optimized for CPU."
)
iface.launch()