|
|
#include "gs.h" |
|
|
#include <torch/extension.h> |
|
|
#include <c10/cuda/CUDAGuard.h> |
|
|
|
|
|
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") |
|
|
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") |
|
|
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) |
|
|
|
|
|
void gs_render( |
|
|
torch::Tensor &sigmas, |
|
|
torch::Tensor &coords, |
|
|
torch::Tensor &colors, |
|
|
torch::Tensor &rendered_img, |
|
|
const int s, |
|
|
const int h, |
|
|
const int w, |
|
|
const int c, |
|
|
const float dmax |
|
|
){ |
|
|
|
|
|
CHECK_INPUT(sigmas); |
|
|
CHECK_INPUT(coords); |
|
|
CHECK_INPUT(colors); |
|
|
CHECK_INPUT(rendered_img); |
|
|
|
|
|
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(sigmas)); |
|
|
|
|
|
_gs_render( |
|
|
(const float *) sigmas.data_ptr(), |
|
|
(const float *) coords.data_ptr(), |
|
|
(const float *) colors.data_ptr(), |
|
|
(float *) rendered_img.data_ptr(), |
|
|
s, h, w, c, dmax); |
|
|
} |
|
|
|
|
|
void gs_render_backward( |
|
|
torch::Tensor &sigmas, |
|
|
torch::Tensor &coords, |
|
|
torch::Tensor &colors, |
|
|
torch::Tensor &grads, |
|
|
torch::Tensor &grads_sigmas, |
|
|
torch::Tensor &grads_coords, |
|
|
torch::Tensor &grads_colors, |
|
|
const int s, |
|
|
const int h, |
|
|
const int w, |
|
|
const int c, |
|
|
const float dmax |
|
|
){ |
|
|
|
|
|
CHECK_INPUT(sigmas); |
|
|
CHECK_INPUT(coords); |
|
|
CHECK_INPUT(colors); |
|
|
CHECK_INPUT(grads); |
|
|
CHECK_INPUT(grads_sigmas); |
|
|
CHECK_INPUT(grads_coords); |
|
|
CHECK_INPUT(grads_colors); |
|
|
|
|
|
|
|
|
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(sigmas)); |
|
|
|
|
|
_gs_render_backward( |
|
|
(const float *) sigmas.data_ptr(), |
|
|
(const float *) coords.data_ptr(), |
|
|
(const float *) colors.data_ptr(), |
|
|
(const float *) grads.data_ptr(), |
|
|
(float *) grads_sigmas.data_ptr(), |
|
|
(float *) grads_coords.data_ptr(), |
|
|
(float *) grads_colors.data_ptr(), |
|
|
s, h, w, c, dmax); |
|
|
} |
|
|
|
|
|
PYBIND11_MODULE( TORCH_EXTENSION_NAME, m) { |
|
|
m.def( "gs_render", |
|
|
&gs_render, |
|
|
"cuda forward wrapper"); |
|
|
m.def( "gs_render_backward", |
|
|
&gs_render_backward, |
|
|
"cuda backward wrapper"); |
|
|
} |
|
|
|