from pathlib import Path

import numpy as np
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F

# ------------------------------------------------------------------------------------- #

num_frequencies = None

# ------------------------------------------------------------------------------------- #


class CharacterDataset(Dataset):
    def __init__(
        self,
        name: str,
        dataset_dir: str,
        standardize: bool,
        num_feats: int,
        num_cams: int,
        sequential: bool,
        num_frequencies: int,
        min_freq: int,
        max_freq: int,
        load_vertices: bool,
        **kwargs,
    ):
        super().__init__()
        self.modality = "char"
        self.name = name
        self.dataset_dir = Path(dataset_dir)
        self.traj_dir = self.dataset_dir / "traj"
        self.data_dir = self.dataset_dir / self.name
        self.vert_dir = self.dataset_dir / "vert_raw"
        self.center_dir = self.dataset_dir / "char_raw"

        self.filenames = None
        self.standardize = standardize
        if self.standardize:
            mean_std = kwargs["standardization"]
            self.norm_mean = torch.Tensor(mean_std["norm_mean_h"])[:, None]
            self.norm_std = torch.Tensor(mean_std["norm_std_h"])[:, None]
            self.velocity = mean_std["velocity"]

        self.num_cams = num_cams
        self.num_feats = num_feats
        self.sequential = sequential
        self.num_frequencies = num_frequencies
        self.min_freq = min_freq
        self.max_freq = max_freq

        self.load_vertices = load_vertices

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, index):
        filename = self.filenames[index]

        char_filename = filename + ".npy"
        char_path = self.data_dir / char_filename

        raw_char_feature = torch.from_numpy(np.load((char_path))).to(torch.float32)
        padding_size = self.num_cams - raw_char_feature.shape[0]
        padded_raw_char_feature = F.pad(
            raw_char_feature, (0, 0, 0, padding_size)
        ).permute(1, 0)

        center_path = self.center_dir / char_filename  # Center to offset mesh
        center_offset = torch.from_numpy(np.load(center_path)[0]).to(torch.float32)
        if self.load_vertices:
            vert_path = self.vert_dir / char_filename
            raw_verts = np.load(vert_path, allow_pickle=True)[()]
            if raw_verts["vertices"] is None:
                num_frames = raw_char_feature.shape[0]
                verts = torch.zeros((num_frames, 6890, 3), dtype=torch.float32)
                padded_verts = torch.zeros(
                    (self.num_cams, 6890, 3), dtype=torch.float32
                )
                faces = torch.zeros((13776, 3), dtype=torch.int16)
            else:
                verts = torch.from_numpy(raw_verts["vertices"]).to(torch.float32)
                verts -= center_offset
                padded_verts = F.pad(verts, (0, 0, 0, 0, 0, padding_size))
                faces = torch.from_numpy(raw_verts["faces"]).to(torch.int16)

        char_feature = raw_char_feature.clone()
        if self.velocity:
            velocity = char_feature[1:].clone() - char_feature[:-1].clone()
            char_feature = torch.cat([raw_char_feature[0][None], velocity])

        if self.standardize:
            # Normalize the first frame (orgin) and the rest (velocity) separately
            if len(self.norm_mean) == 6:
                char_feature[0] -= self.norm_mean[:3, 0].to(raw_char_feature.device)
                char_feature[0] /= self.norm_std[:3, 0].to(raw_char_feature.device)
                char_feature[1:] -= self.norm_mean[3:, 0].to(raw_char_feature.device)
                char_feature[1:] /= self.norm_std[3:, 0].to(raw_char_feature.device)
            # Normalize all in one
            else:
                char_feature -= self.norm_mean[:, 0].to(raw_char_feature.device)
                char_feature /= self.norm_std[:, 0].to(raw_char_feature.device)
        padded_char_feature = F.pad(
            char_feature,
            (0, 0, 0, self.num_cams - char_feature.shape[0]),
        )

        if self.sequential:
            padded_char_feature = padded_char_feature.permute(1, 0)
        else:
            padded_char_feature = padded_char_feature.reshape(-1)

        raw_feats = {"char_raw_feat": padded_raw_char_feature}
        if self.load_vertices:
            raw_feats["char_vertices"] = padded_verts
            raw_feats["char_faces"] = faces

        return char_filename, padded_char_feature, raw_feats