import lietorch import torch import torch.nn.functional as F from .chol import block_solve, schur_solve import geom.projective_ops as pops from torch_scatter import scatter_sum # utility functions for scattering ops def safe_scatter_add_mat(A, ii, jj, n, m): v = (ii >= 0) & (jj >= 0) & (ii < n) & (jj < m) return scatter_sum(A[:,v], ii[v]*m + jj[v], dim=1, dim_size=n*m) def safe_scatter_add_vec(b, ii, n): v = (ii >= 0) & (ii < n) return scatter_sum(b[:,v], ii[v], dim=1, dim_size=n) # apply retraction operator to inv-depth maps def disp_retr(disps, dz, ii): ii = ii.to(device=dz.device) return disps + scatter_sum(dz, ii, dim=1, dim_size=disps.shape[1]) # apply retraction operator to poses def pose_retr(poses, dx, ii): ii = ii.to(device=dx.device) return poses.retr(scatter_sum(dx, ii, dim=1, dim_size=poses.shape[1])) def BA(target, weight, eta, poses, disps, intrinsics, ii, jj, fixedp=1, rig=1): """ Full Bundle Adjustment """ B, P, ht, wd = disps.shape N = ii.shape[0] D = poses.manifold_dim ### 1: commpute jacobians and residuals ### coords, valid, (Ji, Jj, Jz) = pops.projective_transform( poses, disps, intrinsics, ii, jj, jacobian=True) r = (target - coords).view(B, N, -1, 1) w = .001 * (valid * weight).view(B, N, -1, 1) ### 2: construct linear system ### Ji = Ji.reshape(B, N, -1, D) Jj = Jj.reshape(B, N, -1, D) wJiT = (w * Ji).transpose(2,3) wJjT = (w * Jj).transpose(2,3) Jz = Jz.reshape(B, N, ht*wd, -1) Hii = torch.matmul(wJiT, Ji) Hij = torch.matmul(wJiT, Jj) Hji = torch.matmul(wJjT, Ji) Hjj = torch.matmul(wJjT, Jj) vi = torch.matmul(wJiT, r).squeeze(-1) vj = torch.matmul(wJjT, r).squeeze(-1) Ei = (wJiT.view(B,N,D,ht*wd,-1) * Jz[:,:,None]).sum(dim=-1) Ej = (wJjT.view(B,N,D,ht*wd,-1) * Jz[:,:,None]).sum(dim=-1) w = w.view(B, N, ht*wd, -1) r = r.view(B, N, ht*wd, -1) wk = torch.sum(w*r*Jz, dim=-1) Ck = torch.sum(w*Jz*Jz, dim=-1) kx, kk = torch.unique(ii, return_inverse=True) M = kx.shape[0] # only optimize keyframe poses P = P // rig - fixedp ii = ii // rig - fixedp jj = jj // rig - fixedp H = safe_scatter_add_mat(Hii, ii, ii, P, P) + \ safe_scatter_add_mat(Hij, ii, jj, P, P) + \ safe_scatter_add_mat(Hji, jj, ii, P, P) + \ safe_scatter_add_mat(Hjj, jj, jj, P, P) E = safe_scatter_add_mat(Ei, ii, kk, P, M) + \ safe_scatter_add_mat(Ej, jj, kk, P, M) v = safe_scatter_add_vec(vi, ii, P) + \ safe_scatter_add_vec(vj, jj, P) C = safe_scatter_add_vec(Ck, kk, M) w = safe_scatter_add_vec(wk, kk, M) C = C + eta.view(*C.shape) + 1e-7 H = H.view(B, P, P, D, D) E = E.view(B, P, M, D, ht*wd) ### 3: solve the system ### dx, dz = schur_solve(H, E, C, v, w) ### 4: apply retraction ### poses = pose_retr(poses, dx, torch.arange(P) + fixedp) disps = disp_retr(disps, dz.view(B,-1,ht,wd), kx) disps = torch.where(disps > 10, torch.zeros_like(disps), disps) disps = disps.clamp(min=0.0) return poses, disps def MoBA(target, weight, eta, poses, disps, intrinsics, ii, jj, fixedp=1, rig=1): """ Motion only bundle adjustment """ B, P, ht, wd = disps.shape N = ii.shape[0] D = poses.manifold_dim ### 1: commpute jacobians and residuals ### coords, valid, (Ji, Jj, Jz) = pops.projective_transform( poses, disps, intrinsics, ii, jj, jacobian=True) r = (target - coords).view(B, N, -1, 1) w = .001 * (valid * weight).view(B, N, -1, 1) ### 2: construct linear system ### Ji = Ji.reshape(B, N, -1, D) Jj = Jj.reshape(B, N, -1, D) wJiT = (w * Ji).transpose(2,3) wJjT = (w * Jj).transpose(2,3) Hii = torch.matmul(wJiT, Ji) Hij = torch.matmul(wJiT, Jj) Hji = torch.matmul(wJjT, Ji) Hjj = torch.matmul(wJjT, Jj) vi = torch.matmul(wJiT, r).squeeze(-1) vj = torch.matmul(wJjT, r).squeeze(-1) # only optimize keyframe poses P = P // rig - fixedp ii = ii // rig - fixedp jj = jj // rig - fixedp H = safe_scatter_add_mat(Hii, ii, ii, P, P) + \ safe_scatter_add_mat(Hij, ii, jj, P, P) + \ safe_scatter_add_mat(Hji, jj, ii, P, P) + \ safe_scatter_add_mat(Hjj, jj, jj, P, P) v = safe_scatter_add_vec(vi, ii, P) + \ safe_scatter_add_vec(vj, jj, P) H = H.view(B, P, P, D, D) ### 3: solve the system ### dx = block_solve(H, v) ### 4: apply retraction ### poses = pose_retr(poses, dx, torch.arange(P) + fixedp) return poses