#  -*- coding: utf-8 -*-

# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de

import logging
import os
import os.path as osp
import pickle
from collections import namedtuple
from typing import Dict, Optional, Union

import numpy as np
import torch
import torch.nn as nn

logging.getLogger("smplx").setLevel(logging.ERROR)

from .lbs import find_dynamic_lmk_idx_and_bcoords, lbs, vertices2landmarks
from .utils import (
    Array,
    FLAMEOutput,
    MANOOutput,
    SMPLHOutput,
    SMPLOutput,
    SMPLXOutput,
    Struct,
    Tensor,
    find_joint_kin_chain,
    to_np,
    to_tensor,
)
from .vertex_ids import vertex_ids as VERTEX_IDS
from .vertex_joint_selector import VertexJointSelector

ModelOutput = namedtuple(
    "ModelOutput",
    [
        "vertices",
        "joints",
        "full_pose",
        "betas",
        "global_orient",
        "body_pose",
        "expression",
        "left_hand_pose",
        "right_hand_pose",
        "jaw_pose",
    ],
)
ModelOutput.__new__.__defaults__ = (None, ) * len(ModelOutput._fields)


class SMPL(nn.Module):

    NUM_JOINTS = 23
    NUM_BODY_JOINTS = 23
    SHAPE_SPACE_DIM = 300

    def __init__(
        self,
        model_path: str,
        kid_template_path: str = "",
        data_struct: Optional[Struct] = None,
        create_betas: bool = True,
        betas: Optional[Tensor] = None,
        num_betas: int = 10,
        create_global_orient: bool = True,
        global_orient: Optional[Tensor] = None,
        create_body_pose: bool = True,
        body_pose: Optional[Tensor] = None,
        create_transl: bool = True,
        transl: Optional[Tensor] = None,
        dtype=torch.float32,
        batch_size: int = 1,
        joint_mapper=None,
        gender: str = "neutral",
        age: str = "adult",
        vertex_ids: Dict[str, int] = None,
        v_template: Optional[Union[Tensor, Array]] = None,
        v_personal: Optional[Union[Tensor, Array]] = None,
        **kwargs,
    ) -> None:
        """SMPL model constructor

        Parameters
        ----------
        model_path: str
            The path to the folder or to the file where the model
            parameters are stored
        data_struct: Strct
            A struct object. If given, then the parameters of the model are
            read from the object. Otherwise, the model tries to read the
            parameters from the given `model_path`. (default = None)
        create_global_orient: bool, optional
            Flag for creating a member variable for the global orientation
            of the body. (default = True)
        global_orient: torch.tensor, optional, Bx3
            The default value for the global orientation variable.
            (default = None)
        create_body_pose: bool, optional
            Flag for creating a member variable for the pose of the body.
            (default = True)
        body_pose: torch.tensor, optional, Bx(Body Joints * 3)
            The default value for the body pose variable.
            (default = None)
        num_betas: int, optional
            Number of shape components to use
            (default = 10).
        create_betas: bool, optional
            Flag for creating a member variable for the shape space
            (default = True).
        betas: torch.tensor, optional, Bx10
            The default value for the shape member variable.
            (default = None)
        create_transl: bool, optional
            Flag for creating a member variable for the translation
            of the body. (default = True)
        transl: torch.tensor, optional, Bx3
            The default value for the transl variable.
            (default = None)
        dtype: torch.dtype, optional
            The data type for the created variables
        batch_size: int, optional
            The batch size used for creating the member variables
        joint_mapper: object, optional
            An object that re-maps the joints. Useful if one wants to
            re-order the SMPL joints to some other convention (e.g. MSCOCO)
            (default = None)
        gender: str, optional
            Which gender to load
        vertex_ids: dict, optional
            A dictionary containing the indices of the extra vertices that
            will be selected
        """

        self.gender = gender
        self.age = age

        if data_struct is None:
            if osp.isdir(model_path):
                model_fn = "SMPL_{}.{ext}".format(gender.upper(), ext="pkl")
                smpl_path = os.path.join(model_path, model_fn)
            else:
                smpl_path = model_path
            assert osp.exists(smpl_path), "Path {} does not exist!".format(smpl_path)

            with open(smpl_path, "rb") as smpl_file:
                data_struct = Struct(**pickle.load(smpl_file, encoding="latin1"))

        super(SMPL, self).__init__()
        self.batch_size = batch_size
        shapedirs = data_struct.shapedirs
        if shapedirs.shape[-1] < self.SHAPE_SPACE_DIM:
            # print(f'WARNING: You are using a {self.name()} model, with only'
            #       ' 10 shape coefficients.')
            num_betas = min(num_betas, 10)
        else:
            num_betas = min(num_betas, self.SHAPE_SPACE_DIM)

        if self.age == "kid":
            v_template_smil = np.load(kid_template_path)
            v_template_smil -= np.mean(v_template_smil, axis=0)
            v_template_diff = np.expand_dims(v_template_smil - data_struct.v_template, axis=2)
            shapedirs = np.concatenate((shapedirs[:, :, :num_betas], v_template_diff), axis=2)
            num_betas = num_betas + 1

        self._num_betas = num_betas
        shapedirs = shapedirs[:, :, :num_betas]
        # The shape components
        self.register_buffer("shapedirs", to_tensor(to_np(shapedirs), dtype=dtype))

        if vertex_ids is None:
            # SMPL and SMPL-H share the same topology, so any extra joints can
            # be drawn from the same place
            vertex_ids = VERTEX_IDS["smplh"]

        self.dtype = dtype

        self.joint_mapper = joint_mapper

        self.vertex_joint_selector = VertexJointSelector(vertex_ids=vertex_ids, **kwargs)

        self.faces = data_struct.f
        self.register_buffer(
            "faces_tensor",
            to_tensor(to_np(self.faces, dtype=np.int64), dtype=torch.long),
        )

        if create_betas:
            if betas is None:
                default_betas = torch.zeros([batch_size, self.num_betas], dtype=dtype)
            else:
                if torch.is_tensor(betas):
                    default_betas = betas.clone().detach()
                else:
                    default_betas = torch.tensor(betas, dtype=dtype)

            self.register_parameter("betas", nn.Parameter(default_betas, requires_grad=True))

        # The tensor that contains the global rotation of the model
        # It is separated from the pose of the joints in case we wish to
        # optimize only over one of them
        if create_global_orient:
            if global_orient is None:
                default_global_orient = torch.zeros([batch_size, 3], dtype=dtype)
            else:
                if torch.is_tensor(global_orient):
                    default_global_orient = global_orient.clone().detach()
                else:
                    default_global_orient = torch.tensor(global_orient, dtype=dtype)

            global_orient = nn.Parameter(default_global_orient, requires_grad=True)
            self.register_parameter("global_orient", global_orient)

        if create_body_pose:
            if body_pose is None:
                default_body_pose = torch.zeros([batch_size, self.NUM_BODY_JOINTS * 3], dtype=dtype)
            else:
                if torch.is_tensor(body_pose):
                    default_body_pose = body_pose.clone().detach()
                else:
                    default_body_pose = torch.tensor(body_pose, dtype=dtype)
            self.register_parameter(
                "body_pose", nn.Parameter(default_body_pose, requires_grad=True)
            )

        if create_transl:
            if transl is None:
                default_transl = torch.zeros([batch_size, 3], dtype=dtype, requires_grad=True)
            else:
                default_transl = torch.tensor(transl, dtype=dtype)
            self.register_parameter("transl", nn.Parameter(default_transl, requires_grad=True))

        if v_template is None:
            v_template = data_struct.v_template

        if not torch.is_tensor(v_template):
            v_template = to_tensor(to_np(v_template), dtype=dtype)

        if v_personal is not None:
            v_personal = to_tensor(to_np(v_personal), dtype=dtype)
            v_template += v_personal

        # The vertices of the template model
        self.register_buffer("v_template", v_template)

        j_regressor = to_tensor(to_np(data_struct.J_regressor), dtype=dtype)
        self.register_buffer("J_regressor", j_regressor)

        # Pose blend shape basis: 6890 x 3 x 207, reshaped to 6890*3 x 207
        num_pose_basis = data_struct.posedirs.shape[-1]
        # 207 x 20670
        posedirs = np.reshape(data_struct.posedirs, [-1, num_pose_basis]).T
        self.register_buffer("posedirs", to_tensor(to_np(posedirs), dtype=dtype))

        # indices of parents for each joints
        parents = to_tensor(to_np(data_struct.kintree_table[0])).long()
        parents[0] = -1
        self.register_buffer("parents", parents)

        self.register_buffer("lbs_weights", to_tensor(to_np(data_struct.weights), dtype=dtype))

    @property
    def num_betas(self):
        return self._num_betas

    @property
    def num_expression_coeffs(self):
        return 0

    def create_mean_pose(self, data_struct) -> Tensor:
        pass

    def name(self) -> str:
        return "SMPL"

    @torch.no_grad()
    def reset_params(self, **params_dict) -> None:
        for param_name, param in self.named_parameters():
            if param_name in params_dict:
                param[:] = torch.tensor(params_dict[param_name])
            else:
                param.fill_(0)

    def get_num_verts(self) -> int:
        return self.v_template.shape[0]

    def get_num_faces(self) -> int:
        return self.faces.shape[0]

    def extra_repr(self) -> str:
        msg = [
            f"Gender: {self.gender.upper()}",
            f"Number of joints: {self.J_regressor.shape[0]}",
            f"Betas: {self.num_betas}",
        ]
        return "\n".join(msg)

    def forward(
        self,
        betas: Optional[Tensor] = None,
        body_pose: Optional[Tensor] = None,
        global_orient: Optional[Tensor] = None,
        transl: Optional[Tensor] = None,
        return_verts=True,
        return_full_pose: bool = False,
        pose2rot: bool = True,
        **kwargs,
    ) -> SMPLOutput:
        """Forward pass for the SMPL model

        Parameters
        ----------
        global_orient: torch.tensor, optional, shape Bx3
            If given, ignore the member variable and use it as the global
            rotation of the body. Useful if someone wishes to predicts this
            with an external model. (default=None)
        betas: torch.tensor, optional, shape BxN_b
            If given, ignore the member variable `betas` and use it
            instead. For example, it can used if shape parameters
            `betas` are predicted from some external model.
            (default=None)
        body_pose: torch.tensor, optional, shape Bx(J*3)
            If given, ignore the member variable `body_pose` and use it
            instead. For example, it can used if someone predicts the
            pose of the body joints are predicted from some external model.
            It should be a tensor that contains joint rotations in
            axis-angle format. (default=None)
        transl: torch.tensor, optional, shape Bx3
            If given, ignore the member variable `transl` and use it
            instead. For example, it can used if the translation
            `transl` is predicted from some external model.
            (default=None)
        return_verts: bool, optional
            Return the vertices. (default=True)
        return_full_pose: bool, optional
            Returns the full axis-angle pose vector (default=False)

        Returns
        -------
        """
        # If no shape and pose parameters are passed along, then use the
        # ones from the module
        global_orient = (global_orient if global_orient is not None else self.global_orient)
        body_pose = body_pose if body_pose is not None else self.body_pose
        betas = betas if betas is not None else self.betas

        apply_trans = transl is not None or hasattr(self, "transl")
        if transl is None and hasattr(self, "transl"):
            transl = self.transl

        full_pose = torch.cat([global_orient, body_pose], dim=1)

        batch_size = max(betas.shape[0], global_orient.shape[0], body_pose.shape[0])

        if betas.shape[0] != batch_size:
            num_repeats = int(batch_size / betas.shape[0])
            betas = betas.expand(num_repeats, -1)

        vertices, joints = lbs(
            betas,
            full_pose,
            self.v_template,
            self.shapedirs,
            self.posedirs,
            self.J_regressor,
            self.parents,
            self.lbs_weights,
            pose2rot=pose2rot,
        )

        joints = self.vertex_joint_selector(vertices, joints)
        # Map the joints to the current dataset
        if self.joint_mapper is not None:
            joints = self.joint_mapper(joints)

        if apply_trans:
            joints += transl.unsqueeze(dim=1)
            vertices += transl.unsqueeze(dim=1)

        output = SMPLOutput(
            vertices=vertices if return_verts else None,
            global_orient=global_orient,
            body_pose=body_pose,
            joints=joints,
            betas=betas,
            full_pose=full_pose if return_full_pose else None,
        )

        return output


class SMPLLayer(SMPL):
    def __init__(self, *args, **kwargs) -> None:
        # Just create a SMPL module without any member variables
        super(SMPLLayer, self).__init__(
            create_body_pose=False,
            create_betas=False,
            create_global_orient=False,
            create_transl=False,
            *args,
            **kwargs,
        )

    def forward(
        self,
        betas: Optional[Tensor] = None,
        body_pose: Optional[Tensor] = None,
        global_orient: Optional[Tensor] = None,
        transl: Optional[Tensor] = None,
        return_verts=True,
        return_full_pose: bool = False,
        pose2rot: bool = True,
        **kwargs,
    ) -> SMPLOutput:
        """Forward pass for the SMPL model

        Parameters
        ----------
        global_orient: torch.tensor, optional, shape Bx3x3
            Global rotation of the body.  Useful if someone wishes to
            predicts this with an external model. It is expected to be in
            rotation matrix format.  (default=None)
        betas: torch.tensor, optional, shape BxN_b
            Shape parameters. For example, it can used if shape parameters
            `betas` are predicted from some external model.
            (default=None)
        body_pose: torch.tensor, optional, shape BxJx3x3
            Body pose. For example, it can used if someone predicts the
            pose of the body joints are predicted from some external model.
            It should be a tensor that contains joint rotations in
            rotation matrix format. (default=None)
        transl: torch.tensor, optional, shape Bx3
            Translation vector of the body.
            For example, it can used if the translation
            `transl` is predicted from some external model.
            (default=None)
        return_verts: bool, optional
            Return the vertices. (default=True)
        return_full_pose: bool, optional
            Returns the full axis-angle pose vector (default=False)

        Returns
        -------
        """
        model_vars = [betas, global_orient, body_pose, transl]
        batch_size = 1
        for var in model_vars:
            if var is None:
                continue
            batch_size = max(batch_size, len(var))
        device, dtype = self.shapedirs.device, self.shapedirs.dtype
        if global_orient is None:
            global_orient = (
                torch.eye(3, device=device,
                          dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
            )
        if body_pose is None:
            body_pose = (
                torch.eye(3, device=device,
                          dtype=dtype).view(1, 1, 3,
                                            3).expand(batch_size, self.NUM_BODY_JOINTS, -1,
                                                      -1).contiguous()
            )
        if betas is None:
            betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device)
        if transl is None:
            transl = torch.zeros([batch_size, 3], dtype=dtype, device=device)
        full_pose = torch.cat(
            [
                global_orient.reshape(-1, 1, 3, 3),
                body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3),
            ],
            dim=1,
        )

        vertices, joints = lbs(
            betas,
            full_pose,
            self.v_template,
            self.shapedirs,
            self.posedirs,
            self.J_regressor,
            self.parents,
            self.lbs_weights,
            pose2rot=False,
        )

        joints = self.vertex_joint_selector(vertices, joints)
        # Map the joints to the current dataset
        if self.joint_mapper is not None:
            joints = self.joint_mapper(joints)

        if transl is not None:
            joints += transl.unsqueeze(dim=1)
            vertices += transl.unsqueeze(dim=1)

        output = SMPLOutput(
            vertices=vertices if return_verts else None,
            global_orient=global_orient,
            body_pose=body_pose,
            joints=joints,
            betas=betas,
            full_pose=full_pose if return_full_pose else None,
        )

        return output


class SMPLH(SMPL):

    # The hand joints are replaced by MANO
    NUM_BODY_JOINTS = SMPL.NUM_JOINTS - 2
    NUM_HAND_JOINTS = 15
    NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS

    def __init__(
        self,
        model_path,
        kid_template_path: str = "",
        data_struct: Optional[Struct] = None,
        create_left_hand_pose: bool = True,
        left_hand_pose: Optional[Tensor] = None,
        create_right_hand_pose: bool = True,
        right_hand_pose: Optional[Tensor] = None,
        use_pca: bool = True,
        num_pca_comps: int = 6,
        flat_hand_mean: bool = False,
        batch_size: int = 1,
        gender: str = "neutral",
        age: str = "adult",
        dtype=torch.float32,
        vertex_ids=None,
        use_compressed: bool = True,
        ext: str = "pkl",
        **kwargs,
    ) -> None:
        """SMPLH model constructor

        Parameters
        ----------
        model_path: str
            The path to the folder or to the file where the model
            parameters are stored
        data_struct: Strct
            A struct object. If given, then the parameters of the model are
            read from the object. Otherwise, the model tries to read the
            parameters from the given `model_path`. (default = None)
        create_left_hand_pose: bool, optional
            Flag for creating a member variable for the pose of the left
            hand. (default = True)
        left_hand_pose: torch.tensor, optional, BxP
            The default value for the left hand pose member variable.
            (default = None)
        create_right_hand_pose: bool, optional
            Flag for creating a member variable for the pose of the right
            hand. (default = True)
        right_hand_pose: torch.tensor, optional, BxP
            The default value for the right hand pose member variable.
            (default = None)
        num_pca_comps: int, optional
            The number of PCA components to use for each hand.
            (default = 6)
        flat_hand_mean: bool, optional
            If False, then the pose of the hand is initialized to False.
        batch_size: int, optional
            The batch size used for creating the member variables
        gender: str, optional
            Which gender to load
        dtype: torch.dtype, optional
            The data type for the created variables
        vertex_ids: dict, optional
            A dictionary containing the indices of the extra vertices that
            will be selected
        """

        self.num_pca_comps = num_pca_comps
        # If no data structure is passed, then load the data from the given
        # model folder
        if data_struct is None:
            # Load the model
            if osp.isdir(model_path):
                model_fn = "SMPLH_{}.{ext}".format(gender.upper(), ext=ext)
                smplh_path = os.path.join(model_path, model_fn)
            else:
                smplh_path = model_path
            assert osp.exists(smplh_path), "Path {} does not exist!".format(smplh_path)

            if ext == "pkl":
                with open(smplh_path, "rb") as smplh_file:
                    model_data = pickle.load(smplh_file, encoding="latin1")
            elif ext == "npz":
                model_data = np.load(smplh_path, allow_pickle=True)
            else:
                raise ValueError("Unknown extension: {}".format(ext))
            data_struct = Struct(**model_data)

        if vertex_ids is None:
            vertex_ids = VERTEX_IDS["smplh"]

        super(SMPLH, self).__init__(
            model_path=model_path,
            kid_template_path=kid_template_path,
            data_struct=data_struct,
            batch_size=batch_size,
            vertex_ids=vertex_ids,
            gender=gender,
            age=age,
            use_compressed=use_compressed,
            dtype=dtype,
            ext=ext,
            **kwargs,
        )

        self.use_pca = use_pca
        self.num_pca_comps = num_pca_comps
        self.flat_hand_mean = flat_hand_mean

        left_hand_components = data_struct.hands_componentsl[:num_pca_comps]
        right_hand_components = data_struct.hands_componentsr[:num_pca_comps]

        self.np_left_hand_components = left_hand_components
        self.np_right_hand_components = right_hand_components
        if self.use_pca:
            self.register_buffer(
                "left_hand_components", torch.tensor(left_hand_components, dtype=dtype)
            )
            self.register_buffer(
                "right_hand_components",
                torch.tensor(right_hand_components, dtype=dtype),
            )

        if self.flat_hand_mean:
            left_hand_mean = np.zeros_like(data_struct.hands_meanl)
        else:
            left_hand_mean = data_struct.hands_meanl

        if self.flat_hand_mean:
            right_hand_mean = np.zeros_like(data_struct.hands_meanr)
        else:
            right_hand_mean = data_struct.hands_meanr

        self.register_buffer("left_hand_mean", to_tensor(left_hand_mean, dtype=self.dtype))
        self.register_buffer("right_hand_mean", to_tensor(right_hand_mean, dtype=self.dtype))

        # Create the buffers for the pose of the left hand
        hand_pose_dim = num_pca_comps if use_pca else 3 * self.NUM_HAND_JOINTS
        if create_left_hand_pose:
            if left_hand_pose is None:
                default_lhand_pose = torch.zeros([batch_size, hand_pose_dim], dtype=dtype)
            else:
                default_lhand_pose = torch.tensor(left_hand_pose, dtype=dtype)

            left_hand_pose_param = nn.Parameter(default_lhand_pose, requires_grad=True)
            self.register_parameter("left_hand_pose", left_hand_pose_param)

        if create_right_hand_pose:
            if right_hand_pose is None:
                default_rhand_pose = torch.zeros([batch_size, hand_pose_dim], dtype=dtype)
            else:
                default_rhand_pose = torch.tensor(right_hand_pose, dtype=dtype)

            right_hand_pose_param = nn.Parameter(default_rhand_pose, requires_grad=True)
            self.register_parameter("right_hand_pose", right_hand_pose_param)

        # Create the buffer for the mean pose.
        pose_mean_tensor = self.create_mean_pose(data_struct, flat_hand_mean=flat_hand_mean)
        if not torch.is_tensor(pose_mean_tensor):
            pose_mean_tensor = torch.tensor(pose_mean_tensor, dtype=dtype)
        self.register_buffer("pose_mean", pose_mean_tensor)

    def create_mean_pose(self, data_struct, flat_hand_mean=False):
        # Create the array for the mean pose. If flat_hand is false, then use
        # the mean that is given by the data, rather than the flat open hand
        global_orient_mean = torch.zeros([3], dtype=self.dtype)
        body_pose_mean = torch.zeros([self.NUM_BODY_JOINTS * 3], dtype=self.dtype)

        pose_mean = torch.cat(
            [
                global_orient_mean,
                body_pose_mean,
                self.left_hand_mean,
                self.right_hand_mean,
            ],
            dim=0,
        )
        return pose_mean

    def name(self) -> str:
        return "SMPL+H"

    def extra_repr(self):
        msg = super(SMPLH, self).extra_repr()
        msg = [msg]
        if self.use_pca:
            msg.append(f"Number of PCA components: {self.num_pca_comps}")
        msg.append(f"Flat hand mean: {self.flat_hand_mean}")
        return "\n".join(msg)

    def forward(
        self,
        betas: Optional[Tensor] = None,
        global_orient: Optional[Tensor] = None,
        body_pose: Optional[Tensor] = None,
        left_hand_pose: Optional[Tensor] = None,
        right_hand_pose: Optional[Tensor] = None,
        transl: Optional[Tensor] = None,
        return_verts: bool = True,
        return_full_pose: bool = False,
        pose2rot: bool = True,
        **kwargs,
    ) -> SMPLHOutput:
        """"""

        # If no shape and pose parameters are passed along, then use the
        # ones from the module
        global_orient = (global_orient if global_orient is not None else self.global_orient)
        body_pose = body_pose if body_pose is not None else self.body_pose
        betas = betas if betas is not None else self.betas
        left_hand_pose = (left_hand_pose if left_hand_pose is not None else self.left_hand_pose)
        right_hand_pose = (right_hand_pose if right_hand_pose is not None else self.right_hand_pose)

        apply_trans = transl is not None or hasattr(self, "transl")
        if transl is None:
            if hasattr(self, "transl"):
                transl = self.transl

        if self.use_pca:
            left_hand_pose = torch.einsum("bi,ij->bj", [left_hand_pose, self.left_hand_components])
            right_hand_pose = torch.einsum(
                "bi,ij->bj", [right_hand_pose, self.right_hand_components]
            )

        full_pose = torch.cat([global_orient, body_pose, left_hand_pose, right_hand_pose], dim=1)

        full_pose += self.pose_mean

        vertices, joints = lbs(
            betas,
            full_pose,
            self.v_template,
            self.shapedirs,
            self.posedirs,
            self.J_regressor,
            self.parents,
            self.lbs_weights,
            pose2rot=pose2rot,
        )

        # Add any extra joints that might be needed
        joints = self.vertex_joint_selector(vertices, joints)
        if self.joint_mapper is not None:
            joints = self.joint_mapper(joints)

        if apply_trans:
            joints += transl.unsqueeze(dim=1)
            vertices += transl.unsqueeze(dim=1)

        output = SMPLHOutput(
            vertices=vertices if return_verts else None,
            joints=joints,
            betas=betas,
            global_orient=global_orient,
            body_pose=body_pose,
            left_hand_pose=left_hand_pose,
            right_hand_pose=right_hand_pose,
            full_pose=full_pose if return_full_pose else None,
        )

        return output


class SMPLHLayer(SMPLH):
    def __init__(self, *args, **kwargs) -> None:
        """SMPL+H as a layer model constructor"""
        super(SMPLHLayer, self).__init__(
            create_global_orient=False,
            create_body_pose=False,
            create_left_hand_pose=False,
            create_right_hand_pose=False,
            create_betas=False,
            create_transl=False,
            *args,
            **kwargs,
        )

    def forward(
        self,
        betas: Optional[Tensor] = None,
        global_orient: Optional[Tensor] = None,
        body_pose: Optional[Tensor] = None,
        left_hand_pose: Optional[Tensor] = None,
        right_hand_pose: Optional[Tensor] = None,
        transl: Optional[Tensor] = None,
        return_verts: bool = True,
        return_full_pose: bool = False,
        pose2rot: bool = True,
        **kwargs,
    ) -> SMPLHOutput:
        """Forward pass for the SMPL+H model

        Parameters
        ----------
        global_orient: torch.tensor, optional, shape Bx3x3
            Global rotation of the body. Useful if someone wishes to
            predicts this with an external model. It is expected to be in
            rotation matrix format. (default=None)
        betas: torch.tensor, optional, shape BxN_b
            Shape parameters. For example, it can used if shape parameters
            `betas` are predicted from some external model.
            (default=None)
        body_pose: torch.tensor, optional, shape BxJx3x3
            If given, ignore the member variable `body_pose` and use it
            instead. For example, it can used if someone predicts the
            pose of the body joints are predicted from some external model.
            It should be a tensor that contains joint rotations in
            rotation matrix format. (default=None)
        left_hand_pose: torch.tensor, optional, shape Bx15x3x3
            If given, contains the pose of the left hand.
            It should be a tensor that contains joint rotations in
            rotation matrix format. (default=None)
        right_hand_pose: torch.tensor, optional, shape Bx15x3x3
            If given, contains the pose of the right hand.
            It should be a tensor that contains joint rotations in
            rotation matrix format. (default=None)
        transl: torch.tensor, optional, shape Bx3
            Translation vector of the body.
            For example, it can used if the translation
            `transl` is predicted from some external model.
            (default=None)
        return_verts: bool, optional
            Return the vertices. (default=True)
        return_full_pose: bool, optional
            Returns the full axis-angle pose vector (default=False)

        Returns
        -------
        """
        model_vars = [
            betas,
            global_orient,
            body_pose,
            transl,
            left_hand_pose,
            right_hand_pose,
        ]
        batch_size = 1
        for var in model_vars:
            if var is None:
                continue
            batch_size = max(batch_size, len(var))
        device, dtype = self.shapedirs.device, self.shapedirs.dtype
        if global_orient is None:
            global_orient = (
                torch.eye(3, device=device,
                          dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
            )
        if body_pose is None:
            body_pose = (
                torch.eye(3, device=device,
                          dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 21, -1, -1).contiguous()
            )
        if left_hand_pose is None:
            left_hand_pose = (
                torch.eye(3, device=device,
                          dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()
            )
        if right_hand_pose is None:
            right_hand_pose = (
                torch.eye(3, device=device,
                          dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()
            )
        if betas is None:
            betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device)
        if transl is None:
            transl = torch.zeros([batch_size, 3], dtype=dtype, device=device)

        # Concatenate all pose vectors
        full_pose = torch.cat(
            [
                global_orient.reshape(-1, 1, 3, 3),
                body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3),
                left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3),
                right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3),
            ],
            dim=1,
        )

        vertices, joints = lbs(
            betas,
            full_pose,
            self.v_template,
            self.shapedirs,
            self.posedirs,
            self.J_regressor,
            self.parents,
            self.lbs_weights,
            pose2rot=False,
        )

        # Add any extra joints that might be needed
        joints = self.vertex_joint_selector(vertices, joints)
        if self.joint_mapper is not None:
            joints = self.joint_mapper(joints)

        if transl is not None:
            joints += transl.unsqueeze(dim=1)
            vertices += transl.unsqueeze(dim=1)

        output = SMPLHOutput(
            vertices=vertices if return_verts else None,
            joints=joints,
            betas=betas,
            global_orient=global_orient,
            body_pose=body_pose,
            left_hand_pose=left_hand_pose,
            right_hand_pose=right_hand_pose,
            full_pose=full_pose if return_full_pose else None,
        )

        return output


class SMPLX(SMPLH):
    """
    SMPL-X (SMPL eXpressive) is a unified body model, with shape parameters
    trained jointly for the face, hands and body.
    SMPL-X uses standard vertex based linear blend skinning with learned
    corrective blend shapes, has N=10475 vertices and K=54 joints,
    which includes joints for the neck, jaw, eyeballs and fingers.
    """

    NUM_BODY_JOINTS = SMPLH.NUM_BODY_JOINTS    # 21
    NUM_HAND_JOINTS = 15
    NUM_FACE_JOINTS = 3
    NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + NUM_FACE_JOINTS
    EXPRESSION_SPACE_DIM = 100
    NECK_IDX = 12

    def __init__(
        self,
        model_path: str,
        kid_template_path: str = "",
        num_expression_coeffs: int = 10,
        create_expression: bool = True,
        expression: Optional[Tensor] = None,
        create_jaw_pose: bool = True,
        jaw_pose: Optional[Tensor] = None,
        create_leye_pose: bool = True,
        leye_pose: Optional[Tensor] = None,
        create_reye_pose=True,
        reye_pose: Optional[Tensor] = None,
        use_face_contour: bool = False,
        batch_size: int = 1,
        gender: str = "neutral",
        age: str = "adult",
        dtype=torch.float32,
        ext: str = "npz",
        **kwargs,
    ) -> None:
        """SMPLX model constructor

        Parameters
        ----------
        model_path: str
            The path to the folder or to the file where the model
            parameters are stored
        num_expression_coeffs: int, optional
            Number of expression components to use
            (default = 10).
        create_expression: bool, optional
            Flag for creating a member variable for the expression space
            (default = True).
        expression: torch.tensor, optional, Bx10
            The default value for the expression member variable.
            (default = None)
        create_jaw_pose: bool, optional
            Flag for creating a member variable for the jaw pose.
            (default = False)
        jaw_pose: torch.tensor, optional, Bx3
            The default value for the jaw pose variable.
            (default = None)
        create_leye_pose: bool, optional
            Flag for creating a member variable for the left eye pose.
            (default = False)
        leye_pose: torch.tensor, optional, Bx10
            The default value for the left eye pose variable.
            (default = None)
        create_reye_pose: bool, optional
            Flag for creating a member variable for the right eye pose.
            (default = False)
        reye_pose: torch.tensor, optional, Bx10
            The default value for the right eye pose variable.
            (default = None)
        use_face_contour: bool, optional
            Whether to compute the keypoints that form the facial contour
        batch_size: int, optional
            The batch size used for creating the member variables
        gender: str, optional
            Which gender to load
        dtype: torch.dtype
            The data type for the created variables
        """

        # Load the model
        from huggingface_hub import hf_hub_download

        model_fn = "SMPLX_{}.{ext}".format(gender.upper(), ext=ext)
        smplx_path = hf_hub_download(
            repo_id=model_path, use_auth_token=os.environ["ICON"], filename=f"models/{model_fn}"
        )

        if ext == "pkl":
            with open(smplx_path, "rb") as smplx_file:
                model_data = pickle.load(smplx_file, encoding="latin1")
        elif ext == "npz":
            model_data = np.load(smplx_path, allow_pickle=True)
        else:
            raise ValueError("Unknown extension: {}".format(ext))

        data_struct = Struct(**model_data)

        super(SMPLX, self).__init__(
            model_path=model_path,
            kid_template_path=kid_template_path,
            data_struct=data_struct,
            dtype=dtype,
            batch_size=batch_size,
            vertex_ids=VERTEX_IDS["smplx"],
            gender=gender,
            age=age,
            ext=ext,
            **kwargs,
        )

        lmk_faces_idx = data_struct.lmk_faces_idx
        self.register_buffer("lmk_faces_idx", torch.tensor(lmk_faces_idx, dtype=torch.long))
        lmk_bary_coords = data_struct.lmk_bary_coords
        self.register_buffer("lmk_bary_coords", torch.tensor(lmk_bary_coords, dtype=dtype))

        self.use_face_contour = use_face_contour
        if self.use_face_contour:
            dynamic_lmk_faces_idx = data_struct.dynamic_lmk_faces_idx
            dynamic_lmk_faces_idx = torch.tensor(dynamic_lmk_faces_idx, dtype=torch.long)
            self.register_buffer("dynamic_lmk_faces_idx", dynamic_lmk_faces_idx)

            dynamic_lmk_bary_coords = data_struct.dynamic_lmk_bary_coords
            dynamic_lmk_bary_coords = torch.tensor(dynamic_lmk_bary_coords, dtype=dtype)
            self.register_buffer("dynamic_lmk_bary_coords", dynamic_lmk_bary_coords)

            neck_kin_chain = find_joint_kin_chain(self.NECK_IDX, self.parents)
            self.register_buffer("neck_kin_chain", torch.tensor(neck_kin_chain, dtype=torch.long))

        if create_jaw_pose:
            if jaw_pose is None:
                default_jaw_pose = torch.zeros([batch_size, 3], dtype=dtype)
            else:
                default_jaw_pose = torch.tensor(jaw_pose, dtype=dtype)
            jaw_pose_param = nn.Parameter(default_jaw_pose, requires_grad=True)
            self.register_parameter("jaw_pose", jaw_pose_param)

        if create_leye_pose:
            if leye_pose is None:
                default_leye_pose = torch.zeros([batch_size, 3], dtype=dtype)
            else:
                default_leye_pose = torch.tensor(leye_pose, dtype=dtype)
            leye_pose_param = nn.Parameter(default_leye_pose, requires_grad=True)
            self.register_parameter("leye_pose", leye_pose_param)

        if create_reye_pose:
            if reye_pose is None:
                default_reye_pose = torch.zeros([batch_size, 3], dtype=dtype)
            else:
                default_reye_pose = torch.tensor(reye_pose, dtype=dtype)
            reye_pose_param = nn.Parameter(default_reye_pose, requires_grad=True)
            self.register_parameter("reye_pose", reye_pose_param)

        shapedirs = data_struct.shapedirs
        if len(shapedirs.shape) < 3:
            shapedirs = shapedirs[:, :, None]
        if shapedirs.shape[-1] < self.SHAPE_SPACE_DIM + self.EXPRESSION_SPACE_DIM:
            # print(f'WARNING: You are using a {self.name()} model, with only'
            #       ' 10 shape and 10 expression coefficients.')
            expr_start_idx = 10
            expr_end_idx = 20
            num_expression_coeffs = min(num_expression_coeffs, 10)
        else:
            expr_start_idx = self.SHAPE_SPACE_DIM
            expr_end_idx = self.SHAPE_SPACE_DIM + num_expression_coeffs
            num_expression_coeffs = min(num_expression_coeffs, self.EXPRESSION_SPACE_DIM)

        self._num_expression_coeffs = num_expression_coeffs

        expr_dirs = shapedirs[:, :, expr_start_idx:expr_end_idx]
        self.register_buffer("expr_dirs", to_tensor(to_np(expr_dirs), dtype=dtype))

        if create_expression:
            if expression is None:
                default_expression = torch.zeros([batch_size, self.num_expression_coeffs],
                                                 dtype=dtype)
            else:
                default_expression = torch.tensor(expression, dtype=dtype)
            expression_param = nn.Parameter(default_expression, requires_grad=True)
            self.register_parameter("expression", expression_param)

    def name(self) -> str:
        return "SMPL-X"

    @property
    def num_expression_coeffs(self):
        return self._num_expression_coeffs

    def create_mean_pose(self, data_struct, flat_hand_mean=False):
        # Create the array for the mean pose. If flat_hand is false, then use
        # the mean that is given by the data, rather than the flat open hand
        global_orient_mean = torch.zeros([3], dtype=self.dtype)
        body_pose_mean = torch.zeros([self.NUM_BODY_JOINTS * 3], dtype=self.dtype)
        jaw_pose_mean = torch.zeros([3], dtype=self.dtype)
        leye_pose_mean = torch.zeros([3], dtype=self.dtype)
        reye_pose_mean = torch.zeros([3], dtype=self.dtype)

        pose_mean = np.concatenate(
            [
                global_orient_mean,
                body_pose_mean,
                jaw_pose_mean,
                leye_pose_mean,
                reye_pose_mean,
                self.left_hand_mean,
                self.right_hand_mean,
            ],
            axis=0,
        )

        return pose_mean

    def extra_repr(self):
        msg = super(SMPLX, self).extra_repr()
        msg = [msg, f"Number of Expression Coefficients: {self.num_expression_coeffs}"]
        return "\n".join(msg)

    def forward(
        self,
        betas: Optional[Tensor] = None,
        global_orient: Optional[Tensor] = None,
        body_pose: Optional[Tensor] = None,
        left_hand_pose: Optional[Tensor] = None,
        right_hand_pose: Optional[Tensor] = None,
        transl: Optional[Tensor] = None,
        expression: Optional[Tensor] = None,
        jaw_pose: Optional[Tensor] = None,
        leye_pose: Optional[Tensor] = None,
        reye_pose: Optional[Tensor] = None,
        return_verts: bool = True,
        return_full_pose: bool = False,
        pose2rot: bool = True,
        return_joint_transformation: bool = False,
        return_vertex_transformation: bool = False,
        pose_type: str = 'posed',
        **kwargs,
    ) -> SMPLXOutput:
        """
        Forward pass for the SMPLX model

            Parameters
            ----------
            global_orient: torch.tensor, optional, shape Bx3
                If given, ignore the member variable and use it as the global
                rotation of the body. Useful if someone wishes to predicts this
                with an external model. (default=None)
            betas: torch.tensor, optional, shape BxN_b
                If given, ignore the member variable `betas` and use it
                instead. For example, it can used if shape parameters
                `betas` are predicted from some external model.
                (default=None)
            expression: torch.tensor, optional, shape BxN_e
                If given, ignore the member variable `expression` and use it
                instead. For example, it can used if expression parameters
                `expression` are predicted from some external model.
            body_pose: torch.tensor, optional, shape Bx(J*3)
                If given, ignore the member variable `body_pose` and use it
                instead. For example, it can used if someone predicts the
                pose of the body joints are predicted from some external model.
                It should be a tensor that contains joint rotations in
                axis-angle format. (default=None)
            left_hand_pose: torch.tensor, optional, shape BxP
                If given, ignore the member variable `left_hand_pose` and
                use this instead. It should either contain PCA coefficients or
                joint rotations in axis-angle format.
            right_hand_pose: torch.tensor, optional, shape BxP
                If given, ignore the member variable `right_hand_pose` and
                use this instead. It should either contain PCA coefficients or
                joint rotations in axis-angle format.
            jaw_pose: torch.tensor, optional, shape Bx3
                If given, ignore the member variable `jaw_pose` and
                use this instead. It should either joint rotations in
                axis-angle format.
            transl: torch.tensor, optional, shape Bx3
                If given, ignore the member variable `transl` and use it
                instead. For example, it can used if the translation
                `transl` is predicted from some external model.
                (default=None)
            return_verts: bool, optional
                Return the vertices. (default=True)
            return_full_pose: bool, optional
                Returns the full axis-angle pose vector (default=False)

            Returns
            -------
                output: ModelOutput
                A named tuple of type `ModelOutput`
        """

        # If no shape and pose parameters are passed along, then use the
        # ones from the module
        global_orient = (global_orient if global_orient is not None else self.global_orient)
        body_pose = body_pose if body_pose is not None else self.body_pose
        betas = betas if betas is not None else self.betas

        left_hand_pose = (left_hand_pose if left_hand_pose is not None else self.left_hand_pose)
        right_hand_pose = (right_hand_pose if right_hand_pose is not None else self.right_hand_pose)
        jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose
        leye_pose = leye_pose if leye_pose is not None else self.leye_pose
        reye_pose = reye_pose if reye_pose is not None else self.reye_pose
        expression = expression if expression is not None else self.expression

        apply_trans = transl is not None or hasattr(self, "transl")
        if transl is None:
            if hasattr(self, "transl"):
                transl = self.transl

        if self.use_pca:
            left_hand_pose = torch.einsum("bi,ij->bj", [left_hand_pose, self.left_hand_components])
            right_hand_pose = torch.einsum(
                "bi,ij->bj", [right_hand_pose, self.right_hand_components]
            )

        full_pose = torch.cat(
            [
                global_orient,
                body_pose,
                jaw_pose,
                leye_pose,
                reye_pose,
                left_hand_pose,
                right_hand_pose,
            ],
            dim=1,
        )

        if pose_type == "t-pose":
            full_pose *= 0.0
        elif pose_type == "a-pose":
            body_pose = torch.zeros_like(body_pose).view(body_pose.shape[0], -1, 3)
            body_pose[:, 15] = torch.tensor([0., 0., -45 * np.pi / 180.])
            body_pose[:, 16] = torch.tensor([0., 0., 45 * np.pi / 180.])
            body_pose = body_pose.view(body_pose.shape[0], -1)

            full_pose = torch.cat(
                [
                    global_orient * 0.,
                    body_pose,
                    jaw_pose * 0.,
                    leye_pose * 0.,
                    reye_pose * 0.,
                    left_hand_pose * 0.,
                    right_hand_pose * 0.,
                ],
                dim=1,
            )
        elif pose_type == "da-pose":
            body_pose = torch.zeros_like(body_pose).view(body_pose.shape[0], -1, 3)
            body_pose[:, 0] = torch.tensor([0., 0., 30 * np.pi / 180.])
            body_pose[:, 1] = torch.tensor([0., 0., -30 * np.pi / 180.])
            body_pose = body_pose.view(body_pose.shape[0], -1)

            full_pose = torch.cat(
                [
                    global_orient * 0.,
                    body_pose,
                    jaw_pose * 0.,
                    leye_pose * 0.,
                    reye_pose * 0.,
                    left_hand_pose * 0.,
                    right_hand_pose * 0.,
                ],
                dim=1,
            )

        # Add the mean pose of the model. Does not affect the body, only the
        # hands when flat_hand_mean == False
        # full_pose += self.pose_mean

        batch_size = max(betas.shape[0], global_orient.shape[0], body_pose.shape[0])
        # Concatenate the shape and expression coefficients
        scale = int(batch_size / betas.shape[0])
        if scale > 1:
            betas = betas.expand(scale, -1)
        shape_components = torch.cat([betas, expression], dim=-1)

        shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1)

        if return_joint_transformation or return_vertex_transformation:
            vertices, joints, joint_transformation, vertex_transformation = lbs(
                shape_components,
                full_pose,
                self.v_template,
                shapedirs,
                self.posedirs,
                self.J_regressor,
                self.parents,
                self.lbs_weights,
                pose2rot=pose2rot,
                return_transformation=True,
            )
        else:
            vertices, joints = lbs(
                shape_components,
                full_pose,
                self.v_template,
                shapedirs,
                self.posedirs,
                self.J_regressor,
                self.parents,
                self.lbs_weights,
                pose2rot=pose2rot,
            )

        lmk_faces_idx = (self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1).contiguous())
        lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(self.batch_size, 1, 1)
        if self.use_face_contour:
            lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords(
                vertices,
                full_pose,
                self.dynamic_lmk_faces_idx,
                self.dynamic_lmk_bary_coords,
                self.neck_kin_chain,
                pose2rot=True,
            )
            dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords

            lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
            lmk_bary_coords = torch.cat([
                lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords
            ], 1)

        landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords)

        # Add any extra joints that might be needed
        joints = self.vertex_joint_selector(vertices, joints)
        # Add the landmarks to the joints
        joints = torch.cat([joints, landmarks], dim=1)
        # Map the joints to the current dataset

        if self.joint_mapper is not None:
            joints = self.joint_mapper(joints=joints, vertices=vertices)

        if apply_trans:
            joints += transl.unsqueeze(dim=1)
            vertices += transl.unsqueeze(dim=1)

        output = SMPLXOutput(
            vertices=vertices if return_verts else None,
            joints=joints,
            betas=betas,
            expression=expression,
            global_orient=global_orient,
            body_pose=body_pose,
            left_hand_pose=left_hand_pose,
            right_hand_pose=right_hand_pose,
            jaw_pose=jaw_pose,
            full_pose=full_pose if return_full_pose else None,
            joint_transformation=joint_transformation if return_joint_transformation else None,
            vertex_transformation=vertex_transformation if return_vertex_transformation else None,
        )
        return output


class SMPLXLayer(SMPLX):
    def __init__(self, *args, **kwargs) -> None:
        # Just create a SMPLX module without any member variables
        super(SMPLXLayer, self).__init__(
            create_global_orient=False,
            create_body_pose=False,
            create_left_hand_pose=False,
            create_right_hand_pose=False,
            create_jaw_pose=False,
            create_leye_pose=False,
            create_reye_pose=False,
            create_betas=False,
            create_expression=False,
            create_transl=False,
            *args,
            **kwargs,
        )

    def forward(
        self,
        betas: Optional[Tensor] = None,
        global_orient: Optional[Tensor] = None,
        body_pose: Optional[Tensor] = None,
        left_hand_pose: Optional[Tensor] = None,
        right_hand_pose: Optional[Tensor] = None,
        transl: Optional[Tensor] = None,
        expression: Optional[Tensor] = None,
        jaw_pose: Optional[Tensor] = None,
        leye_pose: Optional[Tensor] = None,
        reye_pose: Optional[Tensor] = None,
        return_verts: bool = True,
        return_full_pose: bool = False,
        **kwargs,
    ) -> SMPLXOutput:
        """
        Forward pass for the SMPLX model

            Parameters
            ----------
            global_orient: torch.tensor, optional, shape Bx3x3
                If given, ignore the member variable and use it as the global
                rotation of the body. Useful if someone wishes to predicts this
                with an external model. It is expected to be in rotation matrix
                format. (default=None)
            betas: torch.tensor, optional, shape BxN_b
                If given, ignore the member variable `betas` and use it
                instead. For example, it can used if shape parameters
                `betas` are predicted from some external model.
                (default=None)
            expression: torch.tensor, optional, shape BxN_e
                Expression coefficients.
                For example, it can used if expression parameters
                `expression` are predicted from some external model.
            body_pose: torch.tensor, optional, shape BxJx3x3
                If given, ignore the member variable `body_pose` and use it
                instead. For example, it can used if someone predicts the
                pose of the body joints are predicted from some external model.
                It should be a tensor that contains joint rotations in
                rotation matrix format. (default=None)
            left_hand_pose: torch.tensor, optional, shape Bx15x3x3
                If given, contains the pose of the left hand.
                It should be a tensor that contains joint rotations in
                rotation matrix format. (default=None)
            right_hand_pose: torch.tensor, optional, shape Bx15x3x3
                If given, contains the pose of the right hand.
                It should be a tensor that contains joint rotations in
                rotation matrix format. (default=None)
            jaw_pose: torch.tensor, optional, shape Bx3x3
                Jaw pose. It should either joint rotations in
                rotation matrix format.
            transl: torch.tensor, optional, shape Bx3
                Translation vector of the body.
                For example, it can used if the translation
                `transl` is predicted from some external model.
                (default=None)
            return_verts: bool, optional
                Return the vertices. (default=True)
            return_full_pose: bool, optional
                Returns the full pose vector (default=False)
            Returns
            -------
                output: ModelOutput
                A data class that contains the posed vertices and joints
        """
        device, dtype = self.shapedirs.device, self.shapedirs.dtype

        model_vars = [
            betas,
            global_orient,
            body_pose,
            transl,
            expression,
            left_hand_pose,
            right_hand_pose,
            jaw_pose,
        ]
        batch_size = 1
        for var in model_vars:
            if var is None:
                continue
            batch_size = max(batch_size, len(var))

        if global_orient is None:
            global_orient = (
                torch.eye(3, device=device,
                          dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
            )
        if body_pose is None:
            body_pose = (
                torch.eye(3, device=device,
                          dtype=dtype).view(1, 1, 3,
                                            3).expand(batch_size, self.NUM_BODY_JOINTS, -1,
                                                      -1).contiguous()
            )
        if left_hand_pose is None:
            left_hand_pose = (
                torch.eye(3, device=device,
                          dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()
            )
        if right_hand_pose is None:
            right_hand_pose = (
                torch.eye(3, device=device,
                          dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()
            )
        if jaw_pose is None:
            jaw_pose = (
                torch.eye(3, device=device,
                          dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
            )
        if leye_pose is None:
            leye_pose = (
                torch.eye(3, device=device,
                          dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
            )
        if reye_pose is None:
            reye_pose = (
                torch.eye(3, device=device,
                          dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
            )
        if expression is None:
            expression = torch.zeros([batch_size, self.num_expression_coeffs],
                                     dtype=dtype,
                                     device=device)
        if betas is None:
            betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device)
        if transl is None:
            transl = torch.zeros([batch_size, 3], dtype=dtype, device=device)

        # Concatenate all pose vectors
        full_pose = torch.cat(
            [
                global_orient.reshape(-1, 1, 3, 3),
                body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3),
                jaw_pose.reshape(-1, 1, 3, 3),
                leye_pose.reshape(-1, 1, 3, 3),
                reye_pose.reshape(-1, 1, 3, 3),
                left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3),
                right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3),
            ],
            dim=1,
        )
        shape_components = torch.cat([betas, expression], dim=-1)

        shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1)

        vertices, joints = lbs(
            shape_components,
            full_pose,
            self.v_template,
            shapedirs,
            self.posedirs,
            self.J_regressor,
            self.parents,
            self.lbs_weights,
            pose2rot=False,
        )

        lmk_faces_idx = (self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1).contiguous())
        lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(batch_size, 1, 1)
        if self.use_face_contour:
            lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords(
                vertices,
                full_pose,
                self.dynamic_lmk_faces_idx,
                self.dynamic_lmk_bary_coords,
                self.neck_kin_chain,
                pose2rot=False,
            )
            dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords

            lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
            lmk_bary_coords = torch.cat([
                lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords
            ], 1)

        landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords)

        # Add any extra joints that might be needed
        joints = self.vertex_joint_selector(vertices, joints)
        # Add the landmarks to the joints
        joints = torch.cat([joints, landmarks], dim=1)
        # Map the joints to the current dataset

        if self.joint_mapper is not None:
            joints = self.joint_mapper(joints=joints, vertices=vertices)

        if transl is not None:
            joints += transl.unsqueeze(dim=1)
            vertices += transl.unsqueeze(dim=1)

        output = SMPLXOutput(
            vertices=vertices if return_verts else None,
            joints=joints,
            betas=betas,
            expression=expression,
            global_orient=global_orient,
            body_pose=body_pose,
            left_hand_pose=left_hand_pose,
            right_hand_pose=right_hand_pose,
            jaw_pose=jaw_pose,
            transl=transl,
            full_pose=full_pose if return_full_pose else None,
        )
        return output


class MANO(SMPL):
    # The hand joints are replaced by MANO
    NUM_BODY_JOINTS = 1
    NUM_HAND_JOINTS = 15
    NUM_JOINTS = NUM_BODY_JOINTS + NUM_HAND_JOINTS

    def __init__(
        self,
        model_path: str,
        is_rhand: bool = True,
        data_struct: Optional[Struct] = None,
        create_hand_pose: bool = True,
        hand_pose: Optional[Tensor] = None,
        use_pca: bool = True,
        num_pca_comps: int = 6,
        flat_hand_mean: bool = False,
        batch_size: int = 1,
        dtype=torch.float32,
        vertex_ids=None,
        use_compressed: bool = True,
        ext: str = "pkl",
        **kwargs,
    ) -> None:
        """MANO model constructor

        Parameters
        ----------
        model_path: str
            The path to the folder or to the file where the model
            parameters are stored
        data_struct: Strct
            A struct object. If given, then the parameters of the model are
            read from the object. Otherwise, the model tries to read the
            parameters from the given `model_path`. (default = None)
        create_hand_pose: bool, optional
            Flag for creating a member variable for the pose of the right
            hand. (default = True)
        hand_pose: torch.tensor, optional, BxP
            The default value for the right hand pose member variable.
            (default = None)
        num_pca_comps: int, optional
            The number of PCA components to use for each hand.
            (default = 6)
        flat_hand_mean: bool, optional
            If False, then the pose of the hand is initialized to False.
        batch_size: int, optional
            The batch size used for creating the member variables
        dtype: torch.dtype, optional
            The data type for the created variables
        vertex_ids: dict, optional
            A dictionary containing the indices of the extra vertices that
            will be selected
        """

        self.num_pca_comps = num_pca_comps
        self.is_rhand = is_rhand
        # If no data structure is passed, then load the data from the given
        # model folder
        if data_struct is None:
            # Load the model
            if osp.isdir(model_path):
                model_fn = "MANO_{}.{ext}".format("RIGHT" if is_rhand else "LEFT", ext=ext)
                mano_path = os.path.join(model_path, model_fn)
            else:
                mano_path = model_path
                self.is_rhand = (True if "RIGHT" in os.path.basename(model_path) else False)
            assert osp.exists(mano_path), "Path {} does not exist!".format(mano_path)

            if ext == "pkl":
                with open(mano_path, "rb") as mano_file:
                    model_data = pickle.load(mano_file, encoding="latin1")
            elif ext == "npz":
                model_data = np.load(mano_path, allow_pickle=True)
            else:
                raise ValueError("Unknown extension: {}".format(ext))
            data_struct = Struct(**model_data)

        if vertex_ids is None:
            vertex_ids = VERTEX_IDS["smplh"]

        super(MANO, self).__init__(
            model_path=model_path,
            data_struct=data_struct,
            batch_size=batch_size,
            vertex_ids=vertex_ids,
            use_compressed=use_compressed,
            dtype=dtype,
            ext=ext,
            **kwargs,
        )

        # add only MANO tips to the extra joints
        self.vertex_joint_selector.extra_joints_idxs = to_tensor(
            list(VERTEX_IDS["mano"].values()), dtype=torch.long
        )

        self.use_pca = use_pca
        self.num_pca_comps = num_pca_comps
        if self.num_pca_comps == 45:
            self.use_pca = False
        self.flat_hand_mean = flat_hand_mean

        hand_components = data_struct.hands_components[:num_pca_comps]

        self.np_hand_components = hand_components

        if self.use_pca:
            self.register_buffer("hand_components", torch.tensor(hand_components, dtype=dtype))

        if self.flat_hand_mean:
            hand_mean = np.zeros_like(data_struct.hands_mean)
        else:
            hand_mean = data_struct.hands_mean

        self.register_buffer("hand_mean", to_tensor(hand_mean, dtype=self.dtype))

        # Create the buffers for the pose of the left hand
        hand_pose_dim = num_pca_comps if use_pca else 3 * self.NUM_HAND_JOINTS
        if create_hand_pose:
            if hand_pose is None:
                default_hand_pose = torch.zeros([batch_size, hand_pose_dim], dtype=dtype)
            else:
                default_hand_pose = torch.tensor(hand_pose, dtype=dtype)

            hand_pose_param = nn.Parameter(default_hand_pose, requires_grad=True)
            self.register_parameter("hand_pose", hand_pose_param)

        # Create the buffer for the mean pose.
        pose_mean = self.create_mean_pose(data_struct, flat_hand_mean=flat_hand_mean)
        pose_mean_tensor = pose_mean.clone().to(dtype)
        # pose_mean_tensor = torch.tensor(pose_mean, dtype=dtype)
        self.register_buffer("pose_mean", pose_mean_tensor)

    def name(self) -> str:
        return "MANO"

    def create_mean_pose(self, data_struct, flat_hand_mean=False):
        # Create the array for the mean pose. If flat_hand is false, then use
        # the mean that is given by the data, rather than the flat open hand
        global_orient_mean = torch.zeros([3], dtype=self.dtype)
        pose_mean = torch.cat([global_orient_mean, self.hand_mean], dim=0)
        return pose_mean

    def extra_repr(self):
        msg = [super(MANO, self).extra_repr()]
        if self.use_pca:
            msg.append(f"Number of PCA components: {self.num_pca_comps}")
        msg.append(f"Flat hand mean: {self.flat_hand_mean}")
        return "\n".join(msg)

    def forward(
        self,
        betas: Optional[Tensor] = None,
        global_orient: Optional[Tensor] = None,
        hand_pose: Optional[Tensor] = None,
        transl: Optional[Tensor] = None,
        return_verts: bool = True,
        return_full_pose: bool = False,
        **kwargs,
    ) -> MANOOutput:
        """Forward pass for the MANO model"""
        # If no shape and pose parameters are passed along, then use the
        # ones from the module
        global_orient = (global_orient if global_orient is not None else self.global_orient)
        betas = betas if betas is not None else self.betas
        hand_pose = hand_pose if hand_pose is not None else self.hand_pose

        apply_trans = transl is not None or hasattr(self, "transl")
        if transl is None:
            if hasattr(self, "transl"):
                transl = self.transl

        if self.use_pca:
            hand_pose = torch.einsum("bi,ij->bj", [hand_pose, self.hand_components])

        full_pose = torch.cat([global_orient, hand_pose], dim=1)
        full_pose += self.pose_mean

        vertices, joints = lbs(
            betas,
            full_pose,
            self.v_template,
            self.shapedirs,
            self.posedirs,
            self.J_regressor,
            self.parents,
            self.lbs_weights,
            pose2rot=True,
        )

        # # Add pre-selected extra joints that might be needed
        # joints = self.vertex_joint_selector(vertices, joints)

        if self.joint_mapper is not None:
            joints = self.joint_mapper(joints)

        if apply_trans:
            joints = joints + transl.unsqueeze(dim=1)
            vertices = vertices + transl.unsqueeze(dim=1)

        output = MANOOutput(
            vertices=vertices if return_verts else None,
            joints=joints if return_verts else None,
            betas=betas,
            global_orient=global_orient,
            hand_pose=hand_pose,
            full_pose=full_pose if return_full_pose else None,
        )

        return output


class MANOLayer(MANO):
    def __init__(self, *args, **kwargs) -> None:
        """MANO as a layer model constructor"""
        super(MANOLayer, self).__init__(
            create_global_orient=False,
            create_hand_pose=False,
            create_betas=False,
            create_transl=False,
            *args,
            **kwargs,
        )

    def name(self) -> str:
        return "MANO"

    def forward(
        self,
        betas: Optional[Tensor] = None,
        global_orient: Optional[Tensor] = None,
        hand_pose: Optional[Tensor] = None,
        transl: Optional[Tensor] = None,
        return_verts: bool = True,
        return_full_pose: bool = False,
        **kwargs,
    ) -> MANOOutput:
        """Forward pass for the MANO model"""
        device, dtype = self.shapedirs.device, self.shapedirs.dtype
        if global_orient is None:
            batch_size = 1
            global_orient = (
                torch.eye(3, device=device,
                          dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
            )
        else:
            batch_size = global_orient.shape[0]
        if hand_pose is None:
            hand_pose = (
                torch.eye(3, device=device,
                          dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()
            )
        if betas is None:
            betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device)
        if transl is None:
            transl = torch.zeros([batch_size, 3], dtype=dtype, device=device)

        full_pose = torch.cat([global_orient, hand_pose], dim=1)
        vertices, joints = lbs(
            betas,
            full_pose,
            self.v_template,
            self.shapedirs,
            self.posedirs,
            self.J_regressor,
            self.parents,
            self.lbs_weights,
            pose2rot=False,
        )

        if self.joint_mapper is not None:
            joints = self.joint_mapper(joints)

        if transl is not None:
            joints = joints + transl.unsqueeze(dim=1)
            vertices = vertices + transl.unsqueeze(dim=1)

        output = MANOOutput(
            vertices=vertices if return_verts else None,
            joints=joints if return_verts else None,
            betas=betas,
            global_orient=global_orient,
            hand_pose=hand_pose,
            full_pose=full_pose if return_full_pose else None,
        )

        return output


class FLAME(SMPL):
    NUM_JOINTS = 5
    SHAPE_SPACE_DIM = 300
    EXPRESSION_SPACE_DIM = 100
    NECK_IDX = 0

    def __init__(
        self,
        model_path: str,
        data_struct=None,
        num_expression_coeffs=10,
        create_expression: bool = True,
        expression: Optional[Tensor] = None,
        create_neck_pose: bool = True,
        neck_pose: Optional[Tensor] = None,
        create_jaw_pose: bool = True,
        jaw_pose: Optional[Tensor] = None,
        create_leye_pose: bool = True,
        leye_pose: Optional[Tensor] = None,
        create_reye_pose=True,
        reye_pose: Optional[Tensor] = None,
        use_face_contour=False,
        batch_size: int = 1,
        gender: str = "neutral",
        dtype: torch.dtype = torch.float32,
        ext="pkl",
        **kwargs,
    ) -> None:
        """FLAME model constructor

        Parameters
        ----------
        model_path: str
            The path to the folder or to the file where the model
            parameters are stored
        num_expression_coeffs: int, optional
            Number of expression components to use
            (default = 10).
        create_expression: bool, optional
            Flag for creating a member variable for the expression space
            (default = True).
        expression: torch.tensor, optional, Bx10
            The default value for the expression member variable.
            (default = None)
        create_neck_pose: bool, optional
            Flag for creating a member variable for the neck pose.
            (default = False)
        neck_pose: torch.tensor, optional, Bx3
            The default value for the neck pose variable.
            (default = None)
        create_jaw_pose: bool, optional
            Flag for creating a member variable for the jaw pose.
            (default = False)
        jaw_pose: torch.tensor, optional, Bx3
            The default value for the jaw pose variable.
            (default = None)
        create_leye_pose: bool, optional
            Flag for creating a member variable for the left eye pose.
            (default = False)
        leye_pose: torch.tensor, optional, Bx10
            The default value for the left eye pose variable.
            (default = None)
        create_reye_pose: bool, optional
            Flag for creating a member variable for the right eye pose.
            (default = False)
        reye_pose: torch.tensor, optional, Bx10
            The default value for the right eye pose variable.
            (default = None)
        use_face_contour: bool, optional
            Whether to compute the keypoints that form the facial contour
        batch_size: int, optional
            The batch size used for creating the member variables
        gender: str, optional
            Which gender to load
        dtype: torch.dtype
            The data type for the created variables
        """
        model_fn = f"FLAME_{gender.upper()}.{ext}"
        flame_path = os.path.join(model_path, model_fn)
        assert osp.exists(flame_path), "Path {} does not exist!".format(flame_path)
        if ext == "npz":
            file_data = np.load(flame_path, allow_pickle=True)
        elif ext == "pkl":
            with open(flame_path, "rb") as smpl_file:
                file_data = pickle.load(smpl_file, encoding="latin1")
        else:
            raise ValueError("Unknown extension: {}".format(ext))
        data_struct = Struct(**file_data)

        super(FLAME, self).__init__(
            model_path=model_path,
            data_struct=data_struct,
            dtype=dtype,
            batch_size=batch_size,
            gender=gender,
            ext=ext,
            **kwargs,
        )

        self.use_face_contour = use_face_contour

        self.vertex_joint_selector.extra_joints_idxs = to_tensor([], dtype=torch.long)

        if create_neck_pose:
            if neck_pose is None:
                default_neck_pose = torch.zeros([batch_size, 3], dtype=dtype)
            else:
                default_neck_pose = torch.tensor(neck_pose, dtype=dtype)
            neck_pose_param = nn.Parameter(default_neck_pose, requires_grad=True)
            self.register_parameter("neck_pose", neck_pose_param)

        if create_jaw_pose:
            if jaw_pose is None:
                default_jaw_pose = torch.zeros([batch_size, 3], dtype=dtype)
            else:
                default_jaw_pose = torch.tensor(jaw_pose, dtype=dtype)
            jaw_pose_param = nn.Parameter(default_jaw_pose, requires_grad=True)
            self.register_parameter("jaw_pose", jaw_pose_param)

        if create_leye_pose:
            if leye_pose is None:
                default_leye_pose = torch.zeros([batch_size, 3], dtype=dtype)
            else:
                default_leye_pose = torch.tensor(leye_pose, dtype=dtype)
            leye_pose_param = nn.Parameter(default_leye_pose, requires_grad=True)
            self.register_parameter("leye_pose", leye_pose_param)

        if create_reye_pose:
            if reye_pose is None:
                default_reye_pose = torch.zeros([batch_size, 3], dtype=dtype)
            else:
                default_reye_pose = torch.tensor(reye_pose, dtype=dtype)
            reye_pose_param = nn.Parameter(default_reye_pose, requires_grad=True)
            self.register_parameter("reye_pose", reye_pose_param)

        shapedirs = data_struct.shapedirs
        if len(shapedirs.shape) < 3:
            shapedirs = shapedirs[:, :, None]
        if shapedirs.shape[-1] < self.SHAPE_SPACE_DIM + self.EXPRESSION_SPACE_DIM:
            # print(f'WARNING: You are using a {self.name()} model, with only'
            #       ' 10 shape and 10 expression coefficients.')
            expr_start_idx = 10
            expr_end_idx = 20
            num_expression_coeffs = min(num_expression_coeffs, 10)
        else:
            expr_start_idx = self.SHAPE_SPACE_DIM
            expr_end_idx = self.SHAPE_SPACE_DIM + num_expression_coeffs
            num_expression_coeffs = min(num_expression_coeffs, self.EXPRESSION_SPACE_DIM)

        self._num_expression_coeffs = num_expression_coeffs

        expr_dirs = shapedirs[:, :, expr_start_idx:expr_end_idx]
        self.register_buffer("expr_dirs", to_tensor(to_np(expr_dirs), dtype=dtype))

        if create_expression:
            if expression is None:
                default_expression = torch.zeros([batch_size, self.num_expression_coeffs],
                                                 dtype=dtype)
            else:
                default_expression = torch.tensor(expression, dtype=dtype)
            expression_param = nn.Parameter(default_expression, requires_grad=True)
            self.register_parameter("expression", expression_param)

        # The pickle file that contains the barycentric coordinates for
        # regressing the landmarks
        landmark_bcoord_filename = osp.join(model_path, "flame_static_embedding.pkl")

        with open(landmark_bcoord_filename, "rb") as fp:
            landmarks_data = pickle.load(fp, encoding="latin1")

        lmk_faces_idx = landmarks_data["lmk_face_idx"].astype(np.int64)
        self.register_buffer("lmk_faces_idx", torch.tensor(lmk_faces_idx, dtype=torch.long))
        lmk_bary_coords = landmarks_data["lmk_b_coords"]
        self.register_buffer("lmk_bary_coords", torch.tensor(lmk_bary_coords, dtype=dtype))
        if self.use_face_contour:
            face_contour_path = os.path.join(model_path, "flame_dynamic_embedding.npy")
            contour_embeddings = np.load(face_contour_path, allow_pickle=True,
                                         encoding="latin1")[()]

            dynamic_lmk_faces_idx = np.array(contour_embeddings["lmk_face_idx"], dtype=np.int64)
            dynamic_lmk_faces_idx = torch.tensor(dynamic_lmk_faces_idx, dtype=torch.long)
            self.register_buffer("dynamic_lmk_faces_idx", dynamic_lmk_faces_idx)

            dynamic_lmk_b_coords = torch.tensor(contour_embeddings["lmk_b_coords"], dtype=dtype)
            self.register_buffer("dynamic_lmk_bary_coords", dynamic_lmk_b_coords)

            neck_kin_chain = find_joint_kin_chain(self.NECK_IDX, self.parents)
            self.register_buffer("neck_kin_chain", torch.tensor(neck_kin_chain, dtype=torch.long))

    @property
    def num_expression_coeffs(self):
        return self._num_expression_coeffs

    def name(self) -> str:
        return "FLAME"

    def extra_repr(self):
        msg = [
            super(FLAME, self).extra_repr(),
            f"Number of Expression Coefficients: {self.num_expression_coeffs}",
            f"Use face contour: {self.use_face_contour}",
        ]
        return "\n".join(msg)

    def forward(
        self,
        betas: Optional[Tensor] = None,
        global_orient: Optional[Tensor] = None,
        neck_pose: Optional[Tensor] = None,
        transl: Optional[Tensor] = None,
        expression: Optional[Tensor] = None,
        jaw_pose: Optional[Tensor] = None,
        leye_pose: Optional[Tensor] = None,
        reye_pose: Optional[Tensor] = None,
        return_verts: bool = True,
        return_full_pose: bool = False,
        pose2rot: bool = True,
        **kwargs,
    ) -> FLAMEOutput:
        """
        Forward pass for the SMPLX model

            Parameters
            ----------
            global_orient: torch.tensor, optional, shape Bx3
                If given, ignore the member variable and use it as the global
                rotation of the body. Useful if someone wishes to predicts this
                with an external model. (default=None)
            betas: torch.tensor, optional, shape Bx10
                If given, ignore the member variable `betas` and use it
                instead. For example, it can used if shape parameters
                `betas` are predicted from some external model.
                (default=None)
            expression: torch.tensor, optional, shape Bx10
                If given, ignore the member variable `expression` and use it
                instead. For example, it can used if expression parameters
                `expression` are predicted from some external model.
            jaw_pose: torch.tensor, optional, shape Bx3
                If given, ignore the member variable `jaw_pose` and
                use this instead. It should either joint rotations in
                axis-angle format.
            jaw_pose: torch.tensor, optional, shape Bx3
                If given, ignore the member variable `jaw_pose` and
                use this instead. It should either joint rotations in
                axis-angle format.
            transl: torch.tensor, optional, shape Bx3
                If given, ignore the member variable `transl` and use it
                instead. For example, it can used if the translation
                `transl` is predicted from some external model.
                (default=None)
            return_verts: bool, optional
                Return the vertices. (default=True)
            return_full_pose: bool, optional
                Returns the full axis-angle pose vector (default=False)

            Returns
            -------
                output: ModelOutput
                A named tuple of type `ModelOutput`
        """

        # If no shape and pose parameters are passed along, then use the
        # ones from the module
        global_orient = (global_orient if global_orient is not None else self.global_orient)
        jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose
        neck_pose = neck_pose if neck_pose is not None else self.neck_pose

        leye_pose = leye_pose if leye_pose is not None else self.leye_pose
        reye_pose = reye_pose if reye_pose is not None else self.reye_pose

        betas = betas if betas is not None else self.betas
        expression = expression if expression is not None else self.expression

        apply_trans = transl is not None or hasattr(self, "transl")
        if transl is None:
            if hasattr(self, "transl"):
                transl = self.transl

        full_pose = torch.cat([global_orient, neck_pose, jaw_pose, leye_pose, reye_pose], dim=1)

        batch_size = max(betas.shape[0], global_orient.shape[0], jaw_pose.shape[0])
        # Concatenate the shape and expression coefficients
        scale = int(batch_size / betas.shape[0])
        if scale > 1:
            betas = betas.expand(scale, -1)
        shape_components = torch.cat([betas, expression], dim=-1)
        shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1)

        vertices, joints = lbs(
            shape_components,
            full_pose,
            self.v_template,
            shapedirs,
            self.posedirs,
            self.J_regressor,
            self.parents,
            self.lbs_weights,
            pose2rot=pose2rot,
        )

        lmk_faces_idx = (self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1).contiguous())
        lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(self.batch_size, 1, 1)
        if self.use_face_contour:
            lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords(
                vertices,
                full_pose,
                self.dynamic_lmk_faces_idx,
                self.dynamic_lmk_bary_coords,
                self.neck_kin_chain,
                pose2rot=True,
            )
            dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords
            lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
            lmk_bary_coords = torch.cat([
                lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords
            ], 1)

        landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords)

        # Add any extra joints that might be needed
        joints = self.vertex_joint_selector(vertices, joints)
        # Add the landmarks to the joints
        joints = torch.cat([joints, landmarks], dim=1)

        # Map the joints to the current dataset
        if self.joint_mapper is not None:
            joints = self.joint_mapper(joints=joints, vertices=vertices)

        if apply_trans:
            joints += transl.unsqueeze(dim=1)
            vertices += transl.unsqueeze(dim=1)

        output = FLAMEOutput(
            vertices=vertices if return_verts else None,
            joints=joints,
            betas=betas,
            expression=expression,
            global_orient=global_orient,
            neck_pose=neck_pose,
            jaw_pose=jaw_pose,
            full_pose=full_pose if return_full_pose else None,
        )
        return output


class FLAMELayer(FLAME):
    def __init__(self, *args, **kwargs) -> None:
        """ FLAME as a layer model constructor """
        super(FLAMELayer, self).__init__(
            create_betas=False,
            create_expression=False,
            create_global_orient=False,
            create_neck_pose=False,
            create_jaw_pose=False,
            create_leye_pose=False,
            create_reye_pose=False,
            *args,
            **kwargs,
        )

    def forward(
        self,
        betas: Optional[Tensor] = None,
        global_orient: Optional[Tensor] = None,
        neck_pose: Optional[Tensor] = None,
        transl: Optional[Tensor] = None,
        expression: Optional[Tensor] = None,
        jaw_pose: Optional[Tensor] = None,
        leye_pose: Optional[Tensor] = None,
        reye_pose: Optional[Tensor] = None,
        return_verts: bool = True,
        return_full_pose: bool = False,
        pose2rot: bool = True,
        **kwargs,
    ) -> FLAMEOutput:
        """
        Forward pass for the SMPLX model

            Parameters
            ----------
            global_orient: torch.tensor, optional, shape Bx3x3
                Global rotation of the body. Useful if someone wishes to
                predicts this with an external model. It is expected to be in
                rotation matrix format. (default=None)
            betas: torch.tensor, optional, shape BxN_b
                Shape parameters. For example, it can used if shape parameters
                `betas` are predicted from some external model.
                (default=None)
            expression: torch.tensor, optional, shape BxN_e
                If given, ignore the member variable `expression` and use it
                instead. For example, it can used if expression parameters
                `expression` are predicted from some external model.
            jaw_pose: torch.tensor, optional, shape Bx3x3
                Jaw pose. It should either joint rotations in
                rotation matrix format.
            transl: torch.tensor, optional, shape Bx3
                Translation vector of the body.
                For example, it can used if the translation
                `transl` is predicted from some external model.
                (default=None)
            return_verts: bool, optional
                Return the vertices. (default=True)
            return_full_pose: bool, optional
                Returns the full axis-angle pose vector (default=False)

            Returns
            -------
                output: ModelOutput
                A named tuple of type `ModelOutput`
        """
        device, dtype = self.shapedirs.device, self.shapedirs.dtype
        if global_orient is None:
            batch_size = 1
            global_orient = (
                torch.eye(3, device=device,
                          dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
            )
        else:
            batch_size = global_orient.shape[0]
        if neck_pose is None:
            neck_pose = (
                torch.eye(3, device=device,
                          dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 1, -1, -1).contiguous()
            )
        if jaw_pose is None:
            jaw_pose = (
                torch.eye(3, device=device,
                          dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
            )
        if leye_pose is None:
            leye_pose = (
                torch.eye(3, device=device,
                          dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
            )
        if reye_pose is None:
            reye_pose = (
                torch.eye(3, device=device,
                          dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
            )
        if betas is None:
            betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device)
        if expression is None:
            expression = torch.zeros([batch_size, self.num_expression_coeffs],
                                     dtype=dtype,
                                     device=device)
        if transl is None:
            transl = torch.zeros([batch_size, 3], dtype=dtype, device=device)

        full_pose = torch.cat([global_orient, neck_pose, jaw_pose, leye_pose, reye_pose], dim=1)

        shape_components = torch.cat([betas, expression], dim=-1)
        shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1)

        vertices, joints = lbs(
            shape_components,
            full_pose,
            self.v_template,
            shapedirs,
            self.posedirs,
            self.J_regressor,
            self.parents,
            self.lbs_weights,
            pose2rot=False,
        )

        lmk_faces_idx = (self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1).contiguous())
        lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(self.batch_size, 1, 1)
        if self.use_face_contour:
            lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords(
                vertices,
                full_pose,
                self.dynamic_lmk_faces_idx,
                self.dynamic_lmk_bary_coords,
                self.neck_kin_chain,
                pose2rot=False,
            )
            dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords
            lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
            lmk_bary_coords = torch.cat([
                lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords
            ], 1)

        landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords)

        # Add any extra joints that might be needed
        joints = self.vertex_joint_selector(vertices, joints)
        # Add the landmarks to the joints
        joints = torch.cat([joints, landmarks], dim=1)

        # Map the joints to the current dataset
        if self.joint_mapper is not None:
            joints = self.joint_mapper(joints=joints, vertices=vertices)

        joints += transl.unsqueeze(dim=1)
        vertices += transl.unsqueeze(dim=1)

        output = FLAMEOutput(
            vertices=vertices if return_verts else None,
            joints=joints,
            betas=betas,
            expression=expression,
            global_orient=global_orient,
            neck_pose=neck_pose,
            jaw_pose=jaw_pose,
            full_pose=full_pose if return_full_pose else None,
        )
        return output


def build_layer(model_path: str,
                model_type: str = "smpl",
                **kwargs) -> Union[SMPLLayer, SMPLHLayer, SMPLXLayer, MANOLayer, FLAMELayer]:
    """Method for creating a model from a path and a model type

    Parameters
    ----------
    model_path: str
        Either the path to the model you wish to load or a folder,
        where each subfolder contains the differents types, i.e.:
        model_path:
        |
        |-- smpl
            |-- SMPL_FEMALE
            |-- SMPL_NEUTRAL
            |-- SMPL_MALE
        |-- smplh
            |-- SMPLH_FEMALE
            |-- SMPLH_MALE
        |-- smplx
            |-- SMPLX_FEMALE
            |-- SMPLX_NEUTRAL
            |-- SMPLX_MALE
        |-- mano
            |-- MANO RIGHT
            |-- MANO LEFT
        |-- flame
            |-- FLAME_FEMALE
            |-- FLAME_MALE
            |-- FLAME_NEUTRAL

    model_type: str, optional
        When model_path is a folder, then this parameter specifies  the
        type of model to be loaded
    **kwargs: dict
        Keyword arguments

    Returns
    -------
        body_model: nn.Module
            The PyTorch module that implements the corresponding body model
    Raises
    ------
        ValueError: In case the model type is not one of SMPL, SMPLH,
        SMPLX, MANO or FLAME
    """

    if osp.isdir(model_path):
        model_path = os.path.join(model_path, model_type)
    else:
        model_type = osp.basename(model_path).split("_")[0].lower()

    if model_type.lower() == "smpl":
        return SMPLLayer(model_path, **kwargs)
    elif model_type.lower() == "smplh":
        return SMPLHLayer(model_path, **kwargs)
    elif model_type.lower() == "smplx":
        return SMPLXLayer(model_path, **kwargs)
    elif "mano" in model_type.lower():
        return MANOLayer(model_path, **kwargs)
    elif "flame" in model_type.lower():
        return FLAMELayer(model_path, **kwargs)
    else:
        raise ValueError(f"Unknown model type {model_type}, exiting!")


def create(model_path: str,
           model_type: str = "smpl",
           **kwargs) -> Union[SMPL, SMPLH, SMPLX, MANO, FLAME]:
    """Method for creating a model from a path and a model type

    Parameters
    ----------
    model_path: str
        Either the path to the model you wish to load or a folder,
        where each subfolder contains the differents types, i.e.:
        model_path:
        |
        |-- smpl
            |-- SMPL_FEMALE
            |-- SMPL_NEUTRAL
            |-- SMPL_MALE
        |-- smplh
            |-- SMPLH_FEMALE
            |-- SMPLH_MALE
        |-- smplx
            |-- SMPLX_FEMALE
            |-- SMPLX_NEUTRAL
            |-- SMPLX_MALE
        |-- mano
            |-- MANO RIGHT
            |-- MANO LEFT

    model_type: str, optional
        When model_path is a folder, then this parameter specifies  the
        type of model to be loaded
    **kwargs: dict
        Keyword arguments

    Returns
    -------
        body_model: nn.Module
            The PyTorch module that implements the corresponding body model
    Raises
    ------
        ValueError: In case the model type is not one of SMPL, SMPLH,
        SMPLX, MANO or FLAME
    """

    # If it's a folder, assume
    if osp.isdir(model_path):
        model_path = os.path.join(model_path, model_type)
    else:
        model_type = osp.basename(model_path).split("_")[0].lower()

    if model_type.lower() == "smpl":
        return SMPL(model_path, **kwargs)
    elif model_type.lower() == "smplh":
        return SMPLH(model_path, **kwargs)
    elif model_type.lower() == "smplx":
        return SMPLX(model_path, **kwargs)
    elif "mano" in model_type.lower():
        return MANO(model_path, **kwargs)
    elif "flame" in model_type.lower():
        return FLAME(model_path, **kwargs)
    else:
        raise ValueError(f"Unknown model type {model_type}, exiting!")