import gradio as gr
import torch
import spaces
from PIL import Image, ImageDraw, ImageFont
from src.condition import Condition
from diffusers.pipelines import FluxPipeline
import numpy as np

from src.generate import seed_everything, generate

pipe = None
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
)
pipe = pipe.to("cuda")
pipe.load_lora_weights(
    "Yuanshi/OminiControl",
    weight_name=f"omini/subject_512.safetensors",
    adapter_name="subject_512",
)
pipe.load_lora_weights(
    "Yuanshi/OminiControl",
    weight_name=f"omini/subject_1024_beta.safetensors",
    adapter_name="subject_1024",
)


@spaces.GPU
def process_image_and_text(image, resolution, text):
    w, h, min_size = image.size[0], image.size[1], min(image.size)
    image = image.crop(
        (
            (w - min_size) // 2,
            (h - min_size) // 2,
            (w + min_size) // 2,
            (h + min_size) // 2,
        )
    )
    image = image.resize((512, 512))

    condition = Condition("subject", image)

    result_img = generate(
        pipe,
        prompt=text.strip(),
        conditions=[condition],
        num_inference_steps=8,
        height=resolution,
        width=resolution,
    ).images[0]

    return result_img


def get_samples():
    sample_list = [
        {
            "image": "assets/oranges.jpg",
            "resolution": 512,
            "text": "A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!'",
        },
        {
            "image": "assets/penguin.jpg",
            "resolution": 512,
            "text": "On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat, holding a sign that reads 'Omini Control!'",
        },
        {
            "image": "assets/rc_car.jpg",
            "resolution": 1024,
            "text": "A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.",
        },
        {
            "image": "assets/clock.jpg",
            "resolution": 1024,
            "text": "In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.",
        },
    ]
    return [
        [
            Image.open(sample["image"]).resize((512, 512)),
            sample["resolution"],
            sample["text"],
        ]
        for sample in sample_list
    ]


header = """
# 🌍 OminiControl / FLUX

<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
<a href="https://arxiv.org/abs/2411.15098"><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a>
<a href="https://huggingface.co/Yuanshi/OminiControl"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a>
<a href="https://github.com/Yuanshi9815/OminiControl"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a>
</div>
"""


def create_app():
    with gr.Blocks() as app:
        gr.Markdown(header)
        with gr.Tabs():
            with gr.Tab("Subject-driven"):
                gr.Interface(
                    fn=process_image_and_text,
                    inputs=[
                        gr.Image(type="pil", label="Condition Image", width=300),
                        gr.Radio(
                            [("512", 512), ("1024(beta)", 1024)],
                            label="Resolution",
                            value=512,
                        ),
                        # gr.Slider(4, 16, 4, step=4, label="Inference Steps"),
                        gr.Textbox(lines=2, label="Text Prompt"),
                    ],
                    outputs=gr.Image(type="pil"),
                    examples=get_samples(),
                )
            with gr.Tab("Fill"):
                gr.Markdown("Coming soon")
            with gr.Tab("Canny"):
                gr.Markdown("Coming soon")
            with gr.Tab("Depth"):
                gr.Markdown("Coming soon")
    return app


if __name__ == "__main__":
    create_app().launch(debug=True, ssr_mode=False)