// Copyright (c) Meta Platforms, Inc. and affiliates. // All rights reserved. // // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. #ifndef MVPRAYMARCHER_UTILS_H_ #define MVPRAYMARCHER_UTILS_H_ #include #include #include #include "helper_math.h" static __forceinline__ __device__ float clock_diff(long long int end, long long int start) { long long int max_clock = std::numeric_limits::max(); return (end= b.x && a.y >= b.y && a.z >= b.z; } static __forceinline__ __device__ bool alllt(float3 a, float3 b) { return a.x <= b.x && a.y <= b.y && a.z <= b.z; } static __forceinline__ __device__ float4 softplus(float4 x) { return make_float4( x.x > 20.f ? x.x : logf(1.f + expf(x.x)), x.y > 20.f ? x.y : logf(1.f + expf(x.y)), x.z > 20.f ? x.z : logf(1.f + expf(x.z)), x.w > 20.f ? x.w : logf(1.f + expf(x.w))); } static __forceinline__ __device__ float softplus(float x) { // that's a neat trick return __logf(1.f + __expf(-abs(x))) + max(x, 0.f); } static __forceinline__ __device__ float softplus_grad(float x) { // that's a neat trick float expnabsx = __expf(-abs(x)); return (0.5f - expnabsx / (1.f + expnabsx)) * copysign(1.f, x) + 0.5f; } static __forceinline__ __device__ float4 sigmoid(float4 x) { return make_float4( 1.f / (1.f + expf(-x.x)), 1.f / (1.f + expf(-x.y)), 1.f / (1.f + expf(-x.z)), 1.f / (1.f + expf(-x.w))); } // perform reduction on warp, then call atomicAdd for only one lane static __forceinline__ __device__ void fastAtomicAdd(float * ptr, float val) { for (int offset = 16; offset > 0; offset /= 2) { val += __shfl_down_sync(0xffffffff, val, offset); } const int laneid = (threadIdx.y * blockDim.x + threadIdx.x) % 32; if (laneid == 0) { atomicAdd(ptr, val); } } static __forceinline__ __device__ bool within_bounds_3d(int d, int h, int w, int D, int H, int W) { return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; } static __forceinline__ __device__ void safe_add_3d(float *data, int d, int h, int w, int sD, int sH, int sW, int D, int H, int W, float delta) { if (within_bounds_3d(d, h, w, D, H, W)) { atomicAdd(data + d * sD + h * sH + w * sW, delta); } } static __forceinline__ __device__ void safe_add_3d(float3 *data, int d, int h, int w, int sD, int sH, int sW, int D, int H, int W, float3 delta) { if (within_bounds_3d(d, h, w, D, H, W)) { atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 3 + 0, delta.x); atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 3 + 1, delta.y); atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 3 + 2, delta.z); } } static __forceinline__ __device__ void safe_add_3d(float4 *data, int d, int h, int w, int sD, int sH, int sW, int D, int H, int W, float4 delta) { if (within_bounds_3d(d, h, w, D, H, W)) { atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 4 + 0, delta.x); atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 4 + 1, delta.y); atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 4 + 2, delta.z); atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 4 + 3, delta.w); } } static __forceinline__ __device__ float clip_coordinates(float in, int clip_limit) { return ::min(static_cast(clip_limit - 1), ::max(in, 0.f)); } template static __forceinline__ __device__ float clip_coordinates_set_grad(float in, int clip_limit, scalar_t *grad_in) { if (in < 0.f) { *grad_in = static_cast(0); return 0.f; } else { float max = static_cast(clip_limit - 1); if (in > max) { *grad_in = static_cast(0); return max; } else { *grad_in = static_cast(1); return in; } } } template static __device__ out_t grid_sample_forward(int C, int inp_D, int inp_H, int inp_W, float* vals, float3 pos, bool border) { int inp_sW = 1, inp_sH = inp_W, inp_sD = inp_W * inp_H, inp_sC = inp_W * inp_H * inp_D; int out_sC = 1; // normalize ix, iy, iz from [-1, 1] to [0, inp_W-1] & [0, inp_H-1] & [0, inp_D-1] float ix = max(-10.f, min(10.f, ((pos.x + 1.f) * 0.5f))) * (inp_W - 1); float iy = max(-10.f, min(10.f, ((pos.y + 1.f) * 0.5f))) * (inp_H - 1); float iz = max(-10.f, min(10.f, ((pos.z + 1.f) * 0.5f))) * (inp_D - 1); if (border) { // clip coordinates to image borders ix = clip_coordinates(ix, inp_W); iy = clip_coordinates(iy, inp_H); iz = clip_coordinates(iz, inp_D); } // get corner pixel values from (x, y, z) // for 4d, we used north-east-south-west // for 5d, we add top-bottom int ix_tnw = static_cast(::floor(ix)); int iy_tnw = static_cast(::floor(iy)); int iz_tnw = static_cast(::floor(iz)); int ix_tne = ix_tnw + 1; int iy_tne = iy_tnw; int iz_tne = iz_tnw; int ix_tsw = ix_tnw; int iy_tsw = iy_tnw + 1; int iz_tsw = iz_tnw; int ix_tse = ix_tnw + 1; int iy_tse = iy_tnw + 1; int iz_tse = iz_tnw; int ix_bnw = ix_tnw; int iy_bnw = iy_tnw; int iz_bnw = iz_tnw + 1; int ix_bne = ix_tnw + 1; int iy_bne = iy_tnw; int iz_bne = iz_tnw + 1; int ix_bsw = ix_tnw; int iy_bsw = iy_tnw + 1; int iz_bsw = iz_tnw + 1; int ix_bse = ix_tnw + 1; int iy_bse = iy_tnw + 1; int iz_bse = iz_tnw + 1; // get surfaces to each neighbor: float tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); float tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); float tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); float tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); float bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); float bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); float bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); float bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); out_t result; //auto inp_ptr_NC = input.data + n * inp_sN; //auto out_ptr_NCDHW = output.data + n * out_sN + d * out_sD + h * out_sH + w * out_sW; float * inp_ptr_NC = vals; float * out_ptr_NCDHW = &result.x; for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { // (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse *out_ptr_NCDHW = static_cast(0); if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw; } if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne; } if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw; } if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse; } if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw; } if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne; } if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw; } if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse; } } return result; } template static __device__ float3 grid_sample_backward(int C, int inp_D, int inp_H, int inp_W, float* vals, float* grad_vals, float3 pos, out_t grad_out, bool border) { int inp_sW = 1, inp_sH = inp_W, inp_sD = inp_W * inp_H, inp_sC = inp_W * inp_H * inp_D; int gInp_sW = 1, gInp_sH = inp_W, gInp_sD = inp_W * inp_H, gInp_sC = inp_W * inp_H * inp_D; int gOut_sC = 1; // normalize ix, iy, iz from [-1, 1] to [0, inp_W-1] & [0, inp_H-1] & [0, inp_D-1] float ix = max(-10.f, min(10.f, ((pos.x + 1.f) * 0.5f))) * (inp_W - 1); float iy = max(-10.f, min(10.f, ((pos.y + 1.f) * 0.5f))) * (inp_H - 1); float iz = max(-10.f, min(10.f, ((pos.z + 1.f) * 0.5f))) * (inp_D - 1); float gix_mult = (inp_W - 1.f) / 2; float giy_mult = (inp_H - 1.f) / 2; float giz_mult = (inp_D - 1.f) / 2; if (border) { // clip coordinates to image borders ix = clip_coordinates_set_grad(ix, inp_W, &gix_mult); iy = clip_coordinates_set_grad(iy, inp_H, &giy_mult); iz = clip_coordinates_set_grad(iz, inp_D, &giz_mult); } // get corner pixel values from (x, y, z) // for 4d, we used north-east-south-west // for 5d, we add top-bottom int ix_tnw = static_cast(::floor(ix)); int iy_tnw = static_cast(::floor(iy)); int iz_tnw = static_cast(::floor(iz)); int ix_tne = ix_tnw + 1; int iy_tne = iy_tnw; int iz_tne = iz_tnw; int ix_tsw = ix_tnw; int iy_tsw = iy_tnw + 1; int iz_tsw = iz_tnw; int ix_tse = ix_tnw + 1; int iy_tse = iy_tnw + 1; int iz_tse = iz_tnw; int ix_bnw = ix_tnw; int iy_bnw = iy_tnw; int iz_bnw = iz_tnw + 1; int ix_bne = ix_tnw + 1; int iy_bne = iy_tnw; int iz_bne = iz_tnw + 1; int ix_bsw = ix_tnw; int iy_bsw = iy_tnw + 1; int iz_bsw = iz_tnw + 1; int ix_bse = ix_tnw + 1; int iy_bse = iy_tnw + 1; int iz_bse = iz_tnw + 1; // get surfaces to each neighbor: float tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); float tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); float tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); float tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); float bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); float bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); float bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); float bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); float gix = static_cast(0), giy = static_cast(0), giz = static_cast(0); //float *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW; //float *gInp_ptr_NC = grad_input.data + n * gInp_sN; //float *inp_ptr_NC = input.data + n * inp_sN; float *gOut_ptr_NCDHW = &grad_out.x; float *gInp_ptr_NC = grad_vals; float *inp_ptr_NC = vals; // calculate bilinear weighted pixel value and set output pixel for (int c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC += inp_sC) { float gOut = *gOut_ptr_NCDHW; // calculate and set grad_input safe_add_3d(gInp_ptr_NC, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut); safe_add_3d(gInp_ptr_NC, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut); safe_add_3d(gInp_ptr_NC, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut); safe_add_3d(gInp_ptr_NC, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut); safe_add_3d(gInp_ptr_NC, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut); safe_add_3d(gInp_ptr_NC, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut); safe_add_3d(gInp_ptr_NC, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut); safe_add_3d(gInp_ptr_NC, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut); // calculate grad_grid if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { float tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW]; gix -= tnw_val * (iy_bse - iy) * (iz_bse - iz) * gOut; giy -= tnw_val * (ix_bse - ix) * (iz_bse - iz) * gOut; giz -= tnw_val * (ix_bse - ix) * (iy_bse - iy) * gOut; } if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { float tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW]; gix += tne_val * (iy_bsw - iy) * (iz_bsw - iz) * gOut; giy -= tne_val * (ix - ix_bsw) * (iz_bsw - iz) * gOut; giz -= tne_val * (ix - ix_bsw) * (iy_bsw - iy) * gOut; } if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { float tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW]; gix -= tsw_val * (iy - iy_bne) * (iz_bne - iz) * gOut; giy += tsw_val * (ix_bne - ix) * (iz_bne - iz) * gOut; giz -= tsw_val * (ix_bne - ix) * (iy - iy_bne) * gOut; } if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { float tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW]; gix += tse_val * (iy - iy_bnw) * (iz_bnw - iz) * gOut; giy += tse_val * (ix - ix_bnw) * (iz_bnw - iz) * gOut; giz -= tse_val * (ix - ix_bnw) * (iy - iy_bnw) * gOut; } if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { float bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW]; gix -= bnw_val * (iy_tse - iy) * (iz - iz_tse) * gOut; giy -= bnw_val * (ix_tse - ix) * (iz - iz_tse) * gOut; giz += bnw_val * (ix_tse - ix) * (iy_tse - iy) * gOut; } if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { float bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW]; gix += bne_val * (iy_tsw - iy) * (iz - iz_tsw) * gOut; giy -= bne_val * (ix - ix_tsw) * (iz - iz_tsw) * gOut; giz += bne_val * (ix - ix_tsw) * (iy_tsw - iy) * gOut; } if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { float bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW]; gix -= bsw_val * (iy - iy_tne) * (iz - iz_tne) * gOut; giy += bsw_val * (ix_tne - ix) * (iz - iz_tne) * gOut; giz += bsw_val * (ix_tne - ix) * (iy - iy_tne) * gOut; } if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { float bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW]; gix += bse_val * (iy - iy_tnw) * (iz - iz_tnw) * gOut; giy += bse_val * (ix - ix_tnw) * (iz - iz_tnw) * gOut; giz += bse_val * (ix - ix_tnw) * (iy - iy_tnw) * gOut; } } return make_float3(gix_mult * gix, giy_mult * giy, giz_mult * giz); } // this dummy struct necessary because c++ is dumb template struct GridSampler { static __forceinline__ __device__ out_t forward(int C, int inp_D, int inp_H, int inp_W, float* vals, float3 pos, bool border) { return grid_sample_forward(C, inp_D, inp_H, inp_W, vals, pos, border); } static __forceinline__ __device__ float3 backward(int C, int inp_D, int inp_H, int inp_W, float* vals, float* grad_vals, float3 pos, out_t grad_out, bool border) { return grid_sample_backward(C, inp_D, inp_H, inp_W, vals, grad_vals, pos, grad_out, border); } }; //template //__device__ void cswap ( T& a, T& b ) { // T c(a); a=b; b=c; //} static __forceinline__ __device__ int within_bounds_3d_ind(int d, int h, int w, int D, int H, int W) { return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W ? ((d * H) + h) * W + w : -1; } template static __device__ out_t grid_sample_chlast_forward(int, int inp_D, int inp_H, int inp_W, float * vals, float3 pos, bool border) { int inp_sW = 1, inp_sH = inp_W, inp_sD = inp_W * inp_H; // normalize ix, iy, iz from [-1, 1] to [0, inp_W-1] & [0, inp_H-1] & [0, inp_D-1] float ix = max(-100.f, min(100.f, ((pos.x + 1.f) / 2))) * (inp_W - 1); float iy = max(-100.f, min(100.f, ((pos.y + 1.f) / 2))) * (inp_H - 1); float iz = max(-100.f, min(100.f, ((pos.z + 1.f) / 2))) * (inp_D - 1); if (border) { // clip coordinates to image borders ix = clip_coordinates(ix, inp_W); iy = clip_coordinates(iy, inp_H); iz = clip_coordinates(iz, inp_D); } // get corner pixel values from (x, y, z) // for 4d, we used north-east-south-west // for 5d, we add top-bottom int ix_tnw = static_cast(::floor(ix)); int iy_tnw = static_cast(::floor(iy)); int iz_tnw = static_cast(::floor(iz)); int ix_tne = ix_tnw + 1; int iy_tne = iy_tnw; int iz_tne = iz_tnw; int ix_tsw = ix_tnw; int iy_tsw = iy_tnw + 1; int iz_tsw = iz_tnw; int ix_tse = ix_tnw + 1; int iy_tse = iy_tnw + 1; int iz_tse = iz_tnw; int ix_bnw = ix_tnw; int iy_bnw = iy_tnw; int iz_bnw = iz_tnw + 1; int ix_bne = ix_tnw + 1; int iy_bne = iy_tnw; int iz_bne = iz_tnw + 1; int ix_bsw = ix_tnw; int iy_bsw = iy_tnw + 1; int iz_bsw = iz_tnw + 1; int ix_bse = ix_tnw + 1; int iy_bse = iy_tnw + 1; int iz_bse = iz_tnw + 1; // get surfaces to each neighbor: float tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); float tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); float tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); float tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); float bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); float bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); float bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); float bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); out_t result; memset(&result, 0, sizeof(out_t)); out_t * inp_ptr_NC = (out_t*)vals; out_t * out_ptr_NCDHW = &result; { if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw; } if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne; } if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw; } if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse; } if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw; } if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne; } if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw; } if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse; } } return result; } template static __device__ float3 grid_sample_chlast_backward(int, int inp_D, int inp_H, int inp_W, float* vals, float* grad_vals, float3 pos, out_t grad_out, bool border) { int inp_sW = 1, inp_sH = inp_W, inp_sD = inp_W * inp_H; int gInp_sW = 1, gInp_sH = inp_W, gInp_sD = inp_W * inp_H; // normalize ix, iy, iz from [-1, 1] to [0, inp_W-1] & [0, inp_H-1] & [0, inp_D-1] float ix = max(-100.f, min(100.f, ((pos.x + 1.f) / 2))) * (inp_W - 1); float iy = max(-100.f, min(100.f, ((pos.y + 1.f) / 2))) * (inp_H - 1); float iz = max(-100.f, min(100.f, ((pos.z + 1.f) / 2))) * (inp_D - 1); float gix_mult = (inp_W - 1.f) / 2; float giy_mult = (inp_H - 1.f) / 2; float giz_mult = (inp_D - 1.f) / 2; if (border) { // clip coordinates to image borders ix = clip_coordinates_set_grad(ix, inp_W, &gix_mult); iy = clip_coordinates_set_grad(iy, inp_H, &giy_mult); iz = clip_coordinates_set_grad(iz, inp_D, &giz_mult); } // get corner pixel values from (x, y, z) // for 4d, we used north-east-south-west // for 5d, we add top-bottom int ix_tnw = static_cast(::floor(ix)); int iy_tnw = static_cast(::floor(iy)); int iz_tnw = static_cast(::floor(iz)); int ix_tne = ix_tnw + 1; int iy_tne = iy_tnw; int iz_tne = iz_tnw; int ix_tsw = ix_tnw; int iy_tsw = iy_tnw + 1; int iz_tsw = iz_tnw; int ix_tse = ix_tnw + 1; int iy_tse = iy_tnw + 1; int iz_tse = iz_tnw; int ix_bnw = ix_tnw; int iy_bnw = iy_tnw; int iz_bnw = iz_tnw + 1; int ix_bne = ix_tnw + 1; int iy_bne = iy_tnw; int iz_bne = iz_tnw + 1; int ix_bsw = ix_tnw; int iy_bsw = iy_tnw + 1; int iz_bsw = iz_tnw + 1; int ix_bse = ix_tnw + 1; int iy_bse = iy_tnw + 1; int iz_bse = iz_tnw + 1; // get surfaces to each neighbor: float tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); float tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); float tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); float tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); float bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); float bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); float bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); float bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); float gix = static_cast(0), giy = static_cast(0), giz = static_cast(0); out_t *gOut_ptr_NCDHW = &grad_out; out_t *gInp_ptr_NC = (out_t*)grad_vals; out_t *inp_ptr_NC = (out_t*)vals; // calculate bilinear weighted pixel value and set output pixel { out_t gOut = *gOut_ptr_NCDHW; // calculate and set grad_input safe_add_3d(gInp_ptr_NC, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut); safe_add_3d(gInp_ptr_NC, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut); safe_add_3d(gInp_ptr_NC, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut); safe_add_3d(gInp_ptr_NC, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut); safe_add_3d(gInp_ptr_NC, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut); safe_add_3d(gInp_ptr_NC, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut); safe_add_3d(gInp_ptr_NC, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut); safe_add_3d(gInp_ptr_NC, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut); // calculate grad_grid if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { out_t tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW]; gix -= (iy_bse - iy) * (iz_bse - iz) * dot(tnw_val, gOut); giy -= (ix_bse - ix) * (iz_bse - iz) * dot(tnw_val, gOut); giz -= (ix_bse - ix) * (iy_bse - iy) * dot(tnw_val, gOut); } if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { out_t tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW]; gix += (iy_bsw - iy) * (iz_bsw - iz) * dot(tne_val, gOut); giy -= (ix - ix_bsw) * (iz_bsw - iz) * dot(tne_val, gOut); giz -= (ix - ix_bsw) * (iy_bsw - iy) * dot(tne_val, gOut); } if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { out_t tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW]; gix -= (iy - iy_bne) * (iz_bne - iz) * dot(tsw_val, gOut); giy += (ix_bne - ix) * (iz_bne - iz) * dot(tsw_val, gOut); giz -= (ix_bne - ix) * (iy - iy_bne) * dot(tsw_val, gOut); } if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { out_t tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW]; gix += (iy - iy_bnw) * (iz_bnw - iz) * dot(tse_val, gOut); giy += (ix - ix_bnw) * (iz_bnw - iz) * dot(tse_val, gOut); giz -= (ix - ix_bnw) * (iy - iy_bnw) * dot(tse_val, gOut); } if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { out_t bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW]; gix -= (iy_tse - iy) * (iz - iz_tse) * dot(bnw_val, gOut); giy -= (ix_tse - ix) * (iz - iz_tse) * dot(bnw_val, gOut); giz += (ix_tse - ix) * (iy_tse - iy) * dot(bnw_val, gOut); } if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { out_t bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW]; gix += (iy_tsw - iy) * (iz - iz_tsw) * dot(bne_val, gOut); giy -= (ix - ix_tsw) * (iz - iz_tsw) * dot(bne_val, gOut); giz += (ix - ix_tsw) * (iy_tsw - iy) * dot(bne_val, gOut); } if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { out_t bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW]; gix -= (iy - iy_tne) * (iz - iz_tne) * dot(bsw_val, gOut); giy += (ix_tne - ix) * (iz - iz_tne) * dot(bsw_val, gOut); giz += (ix_tne - ix) * (iy - iy_tne) * dot(bsw_val, gOut); } if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { out_t bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW]; gix += (iy - iy_tnw) * (iz - iz_tnw) * dot(bse_val, gOut); giy += (ix - ix_tnw) * (iz - iz_tnw) * dot(bse_val, gOut); giz += (ix - ix_tnw) * (iy - iy_tnw) * dot(bse_val, gOut); } } return make_float3(gix_mult * gix, giy_mult * giy, giz_mult * giz); } template struct GridSamplerChlast { static __forceinline__ __device__ out_t forward(int C, int inp_D, int inp_H, int inp_W, float* vals, float3 pos, bool border) { return grid_sample_chlast_forward(C, inp_D, inp_H, inp_W, vals, pos, border); } static __forceinline__ __device__ float3 backward(int C, int inp_D, int inp_H, int inp_W, float* vals, float* grad_vals, float3 pos, out_t grad_out, bool border) { return grid_sample_chlast_backward(C, inp_D, inp_H, inp_W, vals, grad_vals, pos, grad_out, border); } }; inline __host__ __device__ float min_component(float3 a) { return fminf(fminf(a.x,a.y),a.z); } inline __host__ __device__ float max_component(float3 a) { return fmaxf(fmaxf(a.x,a.y),a.z); } inline __host__ __device__ float3 abs(float3 a) { return make_float3(abs(a.x), abs(a.y), abs(a.z)); } __forceinline__ __device__ bool ray_aabb_hit(float3 p0, float3 p1, float3 raypos, float3 raydir) { float3 t0 = (p0 - raypos) / raydir; float3 t1 = (p1 - raypos) / raydir; float3 tmin = fminf(t0,t1), tmax = fmaxf(t0,t1); return max_component(tmin) <= min_component(tmax); } __forceinline__ __device__ bool ray_aabb_hit_ird(float3 p0, float3 p1, float3 raypos, float3 ird) { float3 t0 = (p0 - raypos) * ird; float3 t1 = (p1 - raypos) * ird; float3 tmin = fminf(t0,t1), tmax = fmaxf(t0,t1); return max_component(tmin) <= min_component(tmax); } __forceinline__ __device__ void ray_aabb_hit_ird_tminmax(float3 p0, float3 p1, float3 raypos, float3 ird, float &otmin, float &otmax) { float3 t0 = (p0 - raypos) * ird; float3 t1 = (p1 - raypos) * ird; float3 tmin = fminf(t0,t1), tmax = fmaxf(t0,t1); tmin = fminf(t0,t1); tmax = fmaxf(t0,t1); otmin = max_component(tmin); otmax = min_component(tmax); } inline __device__ bool aabb_intersect(float3 p0, float3 p1, float3 r0, float3 rd, float &tmin, float &tmax) { float tymin, tymax, tzmin, tzmax; const float3 bounds[2] = {p0, p1}; float3 ird = 1.0f/rd; int sx = (ird.x<0) ? 1 : 0; int sy = (ird.y<0) ? 1 : 0; int sz = (ird.z<0) ? 1 : 0; tmin = (bounds[sx].x - r0.x) * ird.x; tmax = (bounds[1-sx].x - r0.x) * ird.x; tymin = (bounds[sy].y - r0.y) * ird.y; tymax = (bounds[1-sy].y - r0.y) * ird.y; if ((tmin > tymax) || (tymin > tmax)) return false; if (tymin > tmin) tmin = tymin; if (tymax < tmax) tmax = tymax; tzmin = (bounds[sz].z - r0.z) * ird.z; tzmax = (bounds[1-sz].z - r0.z) * ird.z; if ((tmin > tzmax) || (tzmin > tmax)) return false; if (tzmin > tmin) tmin = tzmin; if (tzmax < tmax) tmax = tzmax; return true; } template static __forceinline__ __device__ void ray_subset_fixedbvh( unsigned warpmask, int K, float3 raypos, float3 raydir, float2 tminmax, float2 &rtminmax, int * sortedobjid, int2 * nodechildren, float3 * nodeaabb, const typename PrimTransfT::Data & primtransf_data, int *hitboxes, int & num) { float3 iraydir = 1.0f/raydir; int stack[64]; int* stack_ptr = stack; *stack_ptr++ = -1; int node = 0; do { // check if we're in a leaf if (node >= (K - 1)) { { int k = node - (K - 1); float3 r0, rd; PrimTransfT::forward2(primtransf_data, k, raypos, raydir, r0, rd); float3 ird = 1.0f/rd; float3 t0 = (-1.f - r0) * ird; float3 t1 = (1.f - r0) * ird; float3 tmin = fminf(t0,t1), tmax = fmaxf(t0,t1); float trmin = max_component(tmin); float trmax = min_component(tmax); bool intersection = trmin <= trmax; if (intersection) { // hit rtminmax.x = fminf(rtminmax.x, trmin); rtminmax.y = fmaxf(rtminmax.y, trmax); } if (sync) { intersection = __any_sync(warpmask, intersection); } if (intersection) { if (sortboxes) { if (num < maxhitboxes) { int j = num - 1; while (j >= 0 && hitboxes[j] > k) { hitboxes[j + 1] = hitboxes[j]; j = j - 1; } hitboxes[j + 1] = k; num++; } } else { if (num < maxhitboxes) { hitboxes[num++] = k; } } } } node = *--stack_ptr; } else { int2 children = make_int2(node * 2 + 1, node * 2 + 2); // check if we're in each child's bbox float3 * nodeaabb_ptr = nodeaabb + children.x * 2; bool traverse_l = ray_aabb_hit_ird(nodeaabb_ptr[0], nodeaabb_ptr[1], raypos, iraydir); bool traverse_r = ray_aabb_hit_ird(nodeaabb_ptr[2], nodeaabb_ptr[3], raypos, iraydir); if (sync) { traverse_l = __any_sync(warpmask, traverse_l); traverse_r = __any_sync(warpmask, traverse_r); } // update stack if (!traverse_l && !traverse_r) { node = *--stack_ptr; } else { node = traverse_l ? children.x : children.y; if (traverse_l && traverse_r) { *stack_ptr++ = children.y; } } if (sync) { __syncwarp(warpmask); } } } while (node != -1); } template struct RaySubsetFixedBVH { static __forceinline__ __device__ void forward( unsigned warpmask, int K, float3 raypos, float3 raydir, float2 tminmax, float2 &rtminmax, int * sortedobjid, int2 * nodechildren, float3 * nodeaabb, const typename PrimTransfT::Data & primtransf_data, int *hitboxes, int & num) { ray_subset_fixedbvh( warpmask, K, raypos, raydir, tminmax, rtminmax, sortedobjid, nodechildren, nodeaabb, primtransf_data, hitboxes, num); } }; #endif