rdjarbeng commited on
Commit
801bb4d
·
verified ·
1 Parent(s): 93c5c43

Try to process individual frames to fix error

Browse files
Files changed (1) hide show
  1. app.py +52 -65
app.py CHANGED
@@ -2,98 +2,85 @@ import gradio as gr
2
  import cv2
3
  import numpy as np
4
  import torch
5
- from ultralytics.models.sam import SAM2VideoPredictor
6
- from ultralytics import YOLOWorld
7
- import supervision as sv
8
  import os
9
 
10
  # Initialize models
11
- overrides = dict(model="sam2.1_t.pt", device="cpu")
12
- predictor = SAM2VideoPredictor(overrides=overrides)
13
  yolo_model = YOLOWorld("yolov8s-world.pt") # Lightweight YOLO-World model
14
 
15
- def detect_motorcycles(frame, prompt="motorcycle"):
16
- """Detect motorcycles in a frame using YOLO-World and return bounding boxes."""
17
  yolo_model.set_classes([prompt])
18
- results = yolo_model.predict(frame, device="cpu")
19
  boxes = []
20
  for result in results:
21
- for box in result.boxes:
22
- # Check if the detected class matches the prompt
23
- if result.names[int(box.cls)] == prompt:
24
- x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
25
- boxes.append([x1, y1, x2, y2])
26
  return boxes
27
 
28
  def segment_and_highlight_video(video_path, prompt="motorcycle", highlight_color="red"):
29
  """Segment and highlight motorcycles in a video using SAM 2 and YOLO-World."""
30
- # Create temporary directory for video frames
31
- frames_dir = "video_frames"
32
- os.makedirs(frames_dir, exist_ok=True)
 
 
 
 
 
 
 
 
 
33
 
34
- # Extract frames
 
 
 
35
  cap = cv2.VideoCapture(video_path)
36
  fps = cap.get(cv2.CAP_PROP_FPS)
37
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
38
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
39
  # Limit resolution for CPU
40
  if width > 640:
41
- height = int(height * 640 / width)
42
  width = 640
43
- frame_paths = []
 
 
44
 
45
- # Save frames as JPEG
46
- frame_idx = 0
47
- with sv.ImageSink(target_dir_path=frames_dir, image_name_pattern="{:05d}.jpeg") as sink:
48
- while cap.isOpened():
49
- ret, frame = cap.read()
50
- if not ret:
51
- break
52
- frame = cv2.resize(frame, (width, height))
53
- sink.save_image(frame)
54
- frame_paths.append(os.path.join(frames_dir, f"{frame_idx:05d}.jpeg"))
55
- frame_idx += 1
56
- cap.release()
57
 
58
- # Initialize SAM 2 inference state
59
- with torch.inference_mode():
60
- state = predictor.init_state(video_path=frames_dir)
61
-
62
- # Detect motorcycles in the first frame
63
- first_frame = cv2.imread(frame_paths[0])
64
- boxes = detect_motorcycles(first_frame, prompt)
65
-
66
- # Add boxes as prompts for SAM 2
67
- if boxes:
68
- frame_idx, obj_ids, masks = predictor.add_new_points_or_box(
69
- state, frame_idx=0, obj_ids=[1], boxes=np.array(boxes)
70
- )
71
-
72
- # Create output video
73
- output_path = "output.mp4"
74
- out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))
75
-
76
- # Color map for highlighting
77
- color_map = {"red": (0, 0, 255), "green": (0, 255, 0), "blue": (255, 0, 0)}
78
- highlight_rgb = color_map.get(highlight_color.lower(), (0, 0, 255))
79
 
80
- # Propagate masks and apply highlights
81
- for frame_idx, obj_ids, masks in predictor.propagate_in_video(state):
82
- frame = cv2.imread(frame_paths[frame_idx])
83
- mask = masks[0].astype(np.uint8) * 255 # Assuming one object
84
  mask_colored = np.zeros_like(frame)
85
- mask_colored[:, :, 0] = mask * highlight_rgb[0]
86
- mask_colored[:, :, 1] = mask * highlight_rgb[1]
87
- mask_colored[:, :, 2] = mask * highlight_rgb[2]
88
  highlighted_frame = cv2.addWeighted(frame, 0.7, mask_colored, 0.3, 0)
89
- out.write(highlighted_frame)
 
90
 
91
- out.release()
 
92
 
93
- # Clean up
94
- for frame_path in frame_paths:
95
- os.remove(frame_path)
96
- os.rmdir(frames_dir)
97
 
98
  return output_path
99
 
 
2
  import cv2
3
  import numpy as np
4
  import torch
5
+ from ultralytics import SAM, YOLOWorld
 
 
6
  import os
7
 
8
  # Initialize models
9
+ sam_model = SAM("sam2.1_t.pt", device="cpu")
 
10
  yolo_model = YOLOWorld("yolov8s-world.pt") # Lightweight YOLO-World model
11
 
12
+ def detect_motorcycles(first_frame, prompt="motorcycle"):
13
+ """Detect motorcycles in the first frame using YOLO-World and return bounding boxes."""
14
  yolo_model.set_classes([prompt])
15
+ results = yolo_model.predict(first_frame, device="cpu")
16
  boxes = []
17
  for result in results:
18
+ boxes.append(result.boxes.xyxy.cpu().numpy())
19
+ if len(boxes) > 0:
20
+ boxes = np.vstack(boxes) # Stack all boxes if multiple results
21
+ else:
22
+ boxes = np.array([])
23
  return boxes
24
 
25
  def segment_and_highlight_video(video_path, prompt="motorcycle", highlight_color="red"):
26
  """Segment and highlight motorcycles in a video using SAM 2 and YOLO-World."""
27
+ # Get first frame for detection
28
+ cap = cv2.VideoCapture(video_path)
29
+ ret, first_frame = cap.read()
30
+ if not ret:
31
+ raise ValueError("Could not read first frame from video.")
32
+ cap.release()
33
+
34
+ # Detect boxes in first frame
35
+ boxes = detect_motorcycles(first_frame, prompt)
36
+
37
+ if len(boxes) == 0:
38
+ return video_path # No motorcycles detected, return original
39
 
40
+ # Run SAM2 on video with boxes prompt
41
+ results = sam_model(source=video_path, bboxes=boxes)
42
+
43
+ # Prepare output video
44
  cap = cv2.VideoCapture(video_path)
45
  fps = cap.get(cv2.CAP_PROP_FPS)
46
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
47
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
48
  # Limit resolution for CPU
49
  if width > 640:
50
+ scale = 640 / width
51
  width = 640
52
+ height = int(height * scale)
53
+ output_path = "output.mp4"
54
+ out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))
55
 
56
+ # Color map for highlighting
57
+ color_map = {"red": (0, 0, 255), "green": (0, 255, 0), "blue": (255, 0, 0)}
58
+ highlight_rgb = color_map.get(highlight_color.lower(), (0, 0, 255))
 
 
 
 
 
 
 
 
 
59
 
60
+ frame_idx = 0
61
+ while cap.isOpened():
62
+ ret, frame = cap.read()
63
+ if not ret:
64
+ break
65
+ frame = cv2.resize(frame, (width, height))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ # Get masks for this frame
68
+ if results[frame_idx].masks is not None:
69
+ masks = results[frame_idx].masks.data.cpu().numpy() # (num_masks, h, w)
70
+ combined_mask = np.any(masks, axis=0).astype(np.uint8) * 255 # Combine all masks
71
  mask_colored = np.zeros_like(frame)
72
+ mask_colored[:, :, 0] = combined_mask * highlight_rgb[0]
73
+ mask_colored[:, :, 1] = combined_mask * highlight_rgb[1]
74
+ mask_colored[:, :, 2] = combined_mask * highlight_rgb[2]
75
  highlighted_frame = cv2.addWeighted(frame, 0.7, mask_colored, 0.3, 0)
76
+ else:
77
+ highlighted_frame = frame
78
 
79
+ out.write(highlighted_frame)
80
+ frame_idx += 1
81
 
82
+ cap.release()
83
+ out.release()
 
 
84
 
85
  return output_path
86