FrozenBurning
single view to 3D init release
81ecb2b
raw
history blame
7.4 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import time
import torch
import torch.nn as nn
from torch.autograd import Function
import torch.nn.functional as F
try:
from . import utilslib
except:
import utilslib
class ComputeRaydirs(Function):
@staticmethod
def forward(self, viewpos, viewrot, focal, princpt, pixelcoords, volradius):
for tensor in [viewpos, viewrot, focal, princpt, pixelcoords]:
assert tensor.is_contiguous()
N = viewpos.size(0)
if isinstance(pixelcoords, tuple):
W, H = pixelcoords
pixelcoords = None
else:
H = pixelcoords.size(1)
W = pixelcoords.size(2)
raypos = torch.empty((N, H, W, 3), device=viewpos.device)
raydirs = torch.empty((N, H, W, 3), device=viewpos.device)
tminmax = torch.empty((N, H, W, 2), device=viewpos.device)
utilslib.compute_raydirs_forward(viewpos, viewrot, focal, princpt,
pixelcoords, W, H, volradius, raypos, raydirs, tminmax)
return raypos, raydirs, tminmax
@staticmethod
def backward(self, grad_raydirs, grad_tminmax):
return None, None, None, None, None, None
def compute_raydirs(viewpos, viewrot, focal, princpt, pixelcoords, volradius):
raypos, raydirs, tminmax = ComputeRaydirs.apply(viewpos, viewrot, focal, princpt, pixelcoords, volradius)
return raypos, raydirs, tminmax
class Rodrigues(nn.Module):
def __init__(self):
super(Rodrigues, self).__init__()
def forward(self, rvec):
theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=1))
rvec = rvec / theta[:, None]
costh = torch.cos(theta)
sinth = torch.sin(theta)
return torch.stack((
rvec[:, 0] ** 2 + (1. - rvec[:, 0] ** 2) * costh,
rvec[:, 0] * rvec[:, 1] * (1. - costh) - rvec[:, 2] * sinth,
rvec[:, 0] * rvec[:, 2] * (1. - costh) + rvec[:, 1] * sinth,
rvec[:, 0] * rvec[:, 1] * (1. - costh) + rvec[:, 2] * sinth,
rvec[:, 1] ** 2 + (1. - rvec[:, 1] ** 2) * costh,
rvec[:, 1] * rvec[:, 2] * (1. - costh) - rvec[:, 0] * sinth,
rvec[:, 0] * rvec[:, 2] * (1. - costh) - rvec[:, 1] * sinth,
rvec[:, 1] * rvec[:, 2] * (1. - costh) + rvec[:, 0] * sinth,
rvec[:, 2] ** 2 + (1. - rvec[:, 2] ** 2) * costh), dim=1).view(-1, 3, 3)
def gradcheck():
N = 2
H = 64
W = 64
k3 = 4
K = k3*k3*k3
M = 32
volradius = 1.
# generate random inputs
torch.manual_seed(1113)
rodrigues = Rodrigues()
_viewpos = torch.tensor([[-0.0, 0.0, -4.] for n in range(N)], device="cuda") + torch.randn(N, 3, device="cuda") * 0.1
viewrvec = torch.randn(N, 3, device="cuda") * 0.01
_viewrot = rodrigues(viewrvec)
_focal = torch.tensor([[W*4.0, W*4.0] for n in range(N)], device="cuda")
_princpt = torch.tensor([[W*0.5, H*0.5] for n in range(N)], device="cuda")
pixely, pixelx = torch.meshgrid(torch.arange(H, device="cuda").float(), torch.arange(W, device="cuda").float())
_pixelcoords = torch.stack([pixelx, pixely], dim=-1)[None, :, :, :].repeat(N, 1, 1, 1)
_viewpos = _viewpos.contiguous().detach().clone()
_viewpos.requires_grad = True
_viewrot = _viewrot.contiguous().detach().clone()
_viewrot.requires_grad = True
_focal = _focal.contiguous().detach().clone()
_focal.requires_grad = True
_princpt = _princpt.contiguous().detach().clone()
_princpt.requires_grad = True
_pixelcoords = _pixelcoords.contiguous().detach().clone()
_pixelcoords.requires_grad = True
max_len = 6.0
_stepsize = max_len / 15.5
params = [_viewpos, _viewrot, _focal, _princpt]
paramnames = ["viewpos", "viewrot", "focal", "princpt"]
########################### run pytorch version ###########################
viewpos = _viewpos
viewrot = _viewrot
focal = _focal
princpt = _princpt
pixelcoords = _pixelcoords
raypos = viewpos[:, None, None, :].repeat(1, H, W, 1)
raydir = (pixelcoords - princpt[:, None, None, :]) / focal[:, None, None, :]
raydir = torch.cat([raydir, torch.ones_like(raydir[:, :, :, 0:1])], dim=-1)
raydir = torch.sum(viewrot[:, None, None, :, :] * raydir[:, :, :, :, None], dim=-2)
raydir = raydir / torch.sqrt(torch.sum(raydir ** 2, dim=-1, keepdim=True))
t1 = (-1. - viewpos[:, None, None, :]) / raydir
t2 = ( 1. - viewpos[:, None, None, :]) / raydir
tmin = torch.max(torch.min(t1[..., 0], t2[..., 0]),
torch.max(torch.min(t1[..., 1], t2[..., 1]),
torch.min(t1[..., 2], t2[..., 2]))).clamp(min=0.)
tmax = torch.min(torch.max(t1[..., 0], t2[..., 0]),
torch.min(torch.max(t1[..., 1], t2[..., 1]),
torch.max(t1[..., 2], t2[..., 2])))
tminmax = torch.stack([tmin, tmax], dim=-1)
sample0 = raydir
torch.cuda.synchronize()
time1 = time.time()
sample0.backward(torch.ones_like(sample0))
torch.cuda.synchronize()
time2 = time.time()
grads0 = [p.grad.detach().clone() if p.grad is not None else None for p in params]
for p in params:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
############################## run cuda version ###########################
viewpos = _viewpos
viewrot = _viewrot
focal = _focal
princpt = _princpt
pixelcoords = _pixelcoords
niter = 1
for p in params:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
t0 = time.time()
torch.cuda.synchronize()
sample1 = compute_raydirs(viewpos, viewrot, focal, princpt, pixelcoords, volradius)[1]
t1 = time.time()
torch.cuda.synchronize()
print("-----------------------------------------------------------------")
print("{:>10} {:>10} {:>10} {:>10} {:>10} {:>10}".format("", "maxabsdiff", "dp", "index", "py", "cuda"))
ind = torch.argmax(torch.abs(sample0 - sample1))
print("{:<10} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format(
"fwd",
torch.max(torch.abs(sample0 - sample1)).item(),
(torch.sum(sample0 * sample1) / torch.sqrt(torch.sum(sample0 * sample0) * torch.sum(sample1 * sample1))).item(),
ind.item(),
sample0.view(-1)[ind].item(),
sample1.view(-1)[ind].item()))
sample1.backward(torch.ones_like(sample1), retain_graph=True)
torch.cuda.synchronize()
t2 = time.time()
print("{:<10} {:10.5} {:10.5} {:10.5}".format("time", tf / niter, tb / niter, (tf + tb) / niter))
grads1 = [p.grad.detach().clone() if p.grad is not None else None for p in params]
############# compare results #############
for p, g0, g1 in zip(paramnames, grads0, grads1):
ind = torch.argmax(torch.abs(g0 - g1))
print("{:<10} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format(
p,
torch.max(torch.abs(g0 - g1)).item(),
(torch.sum(g0 * g1) / torch.sqrt(torch.sum(g0 * g0) * torch.sum(g1 * g1))).item(),
ind.item(),
g0.view(-1)[ind].item(),
g1.view(-1)[ind].item()))
if __name__ == "__main__":
gradcheck()