|
import gradio as gr |
|
import spaces, torch, os, requests, json |
|
from pathlib import Path |
|
from tqdm import tqdm |
|
from PIL import Image |
|
from typing import Union |
|
import numpy as np |
|
from samv2_handler import ( |
|
load_sam_image_model, |
|
run_sam_im_inference, |
|
load_sam_video_model, |
|
run_sam_video_inference, |
|
logger, |
|
) |
|
from toolbox.mask_encoding import b64_mask_decode |
|
|
|
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() |
|
if torch.cuda.get_device_properties(0).major >= 8: |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
|
def download_checkpoints(): |
|
checkpoint_dir = Path("checkpoints") |
|
checkpoint_dir.mkdir(exist_ok=True) |
|
|
|
|
|
with open(checkpoint_dir / "sam2_checkpoints_url.txt", "r") as f: |
|
urls = [url.strip() for url in f.readlines() if url.strip()] |
|
|
|
for url in urls: |
|
filename = url.split("/")[-1] |
|
output_path = checkpoint_dir / filename |
|
|
|
if output_path.exists(): |
|
print(f"Checkpoint {filename} already exists, skipping...") |
|
continue |
|
|
|
print(f"Downloading {filename}...") |
|
response = requests.get(url, stream=True) |
|
total_size = int(response.headers.get("content-length", 0)) |
|
|
|
with open(output_path, "wb") as f: |
|
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
if chunk: |
|
f.write(chunk) |
|
pbar.update(len(chunk)) |
|
|
|
print(f"Downloaded {filename} successfully!") |
|
|
|
|
|
@spaces.GPU |
|
def load_im_model(variant, auto_mask_gen: bool = False): |
|
return load_sam_image_model( |
|
variant=variant, device="cuda", auto_mask_gen=auto_mask_gen |
|
) |
|
|
|
|
|
@spaces.GPU |
|
def load_vid_model(variant): |
|
return load_sam_video_model(variant=variant, device="cuda") |
|
|
|
|
|
@spaces.GPU |
|
@torch.inference_mode() |
|
@torch.autocast(device_type="cuda", dtype=torch.bfloat16) |
|
def process_image( |
|
im: Image.Image, |
|
variant: str, |
|
bboxes: Union[list, str] = None, |
|
points: Union[list, str] = None, |
|
point_labels: Union[list, str] = None, |
|
): |
|
""" |
|
SAM2 Image Segmentation |
|
|
|
Args: |
|
im: Pillow Image |
|
variant: SAM2 model variant |
|
bboxes: bounding boxes of objects to segment, expressed as a list of dicts: [{"x0":..., "y0":..., "x1":..., "y1":...}, ...] |
|
points: points of objects to segment, expressed as a list of dicts [{"x":..., "y":...}, ...] |
|
point_labels: list of integar |
|
Returns: |
|
list: a list of masks in the form of bit64 encoded strings |
|
""" |
|
|
|
has_bboxes = type(bboxes) != type(None) and bboxes != "" |
|
has_points = type(points) != type(None) and points != "" |
|
has_point_labels = type(point_labels) != type(None) and point_labels != "" |
|
assert has_bboxes or has_points, f"either bboxes or points must be provided." |
|
if has_points: |
|
assert has_point_labels, f"point_labels is required if points are provided." |
|
|
|
bboxes = json.loads(bboxes) if isinstance(bboxes, str) and has_bboxes else bboxes |
|
points = json.loads(points) if isinstance(points, str) and has_points else points |
|
point_labels = ( |
|
json.loads(point_labels) |
|
if isinstance(point_labels, str) and has_point_labels |
|
else point_labels |
|
) |
|
if has_points: |
|
assert len(points) == len( |
|
point_labels |
|
), f"{len(points)} points provided but there are {len(point_labels)} labels." |
|
|
|
model = load_im_model(variant=variant) |
|
return run_sam_im_inference( |
|
model, |
|
image=im, |
|
bboxes=bboxes, |
|
points=points, |
|
point_labels=point_labels, |
|
get_pil_mask=False, |
|
b64_encode_mask=True, |
|
) |
|
|
|
|
|
@spaces.GPU( |
|
duration=120 |
|
) |
|
@torch.inference_mode() |
|
@torch.autocast(device_type="cuda", dtype=torch.bfloat16) |
|
def process_video( |
|
video_path: str, |
|
variant: str, |
|
masks: Union[list, str], |
|
drop_masks: bool = False, |
|
ref_frame_idx: int = 0, |
|
async_frame_load: bool = True, |
|
): |
|
""" |
|
SAM2 Video Segmentation |
|
|
|
Args: |
|
video_path: path to video object |
|
variant: SAMv2's model variant |
|
masks: a list of base64 encoded masks for the reference frame, indicating the objects to be tracked |
|
drop_masks: whether to include the base64 encoded mask for each tracked object, if not then only bounding box information will be available |
|
ref_frame_idx: the frame index of the reference frame |
|
async_frame_load: whether to load frames asyncholously while doing video propogation which will improve inference time |
|
Returns: |
|
list: a list of tracked objects expressed as a list of dictionary [{"frame":..., "track_id":..., "x":..., "y":...,"w":...,"h":...,"conf":..., "mask_b64":...},...] |
|
""" |
|
model = load_vid_model(variant=variant) |
|
masks = json.loads(masks) if isinstance(masks, str) else masks |
|
logger.debug(f"masks---\n{masks}") |
|
masks = [ |
|
m[2:-1].encode() if m.startswith("b'") and m.endswith("'") else m for m in masks |
|
] |
|
masks = np.array([b64_mask_decode(m).astype(np.uint8) for m in masks]) |
|
logger.debug(f"masks---\n{masks}") |
|
return run_sam_video_inference( |
|
model, |
|
video_path=video_path, |
|
masks=masks, |
|
device="cuda", |
|
do_tidy_up=True, |
|
drop_mask=drop_masks, |
|
async_frame_load=async_frame_load, |
|
ref_frame_idx=ref_frame_idx, |
|
) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Tab("Images"): |
|
gr.Interface( |
|
fn=process_image, |
|
inputs=[ |
|
gr.Image(label="Input Image", type="pil"), |
|
gr.Dropdown( |
|
label="Model Variant", |
|
choices=["tiny", "small", "base_plus", "large"], |
|
), |
|
gr.Textbox( |
|
label="Bounding Boxes", |
|
value=None, |
|
lines=5, |
|
placeholder='JSON list of dicts: [{"x0":..., "y0":..., "x1":..., "y1":...}, ...]', |
|
), |
|
gr.Textbox( |
|
label="Points", |
|
lines=3, |
|
placeholder='JSON list of dicts: [{"x":..., "y":...}, ...]', |
|
), |
|
gr.Textbox(label="Points' Labels", placeholder="JSON List of Integars"), |
|
], |
|
outputs=gr.JSON(label="Output JSON"), |
|
title="SAM2 for Images", |
|
api_name="process_image", |
|
) |
|
with gr.Tab("Videos"): |
|
gr.Interface( |
|
fn=process_video, |
|
inputs=[ |
|
gr.Video(label="Input Video"), |
|
gr.Dropdown( |
|
label="Model Variant", |
|
choices=["tiny", "small", "base_plus", "large"], |
|
), |
|
gr.Textbox( |
|
label="Masks for Objects of Interest in the First Frame", |
|
value=None, |
|
lines=5, |
|
placeholder=""" |
|
JSON list of base64 encoded masks, e.g.: ["b'iVBORw0KGgoAAAANSUhEUgAABDgAAAeAAQAAAAADGtqnAAAXz...'",...] |
|
""", |
|
), |
|
gr.Checkbox( |
|
label="Drop Masks", |
|
info="remove base64 encoded masks from result JSON", |
|
value=True, |
|
), |
|
gr.Number( |
|
label="Reference Frame Index", |
|
info="frame index for the provided object masks", |
|
value=0, |
|
precision=0, |
|
), |
|
gr.Checkbox( |
|
label="async frame load", |
|
info="start inference in parallel to frame loading", |
|
), |
|
], |
|
outputs=gr.JSON(label="Output JSON"), |
|
title="SAM2 for Videos", |
|
api_name="process_video", |
|
) |
|
|
|
|
|
download_checkpoints() |
|
demo.launch( |
|
mcp_server=True, app_kwargs={"docs_url": "/docs"} |
|
) |
|
|