# -*- coding: utf-8 -*- # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is # holder of all proprietary rights on this computer program. # You can only use this computer program if you have closed # a license agreement with MPG or you get the right to use the computer # program from someone who is authorized to grant you that right. # Any use of the computer program without a valid license is prohibited and # liable to prosecution. # # Copyright©2020 Max-Planck-Gesellschaft zur Förderung # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute # for Intelligent Systems. All rights reserved. # # Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de from typing import Optional, Dict, Callable import sys import numpy as np import torch import torch.nn as nn from tqdm import tqdm from loguru import logger from SMPLX.transfer_model.utils import get_vertices_per_edge from SMPLX.transfer_model.optimizers import build_optimizer, minimize from SMPLX.transfer_model.utils import ( Tensor, batch_rodrigues, apply_deformation_transfer) from SMPLX.transfer_model.losses import build_loss def summary_closure(gt_vertices, var_dict, body_model, mask_ids=None): param_dict = {} for key, var in var_dict.items(): # Decode the axis-angles if 'pose' in key or 'orient' in key: param_dict[key] = batch_rodrigues( var.reshape(-1, 3)).reshape(len(var), -1, 3, 3) else: # Simply pass the variable param_dict[key] = var body_model_output = body_model( return_full_pose=True, get_skin=True, **param_dict) est_vertices = body_model_output.vertices if mask_ids is not None: est_vertices = est_vertices[:, mask_ids] gt_vertices = gt_vertices[:, mask_ids] v2v = (est_vertices - gt_vertices).pow(2).sum(dim=-1).sqrt().mean() return { 'Vertex-to-Vertex': v2v * 1000} def build_model_forward_closure( body_model: nn.Module, var_dict: Dict[str, Tensor], per_part: bool = True, part_key: Optional[str] = None, jidx: Optional[int] = None, part: Optional[Tensor] = None ) -> Callable: if per_part: cond = part is not None and part_key is not None and jidx is not None assert cond, ( 'When per-part is True, "part", "part_key", "jidx" must not be' ' None.' ) def model_forward(): param_dict = {} for key, var in var_dict.items(): if part_key == key: param_dict[key] = batch_rodrigues( var.reshape(-1, 3)).reshape(len(var), -1, 3, 3) param_dict[key][:, jidx] = batch_rodrigues( part.reshape(-1, 3)).reshape(-1, 3, 3) else: # Decode the axis-angles if 'pose' in key or 'orient' in key: param_dict[key] = batch_rodrigues( var.reshape(-1, 3)).reshape(len(var), -1, 3, 3) else: # Simply pass the variable param_dict[key] = var return body_model( return_full_pose=True, get_skin=True, **param_dict) else: def model_forward(): param_dict = {} for key, var in var_dict.items(): # Decode the axis-angles if 'pose' in key or 'orient' in key: param_dict[key] = batch_rodrigues( var.reshape(-1, 3)).reshape(len(var), -1, 3, 3) else: # Simply pass the variable param_dict[key] = var return body_model(return_full_pose=True, get_skin=True, **param_dict) return model_forward def build_edge_closure( body_model: nn.Module, var_dict: Dict[str, Tensor], edge_loss: nn.Module, optimizer_dict, gt_vertices: Tensor, per_part: bool = True, part_key: Optional[str] = None, jidx: Optional[int] = None, part: Optional[Tensor] = None ) -> Callable: ''' Builds the closure for the edge objective ''' optimizer = optimizer_dict['optimizer'] create_graph = optimizer_dict['create_graph'] if per_part: params_to_opt = [part] else: params_to_opt = [p for key, p in var_dict.items() if 'pose' in key] model_forward = build_model_forward_closure( body_model, var_dict, per_part=per_part, part_key=part_key, jidx=jidx, part=part) def closure(backward=True): if backward: optimizer.zero_grad() body_model_output = model_forward() est_vertices = body_model_output.vertices loss = edge_loss(est_vertices, gt_vertices) if backward: if create_graph: # Use this instead of .backward to avoid GPU memory leaks grads = torch.autograd.grad( loss, params_to_opt, create_graph=True) torch.autograd.backward( params_to_opt, grads, create_graph=True) else: loss.backward() return loss return closure def build_vertex_closure( body_model: nn.Module, var_dict: Dict[str, Tensor], optimizer_dict, gt_vertices: Tensor, vertex_loss: nn.Module, mask_ids=None, per_part: bool = True, part_key: Optional[str] = None, jidx: Optional[int] = None, part: Optional[Tensor] = None, params_to_opt: Optional[Tensor] = None, ) -> Callable: ''' Builds the closure for the vertex objective ''' optimizer = optimizer_dict['optimizer'] create_graph = optimizer_dict['create_graph'] model_forward = build_model_forward_closure( body_model, var_dict, per_part=per_part, part_key=part_key, jidx=jidx, part=part) if params_to_opt is None: params_to_opt = [p for key, p in var_dict.items()] def closure(backward=True): if backward: optimizer.zero_grad() body_model_output = model_forward() est_vertices = body_model_output.vertices loss = vertex_loss( est_vertices[:, mask_ids] if mask_ids is not None else est_vertices, gt_vertices[:, mask_ids] if mask_ids is not None else gt_vertices) if backward: if create_graph: # Use this instead of .backward to avoid GPU memory leaks grads = torch.autograd.grad( loss, params_to_opt, create_graph=True) torch.autograd.backward( params_to_opt, grads, create_graph=True) else: loss.backward() return loss return closure def get_variables( batch_size: int, body_model: nn.Module, dtype: torch.dtype = torch.float32 ) -> Dict[str, Tensor]: var_dict = {} device = next(body_model.buffers()).device if (body_model.name() == 'SMPL' or body_model.name() == 'SMPL+H' or body_model.name() == 'SMPL-X'): var_dict.update({ 'transl': torch.zeros( [batch_size, 3], device=device, dtype=dtype), 'global_orient': torch.zeros( [batch_size, 1, 3], device=device, dtype=dtype), 'body_pose': torch.zeros( [batch_size, body_model.NUM_BODY_JOINTS, 3], device=device, dtype=dtype), 'betas': torch.zeros([batch_size, body_model.num_betas], dtype=dtype, device=device), }) if body_model.name() == 'SMPL+H' or body_model.name() == 'SMPL-X': var_dict.update( left_hand_pose=torch.zeros( [batch_size, body_model.NUM_HAND_JOINTS, 3], device=device, dtype=dtype), right_hand_pose=torch.zeros( [batch_size, body_model.NUM_HAND_JOINTS, 3], device=device, dtype=dtype), ) if body_model.name() == 'SMPL-X': var_dict.update( jaw_pose=torch.zeros([batch_size, 1, 3], device=device, dtype=dtype), leye_pose=torch.zeros([batch_size, 1, 3], device=device, dtype=dtype), reye_pose=torch.zeros([batch_size, 1, 3], device=device, dtype=dtype), expression=torch.zeros( [batch_size, body_model.num_expression_coeffs], device=device, dtype=dtype), ) # Toggle gradients to True for key, val in var_dict.items(): val.requires_grad_(True) return var_dict def run_fitting( # exp_cfg, batch: Dict[str, Tensor], body_model: nn.Module, def_matrix: Tensor, mask_ids ) -> Dict[str, Tensor]: ''' Runs fitting ''' vertices = batch['vertices'] faces = batch['faces'] batch_size = len(vertices) dtype, device = vertices.dtype, vertices.device # summary_steps = exp_cfg.get('summary_steps') # interactive = exp_cfg.get('interactive') summary_steps = 100 interactive = True # Get the parameters from the model var_dict = get_variables(batch_size, body_model) # Build the optimizer object for the current batch # optim_cfg = exp_cfg.get('optim', {}) optim_cfg = {'type': 'trust-ncg', 'lr': 1.0, 'gtol': 1e-06, 'ftol': -1.0, 'maxiters': 100, 'lbfgs': {'line_search_fn': 'strong_wolfe', 'max_iter': 50}, 'sgd': {'momentum': 0.9, 'nesterov': True}, 'adam': {'betas': [0.9, 0.999], 'eps': 1e-08, 'amsgrad': False}, 'trust_ncg': {'max_trust_radius': 1000.0, 'initial_trust_radius': 0.05, 'eta': 0.15, 'gtol': 1e-05}} def_vertices = apply_deformation_transfer(def_matrix, vertices, faces) if mask_ids is None: f_sel = np.ones_like(body_model.faces[:, 0], dtype=np.bool_) else: f_per_v = [[] for _ in range(body_model.get_num_verts())] [f_per_v[vv].append(iff) for iff, ff in enumerate(body_model.faces) for vv in ff] f_sel = list(set(tuple(sum([f_per_v[vv] for vv in mask_ids], [])))) vpe = get_vertices_per_edge( body_model.v_template.detach().cpu().numpy(), body_model.faces[f_sel]) def log_closure(): return summary_closure(def_vertices, var_dict, body_model, mask_ids=mask_ids) # edge_fitting_cfg = exp_cfg.get('edge_fitting', {}) edge_fitting_cfg = {'per_part': False, 'reduction': 'mean'} edge_loss = build_loss(type='vertex-edge', gt_edges=vpe, est_edges=vpe, **edge_fitting_cfg) edge_loss = edge_loss.to(device=device) # vertex_fitting_cfg = exp_cfg.get('vertex_fitting', {}) vertex_fitting_cfg = {} vertex_loss = build_loss(**vertex_fitting_cfg) vertex_loss = vertex_loss.to(device=device) per_part = edge_fitting_cfg.get('per_part', True) logger.info(f'Per-part: {per_part}') # Optimize edge-based loss to initialize pose if per_part: for key, var in tqdm(var_dict.items(), desc='Parts'): if 'pose' not in key: continue for jidx in tqdm(range(var.shape[1]), desc='Joints'): part = torch.zeros( [batch_size, 3], dtype=dtype, device=device, requires_grad=True) # Build the optimizer for the current part optimizer_dict = build_optimizer([part], optim_cfg) closure = build_edge_closure( body_model, var_dict, edge_loss, optimizer_dict, def_vertices, per_part=per_part, part_key=key, jidx=jidx, part=part) minimize(optimizer_dict['optimizer'], closure, params=[part], summary_closure=log_closure, summary_steps=summary_steps, interactive=interactive, **optim_cfg) with torch.no_grad(): var[:, jidx] = part else: optimizer_dict = build_optimizer(list(var_dict.values()), optim_cfg) closure = build_edge_closure( body_model, var_dict, edge_loss, optimizer_dict, def_vertices, per_part=per_part) minimize(optimizer_dict['optimizer'], closure, params=var_dict.values(), summary_closure=log_closure, summary_steps=summary_steps, interactive=interactive, **optim_cfg) if 'translation' in var_dict: optimizer_dict = build_optimizer([var_dict['translation']], optim_cfg) closure = build_vertex_closure( body_model, var_dict, optimizer_dict, def_vertices, vertex_loss=vertex_loss, mask_ids=mask_ids, per_part=False, params_to_opt=[var_dict['translation']], ) # Optimize translation minimize(optimizer_dict['optimizer'], closure, params=[var_dict['translation']], summary_closure=log_closure, summary_steps=summary_steps, interactive=interactive, **optim_cfg) # Optimize all model parameters with vertex-based loss optimizer_dict = build_optimizer(list(var_dict.values()), optim_cfg) closure = build_vertex_closure( body_model, var_dict, optimizer_dict, def_vertices, vertex_loss=vertex_loss, per_part=False, mask_ids=mask_ids) minimize(optimizer_dict['optimizer'], closure, params=list(var_dict.values()), summary_closure=log_closure, summary_steps=summary_steps, interactive=interactive, **optim_cfg) param_dict = {} for key, var in var_dict.items(): # Decode the axis-angles if 'pose' in key or 'orient' in key: param_dict[key] = batch_rodrigues( var.reshape(-1, 3)).reshape(len(var), -1, 3, 3) else: # Simply pass the variable param_dict[key] = var body_model_output = body_model( return_full_pose=True, get_skin=True, **param_dict) keys = ["vertices", "joints", "betas", "global_orient", "body_pose", "left_hand_pose", "right_hand_pose", "full_pose"] for key in keys: var_dict[key] = getattr(body_model_output, key) var_dict['faces'] = body_model.faces for key in var_dict.keys(): try: var_dict[key] = var_dict[key].detach().cpu().numpy() except: pass return var_dict