Spaces:
Running
Running
File size: 3,830 Bytes
b7eedf7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import torch
import lietorch
import numpy as np
from lietorch import SE3
from factor_graph import FactorGraph
class DroidFrontend:
def __init__(self, net, video, args):
self.video = video
self.update_op = net.update
self.graph = FactorGraph(video, net.update, max_factors=48, upsample=args.upsample)
# local optimization window
self.t0 = 0
self.t1 = 0
# frontent variables
self.is_initialized = False
self.count = 0
self.max_age = 25
self.iters1 = 4
self.iters2 = 2
self.warmup = args.warmup
self.beta = args.beta
self.frontend_nms = args.frontend_nms
self.keyframe_thresh = args.keyframe_thresh
self.frontend_window = args.frontend_window
self.frontend_thresh = args.frontend_thresh
self.frontend_radius = args.frontend_radius
def __update(self):
""" add edges, perform update """
self.count += 1
self.t1 += 1
if self.graph.corr is not None:
self.graph.rm_factors(self.graph.age > self.max_age, store=True)
self.graph.add_proximity_factors(self.t1-5, max(self.t1-self.frontend_window, 0),
rad=self.frontend_radius, nms=self.frontend_nms, thresh=self.frontend_thresh, beta=self.beta, remove=True)
self.video.disps[self.t1-1] = torch.where(self.video.disps_sens[self.t1-1] > 0,
self.video.disps_sens[self.t1-1], self.video.disps[self.t1-1])
for itr in range(self.iters1):
self.graph.update(None, None, use_inactive=True)
# set initial pose for next frame
poses = SE3(self.video.poses)
d = self.video.distance([self.t1-3], [self.t1-2], beta=self.beta, bidirectional=True)
if d.item() < self.keyframe_thresh:
self.graph.rm_keyframe(self.t1 - 2)
with self.video.get_lock():
self.video.counter.value -= 1
self.t1 -= 1
else:
for itr in range(self.iters2):
self.graph.update(None, None, use_inactive=True)
# set pose for next itration
self.video.poses[self.t1] = self.video.poses[self.t1-1]
self.video.disps[self.t1] = self.video.disps[self.t1-1].mean()
# update visualization
self.video.dirty[self.graph.ii.min():self.t1] = True
def __initialize(self):
""" initialize the SLAM system """
self.t0 = 0
self.t1 = self.video.counter.value
self.graph.add_neighborhood_factors(self.t0, self.t1, r=3)
for itr in range(8):
self.graph.update(1, use_inactive=True)
self.graph.add_proximity_factors(0, 0, rad=2, nms=2, thresh=self.frontend_thresh, remove=False)
for itr in range(8):
self.graph.update(1, use_inactive=True)
# self.video.normalize()
self.video.poses[self.t1] = self.video.poses[self.t1-1].clone()
self.video.disps[self.t1] = self.video.disps[self.t1-4:self.t1].mean()
# initialization complete
self.is_initialized = True
self.last_pose = self.video.poses[self.t1-1].clone()
self.last_disp = self.video.disps[self.t1-1].clone()
self.last_time = self.video.tstamp[self.t1-1].clone()
with self.video.get_lock():
self.video.ready.value = 1
self.video.dirty[:self.t1] = True
self.graph.rm_factors(self.graph.ii < self.warmup-4, store=True)
def __call__(self):
""" main update """
# do initialization
if not self.is_initialized and self.video.counter.value == self.warmup:
self.__initialize()
# do update
elif self.is_initialized and self.t1 < self.video.counter.value:
self.__update()
|