Spaces:
Running
Running
| """ | |
| Code for loading models trained with CosPlace as a global features extractor | |
| for geolocalization through image retrieval. | |
| Multiple models are available with different backbones. Below is a summary of | |
| models available (backbone : list of available output descriptors | |
| dimensionality). For example you can use a model based on a ResNet50 with | |
| descriptors dimensionality 1024. | |
| ResNet18: [32, 64, 128, 256, 512] | |
| ResNet50: [32, 64, 128, 256, 512, 1024, 2048] | |
| ResNet101: [32, 64, 128, 256, 512, 1024, 2048] | |
| ResNet152: [32, 64, 128, 256, 512, 1024, 2048] | |
| VGG16: [ 64, 128, 256, 512] | |
| CosPlace paper: https://arxiv.org/abs/2204.02287 | |
| """ | |
| import torch | |
| import torchvision.transforms as tvf | |
| from ..utils.base_model import BaseModel | |
| class CosPlace(BaseModel): | |
| default_conf = {"backbone": "ResNet50", "fc_output_dim": 2048} | |
| required_inputs = ["image"] | |
| def _init(self, conf): | |
| self.net = torch.hub.load( | |
| "gmberton/CosPlace", | |
| "get_trained_model", | |
| backbone=conf["backbone"], | |
| fc_output_dim=conf["fc_output_dim"], | |
| ).eval() | |
| mean = [0.485, 0.456, 0.406] | |
| std = [0.229, 0.224, 0.225] | |
| self.norm_rgb = tvf.Normalize(mean=mean, std=std) | |
| def _forward(self, data): | |
| image = self.norm_rgb(data["image"]) | |
| desc = self.net(image) | |
| return { | |
| "global_descriptor": desc, | |
| } | |