Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
# holder of all proprietary rights on this computer program. | |
# You can only use this computer program if you have closed | |
# a license agreement with MPG or you get the right to use the computer | |
# program from someone who is authorized to grant you that right. | |
# Any use of the computer program without a valid license is prohibited and | |
# liable to prosecution. | |
# | |
# Copyright©2020 Max-Planck-Gesellschaft zur Förderung | |
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
# for Intelligent Systems. All rights reserved. | |
# | |
# Contact: Vassilis Choutas, [email protected] | |
from __future__ import print_function | |
from __future__ import absolute_import | |
from __future__ import division | |
import sys | |
import time | |
from typing import Callable, Iterator, Union, Optional, List | |
import os.path as osp | |
import yaml | |
from loguru import logger | |
import pickle | |
import numpy as np | |
import torch | |
import torch.autograd as autograd | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .utils import get_reduction_method | |
__all__ = [ | |
'VertexEdgeLoss', | |
'build_loss', | |
] | |
def build_loss(type='l2', reduction='mean', **kwargs) -> nn.Module: | |
logger.debug(f'Building loss: {type}') | |
if type == 'l2': | |
return WeightedMSELoss(reduction=reduction, **kwargs) | |
elif type == 'vertex-edge': | |
return VertexEdgeLoss(reduction=reduction, **kwargs) | |
elif type == 'l1': | |
return nn.L1Loss() | |
else: | |
raise ValueError(f'Unknown loss type: {type}') | |
class WeightedMSELoss(nn.Module): | |
def __init__(self, reduction='mean', **kwargs): | |
super(WeightedMSELoss, self).__init__() | |
self.reduce_str = reduction | |
self.reduce = get_reduction_method(reduction) | |
def forward(self, input, target, weights=None): | |
diff = input - target | |
if weights is None: | |
return diff.pow(2).sum() / diff.shape[0] | |
else: | |
return ( | |
weights.unsqueeze(dim=-1) * diff.pow(2)).sum() / diff.shape[0] | |
class VertexEdgeLoss(nn.Module): | |
def __init__(self, norm_type='l2', | |
gt_edges=None, | |
gt_edge_path='', | |
est_edges=None, | |
est_edge_path='', | |
robustifier=None, | |
edge_thresh=0.0, epsilon=1e-8, | |
reduction='sum', | |
**kwargs): | |
super(VertexEdgeLoss, self).__init__() | |
assert norm_type in ['l1', 'l2'], 'Norm type must be [l1, l2]' | |
self.norm_type = norm_type | |
self.epsilon = epsilon | |
self.reduction = reduction | |
assert self.reduction in ['sum', 'mean'] | |
logger.info(f'Building edge loss with' | |
f' norm_type={norm_type},' | |
f' reduction={reduction},' | |
) | |
gt_edge_path = osp.expandvars(gt_edge_path) | |
est_edge_path = osp.expandvars(est_edge_path) | |
assert osp.exists(gt_edge_path) or gt_edges is not None, ( | |
'gt_edges must not be None or gt_edge_path must exist' | |
) | |
assert osp.exists(est_edge_path) or est_edges is not None, ( | |
'est_edges must not be None or est_edge_path must exist' | |
) | |
if osp.exists(gt_edge_path) and gt_edges is None: | |
gt_edges = np.load(gt_edge_path) | |
if osp.exists(est_edge_path) and est_edges is None: | |
est_edges = np.load(est_edge_path) | |
self.register_buffer( | |
'gt_connections', torch.tensor(gt_edges, dtype=torch.long)) | |
self.register_buffer( | |
'est_connections', torch.tensor(est_edges, dtype=torch.long)) | |
def extra_repr(self): | |
msg = [ | |
f'Norm type: {self.norm_type}', | |
] | |
if self.has_connections: | |
msg.append( | |
f'GT Connections shape: {self.gt_connections.shape}' | |
) | |
msg.append( | |
f'Est Connections shape: {self.est_connections.shape}' | |
) | |
return '\n'.join(msg) | |
def compute_edges(self, points, connections): | |
edge_points = torch.index_select( | |
points, 1, connections.view(-1)).reshape(points.shape[0], -1, 2, 3) | |
return edge_points[:, :, 1] - edge_points[:, :, 0] | |
def forward(self, gt_vertices, est_vertices, weights=None): | |
gt_edges = self.compute_edges( | |
gt_vertices, connections=self.gt_connections) | |
est_edges = self.compute_edges( | |
est_vertices, connections=self.est_connections) | |
raw_edge_diff = (gt_edges - est_edges) | |
batch_size = gt_vertices.shape[0] | |
if self.norm_type == 'l2': | |
edge_diff = raw_edge_diff.pow(2) | |
elif self.norm_type == 'l1': | |
edge_diff = raw_edge_diff.abs() | |
else: | |
raise NotImplementedError( | |
f'Loss type not implemented: {self.loss_type}') | |
if self.reduction == 'sum': | |
return edge_diff.sum() | |
elif self.reduction == 'mean': | |
return edge_diff.sum() / batch_size | |