File size: 2,549 Bytes
c3d0293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# -*- coding: utf-8 -*-

# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2020 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: Vassilis Choutas, [email protected]

import os
import os.path as osp
import pickle

import numpy as np
import torch
from loguru import logger

from SMPLX.transfer_model.utils.typing import Tensor


def read_deformation_transfer(
    deformation_transfer_path: str,
    device=None,
    use_normal: bool = False,
) -> Tensor:
    ''' Reads a deformation transfer
    '''
    if device is None:
        device = torch.device('cpu')
    assert osp.exists(deformation_transfer_path), (
        'Deformation transfer path does not exist:'
        f' {deformation_transfer_path}')
    logger.info(
        f'Loading deformation transfer from: {deformation_transfer_path}')
    # Read the deformation transfer matrix
    with open(deformation_transfer_path, 'rb') as f:
        def_transfer_setup = pickle.load(f, encoding='latin1')
    if 'mtx' in def_transfer_setup:
        def_matrix = def_transfer_setup['mtx']
        if hasattr(def_matrix, 'todense'):
            def_matrix = def_matrix.todense()
        def_matrix = np.array(def_matrix, dtype=np.float32)
        if not use_normal:
            num_verts = def_matrix.shape[1] // 2
            def_matrix = def_matrix[:, :num_verts]
    elif 'matrix' in def_transfer_setup:
        def_matrix = def_transfer_setup['matrix']
    else:
        valid_keys = ['mtx', 'matrix']
        raise KeyError(f'Deformation transfer setup must contain {valid_keys}')

    def_matrix = torch.tensor(def_matrix, device=device, dtype=torch.float32)
    return def_matrix


def apply_deformation_transfer(
    def_matrix: Tensor,
    vertices: Tensor,
    faces: Tensor,
    use_normals=False
) -> Tensor:
    ''' Applies the deformation transfer on the given meshes
    '''
    if use_normals:
        raise NotImplementedError
    else:
        def_vertices = torch.einsum('mn,bni->bmi', [def_matrix, vertices])
        return def_vertices