Spaces:
Running
on
L40S
Running
on
L40S
from functools import reduce | |
from pathlib import Path | |
import torch | |
import torch.nn.functional as F | |
class NullableArgs: | |
def __init__(self, namespace): | |
for key, value in namespace.__dict__.items(): | |
setattr(self, key, value) | |
def __getattr__(self, key): | |
# when an attribute lookup has not found the attribute | |
if key == 'align_mask_width': | |
if 'use_alignment_mask' in self.__dict__: | |
return 1 if self.use_alignment_mask else 0 | |
else: | |
return 0 | |
if key == 'no_head_pose': | |
return not self.predict_head_pose | |
if key == 'no_use_learnable_pe': | |
return not self.use_learnable_pe | |
return None | |
def count_parameters(model): | |
return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
def get_option_text(args, parser): | |
message = '' | |
for k, v in sorted(vars(args).items()): | |
comment = '' | |
default = parser.get_default(k) | |
if v != default: | |
comment = f'\t[default: {str(default)}]' | |
message += f'{str(k):>30}: {str(v):<30}{comment}\n' | |
return message | |
def get_model_path(exp_name, iteration, model_type='DPT'): | |
exp_root_dir = Path(__file__).parent.parent / 'experiments' / model_type | |
exp_dir = exp_root_dir / exp_name | |
if not exp_dir.exists(): | |
exp_dir = next(exp_root_dir.glob(f'{exp_name}*')) | |
model_path = exp_dir / f'checkpoints/iter_{iteration:07}.pt' | |
return model_path, exp_dir.relative_to(exp_root_dir) | |
def get_pose_input(coef_dict, rot_repr, with_global_pose): | |
if rot_repr == 'aa': | |
pose_input = coef_dict['pose'] if with_global_pose else coef_dict['pose'][..., -3:] | |
# Remove mouth rotation round y, z axis | |
pose_input = pose_input[..., :-2] | |
else: | |
raise ValueError(f'Unknown rotation representation: {rot_repr}') | |
return pose_input | |
def get_motion_coef(coef_dict, rot_repr, with_global_pose=False, norm_stats=None): | |
if norm_stats is not None: | |
if rot_repr == 'aa': | |
keys = ['exp', 'pose'] | |
else: | |
raise ValueError(f'Unknown rotation representation {rot_repr}!') | |
coef_dict = {k: (coef_dict[k] - norm_stats[f'{k}_mean']) / norm_stats[f'{k}_std'] for k in keys} | |
pose_coef = get_pose_input(coef_dict, rot_repr, with_global_pose) | |
return torch.cat([coef_dict['exp'], pose_coef], dim=-1) | |
def get_coef_dict(motion_coef, shape_coef=None, denorm_stats=None, with_global_pose=False, rot_repr='aa'): | |
coef_dict = { | |
'exp': motion_coef[..., :50] | |
} | |
if rot_repr == 'aa': | |
if with_global_pose: | |
coef_dict['pose'] = motion_coef[..., 50:] | |
else: | |
placeholder = torch.zeros_like(motion_coef[..., :3]) | |
coef_dict['pose'] = torch.cat([placeholder, motion_coef[..., -1:]], dim=-1) | |
# Add back rotation around y, z axis | |
coef_dict['pose'] = torch.cat([coef_dict['pose'], torch.zeros_like(motion_coef[..., :2])], dim=-1) | |
else: | |
raise ValueError(f'Unknown rotation representation {rot_repr}!') | |
if shape_coef is not None: | |
if motion_coef.ndim == 3: | |
if shape_coef.ndim == 2: | |
shape_coef = shape_coef.unsqueeze(1) | |
if shape_coef.shape[1] == 1: | |
shape_coef = shape_coef.expand(-1, motion_coef.shape[1], -1) | |
coef_dict['shape'] = shape_coef | |
if denorm_stats is not None: | |
coef_dict = {k: coef_dict[k] * denorm_stats[f'{k}_std'] + denorm_stats[f'{k}_mean'] for k in coef_dict} | |
if not with_global_pose: | |
if rot_repr == 'aa': | |
coef_dict['pose'][..., :3] = 0 | |
else: | |
raise ValueError(f'Unknown rotation representation {rot_repr}!') | |
return coef_dict | |
def coef_dict_to_vertices(coef_dict, flame, rot_repr='aa', ignore_global_rot=False, flame_batch_size=512): | |
shape = coef_dict['exp'].shape[:-1] | |
coef_dict = {k: v.view(-1, v.shape[-1]) for k, v in coef_dict.items()} | |
n_samples = reduce(lambda x, y: x * y, shape, 1) | |
# Convert to vertices | |
vert_list = [] | |
for i in range(0, n_samples, flame_batch_size): | |
batch_coef_dict = {k: v[i:i + flame_batch_size] for k, v in coef_dict.items()} | |
if rot_repr == 'aa': | |
vert, _, _ = flame( | |
batch_coef_dict['shape'], batch_coef_dict['exp'], batch_coef_dict['pose'], | |
pose2rot=True, ignore_global_rot=ignore_global_rot, return_lm2d=False, return_lm3d=False) | |
else: | |
raise ValueError(f'Unknown rot_repr: {rot_repr}') | |
vert_list.append(vert) | |
vert_list = torch.cat(vert_list, dim=0) # (n_samples, 5023, 3) | |
vert_list = vert_list.view(*shape, -1, 3) # (..., 5023, 3) | |
return vert_list | |
def compute_loss(args, is_starting_sample, shape_coef, motion_coef_gt, noise, target, prev_motion_coef, coef_stats, | |
flame, end_idx=None): | |
if args.criterion.lower() == 'l2': | |
criterion_func = F.mse_loss | |
elif args.criterion.lower() == 'l1': | |
criterion_func = F.l1_loss | |
else: | |
raise NotImplementedError(f'Criterion {args.criterion} not implemented.') | |
loss_vert = None | |
loss_vel = None | |
loss_smooth = None | |
loss_head_angle = None | |
loss_head_vel = None | |
loss_head_smooth = None | |
loss_head_trans_vel = None | |
loss_head_trans_accel = None | |
loss_head_trans = None | |
if args.target == 'noise': | |
loss_noise = criterion_func(noise, target[:, args.n_prev_motions:], reduction='none') | |
elif args.target == 'sample': | |
if is_starting_sample: | |
target = target[:, args.n_prev_motions:] | |
else: | |
motion_coef_gt = torch.cat([prev_motion_coef, motion_coef_gt], dim=1) | |
if args.no_constrain_prev: | |
target = torch.cat([prev_motion_coef, target[:, args.n_prev_motions:]], dim=1) | |
loss_noise = criterion_func(motion_coef_gt, target, reduction='none') | |
if args.l_vert > 0 or args.l_vel > 0: | |
coef_gt = get_coef_dict(motion_coef_gt, shape_coef, coef_stats, with_global_pose=False, | |
rot_repr=args.rot_repr) | |
coef_pred = get_coef_dict(target, shape_coef, coef_stats, with_global_pose=False, | |
rot_repr=args.rot_repr) | |
seq_len = target.shape[1] | |
if args.rot_repr == 'aa': | |
verts_gt, _, _ = flame(coef_gt['shape'].view(-1, 100), coef_gt['exp'].view(-1, 50), | |
coef_gt['pose'].view(-1, 6), return_lm2d=False, return_lm3d=False) | |
verts_pred, _, _ = flame(coef_pred['shape'].view(-1, 100), coef_pred['exp'].view(-1, 50), | |
coef_pred['pose'].view(-1, 6), return_lm2d=False, return_lm3d=False) | |
else: | |
raise ValueError(f'Unknown rotation representation {args.rot_repr}!') | |
verts_gt = verts_gt.view(-1, seq_len, 5023, 3) | |
verts_pred = verts_pred.view(-1, seq_len, 5023, 3) | |
if args.l_vert > 0: | |
loss_vert = criterion_func(verts_gt, verts_pred, reduction='none') | |
if args.l_vel > 0: | |
vel_gt = verts_gt[:, 1:] - verts_gt[:, :-1] | |
vel_pred = verts_pred[:, 1:] - verts_pred[:, :-1] | |
loss_vel = criterion_func(vel_gt, vel_pred, reduction='none') | |
if args.l_smooth > 0: | |
vel_pred = verts_pred[:, 1:] - verts_pred[:, :-1] | |
loss_smooth = criterion_func(vel_pred[:, 1:], vel_pred[:, :-1], reduction='none') | |
# head pose | |
if not args.no_head_pose: | |
if args.rot_repr == 'aa': | |
head_pose_gt = motion_coef_gt[:, :, 50:53] | |
head_pose_pred = target[:, :, 50:53] | |
else: | |
raise ValueError(f'Unknown rotation representation {args.rot_repr}!') | |
if args.l_head_angle > 0: | |
loss_head_angle = criterion_func(head_pose_gt, head_pose_pred, reduction='none') | |
if args.l_head_vel > 0: | |
head_vel_gt = head_pose_gt[:, 1:] - head_pose_gt[:, :-1] | |
head_vel_pred = head_pose_pred[:, 1:] - head_pose_pred[:, :-1] | |
loss_head_vel = criterion_func(head_vel_gt, head_vel_pred, reduction='none') | |
if args.l_head_smooth > 0: | |
head_vel_pred = head_pose_pred[:, 1:] - head_pose_pred[:, :-1] | |
loss_head_smooth = criterion_func(head_vel_pred[:, 1:], head_vel_pred[:, :-1], reduction='none') | |
if not is_starting_sample and args.l_head_trans > 0: | |
# # version 1: constrain both the predicted previous and current motions (x_{-3} ~ x_{2}) | |
# head_pose_trans = head_pose_pred[:, args.n_prev_motions - 3:args.n_prev_motions + 3] | |
# head_vel_pred = head_pose_trans[:, 1:] - head_pose_trans[:, :-1] | |
# head_accel_pred = head_vel_pred[:, 1:] - head_vel_pred[:, :-1] | |
# version 2: constrain only the predicted current motions (x_{0} ~ x_{2}) | |
head_pose_trans = torch.cat([head_pose_gt[:, args.n_prev_motions - 3:args.n_prev_motions], | |
head_pose_pred[:, args.n_prev_motions:args.n_prev_motions + 3]], dim=1) | |
head_vel_pred = head_pose_trans[:, 1:] - head_pose_trans[:, :-1] | |
head_accel_pred = head_vel_pred[:, 1:] - head_vel_pred[:, :-1] | |
# will constrain x_{-2|0} ~ x_{1} | |
loss_head_trans_vel = criterion_func(head_vel_pred[:, 2:4], head_vel_pred[:, 1:3], reduction='none') | |
# will constrain x_{-3|0} ~ x_{2} | |
loss_head_trans_accel = criterion_func(head_accel_pred[:, 1:], head_accel_pred[:, :-1], | |
reduction='none') | |
else: | |
raise ValueError(f'Unknown diffusion target: {args.target}') | |
if end_idx is None: | |
mask = torch.ones((target.shape[0], args.n_motions), dtype=torch.bool, device=target.device) | |
else: | |
mask = torch.arange(args.n_motions, device=target.device).expand(target.shape[0], -1) < end_idx.unsqueeze(1) | |
if args.target == 'sample' and not is_starting_sample: | |
if args.no_constrain_prev: | |
# Warning: this option will be deprecated in the future | |
mask = torch.cat([torch.zeros_like(mask[:, :args.n_prev_motions]), mask], dim=1) | |
else: | |
mask = torch.cat([torch.ones_like(mask[:, :args.n_prev_motions]), mask], dim=1) | |
loss_noise = loss_noise[mask].mean() | |
if loss_vert is not None: | |
loss_vert = loss_vert[mask].mean() | |
if loss_vel is not None: | |
loss_vel = loss_vel[mask[:, 1:]] | |
loss_vel = loss_vel.mean() if torch.numel(loss_vel) > 0 else None | |
if loss_smooth is not None: | |
loss_smooth = loss_smooth[mask[:, 2:]] | |
loss_smooth = loss_smooth.mean() if torch.numel(loss_smooth) > 0 else None | |
if loss_head_angle is not None: | |
loss_head_angle = loss_head_angle[mask].mean() | |
if loss_head_vel is not None: | |
loss_head_vel = loss_head_vel[mask[:, 1:]] | |
loss_head_vel = loss_head_vel.mean() if torch.numel(loss_head_vel) > 0 else None | |
if loss_head_smooth is not None: | |
loss_head_smooth = loss_head_smooth[mask[:, 2:]] | |
loss_head_smooth = loss_head_smooth.mean() if torch.numel(loss_head_smooth) > 0 else None | |
if loss_head_trans_vel is not None: | |
vel_mask = mask[:, args.n_prev_motions:args.n_prev_motions + 2] | |
accel_mask = mask[:, args.n_prev_motions:args.n_prev_motions + 3] | |
loss_head_trans_vel = loss_head_trans_vel[vel_mask].mean() | |
loss_head_trans_accel = loss_head_trans_accel[accel_mask].mean() | |
loss_head_trans = loss_head_trans_vel + loss_head_trans_accel | |
return loss_noise, loss_vert, loss_vel, loss_smooth, loss_head_angle, loss_head_vel, loss_head_smooth, \ | |
loss_head_trans | |
def _truncate_audio(audio, end_idx, pad_mode='zero'): | |
batch_size = audio.shape[0] | |
audio_trunc = audio.clone() | |
if pad_mode == 'replicate': | |
for i in range(batch_size): | |
audio_trunc[i, end_idx[i]:] = audio_trunc[i, end_idx[i] - 1] | |
elif pad_mode == 'zero': | |
for i in range(batch_size): | |
audio_trunc[i, end_idx[i]:] = 0 | |
else: | |
raise ValueError(f'Unknown pad mode {pad_mode}!') | |
return audio_trunc | |
def _truncate_coef_dict(coef_dict, end_idx, pad_mode='zero'): | |
batch_size = coef_dict['exp'].shape[0] | |
coef_dict_trunc = {k: v.clone() for k, v in coef_dict.items()} | |
if pad_mode == 'replicate': | |
for i in range(batch_size): | |
for k in coef_dict_trunc: | |
coef_dict_trunc[k][i, end_idx[i]:] = coef_dict_trunc[k][i, end_idx[i] - 1] | |
elif pad_mode == 'zero': | |
for i in range(batch_size): | |
for k in coef_dict: | |
coef_dict_trunc[k][i, end_idx[i]:] = 0 | |
else: | |
raise ValueError(f'Unknown pad mode: {pad_mode}!') | |
return coef_dict_trunc | |
def truncate_coef_dict_and_audio(audio, coef_dict, n_motions, audio_unit=640, pad_mode='zero'): | |
batch_size = audio.shape[0] | |
end_idx = torch.randint(1, n_motions, (batch_size,), device=audio.device) | |
audio_end_idx = (end_idx * audio_unit).long() | |
# mask = torch.arange(n_motions, device=audio.device).expand(batch_size, -1) < end_idx.unsqueeze(1) | |
# truncate audio | |
audio_trunc = _truncate_audio(audio, audio_end_idx, pad_mode=pad_mode) | |
# truncate coef dict | |
coef_dict_trunc = _truncate_coef_dict(coef_dict, end_idx, pad_mode=pad_mode) | |
return audio_trunc, coef_dict_trunc, end_idx | |
def truncate_motion_coef_and_audio(audio, motion_coef, n_motions, audio_unit=640, pad_mode='zero'): | |
batch_size = audio.shape[0] | |
end_idx = torch.randint(1, n_motions, (batch_size,), device=audio.device) | |
audio_end_idx = (end_idx * audio_unit).long() | |
# mask = torch.arange(n_motions, device=audio.device).expand(batch_size, -1) < end_idx.unsqueeze(1) | |
# truncate audio | |
audio_trunc = _truncate_audio(audio, audio_end_idx, pad_mode=pad_mode) | |
# prepare coef dict and stats | |
coef_dict = {'exp': motion_coef[..., :50], 'pose_any': motion_coef[..., 50:]} | |
# truncate coef dict | |
coef_dict_trunc = _truncate_coef_dict(coef_dict, end_idx, pad_mode=pad_mode) | |
motion_coef_trunc = torch.cat([coef_dict_trunc['exp'], coef_dict_trunc['pose_any']], dim=-1) | |
return audio_trunc, motion_coef_trunc, end_idx | |
def nt_xent_loss(feature_a, feature_b, temperature): | |
""" | |
Normalized temperature-scaled cross entropy loss. | |
(Adapted from https://github.com/sthalles/SimCLR/blob/master/simclr.py) | |
Args: | |
feature_a (torch.Tensor): shape (batch_size, feature_dim) | |
feature_b (torch.Tensor): shape (batch_size, feature_dim) | |
temperature (float): temperature scaling factor | |
Returns: | |
torch.Tensor: scalar | |
""" | |
batch_size = feature_a.shape[0] | |
device = feature_a.device | |
features = torch.cat([feature_a, feature_b], dim=0) | |
labels = torch.cat([torch.arange(batch_size), torch.arange(batch_size)], dim=0) | |
labels = (labels.unsqueeze(0) == labels.unsqueeze(1)) | |
labels = labels.to(device) | |
features = F.normalize(features, dim=1) | |
similarity_matrix = torch.matmul(features, features.T) | |
# discard the main diagonal from both: labels and similarities matrix | |
mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device) | |
labels = labels[~mask].view(labels.shape[0], -1) | |
similarity_matrix = similarity_matrix[~mask].view(labels.shape[0], -1) | |
# select the positives and negatives | |
positives = similarity_matrix[labels].view(labels.shape[0], -1) | |
negatives = similarity_matrix[~labels].view(labels.shape[0], -1) | |
logits = torch.cat([positives, negatives], dim=1) | |
logits = logits / temperature | |
labels = torch.zeros(labels.shape[0], dtype=torch.long).to(device) | |
loss = F.cross_entropy(logits, labels) | |
return loss | |