ThunderVVV's picture
add thirdparty
b7eedf7
import torch
import lietorch
import numpy as np
from lietorch import SE3
from factor_graph import FactorGraph
class DroidFrontend:
def __init__(self, net, video, args):
self.video = video
self.update_op = net.update
self.graph = FactorGraph(video, net.update, max_factors=48, upsample=args.upsample)
# local optimization window
self.t0 = 0
self.t1 = 0
# frontent variables
self.is_initialized = False
self.count = 0
self.max_age = 25
self.iters1 = 4
self.iters2 = 2
self.warmup = args.warmup
self.beta = args.beta
self.frontend_nms = args.frontend_nms
self.keyframe_thresh = args.keyframe_thresh
self.frontend_window = args.frontend_window
self.frontend_thresh = args.frontend_thresh
self.frontend_radius = args.frontend_radius
def __update(self):
""" add edges, perform update """
self.count += 1
self.t1 += 1
if self.graph.corr is not None:
self.graph.rm_factors(self.graph.age > self.max_age, store=True)
self.graph.add_proximity_factors(self.t1-5, max(self.t1-self.frontend_window, 0),
rad=self.frontend_radius, nms=self.frontend_nms, thresh=self.frontend_thresh, beta=self.beta, remove=True)
self.video.disps[self.t1-1] = torch.where(self.video.disps_sens[self.t1-1] > 0,
self.video.disps_sens[self.t1-1], self.video.disps[self.t1-1])
for itr in range(self.iters1):
self.graph.update(None, None, use_inactive=True)
# set initial pose for next frame
poses = SE3(self.video.poses)
d = self.video.distance([self.t1-3], [self.t1-2], beta=self.beta, bidirectional=True)
if d.item() < self.keyframe_thresh:
self.graph.rm_keyframe(self.t1 - 2)
with self.video.get_lock():
self.video.counter.value -= 1
self.t1 -= 1
else:
for itr in range(self.iters2):
self.graph.update(None, None, use_inactive=True)
# set pose for next itration
self.video.poses[self.t1] = self.video.poses[self.t1-1]
self.video.disps[self.t1] = self.video.disps[self.t1-1].mean()
# update visualization
self.video.dirty[self.graph.ii.min():self.t1] = True
def __initialize(self):
""" initialize the SLAM system """
self.t0 = 0
self.t1 = self.video.counter.value
self.graph.add_neighborhood_factors(self.t0, self.t1, r=3)
for itr in range(8):
self.graph.update(1, use_inactive=True)
self.graph.add_proximity_factors(0, 0, rad=2, nms=2, thresh=self.frontend_thresh, remove=False)
for itr in range(8):
self.graph.update(1, use_inactive=True)
# self.video.normalize()
self.video.poses[self.t1] = self.video.poses[self.t1-1].clone()
self.video.disps[self.t1] = self.video.disps[self.t1-4:self.t1].mean()
# initialization complete
self.is_initialized = True
self.last_pose = self.video.poses[self.t1-1].clone()
self.last_disp = self.video.disps[self.t1-1].clone()
self.last_time = self.video.tstamp[self.t1-1].clone()
with self.video.get_lock():
self.video.ready.value = 1
self.video.dirty[:self.t1] = True
self.graph.rm_factors(self.graph.ii < self.warmup-4, store=True)
def __call__(self):
""" main update """
# do initialization
if not self.is_initialized and self.video.counter.value == self.warmup:
self.__initialize()
# do update
elif self.is_initialized and self.t1 < self.video.counter.value:
self.__update()