Spaces:
Running
Running
| import kornia | |
| from kornia.feature.laf import laf_from_center_scale_ori, extract_patches_from_pyramid | |
| import numpy as np | |
| import torch | |
| import pycolmap | |
| from ..utils.base_model import BaseModel | |
| EPS = 1e-6 | |
| def sift_to_rootsift(x): | |
| x = x / (np.linalg.norm(x, ord=1, axis=-1, keepdims=True) + EPS) | |
| x = np.sqrt(x.clip(min=EPS)) | |
| x = x / (np.linalg.norm(x, axis=-1, keepdims=True) + EPS) | |
| return x | |
| class DoG(BaseModel): | |
| default_conf = { | |
| "options": { | |
| "first_octave": 0, | |
| "peak_threshold": 0.01, | |
| }, | |
| "descriptor": "rootsift", | |
| "max_keypoints": -1, | |
| "patch_size": 32, | |
| "mr_size": 12, | |
| } | |
| required_inputs = ["image"] | |
| detection_noise = 1.0 | |
| max_batch_size = 1024 | |
| def _init(self, conf): | |
| if conf["descriptor"] == "sosnet": | |
| self.describe = kornia.feature.SOSNet(pretrained=True) | |
| elif conf["descriptor"] == "hardnet": | |
| self.describe = kornia.feature.HardNet(pretrained=True) | |
| elif conf["descriptor"] not in ["sift", "rootsift"]: | |
| raise ValueError(f'Unknown descriptor: {conf["descriptor"]}') | |
| self.sift = None # lazily instantiated on the first image | |
| self.device = torch.device("cpu") | |
| def to(self, *args, **kwargs): | |
| device = kwargs.get("device") | |
| if device is None: | |
| match = [a for a in args if isinstance(a, (torch.device, str))] | |
| if len(match) > 0: | |
| device = match[0] | |
| if device is not None: | |
| self.device = torch.device(device) | |
| return super().to(*args, **kwargs) | |
| def _forward(self, data): | |
| image = data["image"] | |
| image_np = image.cpu().numpy()[0, 0] | |
| assert image.shape[1] == 1 | |
| assert image_np.min() >= -EPS and image_np.max() <= 1 + EPS | |
| if self.sift is None: | |
| use_gpu = pycolmap.has_cuda and self.device.type == "cuda" | |
| options = {**self.conf["options"]} | |
| if self.conf["descriptor"] == "rootsift": | |
| options["normalization"] = pycolmap.Normalization.L1_ROOT | |
| else: | |
| options["normalization"] = pycolmap.Normalization.L2 | |
| self.sift = pycolmap.Sift( | |
| options=pycolmap.SiftExtractionOptions(options), | |
| device=getattr(pycolmap.Device, "cuda" if use_gpu else "cpu"), | |
| ) | |
| keypoints, scores, descriptors = self.sift.extract(image_np) | |
| scales = keypoints[:, 2] | |
| oris = np.rad2deg(keypoints[:, 3]) | |
| if self.conf["descriptor"] in ["sift", "rootsift"]: | |
| # We still renormalize because COLMAP does not normalize well, | |
| # maybe due to numerical errors | |
| if self.conf["descriptor"] == "rootsift": | |
| descriptors = sift_to_rootsift(descriptors) | |
| descriptors = torch.from_numpy(descriptors) | |
| elif self.conf["descriptor"] in ("sosnet", "hardnet"): | |
| center = keypoints[:, :2] + 0.5 | |
| laf_scale = scales * self.conf["mr_size"] / 2 | |
| laf_ori = -oris | |
| lafs = laf_from_center_scale_ori( | |
| torch.from_numpy(center)[None], | |
| torch.from_numpy(laf_scale)[None, :, None, None], | |
| torch.from_numpy(laf_ori)[None, :, None], | |
| ).to(image.device) | |
| patches = extract_patches_from_pyramid( | |
| image, lafs, PS=self.conf["patch_size"] | |
| )[0] | |
| descriptors = patches.new_zeros((len(patches), 128)) | |
| if len(patches) > 0: | |
| for start_idx in range(0, len(patches), self.max_batch_size): | |
| end_idx = min(len(patches), start_idx + self.max_batch_size) | |
| descriptors[start_idx:end_idx] = self.describe( | |
| patches[start_idx:end_idx] | |
| ) | |
| else: | |
| raise ValueError(f'Unknown descriptor: {self.conf["descriptor"]}') | |
| keypoints = torch.from_numpy(keypoints[:, :2]) # keep only x, y | |
| scales = torch.from_numpy(scales) | |
| oris = torch.from_numpy(oris) | |
| scores = torch.from_numpy(scores) | |
| if self.conf["max_keypoints"] != -1: | |
| # TODO: check that the scores from PyCOLMAP are 100% correct, | |
| # follow https://github.com/mihaidusmanu/pycolmap/issues/8 | |
| max_number = ( | |
| scores.shape[0] | |
| if scores.shape[0] < self.conf["max_keypoints"] | |
| else self.conf["max_keypoints"] | |
| ) | |
| values, indices = torch.topk(scores, max_number) | |
| keypoints = keypoints[indices] | |
| scales = scales[indices] | |
| oris = oris[indices] | |
| scores = scores[indices] | |
| descriptors = descriptors[indices] | |
| return { | |
| "keypoints": keypoints[None], | |
| "scales": scales[None], | |
| "oris": oris[None], | |
| "scores": scores[None], | |
| "descriptors": descriptors.T[None], | |
| } | |