File size: 4,496 Bytes
c3d0293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import numpy as np
from torch import nn
import pickle as pkl
import torch.nn.functional as F

class Struct(object):
    def __init__(self, **kwargs):
        for key, val in kwargs.items():
            setattr(self, key, val)


def to_np(array, dtype=np.float32):
    if 'scipy.sparse' in str(type(array)):
        array = array.todense()
    return np.array(array, dtype=dtype)


class Get_Joints(nn.Module):
    def __init__(self, path, batch_size=300) -> None:
        super().__init__()
        self.betas = nn.parameter.Parameter(torch.zeros([batch_size, 10], dtype=torch.float32), requires_grad=False)
        with open(path, "rb") as f:
            smpl_prior = pkl.load(f, encoding="latin1")
            data_struct = Struct(**smpl_prior)

        self.v_template = nn.parameter.Parameter(torch.from_numpy(to_np(data_struct.v_template)), requires_grad=False)
        self.shapedirs = nn.parameter.Parameter(torch.from_numpy(to_np(data_struct.shapedirs)), requires_grad=False)
        self.J_regressor = nn.parameter.Parameter(torch.from_numpy(to_np(data_struct.J_regressor)), requires_grad=False)
        posedirs = torch.from_numpy(to_np(data_struct.posedirs))
        num_pose_basis = posedirs.shape[-1]
        posedirs = posedirs.reshape([-1, num_pose_basis]).permute(1, 0)
        self.posedirs = nn.parameter.Parameter(posedirs, requires_grad=False)
        self.parents = nn.parameter.Parameter(torch.from_numpy(to_np(data_struct.kintree_table)[0]).long(), requires_grad=False)
        self.parents[0] = -1

        self.ident = nn.parameter.Parameter(torch.eye(3), requires_grad=False)
        self.K = nn.parameter.Parameter(torch.zeros([1, 3, 3]), requires_grad=False)
        self.zeros = nn.parameter.Parameter(torch.zeros([1, 1]), requires_grad=False)

    def blend_shapes(self, betas, shape_disps):
        blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps])
        return blend_shape

    def vertices2joints(self, J_regressor, vertices):
        return torch.einsum('bik,ji->bjk', [vertices, J_regressor])

    def batch_rodrigues(
        self,
        rot_vecs,
        epsilon = 1e-8,
    ):
        batch_size = rot_vecs.shape[0]
        angle = torch.norm(rot_vecs + epsilon, dim=1, keepdim=True)
        rot_dir = rot_vecs / angle
        cos = torch.unsqueeze(torch.cos(angle), dim=1)
        sin = torch.unsqueeze(torch.sin(angle), dim=1)
        # Bx1 arrays
        rx, ry, rz = torch.split(rot_dir, 1, dim=1)
        K = self.K.repeat(batch_size, 1, 1)
        zeros = self.zeros.repeat(batch_size, 1)
        K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view((batch_size, 3, 3))
        ident = self.ident.unsqueeze(0) 
        rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
        return rot_mat

    def transform_mat(self, R, t):
        return torch.cat([F.pad(R, [0, 0, 0, 1]),
                        F.pad(t, [0, 0, 0, 1], value=1)], dim=2)

    def batch_rigid_transform(
        self,
        rot_mats,
        joints,
        parents,
    ):
        joints = torch.unsqueeze(joints, dim=-1)

        rel_joints = joints.clone()
        rel_joints[:, 1:] -= joints[:, parents[1:]]

        transforms_mat = self.transform_mat(
            rot_mats.reshape(-1, 3, 3),
            rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4)

        transform_chain = [transforms_mat[:, 0]]
        for i in range(1, parents.shape[0]):
            # Subtract the joint location at the rest pose
            # No need for rotation, since it's identity when at rest
            curr_res = torch.matmul(transform_chain[parents[i]],
                                    transforms_mat[:, i])
            transform_chain.append(curr_res)

        transforms = torch.stack(transform_chain, dim=1)

        # The last column of the transformations contains the posed joints
        posed_joints = transforms[:, :, :3, 3]
        return posed_joints

    def forward(self, pose, trans=None):
        pose = pose.float()
        batch = pose.shape[0]
        betas = self.betas[:batch]
        v_shaped = self.v_template + self.blend_shapes(betas, self.shapedirs)
        J = self.vertices2joints(self.J_regressor, v_shaped)
        rot_mats = self.batch_rodrigues(pose.view(-1, 3)).view([batch, -1, 3, 3])
        J_transformed = self.batch_rigid_transform(rot_mats, J, self.parents)
        if trans is not None:
            J_transformed += trans.unsqueeze(dim=1)
        return J_transformed