John Ho commited on
Commit
1d8163a
·
1 Parent(s): af8b4a0

renamed image and video inference functions

Browse files
Files changed (1) hide show
  1. app.py +7 -13
app.py CHANGED
@@ -63,7 +63,7 @@ def load_vid_model(variant):
63
  @spaces.GPU
64
  @torch.inference_mode()
65
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
66
- def segment_image(
67
  im: Image.Image,
68
  variant: str,
69
  bboxes: Union[list, str] = None,
@@ -101,20 +101,14 @@ def segment_image(
101
  @spaces.GPU
102
  @torch.inference_mode()
103
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
104
- def segment_video(
105
- im: Image.Image,
106
- variant: str,
107
- bboxes: Union[list, str] = None,
108
- points: Union[list, str] = None,
109
- point_labels: Union[list, str] = None,
110
- ):
111
  """
112
- SAM2 Image Segmentation
113
 
114
  Args:
115
- im: Pillow Image
116
- object_name: the object you would like to detect
117
- mode: point or object_detection
118
  Returns:
119
  list: a list of masks
120
  """
@@ -138,7 +132,7 @@ def segment_video(
138
  with gr.Blocks() as demo:
139
  with gr.Tab("Images"):
140
  gr.Interface(
141
- fn=detect_image,
142
  inputs=[
143
  gr.Image(label="Input Image", type="pil"),
144
  gr.Dropdown(
 
63
  @spaces.GPU
64
  @torch.inference_mode()
65
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
66
+ def process_image(
67
  im: Image.Image,
68
  variant: str,
69
  bboxes: Union[list, str] = None,
 
101
  @spaces.GPU
102
  @torch.inference_mode()
103
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
104
+ def process_video(video_path: str, variant: str, masks: Union[list, str]):
 
 
 
 
 
 
105
  """
106
+ SAM2 Video Segmentation
107
 
108
  Args:
109
+ video_path: path to video object
110
+ variant: SAMv2's model variant
111
+ masks: a list of masks for the first frame of the video, indicating the objects to be tracked
112
  Returns:
113
  list: a list of masks
114
  """
 
132
  with gr.Blocks() as demo:
133
  with gr.Tab("Images"):
134
  gr.Interface(
135
+ fn=process_image,
136
  inputs=[
137
  gr.Image(label="Input Image", type="pil"),
138
  gr.Dropdown(