from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from torchvision.transforms import ToPILImage

from models.vqvae import VQVAEHF
from models.clip import FrozenCLIPEmbedder
from models.var import TVARHF, sample_with_top_k_top_p_, gumbel_softmax_with_rng


class TVARPipeline:
    vae_path = "michellemoorre/vae-test"
    text_encoder_path = "openai/clip-vit-large-patch14"
    text_encoder_2_path = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"

    def __init__(self, var, vae, text_encoder, text_encoder_2, device):
        self.var = var
        self.vae = vae
        self.text_encoder = text_encoder
        self.text_encoder_2 = text_encoder_2

        self.var.eval()
        self.vae.eval()

        self.device = device


    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, device="cuda"):
        var = TVARHF.from_pretrained(pretrained_model_name_or_path).to(device)
        vae = VQVAEHF.from_pretrained(cls.vae_path).to(device)
        text_encoder = FrozenCLIPEmbedder(cls.text_encoder_path, device=device)
        text_encoder_2 = FrozenCLIPEmbedder(cls.text_encoder_2_path, device=device)

        return cls(var, vae, text_encoder, text_encoder_2, device)


    @staticmethod
    def to_image(tensor):
        return [ToPILImage()(
            (255 * img.cpu().detach()).to(torch.uint8))
        for img in tensor]


    def encode_prompt(
        self,
        prompt: Union[str, List[str]],
        null_prompt: str = "",
        encode_null: bool = True,
    ):
        prompt = [prompt] if isinstance(prompt, str) else prompt
        encodings = [
            self.text_encoder.encode(prompt),
            self.text_encoder_2.encode(prompt),
        ]
        prompt_embeds = torch.concat(
            [encoding.last_hidden_state for encoding in encodings], dim=-1
        )
        pooled_prompt_embeds = encodings[-1].pooler_output
        attn_bias = encodings[-1].attn_bias

        if encode_null:
            null_prompt = [null_prompt] if isinstance(null_prompt, str) else prompt
            null_encodings = [
                self.text_encoder.encode(null_prompt),
                self.text_encoder_2.encode(null_prompt),
            ]
            null_prompt_embeds = torch.concat(
                [encoding.last_hidden_state for encoding in encodings], dim=-1
            )
            null_pooled_prompt_embeds = null_encodings[-1].pooler_output
            null_attn_bias = null_encodings[-1].attn_bias

            B, L, hidden_dim = prompt_embeds.shape
            pooled_dim = pooled_prompt_embeds.shape[1]

            null_prompt_embeds = null_prompt_embeds[:, :L].expand(B, L, hidden_dim).to(prompt_embeds.device)
            null_pooled_prompt_embeds = null_pooled_prompt_embeds.expand(B, pooled_dim).to(pooled_prompt_embeds.device)
            null_attn_bias = null_attn_bias[:, :L].expand(B, L).to(attn_bias.device)

            prompt_embeds = torch.cat([prompt_embeds, null_prompt_embeds], dim=0)
            pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, null_pooled_prompt_embeds], dim=0)
            attn_bias = torch.cat([attn_bias, null_attn_bias], dim=0)

        return prompt_embeds, pooled_prompt_embeds, attn_bias

    @torch.inference_mode()
    def __call__(
        self,
        prompt = None,
        null_prompt = "",
        g_seed: Optional[int] = None,
        cfg=4.0,
        top_k=450,
        top_p=0.95,
        more_smooth=False,
        re=False,
        re_max_depth=10,
        re_start_iter=2,
        return_pil=True,
        encoded_prompt = None,
        encoded_null_prompt = None,
    ) -> torch.Tensor:  # returns reconstructed image (B, 3, H, W) in [0, 1]
        """
        only used for inference, on autoregressive mode
        :param B: batch size
        :param label_B: imagenet label; if None, randomly sampled
        :param g_seed: random seed
        :param cfg: classifier-free guidance ratio
        :param top_k: top-k sampling
        :param top_p: top-p sampling
        :param more_smooth: smoothing the pred using gumbel softmax; only used in visualization, not used in FID/IS benchmarking
        :return: if returns_vemb: list of embedding h_BChw := vae_embed(idx_Bl), else: list of idx_Bl
        """
        assert not self.var.training
        var = self.var
        vae = self.vae
        vae_quant = self.vae.quantize
        if g_seed is None:
            rng = None
        else:
            var.rng.manual_seed(g_seed)
            rng = var.rng

        if encoded_prompt is not None:
            assert encoded_null_prompt is not None
            context, cond_vector, context_attn_bias = self.var.parse_batch(
                encoded_prompt,
                encoded_null_prompt,
            )
        else:
            context, cond_vector, context_attn_bias = self.encode_prompt(prompt, null_prompt)

        B = context.shape[0] // 2

        cond_vector = var.text_pooler(cond_vector)

        sos = cond_BD = cond_vector

        lvl_pos = var.lvl_embed(var.lvl_1L)
        if not var.rope:
            lvl_pos += var.pos_1LC
        next_token_map = (
            sos.unsqueeze(1)
            + var.pos_start.expand(2 * B, var.first_l, -1)
            + lvl_pos[:, : var.first_l]
        )
        cur_L = 0
        f_hat = sos.new_zeros(B, var.Cvae, var.patch_nums[-1], var.patch_nums[-1])

        for b in var.blocks:
            b.attn.kv_caching(True)
            b.cross_attn.kv_caching(True)

        for si, pn in enumerate(var.patch_nums):  # si: i-th segment
            ratio = si / var.num_stages_minus_1
            cond_BD_or_gss = var.shared_ada_lin(cond_BD)
            x_BLC = next_token_map

            if var.rope:
                freqs_cis = var.freqs_cis[:, cur_L : cur_L + pn * pn]
            else:
                freqs_cis = var.freqs_cis

            for block in var.blocks:
                x_BLC = block(
                    x=x_BLC,
                    cond_BD=cond_BD_or_gss,
                    attn_bias=None,
                    context=context,
                    context_attn_bias=context_attn_bias,
                    freqs_cis=freqs_cis,
                )
            cur_L += pn * pn

            logits_BlV = var.get_logits(x_BLC, cond_BD)

            t = cfg * ratio
            logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:]

            idx_Bl = sample_with_top_k_top_p_(
                logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1
            )[:, :, 0]
            if re and si >= re_start_iter:
                selected_logits = torch.gather(logits_BlV, -1, idx_Bl.unsqueeze(-1))[:, :, 0]
                mx = selected_logits.sum(dim=-1)[:, None]
                for _ in range(re_max_depth):
                    new_idx_Bl = sample_with_top_k_top_p_(
                        logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1
                    )[:, :, 0]
                    selected_logits = torch.gather(logits_BlV, -1, new_idx_Bl.unsqueeze(-1))[:, :, 0]

                    new_mx = selected_logits.sum(dim=-1)[:, None]
                    idx_Bl = idx_Bl * (mx >= new_mx) + new_idx_Bl * (mx < new_mx)
                    mx = mx * (mx >= new_mx) + new_mx * (mx < new_mx)
            if not more_smooth:  # this is the default case
                h_BChw = vae_quant.embedding(idx_Bl)  # B, l, Cvae
            else:  # not used when evaluating FID/IS/Precision/Recall
                gum_t = max(0.27 * (1 - ratio * 0.95), 0.005)  # refer to mask-git
                h_BChw = gumbel_softmax_with_rng(
                    logits_BlV.mul(1 + ratio), tau=gum_t, hard=False, dim=-1, rng=rng
                ) @ vae_quant.embedding.weight.unsqueeze(0)

            h_BChw = h_BChw.transpose_(1, 2).reshape(B, var.Cvae, pn, pn)
            f_hat, next_token_map = vae_quant.get_next_autoregressive_input(
                    si, len(var.patch_nums), f_hat, h_BChw
            )
            if si != var.num_stages_minus_1:  # prepare for next stage
                next_token_map = next_token_map.view(B, var.Cvae, -1).transpose(1, 2)
                next_token_map = (
                    var.word_embed(next_token_map)
                    + lvl_pos[:, cur_L : cur_L + var.patch_nums[si + 1] ** 2]
                )
                next_token_map = next_token_map.repeat(
                    2, 1, 1
                )  # double the batch sizes due to CFG

        for b in var.blocks:
            b.attn.kv_caching(False)
            b.cross_attn.kv_caching(False)

        # de-normalize, from [-1, 1] to [0, 1]
        img = vae.fhat_to_img(f_hat).add(1).mul(0.5)
        if return_pil:
            img = self.to_image(img)
        return img