rdjarbeng commited on
Commit
e5ea4ad
·
verified ·
1 Parent(s): 5ada4a5

Use mobile sam to reduce processing time

Browse files
Files changed (1) hide show
  1. app.py +21 -21
app.py CHANGED
@@ -6,18 +6,18 @@ from ultralytics import SAM, YOLOWorld
6
  import os
7
 
8
  # Initialize models
9
- sam_model = SAM("sam2.1_t.pt") # SAM 2.1 tiny model, no device argument
10
- yolo_model = YOLOWorld("yolov8s-world.pt") # Lightweight YOLO-World model
11
 
12
  def detect_motorcycles(first_frame, prompt="motorcycle"):
13
  """Detect motorcycles in the first frame using YOLO-World and return bounding boxes."""
14
  yolo_model.set_classes([prompt])
15
- results = yolo_model.predict(first_frame, device="cpu")
16
  boxes = []
17
  for result in results:
18
  boxes.extend(result.boxes.xyxy.cpu().numpy())
19
  if len(boxes) > 0:
20
- boxes = np.vstack(boxes) # Stack all boxes if multiple results
21
  else:
22
  boxes = np.array([])
23
  return boxes
@@ -29,6 +29,8 @@ def segment_and_highlight_video(video_path, prompt="motorcycle", highlight_color
29
  ret, first_frame = cap.read()
30
  if not ret:
31
  raise ValueError("Could not read first frame from video.")
 
 
32
  cap.release()
33
 
34
  # Detect boxes in first frame
@@ -37,19 +39,19 @@ def segment_and_highlight_video(video_path, prompt="motorcycle", highlight_color
37
  if len(boxes) == 0:
38
  return video_path # No motorcycles detected, return original
39
 
40
- # Run SAM2 on video with boxes prompt
41
- results = sam_model.predict(source=video_path, bboxes=boxes)
 
 
 
 
 
42
 
43
  # Prepare output video
44
  cap = cv2.VideoCapture(video_path)
45
  fps = cap.get(cv2.CAP_PROP_FPS)
46
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
47
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
48
- # Limit resolution for CPU
49
- if width > 640:
50
- scale = 640 / width
51
- width = 640
52
- height = int(height * scale)
53
  output_path = "output.mp4"
54
  out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))
55
 
@@ -58,16 +60,14 @@ def segment_and_highlight_video(video_path, prompt="motorcycle", highlight_color
58
  highlight_rgb = color_map.get(highlight_color.lower(), (0, 0, 255))
59
 
60
  frame_idx = 0
61
- while cap.isOpened():
62
- ret, frame = cap.read()
63
- if not ret:
64
- break
65
  frame = cv2.resize(frame, (width, height))
66
 
67
  # Get masks for this frame
68
- if results[frame_idx].masks is not None:
69
- masks = results[frame_idx].masks.data.cpu().numpy() # (num_masks, h, w)
70
- combined_mask = np.any(masks, axis=0).astype(np.uint8) * 255 # Combine all masks
71
  mask_colored = np.zeros_like(frame)
72
  mask_colored[:, :, 0] = combined_mask * highlight_rgb[0]
73
  mask_colored[:, :, 1] = combined_mask * highlight_rgb[1]
@@ -93,7 +93,7 @@ iface = gr.Interface(
93
  gr.Dropdown(choices=["red", "green", "blue"], label="Highlight Color")
94
  ],
95
  outputs=gr.Video(label="Highlighted Video"),
96
- title="Video Segmentation with SAM 2 and YOLO-World (CPU)",
97
  description="Upload a short video (5-10 seconds), specify a text prompt (e.g., 'motorcycle'), and choose a highlight color. Optimized for CPU."
98
  )
99
  iface.launch()
 
6
  import os
7
 
8
  # Initialize models
9
+ sam_model = SAM("mobile_sam.pt") # Switch to MobileSAM for faster CPU inference
10
+ yolo_model = YOLOWorld("yolov8n-world.pt") # Nano model for faster detection
11
 
12
  def detect_motorcycles(first_frame, prompt="motorcycle"):
13
  """Detect motorcycles in the first frame using YOLO-World and return bounding boxes."""
14
  yolo_model.set_classes([prompt])
15
+ results = yolo_model.predict(first_frame, device="cpu", max_det=2) # Limit to 2 detections
16
  boxes = []
17
  for result in results:
18
  boxes.extend(result.boxes.xyxy.cpu().numpy())
19
  if len(boxes) > 0:
20
+ boxes = np.vstack(boxes)
21
  else:
22
  boxes = np.array([])
23
  return boxes
 
29
  ret, first_frame = cap.read()
30
  if not ret:
31
  raise ValueError("Could not read first frame from video.")
32
+ # Resize first frame for detection
33
+ first_frame = cv2.resize(first_frame, (320, 180))
34
  cap.release()
35
 
36
  # Detect boxes in first frame
 
39
  if len(boxes) == 0:
40
  return video_path # No motorcycles detected, return original
41
 
42
+ # Resize boxes to match SAM input resolution (320x180)
43
+ scale_x = 320 / first_frame.shape[1]
44
+ scale_y = 180 / first_frame.shape[0]
45
+ boxes = boxes * [scale_x, scale_y, scale_x, scale_y]
46
+
47
+ # Run SAM on video with boxes prompt
48
+ results = sam_model.predict(source=video_path, bboxes=boxes, stream=True, imgsz=320) # Stream and low resolution
49
 
50
  # Prepare output video
51
  cap = cv2.VideoCapture(video_path)
52
  fps = cap.get(cv2.CAP_PROP_FPS)
53
+ width = 320
54
+ height = 180
 
 
 
 
 
55
  output_path = "output.mp4"
56
  out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))
57
 
 
60
  highlight_rgb = color_map.get(highlight_color.lower(), (0, 0, 255))
61
 
62
  frame_idx = 0
63
+ for result in results:
64
+ frame = cv2.VideoCapture(video_path).read()[1]
 
 
65
  frame = cv2.resize(frame, (width, height))
66
 
67
  # Get masks for this frame
68
+ if result.masks is not None:
69
+ masks = result.masks.data.cpu().numpy() # (num_masks, h, w)
70
+ combined_mask = np.any(masks, axis=0).astype(np.uint8) * 255
71
  mask_colored = np.zeros_like(frame)
72
  mask_colored[:, :, 0] = combined_mask * highlight_rgb[0]
73
  mask_colored[:, :, 1] = combined_mask * highlight_rgb[1]
 
93
  gr.Dropdown(choices=["red", "green", "blue"], label="Highlight Color")
94
  ],
95
  outputs=gr.Video(label="Highlighted Video"),
96
+ title="Video Segmentation with MobileSAM and YOLO-World (CPU)",
97
  description="Upload a short video (5-10 seconds), specify a text prompt (e.g., 'motorcycle'), and choose a highlight color. Optimized for CPU."
98
  )
99
  iface.launch()