# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import numpy as np import time import torch import torch.nn as nn from torch.autograd import Function import torch.nn.functional as F try: from . import mvpraymarchlib except: import mvpraymarchlib def build_accel(primtransfin, algo, fixedorder=False): """build bvh structure given primitive centers and sizes Parameters: ---------- primtransfin : tuple[tensor, tensor, tensor] primitive transform tensors algo : int raymarching algorithm fixedorder : optional[str] True means the bvh builder will not reorder primitives and will use a trivial tree structure. Likely to be slow for arbitrary configurations of primitives. """ primpos, primrot, primscale = primtransfin N = primpos.size(0) K = primpos.size(1) dev = primpos.device # compute and sort morton codes if fixedorder: sortedobjid = (torch.arange(N*K, dtype=torch.int32, device=dev) % K).view(N, K) else: cmax = primpos.max(dim=1, keepdim=True)[0] cmin = primpos.min(dim=1, keepdim=True)[0] centers_norm = (primpos - cmin) / (cmax - cmin).clamp(min=1e-8) mortoncode = torch.empty((N, K), dtype=torch.int32, device=dev) mvpraymarchlib.compute_morton(centers_norm, mortoncode, algo) sortedcode, sortedobjid_long = torch.sort(mortoncode, dim=-1) sortedobjid = sortedobjid_long.int() if fixedorder: nodechildren = torch.cat([ torch.arange(1, (K - 1) * 2 + 1, dtype=torch.int32, device=dev), torch.div(torch.arange(-2, -(K * 2 + 1) - 1, -1, dtype=torch.int32, device=dev), 2, rounding_mode="floor")], dim=0).view(1, K + K - 1, 2).repeat(N, 1, 1) nodeparent = ( torch.div(torch.arange(-1, K * 2 - 2, dtype=torch.int32, device=dev), 2, rounding_mode="floor") .view(1, -1).repeat(N, 1)) else: nodechildren = torch.empty((N, K + K - 1, 2), dtype=torch.int32, device=dev) nodeparent = torch.full((N, K + K - 1), -1, dtype=torch.int32, device=dev) mvpraymarchlib.build_tree(sortedcode, nodechildren, nodeparent) nodeaabb = torch.empty((N, K + K - 1, 2, 3), dtype=torch.float32, device=dev) mvpraymarchlib.compute_aabb(*primtransfin, sortedobjid, nodechildren, nodeparent, nodeaabb, algo) return sortedobjid, nodechildren, nodeaabb class MVPRaymarch(Function): """Custom Function for raymarching Mixture of Volumetric Primitives.""" @staticmethod def forward(self, raypos, raydir, stepsize, tminmax, primpos, primrot, primscale, template, warp, rayterm, gradmode, options): algo = options["algo"] usebvh = options["usebvh"] sortprims = options["sortprims"] randomorder = options["randomorder"] maxhitboxes = options["maxhitboxes"] synchitboxes = options["synchitboxes"] chlast = options["chlast"] fadescale = options["fadescale"] fadeexp = options["fadeexp"] accum = options["accum"] termthresh = options["termthresh"] griddim = options["griddim"] if isinstance(options["blocksize"], tuple): blocksizex, blocksizey = options["blocksize"] else: blocksizex = options["blocksize"] blocksizey = 1 assert raypos.is_contiguous() and raypos.size(3) == 3 assert raydir.is_contiguous() and raydir.size(3) == 3 assert tminmax.is_contiguous() and tminmax.size(3) == 2 assert primpos is None or primpos.is_contiguous() and primpos.size(2) == 3 assert primrot is None or primrot.is_contiguous() and primrot.size(2) == 3 assert primscale is None or primscale.is_contiguous() and primscale.size(2) == 3 if chlast: assert template.is_contiguous() and len(template.size()) == 6 and template.size(-1) == 4 assert warp is None or (warp.is_contiguous() and warp.size(-1) == 3) else: assert template.is_contiguous() and len(template.size()) == 6 and template.size(2) == 4 assert warp is None or (warp.is_contiguous() and warp.size(2) == 3) primtransfin = (primpos, primrot, primscale) # Build bvh if usebvh is not False: # compute radius of primitives sortedobjid, nodechildren, nodeaabb = build_accel(primtransfin, algo, fixedorder=usebvh=="fixedorder") assert sortedobjid.is_contiguous() assert nodechildren.is_contiguous() assert nodeaabb.is_contiguous() if randomorder: sortedobjid = sortedobjid[torch.randperm(len(sortedobjid))] else: _, sortedobjid, nodechildren, nodeaabb = None, None, None, None # march through boxes N, H, W = raypos.size(0), raypos.size(1), raypos.size(2) rayrgba = torch.empty((N, H, W, 4), device=raypos.device) if gradmode: raysat = torch.full((N, H, W, 3), -1, dtype=torch.float32, device=raypos.device) rayterm = None else: raysat = None rayterm = None mvpraymarchlib.raymarch_forward( raypos, raydir, stepsize, tminmax, sortedobjid, nodechildren, nodeaabb, *primtransfin, template, warp, rayrgba, raysat, rayterm, algo, sortprims, maxhitboxes, synchitboxes, chlast, fadescale, fadeexp, accum, termthresh, griddim, blocksizex, blocksizey) self.save_for_backward( raypos, raydir, tminmax, sortedobjid, nodechildren, nodeaabb, primpos, primrot, primscale, template, warp, rayrgba, raysat, rayterm) self.options = options self.stepsize = stepsize return rayrgba @staticmethod def backward(self, grad_rayrgba): (raypos, raydir, tminmax, sortedobjid, nodechildren, nodeaabb, primpos, primrot, primscale, template, warp, rayrgba, raysat, rayterm) = self.saved_tensors algo = self.options["algo"] usebvh = self.options["usebvh"] sortprims = self.options["sortprims"] maxhitboxes = self.options["maxhitboxes"] synchitboxes = self.options["synchitboxes"] chlast = self.options["chlast"] fadescale = self.options["fadescale"] fadeexp = self.options["fadeexp"] accum = self.options["accum"] termthresh = self.options["termthresh"] griddim = self.options["griddim"] if isinstance(self.options["bwdblocksize"], tuple): blocksizex, blocksizey = self.options["bwdblocksize"] else: blocksizex = self.options["bwdblocksize"] blocksizey = 1 stepsize = self.stepsize grad_primpos = torch.zeros_like(primpos) grad_primrot = torch.zeros_like(primrot) grad_primscale = torch.zeros_like(primscale) primtransfin = (primpos, grad_primpos, primrot, grad_primrot, primscale, grad_primscale) grad_template = torch.zeros_like(template) grad_warp = torch.zeros_like(warp) if warp is not None else None mvpraymarchlib.raymarch_backward(raypos, raydir, stepsize, tminmax, sortedobjid, nodechildren, nodeaabb, *primtransfin, template, grad_template, warp, grad_warp, rayrgba, grad_rayrgba.contiguous(), raysat, rayterm, algo, sortprims, maxhitboxes, synchitboxes, chlast, fadescale, fadeexp, accum, termthresh, griddim, blocksizex, blocksizey) return (None, None, None, None, grad_primpos, grad_primrot, grad_primscale, grad_template, grad_warp, None, None, None) def mvpraymarch(raypos, raydir, stepsize, tminmax, primtransf, template, warp, rayterm=None, algo=0, usebvh="fixedorder", sortprims=False, randomorder=False, maxhitboxes=512, synchitboxes=True, chlast=True, fadescale=8., fadeexp=8., accum=0, termthresh=0., griddim=3, blocksize=(8, 16), bwdblocksize=(8, 16)): """Main entry point for raymarching MVP. Parameters: ---------- raypos: N x H x W x 3 tensor of ray origins raydir: N x H x W x 3 tensor of ray directions stepsize: raymarching step size tminmax: N x H x W x 2 tensor of raymarching min/max bounds template: N x K x 4 x TD x TH x TW tensor of K RGBA primitives warp: N x K x 3 x TD x TH x TW tensor of K warp fields (optional) primpos: N x K x 3 tensor of primitive centers primrot: N x K x 3 x 3 tensor of primitive orientations primscale: N x K x 3 tensor of primitive inverse dimension lengths algo: algorithm for raymarching (valid values: 0, 1). algo=0 is the fastest. Currently algo=0 has a limit of 512 primitives per ray, so problems can occur if there are many more boxes. all sortprims=True options have this limitation, but you can use (algo=1, sortprims=False, usebvh="fixedorder") which works correctly and has no primitive number limitation (but is slightly slower). usebvh: True to use bvh, "fixedorder" for a simple BVH, False for no bvh sortprims: True to sort overlapping primitives at a sample point. Must be True for gradients to match the PyTorch gradients. Seems unstable if False but also not a big performance bottleneck. chlast: whether template is provided as channels last or not. True tends to be faster. fadescale: Opacity is faded at the borders of the primitives by the equation exp(-fadescale * x ** fadeexp) where x is the normalized coordinates of the primitive. fadeexp: Opacity is faded at the borders of the primitives by the equation exp(-fadescale * x ** fadeexp) where x is the normalized coordinates of the primitive. griddim: CUDA grid dimensionality. blocksize: blocksize of CUDA kernels. Should be 2-element tuple if griddim>1, or integer if griddim==1.""" if isinstance(primtransf, tuple): primpos, primrot, primscale = primtransf else: primpos, primrot, primscale = ( primtransf[:, :, 0, :].contiguous(), primtransf[:, :, 1:4, :].contiguous(), primtransf[:, :, 4, :].contiguous()) primtransfin = (primpos, primrot, primscale) out = MVPRaymarch.apply(raypos, raydir, stepsize, tminmax, *primtransfin, template, warp, rayterm, torch.is_grad_enabled(), {"algo": algo, "usebvh": usebvh, "sortprims": sortprims, "randomorder": randomorder, "maxhitboxes": maxhitboxes, "synchitboxes": synchitboxes, "chlast": chlast, "fadescale": fadescale, "fadeexp": fadeexp, "accum": accum, "termthresh": termthresh, "griddim": griddim, "blocksize": blocksize, "bwdblocksize": bwdblocksize}) return out class Rodrigues(nn.Module): def __init__(self): super(Rodrigues, self).__init__() def forward(self, rvec): theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=1)) rvec = rvec / theta[:, None] costh = torch.cos(theta) sinth = torch.sin(theta) return torch.stack(( rvec[:, 0] ** 2 + (1. - rvec[:, 0] ** 2) * costh, rvec[:, 0] * rvec[:, 1] * (1. - costh) - rvec[:, 2] * sinth, rvec[:, 0] * rvec[:, 2] * (1. - costh) + rvec[:, 1] * sinth, rvec[:, 0] * rvec[:, 1] * (1. - costh) + rvec[:, 2] * sinth, rvec[:, 1] ** 2 + (1. - rvec[:, 1] ** 2) * costh, rvec[:, 1] * rvec[:, 2] * (1. - costh) - rvec[:, 0] * sinth, rvec[:, 0] * rvec[:, 2] * (1. - costh) - rvec[:, 1] * sinth, rvec[:, 1] * rvec[:, 2] * (1. - costh) + rvec[:, 0] * sinth, rvec[:, 2] ** 2 + (1. - rvec[:, 2] ** 2) * costh), dim=1).view(-1, 3, 3) def gradcheck(usebvh=True, sortprims=True, maxhitboxes=512, synchitboxes=False, dowarp=False, chlast=False, fadescale=8., fadeexp=8., accum=0, termthresh=0., algo=0, griddim=2, blocksize=(8, 16), bwdblocksize=(8, 16)): N = 2 H = 65 W = 65 k3 = 4 K = k3*k3*k3 M = 32 print("=================================================================") print("usebvh={}, sortprims={}, maxhb={}, synchb={}, dowarp={}, chlast={}, " "fadescale={}, fadeexp={}, accum={}, termthresh={}, algo={}, griddim={}, " "blocksize={}, bwdblocksize={}".format( usebvh, sortprims, maxhitboxes, synchitboxes, dowarp, chlast, fadescale, fadeexp, accum, termthresh, algo, griddim, blocksize, bwdblocksize)) # generate random inputs torch.manual_seed(1112) coherent_rays = True if not coherent_rays: _raypos = torch.randn(N, H, W, 3).to("cuda") _raydir = torch.randn(N, H, W, 3).to("cuda") _raydir /= torch.sqrt(torch.sum(_raydir ** 2, dim=-1, keepdim=True)) else: focal = torch.tensor([[W*4.0, W*4.0] for n in range(N)]) princpt = torch.tensor([[W*0.5, H*0.5] for n in range(N)]) pixely, pixelx = torch.meshgrid(torch.arange(H).float(), torch.arange(W).float()) pixelcoords = torch.stack([pixelx, pixely], dim=-1)[None, :, :, :].repeat(N, 1, 1, 1) raydir = (pixelcoords - princpt[:, None, None, :]) / focal[:, None, None, :] raydir = torch.cat([raydir, torch.ones_like(raydir[:, :, :, 0:1])], dim=-1) raydir = raydir / torch.sqrt(torch.sum(raydir ** 2, dim=-1, keepdim=True)) _raypos = torch.tensor([-0.0, 0.0, -4.])[None, None, None, :].repeat(N, H, W, 1).to("cuda") _raydir = raydir.to("cuda") _raydir /= torch.sqrt(torch.sum(_raydir ** 2, dim=-1, keepdim=True)) max_len = 6.0 _stepsize = max_len / 15.386928 _tminmax = max_len*torch.arange(2, dtype=torch.float32)[None, None, None, :].repeat(N, H, W, 1).to("cuda") + \ torch.rand(N, H, W, 2, device="cuda") * 1. _template = torch.randn(N, K, 4, M, M, M, requires_grad=True) _template.data[:, :, -1, :, :, :] -= 3.5 _template = _template.contiguous().detach().clone() _template.requires_grad = True gridxyz = torch.stack(torch.meshgrid( torch.linspace(-1., 1., M//2), torch.linspace(-1., 1., M//2), torch.linspace(-1., 1., M//2))[::-1], dim=0).contiguous() _warp = (torch.randn(N, K, 3, M//2, M//2, M//2) * 0.01 + gridxyz[None, None, :, :, :, :]).contiguous().detach().clone() _warp.requires_grad = True _primpos = torch.randn(N, K, 3, requires_grad=True) _primpos = torch.randn(N, K, 3, requires_grad=True) coherent_centers = True if coherent_centers: ns = k3 #assert ns*ns*ns==K grid3d = torch.stack(torch.meshgrid( torch.linspace(-1., 1., ns), torch.linspace(-1., 1., ns), torch.linspace(-1., 1., K//(ns*ns)))[::-1], dim=0)[None] _primpos = (( grid3d.permute((0, 2, 3, 4, 1)).reshape(1, K, 3).expand(N, -1, -1) + 0.1 * torch.randn(N, K, 3, requires_grad=True) )).contiguous().detach().clone() _primpos.requires_grad = True scale_ws = 1. _primrot = torch.randn(N, K, 3) rodrigues = Rodrigues() _primrot = rodrigues(_primrot.view(-1, 3)).view(N, K, 3, 3).contiguous().detach().clone() _primrot.requires_grad = True _primscale = torch.randn(N, K, 3, requires_grad=True) _primscale.data *= 0.0 if dowarp: params = [_template, _warp, _primscale, _primrot, _primpos] paramnames = ["template", "warp", "primscale", "primrot", "primpos"] else: params = [_template, _primscale, _primrot, _primpos] paramnames = ["template", "primscale", "primrot", "primpos"] termthreshorig = termthresh ########################### run pytorch version ########################### raypos = _raypos raydir = _raydir stepsize = _stepsize tminmax = _tminmax #template = F.softplus(_template.to("cuda") * 1.5) template = F.softplus(_template.to("cuda") * 1.5) if algo != 2 else _template.to("cuda") * 1.5 warp = _warp.to("cuda") primpos = _primpos.to("cuda") * 0.3 primrot = _primrot.to("cuda") primscale = scale_ws * torch.exp(0.1 * _primscale.to("cuda")) # python raymarching implementation rayrgba = torch.zeros((N, H, W, 4)).to("cuda") raypos = raypos + raydir * tminmax[:, :, :, 0, None] t = tminmax[:, :, :, 0] step = 0 t0 = t.detach().clone() raypos0 = raypos.detach().clone() torch.cuda.synchronize() time0 = time.time() while (t < tminmax[:, :, :, 1]).any(): valid2 = torch.ones_like(rayrgba[:, :, :, 3:4]) for k in range(K): y0 = torch.bmm( (raypos - primpos[:, k, None, None, :]).view(raypos.size(0), -1, raypos.size(3)), primrot[:, k, :, :]).view_as(raypos) * primscale[:, k, None, None, :] fade = torch.exp(-fadescale * torch.sum(torch.abs(y0) ** fadeexp, dim=-1, keepdim=True)) if dowarp: y1 = F.grid_sample( warp[:, k, :, :, :, :], y0[:, None, :, :, :], align_corners=True)[:, :, 0, :, :].permute(0, 2, 3, 1) else: y1 = y0 sample = F.grid_sample( template[:, k, :, :, :, :], y1[:, None, :, :, :], align_corners=True)[:, :, 0, :, :].permute(0, 2, 3, 1) valid1 = ( torch.prod(y0[:, :, :, :] >= -1., dim=-1, keepdim=True) * torch.prod(y0[:, :, :, :] <= 1., dim=-1, keepdim=True)) valid = ((t >= tminmax[:, :, :, 0]) & (t < tminmax[:, :, :, 1])).float()[:, :, :, None] alpha0 = sample[:, :, :, 3:4] rgb = sample[:, :, :, 0:3] * valid * valid1 alpha = alpha0 * fade * stepsize * valid * valid1 if accum == 0: newalpha = rayrgba[:, :, :, 3:4] + alpha contrib = (newalpha.clamp(max=1.0) - rayrgba[:, :, :, 3:4]) * valid * valid1 rayrgba = rayrgba + contrib * torch.cat([rgb, torch.ones_like(alpha)], dim=-1) else: raise step += 1 t = t0 + stepsize * step raypos = raypos0 + raydir * stepsize * step print(rayrgba[..., -1].min().item(), rayrgba[..., -1].max().item()) sample0 = rayrgba torch.cuda.synchronize() time1 = time.time() sample0.backward(torch.ones_like(sample0)) torch.cuda.synchronize() time2 = time.time() print("{:<10} {:>10} {:>10} {:>10}".format("", "fwd", "bwd", "total")) print("{:<10} {:10.5} {:10.5} {:10.5}".format("pytime", time1 - time0, time2 - time1, time2 - time0)) grads0 = [p.grad.detach().clone() for p in params] for p in params: p.grad.detach_() p.grad.zero_() ############################## run cuda version ########################### raypos = _raypos raydir = _raydir stepsize = _stepsize tminmax = _tminmax template = F.softplus(_template.to("cuda") * 1.5) if algo != 2 else _template.to("cuda") * 1.5 warp = _warp.to("cuda") if chlast: template = template.permute(0, 1, 3, 4, 5, 2).contiguous() warp = warp.permute(0, 1, 3, 4, 5, 2).contiguous() primpos = _primpos.to("cuda") * 0.3 primrot = _primrot.to("cuda") primscale = scale_ws * torch.exp(0.1 * _primscale.to("cuda")) niter = 1 tf, tb = 0., 0. for i in range(niter): for p in params: try: p.grad.detach_() p.grad.zero_() except: pass t0 = time.time() torch.cuda.synchronize() sample1 = mvpraymarch(raypos, raydir, stepsize, tminmax, (primpos, primrot, primscale), template, warp if dowarp else None, algo=algo, usebvh=usebvh, sortprims=sortprims, maxhitboxes=maxhitboxes, synchitboxes=synchitboxes, chlast=chlast, fadescale=fadescale, fadeexp=fadeexp, accum=accum, termthresh=termthreshorig, griddim=griddim, blocksize=blocksize, bwdblocksize=bwdblocksize) t1 = time.time() torch.cuda.synchronize() sample1.backward(torch.ones_like(sample1), retain_graph=True) torch.cuda.synchronize() t2 = time.time() tf += t1 - t0 tb += t2 - t1 print("{:<10} {:10.5} {:10.5} {:10.5}".format("time", tf / niter, tb / niter, (tf + tb) / niter)) grads1 = [p.grad.detach().clone() for p in params] ############# compare results ############# print("-----------------------------------------------------------------") print("{:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10}".format("", "maxabsdiff", "dp", "||py||", "||cuda||", "index", "py", "cuda")) ind = torch.argmax(torch.abs(sample0 - sample1)) print("{:<10} {:>10.5} {:>10.5} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format( "fwd", torch.max(torch.abs(sample0 - sample1)).item(), (torch.sum(sample0 * sample1) / torch.sqrt(torch.sum(sample0 * sample0) * torch.sum(sample1 * sample1))).item(), torch.sqrt(torch.sum(sample0 * sample0)).item(), torch.sqrt(torch.sum(sample1 * sample1)).item(), ind.item(), sample0.view(-1)[ind].item(), sample1.view(-1)[ind].item())) for p, g0, g1 in zip(paramnames, grads0, grads1): ind = torch.argmax(torch.abs(g0 - g1)) print("{:<10} {:>10.5} {:>10.5} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format( p, torch.max(torch.abs(g0 - g1)).item(), (torch.sum(g0 * g1) / torch.sqrt(torch.sum(g0 * g0) * torch.sum(g1 * g1))).item(), torch.sqrt(torch.sum(g0 * g0)).item(), torch.sqrt(torch.sum(g1 * g1)).item(), ind.item(), g0.view(-1)[ind].item(), g1.view(-1)[ind].item())) if __name__ == "__main__": gradcheck(usebvh="fixedorder", sortprims=False, maxhitboxes=512, synchitboxes=True, dowarp=False, chlast=True, fadescale=6.5, fadeexp=7.5, accum=0, algo=0, griddim=3) gradcheck(usebvh="fixedorder", sortprims=False, maxhitboxes=512, synchitboxes=True, dowarp=True, chlast=True, fadescale=6.5, fadeexp=7.5, accum=0, algo=1, griddim=3)