from __future__ import annotations

import math
import random

import gradio as gr
import torch
from PIL import Image, ImageOps
from diffusers import StableDiffusionPipeline

help_text = """
"""

example_instructions = [
    "A river"
]

model_id = "dimentox/heightmapstyle"


def main():
    pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None)

    # example_image = Image.open("imgs/example.jpg").convert("RGB")

    def load_example(
            steps: int,
            randomize_seed: bool,
            seed: int,
            randomize_cfg: bool,
            text_cfg_scale: float,
            image_cfg_scale: float,
    ):
        example_instruction = random.choice(example_instructions)
        return [example_instruction] + generate(
            example_instruction,
            steps,
            randomize_seed,
            seed,
            randomize_cfg,
            text_cfg_scale,
            image_cfg_scale,
        )

    def generate(
            instruction: str,
            steps: int,
            randomize_seed: bool,
            seed: int,
            randomize_cfg: bool,
            text_cfg_scale: float,
            image_cfg_scale: float,
    ):
        seed = random.randint(0, 100000) if randomize_seed else seed
        text_cfg_scale = round(random.uniform(6.0, 9.0), ndigits=2) if randomize_cfg else text_cfg_scale
        image_cfg_scale = round(random.uniform(1.2, 1.8), ndigits=2) if randomize_cfg else image_cfg_scale

        # width, height = input_image.size
        # factor = 512 / max(width, height)
        # factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
        # width = int((width * factor) // 64) * 64
        # height = int((height * factor) // 64) * 64
        # input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)

        if instruction == "":
            return [seed]

        generator = torch.manual_seed(seed)
        edited_image = pipe(
            instruction,
            guidance_scale=text_cfg_scale, image_guidance_scale=image_cfg_scale,
            num_inference_steps=steps, generator=generator,
        ).images[0]
        return [seed, text_cfg_scale, image_cfg_scale, edited_image]

    def reset():
        return [0, "Randomize Seed", 1371, "Fix CFG", 7.5, 1.5, None]

    with gr.Blocks() as demo:
        gr.HTML("""

        """)
        with gr.Row():
            with gr.Column(scale=1, min_width=100):
                generate_button = gr.Button("Generate")
            with gr.Column(scale=1, min_width=100):
                load_button = gr.Button("Load Example")
            with gr.Column(scale=1, min_width=100):
                reset_button = gr.Button("Reset")
            with gr.Column(scale=3):
                instruction = gr.Textbox(lines=1, label="Edit Instruction", interactive=True)
        with gr.Row():
            
            edited_image = gr.Image(label=f"Edited Image", type="pil", interactive=False)
         
            edited_image.style(height=512, width=512)
        with gr.Row():
            steps = gr.Number(value=50, precision=0, label="Steps", interactive=True)
            randomize_seed = gr.Radio(
                ["Fix Seed", "Randomize Seed"],
                value="Randomize Seed",
                type="index",
                show_label=False,
                interactive=True,
            )
            seed = gr.Number(value=1371, precision=0, label="Seed", interactive=True)
            randomize_cfg = gr.Radio(
                ["Fix CFG", "Randomize CFG"],
                value="Fix CFG",
                type="index",
                show_label=False,
                interactive=True,
            )
            text_cfg_scale = gr.Number(value=7.5, label=f"Text CFG", interactive=True)
            image_cfg_scale = gr.Number(value=1.5, label=f"Image CFG", interactive=True)

        gr.Markdown(help_text)

        load_button.click(
            fn=load_example,
            inputs=[
                steps,
                randomize_seed,
                seed,
                randomize_cfg,
                text_cfg_scale,
                image_cfg_scale,
            ],
            outputs=[instruction, seed, text_cfg_scale, image_cfg_scale, edited_image],
        )
        generate_button.click(
            fn=generate,
            inputs=[
                instruction,
                steps,
                randomize_seed,
                seed,
                randomize_cfg,
                text_cfg_scale,
                image_cfg_scale,
            ],
            outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image],
        )
        reset_button.click(
            fn=reset,
            inputs=[],
            outputs=[steps, randomize_seed, seed, randomize_cfg, text_cfg_scale, image_cfg_scale, edited_image],
        )

    demo.queue(concurrency_count=1)
    demo.launch(share=False)


if __name__ == "__main__":
    main()

import gradio as gr

gr.Examples(
    [["heightmapsstyle", "a lake with a river"],
     ["heightmapsstyle", "greyscale", "a river running though flat planes"]],
    [txt, txt_2],
    cache_examples=True,
)
gr.load().launch()



# sr_b64 = super_resolution(hmap_b64)