Spaces:
Running
Running
File size: 1,518 Bytes
b7eedf7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import torch
import lietorch
import numpy as np
from lietorch import SE3
from factor_graph import FactorGraph
class DroidBackend:
def __init__(self, net, video, args):
self.video = video
self.update_op = net.update
# global optimization window
self.t0 = 0
self.t1 = 0
self.upsample = args.upsample
self.beta = args.beta
self.backend_thresh = args.backend_thresh
self.backend_radius = args.backend_radius
self.backend_nms = args.backend_nms
self.errors = []
@torch.no_grad()
def __call__(self, steps=12):
""" main update """
t = self.video.counter.value
if not self.video.stereo and not torch.any(self.video.disps_sens):
self.video.normalize()
graph = FactorGraph(self.video, self.update_op, corr_impl="alt", max_factors=16*t, upsample=self.upsample)
graph.add_proximity_factors(rad=self.backend_radius,
nms=self.backend_nms,
thresh=self.backend_thresh,
beta=self.beta)
graph.update_lowmem(steps=steps)
self.errors.append(self.cal_err(graph))
graph.clear_edges()
self.video.dirty[:t] = True
return
def cal_err(self, graph):
coord, _ = graph.video.reproject(graph.ii, graph.jj)
diff = graph.target - coord
err = diff.norm(dim=-1).mean().item()
return err
|