File size: 7,914 Bytes
e7334c8 74821c0 95ca774 af8b4a0 95ca774 d81f6c9 7afaf9e e7334c8 74821c0 e7334c8 af8b4a0 e7334c8 7afaf9e 1d8163a af8b4a0 f8e7037 af8b4a0 f8e7037 af8b4a0 f8e7037 33aa4bb f8e7037 33aa4bb af8b4a0 f8e7037 33aa4bb f8e7037 33aa4bb af8b4a0 f8e7037 af8b4a0 579e65b af8b4a0 a07c563 579e65b aaa1b00 a07c563 e7334c8 1d8163a e7334c8 1d8163a 95ca774 e7334c8 59822ae e7334c8 f8e7037 95ca774 f8e7037 95ca774 f8e7037 a07c563 aaa1b00 579e65b e7334c8 1d8163a e7334c8 16a4c7b 33aa4bb af8b4a0 354d431 33aa4bb e7334c8 33aa4bb e7334c8 05f7921 e7334c8 f8e7037 59822ae f8e7037 59822ae f8e7037 aaa1b00 579e65b aaa1b00 579e65b aaa1b00 f8e7037 05f7921 f8e7037 af8b4a0 74821c0 e7334c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
import gradio as gr
import spaces, torch, os, requests, json
from pathlib import Path
from tqdm import tqdm
from PIL import Image
from typing import Union
import numpy as np
from samv2_handler import (
load_sam_image_model,
run_sam_im_inference,
load_sam_video_model,
run_sam_video_inference,
logger,
)
from toolbox.mask_encoding import b64_mask_decode
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
def download_checkpoints():
checkpoint_dir = Path("checkpoints")
checkpoint_dir.mkdir(exist_ok=True)
# Read URLs from the file
with open(checkpoint_dir / "sam2_checkpoints_url.txt", "r") as f:
urls = [url.strip() for url in f.readlines() if url.strip()]
for url in urls:
filename = url.split("/")[-1]
output_path = checkpoint_dir / filename
if output_path.exists():
print(f"Checkpoint {filename} already exists, skipping...")
continue
print(f"Downloading {filename}...")
response = requests.get(url, stream=True)
total_size = int(response.headers.get("content-length", 0))
with open(output_path, "wb") as f:
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
pbar.update(len(chunk))
print(f"Downloaded {filename} successfully!")
@spaces.GPU
def load_im_model(variant, auto_mask_gen: bool = False):
return load_sam_image_model(
variant=variant, device="cuda", auto_mask_gen=auto_mask_gen
)
@spaces.GPU
def load_vid_model(variant):
return load_sam_video_model(variant=variant, device="cuda")
@spaces.GPU
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def process_image(
im: Image.Image,
variant: str,
bboxes: Union[list, str] = None,
points: Union[list, str] = None,
point_labels: Union[list, str] = None,
):
"""
SAM2 Image Segmentation
Args:
im: Pillow Image
variant: SAM2 model variant
bboxes: bounding boxes of objects to segment, expressed as a list of dicts: [{"x0":..., "y0":..., "x1":..., "y1":...}, ...]
points: points of objects to segment, expressed as a list of dicts [{"x":..., "y":...}, ...]
point_labels: list of integar
Returns:
list: a list of masks in the form of bit64 encoded strings
"""
# input validation
has_bboxes = type(bboxes) != type(None) and bboxes != ""
has_points = type(points) != type(None) and points != ""
has_point_labels = type(point_labels) != type(None) and point_labels != ""
assert has_bboxes or has_points, f"either bboxes or points must be provided."
if has_points:
assert has_point_labels, f"point_labels is required if points are provided."
bboxes = json.loads(bboxes) if isinstance(bboxes, str) and has_bboxes else bboxes
points = json.loads(points) if isinstance(points, str) and has_points else points
point_labels = (
json.loads(point_labels)
if isinstance(point_labels, str) and has_point_labels
else point_labels
)
if has_points:
assert len(points) == len(
point_labels
), f"{len(points)} points provided but there are {len(point_labels)} labels."
model = load_im_model(variant=variant)
return run_sam_im_inference(
model,
image=im,
bboxes=bboxes,
points=points,
point_labels=point_labels,
get_pil_mask=False,
b64_encode_mask=True,
)
@spaces.GPU(
duration=120
) # user must have 2-minute of inference time left at the time of calling
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def process_video(
video_path: str,
variant: str,
masks: Union[list, str],
drop_masks: bool = False,
ref_frame_idx: int = 0,
async_frame_load: bool = True,
):
"""
SAM2 Video Segmentation
Args:
video_path: path to video object
variant: SAMv2's model variant
masks: a list of b64 encoded masks for the first frame of the video, indicating the objects to be tracked
Returns:
list: a list of tracked objects expressed as a list of dictionary [{"frame":..., "track_id":..., "x":..., "y":...,"w":...,"h":...,"conf":..., "mask_b64":...},...]
"""
model = load_vid_model(variant=variant)
masks = json.loads(masks) if isinstance(masks, str) else masks
logger.debug(f"masks---\n{masks}")
masks = [
m[2:-1].encode() if m.startswith("b'") and m.endswith("'") else m for m in masks
] # expect the b'' literal to be included
masks = np.array([b64_mask_decode(m).astype(np.uint8) for m in masks])
logger.debug(f"masks---\n{masks}")
return run_sam_video_inference(
model,
video_path=video_path,
masks=masks,
device="cuda",
do_tidy_up=True,
drop_mask=drop_masks,
async_frame_load=async_frame_load,
ref_frame_idx=ref_frame_idx,
)
with gr.Blocks() as demo:
with gr.Tab("Images"):
gr.Interface(
fn=process_image,
inputs=[
gr.Image(label="Input Image", type="pil"),
gr.Dropdown(
label="Model Variant",
choices=["tiny", "small", "base_plus", "large"],
),
gr.Textbox(
label="Bounding Boxes",
value=None,
lines=5,
placeholder='JSON list of dicts: [{"x0":..., "y0":..., "x1":..., "y1":...}, ...]',
),
gr.Textbox(
label="Points",
lines=3,
placeholder='JSON list of dicts: [{"x":..., "y":...}, ...]',
),
gr.Textbox(label="Points' Labels", placeholder="JSON List of Integars"),
],
outputs=gr.JSON(label="Output JSON"),
title="SAM2 for Images",
api_name="process_image",
)
with gr.Tab("Videos"):
gr.Interface(
fn=process_video,
inputs=[
gr.Video(label="Input Video"),
gr.Dropdown(
label="Model Variant",
choices=["tiny", "small", "base_plus", "large"],
),
gr.Textbox(
label="Masks for Objects of Interest in the First Frame",
value=None,
lines=5,
placeholder="""
JSON list of base64 encoded masks, e.g.: ["b'iVBORw0KGgoAAAANSUhEUgAABDgAAAeAAQAAAAADGtqnAAAXz...'",...]
""",
),
gr.Checkbox(
label="Drop Masks",
info="remove base64 encoded masks from result JSON",
value=True,
),
gr.Number(
label="Reference Frame Index",
info="frame index for the provided object masks",
value=0,
precision=0,
),
gr.Checkbox(
label="async frame load",
info="start inference in parallel to frame loading",
),
],
outputs=gr.JSON(label="Output JSON"),
title="SAM2 for Videos",
api_name="process_video",
)
# Download checkpoints before launching the app
download_checkpoints()
demo.launch(
mcp_server=True, app_kwargs={"docs_url": "/docs"} # add FastAPI Swagger API Docs
)
|