Spaces:
Running
Running
import lietorch | |
import torch | |
import torch.nn.functional as F | |
from .chol import block_solve, schur_solve | |
import geom.projective_ops as pops | |
from torch_scatter import scatter_sum | |
# utility functions for scattering ops | |
def safe_scatter_add_mat(A, ii, jj, n, m): | |
v = (ii >= 0) & (jj >= 0) & (ii < n) & (jj < m) | |
return scatter_sum(A[:,v], ii[v]*m + jj[v], dim=1, dim_size=n*m) | |
def safe_scatter_add_vec(b, ii, n): | |
v = (ii >= 0) & (ii < n) | |
return scatter_sum(b[:,v], ii[v], dim=1, dim_size=n) | |
# apply retraction operator to inv-depth maps | |
def disp_retr(disps, dz, ii): | |
ii = ii.to(device=dz.device) | |
return disps + scatter_sum(dz, ii, dim=1, dim_size=disps.shape[1]) | |
# apply retraction operator to poses | |
def pose_retr(poses, dx, ii): | |
ii = ii.to(device=dx.device) | |
return poses.retr(scatter_sum(dx, ii, dim=1, dim_size=poses.shape[1])) | |
def BA(target, weight, eta, poses, disps, intrinsics, ii, jj, fixedp=1, rig=1): | |
""" Full Bundle Adjustment """ | |
B, P, ht, wd = disps.shape | |
N = ii.shape[0] | |
D = poses.manifold_dim | |
### 1: commpute jacobians and residuals ### | |
coords, valid, (Ji, Jj, Jz) = pops.projective_transform( | |
poses, disps, intrinsics, ii, jj, jacobian=True) | |
r = (target - coords).view(B, N, -1, 1) | |
w = .001 * (valid * weight).view(B, N, -1, 1) | |
### 2: construct linear system ### | |
Ji = Ji.reshape(B, N, -1, D) | |
Jj = Jj.reshape(B, N, -1, D) | |
wJiT = (w * Ji).transpose(2,3) | |
wJjT = (w * Jj).transpose(2,3) | |
Jz = Jz.reshape(B, N, ht*wd, -1) | |
Hii = torch.matmul(wJiT, Ji) | |
Hij = torch.matmul(wJiT, Jj) | |
Hji = torch.matmul(wJjT, Ji) | |
Hjj = torch.matmul(wJjT, Jj) | |
vi = torch.matmul(wJiT, r).squeeze(-1) | |
vj = torch.matmul(wJjT, r).squeeze(-1) | |
Ei = (wJiT.view(B,N,D,ht*wd,-1) * Jz[:,:,None]).sum(dim=-1) | |
Ej = (wJjT.view(B,N,D,ht*wd,-1) * Jz[:,:,None]).sum(dim=-1) | |
w = w.view(B, N, ht*wd, -1) | |
r = r.view(B, N, ht*wd, -1) | |
wk = torch.sum(w*r*Jz, dim=-1) | |
Ck = torch.sum(w*Jz*Jz, dim=-1) | |
kx, kk = torch.unique(ii, return_inverse=True) | |
M = kx.shape[0] | |
# only optimize keyframe poses | |
P = P // rig - fixedp | |
ii = ii // rig - fixedp | |
jj = jj // rig - fixedp | |
H = safe_scatter_add_mat(Hii, ii, ii, P, P) + \ | |
safe_scatter_add_mat(Hij, ii, jj, P, P) + \ | |
safe_scatter_add_mat(Hji, jj, ii, P, P) + \ | |
safe_scatter_add_mat(Hjj, jj, jj, P, P) | |
E = safe_scatter_add_mat(Ei, ii, kk, P, M) + \ | |
safe_scatter_add_mat(Ej, jj, kk, P, M) | |
v = safe_scatter_add_vec(vi, ii, P) + \ | |
safe_scatter_add_vec(vj, jj, P) | |
C = safe_scatter_add_vec(Ck, kk, M) | |
w = safe_scatter_add_vec(wk, kk, M) | |
C = C + eta.view(*C.shape) + 1e-7 | |
H = H.view(B, P, P, D, D) | |
E = E.view(B, P, M, D, ht*wd) | |
### 3: solve the system ### | |
dx, dz = schur_solve(H, E, C, v, w) | |
### 4: apply retraction ### | |
poses = pose_retr(poses, dx, torch.arange(P) + fixedp) | |
disps = disp_retr(disps, dz.view(B,-1,ht,wd), kx) | |
disps = torch.where(disps > 10, torch.zeros_like(disps), disps) | |
disps = disps.clamp(min=0.0) | |
return poses, disps | |
def MoBA(target, weight, eta, poses, disps, intrinsics, ii, jj, fixedp=1, rig=1): | |
""" Motion only bundle adjustment """ | |
B, P, ht, wd = disps.shape | |
N = ii.shape[0] | |
D = poses.manifold_dim | |
### 1: commpute jacobians and residuals ### | |
coords, valid, (Ji, Jj, Jz) = pops.projective_transform( | |
poses, disps, intrinsics, ii, jj, jacobian=True) | |
r = (target - coords).view(B, N, -1, 1) | |
w = .001 * (valid * weight).view(B, N, -1, 1) | |
### 2: construct linear system ### | |
Ji = Ji.reshape(B, N, -1, D) | |
Jj = Jj.reshape(B, N, -1, D) | |
wJiT = (w * Ji).transpose(2,3) | |
wJjT = (w * Jj).transpose(2,3) | |
Hii = torch.matmul(wJiT, Ji) | |
Hij = torch.matmul(wJiT, Jj) | |
Hji = torch.matmul(wJjT, Ji) | |
Hjj = torch.matmul(wJjT, Jj) | |
vi = torch.matmul(wJiT, r).squeeze(-1) | |
vj = torch.matmul(wJjT, r).squeeze(-1) | |
# only optimize keyframe poses | |
P = P // rig - fixedp | |
ii = ii // rig - fixedp | |
jj = jj // rig - fixedp | |
H = safe_scatter_add_mat(Hii, ii, ii, P, P) + \ | |
safe_scatter_add_mat(Hij, ii, jj, P, P) + \ | |
safe_scatter_add_mat(Hji, jj, ii, P, P) + \ | |
safe_scatter_add_mat(Hjj, jj, jj, P, P) | |
v = safe_scatter_add_vec(vi, ii, P) + \ | |
safe_scatter_add_vec(vj, jj, P) | |
H = H.view(B, P, P, D, D) | |
### 3: solve the system ### | |
dx = block_solve(H, v) | |
### 4: apply retraction ### | |
poses = pose_retr(poses, dx, torch.arange(P) + fixedp) | |
return poses | |