# -*- coding: utf-8 -*- # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is # holder of all proprietary rights on this computer program. # You can only use this computer program if you have closed # a license agreement with MPG or you get the right to use the computer # program from someone who is authorized to grant you that right. # Any use of the computer program without a valid license is prohibited and # liable to prosecution. # # Copyright©2020 Max-Planck-Gesellschaft zur Förderung # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute # for Intelligent Systems. All rights reserved. # # Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de from typing import List, Union, Callable, Optional, Dict import torch from loguru import logger from tqdm import tqdm from SMPLX.transfer_model.utils import ( from_torch, Tensor, Array, rel_change) def minimize( optimizer: torch.optim, closure, params: List[Tensor], summary_closure: Optional[Callable[[], Dict[str, float]]] = None, maxiters=100, ftol=-1.0, gtol=1e-9, interactive=True, summary_steps=10, **kwargs ): ''' Helper function for running an optimization process Args: - optimizer: The PyTorch optimizer object - closure: The function used to calculate the gradients - params: a list containing the parameters that will be optimized Keyword arguments: - maxiters (100): The maximum number of iterations for the optimizer - ftol: The tolerance for the relative change in the loss function. If it is lower than this value, then the process stops - gtol: The tolerance for the maximum change in the gradient. If the maximum absolute values of the all gradient tensors are less than this, then the process will stop. ''' prev_loss = None for n in tqdm(range(maxiters), desc='Fitting iterations'): loss = optimizer.step(closure) if n > 0 and prev_loss is not None and ftol > 0: loss_rel_change = rel_change(prev_loss, loss.item()) if loss_rel_change <= ftol: prev_loss = loss.item() break if (all([var.grad.view(-1).abs().max().item() < gtol for var in params if var.grad is not None]) and gtol > 0): prev_loss = loss.item() break if interactive and n % summary_steps == 0: logger.info(f'[{n:05d}] Loss: {loss.item():.4f}') if summary_closure is not None: summaries = summary_closure() for key, val in summaries.items(): logger.info(f'[{n:05d}] {key}: {val:.4f}') prev_loss = loss.item() # Save the final step if interactive: logger.info(f'[{n + 1:05d}] Loss: {loss.item():.4f}') if summary_closure is not None: summaries = summary_closure() for key, val in summaries.items(): logger.info(f'[{n + 1:05d}] {key}: {val:.4f}') return prev_loss