import torch import torch.nn as nn import numpy as np #from numba import jit class HDSNRandomLoss(nn.Module): """ Hieratical depth spatial normalization loss. Replace the original grid masks with the random created masks. loss = MAE((d-median(d)/s - (d'-median(d'))/s'), s = mean(d- median(d)) """ def __init__(self, loss_weight=1.0, random_num=32, data_type=['sfm', 'stereo', 'lidar', 'denselidar', 'denselidar_nometric','denselidar_syn'], disable_dataset=['MapillaryPSD'], sky_id=142, batch_limit=8, **kwargs): super(HDSNRandomLoss, self).__init__() self.loss_weight = loss_weight self.random_num = random_num self.data_type = data_type self.sky_id = sky_id self.batch_limit = batch_limit self.eps = 1e-6 self.disable_dataset = disable_dataset def get_random_masks_for_batch(self, image_size: list)-> torch.Tensor: height, width = image_size crop_h_min = int(0.125 * height) crop_h_max = int(0.5 * height) crop_w_min = int(0.125 * width) crop_w_max = int(0.5 * width) h_max = height - crop_h_min w_max = width - crop_w_min crop_height = np.random.choice(np.arange(crop_h_min, crop_h_max), self.random_num, replace=False) crop_width = np.random.choice(np.arange(crop_w_min, crop_w_max), self.random_num, replace=False) crop_y = np.random.choice(h_max, self.random_num, replace=False) crop_x = np.random.choice(w_max, self.random_num, replace=False) crop_y_end = crop_height + crop_y crop_y_end[crop_y_end>=height] = height crop_x_end = crop_width + crop_x crop_x_end[crop_x_end>=width] = width mask_new = torch.zeros((self.random_num, height, width), dtype=torch.bool, device="cuda") #.cuda() #[N, H, W] for i in range(self.random_num): mask_new[i, crop_y[i]:crop_y_end[i], crop_x[i]:crop_x_end[i]] = True return mask_new #return crop_y, crop_y_end, crop_x, crop_x_end def reorder_sem_masks(self, sem_label): # reorder the semantic mask of a batch assert sem_label.ndim == 3 semantic_ids = torch.unique(sem_label[(sem_label>0) & (sem_label != self.sky_id)]) sem_masks = [sem_label == id for id in semantic_ids] if len(sem_masks) == 0: # no valid semantic labels out = sem_label > 0 return out sem_masks = torch.cat(sem_masks, dim=0) mask_batch = torch.sum(sem_masks.reshape(sem_masks.shape[0], -1), dim=1) > 500 sem_masks = sem_masks[mask_batch] if sem_masks.shape[0] > self.random_num: balance_samples = np.random.choice(sem_masks.shape[0], self.random_num, replace=False) sem_masks = sem_masks[balance_samples, ...] if sem_masks.shape[0] == 0: # no valid semantic labels out = sem_label > 0 return out if sem_masks.ndim == 2: sem_masks = sem_masks[None, :, :] return sem_masks def ssi_mae(self, prediction, target, mask_valid): B, C, H, W = target.shape prediction_nan = prediction.clone().detach() target_nan = target.clone() prediction_nan[~mask_valid] = float('nan') target_nan[~mask_valid] = float('nan') valid_pixs = mask_valid.reshape((B, C,-1)).sum(dim=2, keepdims=True) + 1e-10 valid_pixs = valid_pixs[:, :, :, None] gt_median = target_nan.reshape((B, C,-1)).nanmedian(2, keepdims=True)[0].unsqueeze(-1) # [b,c,h,w] gt_median[torch.isnan(gt_median)] = 0 gt_diff = (torch.abs(target - gt_median) ).reshape((B, C, -1)) gt_s = gt_diff.sum(dim=2)[:, :, None, None] / valid_pixs gt_trans = (target - gt_median) / (gt_s + self.eps) pred_median = prediction_nan.reshape((B, C,-1)).nanmedian(2, keepdims=True)[0].unsqueeze(-1) # [b,c,h,w] pred_median[torch.isnan(pred_median)] = 0 pred_diff = (torch.abs(prediction - pred_median)).reshape((B, C, -1)) pred_s = pred_diff.sum(dim=2)[:, :, None, None] / valid_pixs pred_trans = (prediction - pred_median) / (pred_s + self.eps) loss_sum = torch.sum(torch.abs(gt_trans - pred_trans)*mask_valid) return loss_sum def conditional_ssi_mae(self, prediction, target, mask_valid): B, C, H, W = target.shape conditional_rank_ids = np.random.choice(B, B, replace=False) prediction_nan = prediction.clone() target_nan = target.clone() prediction_nan[~mask_valid] = float('nan') target_nan[~mask_valid] = float('nan') valid_pixs = mask_valid.reshape((B, C,-1)).sum(dim=2, keepdims=True) + self.eps valid_pixs = valid_pixs[:, :, :, None].contiguous() gt_median = target_nan.reshape((B, C,-1)).nanmedian(2, keepdims=True)[0].unsqueeze(-1) # [b,c,h,w] gt_median[torch.isnan(gt_median)] = 0 gt_diff = (torch.abs(target - gt_median) * mask_valid).reshape((B, C,-1)) gt_s = gt_diff.sum(dim=2)[:, :, None, None].contiguous() / valid_pixs # in case some batches have no valid pixels gt_s_small_mask = gt_s < (torch.mean(gt_s)*0.1) gt_s[gt_s_small_mask] = torch.mean(gt_s) gt_trans = (target - gt_median[conditional_rank_ids]) / (gt_s[conditional_rank_ids] + self.eps) pred_median = prediction_nan.reshape((B, C,-1)).nanmedian(2, keepdims=True)[0].unsqueeze(-1) # [b,c,h,w] pred_median[torch.isnan(pred_median)] = 0 pred_diff = (torch.abs(prediction - pred_median) * mask_valid).reshape((B, C,-1)) pred_s = pred_diff.sum(dim=2)[:, :, None, None].contiguous() / valid_pixs pred_s[gt_s_small_mask] = torch.mean(pred_s) pred_trans = (prediction - pred_median[conditional_rank_ids]) / (pred_s[conditional_rank_ids] + self.eps) loss_sum = torch.sum(torch.abs(gt_trans - pred_trans)*mask_valid) # print(torch.abs(gt_trans - pred_trans)[mask_valid]) return loss_sum def forward(self, prediction, target, mask=None, sem_mask=None, **kwargs): """ Calculate loss. """ B, C, H, W = target.shape loss = 0.0 valid_pix = 0.0 device = target.device batches_dataset = kwargs['dataset'] self.batch_valid = torch.tensor([1 if batch_dataset not in self.disable_dataset else 0 \ for batch_dataset in batches_dataset], device=device)[:,None,None,None] batch_limit = self.batch_limit random_sample_masks = self.get_random_masks_for_batch((H, W)) # [N, H, W] for i in range(B): # each batch mask_i = mask[i, ...] #[1, H, W] if self.batch_valid[i, ...] < 0.5: loss += 0 * torch.sum(prediction[i, ...]) valid_pix += 0 * torch.sum(mask_i) continue pred_i = prediction[i, ...].unsqueeze(0).repeat(batch_limit, 1, 1, 1) target_i = target[i, ...].unsqueeze(0).repeat(batch_limit, 1, 1, 1) # get semantic masks sem_label_i = sem_mask[i, ...] if sem_mask is not None else None if sem_label_i is not None: sem_masks = self.reorder_sem_masks(sem_label_i) # [N, H, W] random_sem_masks = torch.cat([random_sample_masks, sem_masks], dim=0) else: random_sem_masks = random_sample_masks #random_sem_masks = random_sample_masks sampled_masks_num = random_sem_masks.shape[0] loops = int(np.ceil(sampled_masks_num / batch_limit)) conditional_rank_ids = np.random.choice(sampled_masks_num, sampled_masks_num, replace=False) for j in range(loops): mask_random_sem_loopi = random_sem_masks[j*batch_limit:(j+1)*batch_limit, ...] mask_sample = (mask_i & mask_random_sem_loopi).unsqueeze(1) # [N, 1, H, W] loss += self.ssi_mae( prediction=pred_i[:mask_sample.shape[0], ...], target=target_i[:mask_sample.shape[0], ...], mask_valid=mask_sample) valid_pix += torch.sum(mask_sample) # conditional ssi loss # rerank_mask_random_sem_loopi = random_sem_masks[conditional_rank_ids, ...][j*batch_limit:(j+1)*batch_limit, ...] # rerank_mask_sample = (mask_i & rerank_mask_random_sem_loopi).unsqueeze(1) # [N, 1, H, W] # loss_cond = self.conditional_ssi_mae( # prediction=pred_i[:rerank_mask_sample.shape[0], ...], # target=target_i[:rerank_mask_sample.shape[0], ...], # mask_valid=rerank_mask_sample) # print(loss_cond / (torch.sum(rerank_mask_sample) + 1e-10), loss_cond, torch.sum(rerank_mask_sample)) # loss += loss_cond # valid_pix += torch.sum(rerank_mask_sample) # crop_y, crop_y_end, crop_x, crop_x_end = self.get_random_masks_for_batch((H, W)) # [N,] # for j in range(B): # for i in range(self.random_num): # mask_crop = mask[j, :, crop_y[i]:crop_y_end[i], crop_x[i]:crop_x_end[i]][None, ...] #[1, 1, crop_h, crop_w] # target_crop = target[j, :, crop_y[i]:crop_y_end[i], crop_x[i]:crop_x_end[i]][None, ...] # pred_crop = prediction[j, :, crop_y[i]:crop_y_end[i], crop_x[i]:crop_x_end[i]][None, ...] # loss += self.ssi_mae(prediction=pred_crop, target=target_crop, mask_valid=mask_crop) # valid_pix += torch.sum(mask_crop) # the whole image mask = mask * self.batch_valid.bool() loss += self.ssi_mae( prediction=prediction, target=target, mask_valid=mask) valid_pix += torch.sum(mask) loss = loss / (valid_pix + self.eps) if torch.isnan(loss).item() | torch.isinf(loss).item(): loss = 0 * torch.sum(prediction) print(f'HDSNL NAN error, {loss}, valid pix: {valid_pix}') return loss * self.loss_weight if __name__ == '__main__': torch.manual_seed(1) torch.cuda.manual_seed_all(1) ssil = HDSNRandomLoss() pred = torch.rand((8, 1, 256, 512)).cuda() gt = torch.rand((8, 1, 256, 512)).cuda()#torch.zeros_like(pred).cuda() # gt[1:, :, 100:256, 100:350] = -1 gt[:2, ...] = -1 mask = gt > 0 sem_mask = np.random.randint(-1, 200, (8, 1, 256, 512)) sem_mask[sem_mask>0] = -1 sem_mask_torch = torch.from_numpy(sem_mask).cuda() out = ssil(pred, gt, mask, sem_mask_torch) print(out)