Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
# holder of all proprietary rights on this computer program. | |
# You can only use this computer program if you have closed | |
# a license agreement with MPG or you get the right to use the computer | |
# program from someone who is authorized to grant you that right. | |
# Any use of the computer program without a valid license is prohibited and | |
# liable to prosecution. | |
# | |
# Copyright©2020 Max-Planck-Gesellschaft zur Förderung | |
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
# for Intelligent Systems. All rights reserved. | |
# | |
# Contact: Vassilis Choutas, [email protected] | |
import sys | |
from typing import NewType, List, Dict, Optional | |
import os | |
import os.path as osp | |
import pickle | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from omegaconf import OmegaConf | |
from loguru import logger | |
from SMPLX.transfer_model.utils.typing import Tensor | |
def rotation_matrix_to_cont_repr(x: Tensor) -> Tensor: | |
assert len(x.shape) == 3, ( | |
f'Expects an array of size Bx3x3, but received {x.shape}') | |
return x[:, :3, :2] | |
def cont_repr_to_rotation_matrix( | |
x: Tensor | |
) -> Tensor: | |
''' Converts tensor in continous representation to rotation matrices | |
''' | |
batch_size = x.shape[0] | |
reshaped_input = x.view(-1, 3, 2) | |
# Normalize the first vector | |
b1 = F.normalize(reshaped_input[:, :, 0].clone(), dim=1) | |
dot_prod = torch.sum( | |
b1 * reshaped_input[:, :, 1].clone(), dim=1, keepdim=True) | |
# Compute the second vector by finding the orthogonal complement to it | |
b2 = F.normalize(reshaped_input[:, :, 1] - dot_prod * b1, dim=1) | |
# Finish building the basis by taking the cross product | |
b3 = torch.cross(b1, b2, dim=1) | |
rot_mats = torch.stack([b1, b2, b3], dim=-1) | |
return rot_mats.view(batch_size, -1, 3, 3) | |
def batch_rodrigues( | |
rot_vecs: Tensor, | |
epsilon: float = 1e-8 | |
) -> Tensor: | |
''' Calculates the rotation matrices for a batch of rotation vectors | |
Parameters | |
---------- | |
rot_vecs: torch.tensor Nx3 | |
array of N axis-angle vectors | |
Returns | |
------- | |
R: torch.tensor Nx3x3 | |
The rotation matrices for the given axis-angle parameters | |
''' | |
assert len(rot_vecs.shape) == 2, ( | |
f'Expects an array of size Bx3, but received {rot_vecs.shape}') | |
batch_size = rot_vecs.shape[0] | |
device = rot_vecs.device | |
dtype = rot_vecs.dtype | |
angle = torch.norm(rot_vecs + epsilon, dim=1, keepdim=True, p=2) | |
rot_dir = rot_vecs / angle | |
cos = torch.unsqueeze(torch.cos(angle), dim=1) | |
sin = torch.unsqueeze(torch.sin(angle), dim=1) | |
# Bx1 arrays | |
rx, ry, rz = torch.split(rot_dir, 1, dim=1) | |
K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) | |
zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) | |
K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ | |
.view((batch_size, 3, 3)) | |
ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) | |
rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) | |
return rot_mat | |
def batch_rot2aa( | |
Rs: Tensor, epsilon: float = 1e-7 | |
) -> Tensor: | |
""" | |
Rs is B x 3 x 3 | |
void cMathUtil::RotMatToAxisAngle(const tMatrix& mat, tVector& out_axis, | |
double& out_theta) | |
{ | |
double c = 0.5 * (mat(0, 0) + mat(1, 1) + mat(2, 2) - 1); | |
c = cMathUtil::Clamp(c, -1.0, 1.0); | |
out_theta = std::acos(c); | |
if (std::abs(out_theta) < 0.00001) | |
{ | |
out_axis = tVector(0, 0, 1, 0); | |
} | |
else | |
{ | |
double m21 = mat(2, 1) - mat(1, 2); | |
double m02 = mat(0, 2) - mat(2, 0); | |
double m10 = mat(1, 0) - mat(0, 1); | |
double denom = std::sqrt(m21 * m21 + m02 * m02 + m10 * m10); | |
out_axis[0] = m21 / denom; | |
out_axis[1] = m02 / denom; | |
out_axis[2] = m10 / denom; | |
out_axis[3] = 0; | |
} | |
} | |
""" | |
cos = 0.5 * (torch.einsum('bii->b', [Rs]) - 1) | |
cos = torch.clamp(cos, -1 + epsilon, 1 - epsilon) | |
theta = torch.acos(cos) | |
m21 = Rs[:, 2, 1] - Rs[:, 1, 2] | |
m02 = Rs[:, 0, 2] - Rs[:, 2, 0] | |
m10 = Rs[:, 1, 0] - Rs[:, 0, 1] | |
denom = torch.sqrt(m21 * m21 + m02 * m02 + m10 * m10 + epsilon) | |
axis0 = torch.where(torch.abs(theta) < 0.00001, m21, m21 / denom) | |
axis1 = torch.where(torch.abs(theta) < 0.00001, m02, m02 / denom) | |
axis2 = torch.where(torch.abs(theta) < 0.00001, m10, m10 / denom) | |
return theta.unsqueeze(1) * torch.stack([axis0, axis1, axis2], 1) | |