John Ho
commited on
Commit
·
af8b4a0
1
Parent(s):
b2e3d42
added video inference imports
Browse files
app.py
CHANGED
@@ -2,7 +2,13 @@ import gradio as gr
|
|
2 |
import spaces, torch, os, requests, json
|
3 |
from pathlib import Path
|
4 |
from tqdm import tqdm
|
5 |
-
from samv2_handler import
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from PIL import Image
|
7 |
from typing import Union
|
8 |
|
@@ -49,10 +55,53 @@ def load_im_model(variant, auto_mask_gen: bool = False):
|
|
49 |
)
|
50 |
|
51 |
|
|
|
|
|
|
|
|
|
|
|
52 |
@spaces.GPU
|
53 |
@torch.inference_mode()
|
54 |
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
55 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
im: Image.Image,
|
57 |
variant: str,
|
58 |
bboxes: Union[list, str] = None,
|
@@ -98,6 +147,9 @@ with gr.Blocks() as demo:
|
|
98 |
),
|
99 |
gr.Textbox(
|
100 |
label='Bounding Boxes (JSON list of dicts: [{"x0":..., "y0":..., "x1":..., "y1":...}, ...])',
|
|
|
|
|
|
|
101 |
),
|
102 |
gr.Textbox(
|
103 |
label='Points (JSON list of dicts: [{"x":..., "y":...}, ...])',
|
@@ -109,6 +161,7 @@ with gr.Blocks() as demo:
|
|
109 |
outputs=gr.JSON(label="Output JSON"),
|
110 |
title="SAM2 for Images",
|
111 |
)
|
|
|
112 |
# Download checkpoints before launching the app
|
113 |
download_checkpoints()
|
114 |
demo.launch(
|
|
|
2 |
import spaces, torch, os, requests, json
|
3 |
from pathlib import Path
|
4 |
from tqdm import tqdm
|
5 |
+
from samv2_handler import (
|
6 |
+
load_sam_image_model,
|
7 |
+
run_sam_im_inference,
|
8 |
+
load_sam_video_model,
|
9 |
+
run_sam_video_inference,
|
10 |
+
logger,
|
11 |
+
)
|
12 |
from PIL import Image
|
13 |
from typing import Union
|
14 |
|
|
|
55 |
)
|
56 |
|
57 |
|
58 |
+
@spaces.GPU
|
59 |
+
def load_vid_model(variant):
|
60 |
+
return load_sam_video_model(variant=variant, device="cuda")
|
61 |
+
|
62 |
+
|
63 |
@spaces.GPU
|
64 |
@torch.inference_mode()
|
65 |
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
66 |
+
def segment_image(
|
67 |
+
im: Image.Image,
|
68 |
+
variant: str,
|
69 |
+
bboxes: Union[list, str] = None,
|
70 |
+
points: Union[list, str] = None,
|
71 |
+
point_labels: Union[list, str] = None,
|
72 |
+
):
|
73 |
+
"""
|
74 |
+
SAM2 Image Segmentation
|
75 |
+
|
76 |
+
Args:
|
77 |
+
im: Pillow Image
|
78 |
+
object_name: the object you would like to detect
|
79 |
+
mode: point or object_detection
|
80 |
+
Returns:
|
81 |
+
list: a list of masks
|
82 |
+
"""
|
83 |
+
logger.debug(f"bboxes type: {type(bboxes)}, value: {bboxes}")
|
84 |
+
bboxes = (
|
85 |
+
json.loads(bboxes)
|
86 |
+
if isinstance(bboxes, str) and type(bboxes) != type(None)
|
87 |
+
else bboxes
|
88 |
+
)
|
89 |
+
assert bboxes or points, f"either bboxes or points must be provided."
|
90 |
+
if points:
|
91 |
+
assert len(points) == len(
|
92 |
+
point_labels
|
93 |
+
), f"{len(points)} points provided but there are {len(point_labels)} labels."
|
94 |
+
|
95 |
+
model = load_im_model(variant=variant)
|
96 |
+
return run_sam_im_inference(
|
97 |
+
model, image=im, bboxes=bboxes, get_pil_mask=False, b64_encode_mask=True
|
98 |
+
)
|
99 |
+
|
100 |
+
|
101 |
+
@spaces.GPU
|
102 |
+
@torch.inference_mode()
|
103 |
+
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
104 |
+
def segment_video(
|
105 |
im: Image.Image,
|
106 |
variant: str,
|
107 |
bboxes: Union[list, str] = None,
|
|
|
147 |
),
|
148 |
gr.Textbox(
|
149 |
label='Bounding Boxes (JSON list of dicts: [{"x0":..., "y0":..., "x1":..., "y1":...}, ...])',
|
150 |
+
value=None,
|
151 |
+
lines=5,
|
152 |
+
placeholder='JSON list of dicts: [{"x0":..., "y0":..., "x1":..., "y1":...}, ...]',
|
153 |
),
|
154 |
gr.Textbox(
|
155 |
label='Points (JSON list of dicts: [{"x":..., "y":...}, ...])',
|
|
|
161 |
outputs=gr.JSON(label="Output JSON"),
|
162 |
title="SAM2 for Images",
|
163 |
)
|
164 |
+
|
165 |
# Download checkpoints before launching the app
|
166 |
download_checkpoints()
|
167 |
demo.launch(
|