SemanticBoost / SMPLX /read_joints_from_pose.py
kleinhe
init
c3d0293
import torch
import numpy as np
from torch import nn
import pickle as pkl
import torch.nn.functional as F
class Struct(object):
def __init__(self, **kwargs):
for key, val in kwargs.items():
setattr(self, key, val)
def to_np(array, dtype=np.float32):
if 'scipy.sparse' in str(type(array)):
array = array.todense()
return np.array(array, dtype=dtype)
class Get_Joints(nn.Module):
def __init__(self, path, batch_size=300) -> None:
super().__init__()
self.betas = nn.parameter.Parameter(torch.zeros([batch_size, 10], dtype=torch.float32), requires_grad=False)
with open(path, "rb") as f:
smpl_prior = pkl.load(f, encoding="latin1")
data_struct = Struct(**smpl_prior)
self.v_template = nn.parameter.Parameter(torch.from_numpy(to_np(data_struct.v_template)), requires_grad=False)
self.shapedirs = nn.parameter.Parameter(torch.from_numpy(to_np(data_struct.shapedirs)), requires_grad=False)
self.J_regressor = nn.parameter.Parameter(torch.from_numpy(to_np(data_struct.J_regressor)), requires_grad=False)
posedirs = torch.from_numpy(to_np(data_struct.posedirs))
num_pose_basis = posedirs.shape[-1]
posedirs = posedirs.reshape([-1, num_pose_basis]).permute(1, 0)
self.posedirs = nn.parameter.Parameter(posedirs, requires_grad=False)
self.parents = nn.parameter.Parameter(torch.from_numpy(to_np(data_struct.kintree_table)[0]).long(), requires_grad=False)
self.parents[0] = -1
self.ident = nn.parameter.Parameter(torch.eye(3), requires_grad=False)
self.K = nn.parameter.Parameter(torch.zeros([1, 3, 3]), requires_grad=False)
self.zeros = nn.parameter.Parameter(torch.zeros([1, 1]), requires_grad=False)
def blend_shapes(self, betas, shape_disps):
blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps])
return blend_shape
def vertices2joints(self, J_regressor, vertices):
return torch.einsum('bik,ji->bjk', [vertices, J_regressor])
def batch_rodrigues(
self,
rot_vecs,
epsilon = 1e-8,
):
batch_size = rot_vecs.shape[0]
angle = torch.norm(rot_vecs + epsilon, dim=1, keepdim=True)
rot_dir = rot_vecs / angle
cos = torch.unsqueeze(torch.cos(angle), dim=1)
sin = torch.unsqueeze(torch.sin(angle), dim=1)
# Bx1 arrays
rx, ry, rz = torch.split(rot_dir, 1, dim=1)
K = self.K.repeat(batch_size, 1, 1)
zeros = self.zeros.repeat(batch_size, 1)
K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view((batch_size, 3, 3))
ident = self.ident.unsqueeze(0)
rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
return rot_mat
def transform_mat(self, R, t):
return torch.cat([F.pad(R, [0, 0, 0, 1]),
F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
def batch_rigid_transform(
self,
rot_mats,
joints,
parents,
):
joints = torch.unsqueeze(joints, dim=-1)
rel_joints = joints.clone()
rel_joints[:, 1:] -= joints[:, parents[1:]]
transforms_mat = self.transform_mat(
rot_mats.reshape(-1, 3, 3),
rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4)
transform_chain = [transforms_mat[:, 0]]
for i in range(1, parents.shape[0]):
# Subtract the joint location at the rest pose
# No need for rotation, since it's identity when at rest
curr_res = torch.matmul(transform_chain[parents[i]],
transforms_mat[:, i])
transform_chain.append(curr_res)
transforms = torch.stack(transform_chain, dim=1)
# The last column of the transformations contains the posed joints
posed_joints = transforms[:, :, :3, 3]
return posed_joints
def forward(self, pose, trans=None):
pose = pose.float()
batch = pose.shape[0]
betas = self.betas[:batch]
v_shaped = self.v_template + self.blend_shapes(betas, self.shapedirs)
J = self.vertices2joints(self.J_regressor, v_shaped)
rot_mats = self.batch_rodrigues(pose.view(-1, 3)).view([batch, -1, 3, 3])
J_transformed = self.batch_rigid_transform(rot_mats, J, self.parents)
if trans is not None:
J_transformed += trans.unsqueeze(dim=1)
return J_transformed