# type: ignore
# Inspired from https://github.com/ixarchakos/try-off-anyone/blob/aa3045453013065573a647e4536922bac696b968/src/model/pipeline.py
# Inspired from https://github.com/ixarchakos/try-off-anyone/blob/aa3045453013065573a647e4536922bac696b968/src/model/attention.py

import torch
from accelerate import load_checkpoint_in_model
from diffusers import AutoencoderKL, DDIMScheduler, UNet2DConditionModel
from diffusers.models.attention_processor import AttnProcessor
from diffusers.utils.torch_utils import randn_tensor
from huggingface_hub import hf_hub_download
from PIL import Image


class Skip(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def __call__(
        self,
        attn: torch.Tensor,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor = None,
        attention_mask: torch.Tensor = None,
        temb: torch.Tensor = None,
    ) -> torch.Tensor:
        return hidden_states


def fine_tuned_modules(unet: UNet2DConditionModel) -> torch.nn.ModuleList:
    trainable_modules = torch.nn.ModuleList()

    for blocks in [unet.down_blocks, unet.mid_block, unet.up_blocks]:
        if hasattr(blocks, "attentions"):
            trainable_modules.append(blocks.attentions)
        else:
            for block in blocks:
                if hasattr(block, "attentions"):
                    trainable_modules.append(block.attentions)

    return trainable_modules


def skip_cross_attentions(unet: UNet2DConditionModel) -> dict[str, AttnProcessor | Skip]:
    attn_processors = {
        name: unet.attn_processors[name] if name.endswith("attn1.processor") else Skip()
        for name in unet.attn_processors.keys()
    }
    return attn_processors


def encode(image: torch.Tensor, vae: AutoencoderKL) -> torch.Tensor:
    image = image.to(memory_format=torch.contiguous_format).float().to(vae.device, dtype=vae.dtype)
    with torch.no_grad():
        return vae.encode(image).latent_dist.sample() * vae.config.scaling_factor


class TryOffAnyone:
    def __init__(
        self,
        device: torch.device,
        dtype: torch.dtype,
        concat_dim: int = -2,
    ) -> None:
        self.concat_dim = concat_dim
        self.device = device
        self.dtype = dtype

        self.noise_scheduler = DDIMScheduler.from_pretrained(
            pretrained_model_name_or_path="stable-diffusion-v1-5/stable-diffusion-inpainting",
            subfolder="scheduler",
        )
        self.vae = AutoencoderKL.from_pretrained(
            pretrained_model_name_or_path="stabilityai/sd-vae-ft-mse",
        ).to(device, dtype=dtype)
        self.unet = UNet2DConditionModel.from_pretrained(
            pretrained_model_name_or_path="stable-diffusion-v1-5/stable-diffusion-inpainting",
            subfolder="unet",
            variant="fp16",
        ).to(device, dtype=dtype)

        self.unet.set_attn_processor(skip_cross_attentions(self.unet))
        load_checkpoint_in_model(
            model=fine_tuned_modules(unet=self.unet),
            checkpoint=hf_hub_download(
                repo_id="ixarchakos/tryOffAnyone",
                filename="model.safetensors",
            ),
        )

    @torch.no_grad()
    def __call__(
        self,
        image: torch.Tensor,
        mask: torch.Tensor,
        inference_steps: int,
        scale: float,
        generator: torch.Generator,
    ) -> list[Image.Image]:
        image = image.unsqueeze(0).to(self.device, dtype=self.dtype)
        mask = (mask.unsqueeze(0) > 0.5).to(self.device, dtype=self.dtype)
        masked_image = image * (mask < 0.5)

        masked_latent = encode(masked_image, self.vae)
        image_latent = encode(image, self.vae)
        mask = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode="nearest")

        masked_latent_concat = torch.cat([masked_latent, image_latent], dim=self.concat_dim)
        mask_concat = torch.cat([mask, torch.zeros_like(mask)], dim=self.concat_dim)

        latents = randn_tensor(
            shape=masked_latent_concat.shape,
            generator=generator,
            device=self.device,
            dtype=self.dtype,
        )

        self.noise_scheduler.set_timesteps(inference_steps, device=self.device)
        timesteps = self.noise_scheduler.timesteps

        if do_classifier_free_guidance := (scale > 1.0):
            masked_latent_concat = torch.cat(
                [
                    torch.cat([masked_latent, torch.zeros_like(image_latent)], dim=self.concat_dim),
                    masked_latent_concat,
                ]
            )

            mask_concat = torch.cat([mask_concat] * 2)

        extra_step = {"generator": generator, "eta": 1.0}
        for t in timesteps:
            input_latents = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
            input_latents = self.noise_scheduler.scale_model_input(input_latents, t)

            input_latents = torch.cat([input_latents, mask_concat, masked_latent_concat], dim=1)

            noise_pred = self.unet(
                input_latents,
                t.to(self.device),
                encoder_hidden_states=None,
                return_dict=False,
            )[0]

            if do_classifier_free_guidance:
                noise_pred_unc, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_unc + scale * (noise_pred_text - noise_pred_unc)

            latents = self.noise_scheduler.step(noise_pred, t, latents, **extra_step).prev_sample

        latents = latents.split(latents.shape[self.concat_dim] // 2, dim=self.concat_dim)[0]
        latents = 1 / self.vae.config.scaling_factor * latents
        image = self.vae.decode(latents.to(self.device, dtype=self.dtype)).sample
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).float().numpy()

        image = (image * 255).round().astype("uint8")
        image = [Image.fromarray(im) for im in image]

        return image