John Ho commited on
Commit
af8b4a0
·
1 Parent(s): b2e3d42

added video inference imports

Browse files
Files changed (1) hide show
  1. app.py +55 -2
app.py CHANGED
@@ -2,7 +2,13 @@ import gradio as gr
2
  import spaces, torch, os, requests, json
3
  from pathlib import Path
4
  from tqdm import tqdm
5
- from samv2_handler import load_sam_image_model, run_sam_im_inference
 
 
 
 
 
 
6
  from PIL import Image
7
  from typing import Union
8
 
@@ -49,10 +55,53 @@ def load_im_model(variant, auto_mask_gen: bool = False):
49
  )
50
 
51
 
 
 
 
 
 
52
  @spaces.GPU
53
  @torch.inference_mode()
54
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
55
- def detect_image(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  im: Image.Image,
57
  variant: str,
58
  bboxes: Union[list, str] = None,
@@ -98,6 +147,9 @@ with gr.Blocks() as demo:
98
  ),
99
  gr.Textbox(
100
  label='Bounding Boxes (JSON list of dicts: [{"x0":..., "y0":..., "x1":..., "y1":...}, ...])',
 
 
 
101
  ),
102
  gr.Textbox(
103
  label='Points (JSON list of dicts: [{"x":..., "y":...}, ...])',
@@ -109,6 +161,7 @@ with gr.Blocks() as demo:
109
  outputs=gr.JSON(label="Output JSON"),
110
  title="SAM2 for Images",
111
  )
 
112
  # Download checkpoints before launching the app
113
  download_checkpoints()
114
  demo.launch(
 
2
  import spaces, torch, os, requests, json
3
  from pathlib import Path
4
  from tqdm import tqdm
5
+ from samv2_handler import (
6
+ load_sam_image_model,
7
+ run_sam_im_inference,
8
+ load_sam_video_model,
9
+ run_sam_video_inference,
10
+ logger,
11
+ )
12
  from PIL import Image
13
  from typing import Union
14
 
 
55
  )
56
 
57
 
58
+ @spaces.GPU
59
+ def load_vid_model(variant):
60
+ return load_sam_video_model(variant=variant, device="cuda")
61
+
62
+
63
  @spaces.GPU
64
  @torch.inference_mode()
65
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
66
+ def segment_image(
67
+ im: Image.Image,
68
+ variant: str,
69
+ bboxes: Union[list, str] = None,
70
+ points: Union[list, str] = None,
71
+ point_labels: Union[list, str] = None,
72
+ ):
73
+ """
74
+ SAM2 Image Segmentation
75
+
76
+ Args:
77
+ im: Pillow Image
78
+ object_name: the object you would like to detect
79
+ mode: point or object_detection
80
+ Returns:
81
+ list: a list of masks
82
+ """
83
+ logger.debug(f"bboxes type: {type(bboxes)}, value: {bboxes}")
84
+ bboxes = (
85
+ json.loads(bboxes)
86
+ if isinstance(bboxes, str) and type(bboxes) != type(None)
87
+ else bboxes
88
+ )
89
+ assert bboxes or points, f"either bboxes or points must be provided."
90
+ if points:
91
+ assert len(points) == len(
92
+ point_labels
93
+ ), f"{len(points)} points provided but there are {len(point_labels)} labels."
94
+
95
+ model = load_im_model(variant=variant)
96
+ return run_sam_im_inference(
97
+ model, image=im, bboxes=bboxes, get_pil_mask=False, b64_encode_mask=True
98
+ )
99
+
100
+
101
+ @spaces.GPU
102
+ @torch.inference_mode()
103
+ @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
104
+ def segment_video(
105
  im: Image.Image,
106
  variant: str,
107
  bboxes: Union[list, str] = None,
 
147
  ),
148
  gr.Textbox(
149
  label='Bounding Boxes (JSON list of dicts: [{"x0":..., "y0":..., "x1":..., "y1":...}, ...])',
150
+ value=None,
151
+ lines=5,
152
+ placeholder='JSON list of dicts: [{"x0":..., "y0":..., "x1":..., "y1":...}, ...]',
153
  ),
154
  gr.Textbox(
155
  label='Points (JSON list of dicts: [{"x":..., "y":...}, ...])',
 
161
  outputs=gr.JSON(label="Output JSON"),
162
  title="SAM2 for Images",
163
  )
164
+
165
  # Download checkpoints before launching the app
166
  download_checkpoints()
167
  demo.launch(