Spaces:
Running
on
Zero
Running
on
Zero
| from typing import * | |
| from torch import Tensor | |
| import torch | |
| from einops import rearrange, repeat | |
| from src.models.gs_render.deferred_bp import deferred_bp | |
| from src.models.gs_render.gs_util import GaussianModel, render | |
| from src.options import Options | |
| from src.utils import unproject_depth | |
| class GaussianRenderer: | |
| def __init__(self, opt: Options): | |
| self.opt = opt | |
| self.scale_activation = lambda x: \ | |
| self.opt.scale_min * x + self.opt.scale_max * (1. - x) # [0, 1] -> [s_min, s_max] | |
| def render(self, | |
| model_outputs: Dict[str, Tensor], | |
| input_C2W: Tensor, input_fxfycxcy: Tensor, | |
| C2W: Tensor, fxfycxcy: Tensor, | |
| height: Optional[float] = None, | |
| width: Optional[float] = None, | |
| bg_color: Tuple[float, float, float] = (1., 1., 1.), | |
| scaling_modifier: float = 1., | |
| opacity_threshold: float = 0., | |
| input_normalized: bool = True, | |
| in_image_format: bool = True, | |
| ): | |
| if not in_image_format: | |
| assert height is not None and width is not None | |
| assert "xyz" in model_outputs # depth must be in image format | |
| rgb, scale, rotation, opacity = model_outputs["rgb"], model_outputs["scale"], model_outputs["rotation"], model_outputs["opacity"] | |
| depth = model_outputs.get("depth", None) | |
| xyz = model_outputs.get("xyz", None) | |
| # Only one of `depth` and `xyz` should be None | |
| assert (depth is not None or xyz is not None) and not (depth is not None and xyz is not None) | |
| # Rendering resolution could be different from input resolution | |
| H = height if height is not None else rgb.shape[-2] | |
| W = width if width is not None else rgb.shape[-1] | |
| # Reshape for rendering | |
| if in_image_format: | |
| rgb = rearrange(rgb, "b v c h w -> b (v h w) c") | |
| scale = rearrange(scale, "b v c h w -> b (v h w) c") | |
| rotation = rearrange(rotation, "b v c h w -> b (v h w) c") | |
| opacity = rearrange(opacity, "b v c h w -> b (v h w) c") | |
| # Prepare XYZ for rendering | |
| if xyz is None: | |
| if input_normalized: | |
| depth = depth + torch.norm(input_C2W[:, :, :3, 3], p=2, dim=2, keepdim=True)[..., None, None] # [-1, 1] -> image plane + [-1, 1] | |
| xyz = unproject_depth(depth.squeeze(2), input_C2W, input_fxfycxcy) # [-1, 1] | |
| xyz = xyz + model_outputs.get("offset", torch.zeros_like(xyz)) | |
| if in_image_format: | |
| xyz = rearrange(xyz, "b v c h w -> b (v h w) c") | |
| # From [-1, 1] to valid values | |
| if input_normalized: | |
| rgb = rgb * 0.5 + 0.5 # [-1, 1] -> [0, 1] | |
| scale = self.scale_activation(scale * 0.5 + 0.5) # [-1, 1] -> [0, 1] -> [s_min, s_max] | |
| rotation = rotation # not changed; already L2 normalized | |
| opacity = opacity * 0.5 + 0.5 # [-1, 1] -> [0, 1] | |
| # Filter by opacity | |
| opacity = (opacity > opacity_threshold) * opacity | |
| (B, V), device = C2W.shape[:2], C2W.device # `HR`/`WR` meight be different from `H`/`W` | |
| images = torch.zeros(B, V, 3, H, W, dtype=torch.float32, device=device) | |
| alphas = torch.zeros(B, V, 1, H, W, dtype=torch.float32, device=device) | |
| depths = torch.zeros(B, V, 1, H, W, dtype=torch.float32, device=device) | |
| normals = torch.zeros(B, V, 3, H, W, dtype=torch.float32, device=device) | |
| pcs = [] | |
| for i in range(B): | |
| pcs.append(GaussianModel().set_data(xyz[i], rgb[i], scale[i], rotation[i], opacity[i])) | |
| if self.opt.render_type == "defered": | |
| images, alphas, depths, normals = deferred_bp( | |
| xyz, rgb, scale, rotation, opacity, | |
| H, W, C2W, fxfycxcy, | |
| self.opt.deferred_bp_patch_size, GaussianModel(), | |
| self.opt.znear, self.opt.zfar, | |
| bg_color, | |
| scaling_modifier, | |
| self.opt.coord_weight > 0. or self.opt.normal_weight > 0. or \ | |
| self.opt.vis_coords or self.opt.vis_normals, # whether render depth & normal | |
| ) | |
| else: # default | |
| for i in range(B): | |
| pc = pcs[i] | |
| for j in range(V): | |
| render_results = render( | |
| pc, H, W, C2W[i, j], fxfycxcy[i, j], | |
| self.opt.znear, self.opt.zfar, | |
| bg_color, | |
| scaling_modifier, | |
| self.opt.coord_weight > 0. or self.opt.normal_weight > 0. or \ | |
| self.opt.vis_coords or self.opt.vis_normals, # whether render depth & normal | |
| ) | |
| images[i, j] = render_results["image"] | |
| # alphas[i, j] = render_results["alpha"] | |
| # depths[i, j] = render_results["depth"] | |
| # normals[i, j] = render_results["normal"] | |
| if not isinstance(bg_color, Tensor): | |
| bg_color = torch.tensor(list(bg_color), dtype=torch.float32, device=device) | |
| bg_color = repeat(bg_color, "c -> b v c h w", b=B, v=V, h=H, w=W) | |
| coords = (unproject_depth(depths.squeeze(2), C2W, fxfycxcy) | |
| * 0.5 + 0.5) * alphas + (1. - alphas) * bg_color | |
| normals_ = (torch.einsum("bvrc,bvchw->bvrhw", C2W[:, :, :3, :3], normals) | |
| * 0.5 + 0.5) * alphas + (1. - alphas) * bg_color | |
| return { | |
| "image": images, | |
| "alpha": alphas, | |
| "coord": coords, | |
| "normal": normals_, | |
| "raw_depth": depths, | |
| "raw_normal": normals, | |
| "pc": pcs, | |
| } | |