mtr-sd / server.py
merligus's picture
Upload server.py
9d12700
import gradio as gr
from PIL import Image
import torch
from diffusers import StableDiffusionInpaintPipeline
# load models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# stable diffusion (inpainting)
sd_pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
# torch_dtype=torch.float16,
).to(device)
def generate_image(image, mask, prompt, negative_prompt, pipe, seed):
# resize for inpainting
w, h = image.size
in_image = image.resize((512, 512))
in_mask = mask.resize((512, 512))
generator = torch.Generator(device).manual_seed(seed)
result = pipe(image=in_image, mask_image=in_mask, prompt=prompt, negative_prompt=negative_prompt, generator=generator)
result = result.images[0]
return result.resize((w, h))
prompt="perfect skin"
negative_prompt=""
seed = 7 # for reproducibility
def predict(image, mask):
# convert
image_source_pil = Image.fromarray(image)
image_mask_pil = Image.fromarray(mask)
# inference
generated_image = generate_image(image=image_source_pil, mask=image_mask_pil, prompt=prompt, negative_prompt=negative_prompt, pipe=sd_pipe, seed=seed)
return generated_image
if __name__ == "__main__":
io = gr.Interface(predict, ["image", "image"], "image").queue().launch()