|
import os |
|
import torch |
|
from torch.utils.cpp_extension import load |
|
from torch.autograd import Function |
|
from torch.autograd.function import once_differentiable |
|
|
|
|
|
build_path = os.path.join(os.path.split(os.path.abspath(__file__))[0], 'build') |
|
os.makedirs(build_path, exist_ok=True) |
|
|
|
file_path = os.path.split(os.path.abspath(__file__))[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gscuda |
|
GSWrapper = gscuda |
|
|
|
class GSCUDA(Function): |
|
|
|
@staticmethod |
|
def forward(ctx, sigmas, coords, colors, rendered_img, dmax): |
|
ctx.save_for_backward(sigmas, coords, colors) |
|
ctx.dmax = dmax |
|
h, w, c = rendered_img.shape |
|
s = sigmas.shape[0] |
|
GSWrapper.gs_render(sigmas, coords, colors, rendered_img, s, h, w, c, dmax) |
|
return rendered_img |
|
|
|
@staticmethod |
|
@once_differentiable |
|
def backward(ctx, grad_output): |
|
sigmas, coords, colors = ctx.saved_tensors |
|
dmax = ctx.dmax |
|
h, w, c = grad_output.shape |
|
s = sigmas.shape[0] |
|
grads_sigmas = torch.zeros_like(sigmas) |
|
grads_coords = torch.zeros_like(coords) |
|
grads_colors = torch.zeros_like(colors) |
|
GSWrapper.gs_render_backward(sigmas, coords, colors, grad_output.contiguous(), grads_sigmas, grads_coords, grads_colors, s, h, w, c, dmax) |
|
return (grads_sigmas, grads_coords, grads_colors, None, None) |
|
|
|
def gaussiansplatting_render(sigmas, coords, colors, image_size,dmax=100): |
|
sigmas = sigmas.contiguous() |
|
coords = coords.contiguous() |
|
colors = colors.contiguous() |
|
h, w = image_size[:2] |
|
c = colors.shape[-1] |
|
rendered_img = torch.zeros(h, w, c).to(colors.device).to(torch.float32) |
|
return GSCUDA.apply(sigmas, coords, colors, rendered_img, dmax) |
|
|
|
if __name__ == "__main__": |
|
sigmas = torch.randn(10, 3).cuda() |
|
coords = torch.randn(10, 2).cuda() |
|
colors = torch.randn(10, 3).cuda() |
|
image_size = (100, 100) |
|
dmax = 0.1 |
|
rendered_img = gaussiansplatting_render(sigmas, coords, colors, image_size, dmax) |
|
print(rendered_img.shape) |