|
import torch.nn as nn
|
|
import torch
|
|
import numpy as np
|
|
from .skeleton_DME import SkeletonConv, SkeletonPool, find_neighbor, build_edge_topology
|
|
from .skeleton import SkeletonResidual
|
|
from .decoders import VQDecoderV3
|
|
|
|
|
|
class LocalEncoder(nn.Module):
|
|
def __init__(self, args, topology):
|
|
super(LocalEncoder, self).__init__()
|
|
args.channel_base = 6
|
|
args.activation = "tanh"
|
|
args.use_residual_blocks = True
|
|
args.z_dim = 1024
|
|
args.temporal_scale = 8
|
|
args.kernel_size = 4
|
|
args.num_layers = args.vae_layer
|
|
args.skeleton_dist = 2
|
|
args.extra_conv = 0
|
|
|
|
args.padding_mode = "constant"
|
|
args.skeleton_pool = "mean"
|
|
args.upsampling = "linear"
|
|
|
|
self.topologies = [topology]
|
|
self.channel_base = [args.channel_base]
|
|
|
|
self.channel_list = []
|
|
self.edge_num = [len(topology)]
|
|
self.pooling_list = []
|
|
self.layers = nn.ModuleList()
|
|
self.args = args
|
|
|
|
|
|
kernel_size = args.kernel_size
|
|
kernel_even = False if kernel_size % 2 else True
|
|
padding = (kernel_size - 1) // 2
|
|
bias = True
|
|
self.grow = args.vae_grow
|
|
for i in range(args.num_layers):
|
|
self.channel_base.append(self.channel_base[-1] * self.grow[i])
|
|
|
|
for i in range(args.num_layers):
|
|
seq = []
|
|
neighbour_list = find_neighbor(self.topologies[i], args.skeleton_dist)
|
|
in_channels = self.channel_base[i] * self.edge_num[i]
|
|
out_channels = self.channel_base[i + 1] * self.edge_num[i]
|
|
if i == 0:
|
|
self.channel_list.append(in_channels)
|
|
self.channel_list.append(out_channels)
|
|
last_pool = True if i == args.num_layers - 1 else False
|
|
|
|
|
|
pool = SkeletonPool(
|
|
edges=self.topologies[i],
|
|
pooling_mode=args.skeleton_pool,
|
|
channels_per_edge=out_channels // len(neighbour_list),
|
|
last_pool=last_pool,
|
|
)
|
|
|
|
if args.use_residual_blocks:
|
|
|
|
seq.append(
|
|
SkeletonResidual(
|
|
self.topologies[i],
|
|
neighbour_list,
|
|
joint_num=self.edge_num[i],
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=2,
|
|
padding=padding,
|
|
padding_mode=args.padding_mode,
|
|
bias=bias,
|
|
extra_conv=args.extra_conv,
|
|
pooling_mode=args.skeleton_pool,
|
|
activation=args.activation,
|
|
last_pool=last_pool,
|
|
)
|
|
)
|
|
else:
|
|
for _ in range(args.extra_conv):
|
|
|
|
seq.append(
|
|
SkeletonConv(
|
|
neighbour_list,
|
|
in_channels=in_channels,
|
|
out_channels=in_channels,
|
|
joint_num=self.edge_num[i],
|
|
kernel_size=kernel_size - 1 if kernel_even else kernel_size,
|
|
stride=1,
|
|
padding=padding,
|
|
padding_mode=args.padding_mode,
|
|
bias=bias,
|
|
)
|
|
)
|
|
seq.append(nn.PReLU() if args.activation == "relu" else nn.Tanh())
|
|
|
|
seq.append(
|
|
SkeletonConv(
|
|
neighbour_list,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
joint_num=self.edge_num[i],
|
|
kernel_size=kernel_size,
|
|
stride=2,
|
|
padding=padding,
|
|
padding_mode=args.padding_mode,
|
|
bias=bias,
|
|
add_offset=False,
|
|
in_offset_channel=3 * self.channel_base[i] // self.channel_base[0],
|
|
)
|
|
)
|
|
|
|
|
|
seq.append(pool)
|
|
seq.append(nn.PReLU() if args.activation == "relu" else nn.Tanh())
|
|
self.layers.append(nn.Sequential(*seq))
|
|
|
|
self.topologies.append(pool.new_edges)
|
|
self.pooling_list.append(pool.pooling_list)
|
|
self.edge_num.append(len(self.topologies[-1]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, input):
|
|
|
|
output = input.permute(0, 2, 1)
|
|
for layer in self.layers:
|
|
output = layer(output)
|
|
|
|
output = output.permute(0, 2, 1)
|
|
return output
|
|
|
|
|
|
def reparameterize(mu, logvar):
|
|
std = torch.exp(0.5 * logvar)
|
|
eps = torch.randn_like(std)
|
|
return mu + eps * std
|
|
|
|
|
|
class VAEConv(nn.Module):
|
|
def __init__(self, args):
|
|
super(VAEConv, self).__init__()
|
|
|
|
|
|
self.fc_mu = nn.Linear(args.vae_length, args.vae_length)
|
|
self.fc_logvar = nn.Linear(args.vae_length, args.vae_length)
|
|
self.variational = args.variational
|
|
|
|
def forward(self, inputs):
|
|
pre_latent = self.encoder(inputs)
|
|
mu, logvar = None, None
|
|
if self.variational:
|
|
mu = self.fc_mu(pre_latent)
|
|
logvar = self.fc_logvar(pre_latent)
|
|
pre_latent = reparameterize(mu, logvar)
|
|
rec_pose = self.decoder(pre_latent)
|
|
return {
|
|
"poses_feat": pre_latent,
|
|
"rec_pose": rec_pose,
|
|
"pose_mu": mu,
|
|
"pose_logvar": logvar,
|
|
}
|
|
|
|
def map2latent(self, inputs):
|
|
pre_latent = self.encoder(inputs)
|
|
if self.variational:
|
|
mu = self.fc_mu(pre_latent)
|
|
logvar = self.fc_logvar(pre_latent)
|
|
pre_latent = reparameterize(mu, logvar)
|
|
return pre_latent
|
|
|
|
def decode(self, pre_latent):
|
|
rec_pose = self.decoder(pre_latent)
|
|
return rec_pose
|
|
|
|
|
|
class VAESKConv(VAEConv):
|
|
def __init__(self, args, model_save_path="./emage/"):
|
|
|
|
super(VAESKConv, self).__init__(args)
|
|
smpl_fname = model_save_path + "smplx_models/smplx/SMPLX_NEUTRAL_2020.npz"
|
|
smpl_data = np.load(smpl_fname, encoding="latin1")
|
|
parents = smpl_data["kintree_table"][0].astype(np.int32)
|
|
edges = build_edge_topology(parents)
|
|
self.encoder = LocalEncoder(args, edges)
|
|
self.decoder = VQDecoderV3(args)
|
|
|