import gradio as gr
import torch
from diffusers import (
    AutoPipelineForText2Image,
    StableDiffusionXLControlNetPipeline,
    DiffusionPipeline,
    StableDiffusionImg2ImgPipeline,
    StableDiffusionInpaintPipeline,
    StableDiffusionAdapterPipeline,
    StableDiffusionControlNetPipeline,
    StableDiffusionXLAdapterPipeline,
    StableDiffusionXLImg2ImgPipeline,
    StableDiffusionXLInpaintPipeline,
    ControlNetModel,
    T2IAdapter,
)
import time
import utils


dtype = torch.float16
device = torch.device("cuda")

pipeline_mapping = {
    "SD T2I": (DiffusionPipeline, "runwayml/stable-diffusion-v1-5"),
    "SD I2I": (StableDiffusionImg2ImgPipeline, "runwayml/stable-diffusion-v1-5"),
    "SD Inpainting": (
        StableDiffusionInpaintPipeline,
        "runwayml/stable-diffusion-inpainting",
    ),
    "SD ControlNet": (
        StableDiffusionControlNetPipeline,
        "runwayml/stable-diffusion-v1-5",
        "lllyasviel/sd-controlnet-canny",
    ),
    "SD T2I Adapters": (
        StableDiffusionAdapterPipeline,
        "CompVis/stable-diffusion-v1-4" "TencentARC/t2iadapter_canny_sd14v1",
    ),
    "SDXL T2I": (DiffusionPipeline, "stabilityai/stable-diffusion-xl-base-1.0"),
    "SDXL I2I": (
        StableDiffusionXLImg2ImgPipeline,
        "stabilityai/stable-diffusion-xl-base-1.0",
    ),
    "SDXL Inpainting": (
        StableDiffusionXLInpaintPipeline,
        "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
    ),
    "SDXL ControlNet": (
        StableDiffusionXLControlNetPipeline,
        "stabilityai/stable-diffusion-xl-base-1.0",
        "diffusers/controlnet-canny-sdxl-1.0",
    ),
    "SDXL T2I Adapters": (
        StableDiffusionXLAdapterPipeline,
        "stabilityai/stable-diffusion-xl-base-1.0",
        "TencentARC/t2i-adapter-canny-sdxl-1.0",
    ),
    "Kandinsky 2.2 (T2I)": (AutoPipelineForText2Image, "kandinsky-community/kandinsky-2-2-decoder"),
    "Würstchen (T2I)": (AutoPipelineForText2Image, "warp-ai/wuerstchen")
}


def load_pipeline(
    pipeline_to_benchmark: str,
    use_channels_last: bool = False,
    do_torch_compile: bool = False,
):
    # Get pipeline details.
    pipeline_details = pipeline_mapping[pipeline_to_benchmark]
    pipeline_cls = pipeline_details[0]
    pipeline_ckpt = pipeline_details[1]

    # Load adapter if needed.
    if "ControlNet" in pipeline_to_benchmark:
        controlnet_ckpt = pipeline_details[2]
        controlnet = ControlNetModel.from_pretrained(
            controlnet_ckpt, variant="fp16", torch_dtype=torch.float16
        ).to(device)
    elif "Adapters" in pipeline_to_benchmark:
        adapter_clpt = pipeline_details[2]
        adapter = T2IAdapter.from_pretrained(
            adapter_clpt, variant="fp16", torch_dtype=torch.float16
        ).to(device)

    # Load pipeline.
    if (
        "ControlNet" not in pipeline_to_benchmark
        or "Adapters" not in pipeline_to_benchmark
    ):
        pipeline = pipeline_cls.from_pretrained(
            pipeline_ckpt, variant="fp16", torch_dtype=dtype
        )

    elif "ControlNet" in pipeline_to_benchmark:
        pipeline = pipeline_cls.from_pretrained(pipeline_ckpt, controlnet=controlnet)
    elif "Adapters" in pipeline_to_benchmark:
        pipeline = pipeline_cls.from_pretrained(pipeline_ckpt, adapter=adapter)
    
    pipeline.to(device)

    # Optionally set memory layout.
    if use_channels_last:
        if pipeline_to_benchmark not in ["Würstchen (T2I)", "Kandinsky 2.2 (T2I)"]:
            pipeline.unet.to(memory_format=torch.channels_last)
        elif pipeline_to_benchmark == "Würstchen (T2I)": 
            pipeline.prior.to(memory_format=torch.channels_last)
            pipeline.decoder.to(memory_format=torch.channels_last)
        elif pipeline_to_benchmark == "Kandinsky 2.2 (T2I)":
            pipeline.unet.to(memory_format=torch.channels_last)

        if hasattr(pipeline, "controlnet"):
            pipeline.controlnet.to(memory_format=torch.channels_last)
        elif hasattr(pipeline, "adapter"):
            pipeline.adapter.to(memory_format=torch.channels_last)

    # Optional torch compilation.
    if do_torch_compile:
        if pipeline_to_benchmark not in ["Würstchen (T2I)", "Kandinsky 2.2 (T2I)"]:
            pipeline.unet = torch.compile(
                pipeline.unet, mode="reduce-overhead", fullgraph=True
            )
        elif pipeline_to_benchmark == "Würstchen (T2I)":
            pipeline.prior = torch.compile(pipeline.prior, mode="reduce-overhead", fullgraph=True)
            pipeline.decoder = torch.compile(pipeline.decoder, mode="reduce-overhead", fullgraph=True)
        elif  pipeline_to_benchmark == "Kandinsky 2.2 (T2I)":
            pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
        
        if hasattr(pipeline, "controlnet"):
            pipeline.controlnet = torch.compile(
                pipeline.controlnet, mode="reduce-overhead", fullgraph=True
            )
        elif hasattr(pipeline, "adapter"):
            pipeline.adapter = torch.compile(
                pipeline.adapter, mode="reduce-overhead", fullgraph=True
            )

    return pipeline


def generate(
    pipeline_to_benchmark: str,
    num_images_per_prompt: int = 1,
    use_channels_last: bool = False,
    do_torch_compile: bool = False,
):
    print("Start...")
    print("Torch version", torch.__version__)
    print("Torch CUDA version", torch.version.cuda)

    pipeline = load_pipeline(
        pipeline_to_benchmark=pipeline_to_benchmark,
        use_channels_last=use_channels_last,
        do_torch_compile=do_torch_compile,
    )
    for _ in range(3):
        prompt = 77 * "a"
        num_inference_steps = 20
        call_args = dict(
            prompt=prompt,
            num_images_per_prompt=num_images_per_prompt,
            num_inference_steps=num_inference_steps,
        )

        if pipeline_to_benchmark in ["SD I2I", "SDXL I2I"]:
            image = utils.get_image_for_img_to_img(pipeline_to_benchmark)
            call_args.update({"image": image})
        elif "Inpainting" in pipeline_to_benchmark:
            image, mask_image = utils.get_image_from_inpainting(pipeline_to_benchmark)
            call_args.update({"image": image, "mask_image": mask_image})
        elif "ControlNet" in pipeline_to_benchmark:
            image = utils.get_image_for_controlnet(pipeline_to_benchmark)
            call_args.update({"image": image})

        elif "Adapters" in pipeline_to_benchmark:
            image = utils.get_image_for_adapters(pipeline_to_benchmark)
            call_args.update({"image": image})

        start_time = time.time()
        _ = pipeline(**call_args).images
        end_time = time.time()

        print(f"For {num_inference_steps} steps", end_time - start_time)
        print("Avg per step", (end_time - start_time) / num_inference_steps)


with gr.Blocks() as demo:
    do_torch_compile = gr.Checkbox(label="Enable torch.compile()?")
    use_channels_last = gr.Checkbox(label="Use `channels_last` memory layout?")
    pipeline_to_benchmark = (
        gr.Dropdown(
            list(pipeline_mapping.keys()),
            value=["Stable Diffusion V1.5"],
            multiselect=False,
            label="Pipeline to benchmark",
        ),
    )
    batch_size = gr.Slider(
        label="Number of images per prompt",
        minimum=1,
        maximum=16,
        step=1,
        value=1,
    )
    btn = gr.Button("Benchmark!").style(
        margin=False,
        rounded=(False, True, True, False),
        full_width=False,
    )

    btn.click(
        fn=generate,
        inputs=[pipeline_to_benchmark, batch_size, use_channels_last, do_torch_compile],
    )

demo.launch()