Spaces:
Running
Running
import torch | |
import lietorch | |
import numpy as np | |
from lietorch import SE3 | |
from factor_graph import FactorGraph | |
class DroidBackend: | |
def __init__(self, net, video, args): | |
self.video = video | |
self.update_op = net.update | |
# global optimization window | |
self.t0 = 0 | |
self.t1 = 0 | |
self.upsample = args.upsample | |
self.beta = args.beta | |
self.backend_thresh = args.backend_thresh | |
self.backend_radius = args.backend_radius | |
self.backend_nms = args.backend_nms | |
self.errors = [] | |
def __call__(self, steps=12): | |
""" main update """ | |
t = self.video.counter.value | |
if not self.video.stereo and not torch.any(self.video.disps_sens): | |
self.video.normalize() | |
graph = FactorGraph(self.video, self.update_op, corr_impl="alt", max_factors=16*t, upsample=self.upsample) | |
graph.add_proximity_factors(rad=self.backend_radius, | |
nms=self.backend_nms, | |
thresh=self.backend_thresh, | |
beta=self.beta) | |
graph.update_lowmem(steps=steps) | |
self.errors.append(self.cal_err(graph)) | |
graph.clear_edges() | |
self.video.dirty[:t] = True | |
return | |
def cal_err(self, graph): | |
coord, _ = graph.video.reproject(graph.ii, graph.jj) | |
diff = graph.target - coord | |
err = diff.norm(dim=-1).mean().item() | |
return err | |