Spaces:
Running
Running
import torch | |
import torch.nn.functional as F | |
import geom.projective_ops as pops | |
class CholeskySolver(torch.autograd.Function): | |
def forward(ctx, H, b): | |
# don't crash training if cholesky decomp fails | |
try: | |
U = torch.linalg.cholesky(H) | |
xs = torch.cholesky_solve(b, U) | |
ctx.save_for_backward(U, xs) | |
ctx.failed = False | |
except Exception as e: | |
print(e) | |
ctx.failed = True | |
xs = torch.zeros_like(b) | |
return xs | |
def backward(ctx, grad_x): | |
if ctx.failed: | |
return None, None | |
U, xs = ctx.saved_tensors | |
dz = torch.cholesky_solve(grad_x, U) | |
dH = -torch.matmul(xs, dz.transpose(-1,-2)) | |
return dH, dz | |
def block_solve(H, b, ep=0.1, lm=0.0001): | |
""" solve normal equations """ | |
B, N, _, D, _ = H.shape | |
I = torch.eye(D).to(H.device) | |
H = H + (ep + lm*H) * I | |
H = H.permute(0,1,3,2,4) | |
H = H.reshape(B, N*D, N*D) | |
b = b.reshape(B, N*D, 1) | |
x = CholeskySolver.apply(H,b) | |
return x.reshape(B, N, D) | |
def schur_solve(H, E, C, v, w, ep=0.1, lm=0.0001, sless=False): | |
""" solve using shur complement """ | |
B, P, M, D, HW = E.shape | |
H = H.permute(0,1,3,2,4).reshape(B, P*D, P*D) | |
E = E.permute(0,1,3,2,4).reshape(B, P*D, M*HW) | |
Q = (1.0 / C).view(B, M*HW, 1) | |
# damping | |
I = torch.eye(P*D).to(H.device) | |
H = H + (ep + lm*H) * I | |
v = v.reshape(B, P*D, 1) | |
w = w.reshape(B, M*HW, 1) | |
Et = E.transpose(1,2) | |
S = H - torch.matmul(E, Q*Et) | |
v = v - torch.matmul(E, Q*w) | |
dx = CholeskySolver.apply(S, v) | |
if sless: | |
return dx.reshape(B, P, D) | |
dz = Q * (w - Et @ dx) | |
dx = dx.reshape(B, P, D) | |
dz = dz.reshape(B, M, HW) | |
return dx, dz |