# -*- coding: utf-8 -*- # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is # holder of all proprietary rights on this computer program. # You can only use this computer program if you have closed # a license agreement with MPG or you get the right to use the computer # program from someone who is authorized to grant you that right. # Any use of the computer program without a valid license is prohibited and # liable to prosecution. # # Copyright©2020 Max-Planck-Gesellschaft zur Förderung # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute # for Intelligent Systems. All rights reserved. # # Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de import sys from typing import NewType, List, Dict, Optional import os import os.path as osp import pickle import torch import torch.nn as nn import torch.nn.functional as F from omegaconf import OmegaConf from loguru import logger from SMPLX.transfer_model.utils.typing import Tensor def rotation_matrix_to_cont_repr(x: Tensor) -> Tensor: assert len(x.shape) == 3, ( f'Expects an array of size Bx3x3, but received {x.shape}') return x[:, :3, :2] def cont_repr_to_rotation_matrix( x: Tensor ) -> Tensor: ''' Converts tensor in continous representation to rotation matrices ''' batch_size = x.shape[0] reshaped_input = x.view(-1, 3, 2) # Normalize the first vector b1 = F.normalize(reshaped_input[:, :, 0].clone(), dim=1) dot_prod = torch.sum( b1 * reshaped_input[:, :, 1].clone(), dim=1, keepdim=True) # Compute the second vector by finding the orthogonal complement to it b2 = F.normalize(reshaped_input[:, :, 1] - dot_prod * b1, dim=1) # Finish building the basis by taking the cross product b3 = torch.cross(b1, b2, dim=1) rot_mats = torch.stack([b1, b2, b3], dim=-1) return rot_mats.view(batch_size, -1, 3, 3) def batch_rodrigues( rot_vecs: Tensor, epsilon: float = 1e-8 ) -> Tensor: ''' Calculates the rotation matrices for a batch of rotation vectors Parameters ---------- rot_vecs: torch.tensor Nx3 array of N axis-angle vectors Returns ------- R: torch.tensor Nx3x3 The rotation matrices for the given axis-angle parameters ''' assert len(rot_vecs.shape) == 2, ( f'Expects an array of size Bx3, but received {rot_vecs.shape}') batch_size = rot_vecs.shape[0] device = rot_vecs.device dtype = rot_vecs.dtype angle = torch.norm(rot_vecs + epsilon, dim=1, keepdim=True, p=2) rot_dir = rot_vecs / angle cos = torch.unsqueeze(torch.cos(angle), dim=1) sin = torch.unsqueeze(torch.sin(angle), dim=1) # Bx1 arrays rx, ry, rz = torch.split(rot_dir, 1, dim=1) K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ .view((batch_size, 3, 3)) ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) return rot_mat def batch_rot2aa( Rs: Tensor, epsilon: float = 1e-7 ) -> Tensor: """ Rs is B x 3 x 3 void cMathUtil::RotMatToAxisAngle(const tMatrix& mat, tVector& out_axis, double& out_theta) { double c = 0.5 * (mat(0, 0) + mat(1, 1) + mat(2, 2) - 1); c = cMathUtil::Clamp(c, -1.0, 1.0); out_theta = std::acos(c); if (std::abs(out_theta) < 0.00001) { out_axis = tVector(0, 0, 1, 0); } else { double m21 = mat(2, 1) - mat(1, 2); double m02 = mat(0, 2) - mat(2, 0); double m10 = mat(1, 0) - mat(0, 1); double denom = std::sqrt(m21 * m21 + m02 * m02 + m10 * m10); out_axis[0] = m21 / denom; out_axis[1] = m02 / denom; out_axis[2] = m10 / denom; out_axis[3] = 0; } } """ cos = 0.5 * (torch.einsum('bii->b', [Rs]) - 1) cos = torch.clamp(cos, -1 + epsilon, 1 - epsilon) theta = torch.acos(cos) m21 = Rs[:, 2, 1] - Rs[:, 1, 2] m02 = Rs[:, 0, 2] - Rs[:, 2, 0] m10 = Rs[:, 1, 0] - Rs[:, 0, 1] denom = torch.sqrt(m21 * m21 + m02 * m02 + m10 * m10 + epsilon) axis0 = torch.where(torch.abs(theta) < 0.00001, m21, m21 / denom) axis1 = torch.where(torch.abs(theta) < 0.00001, m02, m02 / denom) axis2 = torch.where(torch.abs(theta) < 0.00001, m10, m10 / denom) return theta.unsqueeze(1) * torch.stack([axis0, axis1, axis2], 1)