import torch import lietorch import numpy as np from lietorch import SE3 from factor_graph import FactorGraph class DroidFrontend: def __init__(self, net, video, args): self.video = video self.update_op = net.update self.graph = FactorGraph(video, net.update, max_factors=48, upsample=args.upsample) # local optimization window self.t0 = 0 self.t1 = 0 # frontent variables self.is_initialized = False self.count = 0 self.max_age = 25 self.iters1 = 4 self.iters2 = 2 self.warmup = args.warmup self.beta = args.beta self.frontend_nms = args.frontend_nms self.keyframe_thresh = args.keyframe_thresh self.frontend_window = args.frontend_window self.frontend_thresh = args.frontend_thresh self.frontend_radius = args.frontend_radius def __update(self): """ add edges, perform update """ self.count += 1 self.t1 += 1 if self.graph.corr is not None: self.graph.rm_factors(self.graph.age > self.max_age, store=True) self.graph.add_proximity_factors(self.t1-5, max(self.t1-self.frontend_window, 0), rad=self.frontend_radius, nms=self.frontend_nms, thresh=self.frontend_thresh, beta=self.beta, remove=True) self.video.disps[self.t1-1] = torch.where(self.video.disps_sens[self.t1-1] > 0, self.video.disps_sens[self.t1-1], self.video.disps[self.t1-1]) for itr in range(self.iters1): self.graph.update(None, None, use_inactive=True) # set initial pose for next frame poses = SE3(self.video.poses) d = self.video.distance([self.t1-3], [self.t1-2], beta=self.beta, bidirectional=True) if d.item() < self.keyframe_thresh: self.graph.rm_keyframe(self.t1 - 2) with self.video.get_lock(): self.video.counter.value -= 1 self.t1 -= 1 else: for itr in range(self.iters2): self.graph.update(None, None, use_inactive=True) # set pose for next itration self.video.poses[self.t1] = self.video.poses[self.t1-1] self.video.disps[self.t1] = self.video.disps[self.t1-1].mean() # update visualization self.video.dirty[self.graph.ii.min():self.t1] = True def __initialize(self): """ initialize the SLAM system """ self.t0 = 0 self.t1 = self.video.counter.value self.graph.add_neighborhood_factors(self.t0, self.t1, r=3) for itr in range(8): self.graph.update(1, use_inactive=True) self.graph.add_proximity_factors(0, 0, rad=2, nms=2, thresh=self.frontend_thresh, remove=False) for itr in range(8): self.graph.update(1, use_inactive=True) # self.video.normalize() self.video.poses[self.t1] = self.video.poses[self.t1-1].clone() self.video.disps[self.t1] = self.video.disps[self.t1-4:self.t1].mean() # initialization complete self.is_initialized = True self.last_pose = self.video.poses[self.t1-1].clone() self.last_disp = self.video.disps[self.t1-1].clone() self.last_time = self.video.tstamp[self.t1-1].clone() with self.video.get_lock(): self.video.ready.value = 1 self.video.dirty[:self.t1] = True self.graph.rm_factors(self.graph.ii < self.warmup-4, store=True) def __call__(self): """ main update """ # do initialization if not self.is_initialized and self.video.counter.value == self.warmup: self.__initialize() # do update elif self.is_initialized and self.t1 < self.video.counter.value: self.__update()