SkalskiP's picture
:tada: initial commit
7b4534e
raw
history blame
1.04 kB
from typing import Dict
import gradio as gr
import torch
from PIL import Image
from transformers import SamModel, SamProcessor
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL = SamModel.from_pretrained("facebook/sam-vit-large").to(DEVICE)
PROCESSOR = SamProcessor.from_pretrained("facebook/sam-vit-large")
def inference(masked_image: Dict[str, Image.Image]) -> Image.Image:
image = masked_image['image']
mask = masked_image['mask'].resize((256, 256), Image.Resampling.LANCZOS)
return image
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_image = gr.Image(
image_mode='RGB', type='pil', tool="sketch", interactive=True,
brush_radius=20.0, brush_color="#FFFFFF", height=500)
submit_button = gr.Button("Submit")
output_image = gr.Image(image_mode='RGB', type='pil')
submit_button.click(
inference,
inputs=[input_image],
outputs=output_image)
demo.launch(debug=False, show_error=True)