import gradio as gr import torch from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation from diffusers import StableDiffusionInpaintPipeline from PIL import Image, ImageOps import PIL # cuda cpu device_name = 'cpu' device = torch.device(device_name) processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device) inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting").to(device) def numpy_to_pil(images): if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") if images.shape[-1] == 1: # special case for grayscale (single channel) images pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] else: pil_images = [Image.fromarray(image) for image in images] return pil_images def get_mask(text, image): inputs = processor( text=[text], images=[image], padding="max_length", return_tensors="pt" ).to(device) outputs = model(**inputs) mask = torch.sigmoid(outputs.logits).cpu().detach().unsqueeze(-1).numpy() mask_pil = numpy_to_pil(mask)[0].resize(image.size) #mask_pil.show() return mask_pil def predict(prompt, negative_prompt, image, obj2mask): mask = get_mask(obj2mask, image) image = image.convert("RGB").resize((512, 512)) mask_image = mask.convert("RGB").resize((512, 512)) mask_image = ImageOps.invert(mask_image) images = inpainting_pipeline(prompt=prompt, negative_prompt=negative_prompt, image=image, mask_image=mask_image).images mask = mask_image.convert('L') PIL.Image.composite(images[0], image, mask) return (images[0]) def inference(prompt, negative_prompt, obj2mask, image_numpy): generator = torch.Generator() generator.manual_seed(int(52362)) image = numpy_to_pil(image_numpy)[0].convert("RGB").resize((512, 512)) img = predict(prompt, negative_prompt, image, obj2mask) return img with gr.Blocks() as demo: with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt", value="cinematic, landscape, sharpe focus") negative_prompt = gr.Textbox(label="Negative Prompt", value="illustration, 3d render") mask = gr.Textbox(label="Mask", value="shoe") intput_img = gr.Image() run = gr.Button(value="Generate") with gr.Column(): output_img = gr.Image() run.click( inference, inputs=[prompt, negative_prompt, mask, intput_img ], outputs=output_img, ) demo.queue(concurrency_count=1) demo.launch()