Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import poselib | |
| class DenseMatcher(nn.Module): | |
| def __init__(self, inv_temperature = 20, thr = 0.01): | |
| super().__init__() | |
| self.inv_temperature = inv_temperature | |
| self.thr = thr | |
| def forward(self, info0, info1, thr = None, err_thr=4, min_num_inliers=30): | |
| desc0 = info0['descriptors'] | |
| desc1 = info1['descriptors'] | |
| inds, P = self.dual_softmax(desc0, desc1, thr=thr) | |
| mkpts_0 = info0['keypoints'][inds[:,0]] | |
| mkpts_1 = info1['keypoints'][inds[:,1]] | |
| mconf = P[inds[:,0], inds[:,1]] | |
| Fm, inliers = self.get_fundamental_matrix(mkpts_0, mkpts_1) | |
| if inliers.sum() >= min_num_inliers: | |
| desc1_dense = info0['descriptors_dense'] | |
| desc2_dense = info1['descriptors_dense'] | |
| inds_dense, P_dense = self.dual_softmax(desc1_dense, desc2_dense, thr=thr) | |
| mkpts_0_dense = info0['keypoints_dense'][inds_dense[:,0]] | |
| mkpts_1_dense = info1['keypoints_dense'][inds_dense[:,1]] | |
| mconf_dense = P_dense[inds_dense[:,0], inds_dense[:,1]] | |
| mkpts_0_dense, mkpts_1_dense, mconf_dense = self.refine_matches(mkpts_0_dense, mkpts_1_dense, mconf_dense, Fm, err_thr=err_thr) | |
| mkpts_0 = mkpts_0[inliers] | |
| mkpts_1 = mkpts_1[inliers] | |
| mconf = mconf[inliers] | |
| # concatenate the matches | |
| mkpts_0 = torch.cat([mkpts_0, mkpts_0_dense], dim=0) | |
| mkpts_1 = torch.cat([mkpts_1, mkpts_1_dense], dim=0) | |
| mconf = torch.cat([mconf, mconf_dense], dim=0) | |
| return mkpts_0, mkpts_1, mconf | |
| def get_fundamental_matrix(self, kpts_0, kpts_1): | |
| Fm, info = poselib.estimate_fundamental(kpts_0.cpu().numpy(), kpts_1.cpu().numpy(), {'max_epipolar_error': 1, 'progressive_sampling': True}, {}) | |
| inliers = info['inliers'] | |
| Fm = torch.tensor(Fm, device=kpts_0.device, dtype=kpts_0.dtype) | |
| inliers = torch.tensor(inliers, device=kpts_0.device, dtype=torch.bool) | |
| return Fm, inliers | |
| def dual_softmax(self, desc0, desc1, thr = None): | |
| if thr is None: | |
| thr = self.thr | |
| dist_mat = (desc0 @ desc1.t()) * self.inv_temperature | |
| P = dist_mat.softmax(dim = -2) * dist_mat.softmax(dim= -1) | |
| inds = torch.nonzero((P == P.max(dim=-1, keepdim = True).values) | |
| * (P == P.max(dim=-2, keepdim = True).values) * (P >= thr)) | |
| return inds, P | |
| def refine_matches(self, mkpts_0, mkpts_1, mconf, Fm, err_thr=4): | |
| mkpts_0_h = torch.cat([mkpts_0, torch.ones(mkpts_0.shape[0], 1, device=mkpts_0.device)], dim=1) # (N, 3) | |
| mkpts_1_h = torch.cat([mkpts_1, torch.ones(mkpts_1.shape[0], 1, device=mkpts_1.device)], dim=1) # (N, 3) | |
| lines_1 = torch.matmul(Fm, mkpts_0_h.T).T | |
| a, b, c = lines_1[:, 0], lines_1[:, 1], lines_1[:, 2] | |
| x1, y1 = mkpts_1[:, 0], mkpts_1[:, 1] | |
| denom = a**2 + b**2 + 1e-8 | |
| x_offset = (b * (b * x1 - a * y1) - a * c) / denom - x1 | |
| y_offset = (a * (a * y1 - b * x1) - b * c) / denom - y1 | |
| inds = (x_offset.abs() < err_thr) | (y_offset.abs() < err_thr) | |
| x_offset = x_offset[inds] | |
| y_offset = y_offset[inds] | |
| mkpts_0 = mkpts_0[inds] | |
| mkpts_1 = mkpts_1[inds] | |
| refined_mkpts_1 = mkpts_1 + torch.stack([x_offset, y_offset], dim=1) | |
| return mkpts_0, refined_mkpts_1, mconf[inds] | |