File size: 4,197 Bytes
08fc0b2
0a2d6b8
dd2b2f9
fec64c5
08fc0b2
cfe97ad
 
 
bc11c81
 
 
a1deaea
 
 
 
 
d16752a
a1deaea
 
d16752a
 
 
 
 
 
 
 
 
 
 
 
 
7139441
fd69a13
7139441
 
 
98833b0
9539987
d040427
 
d16752a
 
965c284
d040427
 
d16752a
d040427
 
3e711ca
d040427
 
d16752a
 
9539987
e0e4048
d040427
8dcef4c
 
3ceed42
d040427
93b65f4
 
 
 
 
 
 
 
 
d040427
a11c80c
eec3e26
d040427
eec3e26
10e906d
d040427
10e906d
642c2d0
10e906d
 
 
a11c80c
3d2adfb
08fc0b2
 
 
 
 
d040427
 
08fc0b2
 
 
f0cce29
429369b
e0d2fcf
f0cce29
e0fc88a
940b814
6b87482
ba2bf14
3d2adfb
ca1560a
 
bc11c81
 
3d2adfb
7139441
08fc0b2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
from diffusers import AutoPipelineForInpainting, AutoencoderKL
import torch
from PIL import Image, ImageOps

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipeline = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", vae=vae, torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to("cuda")

def get_select_index(evt: gr.SelectData):
    return evt.index

def squarify_image(img):
    if(img.height > img.width): bg_size = img.height
    else:  bg_size = img.width
    bg = Image.new(mode="RGB", size=(bg_size,bg_size), color="white")
    bg.paste(img, ( int((bg.width - bg.width)/2), 0) )

    return bg
    
def divisible_by_8(image):
    width, height = image.size
    
    # Calculate the new width and height that are divisible by 8
    new_width = (width // 8) * 8
    new_height = (height // 8) * 8
    
    # Resize the image
    resized_image = image.resize((new_width, new_height))
    
    return resized_image


def restore_version(index, versions):
    final_dict = {'background': versions[index][0], 'layers': None, 'composite': versions[index][0]}
    return final_dict


def generate(image_editor, prompt, neg_prompt, versions):
    image = image_editor['background'].convert('RGB')

    # Resize image
    image.thumbnail((1024, 1024))
    image = divisible_by_8(image)
    original_image_size = image.size

    # Mask layer
    layer = image_editor["layers"][0].resize(image.size)

    # Make image a square
    image = squarify_image(image)

    # Make sure mask is white with a black background
    mask = Image.new("RGBA", image.size, "WHITE") 
    mask.paste(layer, (0, 0), layer)
    mask = ImageOps.invert(mask.convert('L'))

    # Inpaint
    final_image = pipeline(prompt=prompt, 
                           image=image, 
                           mask_image=mask).images[0]


    if (original_image_size[0] > original_image_size[1]):
        original_image_size[0] = original_image_size[0] * (1024/original_image_size[0])
        original_image_size[1] = original_image_size[1] * (1024/original_image_size[0])
    else:
        original_image_size[0] = original_image_size[0] * (1024/original_image_size[1])
        original_image_size[1] = original_image_size[1] * (1024/original_image_size[1])

    
    # Crop image to original aspect ratio
    final_image = final_image.crop((0, 0, original_image_size[0], original_image_size[1]))

    # gradio.ImageEditor requires a diction
    final_dict = {'background': final_image, 'layers': None, 'composite': final_image}

    # Add generated image to version gallery
    if(versions==None): 
        final_gallery = [image_editor['background'] ,final_image]
    else: 
        final_gallery = versions
        final_gallery.append(final_image)
    
    return final_dict, gr.Gallery(value=final_gallery, visible=True), gr.update(visible=True)

with gr.Blocks() as demo:
    gr.Markdown("""
    # Inpainting Sketch Pad
    by [Tony Assi](https://www.tonyassi.com/)

    Please ❤️ this Space. I build custom AI apps for companies. <a href="mailto: [email protected]">Email me</a> for business inquiries.
    """)
    
    with gr.Row():
        with gr.Column():
            sketch_pad = gr.ImageMask(type='pil', label='Inpaint')
            prompt = gr.Textbox(label="Prompt")
            generate_button = gr.Button("Generate")
            with gr.Accordion("Advanced Settings", open=False):
                neg_prompt = gr.Textbox(label='Negative Prompt', value='ugly, deformed')
        with gr.Column():
            version_gallery = gr.Gallery(label="Versions", type="pil", object_fit='contain', visible=False)
            restore_button = gr.Button("Restore Version", visible=False)
            selected = gr.Number(show_label=False, visible=False)
    

    version_gallery.select(get_select_index, None, selected)
    generate_button.click(fn=generate, inputs=[sketch_pad,prompt, neg_prompt, version_gallery], outputs=[sketch_pad, version_gallery, restore_button])
    restore_button.click(fn=restore_version, inputs=[selected, version_gallery], outputs=sketch_pad)

demo.launch()