Spaces:
Runtime error
Runtime error
from typing import Dict | |
import gradio as gr | |
import torch | |
from PIL import Image | |
from transformers import SamModel, SamProcessor | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
MODEL = SamModel.from_pretrained("facebook/sam-vit-large").to(DEVICE) | |
PROCESSOR = SamProcessor.from_pretrained("facebook/sam-vit-large") | |
def inference(masked_image: Dict[str, Image.Image]) -> Image.Image: | |
image = masked_image['image'] | |
mask = masked_image['mask'].resize((256, 256), Image.Resampling.LANCZOS) | |
return image | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image( | |
image_mode='RGB', type='pil', tool="sketch", interactive=True, | |
brush_radius=20.0, brush_color="#FFFFFF", height=500) | |
submit_button = gr.Button("Submit") | |
output_image = gr.Image(image_mode='RGB', type='pil') | |
submit_button.click( | |
inference, | |
inputs=[input_image], | |
outputs=output_image) | |
demo.launch(debug=False, show_error=True) | |