Spaces:
Sleeping
Sleeping
| import sys | |
| from pathlib import Path | |
| import torch | |
| from ..utils.base_model import BaseModel | |
| sys.path.append(str(Path(__file__).parent / "../../third_party")) | |
| from SuperGluePretrainedNetwork.models import superpoint # noqa E402 | |
| # The original keypoint sampling is incorrect. We patch it here but | |
| # we don't fix it upstream to not impact exisiting evaluations. | |
| def sample_descriptors_fix_sampling(keypoints, descriptors, s: int = 8): | |
| """Interpolate descriptors at keypoint locations""" | |
| b, c, h, w = descriptors.shape | |
| keypoints = (keypoints + 0.5) / (keypoints.new_tensor([w, h]) * s) | |
| keypoints = keypoints * 2 - 1 # normalize to (-1, 1) | |
| descriptors = torch.nn.functional.grid_sample( | |
| descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", align_corners=False | |
| ) | |
| descriptors = torch.nn.functional.normalize( | |
| descriptors.reshape(b, c, -1), p=2, dim=1 | |
| ) | |
| return descriptors | |
| class SuperPoint(BaseModel): | |
| default_conf = { | |
| "nms_radius": 4, | |
| "keypoint_threshold": 0.005, | |
| "max_keypoints": -1, | |
| "remove_borders": 4, | |
| "fix_sampling": False, | |
| } | |
| required_inputs = ["image"] | |
| detection_noise = 2.0 | |
| def _init(self, conf): | |
| if conf["fix_sampling"]: | |
| superpoint.sample_descriptors = sample_descriptors_fix_sampling | |
| self.net = superpoint.SuperPoint(conf) | |
| def _forward(self, data): | |
| return self.net(data) | |