import tempfile
import time
from collections.abc import Sequence
from typing import Any, cast

import gradio as gr
import numpy as np
import pillow_heif
# import spaces
import torch
from gradio_image_annotation import image_annotator
from gradio_imageslider import ImageSlider
from PIL import Image
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
from refiners.fluxion.utils import no_grad
from refiners.solutions import BoxSegmenter
from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor

BoundingBox = tuple[int, int, int, int]

pillow_heif.register_heif_opener()
pillow_heif.register_avif_opener()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# weird dance because ZeroGPU
segmenter = BoxSegmenter(device="cpu")
segmenter.device = device
segmenter.model = segmenter.model.to(device=segmenter.device)

gd_model_path = "IDEA-Research/grounding-dino-base"
gd_processor = GroundingDinoProcessor.from_pretrained(gd_model_path)
gd_model = GroundingDinoForObjectDetection.from_pretrained(gd_model_path, torch_dtype=torch.float32)
gd_model = gd_model.to(device=device)  # type: ignore
assert isinstance(gd_model, GroundingDinoForObjectDetection)


def bbox_union(bboxes: Sequence[list[int]]) -> BoundingBox | None:
    if not bboxes:
        return None
    for bbox in bboxes:
        assert len(bbox) == 4
        assert all(isinstance(x, int) for x in bbox)
    return (
        min(bbox[0] for bbox in bboxes),
        min(bbox[1] for bbox in bboxes),
        max(bbox[2] for bbox in bboxes),
        max(bbox[3] for bbox in bboxes),
    )


def corners_to_pixels_format(bboxes: torch.Tensor, width: int, height: int) -> torch.Tensor:
    x1, y1, x2, y2 = bboxes.round().to(torch.int32).unbind(-1)
    return torch.stack((x1.clamp_(0, width), y1.clamp_(0, height), x2.clamp_(0, width), y2.clamp_(0, height)), dim=-1)


def gd_detect(img: Image.Image, prompt: str) -> BoundingBox | None:
    assert isinstance(gd_processor, GroundingDinoProcessor)

    # Grounding Dino expects a dot after each category.
    inputs = gd_processor(images=img, text=f"{prompt}.", return_tensors="pt").to(device=device)

    with no_grad():
        outputs = gd_model(**inputs)
    width, height = img.size
    results: dict[str, Any] = gd_processor.post_process_grounded_object_detection(
        outputs,
        inputs["input_ids"],
        target_sizes=[(height, width)],
    )[0]
    assert "boxes" in results and isinstance(results["boxes"], torch.Tensor)

    bboxes = corners_to_pixels_format(results["boxes"].cpu(), width, height)
    return bbox_union(bboxes.numpy().tolist())


def apply_mask(
    img: Image.Image,
    mask_img: Image.Image,
    defringe: bool = True,
) -> Image.Image:
    assert img.size == mask_img.size
    img = img.convert("RGB")
    mask_img = mask_img.convert("L")

    if defringe:
        # Mitigate edge halo effects via color decontamination
        rgb, alpha = np.asarray(img) / 255.0, np.asarray(mask_img) / 255.0
        foreground = cast(np.ndarray[Any, np.dtype[np.uint8]], estimate_foreground_ml(rgb, alpha))
        img = Image.fromarray((foreground * 255).astype("uint8"))

    result = Image.new("RGBA", img.size)
    result.paste(img, (0, 0), mask_img)
    return result


# @spaces.GPU
def _gpu_process(
    img: Image.Image,
    prompt: str | BoundingBox | None,
) -> tuple[Image.Image, BoundingBox | None, list[str]]:
    # Because of ZeroGPU shenanigans, we need a *single* function with the
    # `spaces.GPU` decorator that *does not* contain postprocessing.

    time_log: list[str] = []

    if isinstance(prompt, str):
        t0 = time.time()
        bbox = gd_detect(img, prompt)
        time_log.append(f"detect: {time.time() - t0}")
        if not bbox:
            print(time_log[0])
            raise gr.Error("No object detected")
    else:
        bbox = prompt

    t0 = time.time()
    mask = segmenter(img, bbox)
    time_log.append(f"segment: {time.time() - t0}")

    return mask, bbox, time_log


def _process(
    img: Image.Image,
    prompt: str | BoundingBox | None,
) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]:
    # enforce max dimensions for pymatting performance reasons
    if img.width > 2048 or img.height > 2048:
        orig_res = max(img.width, img.height)
        img.thumbnail((2048, 2048))
        if isinstance(prompt, tuple):
            x0, y0, x1, y1 = (int(x * 2048 / orig_res) for x in prompt)
            prompt = (x0, y0, x1, y1)

    mask, bbox, time_log = _gpu_process(img, prompt)

    t0 = time.time()
    masked_alpha = apply_mask(img, mask, defringe=True)
    time_log.append(f"crop: {time.time() - t0}")
    print(", ".join(time_log))

    masked_rgb = Image.alpha_composite(Image.new("RGBA", masked_alpha.size, "white"), masked_alpha)

    thresholded = mask.point(lambda p: 255 if p > 10 else 0)
    bbox = thresholded.getbbox()
    to_dl = masked_alpha.crop(bbox)

    temp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
    to_dl.save(temp, format="PNG")
    temp.close()

    return (img, masked_rgb), gr.DownloadButton(value=temp.name, interactive=True)


def process_bbox(prompts: dict[str, Any]) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]:
    assert isinstance(img := prompts["image"], Image.Image)
    assert isinstance(boxes := prompts["boxes"], list)
    if len(boxes) == 1:
        assert isinstance(box := boxes[0], dict)
        bbox = tuple(box[k] for k in ["xmin", "ymin", "xmax", "ymax"])
    else:
        assert len(boxes) == 0
        bbox = None
    return _process(img, bbox)


def on_change_bbox(prompts: dict[str, Any] | None):
    return gr.update(interactive=prompts is not None)


def process_prompt(img: Image.Image, prompt: str) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]:
    return _process(img, prompt)


def on_change_prompt(img: Image.Image | None, prompt: str | None):
    return gr.update(interactive=bool(img and prompt))


css = """
footer {
    visibility: hidden;
}
"""


with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo:

    with gr.Tab("By prompt", id="tab_prompt"):
        with gr.Row():
            with gr.Column():
                iimg = gr.Image(type="pil", label="Input")
                prompt = gr.Textbox(label="What should we cut?")
                btn = gr.ClearButton(value="Cut Out Object", interactive=False)
            with gr.Column():
                oimg = ImageSlider(label="Before / After", show_download_button=False, interactive=False)
                dlbt = gr.DownloadButton("Download Cutout", interactive=False)

        btn.add(oimg)

        for inp in [iimg, prompt]:
            inp.change(
                fn=on_change_prompt,
                inputs=[iimg, prompt],
                outputs=[btn],
            )
        btn.click(
            fn=process_prompt,
            inputs=[iimg, prompt],
            outputs=[oimg, dlbt],
            api_name=False,
        )

        examples = [
            [
                "examples/text.jpg",
                "text",
            ],            
            [
                "examples/potted-plant.jpg",
                "potted plant",
            ],
            [
                "examples/chair.jpg",
                "chair",
            ],
            [
                "examples/black-lamp.jpg",
                "black lamp",
            ],
        ]

        ex = gr.Examples(
            examples=examples,
            inputs=[iimg, prompt],
            outputs=[oimg, dlbt],
            fn=process_prompt,
            cache_examples=True,
        )

    with gr.Tab("By bounding box", id="tab_bb"):
        with gr.Row():
            with gr.Column():
                annotator = image_annotator(
                    image_type="pil",
                    disable_edit_boxes=True,
                    show_download_button=False,
                    show_share_button=False,
                    single_box=True,
                    label="Input",
                )
                btn = gr.ClearButton(value="Cut Out Object", interactive=False)
            with gr.Column():
                oimg = ImageSlider(label="Before / After", show_download_button=False)
                dlbt = gr.DownloadButton("Download Cutout", interactive=False)

        btn.add(oimg)

        annotator.change(
            fn=on_change_bbox,
            inputs=[annotator],
            outputs=[btn],
        )
        btn.click(
            fn=process_bbox,
            inputs=[annotator],
            outputs=[oimg, dlbt],
            api_name=False,
        )

        examples = [
            {
                "image": "examples/text.jpg",
                "boxes": [{"xmin": 51, "ymin": 511, "xmax": 639, "ymax": 1255}],
            },            
            {
                "image": "examples/potted-plant.jpg",
                "boxes": [{"xmin": 51, "ymin": 511, "xmax": 639, "ymax": 1255}],
            },
            {
                "image": "examples/chair.jpg",
                "boxes": [{"xmin": 98, "ymin": 330, "xmax": 973, "ymax": 1468}],
            },
            {
                "image": "examples/black-lamp.jpg",
                "boxes": [{"xmin": 88, "ymin": 148, "xmax": 700, "ymax": 1414}],
            },
        ]

        ex = gr.Examples(
            examples=examples,
            inputs=[annotator],
            outputs=[oimg, dlbt],
            fn=process_bbox,
            cache_examples=True,
        )


demo.queue(max_size=30, api_open=False)
demo.launch(show_api=False)