John Ho commited on
Commit
f8e7037
·
1 Parent(s): 1d8163a

added interface for video

Browse files
Files changed (1) hide show
  1. app.py +52 -25
app.py CHANGED
@@ -11,6 +11,7 @@ from samv2_handler import (
11
  )
12
  from PIL import Image
13
  from typing import Union
 
14
 
15
  torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
16
  if torch.cuda.get_device_properties(0).major >= 8:
@@ -75,26 +76,39 @@ def process_image(
75
 
76
  Args:
77
  im: Pillow Image
78
- object_name: the object you would like to detect
79
- mode: point or object_detection
 
 
80
  Returns:
81
- list: a list of masks
82
  """
 
83
  logger.debug(f"bboxes type: {type(bboxes)}, value: {bboxes}")
84
- bboxes = (
85
- json.loads(bboxes)
86
- if isinstance(bboxes, str) and type(bboxes) != type(None)
87
- else bboxes
88
- )
89
- assert bboxes or points, f"either bboxes or points must be provided."
90
- if points:
91
  assert len(points) == len(
92
  point_labels
93
  ), f"{len(points)} points provided but there are {len(point_labels)} labels."
94
 
 
 
 
 
 
 
 
95
  model = load_im_model(variant=variant)
96
  return run_sam_im_inference(
97
- model, image=im, bboxes=bboxes, get_pil_mask=False, b64_encode_mask=True
 
 
 
 
 
 
98
  )
99
 
100
 
@@ -112,20 +126,14 @@ def process_video(video_path: str, variant: str, masks: Union[list, str]):
112
  Returns:
113
  list: a list of masks
114
  """
115
- bboxes = (
116
- json.loads(bboxes)
117
- if isinstance(bboxes, str) and type(bboxes) != type(None)
118
- else bboxes
119
- )
120
- assert bboxes or points, f"either bboxes or points must be provided."
121
- if points:
122
- assert len(points) == len(
123
- point_labels
124
- ), f"{len(points)} points provided but there are {len(point_labels)} labels."
125
-
126
- model = load_im_model(variant=variant)
127
- return run_sam_im_inference(
128
- model, image=im, bboxes=bboxes, get_pil_mask=False, b64_encode_mask=True
129
  )
130
 
131
 
@@ -155,6 +163,25 @@ with gr.Blocks() as demo:
155
  outputs=gr.JSON(label="Output JSON"),
156
  title="SAM2 for Images",
157
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  # Download checkpoints before launching the app
160
  download_checkpoints()
 
11
  )
12
  from PIL import Image
13
  from typing import Union
14
+ import numpy as np
15
 
16
  torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
17
  if torch.cuda.get_device_properties(0).major >= 8:
 
76
 
77
  Args:
78
  im: Pillow Image
79
+ variant: SAM2 model variant
80
+ bboxes: bounding boxes of objects to segment, expressed as a list of dicts: [{"x0":..., "y0":..., "x1":..., "y1":...}, ...]
81
+ points: points of objects to segment, expressed as a list of dicts [{"x":..., "y":...}, ...]
82
+ point_labels: list of integar
83
  Returns:
84
+ list: a list of masks in the form of bit64 encoded strings
85
  """
86
+ # input validation
87
  logger.debug(f"bboxes type: {type(bboxes)}, value: {bboxes}")
88
+ has_bboxes = type(bboxes) != type(None) and bboxes != ""
89
+ has_points = type(points) != type(None) and points != ""
90
+ assert has_bboxes or has_points, f"either bboxes or points must be provided."
91
+ if has_points:
 
 
 
92
  assert len(points) == len(
93
  point_labels
94
  ), f"{len(points)} points provided but there are {len(point_labels)} labels."
95
 
96
+ bboxes = json.loads(bboxes) if isinstance(bboxes, str) and has_bboxes else bboxes
97
+ points = json.loads(points) if isinstance(points, str) and has_points else points
98
+ point_labels = (
99
+ json.loads(point_labels)
100
+ if isinstance(point_labels, str) and has_points
101
+ else point_labels
102
+ )
103
  model = load_im_model(variant=variant)
104
  return run_sam_im_inference(
105
+ model,
106
+ image=im,
107
+ bboxes=bboxes,
108
+ points=points,
109
+ point_labels=point_labels,
110
+ get_pil_mask=False,
111
+ b64_encode_mask=True,
112
  )
113
 
114
 
 
126
  Returns:
127
  list: a list of masks
128
  """
129
+ model = load_vid_model(variant=variant)
130
+ return run_sam_video_inference(
131
+ model,
132
+ video_path=video_path,
133
+ masks=np.array(masks),
134
+ device="cuda",
135
+ do_tidy_up=True,
136
+ drop_mask=False,
 
 
 
 
 
 
137
  )
138
 
139
 
 
163
  outputs=gr.JSON(label="Output JSON"),
164
  title="SAM2 for Images",
165
  )
166
+ with gr.Tab("Videos"):
167
+ gr.Interface(
168
+ fn=process_video,
169
+ inputs=[
170
+ gr.Video(label="Input Video"),
171
+ gr.Dropdown(
172
+ label="Model Variant",
173
+ choices=["tiny", "small", "base_plus", "large"],
174
+ ),
175
+ gr.Textbox(
176
+ label='Masks for Objects of Interest in the First Frame (JSON list of dicts: [{"x0":..., "y0":..., "x1":..., "y1":...}, ...])',
177
+ value=None,
178
+ lines=5,
179
+ placeholder='JSON list of dicts: [{"x0":..., "y0":..., "x1":..., "y1":...}, ...]',
180
+ ),
181
+ ],
182
+ outputs=gr.JSON(label="Output JSON"),
183
+ title="SAM2 for Videos",
184
+ )
185
 
186
  # Download checkpoints before launching the app
187
  download_checkpoints()