Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from einops.einops import rearrange | |
| from .backbone import build_backbone | |
| from .modules import LocalFeatureTransformer, FinePreprocess, TopicFormer | |
| from .utils.coarse_matching import CoarseMatching | |
| from .utils.fine_matching import FineMatching | |
| class TopicFM(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| # Misc | |
| self.config = config | |
| # Modules | |
| self.backbone = build_backbone(config) | |
| self.loftr_coarse = TopicFormer(config["coarse"]) | |
| self.coarse_matching = CoarseMatching(config["match_coarse"]) | |
| self.fine_preprocess = FinePreprocess(config) | |
| self.loftr_fine = LocalFeatureTransformer(config["fine"]) | |
| self.fine_matching = FineMatching() | |
| def forward(self, data): | |
| """ | |
| Update: | |
| data (dict): { | |
| 'image0': (torch.Tensor): (N, 1, H, W) | |
| 'image1': (torch.Tensor): (N, 1, H, W) | |
| 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position | |
| 'mask1'(optional) : (torch.Tensor): (N, H, W) | |
| } | |
| """ | |
| # 1. Local Feature CNN | |
| data.update( | |
| { | |
| "bs": data["image0"].size(0), | |
| "hw0_i": data["image0"].shape[2:], | |
| "hw1_i": data["image1"].shape[2:], | |
| } | |
| ) | |
| if data["hw0_i"] == data["hw1_i"]: # faster & better BN convergence | |
| feats_c, feats_f = self.backbone( | |
| torch.cat([data["image0"], data["image1"]], dim=0) | |
| ) | |
| (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split( | |
| data["bs"] | |
| ), feats_f.split(data["bs"]) | |
| else: # handle different input shapes | |
| (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone( | |
| data["image0"] | |
| ), self.backbone(data["image1"]) | |
| data.update( | |
| { | |
| "hw0_c": feat_c0.shape[2:], | |
| "hw1_c": feat_c1.shape[2:], | |
| "hw0_f": feat_f0.shape[2:], | |
| "hw1_f": feat_f1.shape[2:], | |
| } | |
| ) | |
| # 2. coarse-level loftr module | |
| feat_c0 = rearrange(feat_c0, "n c h w -> n (h w) c") | |
| feat_c1 = rearrange(feat_c1, "n c h w -> n (h w) c") | |
| mask_c0 = mask_c1 = None # mask is useful in training | |
| if "mask0" in data: | |
| mask_c0, mask_c1 = data["mask0"].flatten(-2), data["mask1"].flatten(-2) | |
| feat_c0, feat_c1, conf_matrix, topic_matrix = self.loftr_coarse( | |
| feat_c0, feat_c1, mask_c0, mask_c1 | |
| ) | |
| data.update({"conf_matrix": conf_matrix, "topic_matrix": topic_matrix}) ###### | |
| # 3. match coarse-level | |
| self.coarse_matching(data) | |
| # 4. fine-level refinement | |
| feat_f0_unfold, feat_f1_unfold = self.fine_preprocess( | |
| feat_f0, feat_f1, feat_c0.detach(), feat_c1.detach(), data | |
| ) | |
| if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted | |
| feat_f0_unfold, feat_f1_unfold = self.loftr_fine( | |
| feat_f0_unfold, feat_f1_unfold | |
| ) | |
| # 5. match fine-level | |
| self.fine_matching(feat_f0_unfold, feat_f1_unfold, data) | |
| def load_state_dict(self, state_dict, *args, **kwargs): | |
| for k in list(state_dict.keys()): | |
| if k.startswith("matcher."): | |
| state_dict[k.replace("matcher.", "", 1)] = state_dict.pop(k) | |
| return super().load_state_dict(state_dict, *args, **kwargs) | |