#!/usr/bin/env python import os import gradio as gr import numpy as np import PIL.Image import spaces import torch from transformers import VitMatteForImageMatting, VitMatteImageProcessor DESCRIPTION = "# [ViTMatte](https://github.com/hustvl/ViTMatte)" device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1500")) MODEL_ID = os.getenv("MODEL_ID", "hustvl/vitmatte-small-distinctions-646") processor = VitMatteImageProcessor.from_pretrained(MODEL_ID) model = VitMatteForImageMatting.from_pretrained(MODEL_ID).to(device) def check_image_size(image: PIL.Image.Image) -> None: if max(image.size) > MAX_IMAGE_SIZE: raise gr.Error(f"Image size is too large. Max image size is {MAX_IMAGE_SIZE} pixels.") def binarize_mask(mask: np.ndarray) -> np.ndarray: mask[mask < 128] = 0 mask[mask > 0] = 1 return mask def update_trimap(foreground_mask: dict[str, np.ndarray], unknown_mask: dict[str, np.ndarray]) -> np.ndarray: foreground = foreground_mask["mask"][:, :, 0] foreground = binarize_mask(foreground) unknown = unknown_mask["mask"][:, :, 0] unknown = binarize_mask(unknown) trimap = np.zeros_like(foreground) trimap[unknown > 0] = 128 trimap[foreground > 0] = 255 return trimap def adjust_background_image(background_image: PIL.Image.Image, target_size: tuple[int, int]) -> PIL.Image.Image: target_w, target_h = target_size bg_w, bg_h = background_image.size scale = max(target_w / bg_w, target_h / bg_h) new_bg_w = int(bg_w * scale) new_bg_h = int(bg_h * scale) background_image = background_image.resize((new_bg_w, new_bg_h)) left = (new_bg_w - target_w) // 2 top = (new_bg_h - target_h) // 2 right = left + target_w bottom = top + target_h background_image = background_image.crop((left, top, right, bottom)) return background_image def replace_background( image: PIL.Image.Image, alpha: np.ndarray, background_image: PIL.Image.Image | None ) -> PIL.Image.Image | None: if background_image is None: return None if image.mode != "RGB": raise gr.Error("Image must be RGB.") background_image = background_image.convert("RGB") background_image = adjust_background_image(background_image, image.size) image = np.array(image).astype(float) / 255 background_image = np.array(background_image).astype(float) / 255 result = image * alpha[:, :, None] + background_image * (1 - alpha[:, :, None]) result = (result * 255).astype(np.uint8) return result @spaces.GPU @torch.inference_mode() def run( image: PIL.Image.Image, trimap: PIL.Image.Image, apply_background_replacement: bool, background_image: PIL.Image.Image | None, ) -> tuple[np.ndarray, PIL.Image.Image, PIL.Image.Image | None]: if image.size != trimap.size: raise gr.Error("Image and trimap must have the same size.") if max(image.size) > MAX_IMAGE_SIZE: raise gr.Error(f"Image size is too large. Max image size is {MAX_IMAGE_SIZE} pixels.") if image.mode != "RGB": raise gr.Error("Image must be RGB.") if trimap.mode != "L": raise gr.Error("Trimap must be grayscale.") pixel_values = processor(images=image, trimaps=trimap, return_tensors="pt").to(device).pixel_values out = model(pixel_values=pixel_values) alpha = out.alphas[0, 0].to("cpu").numpy() w, h = image.size alpha = alpha[:h, :w] foreground = np.array(image).astype(float) / 255 * alpha[:, :, None] + (1 - alpha[:, :, None]) foreground = (foreground * 255).astype(np.uint8) foreground = PIL.Image.fromarray(foreground) if apply_background_replacement: res_bg_replacement = replace_background(image, alpha, background_image) else: res_bg_replacement = None return alpha, foreground, res_bg_replacement with gr.Blocks(css="style.css") as demo: gr.Markdown(DESCRIPTION) gr.DuplicateButton( value="Duplicate Space for private use", elem_id="duplicate-button", visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1", ) with gr.Row(): with gr.Column(): with gr.Box(): image = gr.Image(label="Input image", type="pil", height=500) with gr.Tabs(): with gr.Tab(label="Trimap"): trimap = gr.Image(label="Trimap", type="pil", image_mode="L", height=500) with gr.Tab(label="Draw trimap"): load_image_button = gr.Button("Load image") foreground_mask = gr.Image( label="Foreground", tool="sketch", type="numpy", brush_color="green", mask_opacity=0.7, height=500, ) unknown_mask = gr.Image( label="Unkown", tool="sketch", type="numpy", brush_color="green", mask_opacity=0.7, height=500, ) set_trimap_button = gr.Button("Set trimap") apply_background_replacement = gr.Checkbox(label="Apply background replacement", checked=False) background_image = gr.Image(label="Background image", type="pil", height=500, visible=False) run_button = gr.Button("Run") with gr.Column(): with gr.Box(): out_alpha = gr.Image(label="Alpha", height=500) out_foreground = gr.Image(label="Foreground", height=500) out_background_replacement = gr.Image(label="Background replacement", height=500, visible=False) inputs = [ image, trimap, apply_background_replacement, background_image, ] outputs = [ out_alpha, out_foreground, out_background_replacement, ] gr.Examples( examples=[ ["assets/retriever_rgb.png", "assets/retriever_trimap.png", False, None], ["assets/bulb_rgb.png", "assets/bulb_trimap.png", True, "assets/new_bg.jpg"], ], inputs=inputs, outputs=outputs, fn=run, cache_examples=os.getenv("CACHE_EXAMPLES") == "1", ) image.change( fn=check_image_size, inputs=image, queue=False, api_name=False, ) load_image_button.click( fn=lambda image: (image, image), inputs=image, outputs=[foreground_mask, unknown_mask], queue=False, api_name=False, ) set_trimap_button.click( fn=update_trimap, inputs=[foreground_mask, unknown_mask], outputs=trimap, queue=False, api_name=False, ) apply_background_replacement.change( fn=lambda checked: (gr.update(visible=checked), gr.update(visible=checked)), inputs=apply_background_replacement, outputs=[background_image, out_background_replacement], queue=False, api_name=False, ) run_button.click( fn=run, inputs=inputs, outputs=outputs, api_name="run", ) if __name__ == "__main__": demo.queue(max_size=20).launch()