ThunderVVV's picture
add thirdparty
b7eedf7
import torch
import torch.nn.functional as F
import geom.projective_ops as pops
class CholeskySolver(torch.autograd.Function):
@staticmethod
def forward(ctx, H, b):
# don't crash training if cholesky decomp fails
try:
U = torch.linalg.cholesky(H)
xs = torch.cholesky_solve(b, U)
ctx.save_for_backward(U, xs)
ctx.failed = False
except Exception as e:
print(e)
ctx.failed = True
xs = torch.zeros_like(b)
return xs
@staticmethod
def backward(ctx, grad_x):
if ctx.failed:
return None, None
U, xs = ctx.saved_tensors
dz = torch.cholesky_solve(grad_x, U)
dH = -torch.matmul(xs, dz.transpose(-1,-2))
return dH, dz
def block_solve(H, b, ep=0.1, lm=0.0001):
""" solve normal equations """
B, N, _, D, _ = H.shape
I = torch.eye(D).to(H.device)
H = H + (ep + lm*H) * I
H = H.permute(0,1,3,2,4)
H = H.reshape(B, N*D, N*D)
b = b.reshape(B, N*D, 1)
x = CholeskySolver.apply(H,b)
return x.reshape(B, N, D)
def schur_solve(H, E, C, v, w, ep=0.1, lm=0.0001, sless=False):
""" solve using shur complement """
B, P, M, D, HW = E.shape
H = H.permute(0,1,3,2,4).reshape(B, P*D, P*D)
E = E.permute(0,1,3,2,4).reshape(B, P*D, M*HW)
Q = (1.0 / C).view(B, M*HW, 1)
# damping
I = torch.eye(P*D).to(H.device)
H = H + (ep + lm*H) * I
v = v.reshape(B, P*D, 1)
w = w.reshape(B, M*HW, 1)
Et = E.transpose(1,2)
S = H - torch.matmul(E, Q*Et)
v = v - torch.matmul(E, Q*w)
dx = CholeskySolver.apply(S, v)
if sless:
return dx.reshape(B, P, D)
dz = Q * (w - Et @ dx)
dx = dx.reshape(B, P, D)
dz = dz.reshape(B, M, HW)
return dx, dz