Spaces:
Running
Running
import torch | |
import lietorch | |
import numpy as np | |
from droid_net import DroidNet | |
from depth_video import DepthVideo | |
from motion_filter import MotionFilter | |
from droid_frontend import DroidFrontend | |
from droid_backend import DroidBackend | |
from trajectory_filler import PoseTrajectoryFiller | |
from collections import OrderedDict | |
from torch.multiprocessing import Process | |
class Droid: | |
def __init__(self, args): | |
super(Droid, self).__init__() | |
self.load_weights(args.weights) | |
self.args = args | |
self.disable_vis = args.disable_vis | |
# store images, depth, poses, intrinsics (shared between processes) | |
self.video = DepthVideo(args.image_size, args.buffer, stereo=args.stereo) | |
# filter incoming frames so that there is enough motion | |
self.filterx = MotionFilter(self.net, self.video, thresh=args.filter_thresh) | |
# frontend process | |
self.frontend = DroidFrontend(self.net, self.video, self.args) | |
# backend process | |
self.backend = DroidBackend(self.net, self.video, self.args) | |
# visualizer | |
if not self.disable_vis: | |
# from visualization import droid_visualization | |
from vis_headless import droid_visualization | |
print('Using headless ...') | |
self.visualizer = Process(target=droid_visualization, args=(self.video, '.')) | |
self.visualizer.start() | |
# post processor - fill in poses for non-keyframes | |
self.traj_filler = PoseTrajectoryFiller(self.net, self.video) | |
def load_weights(self, weights): | |
""" load trained model weights """ | |
self.net = DroidNet() | |
state_dict = OrderedDict([ | |
(k.replace("module.", ""), v) for (k, v) in torch.load(weights).items()]) | |
state_dict["update.weight.2.weight"] = state_dict["update.weight.2.weight"][:2] | |
state_dict["update.weight.2.bias"] = state_dict["update.weight.2.bias"][:2] | |
state_dict["update.delta.2.weight"] = state_dict["update.delta.2.weight"][:2] | |
state_dict["update.delta.2.bias"] = state_dict["update.delta.2.bias"][:2] | |
self.net.load_state_dict(state_dict) | |
self.net.to("cuda:0").eval() | |
def track(self, tstamp, image, depth=None, intrinsics=None, mask=None): | |
""" main thread - update map """ | |
with torch.no_grad(): | |
# check there is enough motion | |
self.filterx.track(tstamp, image, depth, intrinsics, mask) | |
# local bundle adjustment | |
self.frontend() | |
# global bundle adjustment | |
# self.backend() | |
def terminate(self, stream=None, backend=True): | |
""" terminate the visualization process, return poses [t, q] """ | |
del self.frontend | |
if backend: | |
torch.cuda.empty_cache() | |
# print("#" * 32) | |
self.backend(7) | |
torch.cuda.empty_cache() | |
# print("#" * 32) | |
self.backend(12) | |
camera_trajectory = self.traj_filler(stream) | |
return camera_trajectory.inv().data.cpu().numpy() | |
def compute_error(self): | |
""" compute slam reprojection error """ | |
del self.frontend | |
torch.cuda.empty_cache() | |
self.backend(12) | |
return self.backend.errors[-1] | |