File size: 9,103 Bytes
3ef1661 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import torch
import torch.nn as nn
class GRUSequenceLoss(nn.Module):
"""
Loss function defined over sequence of depth predictions
"""
def __init__(self, loss_weight=1, data_type=['lidar', 'denselidar', 'stereo', 'denselidar_syn'], loss_gamma=0.9, silog=False, stereo_sup=0.001, stereo_dataset=['KITTI', 'NYU'], **kwargs):
super(GRUSequenceLoss, self).__init__()
self.loss_weight = loss_weight
self.data_type = data_type
self.eps = 1e-6
self.loss_gamma = loss_gamma
self.silog = silog
self.variance_focus = 0.5
self.stereo_sup = stereo_sup
self.stereo_dataset = stereo_dataset
# assert stereo_mode in ['stereo', 'self_sup']
# self.stereo_mode = stereo_mode
# self.stereo_max = stereo_max
def silog_loss(self, prediction, target, mask):
mask = mask & (prediction > 0.01) & (target > 0.01)
d = torch.log(prediction[mask]) - torch.log(target[mask])
# d_square_mean = torch.sum(d ** 2) / (d.numel() + self.eps)
# d_mean = torch.sum(d) / (d.numel() + self.eps)
# loss = d_square_mean - self.variance_focus * (d_mean ** 2)
loss = torch.sum(torch.abs(d)) / (d.numel() + self.eps)
print("new log l1 loss")
return loss
def conf_loss(self, confidence, prediction, target, mask):
conf_mask = torch.abs(target - prediction) < target
conf_mask = conf_mask & mask
gt_confidence = (1 - torch.abs((prediction - target) / target)) * conf_mask
loss = torch.sum(torch.abs(confidence - gt_confidence) * conf_mask) / (torch.sum(conf_mask) + self.eps)
if torch.isnan(loss).item() | torch.isinf(loss).item():
print(f'GRUSequenceLoss-confidence NAN error, {loss}')
loss = 0 * torch.sum(confidence)
return loss
def forward(self, predictions_list, target, stereo_depth, confidence_list=None, mask=None, **kwargs):
device = target.device
batches_dataset = kwargs['dataset']
self.batch_with_stereo = torch.tensor([1 if batch_dataset in self.stereo_dataset else 0 \
for batch_dataset in batches_dataset], device=device)[:,None,None,None]
n_predictions = len(predictions_list)
assert n_predictions >= 1
loss = 0.0
for i, prediction in enumerate(predictions_list):
# if self.stereo_mode == 'self_sup' and self.stereo_sup > 1e-8:
# B, C, H, W = target.shape
# prediction_nan = prediction.clone().detach()
# target_nan = target.clone()
# prediction_nan[~mask] = float('nan')
# target_nan[~mask] = float('nan')
# gt_median = target_nan.reshape((B, C,-1)).nanmedian(2)[0][:, :, None, None]
# pred_median = prediction_nan.reshape((B, C,-1)).nanmedian(2)[0][:, :, None, None]
# scale = gt_median / (pred_median + 1e-8)
# stereo_depth = (0.0 * stereo_depth + scale * prediction * (prediction < (self.stereo_max - 1)) + \
# prediction * (prediction > (self.stereo_max - 1))).detach()
# We adjust the loss_gamma so it is consistent for any number of RAFT-Stereo iterations
adjusted_loss_gamma = self.loss_gamma**(15/(n_predictions - 1))
i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
# depth L1 loss
if self.silog and mask.sum() > 0:
curr_loss = self.silog_loss(prediction, target, mask)
else:
diff = torch.abs(prediction - target) * mask
#diff = diff + diff * diff * 1.0
curr_loss = torch.sum(diff) / (torch.sum(mask) + self.eps)
if torch.isnan(curr_loss).item() | torch.isinf(curr_loss).item():
print(f'GRUSequenceLoss-depth NAN error, {curr_loss}')
curr_loss = 0 * torch.sum(prediction)
# confidence L1 loss
conf_loss = 0
if confidence_list is not None:
conf_loss = self.conf_loss(confidence_list[i], prediction, target, mask)
# stereo depth loss
mask_stereo = 1 + torch.nn.functional.max_pool2d(\
- torch.nn.functional.max_pool2d(mask * 1.0, 3, stride=1, padding=1, dilation=1), 3, stride=1, padding=1, dilation=1)
stereo_diff = torch.abs(prediction - stereo_depth) * mask_stereo
#stereo_diff = stereo_diff + stereo_diff * stereo_diff * 1.0
stereo_depth_loss = torch.sum(self.batch_with_stereo * stereo_diff * mask_stereo) / (torch.sum(mask_stereo) + self.eps)
stereo_depth_loss = self.stereo_sup * stereo_depth_loss
loss += (conf_loss + curr_loss + stereo_depth_loss) * i_weight
#raise RuntimeError(f'Silog error, {loss}, d_square_mean: {d_square_mean}, d_mean: {d_mean}')
return loss * self.loss_weight
# import torch
# import torch.nn as nn
# class GRUSequenceLoss(nn.Module):
# """
# Loss function defined over sequence of depth predictions
# """
# def __init__(self, loss_weight=1, data_type=['lidar', 'denselidar', 'stereo', 'denselidar_syn'], loss_gamma=0.9, silog=False, stereo_sup=0.001, stereo_dataset=['BigData'], **kwargs):
# super(GRUSequenceLoss, self).__init__()
# self.loss_weight = loss_weight
# self.data_type = data_type
# self.eps = 1e-6
# self.loss_gamma = loss_gamma
# self.silog = silog
# self.variance_focus = 0.5
# self.stereo_sup = stereo_sup
# self.stereo_dataset = stereo_dataset
# def silog_loss(self, prediction, target, mask):
# mask = mask & (prediction > 0.01) & (target > 0.01)
# d = torch.log(prediction[mask]) - torch.log(target[mask])
# # d_square_mean = torch.sum(d ** 2) / (d.numel() + self.eps)
# # d_mean = torch.sum(d) / (d.numel() + self.eps)
# # loss = d_square_mean - self.variance_focus * (d_mean ** 2)
# loss = torch.sum(torch.abs(d)) / (d.numel() + self.eps)
# print("new log l1 loss")
# return loss
# def conf_loss(self, confidence, prediction, target, mask):
# conf_mask = torch.abs(target - prediction) < target
# conf_mask = conf_mask & mask
# gt_confidence = (1 - torch.abs((prediction - target) / target)) * conf_mask
# loss = torch.sum(torch.abs(confidence - gt_confidence) * conf_mask) / (torch.sum(conf_mask) + self.eps)
# if torch.isnan(loss).item() | torch.isinf(loss).item():
# print(f'GRUSequenceLoss-confidence NAN error, {loss}')
# loss = 0 * torch.sum(confidence)
# return loss
# def forward(self, predictions_list, target, stereo_depth, confidence_list=None, mask=None, **kwargs):
# device = target.device
# batches_dataset = kwargs['dataset']
# self.batch_with_stereo = torch.tensor([1 if batch_dataset in self.stereo_dataset else 0 \
# for batch_dataset in batches_dataset], device=device)[:,None,None,None]
# n_predictions = len(predictions_list)
# assert n_predictions >= 1
# loss = 0.0
# for i, prediction in enumerate(predictions_list):
# # We adjust the loss_gamma so it is consistent for any number of RAFT-Stereo iterations
# adjusted_loss_gamma = self.loss_gamma**(15/(n_predictions - 1))
# i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
# # depth L1 loss
# if self.silog and mask.sum() > 0:
# curr_loss = self.silog_loss(prediction, target, mask)
# else:
# diff = torch.abs(prediction - target) * mask
# curr_loss = torch.sum(diff) / (torch.sum(mask) + self.eps)
# if torch.isnan(curr_loss).item() | torch.isinf(curr_loss).item():
# print(f'GRUSequenceLoss-depth NAN error, {curr_loss}')
# curr_loss = 0 * torch.sum(prediction)
# # confidence L1 loss
# conf_loss = 0
# if confidence_list is not None:
# conf_loss = self.conf_loss(confidence_list[i], prediction, target, mask)
# # stereo depth loss
# mask_stereo = 1 + torch.nn.functional.max_pool2d(\
# - torch.nn.functional.max_pool2d(mask * 1.0, 5, stride=1, padding=2, dilation=1), 5, stride=1, padding=2, dilation=1)
# stereo_diff = torch.abs(prediction - stereo_depth) * mask_stereo
# stereo_depth_loss = torch.sum(self.batch_with_stereo * stereo_diff * mask_stereo) / (torch.sum(mask_stereo) + self.eps)
# stereo_depth_loss = self.stereo_sup * stereo_depth_loss
# loss += (conf_loss + curr_loss + stereo_depth_loss) * i_weight
# #raise RuntimeError(f'Silog error, {loss}, d_square_mean: {d_square_mean}, d_mean: {d_mean}')
# return loss * self.loss_weight |