Spaces:
Sleeping
Sleeping
| import sys | |
| from pathlib import Path | |
| import torchvision.transforms as tvf | |
| from ..utils.base_model import BaseModel | |
| base_path = Path(__file__).parent / "../../third_party" | |
| sys.path.append(str(base_path)) | |
| r2d2_path = Path(__file__).parent / "../../third_party/r2d2" | |
| from r2d2.extract import load_network, NonMaxSuppression, extract_multiscale | |
| class R2D2(BaseModel): | |
| default_conf = { | |
| "model_name": "r2d2_WASF_N16.pt", | |
| "max_keypoints": 5000, | |
| "scale_factor": 2**0.25, | |
| "min_size": 256, | |
| "max_size": 1024, | |
| "min_scale": 0, | |
| "max_scale": 1, | |
| "reliability_threshold": 0.7, | |
| "repetability_threshold": 0.7, | |
| } | |
| required_inputs = ["image"] | |
| def _init(self, conf): | |
| model_fn = r2d2_path / "models" / conf["model_name"] | |
| self.norm_rgb = tvf.Normalize( | |
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
| ) | |
| self.net = load_network(model_fn) | |
| self.detector = NonMaxSuppression( | |
| rel_thr=conf["reliability_threshold"], | |
| rep_thr=conf["repetability_threshold"], | |
| ) | |
| def _forward(self, data): | |
| img = data["image"] | |
| img = self.norm_rgb(img) | |
| xys, desc, scores = extract_multiscale( | |
| self.net, | |
| img, | |
| self.detector, | |
| scale_f=self.conf["scale_factor"], | |
| min_size=self.conf["min_size"], | |
| max_size=self.conf["max_size"], | |
| min_scale=self.conf["min_scale"], | |
| max_scale=self.conf["max_scale"], | |
| ) | |
| idxs = scores.argsort()[-self.conf["max_keypoints"] or None :] | |
| xy = xys[idxs, :2] | |
| desc = desc[idxs].t() | |
| scores = scores[idxs] | |
| pred = { | |
| "keypoints": xy[None], | |
| "descriptors": desc[None], | |
| "scores": scores[None], | |
| } | |
| return pred | |