import base64 import os import pdb import random import sys import time from io import BytesIO import gradio as gr import numpy as np import spaces import torch import torchvision.transforms.functional as TF from PIL import Image from torchvision import transforms from src.img2skt import image_to_sketch_gif from src.model import make_1step_sched from src.pix2pix_turbo import Pix2Pix_Turbo model = Pix2Pix_Turbo("sketch_to_image_stochastic") style_list = [ { "name": "No Style", "prompt": "{prompt}", }, { "name": "Cinematic", "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", }, { "name": "3D Model", "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting", }, { "name": "Anime", "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed", }, { "name": "Digital Art", "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed", }, { "name": "Photographic", "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed", }, { "name": "Pixel art", "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics", }, { "name": "Fantasy art", "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy", }, { "name": "Neonpunk", "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional", }, { "name": "Manga", "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style", }, ] styles = {k["name"]: k["prompt"] for k in style_list} STYLE_NAMES = list(styles.keys()) DEFAULT_STYLE_NAME = "Manga" MAX_SEED = np.iinfo(np.int32).max HEIGHT = 512 * 3 # Display height WIDTH = 512 * 3 # Display width PROC_WIDTH = 512 # Processing width PROC_HEIGHT = 512 # Processing height ITER_DELAY = 1.0 # Create a white background image def create_white_background(width, height): return Image.new("RGB", (width, height), color="white") white_background = create_white_background(WIDTH, HEIGHT) @spaces.GPU(duration=45) def run(image, prompt, prompt_template, style_name, seed, val_r): image = image["composite"] if image.size != (PROC_WIDTH, PROC_HEIGHT): image = image.resize((PROC_WIDTH, PROC_HEIGHT)) prompt = prompt_template.replace("{prompt}", prompt) image = image.convert("RGB") image = Image.fromarray(255 - np.array(image)) image_t = TF.to_tensor(image) > 0.5 with torch.no_grad(): c_t = image_t.unsqueeze(0).cuda().float() torch.manual_seed(seed) B, C, H, W = c_t.shape noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device) output_image = model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise) output_pil = TF.to_pil_image(output_image[0].cpu() * 0.5 + 0.5) if output_pil.size != (WIDTH, HEIGHT): output_pil = output_pil.resize((WIDTH, HEIGHT)) return output_pil def clear_image_editor(): return ( {"background": white_background, "layers": None, "composite": None}, gr.Image( value=None, height=HEIGHT, width=WIDTH, elem_id="output_image", type="pil", show_label=False, show_download_button=True, interactive=False, ), gr.Image( value=None, height=HEIGHT, width=WIDTH, show_label=False, show_download_button=True, type="pil", interactive=False, ), gr.Image( value=None, height=HEIGHT, width=WIDTH, show_label=False, show_download_button=True, type="pil", interactive=False, ), gr.State([]), gr.Slider( minimum=0, maximum=1, value=0, step=1, visible=False, scale=4, label="Frame Selector", ), gr.Button("Stop", scale=1, visible=True), ) def iter_frames(frames): for frame in frames: time.sleep(ITER_DELAY) yield { "background": white_background, "layers": [frame], "composite": None, } def apply_func_click(): return gr.Slider( visible=True, ) def frame_selector_change(frame_idx, frames): return { "background": white_background, "layers": [frames[frame_idx]], "composite": None, } with gr.Blocks(fill_width=True, fill_height=True) as demo: image = gr.Sketchpad( value={ "background": white_background, "layers": None, "composite": white_background, }, image_mode="L", type="pil", sources=None, # container=True, label="Sketch", show_label=True, show_download_button=True, # show_share_button=True, interactive=True, layers=False, # height="80vw", canvas_size=(WIDTH, HEIGHT), show_fullscreen_button=False, brush=gr.Brush( colors=["#000000"], color_mode="fixed", default_size=4, ), ) prompt = gr.Textbox(label="Prompt", value="", show_label=True) with gr.Row(): run_button = gr.Button("Run", scale=1) randomize_seed = gr.Button("Random", scale=1) with gr.Row(): apply_button = gr.Button("Stop", scale=1, visible=True) with gr.Row(): frame_selector = gr.Slider( minimum=0, maximum=1, value=0, step=1, visible=False, scale=4, label="Frame Selector", ) with gr.Row(): style = gr.Dropdown( label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1, visible=False, ) prompt_temp = gr.Textbox( label="Prompt Style Template", value=styles[DEFAULT_STYLE_NAME], max_lines=1, scale=2, visible=False, ) with gr.Row(): val_r = gr.Slider( label="Sketch guidance: ", show_label=True, minimum=0, maximum=1, value=0.4, step=0.01, scale=4, visible=False, ) seed = gr.Textbox(label="Seed", value=42, scale=4, visible=False) result = gr.Image( height=HEIGHT, width=WIDTH, elem_id="output_image", type="pil", show_label=False, show_download_button=True, interactive=False, visible=False, ) gr.Markdown("### Instructions") gr.Markdown("1. Enter a text prompt (e.g. cat)") gr.Markdown("2. Draw some sketches on the Sketchpad") gr.Markdown("3. Click on **Run** to generate the skecthes powered by AI") gr.Markdown( "4. While you see the sketches coming out, click on **Stop** to stop more frames coming out" ) gr.Markdown("5. Then you can select a frame by the Frame Selector") gr.Markdown( "6. You may then continue to draw more sketches or change the prompt and repeat the process" ) gr.Markdown("7. You may try different random seeds by clicking on **Random**") gr.Markdown( "**Thanks to the [paper](https://arxiv.org/abs/2403.12036) and their open-sourced models!**" ) frames = gr.State([]) sketches = gr.Image( height=HEIGHT, width=WIDTH, show_label=False, show_download_button=True, type="pil", visible=False, ) one_frame = gr.Image( height=HEIGHT, width=WIDTH, show_label=False, show_download_button=True, type="pil", interactive=False, visible=False, ) inputs = [image, prompt, prompt_temp, style, seed, val_r] outputs = [result] randomize_seed_click = ( randomize_seed.click( lambda: random.randint(0, MAX_SEED), inputs=[], outputs=seed, ) .then( fn=run, inputs=inputs, outputs=outputs, ) .then( image_to_sketch_gif, inputs=[result], outputs=[sketches, frames, frame_selector, apply_button], ) .then( iter_frames, inputs=[frames], outputs=[image], ) ) prompt_submit = ( prompt.submit(fn=run, inputs=inputs, outputs=outputs) .then( image_to_sketch_gif, inputs=[result], outputs=[sketches, frames, frame_selector, apply_button], ) .then( iter_frames, inputs=[frames], outputs=[image], ) ) style_change = ( style.change(lambda x: styles[x], inputs=[style], outputs=[prompt_temp]) .then( fn=run, inputs=inputs, outputs=outputs, ) .then( image_to_sketch_gif, inputs=[result], outputs=[sketches, frames, frame_selector, apply_button], ) .then( iter_frames, inputs=[frames], outputs=[image], ) ) val_r_change = ( val_r.change(run, inputs=inputs, outputs=outputs) .then( image_to_sketch_gif, inputs=[result], outputs=[sketches, frames, frame_selector, apply_button], ) .then( iter_frames, inputs=[frames], outputs=[image], ) ) run_button_click = ( run_button.click(fn=run, inputs=inputs, outputs=outputs) .then( image_to_sketch_gif, inputs=[result], outputs=[sketches, frames, frame_selector, apply_button], ) .then( iter_frames, inputs=[frames], outputs=[image], ) ) image_apply = ( image.apply( run, inputs=inputs, outputs=outputs, ) .then( image_to_sketch_gif, inputs=[result], outputs=[sketches, frames, frame_selector, apply_button], ) .then( iter_frames, inputs=[frames], outputs=[image], ) ) apply_button.click( fn=None, inputs=None, outputs=None, cancels=[ run_button_click, randomize_seed_click, prompt_submit, style_change, val_r_change, image_apply, ], ) apply_button.click( fn=apply_func_click, inputs=None, outputs=[frame_selector], ) frame_selector.release( fn=frame_selector_change, inputs=[frame_selector, frames], outputs=[image] ) image.clear( fn=None, inputs=None, outputs=None, cancels=[ run_button_click, randomize_seed_click, prompt_submit, style_change, val_r_change, image_apply, ], ) image.clear( fn=clear_image_editor, inputs=None, outputs=[ image, result, sketches, one_frame, frames, frame_selector, apply_button, ], ) if __name__ == "__main__": demo.queue().launch()