Spaces:
Running
on
Zero
Running
on
Zero
| from typing import * | |
| from torch import Tensor | |
| import os | |
| import numpy as np | |
| from plyfile import PlyData, PlyElement | |
| import torch | |
| from diff_gaussian_rasterization import ( | |
| GaussianRasterizationSettings, | |
| GaussianRasterizer, | |
| ) | |
| class Camera: | |
| def __init__(self, | |
| C2W: Tensor, fxfycxcy: Tensor, h: int, w: int, | |
| znear: float = 0.01, zfar: float = 100., | |
| ): | |
| self.fxfycxcy = fxfycxcy.clone().float() | |
| self.C2W = C2W.clone().float() | |
| self.W2C = self.C2W.inverse() | |
| self.znear = znear | |
| self.zfar = zfar | |
| self.h = h | |
| self.w = w | |
| fx, fy, cx, cy = self.fxfycxcy[0], self.fxfycxcy[1], self.fxfycxcy[2], self.fxfycxcy[3] | |
| self.tanfovX = 1 / (2 * fx) # `tanHalfFovX` actually | |
| self.tanfovY = 1 / (2 * fy) # `tanHalfFovY` actually | |
| self.fovX = 2 * torch.atan(self.tanfovX) | |
| self.fovY = 2 * torch.atan(self.tanfovY) | |
| self.shiftX = 2 * cx - 1 | |
| self.shiftY = 2 * cy - 1 | |
| def getProjectionMatrix(znear, zfar, fovX, fovY, shiftX, shiftY): | |
| tanHalfFovY = torch.tan((fovY / 2)) | |
| tanHalfFovX = torch.tan((fovX / 2)) | |
| top = tanHalfFovY * znear | |
| bottom = -top | |
| right = tanHalfFovX * znear | |
| left = -right | |
| P = torch.zeros(4, 4, device=fovX.device) | |
| z_sign = 1 | |
| P[0, 0] = 2 * znear / (right - left) | |
| P[1, 1] = 2 * znear / (top - bottom) | |
| P[0, 2] = (right + left) / (right - left) + shiftX | |
| P[1, 2] = (top + bottom) / (top - bottom) + shiftY | |
| P[3, 2] = z_sign | |
| P[2, 2] = z_sign * zfar / (zfar - znear) | |
| P[2, 3] = -(zfar * znear) / (zfar - znear) | |
| return P | |
| self.world_view_transform = self.W2C.transpose(0, 1) | |
| self.projection_matrix = getProjectionMatrix(self.znear, self.zfar, self.fovX, self.fovY, self.shiftX, self.shiftY).transpose(0, 1) | |
| self.full_proj_transform = self.world_view_transform @ self.projection_matrix | |
| self.camera_center = self.C2W[:3, 3] | |
| class GaussianModel: | |
| def __init__(self): | |
| self.xyz = None | |
| self.rgb = None | |
| self.scale = None | |
| self.rotation = None | |
| self.opacity = None | |
| self.sh_degree = 0 | |
| def set_data(self, xyz: Tensor, rgb: Tensor, scale: Tensor, rotation: Tensor, opacity: Tensor): | |
| self.xyz = xyz | |
| self.rgb = rgb | |
| self.scale = scale | |
| self.rotation = rotation | |
| self.opacity = opacity | |
| return self | |
| def to(self, device: torch.device = None, dtype: torch.dtype = None) -> "GaussianModel": | |
| self.xyz = self.xyz.to(device, dtype) | |
| self.rgb = self.rgb.to(device, dtype) | |
| self.scale = self.scale.to(device, dtype) | |
| self.rotation = self.rotation.to(device, dtype) | |
| self.opacity = self.opacity.to(device, dtype) | |
| return self | |
| def save_ply(self, path: str, opacity_threshold: float = 0.): | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| xyz = self.xyz.detach().cpu().numpy() | |
| f_dc = self.rgb.detach().cpu().numpy() | |
| rgb = (f_dc * 255.).clip(0., 255.).astype(np.uint8) | |
| opacity = self.opacity.detach().cpu().numpy() | |
| scale = self.scale.detach().cpu().numpy() | |
| rotation = self.rotation.detach().cpu().numpy() | |
| # Filter out points with low opacity | |
| mask = (opacity > opacity_threshold).squeeze() | |
| xyz = xyz[mask] | |
| f_dc = f_dc[mask] | |
| opacity = opacity[mask] | |
| scale = scale[mask] | |
| rotation = rotation[mask] | |
| rgb = rgb[mask] | |
| dtype_full = [(attribute, "f4") for attribute in self._construct_list_of_attributes()] | |
| dtype_full.extend([("red", "u1"), ("green", "u1"), ("blue", "u1")]) | |
| elements = np.empty(xyz.shape[0], dtype=dtype_full) | |
| attributes = np.concatenate((xyz, f_dc, opacity, scale, rotation, rgb), axis=1) | |
| elements[:] = list(map(tuple, attributes)) | |
| el = PlyElement.describe(elements, "vertex") | |
| PlyData([el]).write(path) | |
| def load_ply(self, path: str): | |
| plydata = PlyData.read(path) | |
| xyz = np.stack(( | |
| np.asarray(plydata.elements[0]["x"]), | |
| np.asarray(plydata.elements[0]["y"]), | |
| np.asarray(plydata.elements[0]["z"]), | |
| ), axis=1) | |
| f_dc = np.stack(( | |
| np.asarray(plydata.elements[0]["f_dc_0"]), | |
| np.asarray(plydata.elements[0]["f_dc_1"]), | |
| np.asarray(plydata.elements[0]["f_dc_2"]), | |
| ), axis=1) | |
| opacity = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] | |
| scale = np.stack(( | |
| np.asarray(plydata.elements[0]["scale_0"]), | |
| np.asarray(plydata.elements[0]["scale_1"]), | |
| np.asarray(plydata.elements[0]["scale_2"]), | |
| ), axis=1) | |
| rotation = np.stack(( | |
| np.asarray(plydata.elements[0]["rot_0"]), | |
| np.asarray(plydata.elements[0]["rot_1"]), | |
| np.asarray(plydata.elements[0]["rot_2"]), | |
| np.asarray(plydata.elements[0]["rot_3"]), | |
| ), axis=1) | |
| self.xyz = torch.from_numpy(xyz).float() | |
| self.rgb = torch.from_numpy(f_dc).float() | |
| self.opacity = torch.from_numpy(opacity).float() | |
| self.scale = torch.from_numpy(scale).float() | |
| self.rotation = torch.from_numpy(rotation).float() | |
| def _construct_list_of_attributes(self): | |
| l = ["x", "y", "z"] | |
| for i in range(self.rgb.shape[1]): | |
| l.append(f"f_dc_{i}") | |
| l.append("opacity") | |
| for i in range(self.scale.shape[1]): | |
| l.append(f"scale_{i}") | |
| for i in range(self.rotation.shape[1]): | |
| l.append(f"rot_{i}") | |
| return l | |
| def render( | |
| pc: GaussianModel, | |
| height: int, | |
| width: int, | |
| C2W: Tensor, | |
| fxfycxcy: Tensor, | |
| znear: float = 0.01, | |
| zfar: float = 100., | |
| bg_color: Union[Tensor, Tuple[float, float, float]] = (1., 1., 1.), | |
| scaling_modifier: float = 1., | |
| render_dn: bool = False, | |
| ): | |
| viewpoint_camera = Camera(C2W, fxfycxcy, height, width, znear, zfar) | |
| if not isinstance(bg_color, Tensor): | |
| bg_color = torch.tensor(list(bg_color), dtype=torch.float32, device=C2W.device) | |
| else: | |
| bg_color = bg_color.to(C2W.device, dtype=torch.float32) | |
| pc = pc.to(dtype=torch.float32) | |
| subpixel_offset = torch.zeros((int(viewpoint_camera.h), int(viewpoint_camera.w), 2), dtype=torch.float32, device="cuda") | |
| raster_settings = GaussianRasterizationSettings( | |
| image_height=int(viewpoint_camera.h), | |
| image_width=int(viewpoint_camera.w), | |
| tanfovx=viewpoint_camera.tanfovX, | |
| tanfovy=viewpoint_camera.tanfovY, | |
| subpixel_offset=subpixel_offset, | |
| kernel_size=0., # cf. Mip-Splatting; not used | |
| bg=bg_color, | |
| scale_modifier=scaling_modifier, | |
| viewmatrix=viewpoint_camera.world_view_transform, | |
| projmatrix=viewpoint_camera.full_proj_transform, | |
| sh_degree=pc.sh_degree, | |
| campos=viewpoint_camera.camera_center, | |
| prefiltered=False, | |
| debug=False, | |
| ) | |
| alpha = normal = depth = None | |
| rasterizer = GaussianRasterizer(raster_settings=raster_settings) | |
| # Rasterize visible Gaussians to image, obtain their radii (on screen). | |
| image, radii = rasterizer( # not used: radii, coord, mcoord, mdepth | |
| means3D=pc.xyz, | |
| means2D=torch.zeros_like(pc.xyz, dtype=torch.float32, device=pc.xyz.device), | |
| shs=None, | |
| colors_precomp=pc.rgb, | |
| opacities=pc.opacity, | |
| scales=pc.scale, | |
| rotations=pc.rotation, | |
| cov3D_precomp=None, | |
| ) | |
| return { | |
| "image": image, | |
| "alpha": alpha, | |
| "depth": depth, | |
| "normal": normal, | |
| } | |