Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Tuple | |
| import numpy as np | |
| import utils.constants as constants | |
| import torch | |
| class HybrIKJointsToRotmat: | |
| def __init__(self): | |
| self.naive_hybrik = constants.SMPL_HYBRIK | |
| self.num_nodes = 22 | |
| self.parents = constants.SMPL_BODY_PARENTS | |
| self.child = constants.SMPL_BODY_CHILDS | |
| self.bones = np.array(constants.SMPL_BODY_BONES).reshape(24, 3)[ | |
| : self.num_nodes | |
| ] | |
| def multi_child_rot( | |
| self, t: np.ndarray, p: np.ndarray, pose_global_parent: np.ndarray | |
| ) -> Tuple[np.ndarray]: | |
| """ | |
| t: B x 3 x child_num | |
| p: B x 3 x child_num | |
| pose_global_parent: B x 3 x 3 | |
| """ | |
| m = np.matmul( | |
| t, np.transpose(np.matmul(np.linalg.inv(pose_global_parent), p), [0, 2, 1]) | |
| ) | |
| u, s, vt = np.linalg.svd(m) | |
| r = np.matmul(np.transpose(vt, [0, 2, 1]), np.transpose(u, [0, 2, 1])) | |
| err_det_mask = (np.linalg.det(r) < 0.0).reshape(-1, 1, 1) | |
| id_fix = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]]).reshape( | |
| 1, 3, 3 | |
| ) | |
| r_fix = np.matmul( | |
| np.transpose(vt, [0, 2, 1]), np.matmul(id_fix, np.transpose(u, [0, 2, 1])) | |
| ) | |
| r = r * (1.0 - err_det_mask) + r_fix * err_det_mask | |
| return r, np.matmul(pose_global_parent, r) | |
| def single_child_rot( | |
| self, | |
| t: np.ndarray, | |
| p: np.ndarray, | |
| pose_global_parent: np.ndarray, | |
| twist: np.ndarray = None, | |
| ) -> Tuple[np.ndarray]: | |
| """ | |
| t: B x 3 x 1 | |
| p: B x 3 x 1 | |
| pose_global_parent: B x 3 x 3 | |
| twist: B x 2 if given, default to None | |
| """ | |
| p_rot = np.matmul(np.linalg.inv(pose_global_parent), p) | |
| cross = np.cross(t, p_rot, axisa=1, axisb=1, axisc=1) | |
| sina = np.linalg.norm(cross, axis=1, keepdims=True) / ( | |
| np.linalg.norm(t, axis=1, keepdims=True) | |
| * np.linalg.norm(p_rot, axis=1, keepdims=True) | |
| ) | |
| cross = cross / np.linalg.norm(cross, axis=1, keepdims=True) | |
| cosa = np.sum(t * p_rot, axis=1, keepdims=True) / ( | |
| np.linalg.norm(t, axis=1, keepdims=True) | |
| * np.linalg.norm(p_rot, axis=1, keepdims=True) | |
| ) | |
| sina = sina.reshape(-1, 1, 1) | |
| cosa = cosa.reshape(-1, 1, 1) | |
| skew_sym_t = np.stack( | |
| [ | |
| 0.0 * cross[:, 0], | |
| -cross[:, 2], | |
| cross[:, 1], | |
| cross[:, 2], | |
| 0.0 * cross[:, 0], | |
| -cross[:, 0], | |
| -cross[:, 1], | |
| cross[:, 0], | |
| 0.0 * cross[:, 0], | |
| ], | |
| 1, | |
| ) | |
| skew_sym_t = skew_sym_t.reshape(-1, 3, 3) | |
| dsw_rotmat = ( | |
| np.eye(3).reshape(1, 3, 3) | |
| + sina * skew_sym_t | |
| + (1.0 - cosa) * np.matmul(skew_sym_t, skew_sym_t) | |
| ) | |
| if twist is not None: | |
| skew_sym_t = np.stack( | |
| [ | |
| 0.0 * t[:, 0], | |
| -t[:, 2], | |
| t[:, 1], | |
| t[:, 2], | |
| 0.0 * t[:, 0], | |
| -t[:, 0], | |
| -t[:, 1], | |
| t[:, 0], | |
| 0.0 * t[:, 0], | |
| ], | |
| 1, | |
| ) | |
| skew_sym_t = skew_sym_t.reshape(-1, 3, 3) | |
| sina = twist[:, 1].reshape(-1, 1, 1) | |
| cosa = twist[:, 0].reshape(-1, 1, 1) | |
| dtw_rotmat = ( | |
| np.eye(3).reshape([1, 3, 3]) | |
| + sina * skew_sym_t | |
| + (1.0 - cosa) * np.matmul(skew_sym_t, skew_sym_t) | |
| ) | |
| dsw_rotmat = np.matmul(dsw_rotmat, dtw_rotmat) | |
| return dsw_rotmat, np.matmul(pose_global_parent, dsw_rotmat) | |
| def __call__(self, joints: np.ndarray, twist: np.ndarray = None) -> np.ndarray: | |
| """ | |
| joints: B x N x 3 | |
| twist: B x N x 2 if given, default to None | |
| """ | |
| expand_dim = False | |
| if len(joints.shape) == 2: | |
| expand_dim = True | |
| joints = np.expand_dims(joints, 0) | |
| if twist is not None: | |
| twist = np.expand_dims(twist, 0) | |
| assert len(joints.shape) == 3 | |
| batch_size = np.shape(joints)[0] | |
| joints_rel = joints - joints[:, self.parents] | |
| joints_hybrik = 0.0 * joints_rel | |
| pose_global = np.zeros([batch_size, self.num_nodes, 3, 3]) | |
| pose = np.zeros([batch_size, self.num_nodes, 3, 3]) | |
| for i in range(self.num_nodes): | |
| if i == 0: | |
| joints_hybrik[:, 0] = joints[:, 0] | |
| else: | |
| joints_hybrik[:, i] = ( | |
| np.matmul( | |
| pose_global[:, self.parents[i]], | |
| self.bones[i].reshape(1, 3, 1), | |
| ).reshape(-1, 3) | |
| + joints_hybrik[:, self.parents[i]] | |
| ) | |
| if self.child[i] == -2: | |
| pose[:, i] = pose[:, i] + np.eye(3).reshape(1, 3, 3) | |
| pose_global[:, i] = pose_global[:, self.parents[i]] | |
| continue | |
| if i == 0: | |
| r, rg = self.multi_child_rot( | |
| np.transpose(self.bones[[1, 2, 3]].reshape(1, 3, 3), [0, 2, 1]), | |
| np.transpose(joints_rel[:, [1, 2, 3]], [0, 2, 1]), | |
| np.eye(3).reshape(1, 3, 3), | |
| ) | |
| elif i == 9: | |
| r, rg = self.multi_child_rot( | |
| np.transpose(self.bones[[12, 13, 14]].reshape(1, 3, 3), [0, 2, 1]), | |
| np.transpose(joints_rel[:, [12, 13, 14]], [0, 2, 1]), | |
| pose_global[:, self.parents[9]], | |
| ) | |
| else: | |
| p = joints_rel[:, self.child[i]] | |
| if self.naive_hybrik[i] == 0: | |
| p = joints[:, self.child[i]] - joints_hybrik[:, i] | |
| twi = None | |
| if twist is not None: | |
| twi = twist[:, i] | |
| r, rg = self.single_child_rot( | |
| self.bones[self.child[i]].reshape(1, 3, 1), | |
| p.reshape(-1, 3, 1), | |
| pose_global[:, self.parents[i]], | |
| twi, | |
| ) | |
| pose[:, i] = r | |
| pose_global[:, i] = rg | |
| if expand_dim: | |
| pose = pose[0] | |
| return pose | |
| class HybrIKJointsToRotmat_Tensor: | |
| def __init__(self): | |
| self.naive_hybrik = constants.SMPL_HYBRIK | |
| self.num_nodes = 22 | |
| self.parents = constants.SMPL_BODY_PARENTS | |
| self.child = constants.SMPL_BODY_CHILDS | |
| self.bones = torch.tensor(constants.SMPL_BODY_BONES).reshape(24, 3)[:self.num_nodes] | |
| def multi_child_rot(self, t, p, pose_global_parent): | |
| """ | |
| t: B x 3 x child_num | |
| p: B x 3 x child_num | |
| pose_global_parent: B x 3 x 3 | |
| """ | |
| m = torch.matmul( | |
| t, torch.transpose(torch.matmul(torch.inverse(pose_global_parent), p), 1, 2) | |
| ) | |
| u, s, vt = torch.linalg.svd(m) | |
| r = torch.matmul(torch.transpose(vt, 1, 2), torch.transpose(u, 1, 2)) | |
| err_det_mask = (torch.det(r) < 0.0).reshape(-1, 1, 1) | |
| id_fix = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]]).reshape(1, 3, 3) | |
| r_fix = torch.matmul( | |
| torch.transpose(vt, 1, 2), torch.matmul(id_fix, torch.transpose(u, 1, 2)) | |
| ) | |
| r = r * (~err_det_mask) + r_fix * err_det_mask | |
| return r, torch.matmul(pose_global_parent, r) | |
| def single_child_rot( | |
| self, | |
| t, | |
| p, | |
| pose_global_parent, | |
| twist = None, | |
| ) -> Tuple[torch.Tensor]: | |
| """ | |
| t: B x 3 x 1 | |
| p: B x 3 x 1 | |
| pose_global_parent: B x 3 x 3 | |
| twist: B x 2 if given, default to None | |
| """ | |
| t_tensor = t.clone().detach()#torch.tensor(t) | |
| p_tensor = p.clone().detach()#torch.tensor(p) | |
| pose_global_parent_tensor = pose_global_parent.clone().detach()#torch.tensor(pose_global_parent) | |
| p_rot = torch.matmul(torch.linalg.inv(pose_global_parent_tensor), p_tensor) | |
| cross = torch.cross(t_tensor, p_rot, dim=1) | |
| sina = torch.linalg.norm(cross, dim=1, keepdim=True) / ( | |
| torch.linalg.norm(t_tensor, dim=1, keepdim=True) | |
| * torch.linalg.norm(p_rot, dim=1, keepdim=True) | |
| ) | |
| cross = cross / torch.linalg.norm(cross, dim=1, keepdim=True) | |
| cosa = torch.sum(t_tensor * p_rot, dim=1, keepdim=True) / ( | |
| torch.linalg.norm(t_tensor, dim=1, keepdim=True) | |
| * torch.linalg.norm(p_rot, dim=1, keepdim=True) | |
| ) | |
| sina = sina.reshape(-1, 1, 1) | |
| cosa = cosa.reshape(-1, 1, 1) | |
| skew_sym_t = torch.stack( | |
| [ | |
| 0.0 * cross[:, 0], | |
| -cross[:, 2], | |
| cross[:, 1], | |
| cross[:, 2], | |
| 0.0 * cross[:, 0], | |
| -cross[:, 0], | |
| -cross[:, 1], | |
| cross[:, 0], | |
| 0.0 * cross[:, 0], | |
| ], | |
| 1, | |
| ) | |
| skew_sym_t = skew_sym_t.reshape(-1, 3, 3) | |
| dsw_rotmat = ( | |
| torch.eye(3).reshape(1, 3, 3) | |
| + sina * skew_sym_t | |
| + (1.0 - cosa) * torch.matmul(skew_sym_t, skew_sym_t) | |
| ) | |
| if twist is not None: | |
| twist_tensor = torch.tensor(twist) | |
| skew_sym_t = torch.stack( | |
| [ | |
| 0.0 * t_tensor[:, 0], | |
| -t_tensor[:, 2], | |
| t_tensor[:, 1], | |
| t_tensor[:, 2], | |
| 0.0 * t_tensor[:, 0], | |
| -t_tensor[:, 0], | |
| -t_tensor[:, 1], | |
| t_tensor[:, 0], | |
| 0.0 * t_tensor[:, 0], | |
| ], | |
| 1, | |
| ) | |
| skew_sym_t = skew_sym_t.reshape(-1, 3, 3) | |
| sina = twist_tensor[:, 1].reshape(-1, 1, 1) | |
| cosa = twist_tensor[:, 0].reshape(-1, 1, 1) | |
| dtw_rotmat = ( | |
| torch.eye(3).reshape([1, 3, 3]) | |
| + sina * skew_sym_t | |
| + (1.0 - cosa) * torch.matmul(skew_sym_t, skew_sym_t) | |
| ) | |
| dsw_rotmat = torch.matmul(dsw_rotmat, dtw_rotmat) | |
| return dsw_rotmat, torch.matmul(pose_global_parent_tensor, dsw_rotmat) | |
| def __call__(self, joints, twist = None) -> torch.Tensor: | |
| """ | |
| joints: B x N x 3 | |
| twist: B x N x 2 if given, default to None | |
| """ | |
| expand_dim = False | |
| if len(joints.shape) == 2: | |
| expand_dim = True | |
| joints = joints.unsqueeze(0) | |
| if twist is not None: | |
| twist = twist.unsqueeze(0) | |
| assert len(joints.shape) == 3 | |
| batch_size = joints.shape[0] | |
| joints_rel = joints - joints[:, self.parents] | |
| joints_hybrik = torch.zeros_like(joints_rel) | |
| pose_global = torch.zeros([batch_size, self.num_nodes, 3, 3]) | |
| pose = torch.zeros([batch_size, self.num_nodes, 3, 3]) | |
| for i in range(self.num_nodes): | |
| if i == 0: | |
| joints_hybrik[:, 0] = joints[:, 0] | |
| else: | |
| joints_hybrik[:, i] = ( | |
| torch.matmul( | |
| pose_global[:, self.parents[i]], | |
| self.bones[i].reshape(1, 3, 1), | |
| ).reshape(-1, 3) | |
| + joints_hybrik[:, self.parents[i]] | |
| ) | |
| if self.child[i] == -2: | |
| pose[:, i] = pose[:, i] + torch.eye(3).reshape(1, 3, 3) | |
| pose_global[:, i] = pose_global[:, self.parents[i]] | |
| continue | |
| if i == 0: | |
| t = self.bones[[1, 2, 3]].reshape(1, 3, 3).permute(0, 2, 1) | |
| p = joints_rel[:, [1, 2, 3]].permute(0, 2, 1) | |
| pose_global_parent = torch.eye(3).reshape(1, 3, 3) | |
| r, rg = self.multi_child_rot(t, p, pose_global_parent) | |
| elif i == 9: | |
| t = self.bones[[12, 13, 14]].reshape(1, 3, 3).permute(0, 2, 1) | |
| p = joints_rel[:, [12, 13, 14]].permute(0, 2, 1) | |
| r, rg = self.multi_child_rot(t, p, pose_global[:, self.parents[9]],) | |
| else: | |
| p = joints_rel[:, self.child[i]] | |
| if self.naive_hybrik[i] == 0: | |
| p = joints[:, self.child[i]] - joints_hybrik[:, i] | |
| twi = None | |
| if twist is not None: | |
| twi = twist[:, i] | |
| t = self.bones[self.child[i]].reshape(-1, 3, 1) | |
| p = p.reshape(-1, 3, 1) | |
| nframes, _, _ = p.shape | |
| t = t.repeat(nframes, 1, 1) | |
| r, rg = self.single_child_rot(t, p, pose_global[:, self.parents[i]], twi) | |
| pose[:, i] = r | |
| pose_global[:, i] = rg | |
| if expand_dim: | |
| pose = pose[0] | |
| return pose | |
| if __name__ == "__main__": | |
| jts2rot_hybrik = HybrIKJointsToRotmat_Tensor() | |
| joints = torch.tensor(constants.SMPL_BODY_BONES).reshape(1, 24, 3)[:, :22] | |
| parents = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19] | |
| for i in range(1, 22): | |
| joints[:, i] = joints[:, i] + joints[:, parents[i]] | |
| print(joints.shape) | |
| pose = jts2rot_hybrik(joints) | |
| print(pose.shape) | |