rdjarbeng commited on
Commit
7a3f21a
·
verified ·
1 Parent(s): e5ea4ad

Limit frames, optimize code

Browse files

SAM Resolution Problem: The original code was calling sam_model.predict(source=video_path, ...) which processes the entire video at SAM's default resolution (1024). I changed this to process individual frames with explicit imgsz=320 parameter.
Inefficient Frame Processing: Your code was opening a new VideoCapture for each frame in the loop (cv2.VideoCapture(video_path).read()[1]), which is extremely inefficient.
Missing Resolution Control for YOLO: Added imgsz=320 to the YOLO prediction to ensure consistent resolution.
Box Scaling Issues: Removed unnecessary box scaling since we're working consistently in the target resolution.
Memory Leaks: Fixed VideoCapture resource management

Files changed (1) hide show
  1. app.py +96 -51
app.py CHANGED
@@ -12,10 +12,12 @@ yolo_model = YOLOWorld("yolov8n-world.pt") # Nano model for faster detection
12
  def detect_motorcycles(first_frame, prompt="motorcycle"):
13
  """Detect motorcycles in the first frame using YOLO-World and return bounding boxes."""
14
  yolo_model.set_classes([prompt])
15
- results = yolo_model.predict(first_frame, device="cpu", max_det=2) # Limit to 2 detections
16
  boxes = []
17
  for result in results:
18
- boxes.extend(result.boxes.xyxy.cpu().numpy())
 
 
19
  if len(boxes) > 0:
20
  boxes = np.vstack(boxes)
21
  else:
@@ -24,64 +26,100 @@ def detect_motorcycles(first_frame, prompt="motorcycle"):
24
 
25
  def segment_and_highlight_video(video_path, prompt="motorcycle", highlight_color="red"):
26
  """Segment and highlight motorcycles in a video using SAM 2 and YOLO-World."""
27
- # Get first frame for detection
 
28
  cap = cv2.VideoCapture(video_path)
 
 
 
 
 
 
 
 
 
29
  ret, first_frame = cap.read()
30
  if not ret:
 
31
  raise ValueError("Could not read first frame from video.")
 
32
  # Resize first frame for detection
33
- first_frame = cv2.resize(first_frame, (320, 180))
34
  cap.release()
35
-
36
- # Detect boxes in first frame
37
- boxes = detect_motorcycles(first_frame, prompt)
38
-
39
  if len(boxes) == 0:
40
  return video_path # No motorcycles detected, return original
41
-
42
- # Resize boxes to match SAM input resolution (320x180)
43
- scale_x = 320 / first_frame.shape[1]
44
- scale_y = 180 / first_frame.shape[0]
45
- boxes = boxes * [scale_x, scale_y, scale_x, scale_y]
46
-
47
- # Run SAM on video with boxes prompt
48
- results = sam_model.predict(source=video_path, bboxes=boxes, stream=True, imgsz=320) # Stream and low resolution
49
-
50
- # Prepare output video
51
- cap = cv2.VideoCapture(video_path)
52
- fps = cap.get(cv2.CAP_PROP_FPS)
53
- width = 320
54
- height = 180
55
- output_path = "output.mp4"
56
- out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))
57
-
58
  # Color map for highlighting
59
  color_map = {"red": (0, 0, 255), "green": (0, 255, 0), "blue": (255, 0, 0)}
60
  highlight_rgb = color_map.get(highlight_color.lower(), (0, 0, 255))
61
-
62
- frame_idx = 0
63
- for result in results:
64
- frame = cv2.VideoCapture(video_path).read()[1]
65
- frame = cv2.resize(frame, (width, height))
66
-
67
- # Get masks for this frame
68
- if result.masks is not None:
69
- masks = result.masks.data.cpu().numpy() # (num_masks, h, w)
70
- combined_mask = np.any(masks, axis=0).astype(np.uint8) * 255
71
- mask_colored = np.zeros_like(frame)
72
- mask_colored[:, :, 0] = combined_mask * highlight_rgb[0]
73
- mask_colored[:, :, 1] = combined_mask * highlight_rgb[1]
74
- mask_colored[:, :, 2] = combined_mask * highlight_rgb[2]
75
- highlighted_frame = cv2.addWeighted(frame, 0.7, mask_colored, 0.3, 0)
76
- else:
77
- highlighted_frame = frame
78
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  out.write(highlighted_frame)
80
- frame_idx += 1
81
-
 
 
 
 
82
  cap.release()
83
  out.release()
84
-
 
85
  return output_path
86
 
87
  # Gradio interface
@@ -89,11 +127,18 @@ iface = gr.Interface(
89
  fn=segment_and_highlight_video,
90
  inputs=[
91
  gr.Video(label="Upload Video"),
92
- gr.Textbox(label="Prompt", placeholder="e.g., motorcycle"),
93
- gr.Dropdown(choices=["red", "green", "blue"], label="Highlight Color")
94
  ],
95
  outputs=gr.Video(label="Highlighted Video"),
96
- title="Video Segmentation with MobileSAM and YOLO-World (CPU)",
97
- description="Upload a short video (5-10 seconds), specify a text prompt (e.g., 'motorcycle'), and choose a highlight color. Optimized for CPU."
 
 
 
 
 
98
  )
99
- iface.launch()
 
 
 
12
  def detect_motorcycles(first_frame, prompt="motorcycle"):
13
  """Detect motorcycles in the first frame using YOLO-World and return bounding boxes."""
14
  yolo_model.set_classes([prompt])
15
+ results = yolo_model.predict(first_frame, device="cpu", max_det=2, imgsz=320) # Force YOLO to use 320 resolution
16
  boxes = []
17
  for result in results:
18
+ if result.boxes is not None and len(result.boxes.xyxy) > 0:
19
+ boxes.extend(result.boxes.xyxy.cpu().numpy())
20
+
21
  if len(boxes) > 0:
22
  boxes = np.vstack(boxes)
23
  else:
 
26
 
27
  def segment_and_highlight_video(video_path, prompt="motorcycle", highlight_color="red"):
28
  """Segment and highlight motorcycles in a video using SAM 2 and YOLO-World."""
29
+
30
+ # Get video properties first
31
  cap = cv2.VideoCapture(video_path)
32
+ original_fps = cap.get(cv2.CAP_PROP_FPS)
33
+ original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
34
+ original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
35
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
36
+
37
+ # Target resolution
38
+ target_width, target_height = 320, 180
39
+
40
+ # Get first frame for detection
41
  ret, first_frame = cap.read()
42
  if not ret:
43
+ cap.release()
44
  raise ValueError("Could not read first frame from video.")
45
+
46
  # Resize first frame for detection
47
+ first_frame_resized = cv2.resize(first_frame, (target_width, target_height))
48
  cap.release()
49
+
50
+ # Detect boxes in resized first frame
51
+ boxes = detect_motorcycles(first_frame_resized, prompt)
 
52
  if len(boxes) == 0:
53
  return video_path # No motorcycles detected, return original
54
+
55
+ # Boxes are already in the target resolution coordinate system
56
+ print(f"Detected {len(boxes)} objects with boxes: {boxes}")
57
+
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  # Color map for highlighting
59
  color_map = {"red": (0, 0, 255), "green": (0, 255, 0), "blue": (255, 0, 0)}
60
  highlight_rgb = color_map.get(highlight_color.lower(), (0, 0, 255))
61
+
62
+ # Process video frame by frame instead of using SAM's video prediction
63
+ cap = cv2.VideoCapture(video_path)
64
+ output_path = "output.mp4"
65
+ out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), original_fps, (target_width, target_height))
66
+
67
+ frame_count = 0
68
+ max_frames = min(total_frames, 150) # Limit to 150 frames (~5 seconds at 30fps)
69
+
70
+ print(f"Processing {max_frames} frames...")
71
+
72
+ while frame_count < max_frames:
73
+ ret, frame = cap.read()
74
+ if not ret:
75
+ break
76
+
77
+ # Resize frame to target resolution
78
+ frame_resized = cv2.resize(frame, (target_width, target_height))
79
+
80
+ try:
81
+ # Run SAM on individual frame with explicit resolution control
82
+ sam_results = sam_model.predict(
83
+ source=frame_resized,
84
+ bboxes=boxes,
85
+ device="cpu",
86
+ imgsz=320, # Force SAM resolution
87
+ conf=0.25,
88
+ verbose=False
89
+ )
90
+
91
+ highlighted_frame = frame_resized.copy()
92
+
93
+ # Process SAM results
94
+ if len(sam_results) > 0 and sam_results[0].masks is not None:
95
+ masks = sam_results[0].masks.data.cpu().numpy()
96
+
97
+ if len(masks) > 0:
98
+ # Combine all masks
99
+ combined_mask = np.any(masks, axis=0).astype(np.uint8)
100
+
101
+ # Create colored overlay
102
+ overlay = np.zeros_like(frame_resized)
103
+ overlay[combined_mask == 1] = highlight_rgb
104
+
105
+ # Blend with original frame
106
+ highlighted_frame = cv2.addWeighted(frame_resized, 0.7, overlay, 0.3, 0)
107
+
108
+ except Exception as e:
109
+ print(f"Error processing frame {frame_count}: {e}")
110
+ highlighted_frame = frame_resized
111
+
112
  out.write(highlighted_frame)
113
+ frame_count += 1
114
+
115
+ # Progress indicator
116
+ if frame_count % 30 == 0:
117
+ print(f"Processed {frame_count}/{max_frames} frames")
118
+
119
  cap.release()
120
  out.release()
121
+
122
+ print(f"Video processing complete. Output saved to {output_path}")
123
  return output_path
124
 
125
  # Gradio interface
 
127
  fn=segment_and_highlight_video,
128
  inputs=[
129
  gr.Video(label="Upload Video"),
130
+ gr.Textbox(label="Prompt", placeholder="e.g., motorcycle", value="motorcycle"),
131
+ gr.Dropdown(choices=["red", "green", "blue"], label="Highlight Color", value="red")
132
  ],
133
  outputs=gr.Video(label="Highlighted Video"),
134
+ title="Video Segmentation with MobileSAM and YOLO-World (CPU Optimized)",
135
+ description="Upload a short video (5-10 seconds), specify a text prompt (e.g., 'motorcycle'), and choose a highlight color. Optimized for CPU with 320x180 resolution.",
136
+ examples=[
137
+ [None, "motorcycle", "red"],
138
+ [None, "car", "green"],
139
+ [None, "person", "blue"]
140
+ ]
141
  )
142
+
143
+ if __name__ == "__main__":
144
+ iface.launch()