# This script is extended based on https://github.com/nkolot/SPIN/blob/master/models/smpl.py

import json
import os
import pickle
from dataclasses import dataclass
from typing import Optional

import numpy as np
import torch
import torch.nn as nn

from lib.pymafx.core import constants, path_config
from lib.smplx import SMPL as _SMPL
from lib.smplx import FLAMELayer, MANOLayer, SMPLXLayer
from lib.smplx.body_models import SMPLXOutput
from lib.smplx.lbs import (
    batch_rodrigues,
    blend_shapes,
    transform_mat,
    vertices2joints,
)

SMPL_MEAN_PARAMS = path_config.SMPL_MEAN_PARAMS
SMPL_MODEL_DIR = path_config.SMPL_MODEL_DIR


@dataclass
class ModelOutput(SMPLXOutput):
    smpl_joints: Optional[torch.Tensor] = None
    joints_J19: Optional[torch.Tensor] = None
    smplx_vertices: Optional[torch.Tensor] = None
    flame_vertices: Optional[torch.Tensor] = None
    lhand_vertices: Optional[torch.Tensor] = None
    rhand_vertices: Optional[torch.Tensor] = None
    lhand_joints: Optional[torch.Tensor] = None
    rhand_joints: Optional[torch.Tensor] = None
    face_joints: Optional[torch.Tensor] = None
    lfoot_joints: Optional[torch.Tensor] = None
    rfoot_joints: Optional[torch.Tensor] = None


class SMPL(_SMPL):
    """ Extension of the official SMPL implementation to support more joints """
    def __init__(
        self,
        create_betas=False,
        create_global_orient=False,
        create_body_pose=False,
        create_transl=False,
        *args,
        **kwargs
    ):
        super().__init__(
            create_betas=create_betas,
            create_global_orient=create_global_orient,
            create_body_pose=create_body_pose,
            create_transl=create_transl,
            *args,
            **kwargs
        )
        joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES]
        J_regressor_extra = np.load(path_config.JOINT_REGRESSOR_TRAIN_EXTRA)
        self.register_buffer(
            'J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32)
        )
        self.joint_map = torch.tensor(joints, dtype=torch.long)
        # self.ModelOutput = namedtuple('ModelOutput_', ModelOutput._fields + ('smpl_joints', 'joints_J19',))
        # self.ModelOutput.__new__.__defaults__ = (None,) * len(self.ModelOutput._fields)

        tpose_joints = vertices2joints(self.J_regressor, self.v_template.unsqueeze(0))
        self.register_buffer('tpose_joints', tpose_joints)

    def forward(self, *args, **kwargs):
        kwargs['get_skin'] = True
        smpl_output = super().forward(*args, **kwargs)
        extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices)
        # smpl_output.joints: [B, 45, 3]  extra_joints: [B, 9, 3]
        vertices = smpl_output.vertices
        joints = torch.cat([smpl_output.joints, extra_joints], dim=1)
        smpl_joints = smpl_output.joints[:, :24]
        joints = joints[:, self.joint_map, :]    # [B, 49, 3]
        joints_J24 = joints[:, -24:, :]
        joints_J19 = joints_J24[:, constants.J24_TO_J19, :]
        output = ModelOutput(
            vertices=vertices,
            global_orient=smpl_output.global_orient,
            body_pose=smpl_output.body_pose,
            joints=joints,
            joints_J19=joints_J19,
            smpl_joints=smpl_joints,
            betas=smpl_output.betas,
            full_pose=smpl_output.full_pose
        )
        return output

    def get_global_rotation(
        self,
        global_orient: Optional[torch.Tensor] = None,
        body_pose: Optional[torch.Tensor] = None,
        **kwargs
    ):
        '''
        Forward pass for the SMPLX model

            Parameters
            ----------
            global_orient: torch.tensor, optional, shape Bx3x3
                If given, ignore the member variable and use it as the global
                rotation of the body. Useful if someone wishes to predicts this
                with an external model. It is expected to be in rotation matrix
                format. (default=None)
            body_pose: torch.tensor, optional, shape BxJx3x3
                If given, ignore the member variable `body_pose` and use it
                instead. For example, it can used if someone predicts the
                pose of the body joints are predicted from some external model.
                It should be a tensor that contains joint rotations in
                rotation matrix format. (default=None)
            Returns
            -------
                output: Global rotation matrix
        '''
        device, dtype = self.shapedirs.device, self.shapedirs.dtype

        model_vars = [global_orient, body_pose]
        batch_size = 1
        for var in model_vars:
            if var is None:
                continue
            batch_size = max(batch_size, len(var))

        if global_orient is None:
            global_orient = torch.eye(3, device=device,
                                      dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1,
                                                                           -1).contiguous()
        if body_pose is None:
            body_pose = torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(
                batch_size, self.NUM_BODY_JOINTS, -1, -1
            ).contiguous()

        # Concatenate all pose vectors
        full_pose = torch.cat([
            global_orient.reshape(-1, 1, 3, 3),
            body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3)
        ],
                              dim=1)

        rot_mats = full_pose.view(batch_size, -1, 3, 3)

        # Get the joints
        # NxJx3 array
        # joints = vertices2joints(self.J_regressor, self.v_template.unsqueeze(0).expand(batch_size, -1, -1))
        # joints = torch.unsqueeze(joints, dim=-1)

        joints = self.tpose_joints.expand(batch_size, -1, -1).unsqueeze(-1)

        rel_joints = joints.clone()
        rel_joints[:, 1:] -= joints[:, self.parents[1:]]

        transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3),
                                       rel_joints.reshape(-1, 3,
                                                          1)).reshape(-1, joints.shape[1], 4, 4)

        transform_chain = [transforms_mat[:, 0]]
        for i in range(1, self.parents.shape[0]):
            # Subtract the joint location at the rest pose
            # No need for rotation, since it's identity when at rest
            curr_res = torch.matmul(transform_chain[self.parents[i]], transforms_mat[:, i])
            transform_chain.append(curr_res)

        transforms = torch.stack(transform_chain, dim=1)

        global_rotmat = transforms[:, :, :3, :3]

        # The last column of the transformations contains the posed joints
        posed_joints = transforms[:, :, :3, 3]

        return global_rotmat, posed_joints


class SMPLX(SMPLXLayer):
    """ Extension of the official SMPLX implementation to support more functions """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def get_global_rotation(
        self,
        global_orient: Optional[torch.Tensor] = None,
        body_pose: Optional[torch.Tensor] = None,
        left_hand_pose: Optional[torch.Tensor] = None,
        right_hand_pose: Optional[torch.Tensor] = None,
        jaw_pose: Optional[torch.Tensor] = None,
        leye_pose: Optional[torch.Tensor] = None,
        reye_pose: Optional[torch.Tensor] = None,
        **kwargs
    ):
        '''
        Forward pass for the SMPLX model

            Parameters
            ----------
            global_orient: torch.tensor, optional, shape Bx3x3
                If given, ignore the member variable and use it as the global
                rotation of the body. Useful if someone wishes to predicts this
                with an external model. It is expected to be in rotation matrix
                format. (default=None)
            betas: torch.tensor, optional, shape BxN_b
                If given, ignore the member variable `betas` and use it
                instead. For example, it can used if shape parameters
                `betas` are predicted from some external model.
                (default=None)
            expression: torch.tensor, optional, shape BxN_e
                Expression coefficients.
                For example, it can used if expression parameters
                `expression` are predicted from some external model.
            body_pose: torch.tensor, optional, shape BxJx3x3
                If given, ignore the member variable `body_pose` and use it
                instead. For example, it can used if someone predicts the
                pose of the body joints are predicted from some external model.
                It should be a tensor that contains joint rotations in
                rotation matrix format. (default=None)
            left_hand_pose: torch.tensor, optional, shape Bx15x3x3
                If given, contains the pose of the left hand.
                It should be a tensor that contains joint rotations in
                rotation matrix format. (default=None)
            right_hand_pose: torch.tensor, optional, shape Bx15x3x3
                If given, contains the pose of the right hand.
                It should be a tensor that contains joint rotations in
                rotation matrix format. (default=None)
            jaw_pose: torch.tensor, optional, shape Bx3x3
                Jaw pose. It should either joint rotations in
                rotation matrix format.
            transl: torch.tensor, optional, shape Bx3
                Translation vector of the body.
                For example, it can used if the translation
                `transl` is predicted from some external model.
                (default=None)
            return_verts: bool, optional
                Return the vertices. (default=True)
            return_full_pose: bool, optional
                Returns the full pose vector (default=False)
            Returns
            -------
                output: ModelOutput
                A data class that contains the posed vertices and joints
        '''
        device, dtype = self.shapedirs.device, self.shapedirs.dtype

        model_vars = [global_orient, body_pose, left_hand_pose, right_hand_pose, jaw_pose]
        batch_size = 1
        for var in model_vars:
            if var is None:
                continue
            batch_size = max(batch_size, len(var))

        if global_orient is None:
            global_orient = torch.eye(3, device=device,
                                      dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1,
                                                                           -1).contiguous()
        if body_pose is None:
            body_pose = torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(
                batch_size, self.NUM_BODY_JOINTS, -1, -1
            ).contiguous()
        if left_hand_pose is None:
            left_hand_pose = torch.eye(3, device=device,
                                       dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1,
                                                                            -1).contiguous()
        if right_hand_pose is None:
            right_hand_pose = torch.eye(3, device=device,
                                        dtype=dtype).view(1, 1, 3,
                                                          3).expand(batch_size, 15, -1,
                                                                    -1).contiguous()
        if jaw_pose is None:
            jaw_pose = torch.eye(3, device=device,
                                 dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1,
                                                                      -1).contiguous()
        if leye_pose is None:
            leye_pose = torch.eye(3, device=device,
                                  dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1,
                                                                       -1).contiguous()
        if reye_pose is None:
            reye_pose = torch.eye(3, device=device,
                                  dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1,
                                                                       -1).contiguous()

        # Concatenate all pose vectors
        full_pose = torch.cat([
            global_orient.reshape(-1, 1, 3, 3),
            body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3),
            jaw_pose.reshape(-1, 1, 3, 3),
            leye_pose.reshape(-1, 1, 3, 3),
            reye_pose.reshape(-1, 1, 3, 3),
            left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3),
            right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3)
        ],
                              dim=1)

        rot_mats = full_pose.view(batch_size, -1, 3, 3)

        # Get the joints
        # NxJx3 array
        joints = vertices2joints(
            self.J_regressor,
            self.v_template.unsqueeze(0).expand(batch_size, -1, -1)
        )

        joints = torch.unsqueeze(joints, dim=-1)

        rel_joints = joints.clone()
        rel_joints[:, 1:] -= joints[:, self.parents[1:]]

        transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3),
                                       rel_joints.reshape(-1, 3,
                                                          1)).reshape(-1, joints.shape[1], 4, 4)

        transform_chain = [transforms_mat[:, 0]]
        for i in range(1, self.parents.shape[0]):
            # Subtract the joint location at the rest pose
            # No need for rotation, since it's identity when at rest
            curr_res = torch.matmul(transform_chain[self.parents[i]], transforms_mat[:, i])
            transform_chain.append(curr_res)

        transforms = torch.stack(transform_chain, dim=1)

        global_rotmat = transforms[:, :, :3, :3]

        # The last column of the transformations contains the posed joints
        posed_joints = transforms[:, :, :3, 3]

        return global_rotmat, posed_joints


class SMPLX_ALL(nn.Module):
    """ Extension of the official SMPLX implementation to support more joints """
    def __init__(self, batch_size=1, use_face_contour=True, all_gender=False, **kwargs):
        super().__init__()
        numBetas = 10
        self.use_face_contour = use_face_contour
        if all_gender:
            self.genders = ['male', 'female', 'neutral']
        else:
            self.genders = ['neutral']
        for gender in self.genders:
            assert gender in ['male', 'female', 'neutral']
        self.model_dict = nn.ModuleDict({
            gender: SMPLX(
                path_config.SMPL_MODEL_DIR,
                gender=gender,
                ext='npz',
                num_betas=numBetas,
                use_pca=False,
                batch_size=batch_size,
                use_face_contour=use_face_contour,
                num_pca_comps=45,
                **kwargs
            )
            for gender in self.genders
        })
        self.model_neutral = self.model_dict['neutral']
        joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES]
        J_regressor_extra = np.load(path_config.JOINT_REGRESSOR_TRAIN_EXTRA)
        self.register_buffer(
            'J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32)
        )
        self.joint_map = torch.tensor(joints, dtype=torch.long)
        # smplx_to_smpl.pkl, file source: https://smpl-x.is.tue.mpg.de
        smplx_to_smpl = pickle.load(
            open(os.path.join(SMPL_MODEL_DIR, 'model_transfer/smplx_to_smpl.pkl'), 'rb')
        )
        self.register_buffer(
            'smplx2smpl', torch.tensor(smplx_to_smpl['matrix'][None], dtype=torch.float32)
        )

        smpl2limb_vert_faces = get_partial_smpl('smpl')
        self.smpl2lhand = torch.from_numpy(smpl2limb_vert_faces['lhand']['vids']).long()
        self.smpl2rhand = torch.from_numpy(smpl2limb_vert_faces['rhand']['vids']).long()

        # left and right hand joint mapping
        smplx2lhand_joints = [
            constants.SMPLX_JOINT_IDS['left_{}'.format(name)] for name in constants.HAND_NAMES
        ]
        smplx2rhand_joints = [
            constants.SMPLX_JOINT_IDS['right_{}'.format(name)] for name in constants.HAND_NAMES
        ]
        self.smplx2lh_joint_map = torch.tensor(smplx2lhand_joints, dtype=torch.long)
        self.smplx2rh_joint_map = torch.tensor(smplx2rhand_joints, dtype=torch.long)

        # left and right foot joint mapping
        smplx2lfoot_joints = [
            constants.SMPLX_JOINT_IDS['left_{}'.format(name)] for name in constants.FOOT_NAMES
        ]
        smplx2rfoot_joints = [
            constants.SMPLX_JOINT_IDS['right_{}'.format(name)] for name in constants.FOOT_NAMES
        ]
        self.smplx2lf_joint_map = torch.tensor(smplx2lfoot_joints, dtype=torch.long)
        self.smplx2rf_joint_map = torch.tensor(smplx2rfoot_joints, dtype=torch.long)

        for g in self.genders:
            J_template = torch.einsum(
                'ji,ik->jk', [self.model_dict[g].J_regressor[:24], self.model_dict[g].v_template]
            )
            J_dirs = torch.einsum(
                'ji,ikl->jkl', [self.model_dict[g].J_regressor[:24], self.model_dict[g].shapedirs]
            )

            self.register_buffer(f'{g}_J_template', J_template)
            self.register_buffer(f'{g}_J_dirs', J_dirs)

    def forward(self, *args, **kwargs):
        batch_size = kwargs['body_pose'].shape[0]
        kwargs['get_skin'] = True
        if 'pose2rot' not in kwargs:
            kwargs['pose2rot'] = True
        if 'gender' not in kwargs:
            kwargs['gender'] = 2 * torch.ones(batch_size).to(kwargs['body_pose'].device)

        # pose for 55 joints: 1, 21, 15, 15, 1, 1, 1
        pose_keys = [
            'global_orient', 'body_pose', 'left_hand_pose', 'right_hand_pose', 'jaw_pose',
            'leye_pose', 'reye_pose'
        ]
        param_keys = ['betas'] + pose_keys
        if kwargs['pose2rot']:
            for key in pose_keys:
                if key in kwargs:
                    # if key == 'left_hand_pose':
                    #     kwargs[key] += self.model_neutral.left_hand_mean
                    # elif key == 'right_hand_pose':
                    #     kwargs[key] += self.model_neutral.right_hand_mean
                    kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view([
                        batch_size, -1, 3, 3
                    ])
        if kwargs['body_pose'].shape[1] == 23:
            # remove hand pose in the body_pose
            kwargs['body_pose'] = kwargs['body_pose'][:, :21]
        gender_idx_list = []
        smplx_vertices, smplx_joints = [], []
        for gi, g in enumerate(['male', 'female', 'neutral']):
            gender_idx = ((kwargs['gender'] == gi).nonzero(as_tuple=True)[0])
            if len(gender_idx) == 0:
                continue
            gender_idx_list.extend([int(idx) for idx in gender_idx])
            gender_kwargs = {'get_skin': kwargs['get_skin'], 'pose2rot': kwargs['pose2rot']}
            gender_kwargs.update({k: kwargs[k][gender_idx] for k in param_keys if k in kwargs})
            gender_smplx_output = self.model_dict[g].forward(*args, **gender_kwargs)
            smplx_vertices.append(gender_smplx_output.vertices)
            smplx_joints.append(gender_smplx_output.joints)

        idx_rearrange = [gender_idx_list.index(i) for i in range(len(list(gender_idx_list)))]
        idx_rearrange = torch.tensor(idx_rearrange).long().to(kwargs['body_pose'].device)

        smplx_vertices = torch.cat(smplx_vertices)[idx_rearrange]
        smplx_joints = torch.cat(smplx_joints)[idx_rearrange]

        # constants.HAND_NAMES
        lhand_joints = smplx_joints[:, self.smplx2lh_joint_map]
        rhand_joints = smplx_joints[:, self.smplx2rh_joint_map]
        # constants.FACIAL_LANDMARKS
        face_joints = smplx_joints[:, -68:] if self.use_face_contour else smplx_joints[:, -51:]
        # constants.FOOT_NAMES
        lfoot_joints = smplx_joints[:, self.smplx2lf_joint_map]
        rfoot_joints = smplx_joints[:, self.smplx2rf_joint_map]

        smpl_vertices = torch.bmm(self.smplx2smpl.expand(batch_size, -1, -1), smplx_vertices)
        lhand_vertices = smpl_vertices[:, self.smpl2lhand]
        rhand_vertices = smpl_vertices[:, self.smpl2rhand]
        extra_joints = vertices2joints(self.J_regressor_extra, smpl_vertices)
        # smpl_output.joints: [B, 45, 3]  extra_joints: [B, 9, 3]
        smplx_j45 = smplx_joints[:, constants.SMPLX2SMPL_J45]
        joints = torch.cat([smplx_j45, extra_joints], dim=1)
        smpl_joints = smplx_j45[:, :24]
        joints = joints[:, self.joint_map, :]    # [B, 49, 3]
        joints_J24 = joints[:, -24:, :]
        joints_J19 = joints_J24[:, constants.J24_TO_J19, :]
        output = ModelOutput(
            vertices=smpl_vertices,
            smplx_vertices=smplx_vertices,
            lhand_vertices=lhand_vertices,
            rhand_vertices=rhand_vertices,
        # global_orient=smplx_output.global_orient,
        # body_pose=smplx_output.body_pose,
            joints=joints,
            joints_J19=joints_J19,
            smpl_joints=smpl_joints,
        # betas=smplx_output.betas,
        # full_pose=smplx_output.full_pose,
            lhand_joints=lhand_joints,
            rhand_joints=rhand_joints,
            lfoot_joints=lfoot_joints,
            rfoot_joints=rfoot_joints,
            face_joints=face_joints,
        )
        return output

    # def make_hand_regressor(self):
    #     # borrowed from https://github.com/mks0601/Hand4Whole_RELEASE/blob/main/common/utils/human_models.py
    #     regressor = self.model_neutral.J_regressor.numpy()
    #     vertex_num = self.model_neutral.J_regressor.shape[-1]
    #     lhand_regressor = np.concatenate((regressor[[20,37,38,39],:],
    #                                         np.eye(vertex_num)[5361,None],
    #                                             regressor[[25,26,27],:],
    #                                             np.eye(vertex_num)[4933,None],
    #                                             regressor[[28,29,30],:],
    #                                             np.eye(vertex_num)[5058,None],
    #                                             regressor[[34,35,36],:],
    #                                             np.eye(vertex_num)[5169,None],
    #                                             regressor[[31,32,33],:],
    #                                             np.eye(vertex_num)[5286,None]))
    #     rhand_regressor = np.concatenate((regressor[[21,52,53,54],:],
    #                                         np.eye(vertex_num)[8079,None],
    #                                             regressor[[40,41,42],:],
    #                                             np.eye(vertex_num)[7669,None],
    #                                             regressor[[43,44,45],:],
    #                                             np.eye(vertex_num)[7794,None],
    #                                             regressor[[49,50,51],:],
    #                                             np.eye(vertex_num)[7905,None],
    #                                             regressor[[46,47,48],:],
    #                                             np.eye(vertex_num)[8022,None]))
    #     return torch.from_numpy(lhand_regressor).float(), torch.from_numpy(rhand_regressor).float()

    def get_tpose(self, betas=None, gender=None):
        kwargs = {}
        if betas is None:
            betas = torch.zeros(1, 10).to(self.J_regressor_extra.device)
        kwargs['betas'] = betas

        batch_size = kwargs['betas'].shape[0]
        device = kwargs['betas'].device

        if gender is None:
            kwargs['gender'] = 2 * torch.ones(batch_size).to(device)
        else:
            kwargs['gender'] = gender

        param_keys = ['betas']

        gender_idx_list = []
        smplx_joints = []
        for gi, g in enumerate(['male', 'female', 'neutral']):
            gender_idx = ((kwargs['gender'] == gi).nonzero(as_tuple=True)[0])
            if len(gender_idx) == 0:
                continue
            gender_idx_list.extend([int(idx) for idx in gender_idx])
            gender_kwargs = {}
            gender_kwargs.update({k: kwargs[k][gender_idx] for k in param_keys if k in kwargs})

            J = getattr(self, f'{g}_J_template').unsqueeze(0) + blend_shapes(
                gender_kwargs['betas'], getattr(self, f'{g}_J_dirs')
            )

            smplx_joints.append(J)

        idx_rearrange = [gender_idx_list.index(i) for i in range(len(list(gender_idx_list)))]
        idx_rearrange = torch.tensor(idx_rearrange).long().to(device)

        smplx_joints = torch.cat(smplx_joints)[idx_rearrange]

        return smplx_joints


class MANO(MANOLayer):
    """ Extension of the official MANO implementation to support more joints """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, *args, **kwargs):
        if 'pose2rot' not in kwargs:
            kwargs['pose2rot'] = True
        pose_keys = ['global_orient', 'right_hand_pose']
        batch_size = kwargs['global_orient'].shape[0]
        if kwargs['pose2rot']:
            for key in pose_keys:
                if key in kwargs:
                    kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view([
                        batch_size, -1, 3, 3
                    ])
        kwargs['hand_pose'] = kwargs.pop('right_hand_pose')
        mano_output = super().forward(*args, **kwargs)
        th_verts = mano_output.vertices
        th_jtr = mano_output.joints
        # https://github.com/hassony2/manopth/blob/master/manopth/manolayer.py#L248-L260
        # In addition to MANO reference joints we sample vertices on each finger
        # to serve as finger tips
        tips = th_verts[:, [745, 317, 445, 556, 673]]
        th_jtr = torch.cat([th_jtr, tips], 1)
        # Reorder joints to match visualization utilities
        th_jtr = th_jtr[:,
                        [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]]
        output = ModelOutput(
            rhand_vertices=th_verts,
            rhand_joints=th_jtr,
        )
        return output


class FLAME(FLAMELayer):
    """ Extension of the official FLAME implementation to support more joints """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, *args, **kwargs):
        if 'pose2rot' not in kwargs:
            kwargs['pose2rot'] = True
        pose_keys = ['global_orient', 'jaw_pose', 'leye_pose', 'reye_pose']
        batch_size = kwargs['global_orient'].shape[0]
        if kwargs['pose2rot']:
            for key in pose_keys:
                if key in kwargs:
                    kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view([
                        batch_size, -1, 3, 3
                    ])
        flame_output = super().forward(*args, **kwargs)
        output = ModelOutput(
            flame_vertices=flame_output.vertices,
            face_joints=flame_output.joints[:, 5:],
        )
        return output


class SMPL_Family():
    def __init__(self, model_type='smpl', *args, **kwargs):
        if model_type == 'smpl':
            self.model = SMPL(model_path=SMPL_MODEL_DIR, *args, **kwargs)
        elif model_type == 'smplx':
            self.model = SMPLX_ALL(*args, **kwargs)
        elif model_type == 'mano':
            self.model = MANO(
                model_path=SMPL_MODEL_DIR, is_rhand=True, use_pca=False, *args, **kwargs
            )
        elif model_type == 'flame':
            self.model = FLAME(model_path=SMPL_MODEL_DIR, use_face_contour=True, *args, **kwargs)

    def __call__(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def get_tpose(self, *args, **kwargs):
        return self.model.get_tpose(*args, **kwargs)

    # def to(self, device):
    #     self.model.to(device)

    # def cuda(self, device=None):
    #     if device is None:
    #         self.model.cuda()
    #     else:
    #         self.model.cuda(device)


def get_smpl_faces():
    smpl = SMPL(model_path=SMPL_MODEL_DIR, batch_size=1)
    return smpl.faces


def get_smplx_faces():
    smplx = SMPLX(SMPL_MODEL_DIR, batch_size=1)
    return smplx.faces


def get_mano_faces(hand_type='right'):
    assert hand_type in ['right', 'left']
    is_rhand = True if hand_type == 'right' else False
    mano = MANO(SMPL_MODEL_DIR, batch_size=1, is_rhand=is_rhand)

    return mano.faces


def get_flame_faces():
    flame = FLAME(SMPL_MODEL_DIR, batch_size=1)

    return flame.faces


def get_model_faces(type='smpl'):
    if type == 'smpl':
        return get_smpl_faces()
    elif type == 'smplx':
        return get_smplx_faces()
    elif type == 'mano':
        return get_mano_faces()
    elif type == 'flame':
        return get_flame_faces()


def get_model_tpose(type='smpl'):
    if type == 'smpl':
        return get_smpl_tpose()
    elif type == 'smplx':
        return get_smplx_tpose()
    elif type == 'mano':
        return get_mano_tpose()
    elif type == 'flame':
        return get_flame_tpose()


def get_smpl_tpose():
    smpl = SMPL(
        create_betas=True,
        create_global_orient=True,
        create_body_pose=True,
        model_path=SMPL_MODEL_DIR,
        batch_size=1
    )
    vertices = smpl().vertices[0]
    return vertices.detach()


def get_smpl_tpose_joint():
    smpl = SMPL(
        create_betas=True,
        create_global_orient=True,
        create_body_pose=True,
        model_path=SMPL_MODEL_DIR,
        batch_size=1
    )
    tpose_joint = smpl().smpl_joints[0]
    return tpose_joint.detach()


def get_smplx_tpose():
    smplx = SMPLXLayer(SMPL_MODEL_DIR, batch_size=1)
    vertices = smplx().vertices[0]
    return vertices


def get_smplx_tpose_joint():
    smplx = SMPLXLayer(SMPL_MODEL_DIR, batch_size=1)
    tpose_joint = smplx().joints[0]
    return tpose_joint


def get_mano_tpose():
    mano = MANO(SMPL_MODEL_DIR, batch_size=1, is_rhand=True)
    vertices = mano(global_orient=torch.zeros(1, 3),
                    right_hand_pose=torch.zeros(1, 15 * 3)).rhand_vertices[0]
    return vertices


def get_flame_tpose():
    flame = FLAME(SMPL_MODEL_DIR, batch_size=1)
    vertices = flame(global_orient=torch.zeros(1, 3)).flame_vertices[0]
    return vertices


def get_part_joints(smpl_joints):
    batch_size = smpl_joints.shape[0]

    # part_joints = torch.zeros().to(smpl_joints.device)

    one_seg_pairs = [(0, 1), (0, 2), (0, 3), (3, 6), (9, 12), (9, 13), (9, 14), (12, 15), (13, 16),
                     (14, 17)]
    two_seg_pairs = [(1, 4), (2, 5), (4, 7), (5, 8), (16, 18), (17, 19), (18, 20), (19, 21)]

    one_seg_pairs.extend(two_seg_pairs)

    single_joints = [(10), (11), (15), (22), (23)]

    part_joints = []

    for j_p in one_seg_pairs:
        new_joint = torch.mean(smpl_joints[:, j_p], dim=1, keepdim=True)
        part_joints.append(new_joint)

    for j_p in single_joints:
        part_joints.append(smpl_joints[:, j_p:j_p + 1])

    part_joints = torch.cat(part_joints, dim=1)

    return part_joints


def get_partial_smpl(body_model='smpl', device=torch.device('cuda')):

    body_model_faces = get_model_faces(body_model)
    body_model_num_verts = len(get_model_tpose(body_model))

    part_vert_faces = {}

    for part in ['lhand', 'rhand', 'face', 'arm', 'forearm', 'larm', 'rarm', 'lwrist', 'rwrist']:
        part_vid_fname = '{}/{}_{}_vids.npz'.format(path_config.PARTIAL_MESH_DIR, body_model, part)
        if os.path.exists(part_vid_fname):
            part_vids = np.load(part_vid_fname)
            part_vert_faces[part] = {'vids': part_vids['vids'], 'faces': part_vids['faces']}
        else:
            if part in ['lhand', 'rhand']:
                with open(
                    os.path.join(SMPL_MODEL_DIR, 'model_transfer/MANO_SMPLX_vertex_ids.pkl'), 'rb'
                ) as json_file:
                    smplx_mano_id = pickle.load(json_file)
                with open(
                    os.path.join(SMPL_MODEL_DIR, 'model_transfer/smplx_to_smpl.pkl'), 'rb'
                ) as json_file:
                    smplx_smpl_id = pickle.load(json_file)

                smplx_tpose = get_smplx_tpose()
                smpl_tpose = np.matmul(smplx_smpl_id['matrix'], smplx_tpose)

                if part == 'lhand':
                    mano_vert = smplx_tpose[smplx_mano_id['left_hand']]
                elif part == 'rhand':
                    mano_vert = smplx_tpose[smplx_mano_id['right_hand']]

                smpl2mano_id = []
                for vert in mano_vert:
                    v_diff = smpl_tpose - vert
                    v_diff = torch.sum(v_diff * v_diff, dim=1)
                    v_closest = torch.argmin(v_diff)
                    smpl2mano_id.append(int(v_closest))

                smpl2mano_vids = np.array(smpl2mano_id).astype(np.long)
                mano_faces = get_mano_faces(hand_type='right' if part == 'rhand' else 'left'
                                           ).astype(np.long)

                np.savez(part_vid_fname, vids=smpl2mano_vids, faces=mano_faces)
                part_vert_faces[part] = {'vids': smpl2mano_vids, 'faces': mano_faces}

            elif part in ['face', 'arm', 'forearm', 'larm', 'rarm']:
                with open(
                    os.path.join(SMPL_MODEL_DIR, '{}_vert_segmentation.json'.format(body_model)),
                    'rb'
                ) as json_file:
                    smplx_part_id = json.load(json_file)

                # main_body_part = list(smplx_part_id.keys())
                # print('main_body_part', main_body_part)

                if part == 'face':
                    selected_body_part = ['head']
                elif part == 'arm':
                    selected_body_part = [
                        'rightHand',
                        'leftArm',
                        'leftShoulder',
                        'rightShoulder',
                        'rightArm',
                        'leftHandIndex1',
                        'rightHandIndex1',
                        'leftForeArm',
                        'rightForeArm',
                        'leftHand',
                    ]
                    # selected_body_part = ['rightHand', 'leftArm', 'rightArm', 'leftHandIndex1', 'rightHandIndex1', 'leftForeArm', 'rightForeArm', 'leftHand',]
                elif part == 'forearm':
                    selected_body_part = [
                        'rightHand',
                        'leftHandIndex1',
                        'rightHandIndex1',
                        'leftForeArm',
                        'rightForeArm',
                        'leftHand',
                    ]
                elif part == 'arm_eval':
                    selected_body_part = ['leftArm', 'rightArm', 'leftForeArm', 'rightForeArm']
                elif part == 'larm':
                    # selected_body_part = ['leftArm', 'leftForeArm']
                    selected_body_part = ['leftForeArm']
                elif part == 'rarm':
                    # selected_body_part = ['rightArm', 'rightForeArm']
                    selected_body_part = ['rightForeArm']

                part_body_idx = []
                for k in selected_body_part:
                    part_body_idx.extend(smplx_part_id[k])

                part_body_fid = []
                for f_id, face in enumerate(body_model_faces):
                    if any(f in part_body_idx for f in face):
                        part_body_fid.append(f_id)

                smpl2head_vids = np.unique(body_model_faces[part_body_fid]).astype(np.long)

                mesh_vid_raw = np.arange(body_model_num_verts)
                head_vid_new = np.arange(len(smpl2head_vids))
                mesh_vid_raw[smpl2head_vids] = head_vid_new

                head_faces = body_model_faces[part_body_fid]
                head_faces = mesh_vid_raw[head_faces].astype(np.long)

                np.savez(part_vid_fname, vids=smpl2head_vids, faces=head_faces)
                part_vert_faces[part] = {'vids': smpl2head_vids, 'faces': head_faces}

            elif part in ['lwrist', 'rwrist']:

                if body_model == 'smplx':
                    body_model_verts = get_smplx_tpose()
                    tpose_joint = get_smplx_tpose_joint()
                elif body_model == 'smpl':
                    body_model_verts = get_smpl_tpose()
                    tpose_joint = get_smpl_tpose_joint()

                wrist_joint = tpose_joint[20] if part == 'lwrist' else tpose_joint[21]

                dist = 0.005
                wrist_vids = []
                for vid, vt in enumerate(body_model_verts):

                    v_j_dist = torch.sum((vt - wrist_joint)**2)

                    if v_j_dist < dist:
                        wrist_vids.append(vid)

                wrist_vids = np.array(wrist_vids)

                part_body_fid = []
                for f_id, face in enumerate(body_model_faces):
                    if any(f in wrist_vids for f in face):
                        part_body_fid.append(f_id)

                smpl2part_vids = np.unique(body_model_faces[part_body_fid]).astype(np.long)

                mesh_vid_raw = np.arange(body_model_num_verts)
                part_vid_new = np.arange(len(smpl2part_vids))
                mesh_vid_raw[smpl2part_vids] = part_vid_new

                part_faces = body_model_faces[part_body_fid]
                part_faces = mesh_vid_raw[part_faces].astype(np.long)

                np.savez(part_vid_fname, vids=smpl2part_vids, faces=part_faces)
                part_vert_faces[part] = {'vids': smpl2part_vids, 'faces': part_faces}

                # import trimesh
                # mesh = trimesh.Trimesh(vertices=body_model_verts[smpl2part_vids], faces=part_faces, process=False)
                # mesh.export(f'results/smplx_{part}.obj')

                # mesh = trimesh.Trimesh(vertices=body_model_verts, faces=body_model_faces, process=False)
                # mesh.export(f'results/smplx_model.obj')

    return part_vert_faces