mt-cly
init
909940e
raw
history blame
5 kB
import torch
from gswrapper import gaussiansplatting_render
def torch_version(sigmas, coords, colors, image_size, dmax=100):
h, w = image_size
c = colors.shape[-1]
if h >= 50 or w >= 50:
logger.warning(f'too large values for h({h}), w({w}), torch version would be slow')
rendered_img = torch.zeros(h, w, c).to(colors.device).to(torch.float32)
for hi in range(h):
for wi in range(w):
curh = 2*hi/(h-1)-1.0
curw = 2*wi/(w-1)-1.0
v = (curw-coords[:,0])**2/sigmas[:,0]**2
v -= (2*sigmas[:,2])*(curw-coords[:,0])*(curh-coords[:,1])/sigmas[:,0]/sigmas[:,1]
v += (curh-coords[:,1])**2/sigmas[:,1]**2
v *= -1.0/(2.0*(1-sigmas[:,2]**2))
v = torch.exp(v)
mask_w = abs(curw-coords[:,0]) <= dmax
mask_h = abs(curh-coords[:,1]) <= dmax
mask = torch.logical_and(mask_w, mask_h)
for ci in range(c):
rendered_img[hi, wi, ci] = torch.sum((v*colors[:, ci])[mask])
return rendered_img
if __name__ == "__main__":
s = 4 # the number of gs
image_size = (10, 10)
for _ in range(1):
print(f"--------------------------- begins --------------------------------")
sigmas = 0.999*torch.rand(s, 3).to(torch.float32).to("cuda")
sigmas[:,:2] = 5*sigmas[:, :2]
coords = 2*torch.rand(s, 2).to(torch.float32).to("cuda")-1.0
colors = torch.rand(s, 3).to(torch.float32).to("cuda")
# colors = torch.rand(s, 5).to(torch.float32).to("cuda")
dmax = 0.5
# sigmas = torch.Tensor([[0.9196, 0.3979, 0.7784]]).to(torch.float32).to("cuda")
# coords = torch.Tensor([[-0.0469, -0.1726]]).to(torch.float32).to("cuda")
# colors = torch.Tensor([[0.3775, 0.2346, 0.1513]]).to(torch.float32).to("cuda")
# colors = torch.ones_like(coords[:,0:1])
print(f"sigmas: {sigmas}, \ncoords:{coords}, \ncolors:{colors}\ndmax:{dmax}")
# --- check forward ---
with torch.no_grad():
rendered_img_th = torch_version(sigmas,coords,colors,image_size,dmax)
rendered_img_cuda = gaussiansplatting_render(sigmas,coords,colors,image_size,dmax)
#
distance = (rendered_img_th-rendered_img_cuda)**2
print(f"check forward - torch: {rendered_img_th[:2,:2,0]}")
print(f"check forward - cuda: {rendered_img_cuda[:2,:2,0]}")
print(f"check forward - distance: {distance[:2, :2, 0]}")
print(f"check forward - sum: {torch.sum(distance)}\n")
# --- ends ---
# --- check backward ---
sigmas.requires_grad_(True)
coords.requires_grad_(True)
colors.requires_grad_(True)
# sigmas.retain_grad()
# coords.retain_grad()
# colors.retain_grad()
weight = torch.rand_like(rendered_img_th) # make each pixel has different grads
sigmas.grad = None
coords.grad = None
colors.grad = None
rendered_img_th = torch_version(sigmas,coords,colors,image_size,dmax)
loss_th = torch.sum(weight*rendered_img_th)
# loss_th = torch.sum(rendered_img_th)
loss_th.backward()
sigmas_grad_th = sigmas.grad
coords_grad_th = coords.grad
colors_grad_th = colors.grad
sigmas.grad = None
coords.grad = None
colors.grad = None
rendered_img_cuda = gaussiansplatting_render(sigmas,coords,colors,image_size,dmax)
loss_cuda = torch.sum(weight*rendered_img_cuda)
# loss_cuda = torch.sum(rendered_img_cuda)
loss_cuda.backward()
sigmas_grad_cuda = sigmas.grad
coords_grad_cuda = coords.grad
colors_grad_cuda = colors.grad
distance_sigmas_grad = (sigmas_grad_th-sigmas_grad_cuda)**2
distance_coords_grad = (coords_grad_th-coords_grad_cuda)**2
distance_colors_grad = (colors_grad_th-colors_grad_cuda)**2
print(f"check backward - sigmas - torch: {sigmas_grad_th[:2]}")
print(f"check backward - sigmas - cuda: {sigmas_grad_cuda[:2]}")
print(f"check backward - sigmas - distance: {distance_sigmas_grad[:2]}")
print(f"check backward - sigmas - sum: {torch.sum(distance_sigmas_grad)}\n")
print(f"check backward - coords - torch: {coords_grad_th[:2]}")
print(f"check backward - coords - cuda: {coords_grad_cuda[:2]}")
print(f"check backward - coords - distance: {distance_coords_grad[:2]}")
print(f"check backward - coords - sum: {torch.sum(distance_coords_grad)}\n")
print(f"check backward - colors - torch: {colors_grad_th[:2]}")
print(f"check backward - colors - cuda: {colors_grad_cuda[:2]}")
print(f"check backward - colors - distance: {distance_colors_grad[:2]}")
print(f"check backward - colors - sum: {torch.sum(distance_colors_grad)}\n")
print(f"--------------------------- ends --------------------------------\n\n")