"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""

import json
from typing import Callable, Optional

import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Rearrange

from model.guide import GuideTransformer
from model.modules.audio_encoder import Wav2VecEncoder
from model.modules.rotary_embedding_torch import RotaryEmbedding
from model.modules.transformer_modules import (
    DecoderLayerStack,
    FiLMTransformerDecoderLayer,
    RegressionTransformer,
    TransformerEncoderLayerRotary,
)
from model.utils import (
    init_weight,
    PositionalEncoding,
    prob_mask_like,
    setup_lip_regressor,
    SinusoidalPosEmb,
)
from model.vqvae import setup_tokenizer
from torch.nn import functional as F
from utils.misc import prGreen, prRed


class Audio2LipRegressionTransformer(torch.nn.Module):
    def __init__(
        self,
        n_vertices: int = 338,
        causal: bool = False,
        train_wav2vec: bool = False,
        transformer_encoder_layers: int = 2,
        transformer_decoder_layers: int = 4,
    ):
        super().__init__()
        self.n_vertices = n_vertices

        self.audio_encoder = Wav2VecEncoder()
        if not train_wav2vec:
            self.audio_encoder.eval()
            for param in self.audio_encoder.parameters():
                param.requires_grad = False

        self.regression_model = RegressionTransformer(
            transformer_encoder_layers=transformer_encoder_layers,
            transformer_decoder_layers=transformer_decoder_layers,
            d_model=512,
            d_cond=512,
            num_heads=4,
            causal=causal,
        )
        self.project_output = torch.nn.Linear(512, self.n_vertices * 3)

    def forward(self, audio):
        """
        :param audio: tensor of shape B x T x 1600
        :return: tensor of shape B x T x n_vertices x 3 containing reconstructed lip geometry
        """
        B, T = audio.shape[0], audio.shape[1]

        cond = self.audio_encoder(audio)

        x = torch.zeros(B, T, 512, device=audio.device)
        x = self.regression_model(x, cond)
        x = self.project_output(x)

        verts = x.view(B, T, self.n_vertices, 3)
        return verts


class FiLMTransformer(nn.Module):
    def __init__(
        self,
        args,
        nfeats: int,
        latent_dim: int = 512,
        ff_size: int = 1024,
        num_layers: int = 4,
        num_heads: int = 4,
        dropout: float = 0.1,
        cond_feature_dim: int = 4800,
        activation: Callable[[torch.Tensor], torch.Tensor] = F.gelu,
        use_rotary: bool = True,
        cond_mode: str = "audio",
        split_type: str = "train",
        device: str = "cuda",
        **kwargs,
    ) -> None:
        super().__init__()
        self.nfeats = nfeats
        self.cond_mode = cond_mode
        self.cond_feature_dim = cond_feature_dim
        self.add_frame_cond = args.add_frame_cond
        self.data_format = args.data_format
        self.split_type = split_type
        self.device = device

        # positional embeddings
        self.rotary = None
        self.abs_pos_encoding = nn.Identity()
        # if rotary, replace absolute embedding with a rotary embedding instance (absolute becomes an identity)
        if use_rotary:
            self.rotary = RotaryEmbedding(dim=latent_dim)
        else:
            self.abs_pos_encoding = PositionalEncoding(
                latent_dim, dropout, batch_first=True
            )

        # time embedding processing
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(latent_dim),
            nn.Linear(latent_dim, latent_dim * 4),
            nn.Mish(),
        )
        self.to_time_cond = nn.Sequential(
            nn.Linear(latent_dim * 4, latent_dim),
        )
        self.to_time_tokens = nn.Sequential(
            nn.Linear(latent_dim * 4, latent_dim * 2),
            Rearrange("b (r d) -> b r d", r=2),
        )

        # null embeddings for guidance dropout
        self.seq_len = args.max_seq_length
        emb_len = 1998  # hardcoded for now
        self.null_cond_embed = nn.Parameter(torch.randn(1, emb_len, latent_dim))
        self.null_cond_hidden = nn.Parameter(torch.randn(1, latent_dim))
        self.norm_cond = nn.LayerNorm(latent_dim)
        self.setup_audio_models()

        # set up pose/face specific parts of the model
        self.input_projection = nn.Linear(self.nfeats, latent_dim)
        if self.data_format == "pose":
            cond_feature_dim = 1024
            key_feature_dim = 104
            self.step = 30
            self.use_cm = True
            self.setup_guide_models(args, latent_dim, key_feature_dim)
            self.post_pose_layers = self._build_single_pose_conv(self.nfeats)
            self.post_pose_layers.apply(init_weight)
            self.final_conv = torch.nn.Conv1d(self.nfeats, self.nfeats, kernel_size=1)
            self.receptive_field = 25
        elif self.data_format == "face":
            self.use_cm = False
            cond_feature_dim = 1024 + 1014
            self.setup_lip_models()
            self.cond_encoder = nn.Sequential()
            for _ in range(2):
                self.cond_encoder.append(
                    TransformerEncoderLayerRotary(
                        d_model=latent_dim,
                        nhead=num_heads,
                        dim_feedforward=ff_size,
                        dropout=dropout,
                        activation=activation,
                        batch_first=True,
                        rotary=self.rotary,
                    )
                )
            self.cond_encoder.apply(init_weight)

        self.cond_projection = nn.Linear(cond_feature_dim, latent_dim)
        self.non_attn_cond_projection = nn.Sequential(
            nn.LayerNorm(latent_dim),
            nn.Linear(latent_dim, latent_dim),
            nn.SiLU(),
            nn.Linear(latent_dim, latent_dim),
        )

        # decoder
        decoderstack = nn.ModuleList([])
        for _ in range(num_layers):
            decoderstack.append(
                FiLMTransformerDecoderLayer(
                    latent_dim,
                    num_heads,
                    dim_feedforward=ff_size,
                    dropout=dropout,
                    activation=activation,
                    batch_first=True,
                    rotary=self.rotary,
                    use_cm=self.use_cm,
                )
            )
        self.seqTransDecoder = DecoderLayerStack(decoderstack)
        self.seqTransDecoder.apply(init_weight)
        self.final_layer = nn.Linear(latent_dim, self.nfeats)
        self.final_layer.apply(init_weight)

    def _build_single_pose_conv(self, nfeats: int) -> nn.ModuleList:
        post_pose_layers = torch.nn.ModuleList(
            [
                torch.nn.Conv1d(nfeats, max(256, nfeats), kernel_size=3, dilation=1),
                torch.nn.Conv1d(max(256, nfeats), nfeats, kernel_size=3, dilation=2),
                torch.nn.Conv1d(nfeats, nfeats, kernel_size=3, dilation=3),
                torch.nn.Conv1d(nfeats, nfeats, kernel_size=3, dilation=1),
                torch.nn.Conv1d(nfeats, nfeats, kernel_size=3, dilation=2),
                torch.nn.Conv1d(nfeats, nfeats, kernel_size=3, dilation=3),
            ]
        )
        return post_pose_layers

    def _run_single_pose_conv(self, output: torch.Tensor) -> torch.Tensor:
        output = torch.nn.functional.pad(output, pad=[self.receptive_field - 1, 0])
        for _, layer in enumerate(self.post_pose_layers):
            y = torch.nn.functional.leaky_relu(layer(output), negative_slope=0.2)
            if self.split_type == "train":
                y = torch.nn.functional.dropout(y, 0.2)
            if output.shape[1] == y.shape[1]:
                output = (output[:, :, -y.shape[-1] :] + y) / 2.0  # skip connection
            else:
                output = y
        return output

    def setup_guide_models(self, args, latent_dim: int, key_feature_dim: int) -> None:
        # set up conditioning info
        max_keyframe_len = len(list(range(self.seq_len))[:: self.step])
        self.null_pose_embed = nn.Parameter(
            torch.randn(1, max_keyframe_len, latent_dim)
        )
        prGreen(f"using keyframes: {self.null_pose_embed.shape}")
        self.frame_cond_projection = nn.Linear(key_feature_dim, latent_dim)
        self.frame_norm_cond = nn.LayerNorm(latent_dim)
        # for test time set up keyframe transformer
        self.resume_trans = None
        if self.split_type == "test":
            if hasattr(args, "resume_trans") and args.resume_trans is not None:
                self.resume_trans = args.resume_trans
                self.setup_guide_predictor(args.resume_trans)
            else:
                prRed("not using transformer, just using ground truth")

    def setup_guide_predictor(self, cp_path: str) -> None:
        cp_dir = cp_path.split("checkpoints/iter-")[0]
        with open(f"{cp_dir}/args.json") as f:
            trans_args = json.load(f)

        # set up tokenizer based on trans_arg load point
        self.tokenizer = setup_tokenizer(trans_args["resume_pth"])

        # set up transformer
        self.transformer = GuideTransformer(
            tokens=self.tokenizer.n_clusters,
            num_layers=trans_args["layers"],
            dim=trans_args["dim"],
            emb_len=1998,
            num_audio_layers=trans_args["num_audio_layers"],
        )
        for param in self.transformer.parameters():
            param.requires_grad = False
        prGreen("loading TRANSFORMER checkpoint from {}".format(cp_path))
        cp = torch.load(cp_path)
        missing_keys, unexpected_keys = self.transformer.load_state_dict(
            cp["model_state_dict"], strict=False
        )
        assert len(missing_keys) == 0, missing_keys
        assert len(unexpected_keys) == 0, unexpected_keys

    def setup_audio_models(self) -> None:
        self.audio_model, self.audio_resampler = setup_lip_regressor()

    def setup_lip_models(self) -> None:
        self.lip_model = Audio2LipRegressionTransformer()
        cp_path = "./assets/iter-0200000.pt"
        cp = torch.load(cp_path, map_location=torch.device(self.device))
        self.lip_model.load_state_dict(cp["model_state_dict"])
        for param in self.lip_model.parameters():
            param.requires_grad = False
        prGreen(f"adding lip conditioning {cp_path}")

    def parameters_w_grad(self):
        return [p for p in self.parameters() if p.requires_grad]

    def encode_audio(self, raw_audio: torch.Tensor) -> torch.Tensor:
        device = next(self.parameters()).device
        a0 = self.audio_resampler(raw_audio[:, :, 0].to(device))
        a1 = self.audio_resampler(raw_audio[:, :, 1].to(device))
        with torch.no_grad():
            z0 = self.audio_model.feature_extractor(a0)
            z1 = self.audio_model.feature_extractor(a1)
            emb = torch.cat((z0, z1), axis=1).permute(0, 2, 1)
        return emb

    def encode_lip(self, audio: torch.Tensor, cond_embed: torch.Tensor) -> torch.Tensor:
        reshaped_audio = audio.reshape((audio.shape[0], -1, 1600, 2))[..., 0]
        # processes 4 seconds at a time
        B, T, _ = reshaped_audio.shape
        lip_cond = torch.zeros(
            (audio.shape[0], T, 338, 3),
            device=audio.device,
            dtype=audio.dtype,
        )
        for i in range(0, T, 120):
            lip_cond[:, i : i + 120, ...] = self.lip_model(
                reshaped_audio[:, i : i + 120, ...]
            )
        lip_cond = lip_cond.permute(0, 2, 3, 1).reshape((B, 338 * 3, -1))
        lip_cond = torch.nn.functional.interpolate(
            lip_cond, size=cond_embed.shape[1], mode="nearest-exact"
        ).permute(0, 2, 1)
        cond_embed = torch.cat((cond_embed, lip_cond), dim=-1)
        return cond_embed

    def encode_keyframes(
        self, y: torch.Tensor, cond_drop_prob: float, batch_size: int
    ) -> torch.Tensor:
        pred = y["keyframes"]
        new_mask = y["mask"][..., :: self.step].squeeze((1, 2))
        pred[~new_mask] = 0.0  # pad the unknown
        pose_hidden = self.frame_cond_projection(pred.detach().clone().cuda())
        pose_embed = self.abs_pos_encoding(pose_hidden)
        pose_tokens = self.frame_norm_cond(pose_embed)
        # do conditional dropout for guide poses
        key_cond_drop_prob = cond_drop_prob
        keep_mask_pose = prob_mask_like(
            (batch_size,), 1 - key_cond_drop_prob, device=pose_tokens.device
        )
        keep_mask_pose_embed = rearrange(keep_mask_pose, "b -> b 1 1")
        null_pose_embed = self.null_pose_embed.to(pose_tokens.dtype)
        pose_tokens = torch.where(
            keep_mask_pose_embed,
            pose_tokens,
            null_pose_embed[:, : pose_tokens.shape[1], :],
        )
        return pose_tokens

    def forward(
        self,
        x: torch.Tensor,
        times: torch.Tensor,
        y: Optional[torch.Tensor] = None,
        cond_drop_prob: float = 0.0,
    ) -> torch.Tensor:
        if x.dim() == 4:
            x = x.permute(0, 3, 1, 2).squeeze(-1)
        batch_size, device = x.shape[0], x.device
        if self.cond_mode == "uncond":
            cond_embed = torch.zeros(
                (x.shape[0], x.shape[1], self.cond_feature_dim),
                dtype=x.dtype,
                device=x.device,
            )
        else:
            cond_embed = y["audio"]
            cond_embed = self.encode_audio(cond_embed)
            if self.data_format == "face":
                cond_embed = self.encode_lip(y["audio"], cond_embed)
                pose_tokens = None
            if self.data_format == "pose":
                pose_tokens = self.encode_keyframes(y, cond_drop_prob, batch_size)
        assert cond_embed is not None, "cond emb should not be none"
        # process conditioning information
        x = self.input_projection(x)
        x = self.abs_pos_encoding(x)
        audio_cond_drop_prob = cond_drop_prob
        keep_mask = prob_mask_like(
            (batch_size,), 1 - audio_cond_drop_prob, device=device
        )
        keep_mask_embed = rearrange(keep_mask, "b -> b 1 1")
        keep_mask_hidden = rearrange(keep_mask, "b -> b 1")
        cond_tokens = self.cond_projection(cond_embed)
        cond_tokens = self.abs_pos_encoding(cond_tokens)
        if self.data_format == "face":
            cond_tokens = self.cond_encoder(cond_tokens)
        null_cond_embed = self.null_cond_embed.to(cond_tokens.dtype)
        cond_tokens = torch.where(
            keep_mask_embed, cond_tokens, null_cond_embed[:, : cond_tokens.shape[1], :]
        )
        mean_pooled_cond_tokens = cond_tokens.mean(dim=-2)
        cond_hidden = self.non_attn_cond_projection(mean_pooled_cond_tokens)

        # create t conditioning
        t_hidden = self.time_mlp(times)
        t = self.to_time_cond(t_hidden)
        t_tokens = self.to_time_tokens(t_hidden)
        null_cond_hidden = self.null_cond_hidden.to(t.dtype)
        cond_hidden = torch.where(keep_mask_hidden, cond_hidden, null_cond_hidden)
        t += cond_hidden

        # cross-attention conditioning
        c = torch.cat((cond_tokens, t_tokens), dim=-2)
        cond_tokens = self.norm_cond(c)

        # Pass through the transformer decoder
        output = self.seqTransDecoder(x, cond_tokens, t, memory2=pose_tokens)
        output = self.final_layer(output)
        if self.data_format == "pose":
            output = output.permute(0, 2, 1)
            output = self._run_single_pose_conv(output)
            output = self.final_conv(output)
            output = output.permute(0, 2, 1)
        return output