Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops.einops import rearrange | |
| INF = 1e9 | |
| def mask_border(m, b: int, v): | |
| """Mask borders with value | |
| Args: | |
| m (torch.Tensor): [N, H0, W0, H1, W1] | |
| b (int) | |
| v (m.dtype) | |
| """ | |
| if b <= 0: | |
| return | |
| m[:, :b] = v | |
| m[:, :, :b] = v | |
| m[:, :, :, :b] = v | |
| m[:, :, :, :, :b] = v | |
| m[:, -b:] = v | |
| m[:, :, -b:] = v | |
| m[:, :, :, -b:] = v | |
| m[:, :, :, :, -b:] = v | |
| def mask_border_with_padding(m, bd, v, p_m0, p_m1): | |
| if bd <= 0: | |
| return | |
| m[:, :bd] = v | |
| m[:, :, :bd] = v | |
| m[:, :, :, :bd] = v | |
| m[:, :, :, :, :bd] = v | |
| h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int() | |
| h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int() | |
| for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)): | |
| m[b_idx, h0 - bd :] = v | |
| m[b_idx, :, w0 - bd :] = v | |
| m[b_idx, :, :, h1 - bd :] = v | |
| m[b_idx, :, :, :, w1 - bd :] = v | |
| def compute_max_candidates(p_m0, p_m1): | |
| """Compute the max candidates of all pairs within a batch | |
| Args: | |
| p_m0, p_m1 (torch.Tensor): padded masks | |
| """ | |
| h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0] | |
| h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0] | |
| max_cand = torch.sum(torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0]) | |
| return max_cand | |
| class CoarseMatching(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| # general config | |
| self.thr = config["thr"] | |
| self.border_rm = config["border_rm"] | |
| # -- # for trainig fine-level LoFTR | |
| self.train_coarse_percent = config["train_coarse_percent"] | |
| self.train_pad_num_gt_min = config["train_pad_num_gt_min"] | |
| # we provide 2 options for differentiable matching | |
| self.match_type = config["match_type"] | |
| if self.match_type == "dual_softmax": | |
| self.temperature = config["dsmax_temperature"] | |
| elif self.match_type == "sinkhorn": | |
| try: | |
| from .superglue import log_optimal_transport | |
| except ImportError: | |
| raise ImportError("download superglue.py first!") | |
| self.log_optimal_transport = log_optimal_transport | |
| self.bin_score = nn.Parameter( | |
| torch.tensor(config["skh_init_bin_score"], requires_grad=True) | |
| ) | |
| self.skh_iters = config["skh_iters"] | |
| self.skh_prefilter = config["skh_prefilter"] | |
| else: | |
| raise NotImplementedError() | |
| def forward(self, data): | |
| """ | |
| Args: | |
| data (dict) | |
| Update: | |
| data (dict): { | |
| 'b_ids' (torch.Tensor): [M'], | |
| 'i_ids' (torch.Tensor): [M'], | |
| 'j_ids' (torch.Tensor): [M'], | |
| 'gt_mask' (torch.Tensor): [M'], | |
| 'mkpts0_c' (torch.Tensor): [M, 2], | |
| 'mkpts1_c' (torch.Tensor): [M, 2], | |
| 'mconf' (torch.Tensor): [M]} | |
| NOTE: M' != M during training. | |
| """ | |
| conf_matrix = data["conf_matrix"] | |
| # predict coarse matches from conf_matrix | |
| data.update(**self.get_coarse_match(conf_matrix, data)) | |
| def get_coarse_match(self, conf_matrix, data): | |
| """ | |
| Args: | |
| conf_matrix (torch.Tensor): [N, L, S] | |
| data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c'] | |
| Returns: | |
| coarse_matches (dict): { | |
| 'b_ids' (torch.Tensor): [M'], | |
| 'i_ids' (torch.Tensor): [M'], | |
| 'j_ids' (torch.Tensor): [M'], | |
| 'gt_mask' (torch.Tensor): [M'], | |
| 'm_bids' (torch.Tensor): [M], | |
| 'mkpts0_c' (torch.Tensor): [M, 2], | |
| 'mkpts1_c' (torch.Tensor): [M, 2], | |
| 'mconf' (torch.Tensor): [M]} | |
| """ | |
| axes_lengths = { | |
| "h0c": data["hw0_c"][0], | |
| "w0c": data["hw0_c"][1], | |
| "h1c": data["hw1_c"][0], | |
| "w1c": data["hw1_c"][1], | |
| } | |
| _device = conf_matrix.device | |
| # 1. confidence thresholding | |
| mask = conf_matrix > self.thr | |
| mask = rearrange( | |
| mask, "b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c", **axes_lengths | |
| ) | |
| if "mask0" not in data: | |
| mask_border(mask, self.border_rm, False) | |
| else: | |
| mask_border_with_padding( | |
| mask, self.border_rm, False, data["mask0"], data["mask1"] | |
| ) | |
| mask = rearrange( | |
| mask, "b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)", **axes_lengths | |
| ) | |
| # 2. mutual nearest | |
| mask = ( | |
| mask | |
| * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) | |
| * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0]) | |
| ) | |
| # 3. find all valid coarse matches | |
| # this only works when at most one `True` in each row | |
| mask_v, all_j_ids = mask.max(dim=2) | |
| b_ids, i_ids = torch.where(mask_v) | |
| j_ids = all_j_ids[b_ids, i_ids] | |
| mconf = conf_matrix[b_ids, i_ids, j_ids] | |
| # 4. Random sampling of training samples for fine-level LoFTR | |
| # (optional) pad samples with gt coarse-level matches | |
| if self.training: | |
| # NOTE: | |
| # The sampling is performed across all pairs in a batch without manually balancing | |
| # #samples for fine-level increases w.r.t. batch_size | |
| if "mask0" not in data: | |
| num_candidates_max = mask.size(0) * max(mask.size(1), mask.size(2)) | |
| else: | |
| num_candidates_max = compute_max_candidates( | |
| data["mask0"], data["mask1"] | |
| ) | |
| num_matches_train = int(num_candidates_max * self.train_coarse_percent) | |
| num_matches_pred = len(b_ids) | |
| assert ( | |
| self.train_pad_num_gt_min < num_matches_train | |
| ), "min-num-gt-pad should be less than num-train-matches" | |
| # pred_indices is to select from prediction | |
| if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min: | |
| pred_indices = torch.arange(num_matches_pred, device=_device) | |
| else: | |
| pred_indices = torch.randint( | |
| num_matches_pred, | |
| (num_matches_train - self.train_pad_num_gt_min,), | |
| device=_device, | |
| ) | |
| # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200) | |
| gt_pad_indices = torch.randint( | |
| len(data["spv_b_ids"]), | |
| (max(num_matches_train - num_matches_pred, self.train_pad_num_gt_min),), | |
| device=_device, | |
| ) | |
| mconf_gt = torch.zeros( | |
| len(data["spv_b_ids"]), device=_device | |
| ) # set conf of gt paddings to all zero | |
| b_ids, i_ids, j_ids, mconf = map( | |
| lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]], dim=0), | |
| *zip( | |
| [b_ids, data["spv_b_ids"]], | |
| [i_ids, data["spv_i_ids"]], | |
| [j_ids, data["spv_j_ids"]], | |
| [mconf, mconf_gt], | |
| ) | |
| ) | |
| # These matches select patches that feed into fine-level network | |
| coarse_matches = {"b_ids": b_ids, "i_ids": i_ids, "j_ids": j_ids} | |
| # 4. Update with matches in original image resolution | |
| scale = data["hw0_i"][0] / data["hw0_c"][0] | |
| scale0 = scale * data["scale0"][b_ids] if "scale0" in data else scale | |
| scale1 = scale * data["scale1"][b_ids] if "scale1" in data else scale | |
| mkpts0_c = ( | |
| torch.stack([i_ids % data["hw0_c"][1], i_ids // data["hw0_c"][1]], dim=1) | |
| * scale0 | |
| ) | |
| mkpts1_c = ( | |
| torch.stack([j_ids % data["hw1_c"][1], j_ids // data["hw1_c"][1]], dim=1) | |
| * scale1 | |
| ) | |
| # These matches is the current prediction (for visualization) | |
| coarse_matches.update( | |
| { | |
| "gt_mask": mconf == 0, | |
| "m_bids": b_ids[mconf != 0], # mconf == 0 => gt matches | |
| "mkpts0_c": mkpts0_c[mconf != 0], | |
| "mkpts1_c": mkpts1_c[mconf != 0], | |
| "mconf": mconf[mconf != 0], | |
| } | |
| ) | |
| return coarse_matches | |