// Copyright (c) Meta Platforms, Inc. and affiliates. // All rights reserved. // // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. #ifndef MVPRAYMARCHER_PRIMSAMPLER_H_ #define MVPRAYMARCHER_PRIMSAMPLER_H_ struct PrimSamplerDataBase { typedef PrimSamplerDataBase base; }; template< bool dowarp, template class GridSamplerT=GridSamplerChlast> struct PrimSamplerTW { struct Data : public PrimSamplerDataBase { float fadescale, fadeexp; int tplate_nstride; int TD, TH, TW; float * tplate; float * grad_tplate; int warp_nstride; int WD, WH, WW; float * warp; float * grad_warp; __forceinline__ __device__ void n_stride(int n) { tplate += n * tplate_nstride; grad_tplate += n * tplate_nstride; warp += n * warp_nstride; grad_warp += n * warp_nstride; } }; float fade; float * tplate_ptr; float * warp_ptr; float3 yy1; __forceinline__ __device__ float4 forward( const Data & data, int k, float3 y0) { fade = __expf(-data.fadescale * ( __powf(abs(y0.x), data.fadeexp) + __powf(abs(y0.y), data.fadeexp) + __powf(abs(y0.z), data.fadeexp))); if (dowarp) { warp_ptr = data.warp + (k * 3 * data.WD * data.WH * data.WW); yy1 = GridSamplerT::forward(3, data.WD, data.WH, data.WW, warp_ptr, y0, false); } else { yy1 = y0; } tplate_ptr = data.tplate + (k * 4 * data.TD * data.TH * data.TW); float4 sample = GridSamplerT::forward(4, data.TD, data.TH, data.TW, tplate_ptr, yy1, false); sample.w *= fade; return sample; } __forceinline__ __device__ float3 backward(const Data & data, int k, float3 y0, float4 sample, float4 dL_sample, bool validthread) { float3 dfade_y0 = -(data.fadescale * data.fadeexp) * make_float3( __powf(abs(y0.x), data.fadeexp - 1.f) * (y0.x > 0.f ? 1.f : -1.f), __powf(abs(y0.y), data.fadeexp - 1.f) * (y0.y > 0.f ? 1.f : -1.f), __powf(abs(y0.z), data.fadeexp - 1.f) * (y0.z > 0.f ? 1.f : -1.f)); float3 dL_y0 = dfade_y0 * sample.w * dL_sample.w; dL_sample.w *= fade; float * grad_tplate_ptr = data.grad_tplate + (k * 4 * data.TD * data.TH * data.TW); float3 dL_y1 = GridSamplerT::backward(4, data.TD, data.TH, data.TW, tplate_ptr, grad_tplate_ptr, yy1, validthread ? dL_sample : make_float4(0.f), false); if (dowarp) { float * grad_warp_ptr = data.grad_warp + (k * 3 * data.WD * data.WH * data.WW); dL_y0 += GridSamplerT::backward(3, data.WD, data.WH, data.WW, warp_ptr, grad_warp_ptr, y0, validthread ? dL_y1 : make_float3(0.f), false); } else { dL_y0 += dL_y1; } return dL_y0; } }; #endif