John Ho commited on
Commit
e7334c8
·
1 Parent(s): df7d2e0

init comit

Browse files
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces, torch
3
+ from samv2_handler import load_sam_image_model, run_sam_im_inference
4
+ from PIL import Image
5
+ from typing import Union
6
+
7
+
8
+ @spaces.GPU
9
+ def load_im_model(variant, auto_mask_gen: bool = False):
10
+ return load_sam_image_model(
11
+ variant=variant, device="cuda", auto_mask_gen=auto_mask_gen
12
+ )
13
+
14
+
15
+ @spaces.GPU
16
+ def detect_image(
17
+ im: Image.Image,
18
+ variant: str,
19
+ bboxes: Union[list, str] = None,
20
+ points: Union[list, str] = None,
21
+ point_labels: Union[list, str] = None,
22
+ ):
23
+ """
24
+ SAM2 Image Segmentation
25
+
26
+ Args:
27
+ im: Pillow Image
28
+ object_name: the object you would like to detect
29
+ mode: point or object_detection
30
+ Returns:
31
+ list: a list of masks
32
+ """
33
+ bboxes = json.loads(bboxes) if isinstance(bboxes, str) else bboxes
34
+ model = load_im_model(variant=variant)
35
+ return run_sam_im_inference(
36
+ model, image=im, bboxes=bboxes, get_pil_mask=False, b64_encode_mask=True
37
+ )
38
+
39
+
40
+ with gr.Blocks() as demo:
41
+ with gr.Tab("Images"):
42
+ gr.Interface(
43
+ fn=detect_image,
44
+ inputs=[
45
+ gr.Image(label="Input Image", type="pil"),
46
+ gr.Dropdown(
47
+ label="Model Variant",
48
+ choices=["tiny", "small", "base_plus", "large"],
49
+ ),
50
+ gr.JSON(
51
+ label='Bounding Boxes (JSON list of dicts: [{"x0":..., "y0":..., "x1":..., "y1":...}, ...])',
52
+ optional=True,
53
+ ),
54
+ ],
55
+ outputs=gr.JSON(label="Output JSON"),
56
+ title="SAM2 for Images",
57
+ )
58
+ demo.launch(
59
+ mcp_server=True, app_kwargs={"docs_url": "/docs"} # add FastAPI Swagger API Docs
60
+ )
ffmpeg_extractor.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ffmpeg, typer, os, sys, json, shutil
2
+ from loguru import logger
3
+
4
+ logger.remove()
5
+ logger.add(
6
+ sys.stderr,
7
+ format="<d>{time:YYYY-MM-DD ddd HH:mm:ss}</d> | <lvl>{level}</lvl> | <lvl>{message}</lvl>",
8
+ )
9
+ app = typer.Typer(pretty_exceptions_show_locals=False)
10
+
11
+
12
+ def parse_frame_name(fname: str):
13
+ """return a tuple of frame_type and frame_index"""
14
+ fn, fext = os.path.splitext(os.path.basename(fname))
15
+ frame_type, frame_index = fn.split("_")
16
+ return frame_type, int(frame_index)
17
+
18
+
19
+ def get_fps_ffmpeg(video_path: str):
20
+ probe = ffmpeg.probe(video_path)
21
+ # Find the first video stream
22
+ video_stream = next(
23
+ (stream for stream in probe["streams"] if stream["codec_type"] == "video"), None
24
+ )
25
+ if video_stream is None:
26
+ raise ValueError("No video stream found")
27
+ # Frame rate is given as a string fraction, e.g., '30000/1001'
28
+ r_frame_rate = video_stream["r_frame_rate"]
29
+ num, denom = map(int, r_frame_rate.split("/"))
30
+ return num / denom
31
+
32
+
33
+ @app.command()
34
+ def extract_keyframes_greedy(
35
+ video_path: str,
36
+ output_dir: str = None,
37
+ threshold: float = 0.2,
38
+ overwrite: bool = False,
39
+ ):
40
+ """
41
+ run i-frames extractions and keyframes extraction and return a list of keyframe's paths
42
+ """
43
+ assert (
44
+ threshold > 0
45
+ ), f"threshold must be no negative, for i-frame extraction use extract-keyframes instead"
46
+
47
+ iframes = extract_keyframes(
48
+ video_path,
49
+ output_dir=output_dir,
50
+ threshold=0,
51
+ overwrite=overwrite,
52
+ append=False,
53
+ )
54
+ assert type(iframes) != type(None), f"i-frames extraction failed"
55
+ kframes = extract_keyframes(
56
+ video_path,
57
+ output_dir=output_dir,
58
+ threshold=threshold,
59
+ overwrite=False,
60
+ append=True,
61
+ )
62
+ assert type(kframes) != type(None), f"keyframes extraction failed"
63
+
64
+ # remove kframes that are also iframes
65
+ removed_kframes = []
66
+ for fn in kframes:
67
+ fname = os.path.basename(fn)
68
+ if os.path.isfile(
69
+ os.path.join(os.path.dirname(fn), fname.replace("kframe_", "iframe_"))
70
+ ):
71
+ os.remove(fn)
72
+ removed_kframes.append(fn)
73
+ if len(removed_kframes) > 0:
74
+ logger.warning(f"removed {len(removed_kframes)} redundant kframes")
75
+ kframes = [kf for kf in kframes if kf not in removed_kframes]
76
+
77
+ frames = iframes + kframes
78
+ logger.success(f"extracted {len(frames)} total frames")
79
+ return frames
80
+
81
+
82
+ @app.command()
83
+ def extract_keyframes(
84
+ video_path: str,
85
+ output_dir: str = None,
86
+ threshold: float = 0.3,
87
+ overwrite: bool = False,
88
+ append: bool = False,
89
+ ):
90
+ """extract keyframes as images into output_dir and return a list of keyframe's paths
91
+
92
+ Args:
93
+ output_dir: if not provided, will be in video_name/keyframes/
94
+ """
95
+ # Create output directory if it doesn't exist
96
+ output_dir = output_dir if output_dir else os.path.dirname(video_path)
97
+ vname, vext = os.path.splitext(os.path.basename(video_path))
98
+ output_dir = os.path.join(output_dir, vname, "keyframes")
99
+ if os.path.isdir(output_dir):
100
+ if overwrite:
101
+ shutil.rmtree(output_dir)
102
+ logger.warning(f"removed existing data: {output_dir}")
103
+ elif not append:
104
+ logger.error(f"overwrite is false and data already exists!")
105
+ return None
106
+ os.makedirs(output_dir, exist_ok=True)
107
+
108
+ # Construct the ffmpeg-python pipeline
109
+ stream = ffmpeg.input(video_path)
110
+ config_dict = {
111
+ "vsync": "0",
112
+ "frame_pts": "true",
113
+ }
114
+
115
+ if threshold:
116
+ # always add in the first frame by default
117
+ filter_value = f"eq(n,0)+gt(scene,{threshold})"
118
+ frame_name = "kframe"
119
+ logger.info(f"Extracting Scene-changing frames with {filter_value}")
120
+ else:
121
+ filter_value = f"eq(pict_type,I)"
122
+ # config_dict["skip_frame"] = "nokey"
123
+ frame_name = "iframe"
124
+ logger.info(f"Extracting I-Frames since no threshold provided: {filter_value}")
125
+
126
+ stream = ffmpeg.filter(stream, "select", filter_value)
127
+ stream = ffmpeg.output(stream, f"{output_dir}/{frame_name}_%d.jpg", **config_dict)
128
+
129
+ # Execute the ffmpeg command
130
+ try:
131
+ ffmpeg.run(stream, capture_stdout=True, capture_stderr=True)
132
+ frames = [
133
+ os.path.join(output_dir, f)
134
+ for f in os.listdir(output_dir)
135
+ if f.endswith(".jpg") and frame_name in f
136
+ ]
137
+ logger.success(f"{len(frames)} {frame_name} extracted to {output_dir}")
138
+ return frames
139
+ except ffmpeg.Error as e:
140
+ logger.error(f"Error executing FFmpeg command: {e.stderr.decode()}")
141
+ return None
142
+
143
+
144
+ @app.command()
145
+ def extract_audio(video_path: str, output_dir: str = None, overwrite: bool = False):
146
+ """extracting audio of a video file into m4a without re-encoding
147
+ ref: https://www.baeldung.com/linux/ffmpeg-audio-from-video#1-extracting-audio-without-re-encoding
148
+ """
149
+ # Create output directory if it doesn't exist
150
+ output_dir = output_dir if output_dir else os.path.dirname(video_path)
151
+ vname, vext = os.path.splitext(os.path.basename(video_path))
152
+ output_dir = os.path.join(output_dir, vname)
153
+ output_fname = os.path.join(output_dir, vname + ".m4a")
154
+ if os.path.isfile(output_fname):
155
+ if overwrite:
156
+ os.remove(output_fname)
157
+ logger.warning(f"removed existing data: {output_fname}")
158
+ else:
159
+ logger.error(f"overwrite is false and data already exists!")
160
+ return None
161
+ os.makedirs(output_dir, exist_ok=True)
162
+
163
+ # Construct the ffmpeg-python pipeline
164
+ stream = ffmpeg.input(video_path)
165
+ config_dict = {"map": "0:a", "acodec": "copy"}
166
+ stream = ffmpeg.output(stream, output_fname, **config_dict)
167
+
168
+ # Execute the ffmpeg command
169
+ try:
170
+ ffmpeg.run(stream, capture_stdout=True, capture_stderr=True)
171
+ logger.success(f"audio extracted to {output_fname}")
172
+ return output_fname
173
+ except ffmpeg.Error as e:
174
+ logger.error(f"Error executing FFmpeg command: {e.stderr.decode()}")
175
+ return None
176
+
177
+
178
+ @app.command()
179
+ def extract_frames(
180
+ video_path: str,
181
+ output_dir: str = None,
182
+ fps: int = None,
183
+ every_x: int = None,
184
+ overwrite: bool = False,
185
+ append: bool = False,
186
+ im_name_pattern: str = "frame_%05d.jpg",
187
+ ):
188
+ """extract frames as images into output_dir and return the list of frames' paths
189
+
190
+ Args:
191
+ output_dir: if not provided, will be in video_name/keyframes/
192
+ """
193
+ # Create output directory if it doesn't exist
194
+ vname, vext = os.path.splitext(os.path.basename(video_path))
195
+ output_dir = output_dir if output_dir else os.path.dirname(video_path)
196
+ output_dir = os.path.join(output_dir, vname, "keyframes")
197
+ if os.path.isdir(output_dir):
198
+ if overwrite:
199
+ shutil.rmtree(output_dir)
200
+ logger.warning(f"removed existing data: {output_dir}")
201
+ elif not append:
202
+ logger.error(f"overwrite is false and data already exists in {output_dir}!")
203
+ return None
204
+ os.makedirs(output_dir, exist_ok=True)
205
+
206
+ # Construct the ffmpeg-python pipeline
207
+ stream = ffmpeg.input(video_path)
208
+ config_dict = {
209
+ "vsync": 0, # preserves the original timestamps
210
+ "frame_pts": 1, # set output file's %d to the frame's PTS
211
+ }
212
+ if fps:
213
+ # check FPS
214
+ vid_fps = get_fps_ffmpeg(video_path)
215
+ fps = min(vid_fps, fps)
216
+ logger.info(f"{vname}{vext} FPS: {vid_fps}, extraction FPS: {fps}")
217
+ config_dict["vf"] = f"fps={fps}"
218
+ elif every_x:
219
+ config_dict["vf"] = f"select=not(mod(n\,{every_x}))"
220
+
221
+ logger.info(
222
+ f"Extracting Frames into {output_dir} with these configs: \n{config_dict}"
223
+ )
224
+ stream = ffmpeg.output(stream, f"{output_dir}/{im_name_pattern}", **config_dict)
225
+
226
+ # Execute the ffmpeg command
227
+ try:
228
+ ffmpeg.run(stream, capture_stdout=True, capture_stderr=True)
229
+ frames = [
230
+ os.path.join(output_dir, f)
231
+ for f in os.listdir(output_dir)
232
+ if f.endswith(".jpg")
233
+ ]
234
+ logger.success(f"{len(frames)} frames extracted to {output_dir}")
235
+ return frames
236
+ except ffmpeg.Error as e:
237
+ logger.error(f"Error executing FFmpeg command: {e.stderr.decode()}")
238
+ return None
239
+
240
+
241
+ if __name__ == "__main__":
242
+ app()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ ffmpeg-python>=0.2.0
2
+ imageio[ffmpeg]>=2.37.0
3
+ loguru>=0.7.3
4
+ pydantic
5
+ retrying>=1.3.4
6
+ samv2==0.0.4
7
+ validators>=0.35.0
samv2_handler.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, shutil
2
+ import numpy as np
3
+ from PIL import Image
4
+ from typing import Literal, Any, Union, Generic, List
5
+ from pydantic import BaseModel
6
+ from sam2.build_sam import build_sam2, build_sam2_video_predictor
7
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
8
+ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
9
+ from sam2.utils.misc import variant_to_config_mapping
10
+ from sam2.utils.visualization import show_masks
11
+ from ffmpeg_extractor import extract_frames, logger
12
+ from toolbox.vid_utils import VidInfo
13
+ from toolbox.mask_encoding import b64_mask_encode
14
+
15
+ variant_checkpoints_mapping = {
16
+ "tiny": "checkpoints/sam2_hiera_tiny.pt",
17
+ "small": "checkpoints/sam2_hiera_small.pt",
18
+ "base_plus": "checkpoints/sam2_hiera_base_plus.pt",
19
+ "large": "checkpoints/sam2_hiera_large.pt",
20
+ }
21
+
22
+
23
+ class bbox_xyxy(BaseModel):
24
+ x0: Union[int, float]
25
+ y0: Union[int, float]
26
+ x1: Union[int, float]
27
+ y1: Union[int, float]
28
+
29
+
30
+ class point_xy(BaseModel):
31
+ x: Union[int, float]
32
+ y: Union[int, float]
33
+
34
+
35
+ def mask_to_xyxy(mask: np.ndarray) -> tuple:
36
+ """Convert a binary mask of shape (h, w) to
37
+ xyxy bounding box format (top-left and bottom-right coordinates).
38
+ """
39
+ ys, xs = np.where(mask)
40
+ if len(xs) == 0 or len(ys) == 0:
41
+ logger.warning("mask_to_xyxy: No object found in the mask")
42
+ return None
43
+ x_min = np.min(xs)
44
+ y_min = np.min(ys)
45
+ x_max = np.max(xs)
46
+ y_max = np.max(ys)
47
+ xyxy = (x_min, y_min, x_max, y_max)
48
+ xyxy = tuple([int(i) for i in xyxy])
49
+ return xyxy
50
+
51
+
52
+ def load_sam_image_model(
53
+ # variant: Literal[*variant_checkpoints_mapping.keys()],
54
+ variant: Literal["tiny", "small", "base_plus", "large"],
55
+ device: str = "cpu",
56
+ auto_mask_gen: bool = False,
57
+ ) -> SAM2ImagePredictor:
58
+ model = build_sam2(
59
+ config_file=variant_to_config_mapping[variant],
60
+ ckpt_path=variant_checkpoints_mapping[variant],
61
+ device=device,
62
+ )
63
+ return (
64
+ SAM2AutomaticMaskGenerator(model)
65
+ if auto_mask_gen
66
+ else SAM2ImagePredictor(sam_model=model)
67
+ )
68
+
69
+
70
+ def load_sam_video_model(
71
+ variant: Literal["tiny", "small", "base_plus", "large"] = "small",
72
+ device: str = "cpu",
73
+ ) -> Any:
74
+ return build_sam2_video_predictor(
75
+ config_file=variant_to_config_mapping[variant],
76
+ ckpt_path=variant_checkpoints_mapping[variant],
77
+ device=device,
78
+ )
79
+
80
+
81
+ def run_sam_im_inference(
82
+ model: Any,
83
+ image: Image.Image,
84
+ points: Union[List[point_xy], List[dict]] = [],
85
+ point_labels: List[int] = [],
86
+ bboxes: Union[List[bbox_xyxy], List[dict]] = [],
87
+ get_pil_mask: bool = False,
88
+ b64_encode_mask: bool = False,
89
+ ):
90
+ """returns a list of np masks, each with the shape (h,w) and dtype uint8"""
91
+ assert (
92
+ points or bboxes
93
+ ), f"SAM2 Image Inference must have either bounding boxes or points. Neither were provided."
94
+ if points:
95
+ assert len(points) == len(
96
+ point_labels
97
+ ), f"{len(points)} points provided but {len(point_labels)} labels given."
98
+
99
+ # determine multimask_output
100
+ has_multi = False
101
+ if points and bboxes:
102
+ has_multi = True
103
+ elif points and len(list(set(point_labels))) > 1:
104
+ has_multi = True
105
+ elif bboxes and len(bboxes) > 1:
106
+ has_multi = True
107
+
108
+ # parse provided bboxes
109
+ bboxes = (
110
+ [bbox_xyxy(**bbox) if isinstance(bbox, dict) else bbox for bbox in bboxes]
111
+ if bboxes
112
+ else []
113
+ )
114
+ points = (
115
+ [point_xy(**p) if isinstance(p, dict) else p for p in points] if points else []
116
+ )
117
+
118
+ # setup inference
119
+ image = np.array(image.convert("RGB"))
120
+ model.set_image(image)
121
+
122
+ box_coords = (
123
+ np.array([[b.x0, b.y0, b.x1, b.y1] for b in bboxes]) if bboxes else None
124
+ )
125
+ point_coords = np.array([[p.x, p.y] for p in points]) if points else None
126
+ point_labels = np.array(point_labels) if point_labels else None
127
+
128
+ masks, scores, _ = model.predict(
129
+ box=box_coords,
130
+ point_coords=point_coords,
131
+ point_labels=point_labels,
132
+ multimask_output=has_multi,
133
+ )
134
+ # mask here is of shape (X, h, w) of np array, X = number of masks
135
+
136
+ if get_pil_mask:
137
+ return show_masks(image, masks, scores=None, display_image=False)
138
+ else:
139
+ output_masks = []
140
+ for i, mask in enumerate(masks):
141
+ if mask.ndim > 2: # shape (3, h, w)
142
+ mask = np.transpose(mask, (1, 2, 0)) # shape (h,w,3)
143
+ mask = Image.fromarray((mask * 255).astype(np.uint8)).convert("L")
144
+ output_masks.append(np.array(mask))
145
+ else:
146
+ output_masks.append(mask.squeeze().astype(np.uint8))
147
+ return (
148
+ [b64_mask_encode(m) for m in output_masks]
149
+ if b64_encode_mask
150
+ else output_masks
151
+ )
152
+
153
+
154
+ def run_sam_video_inference(
155
+ model: Any,
156
+ video_path: str,
157
+ masks: np.ndarray,
158
+ device: str = "cpu",
159
+ sample_fps: int = None,
160
+ every_x: int = None,
161
+ do_tidy_up: bool = False,
162
+ drop_mask: bool = True,
163
+ ):
164
+ # put video frames into directory
165
+ # TODO:
166
+ # change frame size
167
+ # async frame load
168
+ l_frames_fp = extract_frames(
169
+ video_path,
170
+ fps=sample_fps,
171
+ every_x=every_x,
172
+ overwrite=True,
173
+ im_name_pattern="%05d.jpg",
174
+ )
175
+ vframes_dir = os.path.dirname(l_frames_fp[0])
176
+ vinfo = VidInfo(video_path)
177
+ w = vinfo["frame_width"]
178
+ h = vinfo["frame_height"]
179
+
180
+ inference_state = model.init_state(video_path=vframes_dir, device=device)
181
+ for i, mask in enumerate(masks):
182
+ model.add_new_mask(
183
+ inference_state=inference_state, frame_idx=0, obj_id=i, mask=mask
184
+ )
185
+ masks_generator = model.propagate_in_video(inference_state)
186
+
187
+ detections = []
188
+ for i, tracker_ids, mask_logits in masks_generator:
189
+ masks = (mask_logits > 0.0).cpu().numpy().astype(np.uint8)
190
+ for id, mask in zip(tracker_ids, masks):
191
+ mask = mask.squeeze().astype(np.uint8)
192
+ xyxy = mask_to_xyxy(mask)
193
+ if not xyxy: # mask is empty
194
+ logger.debug(f"track_id {id} is missing mask at frame {i}")
195
+ continue
196
+ x0, y0, x1, y1 = xyxy
197
+ det = { # miro's detections format for videos
198
+ "frame": i,
199
+ "track_id": id,
200
+ "x": x0 / w,
201
+ "y": y0 / h,
202
+ "w": (x1 - x0) / w,
203
+ "h": (y1 - y0) / h,
204
+ "conf": 1,
205
+ }
206
+ if not drop_mask:
207
+ det["mask_b64"] = b64_mask_encode(mask)
208
+ detections.append(det)
209
+
210
+ if do_tidy_up:
211
+ # remove vframes_dir
212
+ shutil.rmtree(vframes_dir)
213
+
214
+ return detections
toolbox/mask_encoding.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64, os, io, random, time
2
+ from PIL import Image
3
+ import numpy as np
4
+
5
+ def b64_mask_encode(mask_np_arr, tmp_dir = '/tmp/miro/mask_encoding/'):
6
+ '''
7
+ turn a binary mask in numpy into a base64 string
8
+ '''
9
+ mask_im = Image.fromarray(np.array(mask_np_arr).astype(np.uint8)*255)
10
+ mask_im = mask_im.convert(mode = '1') # convert to 1bit image
11
+
12
+ if not os.path.isdir(tmp_dir):
13
+ print(f'b64_mask_encode: making tmp dir for mask encoding...')
14
+ os.makedirs(tmp_dir)
15
+
16
+ timestr = time.strftime("%Y%m%d-%H%M%S")
17
+ hash_str = random.getrandbits(128)
18
+ tmp_fname = tmp_dir + f'{timestr}_{hash_str}_mask.png'
19
+ mask_im.save(tmp_fname)
20
+ return base64.b64encode(open(tmp_fname, 'rb').read())
21
+
22
+ def b64_mask_decode(b64_string):
23
+ '''
24
+ decode a base64 string back to a binary mask numpy array
25
+ '''
26
+ im_bytes = base64.b64decode(b64_string)
27
+ im_decode = Image.open(io.BytesIO(im_bytes))
28
+ return np.array(im_decode)
29
+
30
+ def get_true_mask(mask_arr, im_w_h:tuple, x0, y0, x1, y1):
31
+ '''
32
+ decode the mask of CM output to get a mask that's the same size as source im
33
+ '''
34
+ if x0 > im_w_h[0] or x1 > im_w_h[0] or y0 > im_w_h[1] or y1 > im_w_h[1]:
35
+ raise ValueError(f'get_true_mask: Xs and Ys exceeded im_w_h bound: {im_w_h}')
36
+
37
+ if mask_arr.shape != (y1 - y0, x1 - x0):
38
+ raise ValueError(f'get_true_mask: Bounding Box h: {y1-y0} w: {x1-x0} does not match mask shape: {mask_arr.shape}')
39
+
40
+ w, h = im_w_h
41
+ mask = np.zeros((h,w), dtype = np.uint8)
42
+ mask[y0:y1, x0:x1] = mask_arr
43
+ return mask
toolbox/vid_utils.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from tqdm import tqdm
3
+ import cv2, imageio, ffmpeg, os, time, shutil
4
+
5
+ def VidInfo(vid_path):
6
+ '''
7
+ returns a dictonary of 'duration', 'fps', 'frame_count', 'frame_height', 'frame_width',
8
+ 'format', 'fourcc'
9
+ '''
10
+ vcap = cv2.VideoCapture(vid_path)
11
+ if not vcap.isOpened():
12
+ # cannot read video
13
+ if vid_path.startswith('https://'):
14
+ # likely a ffmpeg without open-ssl support issue
15
+ # https://github.com/opencv/opencv-python/issues/204
16
+ return VidInfo(vid_path.replace('https://','http://'))
17
+ else:
18
+ return None
19
+
20
+ info_dict = {
21
+ 'fps' : round(vcap.get(cv2.CAP_PROP_FPS),2), #int(vcap.get(cv2.CAP_PROP_FPS)),
22
+ 'frame_count': int(vcap.get(cv2.CAP_PROP_FRAME_COUNT)), # number of frames should integars
23
+ 'duration': round(
24
+ int(vcap.get(cv2.CAP_PROP_FRAME_COUNT)) / vcap.get(cv2.CAP_PROP_FPS),
25
+ 2), # round number of seconds to 2 decimals
26
+ 'frame_height': vcap.get(cv2.CAP_PROP_FRAME_HEIGHT),
27
+ 'frame_width': vcap.get(cv2.CAP_PROP_FRAME_WIDTH),
28
+ 'format': vcap.get(cv2.CAP_PROP_FORMAT),
29
+ 'fourcc': vcap.get(cv2.CAP_PROP_FOURCC)
30
+ }
31
+ vcap.release()
32
+ return info_dict
33
+
34
+ def VidReader(vid_path, verbose = False, use_imageio = True):
35
+ '''
36
+ given a video file path, returns a list of images
37
+ Args:
38
+ vid_path: a MP4 file path
39
+ use_imageio: if true, function returns a ImageIO reader object (RGB);
40
+ otherwise, a list of CV2 array will be returned
41
+ '''
42
+
43
+ if use_imageio:
44
+ vid = imageio.get_reader(vid_path, 'ffmpeg')
45
+ return vid
46
+
47
+ vcap = cv2.VideoCapture(vid_path)
48
+ s_time = time.time()
49
+
50
+ # try to determine the total number of frames in Vid
51
+ frame_count = int(vcap.get(cv2.CAP_PROP_FRAME_COUNT))
52
+ frame_rate = int(vcap.get(cv2.CAP_PROP_FPS))
53
+ if verbose:
54
+ print(f'\t{frame_count} total frames in video {vid_path}')
55
+ print(f'\t\t FPS: {frame_rate}')
56
+ print(f'\t\t Video Duration: {frame_count/ frame_rate}s')
57
+
58
+ # loop over frames
59
+ results = []
60
+ for i in tqdm(range(frame_count)):
61
+ grabbed, frame = vcap.read()
62
+ if grabbed:
63
+ results.append(frame)
64
+
65
+ # Output
66
+ r_time = "{:.2f}".format(time.time() - s_time)
67
+ if verbose:
68
+ print(f'\t{vid_path} loaded in {r_time} ({frame_count/float(r_time)} fps)')
69
+ vcap.release()
70
+ return results
71
+
72
+ def get_vid_frame(n, vid_path):
73
+ '''
74
+ return frame(s) in np.array specified by i
75
+ Args:
76
+ n: list of int
77
+ '''
78
+ vreader = VidReader(vid_path, verbose = False, use_imageio = True)
79
+ fcount = VidInfo(vid_path)['frame_count']
80
+
81
+ if type(n) == list:
82
+ return [vreader.get_data(i) if i in range(fcount) else None for i in n]
83
+ elif type(n) == int:
84
+ return vreader.get_data(n) if n in range(fcount) else None
85
+ else:
86
+ raise ValueError(f'n must be either int or list, {type(n)} detected.')
87
+
88
+ def vid_slicer(vid_path, output_path, start_frame, end_frame, keep_audio = False, overwrite = False):
89
+ '''
90
+ ref https://github.com/kkroening/ffmpeg-python/issues/184#issuecomment-493847192
91
+ '''
92
+ if not( os.path.isdir(os.path.dirname(output_path))):
93
+ raise ValueError(f'output_path directory does not exists: {os.path.dirname(output_path)}')
94
+
95
+ if os.path.isfile(output_path) and not overwrite:
96
+ warnings.warn(f'{output_path} already exists but overwrite switch is False, nothing done.')
97
+ return None
98
+
99
+ input_vid = ffmpeg.input(vid_path)
100
+ vid_info = VidInfo(vid_path)
101
+ end_frame += 1
102
+
103
+ if keep_audio:
104
+ vid = (
105
+ input_vid
106
+ .trim(start_frame = start_frame, end_frame = end_frame)
107
+ .setpts('PTS-STARTPTS')
108
+ )
109
+ aud = (
110
+ input_vid
111
+ .filter_('atrim', start = start_frame / vid_info['fps'], end = end_frame / vid_info['fps'])
112
+ .filter_('asetpts', 'PTS-STARTPTS')
113
+ )
114
+ joined = ffmpeg.concat(vid, aud, v = 1, a =1).node
115
+ output = ffmpeg.output(joined[0], joined[1], f'{output_path}').overwrite_output()
116
+ output.run()
117
+ else:
118
+ (
119
+ input_vid
120
+ .trim (start_frame = start_frame, end_frame = end_frame )
121
+ .setpts ('PTS-STARTPTS')
122
+ .output (f'{output_path}')
123
+ .overwrite_output()
124
+ .run()
125
+ )
126
+ return output_path
127
+
128
+ def vid_resize(vid_path, output_path, width, overwrite = False):
129
+ '''
130
+ use ffmpeg to resize the input video to the width given, keeping aspect ratio
131
+ '''
132
+ if not( os.path.isdir(os.path.dirname(output_path))):
133
+ raise ValueError(f'output_path directory does not exists: {os.path.dirname(output_path)}')
134
+
135
+ if os.path.isfile(output_path) and not overwrite:
136
+ warnings.warn(f'{output_path} already exists but overwrite switch is False, nothing done.')
137
+ return None
138
+
139
+ input_vid = ffmpeg.input(vid_path)
140
+ vid = (
141
+ input_vid
142
+ .filter('scale', width, -1)
143
+ .output(output_path)
144
+ .overwrite_output()
145
+ .run()
146
+ )
147
+ return output_path
148
+
149
+ def vid_reduce_framerate(vid_path, output_path, new_fps, overwrite = False):
150
+ '''
151
+ use ffmpeg to resize the input video to the width given, keeping aspect ratio
152
+ '''
153
+ if not( os.path.isdir(os.path.dirname(output_path))):
154
+ raise ValueError(f'output_path directory does not exists: {os.path.dirname(output_path)}')
155
+
156
+ if os.path.isfile(output_path) and not overwrite:
157
+ warnings.warn(f'{output_path} already exists but overwrite switch is False, nothing done.')
158
+ return None
159
+
160
+ input_vid = ffmpeg.input(vid_path)
161
+ vid = (
162
+ input_vid
163
+ .filter('fps', fps = new_fps, round = 'up')
164
+ .output(output_path)
165
+ .overwrite_output()
166
+ .run()
167
+ )
168
+ return output_path
169
+
170
+ def seek_frame_count(VidReader, cv2_frame_count, guess_within = 0.1,
171
+ seek_rate = 1, bDebug = False):
172
+ '''
173
+ imageio/ffmpeg frame count could be different than cv2. this function
174
+ returns the true frame count in the given vid reader. Returns None if frame
175
+ count can't be determined
176
+ Args:
177
+ VidReader: ImageIO video reader object with method .get_data()
178
+ cv2_frame_count: frame count from cv2
179
+ guess_within: look for actual frame count within X% of cv2_frame_count
180
+ '''
181
+ max_guess = int(cv2_frame_count * (1-guess_within))
182
+ seek_rate = max(seek_rate, 1)
183
+ pbar = reversed(range(max_guess, cv2_frame_count, seek_rate))
184
+ if bDebug:
185
+ pbar = tqdm(pbar, desc = f'seeking frame')
186
+ print(f'seeking from {max_guess} to {cv2_frame_count} with seek_rate of {seek_rate}')
187
+
188
+ for i in pbar:
189
+ try:
190
+ im = VidReader.get_data(i)
191
+ except IndexError:
192
+ if bDebug:
193
+ print(f'{i} not found.')
194
+ continue
195
+ # Frame Found
196
+ if i+1 == cv2_frame_count:
197
+ print(f'seek_frame_count: found frame count at {i+1}')
198
+ return i + 1
199
+ else:
200
+ return seek_frame_count(VidReader, cv2_frame_count = i + seek_rate,
201
+ guess_within= seek_rate / (i + seek_rate),
202
+ seek_rate= int(seek_rate/2),
203
+ bDebug = bDebug)
204
+ return None
205
+
206
+ def VidWriter(lFrames, output_path, strFourcc = 'MP4V', verbose = False, intFPS = 20, crf = None,
207
+ use_imageio = False):
208
+ '''
209
+ Given a list of images in numpy array format, it outputs a MP4 file
210
+ Args:
211
+ lFrames: list of numpy arrays or filename
212
+ output_path: a MP4 file path
213
+ strFourcc: four letter video codec; XVID is more preferable. MJPG results in high size video. X264 gives very small size video; see https://opencv-python-tutroals.readthedocs.io/en/latest/py_tutorials/py_gui/py_video_display/py_video_display.html
214
+ crf: Constant Rate Factor for ffmpeg video compression
215
+ '''
216
+ s_time = time.time()
217
+
218
+ if not output_path.endswith('.mp4'):
219
+ raise ValueError(f'VidWriter: only mp4 video output supported.')
220
+
221
+ if crf:
222
+ crf = int(crf)
223
+ if crf > 24 or crf < 18:
224
+ raise ValueError(f'VidWriter: crf must be between 18 and 24')
225
+
226
+ if not os.path.exists(os.path.dirname(output_path)):
227
+ output_dir = os.path.dirname(output_path)
228
+ print(f'\t{output_dir} does not exist.\n\tCreating video file output directory: {output_dir}')
229
+ os.makedirs(output_dir)
230
+
231
+ if use_imageio:
232
+ writer = imageio.get_writer(output_path, fps = intFPS)
233
+ for frame in tqdm(lFrames, desc = "Writing video using ImageIO"):
234
+ if not type(frame) == np.ndarray:
235
+ # read from filename
236
+ if not os.path.isfile(frame):
237
+ raise ValueError(f'VidWriter: lFrames must be list of images (np.array) or filenames')
238
+ frame = imageio.imread(frame)
239
+
240
+ writer.append_data(frame)
241
+ writer.close()
242
+ else:
243
+ #init OpenCV Vid Writer:
244
+ H , W = lFrames[0].shape[:2]
245
+ #fourcc = cv2.VideoWriter_fourcc(*'MP4V')
246
+ fourcc = cv2.VideoWriter_fourcc(*strFourcc)
247
+ if verbose:
248
+ print(f'\tEncoding using fourcc: {strFourcc}')
249
+ writer = cv2.VideoWriter(output_path, fourcc, fps = intFPS, frameSize = (W, H), isColor = True)
250
+
251
+ for frame in tqdm(lFrames, desc = "Writing video using OpenCV"):
252
+ writer.write(frame)
253
+ writer.release()
254
+
255
+ # Output
256
+ r_time = "{:.2f}".format( max(time.time() - s_time, 0.01))
257
+ if verbose:
258
+ print(f'\t{output_path} written in {r_time} ({len(lFrames)/float(r_time)} fps)')
259
+
260
+ if crf:
261
+ if verbose:
262
+ print(f'\tCompressing {output_path} with FFmpeg using crf: {crf}')
263
+
264
+ isCompressed = VidCompress(output_path, crf = crf, use_ffmpy = False)
265
+
266
+ if verbose:
267
+ print(f'\tCompressed: {isCompressed}')
268
+
269
+ return output_path
270
+
271
+ def im_dir_to_video(im_dir, output_path, fps, tup_im_extension = ('.jpg'),
272
+ max_long_edge = 600, filename_len = 6, pixel_format = 'yuv420p',
273
+ tqdm_func = tqdm):
274
+ '''turn a directory of images into video using ffmpeg
275
+ ref: https://github.com/kkroening/ffmpeg-python/issues/95#issuecomment-401428324
276
+ Args:
277
+ pixel_format: for list of supported formats see https://en.wikipedia.org/wiki/FFmpeg#Pixel_formats
278
+ filename_len: ensure frame number are zero padded; 0 will skip this step
279
+ '''
280
+ if filename_len:
281
+ # Ensure Filenames are Zero padded
282
+ l_im_fp = [f for f in os.listdir(im_dir) if f.endswith(tup_im_extension)]
283
+ l_im_fp = sorted(l_im_fp, key = lambda f: int(f.split('.')[0]))
284
+ for f in tqdm_func(l_im_fp, desc = 'ensuring image filenames are zero padded'):
285
+ fname, fext = os.path.splitext(f)
286
+ padded_f = fname.zfill(filename_len) + fext
287
+ if not os.path.isfile(os.path.join(im_dir,padded_f)):
288
+ shutil.move(os.path.join(im_dir, f), os.path.join(im_dir, padded_f))
289
+ # removed symlink to f as it will duplicate the frames in video generation
290
+ # os.symlink(src = os.path.join(im_dir, padded_f), dst = os.path.join(im_dir, f))
291
+ #TODO: ensure image size are divisible by 2
292
+
293
+ im_dir += '' if im_dir.endswith('/') else '/'
294
+ im_stream_string = f'{im_dir}*.jpg'
295
+ # we need to escape special characters
296
+ im_stream_string = im_stream_string.translate(
297
+ str.maketrans(
298
+ {'[': r'\[',
299
+ ']': r'\]'})
300
+ )
301
+ r = (
302
+ ffmpeg
303
+ .input(im_stream_string, pattern_type = 'glob', framerate=fps)
304
+ .filter('format', pixel_format)
305
+ # .filter('pad', 'ceil(iw/2)*2:ceil(ih/2)*2')
306
+ .output(output_path)
307
+ .run()
308
+ )
309
+ return output_path
310
+ #
311
+ # def VidCompress(input_path, output_path = None, crf = 24, use_ffmpy = False):
312
+ # '''
313
+ # Compress input_path video (mp4 only) using ffmpy
314
+ # crf: Constant Rate Factor for ffmpeg video compression, must be between 18 and 24
315
+ # use_ffmpy: use ffmpy instead of commandline call to ffmpeg
316
+ # '''
317
+ # if not input_path.endswith('.mp4'):
318
+ # print(f'\tFATAL: only mp4 videos supported.')
319
+ # return None
320
+ #
321
+ # output_fname = output_path if output_path else input_path
322
+ # tmp_fname = input_path.replace(".mp4","_tmp.mp4")
323
+ # os.rename(input_path, tmp_fname)
324
+ #
325
+ # try:
326
+ # if not use_ffmpy:
327
+ # #os.popen(f'ffmpeg -i {tmp_fname} -vcodec libx264 -crf {crf} {output_fname}')
328
+ #
329
+ # cmdOut = subprocess.Popen(['ffmpeg', '-i', tmp_fname, '-vcodec', 'libx264', '-crf', str(crf), output_fname],
330
+ # stdout = subprocess.PIPE,
331
+ # stderr = subprocess.STDOUT)
332
+ # stdout, stderr = cmdOut.communicate()
333
+ # if not stderr:
334
+ # os.remove(tmp_fname)
335
+ # return True
336
+ # else:
337
+ # return False
338
+ # else:
339
+ # ff = FFmpeg(
340
+ # inputs = {tmp_fname : None},
341
+ # outputs = {output_fname : f'-vcodec libx264 -crf {crf}'}
342
+ # )
343
+ # ff.run()
344
+ #
345
+ # os.remove(tmp_fname)
346
+ # return True
347
+ #
348
+ # except OSError as e:
349
+ # print(f'\tWARNING: Compression Failed; OSError\n\tLikely out of RAM\n\tError Msg: {e}')
350
+ # os.rename(tmp_fname, output_fname)
351
+ # return False