import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from collections import OrderedDict from modules.extractor import BasicEncoder from modules.corr import CorrBlock from modules.gru import ConvGRU from modules.clipping import GradientClip from lietorch import SE3 from geom.ba import BA import geom.projective_ops as pops from geom.graph_utils import graph_to_edge_list, keyframe_indicies from torch_scatter import scatter_mean def cvx_upsample(data, mask): """ upsample pixel-wise transformation field """ batch, ht, wd, dim = data.shape data = data.permute(0, 3, 1, 2) mask = mask.view(batch, 1, 9, 8, 8, ht, wd) mask = torch.softmax(mask, dim=2) up_data = F.unfold(data, [3,3], padding=1) up_data = up_data.view(batch, dim, 9, 1, 1, ht, wd) up_data = torch.sum(mask * up_data, dim=2) up_data = up_data.permute(0, 4, 2, 5, 3, 1) up_data = up_data.reshape(batch, 8*ht, 8*wd, dim) return up_data def upsample_disp(disp, mask): batch, num, ht, wd = disp.shape disp = disp.view(batch*num, ht, wd, 1) mask = mask.view(batch*num, -1, ht, wd) return cvx_upsample(disp, mask).view(batch, num, 8*ht, 8*wd) class GraphAgg(nn.Module): def __init__(self): super(GraphAgg, self).__init__() self.conv1 = nn.Conv2d(128, 128, 3, padding=1) self.conv2 = nn.Conv2d(128, 128, 3, padding=1) self.relu = nn.ReLU(inplace=True) self.eta = nn.Sequential( nn.Conv2d(128, 1, 3, padding=1), GradientClip(), nn.Softplus()) self.upmask = nn.Sequential( nn.Conv2d(128, 8*8*9, 1, padding=0)) def forward(self, net, ii): batch, num, ch, ht, wd = net.shape net = net.view(batch*num, ch, ht, wd) _, ix = torch.unique(ii, return_inverse=True) net = self.relu(self.conv1(net)) net = net.view(batch, num, 128, ht, wd) net = scatter_mean(net, ix, dim=1) net = net.view(-1, 128, ht, wd) net = self.relu(self.conv2(net)) eta = self.eta(net).view(batch, -1, ht, wd) upmask = self.upmask(net).view(batch, -1, 8*8*9, ht, wd) return .01 * eta, upmask class UpdateModule(nn.Module): def __init__(self): super(UpdateModule, self).__init__() cor_planes = 4 * (2*3 + 1)**2 self.corr_encoder = nn.Sequential( nn.Conv2d(cor_planes, 128, 1, padding=0), nn.ReLU(inplace=True), nn.Conv2d(128, 128, 3, padding=1), nn.ReLU(inplace=True)) self.flow_encoder = nn.Sequential( nn.Conv2d(4, 128, 7, padding=3), nn.ReLU(inplace=True), nn.Conv2d(128, 64, 3, padding=1), nn.ReLU(inplace=True)) self.weight = nn.Sequential( nn.Conv2d(128, 128, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, 2, 3, padding=1), GradientClip(), nn.Sigmoid()) self.delta = nn.Sequential( nn.Conv2d(128, 128, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, 2, 3, padding=1), GradientClip()) self.gru = ConvGRU(128, 128+128+64) self.agg = GraphAgg() def forward(self, net, inp, corr, flow=None, ii=None, jj=None, mask=None): """ RaftSLAM update operator """ batch, num, ch, ht, wd = net.shape if flow is None: flow = torch.zeros(batch, num, 4, ht, wd, device=net.device) output_dim = (batch, num, -1, ht, wd) net = net.view(batch*num, -1, ht, wd) inp = inp.view(batch*num, -1, ht, wd) corr = corr.view(batch*num, -1, ht, wd) flow = flow.view(batch*num, -1, ht, wd) corr = self.corr_encoder(corr) flow = self.flow_encoder(flow) net = self.gru(net, inp, corr, flow) ### update variables ### delta = self.delta(net).view(*output_dim) weight = self.weight(net).view(*output_dim) # print('Update') # print('delta:', delta.shape) # [1,1,2,64,48] # print('weight:', weight.shape) # [1,1,2,64,48] delta = delta.permute(0,1,3,4,2)[...,:2].contiguous() weight = weight.permute(0,1,3,4,2)[...,:2].contiguous() net = net.view(*output_dim) if ii is not None: eta, upmask = self.agg(net, ii.to(net.device)) return net, delta, weight, eta, upmask else: return net, delta, weight class DroidNet(nn.Module): def __init__(self): super(DroidNet, self).__init__() self.fnet = BasicEncoder(output_dim=128, norm_fn='instance') self.cnet = BasicEncoder(output_dim=256, norm_fn='none') self.update = UpdateModule() def extract_features(self, images): """ run feeature extraction networks """ # normalize images images = images[:, :, [2,1,0]] / 255.0 mean = torch.as_tensor([0.485, 0.456, 0.406], device=images.device) std = torch.as_tensor([0.229, 0.224, 0.225], device=images.device) images = images.sub_(mean[:, None, None]).div_(std[:, None, None]) fmaps = self.fnet(images) net = self.cnet(images) net, inp = net.split([128,128], dim=2) net = torch.tanh(net) inp = torch.relu(inp) return fmaps, net, inp def forward(self, Gs, images, disps, intrinsics, graph=None, num_steps=12, fixedp=2): """ Estimates SE3 or Sim3 between pair of frames """ u = keyframe_indicies(graph) ii, jj, kk = graph_to_edge_list(graph) ii = ii.to(device=images.device, dtype=torch.long) jj = jj.to(device=images.device, dtype=torch.long) fmaps, net, inp = self.extract_features(images) net, inp = net[:,ii], inp[:,ii] corr_fn = CorrBlock(fmaps[:,ii], fmaps[:,jj], num_levels=4, radius=3) ht, wd = images.shape[-2:] coords0 = pops.coords_grid(ht//8, wd//8, device=images.device) coords1, _ = pops.projective_transform(Gs, disps, intrinsics, ii, jj) target = coords1.clone() Gs_list, disp_list, residual_list = [], [], [] for step in range(num_steps): Gs = Gs.detach() disps = disps.detach() coords1 = coords1.detach() target = target.detach() # extract motion features corr = corr_fn(coords1) resd = target - coords1 flow = coords1 - coords0 motion = torch.cat([flow, resd], dim=-1) motion = motion.permute(0,1,4,2,3).clamp(-64.0, 64.0) net, delta, weight, eta, upmask = \ self.update(net, inp, corr, motion, ii, jj) target = coords1 + delta for i in range(2): Gs, disps = BA(target, weight, eta, Gs, disps, intrinsics, ii, jj, fixedp=2) coords1, valid_mask = pops.projective_transform(Gs, disps, intrinsics, ii, jj) residual = (target - coords1) Gs_list.append(Gs) disp_list.append(upsample_disp(disps, upmask)) residual_list.append(valid_mask * residual) return Gs_list, disp_list, residual_list