# -*- coding: utf-8 -*- # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is # holder of all proprietary rights on this computer program. # You can only use this computer program if you have closed # a license agreement with MPG or you get the right to use the computer # program from someone who is authorized to grant you that right. # Any use of the computer program without a valid license is prohibited and # liable to prosecution. # # Copyright©2020 Max-Planck-Gesellschaft zur Förderung # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute # for Intelligent Systems. All rights reserved. # # Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de from __future__ import print_function from __future__ import absolute_import from __future__ import division import sys import time from typing import Callable, Iterator, Union, Optional, List import os.path as osp import yaml from loguru import logger import pickle import numpy as np import torch import torch.autograd as autograd import torch.nn as nn import torch.nn.functional as F from .utils import get_reduction_method __all__ = [ 'VertexEdgeLoss', 'build_loss', ] def build_loss(type='l2', reduction='mean', **kwargs) -> nn.Module: logger.debug(f'Building loss: {type}') if type == 'l2': return WeightedMSELoss(reduction=reduction, **kwargs) elif type == 'vertex-edge': return VertexEdgeLoss(reduction=reduction, **kwargs) elif type == 'l1': return nn.L1Loss() else: raise ValueError(f'Unknown loss type: {type}') class WeightedMSELoss(nn.Module): def __init__(self, reduction='mean', **kwargs): super(WeightedMSELoss, self).__init__() self.reduce_str = reduction self.reduce = get_reduction_method(reduction) def forward(self, input, target, weights=None): diff = input - target if weights is None: return diff.pow(2).sum() / diff.shape[0] else: return ( weights.unsqueeze(dim=-1) * diff.pow(2)).sum() / diff.shape[0] class VertexEdgeLoss(nn.Module): def __init__(self, norm_type='l2', gt_edges=None, gt_edge_path='', est_edges=None, est_edge_path='', robustifier=None, edge_thresh=0.0, epsilon=1e-8, reduction='sum', **kwargs): super(VertexEdgeLoss, self).__init__() assert norm_type in ['l1', 'l2'], 'Norm type must be [l1, l2]' self.norm_type = norm_type self.epsilon = epsilon self.reduction = reduction assert self.reduction in ['sum', 'mean'] logger.info(f'Building edge loss with' f' norm_type={norm_type},' f' reduction={reduction},' ) gt_edge_path = osp.expandvars(gt_edge_path) est_edge_path = osp.expandvars(est_edge_path) assert osp.exists(gt_edge_path) or gt_edges is not None, ( 'gt_edges must not be None or gt_edge_path must exist' ) assert osp.exists(est_edge_path) or est_edges is not None, ( 'est_edges must not be None or est_edge_path must exist' ) if osp.exists(gt_edge_path) and gt_edges is None: gt_edges = np.load(gt_edge_path) if osp.exists(est_edge_path) and est_edges is None: est_edges = np.load(est_edge_path) self.register_buffer( 'gt_connections', torch.tensor(gt_edges, dtype=torch.long)) self.register_buffer( 'est_connections', torch.tensor(est_edges, dtype=torch.long)) def extra_repr(self): msg = [ f'Norm type: {self.norm_type}', ] if self.has_connections: msg.append( f'GT Connections shape: {self.gt_connections.shape}' ) msg.append( f'Est Connections shape: {self.est_connections.shape}' ) return '\n'.join(msg) def compute_edges(self, points, connections): edge_points = torch.index_select( points, 1, connections.view(-1)).reshape(points.shape[0], -1, 2, 3) return edge_points[:, :, 1] - edge_points[:, :, 0] def forward(self, gt_vertices, est_vertices, weights=None): gt_edges = self.compute_edges( gt_vertices, connections=self.gt_connections) est_edges = self.compute_edges( est_vertices, connections=self.est_connections) raw_edge_diff = (gt_edges - est_edges) batch_size = gt_vertices.shape[0] if self.norm_type == 'l2': edge_diff = raw_edge_diff.pow(2) elif self.norm_type == 'l1': edge_diff = raw_edge_diff.abs() else: raise NotImplementedError( f'Loss type not implemented: {self.loss_type}') if self.reduction == 'sum': return edge_diff.sum() elif self.reduction == 'mean': return edge_diff.sum() / batch_size