from typing import TypedDict

import diffusers.image_processor
import gradio as gr
import pillow_heif  # pyright: ignore[reportMissingTypeStubs]
import spaces  # pyright: ignore[reportMissingTypeStubs]
import torch
from PIL import Image

from pipeline import TryOffAnyone

pillow_heif.register_heif_opener()  # pyright: ignore[reportUnknownMemberType]
pillow_heif.register_avif_opener()  # pyright: ignore[reportUnknownMemberType]

torch.set_float32_matmul_precision("high")
torch.backends.cuda.matmul.allow_tf32 = True

TITLE = """
# Try Off Anyone

## ⚠️ Important

1. Choose an example image or upload your own
2. Use the Pen tool to draw a mask over the clothing area you want to extract

[[arxiv:2412.08573]](https://arxiv.org/abs/2412.08573)
[[github:ixarchakos/try-off-anyone]](https://github.com/ixarchakos/try-off-anyone)
"""

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32

pipeline_tryoff = TryOffAnyone(
    device=DEVICE,
    dtype=DTYPE,
)
mask_processor = diffusers.image_processor.VaeImageProcessor(
    vae_scale_factor=8,
    do_normalize=False,
    do_binarize=True,
    do_convert_grayscale=True,
)
vae_processor = diffusers.image_processor.VaeImageProcessor(
    vae_scale_factor=8,
)


class ImageData(TypedDict):
    background: Image.Image
    composite: Image.Image
    layers: list[Image.Image]


@spaces.GPU
def process(
    image_data: ImageData,
    image_width: int,
    image_height: int,
    num_inference_steps: int,
    condition_scale: float,
    seed: int,
) -> Image.Image:
    assert image_width > 0
    assert image_height > 0
    assert num_inference_steps > 0
    assert condition_scale > 0
    assert seed >= 0

    # extract image and mask from image_data
    image = image_data["background"]
    mask = image_data["layers"][0]

    # preprocess image
    image = image.convert("RGB").resize((image_width, image_height))
    image_preprocessed = vae_processor.preprocess(  # pyright: ignore[reportUnknownMemberType,reportAssignmentType]
        image=image,
        width=image_width,
        height=image_height,
    )[0]

    # preprocess mask
    mask = mask.getchannel("A").resize((image_width, image_height))
    mask_preprocessed = mask_processor.preprocess(  # pyright: ignore[reportUnknownMemberType]
        image=mask,
        width=image_width,
        height=image_height,
    )[0]

    # generate the TryOff image
    gen = torch.Generator(device=DEVICE).manual_seed(seed)
    tryoff_image = pipeline_tryoff(
        image_preprocessed,
        mask_preprocessed,
        inference_steps=num_inference_steps,
        scale=condition_scale,
        generator=gen,
    )[0]

    return tryoff_image


with gr.Blocks() as demo:
    gr.Markdown(TITLE)

    with gr.Row():
        with gr.Column():
            input_image = gr.ImageMask(
                label="Input Image",
                height=1024,  # https://github.com/gradio-app/gradio/issues/10236
                type="pil",
                interactive=True,
            )
            run_button = gr.Button(
                value="Extract Clothing",
            )
            gr.Examples(
                examples=[
                    ["examples/model_1.jpg"],
                    ["examples/model_2.jpg"],
                    ["examples/model_3.jpg"],
                    ["examples/model_4.jpg"],
                    ["examples/model_5.jpg"],
                    ["examples/model_6.jpg"],
                    ["examples/model_7.jpg"],
                    ["examples/model_8.jpg"],
                    ["examples/model_9.jpg"],
                ],
                inputs=[input_image],
            )
        with gr.Column():
            output_image = gr.Image(
                label="TryOff result",
                height=1024,
                image_mode="RGB",
                type="pil",
            )

    with gr.Accordion("Advanced Settings", open=True):
        seed = gr.Slider(
            label="Seed",
            minimum=0,
            maximum=100_000,
            value=69_420,
            step=1,
        )
        scale = gr.Slider(
            label="Scale",
            minimum=0.5,
            maximum=5,
            value=2.5,
            step=0.05,
        )
        num_inference_steps = gr.Slider(
            label="Number of inference steps",
            minimum=1,
            maximum=50,
            value=25,
            step=1,
        )
        with gr.Row():
            image_width = gr.Slider(
                label="Image Width",
                minimum=64,
                maximum=1024,
                value=384,
                step=8,
            )
            image_height = gr.Slider(
                label="Image Height",
                minimum=64,
                maximum=1024,
                value=512,
                step=8,
            )

    run_button.click(
        fn=process,
        inputs=[
            input_image,
            image_width,
            image_height,
            num_inference_steps,
            scale,
            seed,
        ],
        outputs=output_image,
    )

demo.launch()