|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
|
|
|
|
class HDSNRandomLoss(nn.Module): |
|
""" |
|
Hieratical depth spatial normalization loss. |
|
Replace the original grid masks with the random created masks. |
|
loss = MAE((d-median(d)/s - (d'-median(d'))/s'), s = mean(d- median(d)) |
|
""" |
|
def __init__(self, loss_weight=1.0, random_num=32, data_type=['sfm', 'stereo', 'lidar', 'denselidar', 'denselidar_nometric','denselidar_syn'], disable_dataset=['MapillaryPSD'], sky_id=142, batch_limit=8, **kwargs): |
|
super(HDSNRandomLoss, self).__init__() |
|
self.loss_weight = loss_weight |
|
self.random_num = random_num |
|
self.data_type = data_type |
|
self.sky_id = sky_id |
|
self.batch_limit = batch_limit |
|
self.eps = 1e-6 |
|
self.disable_dataset = disable_dataset |
|
|
|
def get_random_masks_for_batch(self, image_size: list)-> torch.Tensor: |
|
height, width = image_size |
|
crop_h_min = int(0.125 * height) |
|
crop_h_max = int(0.5 * height) |
|
crop_w_min = int(0.125 * width) |
|
crop_w_max = int(0.5 * width) |
|
h_max = height - crop_h_min |
|
w_max = width - crop_w_min |
|
crop_height = np.random.choice(np.arange(crop_h_min, crop_h_max), self.random_num, replace=False) |
|
crop_width = np.random.choice(np.arange(crop_w_min, crop_w_max), self.random_num, replace=False) |
|
crop_y = np.random.choice(h_max, self.random_num, replace=False) |
|
crop_x = np.random.choice(w_max, self.random_num, replace=False) |
|
crop_y_end = crop_height + crop_y |
|
crop_y_end[crop_y_end>=height] = height |
|
crop_x_end = crop_width + crop_x |
|
crop_x_end[crop_x_end>=width] = width |
|
|
|
mask_new = torch.zeros((self.random_num, height, width), dtype=torch.bool, device="cuda") |
|
for i in range(self.random_num): |
|
mask_new[i, crop_y[i]:crop_y_end[i], crop_x[i]:crop_x_end[i]] = True |
|
|
|
return mask_new |
|
|
|
|
|
def reorder_sem_masks(self, sem_label): |
|
|
|
assert sem_label.ndim == 3 |
|
semantic_ids = torch.unique(sem_label[(sem_label>0) & (sem_label != self.sky_id)]) |
|
sem_masks = [sem_label == id for id in semantic_ids] |
|
if len(sem_masks) == 0: |
|
|
|
out = sem_label > 0 |
|
return out |
|
|
|
sem_masks = torch.cat(sem_masks, dim=0) |
|
mask_batch = torch.sum(sem_masks.reshape(sem_masks.shape[0], -1), dim=1) > 500 |
|
sem_masks = sem_masks[mask_batch] |
|
if sem_masks.shape[0] > self.random_num: |
|
balance_samples = np.random.choice(sem_masks.shape[0], self.random_num, replace=False) |
|
sem_masks = sem_masks[balance_samples, ...] |
|
|
|
if sem_masks.shape[0] == 0: |
|
|
|
out = sem_label > 0 |
|
return out |
|
|
|
if sem_masks.ndim == 2: |
|
sem_masks = sem_masks[None, :, :] |
|
return sem_masks |
|
|
|
def ssi_mae(self, prediction, target, mask_valid): |
|
B, C, H, W = target.shape |
|
prediction_nan = prediction.clone().detach() |
|
target_nan = target.clone() |
|
prediction_nan[~mask_valid] = float('nan') |
|
target_nan[~mask_valid] = float('nan') |
|
|
|
valid_pixs = mask_valid.reshape((B, C,-1)).sum(dim=2, keepdims=True) + 1e-10 |
|
valid_pixs = valid_pixs[:, :, :, None] |
|
|
|
gt_median = target_nan.reshape((B, C,-1)).nanmedian(2, keepdims=True)[0].unsqueeze(-1) |
|
gt_median[torch.isnan(gt_median)] = 0 |
|
gt_diff = (torch.abs(target - gt_median) ).reshape((B, C, -1)) |
|
gt_s = gt_diff.sum(dim=2)[:, :, None, None] / valid_pixs |
|
gt_trans = (target - gt_median) / (gt_s + self.eps) |
|
|
|
pred_median = prediction_nan.reshape((B, C,-1)).nanmedian(2, keepdims=True)[0].unsqueeze(-1) |
|
pred_median[torch.isnan(pred_median)] = 0 |
|
pred_diff = (torch.abs(prediction - pred_median)).reshape((B, C, -1)) |
|
pred_s = pred_diff.sum(dim=2)[:, :, None, None] / valid_pixs |
|
pred_trans = (prediction - pred_median) / (pred_s + self.eps) |
|
|
|
loss_sum = torch.sum(torch.abs(gt_trans - pred_trans)*mask_valid) |
|
return loss_sum |
|
|
|
def conditional_ssi_mae(self, prediction, target, mask_valid): |
|
B, C, H, W = target.shape |
|
conditional_rank_ids = np.random.choice(B, B, replace=False) |
|
|
|
prediction_nan = prediction.clone() |
|
target_nan = target.clone() |
|
prediction_nan[~mask_valid] = float('nan') |
|
target_nan[~mask_valid] = float('nan') |
|
|
|
valid_pixs = mask_valid.reshape((B, C,-1)).sum(dim=2, keepdims=True) + self.eps |
|
valid_pixs = valid_pixs[:, :, :, None].contiguous() |
|
|
|
gt_median = target_nan.reshape((B, C,-1)).nanmedian(2, keepdims=True)[0].unsqueeze(-1) |
|
gt_median[torch.isnan(gt_median)] = 0 |
|
gt_diff = (torch.abs(target - gt_median) * mask_valid).reshape((B, C,-1)) |
|
gt_s = gt_diff.sum(dim=2)[:, :, None, None].contiguous() / valid_pixs |
|
|
|
|
|
gt_s_small_mask = gt_s < (torch.mean(gt_s)*0.1) |
|
gt_s[gt_s_small_mask] = torch.mean(gt_s) |
|
gt_trans = (target - gt_median[conditional_rank_ids]) / (gt_s[conditional_rank_ids] + self.eps) |
|
|
|
pred_median = prediction_nan.reshape((B, C,-1)).nanmedian(2, keepdims=True)[0].unsqueeze(-1) |
|
pred_median[torch.isnan(pred_median)] = 0 |
|
pred_diff = (torch.abs(prediction - pred_median) * mask_valid).reshape((B, C,-1)) |
|
pred_s = pred_diff.sum(dim=2)[:, :, None, None].contiguous() / valid_pixs |
|
pred_s[gt_s_small_mask] = torch.mean(pred_s) |
|
pred_trans = (prediction - pred_median[conditional_rank_ids]) / (pred_s[conditional_rank_ids] + self.eps) |
|
|
|
loss_sum = torch.sum(torch.abs(gt_trans - pred_trans)*mask_valid) |
|
|
|
return loss_sum |
|
|
|
|
|
def forward(self, prediction, target, mask=None, sem_mask=None, **kwargs): |
|
""" |
|
Calculate loss. |
|
""" |
|
B, C, H, W = target.shape |
|
|
|
loss = 0.0 |
|
valid_pix = 0.0 |
|
|
|
device = target.device |
|
|
|
batches_dataset = kwargs['dataset'] |
|
self.batch_valid = torch.tensor([1 if batch_dataset not in self.disable_dataset else 0 \ |
|
for batch_dataset in batches_dataset], device=device)[:,None,None,None] |
|
|
|
batch_limit = self.batch_limit |
|
|
|
random_sample_masks = self.get_random_masks_for_batch((H, W)) |
|
for i in range(B): |
|
|
|
mask_i = mask[i, ...] |
|
if self.batch_valid[i, ...] < 0.5: |
|
loss += 0 * torch.sum(prediction[i, ...]) |
|
valid_pix += 0 * torch.sum(mask_i) |
|
continue |
|
|
|
pred_i = prediction[i, ...].unsqueeze(0).repeat(batch_limit, 1, 1, 1) |
|
target_i = target[i, ...].unsqueeze(0).repeat(batch_limit, 1, 1, 1) |
|
|
|
|
|
sem_label_i = sem_mask[i, ...] if sem_mask is not None else None |
|
if sem_label_i is not None: |
|
sem_masks = self.reorder_sem_masks(sem_label_i) |
|
random_sem_masks = torch.cat([random_sample_masks, sem_masks], dim=0) |
|
else: |
|
random_sem_masks = random_sample_masks |
|
|
|
|
|
|
|
sampled_masks_num = random_sem_masks.shape[0] |
|
loops = int(np.ceil(sampled_masks_num / batch_limit)) |
|
conditional_rank_ids = np.random.choice(sampled_masks_num, sampled_masks_num, replace=False) |
|
|
|
for j in range(loops): |
|
mask_random_sem_loopi = random_sem_masks[j*batch_limit:(j+1)*batch_limit, ...] |
|
mask_sample = (mask_i & mask_random_sem_loopi).unsqueeze(1) |
|
loss += self.ssi_mae( |
|
prediction=pred_i[:mask_sample.shape[0], ...], |
|
target=target_i[:mask_sample.shape[0], ...], |
|
mask_valid=mask_sample) |
|
valid_pix += torch.sum(mask_sample) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask = mask * self.batch_valid.bool() |
|
loss += self.ssi_mae( |
|
prediction=prediction, |
|
target=target, |
|
mask_valid=mask) |
|
valid_pix += torch.sum(mask) |
|
loss = loss / (valid_pix + self.eps) |
|
if torch.isnan(loss).item() | torch.isinf(loss).item(): |
|
loss = 0 * torch.sum(prediction) |
|
print(f'HDSNL NAN error, {loss}, valid pix: {valid_pix}') |
|
return loss * self.loss_weight |
|
|
|
if __name__ == '__main__': |
|
torch.manual_seed(1) |
|
torch.cuda.manual_seed_all(1) |
|
ssil = HDSNRandomLoss() |
|
pred = torch.rand((8, 1, 256, 512)).cuda() |
|
gt = torch.rand((8, 1, 256, 512)).cuda() |
|
gt[1:, :, 100:256, 100:350] = -1 |
|
gt[:2, ...] = -1 |
|
mask = gt > 0 |
|
sem_mask = np.random.randint(-1, 200, (8, 1, 256, 512)) |
|
sem_mask[sem_mask>0] = -1 |
|
sem_mask_torch = torch.from_numpy(sem_mask).cuda() |
|
|
|
out = ssil(pred, gt, mask, sem_mask_torch) |
|
print(out) |
|
|