#!/usr/bin/env python

import gradio as gr
import PIL.Image
import spaces
import torch
from controlnet_aux import CannyDetector
from diffusers.pipelines import BlipDiffusionControlNetPipeline

from settings import DEFAULT_NEGATIVE_PROMPT, MAX_INFERENCE_STEPS
from utils import MAX_SEED, randomize_seed_fn

canny_detector = CannyDetector()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    pipe = BlipDiffusionControlNetPipeline.from_pretrained(
        "Salesforce/blipdiffusion-controlnet", torch_dtype=torch.float16
    ).to(device)
else:
    pipe = None


@spaces.GPU
def run(
    condition_image: PIL.Image.Image,
    style_image: PIL.Image.Image,
    condition_subject: str,
    style_subject: str,
    prompt: str,
    negative_prompt: str = DEFAULT_NEGATIVE_PROMPT,
    seed: int = 0,
    guidance_scale: float = 7.5,
    num_inference_steps: int = 25,
) -> PIL.Image.Image:
    if num_inference_steps > MAX_INFERENCE_STEPS:
        error_message = f"Number of inference steps must be less than {MAX_INFERENCE_STEPS}"
        raise gr.Error(error_message)
    condition_image = canny_detector(condition_image, 30, 70, output_type="pil")
    return pipe(
        prompt,
        style_image,
        condition_image,
        style_subject,
        condition_subject,
        generator=torch.Generator(device=device).manual_seed(seed),
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        neg_prompt=negative_prompt,
        height=512,
        width=512,
    ).images[0]


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            condition_image = gr.Image(label="Condition Image")
            style_image = gr.Image(label="Style Image")
            condition_subject = gr.Textbox(label="Condition Subject")
            style_subject = gr.Textbox(label="Style Subject")
            prompt = gr.Textbox(label="Prompt")
            run_button = gr.Button()
            with gr.Accordion(label="Advanced options", open=False):
                negative_prompt = gr.Textbox(label="Negative Prompt", value=DEFAULT_NEGATIVE_PROMPT)
                seed = gr.Slider(
                    label="Seed",
                    minimum=0,
                    maximum=MAX_SEED,
                    step=1,
                    value=0,
                )
                randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
                guidance_scale = gr.Slider(
                    label="Guidance Scale",
                    minimum=0,
                    maximum=10,
                    step=0.1,
                    value=7.5,
                )
                num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=1,
                    maximum=MAX_INFERENCE_STEPS,
                    step=1,
                    value=25,
                )
        with gr.Column():
            result = gr.Image(label="Result")

    gr.Examples(
        examples=[
            [
                "images/kettle.jpg",
                "images/flower.jpg",
                "teapot",
                "flower",
                "on a marble table",
            ],
        ],
        inputs=[
            condition_image,
            style_image,
            condition_subject,
            style_subject,
            prompt,
        ],
        outputs=result,
        fn=run,
    )

    inputs = [
        condition_image,
        style_image,
        condition_subject,
        style_subject,
        prompt,
        negative_prompt,
        seed,
        guidance_scale,
        num_inference_steps,
    ]
    gr.on(
        triggers=[
            condition_subject.submit,
            style_subject.submit,
            prompt.submit,
            negative_prompt.submit,
            run_button.click,
        ],
        fn=randomize_seed_fn,
        inputs=[seed, randomize_seed],
        outputs=seed,
        api_name=False,
        concurrency_limit=None,
    ).then(
        fn=run,
        inputs=inputs,
        outputs=result,
        api_name="run-stylization",
        concurrency_id="gpu",
        concurrency_limit=1,
    )

if __name__ == "__main__":
    demo.queue(max_size=20).launch()