| from typing import Any | |
| import torch.nn as nn | |
| from .backbone import Backbone | |
| from .backbone_croco_multiview import AsymmetricCroCoMulti | |
| from .backbone_dino import BackboneDino, BackboneDinoCfg | |
| from .backbone_resnet import BackboneResnet, BackboneResnetCfg | |
| from .backbone_croco import AsymmetricCroCo, BackboneCrocoCfg | |
| BACKBONES: dict[str, Backbone[Any]] = { | |
| "resnet": BackboneResnet, | |
| "dino": BackboneDino, | |
| "croco": AsymmetricCroCo, | |
| "croco_multi": AsymmetricCroCoMulti, | |
| } | |
| BackboneCfg = BackboneResnetCfg | BackboneDinoCfg | BackboneCrocoCfg | |
| def get_backbone(cfg: BackboneCfg, d_in: int = 3) -> nn.Module: | |
| return BACKBONES[cfg.name](cfg, d_in) | |