import torch import numpy as np from torch import nn import pickle as pkl import torch.nn.functional as F class Struct(object): def __init__(self, **kwargs): for key, val in kwargs.items(): setattr(self, key, val) def to_np(array, dtype=np.float32): if 'scipy.sparse' in str(type(array)): array = array.todense() return np.array(array, dtype=dtype) class Get_Joints(nn.Module): def __init__(self, path, batch_size=300) -> None: super().__init__() self.betas = nn.parameter.Parameter(torch.zeros([batch_size, 10], dtype=torch.float32), requires_grad=False) with open(path, "rb") as f: smpl_prior = pkl.load(f, encoding="latin1") data_struct = Struct(**smpl_prior) self.v_template = nn.parameter.Parameter(torch.from_numpy(to_np(data_struct.v_template)), requires_grad=False) self.shapedirs = nn.parameter.Parameter(torch.from_numpy(to_np(data_struct.shapedirs)), requires_grad=False) self.J_regressor = nn.parameter.Parameter(torch.from_numpy(to_np(data_struct.J_regressor)), requires_grad=False) posedirs = torch.from_numpy(to_np(data_struct.posedirs)) num_pose_basis = posedirs.shape[-1] posedirs = posedirs.reshape([-1, num_pose_basis]).permute(1, 0) self.posedirs = nn.parameter.Parameter(posedirs, requires_grad=False) self.parents = nn.parameter.Parameter(torch.from_numpy(to_np(data_struct.kintree_table)[0]).long(), requires_grad=False) self.parents[0] = -1 self.ident = nn.parameter.Parameter(torch.eye(3), requires_grad=False) self.K = nn.parameter.Parameter(torch.zeros([1, 3, 3]), requires_grad=False) self.zeros = nn.parameter.Parameter(torch.zeros([1, 1]), requires_grad=False) def blend_shapes(self, betas, shape_disps): blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps]) return blend_shape def vertices2joints(self, J_regressor, vertices): return torch.einsum('bik,ji->bjk', [vertices, J_regressor]) def batch_rodrigues( self, rot_vecs, epsilon = 1e-8, ): batch_size = rot_vecs.shape[0] angle = torch.norm(rot_vecs + epsilon, dim=1, keepdim=True) rot_dir = rot_vecs / angle cos = torch.unsqueeze(torch.cos(angle), dim=1) sin = torch.unsqueeze(torch.sin(angle), dim=1) # Bx1 arrays rx, ry, rz = torch.split(rot_dir, 1, dim=1) K = self.K.repeat(batch_size, 1, 1) zeros = self.zeros.repeat(batch_size, 1) K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view((batch_size, 3, 3)) ident = self.ident.unsqueeze(0) rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) return rot_mat def transform_mat(self, R, t): return torch.cat([F.pad(R, [0, 0, 0, 1]), F.pad(t, [0, 0, 0, 1], value=1)], dim=2) def batch_rigid_transform( self, rot_mats, joints, parents, ): joints = torch.unsqueeze(joints, dim=-1) rel_joints = joints.clone() rel_joints[:, 1:] -= joints[:, parents[1:]] transforms_mat = self.transform_mat( rot_mats.reshape(-1, 3, 3), rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4) transform_chain = [transforms_mat[:, 0]] for i in range(1, parents.shape[0]): # Subtract the joint location at the rest pose # No need for rotation, since it's identity when at rest curr_res = torch.matmul(transform_chain[parents[i]], transforms_mat[:, i]) transform_chain.append(curr_res) transforms = torch.stack(transform_chain, dim=1) # The last column of the transformations contains the posed joints posed_joints = transforms[:, :, :3, 3] return posed_joints def forward(self, pose, trans=None): pose = pose.float() batch = pose.shape[0] betas = self.betas[:batch] v_shaped = self.v_template + self.blend_shapes(betas, self.shapedirs) J = self.vertices2joints(self.J_regressor, v_shaped) rot_mats = self.batch_rodrigues(pose.view(-1, 3)).view([batch, -1, 3, 3]) J_transformed = self.batch_rigid_transform(rot_mats, J, self.parents) if trans is not None: J_transformed += trans.unsqueeze(dim=1) return J_transformed