kleinhe
init
c3d0293
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2020 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: Vassilis Choutas, [email protected]
import sys
from typing import NewType, List, Dict, Optional
import os
import os.path as osp
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from loguru import logger
from SMPLX.transfer_model.utils.typing import Tensor
def rotation_matrix_to_cont_repr(x: Tensor) -> Tensor:
assert len(x.shape) == 3, (
f'Expects an array of size Bx3x3, but received {x.shape}')
return x[:, :3, :2]
def cont_repr_to_rotation_matrix(
x: Tensor
) -> Tensor:
''' Converts tensor in continous representation to rotation matrices
'''
batch_size = x.shape[0]
reshaped_input = x.view(-1, 3, 2)
# Normalize the first vector
b1 = F.normalize(reshaped_input[:, :, 0].clone(), dim=1)
dot_prod = torch.sum(
b1 * reshaped_input[:, :, 1].clone(), dim=1, keepdim=True)
# Compute the second vector by finding the orthogonal complement to it
b2 = F.normalize(reshaped_input[:, :, 1] - dot_prod * b1, dim=1)
# Finish building the basis by taking the cross product
b3 = torch.cross(b1, b2, dim=1)
rot_mats = torch.stack([b1, b2, b3], dim=-1)
return rot_mats.view(batch_size, -1, 3, 3)
def batch_rodrigues(
rot_vecs: Tensor,
epsilon: float = 1e-8
) -> Tensor:
''' Calculates the rotation matrices for a batch of rotation vectors
Parameters
----------
rot_vecs: torch.tensor Nx3
array of N axis-angle vectors
Returns
-------
R: torch.tensor Nx3x3
The rotation matrices for the given axis-angle parameters
'''
assert len(rot_vecs.shape) == 2, (
f'Expects an array of size Bx3, but received {rot_vecs.shape}')
batch_size = rot_vecs.shape[0]
device = rot_vecs.device
dtype = rot_vecs.dtype
angle = torch.norm(rot_vecs + epsilon, dim=1, keepdim=True, p=2)
rot_dir = rot_vecs / angle
cos = torch.unsqueeze(torch.cos(angle), dim=1)
sin = torch.unsqueeze(torch.sin(angle), dim=1)
# Bx1 arrays
rx, ry, rz = torch.split(rot_dir, 1, dim=1)
K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \
.view((batch_size, 3, 3))
ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
return rot_mat
def batch_rot2aa(
Rs: Tensor, epsilon: float = 1e-7
) -> Tensor:
"""
Rs is B x 3 x 3
void cMathUtil::RotMatToAxisAngle(const tMatrix& mat, tVector& out_axis,
double& out_theta)
{
double c = 0.5 * (mat(0, 0) + mat(1, 1) + mat(2, 2) - 1);
c = cMathUtil::Clamp(c, -1.0, 1.0);
out_theta = std::acos(c);
if (std::abs(out_theta) < 0.00001)
{
out_axis = tVector(0, 0, 1, 0);
}
else
{
double m21 = mat(2, 1) - mat(1, 2);
double m02 = mat(0, 2) - mat(2, 0);
double m10 = mat(1, 0) - mat(0, 1);
double denom = std::sqrt(m21 * m21 + m02 * m02 + m10 * m10);
out_axis[0] = m21 / denom;
out_axis[1] = m02 / denom;
out_axis[2] = m10 / denom;
out_axis[3] = 0;
}
}
"""
cos = 0.5 * (torch.einsum('bii->b', [Rs]) - 1)
cos = torch.clamp(cos, -1 + epsilon, 1 - epsilon)
theta = torch.acos(cos)
m21 = Rs[:, 2, 1] - Rs[:, 1, 2]
m02 = Rs[:, 0, 2] - Rs[:, 2, 0]
m10 = Rs[:, 1, 0] - Rs[:, 0, 1]
denom = torch.sqrt(m21 * m21 + m02 * m02 + m10 * m10 + epsilon)
axis0 = torch.where(torch.abs(theta) < 0.00001, m21, m21 / denom)
axis1 = torch.where(torch.abs(theta) < 0.00001, m02, m02 / denom)
axis2 = torch.where(torch.abs(theta) < 0.00001, m10, m10 / denom)
return theta.unsqueeze(1) * torch.stack([axis0, axis1, axis2], 1)