|
import torch |
|
import numpy as np |
|
import torch.nn.functional as F |
|
import math |
|
import torch.nn as nn |
|
|
|
import torchvision.utils |
|
from torchvision.utils import save_image |
|
|
|
|
|
def rendering_python(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, device): |
|
sr_h, sr_w = sr_size[0], sr_size[1] |
|
num_gs = sigma_x.shape[0] |
|
|
|
sigma_x = sigma_x[...,None] |
|
sigma_y = sigma_y[...,None] |
|
rho = rho[...,None] |
|
covariance = torch.stack( |
|
[torch.stack([sigma_x**2, rho*sigma_x*sigma_y], dim=-1), |
|
torch.stack([rho*sigma_x*sigma_y, sigma_y**2], dim=-1)], |
|
dim=-2 |
|
) |
|
|
|
|
|
determinant = (sigma_x**2) * (sigma_y**2) - (rho * sigma_x * sigma_y)**2 |
|
if (determinant < 0).any(): |
|
raise ValueError("Covariance matrix must be positive semi-definite") |
|
|
|
inv_covariance = torch.inverse(covariance) |
|
|
|
|
|
num_step = int(10 * 2 / step_size) |
|
ax_h_batch = torch.tensor([i * step_size for i in range(num_step)]).to(device)[None] |
|
ax_h_batch -= ax_h_batch.mean() |
|
ax_w_batch = torch.tensor([i * step_size for i in range(num_step)]).to(device)[None] |
|
ax_w_batch -= ax_w_batch.mean() |
|
|
|
|
|
ax_batch_expanded_x = ax_h_batch.unsqueeze(-1).expand(-1, -1, num_step) |
|
ax_batch_expanded_y = ax_w_batch.unsqueeze(1).expand(-1, num_step, -1) |
|
|
|
|
|
xx, yy = ax_batch_expanded_x, ax_batch_expanded_y |
|
|
|
xy = torch.stack([xx, yy], dim=-1) |
|
|
|
max_buffer = 2000 |
|
final_image = torch.zeros((3, sr_h, sr_w), device=device) |
|
for i in range(num_gs // max_buffer + 1): |
|
|
|
s_idx, e_idx = i * max_buffer, min((i + 1) * max_buffer, num_gs) |
|
buffer_size = e_idx - s_idx |
|
if buffer_size == 0: |
|
break |
|
|
|
buff_inv_covariance = inv_covariance[s_idx:e_idx] |
|
buff_covariance = covariance[s_idx:e_idx] |
|
buffer_pixel_coords = coords[s_idx:e_idx] |
|
buffer_alpha = colours_with_alpha[s_idx:e_idx].unsqueeze(-1).unsqueeze(-1) |
|
|
|
z = torch.einsum('b...i,b...ij,b...j->b...', xy, -0.5 * buff_inv_covariance, xy) |
|
kernel = torch.exp(z) / (2 * torch.tensor(np.pi, device=device) * torch.sqrt(torch.det(buff_covariance)).view(buffer_size, 1, 1)) |
|
|
|
kernel_max = kernel.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0] |
|
kernel_normalized = kernel / (kernel_max + 1e-4) |
|
kernel_reshaped = kernel_normalized.repeat(1, 3, 1).view(buffer_size * 3, num_step, num_step) |
|
kernel_reshaped = kernel_reshaped.unsqueeze(0).reshape(buffer_size, 3, num_step, num_step) |
|
|
|
b, c, h, w = kernel_reshaped.shape |
|
|
|
|
|
theta = torch.zeros(b, 2, 3, dtype=torch.float32, device=device) |
|
theta[:, 0, 0] = 1 * sr_w / num_step |
|
theta[:, 1, 1] = 1 * sr_h / num_step |
|
theta[:, 0, 2] = -buffer_pixel_coords[:, 0] * sr_w / num_step |
|
theta[:, 1, 2] = -buffer_pixel_coords[:, 1] * sr_h / num_step |
|
|
|
grid = F.affine_grid(theta, size=(b, c, sr_h, sr_w), align_corners=False) |
|
kernel_reshaped_translated = F.grid_sample(kernel_reshaped, grid, |
|
align_corners=False) |
|
buffer_final_image = buffer_alpha * kernel_reshaped_translated |
|
final_image += buffer_final_image.sum(0) |
|
|
|
return final_image |
|
|
|
def rendering_cuda(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, device): |
|
from utils.gs_cuda.gswrapper import GSCUDA |
|
sigmas = torch.cat([sigma_y/step_size*2/(sr_size[1] - 1), sigma_x/step_size*2/(sr_size[0] - 1), rho], dim=-1).contiguous() |
|
coords[:, 0] = (coords[:, 0] + 1 - 1/sr_size[1]) * sr_size[1] / (sr_size[1] - 1) - 1.0 |
|
coords[:, 1] = (coords[:, 1] + 1 - 1/sr_size[0]) * sr_size[0] / (sr_size[0] - 1) - 1.0 |
|
colours_with_alpha = colours_with_alpha.contiguous() |
|
rendered_img = torch.zeros(sr_size[0], sr_size[1], 3).to(device).type(torch.float32).contiguous() |
|
|
|
|
|
|
|
final_image = GSCUDA.apply(sigmas, coords, colours_with_alpha, rendered_img) |
|
final_image = final_image.permute(2, 0, 1).contiguous() |
|
return final_image |
|
|
|
def rendering_cuda_buffer(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, device, buffer_size = 1000000): |
|
from utils.gs_cuda.gswrapper import GSCUDA |
|
sigmas = torch.cat([sigma_y/step_size*2/(sr_size[1] - 1), sigma_x/step_size*2/(sr_size[0] - 1), rho], dim=-1).contiguous() |
|
coords[:, 0] = (coords[:, 0] + 1 - 1/sr_size[1]) * sr_size[1] / (sr_size[1] - 1) - 1.0 |
|
coords[:, 1] = (coords[:, 1] + 1 - 1/sr_size[0]) * sr_size[0] / (sr_size[0] - 1) - 1.0 |
|
colours_with_alpha = colours_with_alpha.contiguous() |
|
final_image = torch.zeros(sr_size[0], sr_size[1], 3).to(device).type(torch.float32).contiguous() |
|
|
|
|
|
buffer_num = len(sigma_x)// buffer_size+1 |
|
for buffer_id in range(buffer_num): |
|
|
|
idx_start, idx_end = buffer_id * buffer_size, (buffer_id+1) * buffer_size |
|
final_image = GSCUDA.apply(sigmas[idx_start:idx_end], coords[idx_start:idx_end], |
|
colours_with_alpha[idx_start:idx_end], final_image) |
|
|
|
final_image = final_image.permute(2, 0, 1).contiguous() |
|
return final_image |
|
|
|
def rendering_cuda_dmax(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, device, dmax=1): |
|
from utils.gs_cuda_dmax.gswrapper import GSCUDA |
|
sigmas = torch.cat([sigma_y/step_size*2/(sr_size[1] - 1), sigma_x/step_size*2/(sr_size[0] - 1), rho], dim=-1).contiguous() |
|
coords[:, 0] = (coords[:, 0] + 1 - 1/sr_size[1]) * sr_size[1] / (sr_size[1] - 1) - 1.0 |
|
coords[:, 1] = (coords[:, 1] + 1 - 1/sr_size[0]) * sr_size[0] / (sr_size[0] - 1) - 1.0 |
|
colours_with_alpha = colours_with_alpha.contiguous() |
|
rendered_img = torch.zeros(sr_size[0], sr_size[1], 3).to(device).type(torch.float32).contiguous() |
|
|
|
|
|
|
|
final_image = GSCUDA.apply(sigmas, coords, colours_with_alpha, rendered_img, dmax) |
|
final_image = final_image.permute(2, 0, 1).contiguous() |
|
return final_image |
|
|
|
def rendering_cuda_dmax_buffer(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, device, dmax=1, buffer_size = 1000000): |
|
from utils.gs_cuda_dmax.gswrapper import GSCUDA |
|
sigmas = torch.cat([sigma_y/step_size*2/(sr_size[1] - 1), sigma_x/step_size*2/(sr_size[0] - 1), rho], dim=-1).contiguous() |
|
coords[:, 0] = (coords[:, 0] + 1 - 1/sr_size[1]) * sr_size[1] / (sr_size[1] - 1) - 1.0 |
|
coords[:, 1] = (coords[:, 1] + 1 - 1/sr_size[0]) * sr_size[0] / (sr_size[0] - 1) - 1.0 |
|
colours_with_alpha = colours_with_alpha.contiguous() |
|
|
|
final_image = torch.zeros(sr_size[0], sr_size[1], 3).to(device).type(torch.float32).contiguous() |
|
|
|
|
|
|
|
|
|
|
|
buffer_num = len(sigma_x)// buffer_size+1 |
|
for buffer_id in range(buffer_num): |
|
|
|
idx_start, idx_end = buffer_id * buffer_size, (buffer_id+1) * buffer_size |
|
final_image = GSCUDA.apply(sigmas[idx_start:idx_end], coords[idx_start:idx_end], |
|
colours_with_alpha[idx_start:idx_end], final_image, dmax) |
|
|
|
|
|
final_image = final_image.permute(2, 0, 1).contiguous() |
|
return final_image |
|
|
|
|
|
def generate_2D_gaussian_splatting_step(sr_size, gs_parameters, scale, scale_modify, |
|
sample_coords = None, default_step_size = 1.2, |
|
cuda_rendering=True, mode = 'scale_modify', |
|
if_dmax = True, |
|
dmax_mode = 'fix', |
|
dmax = 25): |
|
|
|
|
|
if mode == 'scale': |
|
final_scale = scale |
|
elif mode == 'scale_modify': |
|
assert scale_modify[0] == scale_modify[1], f"scale_modify is not the same-{scale_modify}" |
|
final_scale = scale_modify[0] |
|
step_size = default_step_size/ final_scale |
|
|
|
|
|
sigma_x = 0.99999 * torch.sigmoid(gs_parameters[:, 0:1]) + 1e-6 |
|
sigma_y = 0.99999 * torch.sigmoid(gs_parameters[:, 1:2]) + 1e-6 |
|
rho = 0.999999 * torch.tanh(gs_parameters[:, 2:3]) |
|
alpha = torch.sigmoid(gs_parameters[:, 3:4]) |
|
colours = torch.sigmoid(gs_parameters[:, 4:7]) |
|
coords = (gs_parameters[:, 7:9] * 2 - 1) |
|
colours_with_alpha = colours * alpha |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if cuda_rendering: |
|
if if_dmax: |
|
if dmax_mode == 'dynamic': |
|
dmax = (dmax + 2) / min(sr_size[0], sr_size[1]) |
|
elif dmax_mode == 'fix': |
|
pass |
|
else: |
|
raise ValueError(f"dmax_mode-{dmax_mode} must be fix or dynamic") |
|
final_image = rendering_cuda_dmax(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, dmax=dmax, device=sigma_x.device) |
|
else: |
|
final_image = rendering_cuda(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, device=sigma_x.device) |
|
else: |
|
final_image = rendering_python(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, device=sigma_x.device) |
|
if sample_coords is not None: |
|
sample_RGB_values = [final_image[:, coord[0], coord[1]] for coord in sample_coords] |
|
final_image = torch.stack(sample_RGB_values, dim = 1) |
|
return final_image |
|
|
|
def generate_2D_gaussian_splatting_step_buffer(sr_size, gs_parameters, scale, scale_modify, |
|
sample_coords = None, default_step_size = 1.2, |
|
cuda_rendering=True, mode = 'scale_modify', |
|
if_dmax = True, |
|
dmax_mode = 'fix', |
|
dmax = 25, |
|
buffer_size = 4000000): |
|
|
|
|
|
if mode == 'scale': |
|
final_scale = scale |
|
elif mode == 'scale_modify': |
|
assert scale_modify[0] == scale_modify[1], f"scale_modify is not the same-{scale_modify}" |
|
final_scale = scale_modify[0] |
|
step_size = default_step_size/ final_scale |
|
|
|
|
|
sigma_x = 0.99999 * torch.sigmoid(gs_parameters[:, 0:1]) + 1e-6 |
|
sigma_y = 0.99999 * torch.sigmoid(gs_parameters[:, 1:2]) + 1e-6 |
|
rho = 0.999999 * torch.tanh(gs_parameters[:, 2:3]) |
|
alpha = torch.sigmoid(gs_parameters[:, 3:4]) |
|
colours = torch.sigmoid(gs_parameters[:, 4:7]) |
|
coords = (gs_parameters[:, 7:9] * 2 - 1) |
|
colours_with_alpha = colours * alpha |
|
|
|
|
|
if cuda_rendering: |
|
if if_dmax: |
|
if dmax_mode == 'dynamic': |
|
dmax = (dmax + 2) / min(sr_size[0], sr_size[1]) |
|
elif dmax_mode == 'fix': |
|
pass |
|
else: |
|
raise ValueError(f"dmax_mode-{dmax_mode} must be fix or dynamic") |
|
final_image = rendering_cuda_dmax_buffer(sigma_x, sigma_y, rho, coords, colours_with_alpha, |
|
sr_size, step_size, dmax=dmax, device=sigma_x.device, |
|
buffer_size = buffer_size) |
|
else: |
|
final_image = rendering_cuda_buffer(sigma_x, sigma_y, rho, coords, colours_with_alpha, |
|
sr_size, step_size, device=sigma_x.device, |
|
buffer_size = buffer_size) |
|
else: |
|
final_image = rendering_python(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, device=sigma_x.device) |
|
if sample_coords is not None: |
|
sample_RGB_values = [final_image[:, coord[0], coord[1]] for coord in sample_coords] |
|
final_image = torch.stack(sample_RGB_values, dim = 1) |
|
return final_image |