# original code by zenafey

from utils import place_lora, get_exif_data
from css import css
from grutils import *
import inference


lora_list = pipe.constant("/sd/loras")
samplers = pipe.constant("/sd/samplers")


with gr.Blocks(css=css, theme="zenafey/prodia-web") as demo:
    model = gr.Dropdown(interactive=True, value="anything-v4.5-pruned.ckpt [65745d25]", show_label=True, label="Stable Diffusion Checkpoint",
                        choices=model_list, elem_id="model_dd")

    with gr.Tabs() as tabs:
        with gr.Tab("txt2img", id='t2i'):
            with gr.Row():
                with gr.Column(scale=6, min_width=600):
                    prompt = gr.Textbox("space warrior, beautiful, female, ultrarealistic, soft lighting, 8k",
                                        placeholder="Prompt", show_label=False, lines=3)
                    negative_prompt = gr.Textbox(placeholder="Negative Prompt", show_label=False, lines=3,
                                                 value="3d, cartoon, anime, (deformed eyes, nose, ears, nose), bad anatomy, ugly")
                with gr.Row():
                    t2i_generate_btn = gr.Button("Generate", variant='primary', elem_id="generate")

                    t2i_stop_btn = gr.Button("Cancel", variant="stop", elem_id="generate", visible=False)

            with gr.Row():
                with gr.Column():
                    with gr.Tab("Generation"):
                        with gr.Row():
                            with gr.Column(scale=1):
                                sampler = gr.Dropdown(value="DPM++ 2M Karras", show_label=True, label="Sampling Method",
                                                      choices=samplers)

                            with gr.Column(scale=1):
                                steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=60, value=25, step=0.5)

                        with gr.Row():
                            with gr.Column(scale=8):
                                width = gr.Slider(label="Width", maximum=1024, value=512, step=8)
                                height = gr.Slider(label="Height", maximum=1024, value=512, step=8)

                            with gr.Column(scale=1):
                                batch_size = gr.Slider(label="Batch Size", maximum=1, value=1)
                                batch_count = gr.Slider(label="Batch Count", minimum=1, maximum=50, value=1, step=1)

                        cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, value=7, step=0.5)
                        seed = gr.Number(label="Seed", value=-1)

                    with gr.Tab("Lora"):
                        with gr.Row():
                            for lora in lora_list:
                                lora_btn = gr.Button(lora, size="sm")
                                lora_btn.click(place_lora, inputs=[prompt, lora_btn], outputs=prompt, queue=False)

                with gr.Column():
                    image_output = gr.Gallery(columns=3,
                        value=["https://images.prodia.xyz/8ede1a7c-c0ee-4ded-987d-6ffed35fc477.png"])

        with gr.Tab("img2img", id='i2i'):
            with gr.Row():
                with gr.Column(scale=6, min_width=600):
                    i2i_prompt = gr.Textbox("space warrior, beautiful, female, ultrarealistic, soft lighting, 8k",
                                            placeholder="Prompt", show_label=False, lines=3)
                    i2i_negative_prompt = gr.Textbox(placeholder="Negative Prompt", show_label=False, lines=3,
                                                     value="3d, cartoon, anime, (deformed eyes, nose, ears, nose), bad anatomy, ugly")
                with gr.Row():
                    i2i_generate_btn = gr.Button("Generate", variant='primary', elem_id="generate")
                    i2i_stop_btn = gr.Button("Cancel", variant="stop", elem_id="generate", visible=False)

            with gr.Row():
                with gr.Column(scale=1):
                    with gr.Tab("Generation"):
                        i2i_image_input = gr.Image(type="pil")

                        with gr.Row():
                            with gr.Column(scale=1):
                                i2i_sampler = gr.Dropdown(value="DPM++ 2M Karras", show_label=True,
                                                          label="Sampling Method", choices=samplers)

                            with gr.Column(scale=1):
                                i2i_steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=60, value=25, step=0.5)

                        with gr.Row():
                            with gr.Column(scale=6):
                                i2i_width = gr.Slider(label="Width", maximum=1024, value=512, step=8)
                                i2i_height = gr.Slider(label="Height", maximum=1024, value=512, step=8)

                            with gr.Column(scale=1):
                                i2i_batch_size = gr.Slider(label="Batch Size", maximum=1, value=1)
                                i2i_batch_count = gr.Slider(label="Batch Count", minimum=1, maximum=50, value=1, step=1)

                        i2i_cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, value=7, step=1)
                        i2i_denoising = gr.Slider(label="Denoising Strength", minimum=0, maximum=1, value=0.5, step=0.05)
                        i2i_seed = gr.Number(label="Seed", value=-1)

                    with gr.Tab("Lora"):
                        with gr.Row():
                            for lora in lora_list:
                                lora_btn = gr.Button(lora, size="sm")
                                lora_btn.click(place_lora, inputs=[i2i_prompt, lora_btn], outputs=i2i_prompt, queue=False)

                with gr.Column(scale=1):
                    i2i_image_output = gr.Gallery(columns=3,
                        value=["https://images.prodia.xyz/8ede1a7c-c0ee-4ded-987d-6ffed35fc477.png"])

        with gr.Tab("Extras"):
            with gr.Row():
                with gr.Tab("Single Image"):
                    with gr.Column():
                        upscale_image_input = gr.Image(type="pil")
                        upscale_btn = gr.Button("Generate", variant="primary")
                        upscale_stop_btn = gr.Button("Stop", variant="stop", visible=False)
                        with gr.Tab("Scale by"):
                            upscale_scale = gr.Radio([2, 4], value=2, label="Resize")

                upscale_output = gr.Image()

        with gr.Tab("PNG Info"):
            with gr.Row():
                with gr.Column():
                    image_input = gr.Image(type="pil")

                with gr.Column():
                    exif_output = gr.HTML(label="EXIF Data")
                    send_to_txt2img_btn = gr.Button("Send to txt2img")

        with gr.Tab("Past generations"):
            inference.gr_user_history.render()

        t2i_event_start = t2i_generate_btn.click(
            update_btn_start,
            outputs=[t2i_generate_btn, t2i_stop_btn],
            queue=False
        )
        t2i_event = t2i_event_start.then(
            inference.txt2img,
            inputs=[prompt, negative_prompt, model, steps, sampler, cfg_scale, width, height, seed, batch_count],
            outputs=[image_output]
        )
        t2i_event_end = t2i_event.then(
            update_btn_end,
            outputs=[t2i_generate_btn, t2i_stop_btn],
            queue=False
        )

        t2i_stop_btn.click(fn=update_btn_end, outputs=[t2i_generate_btn, t2i_stop_btn], cancels=[t2i_event], queue=False)

        image_input.upload(get_exif_data, inputs=[image_input], outputs=exif_output)

        send_to_txt2img_btn.click(
            fn=switch_to_t2i,
            outputs=[tabs],
            queue=False
        ).then(
            fn=send_to_txt2img,
            inputs=[image_input],
            outputs=[prompt, negative_prompt, steps, seed, model, sampler, width, height, cfg_scale],
            queue=False
        )

        i2i_event_start = i2i_generate_btn.click(
            update_btn_start,
            outputs=[i2i_generate_btn, i2i_stop_btn],
            queue=False
        )
        i2i_event = i2i_event_start.then(inference.img2img,
                                         inputs=[i2i_image_input, i2i_denoising, i2i_prompt, i2i_negative_prompt,
                                                 model, i2i_steps, i2i_sampler, i2i_cfg_scale, i2i_width, i2i_height,
                                                 i2i_seed, i2i_batch_count],
                                         outputs=[i2i_image_output])
        i2i_event_end = i2i_event.then(
            update_btn_end,
            outputs=[i2i_generate_btn, i2i_stop_btn],
            queue=False
        )
        i2i_stop_btn.click(fn=update_btn_end, outputs=[i2i_generate_btn, i2i_stop_btn], cancels=[i2i_event], queue=False)

        upscale_event_start = upscale_btn.click(
            fn=update_btn_start,
            outputs=[upscale_btn, upscale_stop_btn],
            queue=False
        )
        upscale_event = upscale_event_start.then(
            fn=inference.upscale,
            inputs=[upscale_image_input, upscale_scale],
            outputs=[upscale_output]
        )
        upscale_event_end = upscale_event.then(
            fn=update_btn_end,
            outputs=[upscale_btn, upscale_stop_btn],
            queue=False
        )

        upscale_stop_btn.click(fn=update_btn_end, outputs=[upscale_btn, upscale_stop_btn], cancels=[upscale_event], queue=False)

demo.queue(max_size=20, api_open=False).launch(max_threads=400)