Spaces:
Running
on
Zero
Running
on
Zero
| from loguru import logger | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from kornia.geometry.subpix import dsnt | |
| from kornia.utils.grid import create_meshgrid | |
| class LoFTRLoss(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config # config under the global namespace | |
| self.loss_config = config['loftr']['loss'] | |
| self.match_type = 'dual_softmax' | |
| self.sparse_spvs = self.config['loftr']['match_coarse']['sparse_spvs'] | |
| self.fine_sparse_spvs = self.config['loftr']['match_fine']['sparse_spvs'] | |
| # coarse-level | |
| self.correct_thr = self.loss_config['fine_correct_thr'] | |
| self.c_pos_w = self.loss_config['pos_weight'] | |
| self.c_neg_w = self.loss_config['neg_weight'] | |
| # coarse_overlap_weight | |
| self.overlap_weightc = self.config['loftr']['loss']['coarse_overlap_weight'] | |
| self.overlap_weightf = self.config['loftr']['loss']['fine_overlap_weight'] | |
| # subpixel-level | |
| self.local_regressw = self.config['loftr']['fine_window_size'] | |
| self.local_regress_temperature = self.config['loftr']['match_fine']['local_regress_temperature'] | |
| def compute_coarse_loss(self, conf, conf_gt, weight=None, overlap_weight=None): | |
| """ Point-wise CE / Focal Loss with 0 / 1 confidence as gt. | |
| Args: | |
| conf (torch.Tensor): (N, HW0, HW1) / (N, HW0+1, HW1+1) | |
| conf_gt (torch.Tensor): (N, HW0, HW1) | |
| weight (torch.Tensor): (N, HW0, HW1) | |
| """ | |
| pos_mask, neg_mask = conf_gt == 1, conf_gt == 0 | |
| del conf_gt | |
| # logger.info(f'real sum of conf_matrix_c_gt: {pos_mask.sum().item()}') | |
| c_pos_w, c_neg_w = self.c_pos_w, self.c_neg_w | |
| # corner case: no gt coarse-level match at all | |
| if not pos_mask.any(): # assign a wrong gt | |
| pos_mask[0, 0, 0] = True | |
| if weight is not None: | |
| weight[0, 0, 0] = 0. | |
| c_pos_w = 0. | |
| if not neg_mask.any(): | |
| neg_mask[0, 0, 0] = True | |
| if weight is not None: | |
| weight[0, 0, 0] = 0. | |
| c_neg_w = 0. | |
| if self.loss_config['coarse_type'] == 'focal': | |
| conf = torch.clamp(conf, 1e-6, 1-1e-6) | |
| alpha = self.loss_config['focal_alpha'] | |
| gamma = self.loss_config['focal_gamma'] | |
| if self.sparse_spvs: | |
| pos_conf = conf[pos_mask] | |
| loss_pos = - alpha * torch.pow(1 - pos_conf, gamma) * pos_conf.log() | |
| # handle loss weights | |
| if weight is not None: | |
| # Different from dense-spvs, the loss w.r.t. padded regions aren't directly zeroed out, | |
| # but only through manually setting corresponding regions in sim_matrix to '-inf'. | |
| loss_pos = loss_pos * weight[pos_mask] | |
| if self.overlap_weightc: | |
| loss_pos = loss_pos * overlap_weight # already been masked slice in supervision | |
| loss = c_pos_w * loss_pos.mean() | |
| return loss | |
| else: # dense supervision | |
| loss_pos = - alpha * torch.pow(1 - conf[pos_mask], gamma) * (conf[pos_mask]).log() | |
| loss_neg = - alpha * torch.pow(conf[neg_mask], gamma) * (1 - conf[neg_mask]).log() | |
| logger.info("conf_pos_c: {loss_pos}, conf_neg_c: {loss_neg}".format(loss_pos=conf[pos_mask].mean(), loss_neg=conf[neg_mask].mean())) | |
| if weight is not None: | |
| loss_pos = loss_pos * weight[pos_mask] | |
| loss_neg = loss_neg * weight[neg_mask] | |
| if self.overlap_weightc: | |
| loss_pos = loss_pos * overlap_weight # already been masked slice in supervision | |
| loss_pos_mean, loss_neg_mean = loss_pos.mean(), loss_neg.mean() | |
| logger.info("conf_pos_c: {loss_pos}, conf_neg_c: {loss_neg}".format(loss_pos=conf[pos_mask].mean(), loss_neg=conf[neg_mask].mean())) | |
| return c_pos_w * loss_pos_mean + c_neg_w * loss_neg_mean | |
| # each negative element occupy a smaller propotion than positive elements. => higher negative loss weight needed | |
| else: | |
| raise ValueError('Unknown coarse loss: {type}'.format(type=self.loss_config['coarse_type'])) | |
| def compute_fine_loss(self, conf_matrix_f, conf_matrix_f_gt, overlap_weight=None): | |
| """ | |
| Args: | |
| conf_matrix_f (torch.Tensor): [m, WW, WW] <x, y> | |
| conf_matrix_f_gt (torch.Tensor): [m, WW, WW] <x, y> | |
| """ | |
| if conf_matrix_f_gt.shape[0] == 0: | |
| if self.training: # this seldomly happen during training, since we pad prediction with gt | |
| # sometimes there is not coarse-level gt at all. | |
| logger.warning("assign a false supervision to avoid ddp deadlock") | |
| pass | |
| else: | |
| return None | |
| pos_mask, neg_mask = conf_matrix_f_gt == 1, conf_matrix_f_gt == 0 | |
| del conf_matrix_f_gt | |
| c_pos_w, c_neg_w = self.c_pos_w, self.c_neg_w | |
| if not pos_mask.any(): # assign a wrong gt | |
| pos_mask[0, 0, 0] = True | |
| c_pos_w = 0. | |
| if not neg_mask.any(): | |
| neg_mask[0, 0, 0] = True | |
| c_neg_w = 0. | |
| conf = torch.clamp(conf_matrix_f, 1e-6, 1-1e-6) | |
| alpha = self.loss_config['focal_alpha'] | |
| gamma = self.loss_config['focal_gamma'] | |
| if self.fine_sparse_spvs: | |
| loss_pos = - alpha * torch.pow(1 - conf[pos_mask], gamma) * (conf[pos_mask]).log() | |
| if self.overlap_weightf: | |
| loss_pos = loss_pos * overlap_weight # already been masked slice in supervision | |
| return c_pos_w * loss_pos.mean() | |
| else: | |
| loss_pos = - alpha * torch.pow(1 - conf[pos_mask], gamma) * (conf[pos_mask]).log() | |
| loss_neg = - alpha * torch.pow(conf[neg_mask], gamma) * (1 - conf[neg_mask]).log() | |
| logger.info("conf_pos_f: {loss_pos}, conf_neg_f: {loss_neg}".format(loss_pos=conf[pos_mask].mean(), loss_neg=conf[neg_mask].mean())) | |
| if self.overlap_weightf: | |
| loss_pos = loss_pos * overlap_weight # already been masked slice in supervision | |
| return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean() | |
| def _compute_local_loss_l2(self, expec_f, expec_f_gt): | |
| """ | |
| Args: | |
| expec_f (torch.Tensor): [M, 2] <x, y> | |
| expec_f_gt (torch.Tensor): [M, 2] <x, y> | |
| """ | |
| correct_mask = torch.linalg.norm(expec_f_gt, ord=float('inf'), dim=1) < self.correct_thr | |
| if correct_mask.sum() == 0: | |
| if self.training: # this seldomly happen when training, since we pad prediction with gt | |
| logger.warning("assign a false supervision to avoid ddp deadlock") | |
| correct_mask[0] = True | |
| else: | |
| return None | |
| offset_l2 = ((expec_f_gt[correct_mask] - expec_f[correct_mask]) ** 2).sum(-1) | |
| return offset_l2.mean() | |
| def compute_c_weight(self, data): | |
| """ compute element-wise weights for computing coarse-level loss. """ | |
| if 'mask0' in data: | |
| c_weight = (data['mask0'].flatten(-2)[..., None] * data['mask1'].flatten(-2)[:, None]) | |
| else: | |
| c_weight = None | |
| return c_weight | |
| def forward(self, data): | |
| """ | |
| Update: | |
| data (dict): update{ | |
| 'loss': [1] the reduced loss across a batch, | |
| 'loss_scalars' (dict): loss scalars for tensorboard_record | |
| } | |
| """ | |
| loss_scalars = {} | |
| # 0. compute element-wise loss weight | |
| c_weight = self.compute_c_weight(data) | |
| # 1. coarse-level loss | |
| if self.overlap_weightc: | |
| loss_c = self.compute_coarse_loss( | |
| data['conf_matrix_with_bin'] if self.sparse_spvs and self.match_type == 'sinkhorn' \ | |
| else data['conf_matrix'], | |
| data['conf_matrix_gt'], | |
| weight=c_weight, overlap_weight=data['conf_matrix_error_gt']) | |
| else: | |
| loss_c = self.compute_coarse_loss( | |
| data['conf_matrix_with_bin'] if self.sparse_spvs and self.match_type == 'sinkhorn' \ | |
| else data['conf_matrix'], | |
| data['conf_matrix_gt'], | |
| weight=c_weight) | |
| loss = loss_c * self.loss_config['coarse_weight'] | |
| loss_scalars.update({"loss_c": loss_c.clone().detach().cpu()}) | |
| # 2. pixel-level loss (first-stage refinement) | |
| if self.overlap_weightf: | |
| loss_f = self.compute_fine_loss(data['conf_matrix_f'], data['conf_matrix_f_gt'], data['conf_matrix_f_error_gt']) | |
| else: | |
| loss_f = self.compute_fine_loss(data['conf_matrix_f'], data['conf_matrix_f_gt']) | |
| if loss_f is not None: | |
| loss += loss_f * self.loss_config['fine_weight'] | |
| loss_scalars.update({"loss_f": loss_f.clone().detach().cpu()}) | |
| else: | |
| assert self.training is False | |
| loss_scalars.update({'loss_f': torch.tensor(1.)}) # 1 is the upper bound | |
| # 3. subpixel-level loss (second-stage refinement) | |
| # we calculate subpixel-level loss for all pixel-level gt | |
| if 'expec_f' not in data: | |
| sim_matrix_f, m_ids, i_ids, j_ids_di, j_ids_dj = data['sim_matrix_ff'], data['m_ids_f'], data['i_ids_f'], data['j_ids_f_di'], data['j_ids_f_dj'] | |
| del data['sim_matrix_ff'], data['m_ids_f'], data['i_ids_f'], data['j_ids_f_di'], data['j_ids_f_dj'] | |
| delta = create_meshgrid(3, 3, True, sim_matrix_f.device).to(torch.long) # [1, 3, 3, 2] | |
| m_ids = m_ids[...,None,None].expand(-1, 3, 3) | |
| i_ids = i_ids[...,None,None].expand(-1, 3, 3) | |
| # Note that j_ids_di & j_ids_dj in (i, j) format while delta in (x, y) format | |
| j_ids_di = j_ids_di[...,None,None].expand(-1, 3, 3) + delta[None, ..., 1] | |
| j_ids_dj = j_ids_dj[...,None,None].expand(-1, 3, 3) + delta[None, ..., 0] | |
| sim_matrix_f = sim_matrix_f.reshape(-1, self.local_regressw*self.local_regressw, self.local_regressw+2, self.local_regressw+2) # [M, WW, W+2, W+2] | |
| sim_matrix_f = sim_matrix_f[m_ids, i_ids, j_ids_di, j_ids_dj] | |
| sim_matrix_f = sim_matrix_f.reshape(-1, 9) | |
| sim_matrix_f = F.softmax(sim_matrix_f / self.local_regress_temperature, dim=-1) | |
| heatmap = sim_matrix_f.reshape(-1, 3, 3) | |
| # compute coordinates from heatmap | |
| coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] | |
| data.update({'expec_f': coords_normalized}) | |
| loss_l = self._compute_local_loss_l2(data['expec_f'], data['expec_f_gt']) | |
| loss += loss_l * self.loss_config['local_weight'] | |
| loss_scalars.update({"loss_l": loss_l.clone().detach().cpu()}) | |
| loss_scalars.update({'loss': loss.clone().detach().cpu()}) | |
| data.update({"loss": loss, "loss_scalars": loss_scalars}) |