// Copyright (c) Meta Platforms, Inc. and affiliates. // All rights reserved. // // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. template< int maxhitboxes, int nwarps, class RaySubsetT=RaySubsetFixedBVH, class PrimTransfT=PrimTransfSRT, class PrimSamplerT=PrimSamplerTW, class PrimAccumT=PrimAccumAdditive> __global__ void raymarch_subset_forward_kernel( int N, int H, int W, int K, float3 * rayposim, float3 * raydirim, float stepsize, float2 * tminmaxim, int * sortedobjid, int2 * nodechildren, float3 * nodeaabb, typename PrimTransfT::Data primtransf_data, typename PrimSamplerT::Data primsampler_data, typename PrimAccumT::Data primaccum_data ) { int w = blockIdx.x * blockDim.x + threadIdx.x; int h = blockIdx.y * blockDim.y + threadIdx.y; int n = blockIdx.z; bool validthread = (w < W) && (h < H) && (n 0 ? 1 : maxhitboxes]; __shared__ int hitboxes_sh[nwarps > 0 ? maxhitboxes * nwarps : 1]; int * hitboxes_ptr = nwarps > 0 ? hitboxes_sh + maxhitboxes * warpid : hitboxes; int nhitboxes = 0; // find raytminmax float2 rtminmax = make_float2(std::numeric_limits::infinity(), -std::numeric_limits::infinity()); RaySubsetT::forward(warpmask, K, raypos, raydir, tminmax, rtminmax, sortedobjid, nodechildren, nodeaabb, primtransf_data, hitboxes_ptr, nhitboxes); rtminmax.x = max(rtminmax.x, tminmax.x); rtminmax.y = min(rtminmax.y, tminmax.y); __syncwarp(warpmask); float t = tminmax.x; raypos = raypos + raydir * tminmax.x; int incs = floor((rtminmax.x - t) / stepsize); t += incs * stepsize; raypos += raydir * incs * stepsize; PrimAccumT pa; while (!__all_sync(warpmask, t > rtminmax.y + 1e-5f || pa.is_done())) { for (int ks = 0; ks < nhitboxes; ++ks) { int k = hitboxes_ptr[ks]; // compute primitive-relative coordinate PrimTransfT pt; float3 samplepos = pt.forward(primtransf_data, k, raypos); if (pt.valid(samplepos) && !pa.is_done() && t < rtminmax.y + 1e-5f) { // sample PrimSamplerT ps; float4 sample = ps.forward(primsampler_data, k, samplepos); // accumulate pa.forward_prim(primaccum_data, sample, stepsize); } } // update position t += stepsize; raypos += raydir * stepsize; } pa.write(primaccum_data); } template < bool forwarddir, int maxhitboxes, int nwarps, class RaySubsetT=RaySubsetFixedBVH, class PrimTransfT=PrimTransfSRT, class PrimSamplerT=PrimSamplerTW, class PrimAccumT=PrimAccumAdditive> __global__ void raymarch_subset_backward_kernel( int N, int H, int W, int K, float3 * rayposim, float3 * raydirim, float stepsize, float2 * tminmaxim, int * sortedobjid, int2 * nodechildren, float3 * nodeaabb, typename PrimTransfT::Data primtransf_data, typename PrimSamplerT::Data primsampler_data, typename PrimAccumT::Data primaccum_data ) { int w = blockIdx.x * blockDim.x + threadIdx.x; int h = blockIdx.y * blockDim.y + threadIdx.y; int n = blockIdx.z; bool validthread = (w < W) && (h < H) && (n 0 ? 1 : maxhitboxes]; __shared__ int hitboxes_sh[nwarps > 0 ? maxhitboxes * nwarps : 1]; int * hitboxes_ptr = nwarps > 0 ? hitboxes_sh + maxhitboxes * warpid : hitboxes; int nhitboxes = 0; // find raytminmax float2 rtminmax = make_float2(std::numeric_limits::infinity(), -std::numeric_limits::infinity()); RaySubsetT::forward(warpmask, K, raypos, raydir, tminmax, rtminmax, sortedobjid, nodechildren, nodeaabb, primtransf_data, hitboxes_ptr, nhitboxes); rtminmax.x = max(rtminmax.x, tminmax.x); rtminmax.y = min(rtminmax.y, tminmax.y); __syncwarp(warpmask); // set up raymarching position float t = tminmax.x; raypos = raypos + raydir * tminmax.x; int incs = floor((rtminmax.x - t) / stepsize); t += incs * stepsize; raypos += raydir * incs * stepsize; if (!forwarddir) { int nsteps = pa.get_nsteps(); t += nsteps * stepsize; raypos += raydir * nsteps * stepsize; } while (__any_sync(warpmask, ( (forwarddir && t < rtminmax.y + 1e-5f || !forwarddir && t > rtminmax.x - 1e-5f) && !pa.is_done()))) { for (int ks = 0; ks < nhitboxes; ++ks) { int k = hitboxes_ptr[forwarddir ? ks : nhitboxes - ks - 1]; PrimTransfT pt; float3 samplepos = pt.forward(primtransf_data, k, raypos); bool evalprim = pt.valid(samplepos) && !pa.is_done() && t < rtminmax.y + 1e-5f; float3 dL_samplepos = make_float3(0.f); if (evalprim) { PrimSamplerT ps; float4 sample = ps.forward(primsampler_data, k, samplepos); float4 dL_sample = pa.forwardbackward_prim(primaccum_data, sample, stepsize); dL_samplepos = ps.backward(primsampler_data, k, samplepos, sample, dL_sample, validthread); } if (__any_sync(warpmask, evalprim)) { pt.backward(primtransf_data, k, samplepos, dL_samplepos, validthread && evalprim); } } if (forwarddir) { t += stepsize; raypos += raydir * stepsize; } else { t -= stepsize; raypos -= raydir * stepsize; } } }