kleinhe
init
c3d0293
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2020 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: Vassilis Choutas, [email protected]
from __future__ import print_function
from __future__ import absolute_import
from __future__ import division
import sys
import time
from typing import Callable, Iterator, Union, Optional, List
import os.path as osp
import yaml
from loguru import logger
import pickle
import numpy as np
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
from .utils import get_reduction_method
__all__ = [
'VertexEdgeLoss',
'build_loss',
]
def build_loss(type='l2', reduction='mean', **kwargs) -> nn.Module:
logger.debug(f'Building loss: {type}')
if type == 'l2':
return WeightedMSELoss(reduction=reduction, **kwargs)
elif type == 'vertex-edge':
return VertexEdgeLoss(reduction=reduction, **kwargs)
elif type == 'l1':
return nn.L1Loss()
else:
raise ValueError(f'Unknown loss type: {type}')
class WeightedMSELoss(nn.Module):
def __init__(self, reduction='mean', **kwargs):
super(WeightedMSELoss, self).__init__()
self.reduce_str = reduction
self.reduce = get_reduction_method(reduction)
def forward(self, input, target, weights=None):
diff = input - target
if weights is None:
return diff.pow(2).sum() / diff.shape[0]
else:
return (
weights.unsqueeze(dim=-1) * diff.pow(2)).sum() / diff.shape[0]
class VertexEdgeLoss(nn.Module):
def __init__(self, norm_type='l2',
gt_edges=None,
gt_edge_path='',
est_edges=None,
est_edge_path='',
robustifier=None,
edge_thresh=0.0, epsilon=1e-8,
reduction='sum',
**kwargs):
super(VertexEdgeLoss, self).__init__()
assert norm_type in ['l1', 'l2'], 'Norm type must be [l1, l2]'
self.norm_type = norm_type
self.epsilon = epsilon
self.reduction = reduction
assert self.reduction in ['sum', 'mean']
logger.info(f'Building edge loss with'
f' norm_type={norm_type},'
f' reduction={reduction},'
)
gt_edge_path = osp.expandvars(gt_edge_path)
est_edge_path = osp.expandvars(est_edge_path)
assert osp.exists(gt_edge_path) or gt_edges is not None, (
'gt_edges must not be None or gt_edge_path must exist'
)
assert osp.exists(est_edge_path) or est_edges is not None, (
'est_edges must not be None or est_edge_path must exist'
)
if osp.exists(gt_edge_path) and gt_edges is None:
gt_edges = np.load(gt_edge_path)
if osp.exists(est_edge_path) and est_edges is None:
est_edges = np.load(est_edge_path)
self.register_buffer(
'gt_connections', torch.tensor(gt_edges, dtype=torch.long))
self.register_buffer(
'est_connections', torch.tensor(est_edges, dtype=torch.long))
def extra_repr(self):
msg = [
f'Norm type: {self.norm_type}',
]
if self.has_connections:
msg.append(
f'GT Connections shape: {self.gt_connections.shape}'
)
msg.append(
f'Est Connections shape: {self.est_connections.shape}'
)
return '\n'.join(msg)
def compute_edges(self, points, connections):
edge_points = torch.index_select(
points, 1, connections.view(-1)).reshape(points.shape[0], -1, 2, 3)
return edge_points[:, :, 1] - edge_points[:, :, 0]
def forward(self, gt_vertices, est_vertices, weights=None):
gt_edges = self.compute_edges(
gt_vertices, connections=self.gt_connections)
est_edges = self.compute_edges(
est_vertices, connections=self.est_connections)
raw_edge_diff = (gt_edges - est_edges)
batch_size = gt_vertices.shape[0]
if self.norm_type == 'l2':
edge_diff = raw_edge_diff.pow(2)
elif self.norm_type == 'l1':
edge_diff = raw_edge_diff.abs()
else:
raise NotImplementedError(
f'Loss type not implemented: {self.loss_type}')
if self.reduction == 'sum':
return edge_diff.sum()
elif self.reduction == 'mean':
return edge_diff.sum() / batch_size