ThunderVVV's picture
add thirdparty
b7eedf7
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from modules.extractor import BasicEncoder
from modules.corr import CorrBlock
from modules.gru import ConvGRU
from modules.clipping import GradientClip
from lietorch import SE3
from geom.ba import BA
import geom.projective_ops as pops
from geom.graph_utils import graph_to_edge_list, keyframe_indicies
from torch_scatter import scatter_mean
def cvx_upsample(data, mask):
""" upsample pixel-wise transformation field """
batch, ht, wd, dim = data.shape
data = data.permute(0, 3, 1, 2)
mask = mask.view(batch, 1, 9, 8, 8, ht, wd)
mask = torch.softmax(mask, dim=2)
up_data = F.unfold(data, [3,3], padding=1)
up_data = up_data.view(batch, dim, 9, 1, 1, ht, wd)
up_data = torch.sum(mask * up_data, dim=2)
up_data = up_data.permute(0, 4, 2, 5, 3, 1)
up_data = up_data.reshape(batch, 8*ht, 8*wd, dim)
return up_data
def upsample_disp(disp, mask):
batch, num, ht, wd = disp.shape
disp = disp.view(batch*num, ht, wd, 1)
mask = mask.view(batch*num, -1, ht, wd)
return cvx_upsample(disp, mask).view(batch, num, 8*ht, 8*wd)
class GraphAgg(nn.Module):
def __init__(self):
super(GraphAgg, self).__init__()
self.conv1 = nn.Conv2d(128, 128, 3, padding=1)
self.conv2 = nn.Conv2d(128, 128, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
self.eta = nn.Sequential(
nn.Conv2d(128, 1, 3, padding=1),
GradientClip(),
nn.Softplus())
self.upmask = nn.Sequential(
nn.Conv2d(128, 8*8*9, 1, padding=0))
def forward(self, net, ii):
batch, num, ch, ht, wd = net.shape
net = net.view(batch*num, ch, ht, wd)
_, ix = torch.unique(ii, return_inverse=True)
net = self.relu(self.conv1(net))
net = net.view(batch, num, 128, ht, wd)
net = scatter_mean(net, ix, dim=1)
net = net.view(-1, 128, ht, wd)
net = self.relu(self.conv2(net))
eta = self.eta(net).view(batch, -1, ht, wd)
upmask = self.upmask(net).view(batch, -1, 8*8*9, ht, wd)
return .01 * eta, upmask
class UpdateModule(nn.Module):
def __init__(self):
super(UpdateModule, self).__init__()
cor_planes = 4 * (2*3 + 1)**2
self.corr_encoder = nn.Sequential(
nn.Conv2d(cor_planes, 128, 1, padding=0),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, padding=1),
nn.ReLU(inplace=True))
self.flow_encoder = nn.Sequential(
nn.Conv2d(4, 128, 7, padding=3),
nn.ReLU(inplace=True),
nn.Conv2d(128, 64, 3, padding=1),
nn.ReLU(inplace=True))
self.weight = nn.Sequential(
nn.Conv2d(128, 128, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 2, 3, padding=1),
GradientClip(),
nn.Sigmoid())
self.delta = nn.Sequential(
nn.Conv2d(128, 128, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 2, 3, padding=1),
GradientClip())
self.gru = ConvGRU(128, 128+128+64)
self.agg = GraphAgg()
def forward(self, net, inp, corr, flow=None, ii=None, jj=None, mask=None):
""" RaftSLAM update operator """
batch, num, ch, ht, wd = net.shape
if flow is None:
flow = torch.zeros(batch, num, 4, ht, wd, device=net.device)
output_dim = (batch, num, -1, ht, wd)
net = net.view(batch*num, -1, ht, wd)
inp = inp.view(batch*num, -1, ht, wd)
corr = corr.view(batch*num, -1, ht, wd)
flow = flow.view(batch*num, -1, ht, wd)
corr = self.corr_encoder(corr)
flow = self.flow_encoder(flow)
net = self.gru(net, inp, corr, flow)
### update variables ###
delta = self.delta(net).view(*output_dim)
weight = self.weight(net).view(*output_dim)
# print('Update')
# print('delta:', delta.shape) # [1,1,2,64,48]
# print('weight:', weight.shape) # [1,1,2,64,48]
delta = delta.permute(0,1,3,4,2)[...,:2].contiguous()
weight = weight.permute(0,1,3,4,2)[...,:2].contiguous()
net = net.view(*output_dim)
if ii is not None:
eta, upmask = self.agg(net, ii.to(net.device))
return net, delta, weight, eta, upmask
else:
return net, delta, weight
class DroidNet(nn.Module):
def __init__(self):
super(DroidNet, self).__init__()
self.fnet = BasicEncoder(output_dim=128, norm_fn='instance')
self.cnet = BasicEncoder(output_dim=256, norm_fn='none')
self.update = UpdateModule()
def extract_features(self, images):
""" run feeature extraction networks """
# normalize images
images = images[:, :, [2,1,0]] / 255.0
mean = torch.as_tensor([0.485, 0.456, 0.406], device=images.device)
std = torch.as_tensor([0.229, 0.224, 0.225], device=images.device)
images = images.sub_(mean[:, None, None]).div_(std[:, None, None])
fmaps = self.fnet(images)
net = self.cnet(images)
net, inp = net.split([128,128], dim=2)
net = torch.tanh(net)
inp = torch.relu(inp)
return fmaps, net, inp
def forward(self, Gs, images, disps, intrinsics, graph=None, num_steps=12, fixedp=2):
""" Estimates SE3 or Sim3 between pair of frames """
u = keyframe_indicies(graph)
ii, jj, kk = graph_to_edge_list(graph)
ii = ii.to(device=images.device, dtype=torch.long)
jj = jj.to(device=images.device, dtype=torch.long)
fmaps, net, inp = self.extract_features(images)
net, inp = net[:,ii], inp[:,ii]
corr_fn = CorrBlock(fmaps[:,ii], fmaps[:,jj], num_levels=4, radius=3)
ht, wd = images.shape[-2:]
coords0 = pops.coords_grid(ht//8, wd//8, device=images.device)
coords1, _ = pops.projective_transform(Gs, disps, intrinsics, ii, jj)
target = coords1.clone()
Gs_list, disp_list, residual_list = [], [], []
for step in range(num_steps):
Gs = Gs.detach()
disps = disps.detach()
coords1 = coords1.detach()
target = target.detach()
# extract motion features
corr = corr_fn(coords1)
resd = target - coords1
flow = coords1 - coords0
motion = torch.cat([flow, resd], dim=-1)
motion = motion.permute(0,1,4,2,3).clamp(-64.0, 64.0)
net, delta, weight, eta, upmask = \
self.update(net, inp, corr, motion, ii, jj)
target = coords1 + delta
for i in range(2):
Gs, disps = BA(target, weight, eta, Gs, disps, intrinsics, ii, jj, fixedp=2)
coords1, valid_mask = pops.projective_transform(Gs, disps, intrinsics, ii, jj)
residual = (target - coords1)
Gs_list.append(Gs)
disp_list.append(upsample_disp(disps, upmask))
residual_list.append(valid_mask * residual)
return Gs_list, disp_list, residual_list