import torch import lietorch import numpy as np from lietorch import SE3 from factor_graph import FactorGraph class DroidBackend: def __init__(self, net, video, args): self.video = video self.update_op = net.update # global optimization window self.t0 = 0 self.t1 = 0 self.upsample = args.upsample self.beta = args.beta self.backend_thresh = args.backend_thresh self.backend_radius = args.backend_radius self.backend_nms = args.backend_nms self.errors = [] @torch.no_grad() def __call__(self, steps=12): """ main update """ t = self.video.counter.value if not self.video.stereo and not torch.any(self.video.disps_sens): self.video.normalize() graph = FactorGraph(self.video, self.update_op, corr_impl="alt", max_factors=16*t, upsample=self.upsample) graph.add_proximity_factors(rad=self.backend_radius, nms=self.backend_nms, thresh=self.backend_thresh, beta=self.beta) graph.update_lowmem(steps=steps) self.errors.append(self.cal_err(graph)) graph.clear_edges() self.video.dirty[:t] = True return def cal_err(self, graph): coord, _ = graph.video.reproject(graph.ii, graph.jj) diff = graph.target - coord err = diff.norm(dim=-1).mean().item() return err