from typing import Dict import gradio as gr import torch from PIL import Image from transformers import SamModel, SamProcessor DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL = SamModel.from_pretrained("facebook/sam-vit-large").to(DEVICE) PROCESSOR = SamProcessor.from_pretrained("facebook/sam-vit-large") def inference(masked_image: Dict[str, Image.Image]) -> Image.Image: image = masked_image['image'] mask = masked_image['mask'].resize((256, 256), Image.Resampling.LANCZOS) return image with gr.Blocks() as demo: with gr.Row(): with gr.Column(): input_image = gr.Image( image_mode='RGB', type='pil', tool="sketch", interactive=True, brush_radius=20.0, brush_color="#FFFFFF", height=500) submit_button = gr.Button("Submit") output_image = gr.Image(image_mode='RGB', type='pil') submit_button.click( inference, inputs=[input_image], outputs=output_image) demo.launch(debug=False, show_error=True)