File size: 3,696 Bytes
483fb8b
 
 
 
801bb4d
483fb8b
 
93c5c43
e5ea4ad
 
483fb8b
801bb4d
 
93c5c43
e5ea4ad
483fb8b
 
2974d2b
801bb4d
e5ea4ad
801bb4d
 
483fb8b
 
 
 
801bb4d
 
 
 
 
e5ea4ad
 
801bb4d
 
 
 
 
 
 
483fb8b
e5ea4ad
 
 
 
 
 
 
801bb4d
 
483fb8b
 
e5ea4ad
 
801bb4d
 
483fb8b
801bb4d
 
 
483fb8b
801bb4d
e5ea4ad
 
801bb4d
483fb8b
801bb4d
e5ea4ad
 
 
483fb8b
801bb4d
 
 
483fb8b
801bb4d
 
483fb8b
801bb4d
 
483fb8b
801bb4d
 
483fb8b
 
 
 
 
 
 
 
 
 
 
 
e5ea4ad
2974d2b
483fb8b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
import cv2
import numpy as np
import torch
from ultralytics import SAM, YOLOWorld
import os

# Initialize models
sam_model = SAM("mobile_sam.pt")  # Switch to MobileSAM for faster CPU inference
yolo_model = YOLOWorld("yolov8n-world.pt")  # Nano model for faster detection

def detect_motorcycles(first_frame, prompt="motorcycle"):
    """Detect motorcycles in the first frame using YOLO-World and return bounding boxes."""
    yolo_model.set_classes([prompt])
    results = yolo_model.predict(first_frame, device="cpu", max_det=2)  # Limit to 2 detections
    boxes = []
    for result in results:
        boxes.extend(result.boxes.xyxy.cpu().numpy())
    if len(boxes) > 0:
        boxes = np.vstack(boxes)
    else:
        boxes = np.array([])
    return boxes

def segment_and_highlight_video(video_path, prompt="motorcycle", highlight_color="red"):
    """Segment and highlight motorcycles in a video using SAM 2 and YOLO-World."""
    # Get first frame for detection
    cap = cv2.VideoCapture(video_path)
    ret, first_frame = cap.read()
    if not ret:
        raise ValueError("Could not read first frame from video.")
    # Resize first frame for detection
    first_frame = cv2.resize(first_frame, (320, 180))
    cap.release()

    # Detect boxes in first frame
    boxes = detect_motorcycles(first_frame, prompt)

    if len(boxes) == 0:
        return video_path  # No motorcycles detected, return original

    # Resize boxes to match SAM input resolution (320x180)
    scale_x = 320 / first_frame.shape[1]
    scale_y = 180 / first_frame.shape[0]
    boxes = boxes * [scale_x, scale_y, scale_x, scale_y]

    # Run SAM on video with boxes prompt
    results = sam_model.predict(source=video_path, bboxes=boxes, stream=True, imgsz=320)  # Stream and low resolution

    # Prepare output video
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = 320
    height = 180
    output_path = "output.mp4"
    out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))

    # Color map for highlighting
    color_map = {"red": (0, 0, 255), "green": (0, 255, 0), "blue": (255, 0, 0)}
    highlight_rgb = color_map.get(highlight_color.lower(), (0, 0, 255))

    frame_idx = 0
    for result in results:
        frame = cv2.VideoCapture(video_path).read()[1]
        frame = cv2.resize(frame, (width, height))

        # Get masks for this frame
        if result.masks is not None:
            masks = result.masks.data.cpu().numpy()  # (num_masks, h, w)
            combined_mask = np.any(masks, axis=0).astype(np.uint8) * 255
            mask_colored = np.zeros_like(frame)
            mask_colored[:, :, 0] = combined_mask * highlight_rgb[0]
            mask_colored[:, :, 1] = combined_mask * highlight_rgb[1]
            mask_colored[:, :, 2] = combined_mask * highlight_rgb[2]
            highlighted_frame = cv2.addWeighted(frame, 0.7, mask_colored, 0.3, 0)
        else:
            highlighted_frame = frame

        out.write(highlighted_frame)
        frame_idx += 1

    cap.release()
    out.release()

    return output_path

# Gradio interface
iface = gr.Interface(
    fn=segment_and_highlight_video,
    inputs=[
        gr.Video(label="Upload Video"),
        gr.Textbox(label="Prompt", placeholder="e.g., motorcycle"),
        gr.Dropdown(choices=["red", "green", "blue"], label="Highlight Color")
    ],
    outputs=gr.Video(label="Highlighted Video"),
    title="Video Segmentation with MobileSAM and YOLO-World (CPU)",
    description="Upload a short video (5-10 seconds), specify a text prompt (e.g., 'motorcycle'), and choose a highlight color. Optimized for CPU."
)
iface.launch()