from functools import reduce
from pathlib import Path

import torch
import torch.nn.functional as F


class NullableArgs:
    def __init__(self, namespace):
        for key, value in namespace.__dict__.items():
            setattr(self, key, value)

    def __getattr__(self, key):
        # when an attribute lookup has not found the attribute
        if key == 'align_mask_width':
            if 'use_alignment_mask' in self.__dict__:
                return 1 if self.use_alignment_mask else 0
            else:
                return 0
        if key == 'no_head_pose':
            return not self.predict_head_pose
        if key == 'no_use_learnable_pe':
            return not self.use_learnable_pe

        return None


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def get_option_text(args, parser):
    message = ''
    for k, v in sorted(vars(args).items()):
        comment = ''
        default = parser.get_default(k)
        if v != default:
            comment = f'\t[default: {str(default)}]'
        message += f'{str(k):>30}: {str(v):<30}{comment}\n'
    return message


def get_model_path(exp_name, iteration, model_type='DPT'):
    exp_root_dir = Path(__file__).parent.parent / 'experiments' / model_type
    exp_dir = exp_root_dir / exp_name
    if not exp_dir.exists():
        exp_dir = next(exp_root_dir.glob(f'{exp_name}*'))
    model_path = exp_dir / f'checkpoints/iter_{iteration:07}.pt'
    return model_path, exp_dir.relative_to(exp_root_dir)


def get_pose_input(coef_dict, rot_repr, with_global_pose):
    if rot_repr == 'aa':
        pose_input = coef_dict['pose'] if with_global_pose else coef_dict['pose'][..., -3:]
        # Remove mouth rotation round y, z axis
        pose_input = pose_input[..., :-2]
    else:
        raise ValueError(f'Unknown rotation representation: {rot_repr}')
    return pose_input


def get_motion_coef(coef_dict, rot_repr, with_global_pose=False, norm_stats=None):
    if norm_stats is not None:
        if rot_repr == 'aa':
            keys = ['exp', 'pose']
        else:
            raise ValueError(f'Unknown rotation representation {rot_repr}!')

        coef_dict = {k: (coef_dict[k] - norm_stats[f'{k}_mean']) / norm_stats[f'{k}_std'] for k in keys}
    pose_coef = get_pose_input(coef_dict, rot_repr, with_global_pose)
    return torch.cat([coef_dict['exp'], pose_coef], dim=-1)


def get_coef_dict(motion_coef, shape_coef=None, denorm_stats=None, with_global_pose=False, rot_repr='aa'):
    coef_dict = {
        'exp': motion_coef[..., :50]
    }
    if rot_repr == 'aa':
        if with_global_pose:
            coef_dict['pose'] = motion_coef[..., 50:]
        else:
            placeholder = torch.zeros_like(motion_coef[..., :3])
            coef_dict['pose'] = torch.cat([placeholder, motion_coef[..., -1:]], dim=-1)
        # Add back rotation around y, z axis
        coef_dict['pose'] = torch.cat([coef_dict['pose'], torch.zeros_like(motion_coef[..., :2])], dim=-1)
    else:
        raise ValueError(f'Unknown rotation representation {rot_repr}!')

    if shape_coef is not None:
        if motion_coef.ndim == 3:
            if shape_coef.ndim == 2:
                shape_coef = shape_coef.unsqueeze(1)
            if shape_coef.shape[1] == 1:
                shape_coef = shape_coef.expand(-1, motion_coef.shape[1], -1)

        coef_dict['shape'] = shape_coef

    if denorm_stats is not None:
        coef_dict = {k: coef_dict[k] * denorm_stats[f'{k}_std'] + denorm_stats[f'{k}_mean'] for k in coef_dict}

    if not with_global_pose:
        if rot_repr == 'aa':
            coef_dict['pose'][..., :3] = 0
        else:
            raise ValueError(f'Unknown rotation representation {rot_repr}!')

    return coef_dict


def coef_dict_to_vertices(coef_dict, flame, rot_repr='aa', ignore_global_rot=False, flame_batch_size=512):
    shape = coef_dict['exp'].shape[:-1]
    coef_dict = {k: v.view(-1, v.shape[-1]) for k, v in coef_dict.items()}
    n_samples = reduce(lambda x, y: x * y, shape, 1)

    # Convert to vertices
    vert_list = []
    for i in range(0, n_samples, flame_batch_size):
        batch_coef_dict = {k: v[i:i + flame_batch_size] for k, v in coef_dict.items()}
        if rot_repr == 'aa':
            vert, _, _ = flame(
                batch_coef_dict['shape'], batch_coef_dict['exp'], batch_coef_dict['pose'],
                pose2rot=True, ignore_global_rot=ignore_global_rot, return_lm2d=False, return_lm3d=False)
        else:
            raise ValueError(f'Unknown rot_repr: {rot_repr}')
        vert_list.append(vert)

    vert_list = torch.cat(vert_list, dim=0)  # (n_samples, 5023, 3)
    vert_list = vert_list.view(*shape, -1, 3)  # (..., 5023, 3)

    return vert_list


def compute_loss(args, is_starting_sample, shape_coef, motion_coef_gt, noise, target, prev_motion_coef, coef_stats,
                 flame, end_idx=None):
    if args.criterion.lower() == 'l2':
        criterion_func = F.mse_loss
    elif args.criterion.lower() == 'l1':
        criterion_func = F.l1_loss
    else:
        raise NotImplementedError(f'Criterion {args.criterion} not implemented.')

    loss_vert = None
    loss_vel = None
    loss_smooth = None
    loss_head_angle = None
    loss_head_vel = None
    loss_head_smooth = None
    loss_head_trans_vel = None
    loss_head_trans_accel = None
    loss_head_trans = None
    if args.target == 'noise':
        loss_noise = criterion_func(noise, target[:, args.n_prev_motions:], reduction='none')
    elif args.target == 'sample':
        if is_starting_sample:
            target = target[:, args.n_prev_motions:]
        else:
            motion_coef_gt = torch.cat([prev_motion_coef, motion_coef_gt], dim=1)
            if args.no_constrain_prev:
                target = torch.cat([prev_motion_coef, target[:, args.n_prev_motions:]], dim=1)

        loss_noise = criterion_func(motion_coef_gt, target, reduction='none')

        if args.l_vert > 0 or args.l_vel > 0:
            coef_gt = get_coef_dict(motion_coef_gt, shape_coef, coef_stats, with_global_pose=False,
                                    rot_repr=args.rot_repr)
            coef_pred = get_coef_dict(target, shape_coef, coef_stats, with_global_pose=False,
                                      rot_repr=args.rot_repr)
            seq_len = target.shape[1]

            if args.rot_repr == 'aa':
                verts_gt, _, _ = flame(coef_gt['shape'].view(-1, 100), coef_gt['exp'].view(-1, 50),
                                       coef_gt['pose'].view(-1, 6), return_lm2d=False, return_lm3d=False)
                verts_pred, _, _ = flame(coef_pred['shape'].view(-1, 100), coef_pred['exp'].view(-1, 50),
                                         coef_pred['pose'].view(-1, 6), return_lm2d=False, return_lm3d=False)
            else:
                raise ValueError(f'Unknown rotation representation {args.rot_repr}!')
            verts_gt = verts_gt.view(-1, seq_len, 5023, 3)
            verts_pred = verts_pred.view(-1, seq_len, 5023, 3)

            if args.l_vert > 0:
                loss_vert = criterion_func(verts_gt, verts_pred, reduction='none')

            if args.l_vel > 0:
                vel_gt = verts_gt[:, 1:] - verts_gt[:, :-1]
                vel_pred = verts_pred[:, 1:] - verts_pred[:, :-1]
                loss_vel = criterion_func(vel_gt, vel_pred, reduction='none')

            if args.l_smooth > 0:
                vel_pred = verts_pred[:, 1:] - verts_pred[:, :-1]
                loss_smooth = criterion_func(vel_pred[:, 1:], vel_pred[:, :-1], reduction='none')

        # head pose
        if not args.no_head_pose:
            if args.rot_repr == 'aa':
                head_pose_gt = motion_coef_gt[:, :, 50:53]
                head_pose_pred = target[:, :, 50:53]
            else:
                raise ValueError(f'Unknown rotation representation {args.rot_repr}!')

            if args.l_head_angle > 0:
                loss_head_angle = criterion_func(head_pose_gt, head_pose_pred, reduction='none')

            if args.l_head_vel > 0:
                head_vel_gt = head_pose_gt[:, 1:] - head_pose_gt[:, :-1]
                head_vel_pred = head_pose_pred[:, 1:] - head_pose_pred[:, :-1]
                loss_head_vel = criterion_func(head_vel_gt, head_vel_pred, reduction='none')

            if args.l_head_smooth > 0:
                head_vel_pred = head_pose_pred[:, 1:] - head_pose_pred[:, :-1]
                loss_head_smooth = criterion_func(head_vel_pred[:, 1:], head_vel_pred[:, :-1], reduction='none')

            if not is_starting_sample and args.l_head_trans > 0:
                # # version 1: constrain both the predicted previous and current motions (x_{-3} ~ x_{2})
                # head_pose_trans = head_pose_pred[:, args.n_prev_motions - 3:args.n_prev_motions + 3]
                # head_vel_pred = head_pose_trans[:, 1:] - head_pose_trans[:, :-1]
                # head_accel_pred = head_vel_pred[:, 1:] - head_vel_pred[:, :-1]

                # version 2: constrain only the predicted current motions (x_{0} ~ x_{2})
                head_pose_trans = torch.cat([head_pose_gt[:, args.n_prev_motions - 3:args.n_prev_motions],
                                             head_pose_pred[:, args.n_prev_motions:args.n_prev_motions + 3]], dim=1)
                head_vel_pred = head_pose_trans[:, 1:] - head_pose_trans[:, :-1]
                head_accel_pred = head_vel_pred[:, 1:] - head_vel_pred[:, :-1]

                # will constrain x_{-2|0} ~ x_{1}
                loss_head_trans_vel = criterion_func(head_vel_pred[:, 2:4], head_vel_pred[:, 1:3], reduction='none')
                # will constrain x_{-3|0} ~ x_{2}
                loss_head_trans_accel = criterion_func(head_accel_pred[:, 1:], head_accel_pred[:, :-1],
                                                       reduction='none')
    else:
        raise ValueError(f'Unknown diffusion target: {args.target}')

    if end_idx is None:
        mask = torch.ones((target.shape[0], args.n_motions), dtype=torch.bool, device=target.device)
    else:
        mask = torch.arange(args.n_motions, device=target.device).expand(target.shape[0], -1) < end_idx.unsqueeze(1)

    if args.target == 'sample' and not is_starting_sample:
        if args.no_constrain_prev:
            # Warning: this option will be deprecated in the future
            mask = torch.cat([torch.zeros_like(mask[:, :args.n_prev_motions]), mask], dim=1)
        else:
            mask = torch.cat([torch.ones_like(mask[:, :args.n_prev_motions]), mask], dim=1)

    loss_noise = loss_noise[mask].mean()
    if loss_vert is not None:
        loss_vert = loss_vert[mask].mean()
    if loss_vel is not None:
        loss_vel = loss_vel[mask[:, 1:]]
        loss_vel = loss_vel.mean() if torch.numel(loss_vel) > 0 else None
    if loss_smooth is not None:
        loss_smooth = loss_smooth[mask[:, 2:]]
        loss_smooth = loss_smooth.mean() if torch.numel(loss_smooth) > 0 else None
    if loss_head_angle is not None:
        loss_head_angle = loss_head_angle[mask].mean()
    if loss_head_vel is not None:
        loss_head_vel = loss_head_vel[mask[:, 1:]]
        loss_head_vel = loss_head_vel.mean() if torch.numel(loss_head_vel) > 0 else None
    if loss_head_smooth is not None:
        loss_head_smooth = loss_head_smooth[mask[:, 2:]]
        loss_head_smooth = loss_head_smooth.mean() if torch.numel(loss_head_smooth) > 0 else None
    if loss_head_trans_vel is not None:
        vel_mask = mask[:, args.n_prev_motions:args.n_prev_motions + 2]
        accel_mask = mask[:, args.n_prev_motions:args.n_prev_motions + 3]
        loss_head_trans_vel = loss_head_trans_vel[vel_mask].mean()
        loss_head_trans_accel = loss_head_trans_accel[accel_mask].mean()
        loss_head_trans = loss_head_trans_vel + loss_head_trans_accel

    return loss_noise, loss_vert, loss_vel, loss_smooth, loss_head_angle, loss_head_vel, loss_head_smooth, \
           loss_head_trans


def _truncate_audio(audio, end_idx, pad_mode='zero'):
    batch_size = audio.shape[0]
    audio_trunc = audio.clone()
    if pad_mode == 'replicate':
        for i in range(batch_size):
            audio_trunc[i, end_idx[i]:] = audio_trunc[i, end_idx[i] - 1]
    elif pad_mode == 'zero':
        for i in range(batch_size):
            audio_trunc[i, end_idx[i]:] = 0
    else:
        raise ValueError(f'Unknown pad mode {pad_mode}!')

    return audio_trunc


def _truncate_coef_dict(coef_dict, end_idx, pad_mode='zero'):
    batch_size = coef_dict['exp'].shape[0]
    coef_dict_trunc = {k: v.clone() for k, v in coef_dict.items()}
    if pad_mode == 'replicate':
        for i in range(batch_size):
            for k in coef_dict_trunc:
                coef_dict_trunc[k][i, end_idx[i]:] = coef_dict_trunc[k][i, end_idx[i] - 1]
    elif pad_mode == 'zero':
        for i in range(batch_size):
            for k in coef_dict:
                coef_dict_trunc[k][i, end_idx[i]:] = 0
    else:
        raise ValueError(f'Unknown pad mode: {pad_mode}!')

    return coef_dict_trunc


def truncate_coef_dict_and_audio(audio, coef_dict, n_motions, audio_unit=640, pad_mode='zero'):
    batch_size = audio.shape[0]
    end_idx = torch.randint(1, n_motions, (batch_size,), device=audio.device)
    audio_end_idx = (end_idx * audio_unit).long()
    # mask = torch.arange(n_motions, device=audio.device).expand(batch_size, -1) < end_idx.unsqueeze(1)

    # truncate audio
    audio_trunc = _truncate_audio(audio, audio_end_idx, pad_mode=pad_mode)

    # truncate coef dict
    coef_dict_trunc = _truncate_coef_dict(coef_dict, end_idx, pad_mode=pad_mode)

    return audio_trunc, coef_dict_trunc, end_idx


def truncate_motion_coef_and_audio(audio, motion_coef, n_motions, audio_unit=640, pad_mode='zero'):
    batch_size = audio.shape[0]
    end_idx = torch.randint(1, n_motions, (batch_size,), device=audio.device)
    audio_end_idx = (end_idx * audio_unit).long()
    # mask = torch.arange(n_motions, device=audio.device).expand(batch_size, -1) < end_idx.unsqueeze(1)

    # truncate audio
    audio_trunc = _truncate_audio(audio, audio_end_idx, pad_mode=pad_mode)

    # prepare coef dict and stats
    coef_dict = {'exp': motion_coef[..., :50], 'pose_any': motion_coef[..., 50:]}

    # truncate coef dict
    coef_dict_trunc = _truncate_coef_dict(coef_dict, end_idx, pad_mode=pad_mode)
    motion_coef_trunc = torch.cat([coef_dict_trunc['exp'], coef_dict_trunc['pose_any']], dim=-1)

    return audio_trunc, motion_coef_trunc, end_idx


def nt_xent_loss(feature_a, feature_b, temperature):
    """
    Normalized temperature-scaled cross entropy loss.

    (Adapted from https://github.com/sthalles/SimCLR/blob/master/simclr.py)

    Args:
        feature_a (torch.Tensor): shape (batch_size, feature_dim)
        feature_b (torch.Tensor): shape (batch_size, feature_dim)
        temperature (float): temperature scaling factor

    Returns:
        torch.Tensor: scalar
    """
    batch_size = feature_a.shape[0]
    device = feature_a.device

    features = torch.cat([feature_a, feature_b], dim=0)

    labels = torch.cat([torch.arange(batch_size), torch.arange(batch_size)], dim=0)
    labels = (labels.unsqueeze(0) == labels.unsqueeze(1))
    labels = labels.to(device)

    features = F.normalize(features, dim=1)
    similarity_matrix = torch.matmul(features, features.T)

    # discard the main diagonal from both: labels and similarities matrix
    mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)
    labels = labels[~mask].view(labels.shape[0], -1)
    similarity_matrix = similarity_matrix[~mask].view(labels.shape[0], -1)

    # select the positives and negatives
    positives = similarity_matrix[labels].view(labels.shape[0], -1)
    negatives = similarity_matrix[~labels].view(labels.shape[0], -1)

    logits = torch.cat([positives, negatives], dim=1)
    logits = logits / temperature
    labels = torch.zeros(labels.shape[0], dtype=torch.long).to(device)

    loss = F.cross_entropy(logits, labels)
    return loss