rdjarbeng commited on
Commit
483fb8b
·
verified ·
1 Parent(s): 87560e7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ from sam2.sam2_video_predictor import SAM2VideoPredictor
6
+ from ultralytics import YOLO
7
+ import supervision as sv
8
+ import os
9
+
10
+ # Initialize models from Hugging Face Hub
11
+ predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2.1-hiera-tiny")
12
+ yolo_model = YOLO("ultralytics/yolo-world-v8n") # Lightweight YOLO-World model
13
+
14
+ def detect_motorcycles(frame, prompt="motorcycle"):
15
+ """Detect motorcycles in a frame using YOLO-World and return bounding boxes."""
16
+ results = yolo_model.predict(frame, prompt=prompt, device="cpu")
17
+ boxes = []
18
+ for result in results:
19
+ for box in result.boxes:
20
+ if result.names[box.cls] == prompt:
21
+ x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
22
+ boxes.append([x1, y1, x2, y2])
23
+ return boxes
24
+
25
+ def segment_and_highlight_video(video_path, prompt="motorcycle", highlight_color="red"):
26
+ """Segment and highlight motorcycles in a video using SAM 2 and YOLO-World."""
27
+ # Create temporary directory for video frames
28
+ frames_dir = "video_frames"
29
+ os.makedirs(frames_dir, exist_ok=True)
30
+
31
+ # Extract frames
32
+ cap = cv2.VideoCapture(video_path)
33
+ fps = cap.get(cv2.CAP_PROP_FPS)
34
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
35
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
36
+ # Limit resolution for CPU
37
+ if width > 640:
38
+ height = int(height * 640 / width)
39
+ width = 640
40
+ frame_paths = []
41
+
42
+ # Save frames as JPEG
43
+ frame_idx = 0
44
+ with sv.ImageSink(target_dir_path=frames_dir, image_name_pattern="{:05d}.jpeg") as sink:
45
+ while cap.isOpened():
46
+ ret, frame = cap.read()
47
+ if not ret:
48
+ break
49
+ frame = cv2.resize(frame, (width, height))
50
+ sink.save_image(frame)
51
+ frame_paths.append(os.path.join(frames_dir, f"{frame_idx:05d}.jpeg"))
52
+ frame_idx += 1
53
+ cap.release()
54
+
55
+ # Initialize SAM 2 inference state
56
+ with torch.inference_mode():
57
+ state = predictor.init_state(video_path=frames_dir)
58
+
59
+ # Detect motorcycles in the first frame
60
+ first_frame = cv2.imread(frame_paths[0])
61
+ boxes = detect_motorcycles(first_frame, prompt)
62
+
63
+ # Add boxes as prompts for SAM 2
64
+ if boxes:
65
+ frame_idx, obj_ids, masks = predictor.add_new_points_or_box(
66
+ state, frame_idx=0, obj_ids=[1], boxes=np.array(boxes)
67
+ )
68
+
69
+ # Create output video
70
+ output_path = "output.mp4"
71
+ out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))
72
+
73
+ # Color map for highlighting
74
+ color_map = {"red": (0, 0, 255), "green": (0, 255, 0), "blue": (255, 0, 0)}
75
+ highlight_rgb = color_map.get(highlight_color.lower(), (0, 0, 255))
76
+
77
+ # Propagate masks and apply highlights
78
+ for frame_idx, obj_ids, masks in predictor.propagate_in_video(state):
79
+ frame = cv2.imread(frame_paths[frame_idx])
80
+ mask = masks[0].astype(np.uint8) * 255 # Assuming one object
81
+ mask_colored = np.zeros_like(frame)
82
+ mask_colored[:, :, 0] = mask * highlight_rgb[0]
83
+ mask_colored[:, :, 1] = mask * highlight_rgb[1]
84
+ mask_colored[:, :, 2] = mask * highlight_rgb[2]
85
+ highlighted_frame = cv2.addWeighted(frame, 0.7, mask_colored, 0.3, 0)
86
+ out.write(highlighted_frame)
87
+
88
+ out.release()
89
+
90
+ # Clean up
91
+ for frame_path in frame_paths:
92
+ os.remove(frame_path)
93
+ os.rmdir(frames_dir)
94
+
95
+ return output_path
96
+
97
+ # Gradio interface
98
+ iface = gr.Interface(
99
+ fn=segment_and_highlight_video,
100
+ inputs=[
101
+ gr.Video(label="Upload Video"),
102
+ gr.Textbox(label="Prompt", placeholder="e.g., motorcycle"),
103
+ gr.Dropdown(choices=["red", "green", "blue"], label="Highlight Color")
104
+ ],
105
+ outputs=gr.Video(label="Highlighted Video"),
106
+ title="Video Segmentation with SAM 2 and YOLO-World (CPU)",
107
+ description="Upload a short video, specify a text prompt (e.g., 'motorcycle'), and choose a highlight color. Optimized for CPU."
108
+ )
109
+ iface.launch()