ThunderVVV's picture
add thirdparty
b7eedf7
import torch
import lietorch
import numpy as np
from lietorch import SE3
from factor_graph import FactorGraph
class DroidBackend:
def __init__(self, net, video, args):
self.video = video
self.update_op = net.update
# global optimization window
self.t0 = 0
self.t1 = 0
self.upsample = args.upsample
self.beta = args.beta
self.backend_thresh = args.backend_thresh
self.backend_radius = args.backend_radius
self.backend_nms = args.backend_nms
self.errors = []
@torch.no_grad()
def __call__(self, steps=12):
""" main update """
t = self.video.counter.value
if not self.video.stereo and not torch.any(self.video.disps_sens):
self.video.normalize()
graph = FactorGraph(self.video, self.update_op, corr_impl="alt", max_factors=16*t, upsample=self.upsample)
graph.add_proximity_factors(rad=self.backend_radius,
nms=self.backend_nms,
thresh=self.backend_thresh,
beta=self.beta)
graph.update_lowmem(steps=steps)
self.errors.append(self.cal_err(graph))
graph.clear_edges()
self.video.dirty[:t] = True
return
def cal_err(self, graph):
coord, _ = graph.video.reproject(graph.ii, graph.jj)
diff = graph.target - coord
err = diff.norm(dim=-1).mean().item()
return err