File size: 2,401 Bytes
909940e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import torch
from torch.utils.cpp_extension import load
from torch.autograd import Function
from torch.autograd.function import once_differentiable

# 
build_path = os.path.join(os.path.split(os.path.abspath(__file__))[0], 'build')
os.makedirs(build_path, exist_ok=True)

file_path = os.path.split(os.path.abspath(__file__))[0]
# GSWrapper = load(
#         name="gscuda",
#         # sources=["gs_cuda/gswrapper.cpp", "gs_cuda/gs.cu"],
#         sources=[os.path.join(file_path, "gswrapper.cpp"),
#                  os.path.join(file_path, "gs.cu")],
#         build_directory=build_path,
#         verbose=True)

import gscuda
GSWrapper = gscuda

class GSCUDA(Function):
   
        @staticmethod
        def forward(ctx, sigmas, coords, colors, rendered_img, dmax):
            ctx.save_for_backward(sigmas, coords, colors)
            ctx.dmax = dmax
            h, w, c = rendered_img.shape
            s = sigmas.shape[0]
            GSWrapper.gs_render(sigmas, coords, colors, rendered_img, s, h, w, c, dmax)
            return rendered_img

        @staticmethod
        @once_differentiable
        def backward(ctx, grad_output):
            sigmas, coords, colors = ctx.saved_tensors
            dmax = ctx.dmax
            h, w, c = grad_output.shape
            s = sigmas.shape[0]
            grads_sigmas = torch.zeros_like(sigmas)
            grads_coords = torch.zeros_like(coords)
            grads_colors = torch.zeros_like(colors)
            GSWrapper.gs_render_backward(sigmas, coords, colors, grad_output.contiguous(), grads_sigmas, grads_coords, grads_colors, s, h, w, c, dmax)
            return (grads_sigmas, grads_coords, grads_colors, None, None)

def gaussiansplatting_render(sigmas, coords, colors, image_size,dmax=100):
    sigmas = sigmas.contiguous() # (gs num, 3)
    coords = coords.contiguous() # (gs num, 2)
    colors = colors.contiguous() # (gs num, c)
    h, w = image_size[:2]
    c = colors.shape[-1]
    rendered_img = torch.zeros(h, w, c).to(colors.device).to(torch.float32)
    return GSCUDA.apply(sigmas, coords, colors, rendered_img, dmax)

if __name__ == "__main__":
    sigmas = torch.randn(10, 3).cuda()
    coords = torch.randn(10, 2).cuda()
    colors = torch.randn(10, 3).cuda()
    image_size = (100, 100)
    dmax = 0.1
    rendered_img = gaussiansplatting_render(sigmas, coords, colors, image_size, dmax)
    print(rendered_img.shape)