from argparse import ArgumentParser

from diffusers import DDIMScheduler, StableDiffusionXLImg2ImgPipeline
import gradio as gr
import torch
import yaml

from ctrl_x.pipelines.pipeline_sdxl import CtrlXStableDiffusionXLPipeline
from ctrl_x.utils import *
from ctrl_x.utils.sdxl import *

import spaces


parser = ArgumentParser()
parser.add_argument("-m", "--model", type=str, default=None)  # Optionally, load model checkpoint from single file
args = parser.parse_args()

torch.backends.cudnn.enabled = False  # Sometimes necessary to suppress CUDNN_STATUS_NOT_SUPPORTED

torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
refiner_id_or_path = "stabilityai/stable-diffusion-xl-refiner-1.0"
device = "cuda" if torch.cuda.is_available() else "cpu"
#variant = "fp16" if device == "cuda" else "fp32"

scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler")  # TODO: Support other schedulers
if args.model is None:
    pipe = CtrlXStableDiffusionXLPipeline.from_pretrained(
        model_id_or_path, scheduler=scheduler, torch_dtype=torch_dtype, use_safetensors=True
    )
else:
    print(f"Using weights {args.model} for SDXL base model.")
    pipe = CtrlXStableDiffusionXLPipeline.from_single_file(args.model, scheduler=scheduler, torch_dtype=torch_dtype)
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
    refiner_id_or_path, scheduler=scheduler, text_encoder_2=pipe.text_encoder_2, vae=pipe.vae,
    torch_dtype=torch_dtype, use_safetensors=True,
)

if torch.cuda.is_available():
    pipe = pipe.to("cuda")
    refiner = refiner.to("cuda")
    

def get_control_config(structure_schedule, appearance_schedule):
    s = structure_schedule
    a = appearance_schedule
    
    control_config =\
f"""control_schedule:
    #       structure_conv   structure_attn   appearance_attn  conv/attn
    encoder:                                                # (num layers)
        0: [[             ], [             ], [             ]]  # 2/0
        1: [[             ], [             ], [{a}, {a}     ]]  # 2/2
        2: [[             ], [             ], [{a}, {a}     ]]  # 2/2
    middle: [[            ], [             ], [             ]]  # 2/1
    decoder:
        0: [[{s}          ], [{s}, {s}, {s}], [0.0, {a}, {a}]]  # 3/3
        1: [[             ], [             ], [{a}, {a}     ]]  # 3/3
        2: [[             ], [             ], [             ]]  # 3/0

control_target:
    - [output_tensor]  # structure_conv   choices: {{hidden_states, output_tensor}}
    - [query, key]     # structure_attn   choices: {{query, key, value}}
    - [before]         # appearance_attn  choices: {{before, value, after}}

self_recurrence_schedule:
    - [0.1, 0.5, 2]  # format: [start, end, num_recurrence]"""
    
    return control_config
    

css = """
.config textarea {font-family: monospace; font-size: 80%; white-space: pre}
.mono {font-family: monospace}
"""

title = """
<div style="display: flex; align-items: center; justify-content: center;margin-bottom: -15px">
    <h1 style="margin-left: 12px;text-align: center;display: inline-block">
        Ctrl-X: Controlling Structure and Appearance for Text-To-Image Generation Without Guidance
    </h1>
    <h3 style="display: inline-block; margin-left: 10px; margin-top: 7.5px; font-weight: 500">
        SDXL v1.0
    </h3>
</div>
<div style="display: flex; align-items: center; justify-content: center;margin-bottom: 25px">
    <h3 style="text-align: center">
        [<a href="https://genforce.github.io/ctrl-x/">Page</a>]
        &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
        [<a href="https://arxiv.org/abs/2406.07540">Paper</a>]
        &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
        [<a href="https://github.com/genforce/ctrl-x">Code</a>]
    </h3>
</div>
"""
description = """<div>
    <p>
        <b>Ctrl-X</b> is a simple training-free and guidance-free framework for text-to-image (T2I) generation with 
        structure and appearance control. Given structure and appearance images, Ctrl-X designs feedforward structure 
        control to enable structure alignment with the arbitrary structure image and semantic-aware appearance transfer 
        to facilitate the appearance transfer from the appearance image.
    </p>
    <p>
        Here are some notes and tips for this demo:
    </p>
    <ul>
        <li> On input images:
            <ul>
                <li>
                    If both the structure and appearance images are provided, then Ctrl-X does <i>structure and 
                    appearance</i> control.
                </li>
                <li>
                    If only the structure image is provided, then Ctrl-X does <i>structure-only</i> control and the 
                    appearance image is jointly generated with the output image.
                </li>
                <li>
                    Similarly, if only the appearance image is provided, then Ctrl-X does <i>appearance-only</i> 
                    control.
                </li>
            </ul>
        </li>
        <li> On prompts:
            <ul>
                <li>
                    Though the output prompt can affect the output image to a noticeable extent, the "accuracy" of the 
                    structure and appearance prompts are not impactful to the final image.
                </li>
                <li>
                    If the structure or appearance prompt is left blank, then it uses the (non-optional) output prompt 
                    by default.
                </li>
            </ul>
        </li>
        <li> On control schedules:
            <ul>
                <li>
                    When "Use advanced config" is <b>OFF</b>, the demo uses the structure guidance 
                    (<span class="mono">structure_conv</span> and <span class="mono">structure_attn</span> 
                    in the advanced config) and appearance guidance (<span class="mono">appearance_attn</span> in the 
                    advanced config) sliders to change the control schedules.
                </li>
                <li>
                    Otherwise, the demo uses "Advanced control config," which allows per-layer structure and 
                    appearance schedule control, along with self-recurrence control. <i>This should be used 
                    carefully</i>, and we recommend switching "Use advanced config" <b>OFF</b> in most cases. (For the 
                    examples provided at the bottom of the demo, the advanced config uses the default schedules that 
                    may not be the best settings for these examples.)
                </li>
            </ul>
        </li>
    </ul>
    <p>
        Have fun! :D
    </p>
</div>
"""

@spaces.GPU
def inference(
    structure_image,
    appearance_image,
    prompt,
    structure_prompt,
    appearance_prompt,
    positive_prompt="high quality",
    negative_prompt="ugly, blurry, dark, low res, unrealistic",
    guidance_scale=5.0,
    structure_guidance_scale=5.0,
    appearance_guidance_scale=5.0,
    num_inference_steps=28,
    eta=1.0,
    seed=42,
    width=1024,
    height=1024,
    structure_schedule=0.6,
    appearance_schedule=0.6,
    use_advanced_config=False,
    control_config="",
    progress=gr.Progress(track_tqdm=True)
):
    torch.manual_seed(seed)
    
    pipe.scheduler.set_timesteps(num_inference_steps, device=device)
    timesteps = pipe.scheduler.timesteps
    
    print(f"\nUsing the following control config (use_advanced_config={use_advanced_config}):")
    if not use_advanced_config:
        control_config = get_control_config(structure_schedule, appearance_schedule)
    print(control_config, end="\n\n")
    
    config = yaml.safe_load(control_config)
    register_control(
        model = pipe,
        timesteps = timesteps,
        control_schedule = config["control_schedule"],
        control_target = config["control_target"],
    )
    
    pipe.safety_checker = None
    pipe.requires_safety_checker = False
    
    self_recurrence_schedule = get_self_recurrence_schedule(config["self_recurrence_schedule"], num_inference_steps)

    pipe.set_progress_bar_config(desc="Ctrl-X inference")
    refiner.set_progress_bar_config(desc="Refiner")
    
    result, structure, appearance = pipe(
        prompt = prompt,
        structure_prompt = structure_prompt,
        appearance_prompt = appearance_prompt,
        structure_image = structure_image,
        appearance_image = appearance_image,
        num_inference_steps = num_inference_steps,
        negative_prompt = negative_prompt,
        positive_prompt = positive_prompt,
        height = height,
        width = width,
        guidance_scale = guidance_scale,
        structure_guidance_scale = structure_guidance_scale,
        appearance_guidance_scale = appearance_guidance_scale,
        eta = eta,
        output_type = "pil",
        return_dict = False,
        control_schedule = config["control_schedule"],
        self_recurrence_schedule = self_recurrence_schedule,
    )
    
    result_refiner = refiner(
        image = pipe.refiner_args["latents"],
        prompt = pipe.refiner_args["prompt"],
        negative_prompt = pipe.refiner_args["negative_prompt"],
        height = height,
        width = width,
        num_inference_steps = num_inference_steps,
        guidance_scale = guidance_scale,
        guidance_rescale = 0.7,
        num_images_per_prompt = 1,
        eta = eta,
        output_type = "pil",
    ).images
    del pipe.refiner_args
    
    return [result[0], result_refiner[0], structure[0], appearance[0]]
    
    
with gr.Blocks(theme=gr.themes.Default(), css=css, title="Ctrl-X (SDXL v1.0)") as app:
    gr.HTML(title)
    with gr.Accordion("Instructions", open=False):
        gr.HTML(description)
    with gr.Row():
        with gr.Column(scale=45):
            with gr.Group():
                kwargs = {}  # {"width": 400, "height": 400}
                with gr.Row():
                    structure_image = gr.Image(label="Upload structure image (optional)", type="pil", **kwargs)
                    appearance_image = gr.Image(label="Upload appearance image (optional)", type="pil", **kwargs)
                with gr.Row():
                    structure_prompt = gr.Textbox(label="Structure prompt (optional)", placeholder="Describes the structure image")
                    appearance_prompt = gr.Textbox(label="Appearance prompt (optional)", placeholder="Describes the style image")
                with gr.Row():
                    prompt = gr.Textbox(label="Output prompt", placeholder="Prompt which describes the output image")
                with gr.Row():
                    positive_prompt = gr.Textbox(label="Positive prompt", value="high quality", placeholder="")
                    negative_prompt = gr.Textbox(label="Negative prompt", value="ugly, blurry, dark, low res, unrealistic", placeholder="")
                with gr.Accordion("Advanced Options", open=False):
                    with gr.Row():
                        guidance_scale = gr.Slider(label="Target guidance scale", value=5.0, minimum=1, maximum=10)
                        structure_guidance_scale = gr.Slider(label="Structure guidance scale", value=5.0, minimum=1, maximum=10)
                        appearance_guidance_scale = gr.Slider(label="Appearance guidance scale", value=5.0, minimum=1, maximum=10)
                    with gr.Row():
                        num_inference_steps = gr.Slider(label="# inference steps", value=28, minimum=1, maximum=200, step=1)
                        eta = gr.Slider(label="Eta (noise)", value=1.0, minimum=0, maximum=1.0, step=0.01)
                        seed = gr.Slider(0, 2147483647, label="Seed", value=90095, step=1)
                    with gr.Row():
                        width = gr.Slider(label="Width", value=1024, minimum=256, maximum=2048, step=pipe.vae_scale_factor)
                        height = gr.Slider(label="Height", value=1024, minimum=256, maximum=2048, step=pipe.vae_scale_factor)
                    with gr.Row():
                        structure_schedule = gr.Slider(label="Structure schedule", value=0.6, minimum=0.0, maximum=1.0, step=0.01, scale=2)
                        appearance_schedule = gr.Slider(label="Appearance schedule", value=0.6, minimum=0.0, maximum=1.0, step=0.01, scale=2)
                        use_advanced_config = gr.Checkbox(label="Use advanced config", value=False, scale=1)
                    with gr.Row():
                        control_config = gr.Textbox(
                            label="Advanced control config", lines=20, value=get_control_config(0.6, 0.6), elem_classes=["config"], visible=False,
                        )
                        use_advanced_config.change(
                            fn=lambda value: gr.update(visible=value), inputs=use_advanced_config, outputs=control_config,
                        )
                with gr.Row():
                    generate = gr.Button(value="Run")
        
        with gr.Column(scale=55):
            with gr.Group():
                with gr.Row():
                    result_refiner = gr.Image(label="Output image w/ refiner", format="jpg", **kwargs)
                with gr.Row():
                    result = gr.Image(label="Output image", format="jpg", **kwargs)
                    structure_recon = gr.Image(label="Structure image", format="jpg", **kwargs)
                    appearance_recon = gr.Image(label="Style image", format="jpg", **kwargs)
                        
    inputs = [
        structure_image, appearance_image,
        prompt, structure_prompt, appearance_prompt,
        positive_prompt, negative_prompt,
        guidance_scale, structure_guidance_scale, appearance_guidance_scale,
        num_inference_steps, eta, seed,
        width, height,
        structure_schedule, appearance_schedule, use_advanced_config,
        control_config,
    ]
    outputs = [result, result_refiner, structure_recon, appearance_recon]
    
    generate.click(inference, inputs=inputs, outputs=outputs)

    examples = gr.Examples(
        [
            [
                "assets/images/horse__point_cloud.jpg",
                "assets/images/horse.jpg",
                "a photo of a horse standing on grass",
                "a 3D point cloud of a horse",
                "",
            ],
            [
                "assets/images/cat__mesh.jpg",
                "assets/images/tiger.jpg",
                "a photo of a tiger standing on snow",
                "a 3D mesh of a cat",
                "",
            ],
            [
                "assets/images/dog__sketch.jpg",
                "assets/images/squirrel.jpg",
                "a photo of a squirrel",
                "a sketch of a dog",
                "",
            ],
            [
                "assets/images/living_room__seg.jpg",
                "assets/images/van_gogh.jpg",
                "a Van Gogh painting of a living room",
                "a segmentation map of a living room",
                "",
            ],
            [
                "assets/images/bedroom__sketch.jpg",
                "assets/images/living_room_modern.jpg",
                "a sketch of a bedroom",
                "a photo of a modern bedroom during sunset",
                "",
            ],
            [
                "assets/images/running__pose.jpg",
                "assets/images/man_park.jpg",
                "a photo of a man running in a park",
                "a pose image of a person running",
                "",
            ],
            [
                "assets/images/fruit_bowl.jpg",
                "assets/images/grapes.jpg",
                "a photo of a bowl of grapes in the trees",
                "a photo of a bowl of fruits",
                "",
            ],
            [
                "assets/images/bear_avocado__spatext.jpg",
                None,
                "a realistic photo of a bear and an avocado in a forest",
                "a segmentation map of a bear and an avocado",
                "",
            ],
            [
                "assets/images/cat__point_cloud.jpg",
                None,
                "an embroidery of a white cat sitting on a rock under the night sky",
                "a 3D point cloud of a cat",
                "",
            ],
            [
                "assets/images/library__mesh.jpg",
                None,
                "a Polaroid photo of an old library, sunlight streaming in",
                "a 3D mesh of a library",
                "",
            ],
            [
                "assets/images/knight__humanoid.jpg",
                None,
                "a photo of a medieval soldier standing on a barren field, raining",
                "a 3D model of a person holding a sword and shield",
                "",
            ],
            [
                "assets/images/person__mesh.jpg",
                None,
                "a photo of a Karate man performing in a cyberpunk city at night",
                "a 3D mesh of a person",
                "",
            ],
        ],
        [
            structure_image,
            appearance_image,
            prompt,
            structure_prompt,
            appearance_prompt,
        ],
        examples_per_page=50,
        cache_examples="lazy",
        fn=inference,
        outputs=[result, result_refiner, structure_recon, appearance_recon]
    )

app.launch(debug=False, share=False)