import cv2 import torch import lietorch from lietorch import SE3 from collections import OrderedDict from factor_graph import FactorGraph from droid_net import DroidNet import geom.projective_ops as pops class PoseTrajectoryFiller: """ This class is used to fill in non-keyframe poses """ def __init__(self, net, video, device="cuda:0"): # split net modules self.cnet = net.cnet self.fnet = net.fnet self.update = net.update self.count = 0 self.video = video self.device = device # mean, std for image normalization self.MEAN = torch.as_tensor([0.485, 0.456, 0.406], device=self.device)[:, None, None] self.STDV = torch.as_tensor([0.229, 0.224, 0.225], device=self.device)[:, None, None] @torch.cuda.amp.autocast(enabled=True) def __feature_encoder(self, image): """ features for correlation volume """ return self.fnet(image) def __fill(self, tstamps, images, intrinsics): """ fill operator """ tt = torch.as_tensor(tstamps, device="cuda") images = torch.stack(images, 0) intrinsics = torch.stack(intrinsics, 0) inputs = images[:,:,[2,1,0]].to(self.device) / 255.0 ### linear pose interpolation ### N = self.video.counter.value # number of keyframes M = len(tstamps) # 16 frames to fill ts = self.video.tstamp[:N] # tstamp of keyframes Ps = SE3(self.video.poses[:N]) # pose of keyframes t0 = torch.as_tensor([ts[ts<=t].shape[0] - 1 for t in tstamps]) t1 = torch.where(t0 0: pose_list += self.__fill(tstamps, images, intrinsics) # stitch pose segments together return lietorch.cat(pose_list, 0)