jadechoghari
commited on
Create renderer.py
Browse files- renderer.py +314 -0
renderer.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
7 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
8 |
+
#
|
9 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
10 |
+
# property and proprietary rights in and to this material, related
|
11 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
12 |
+
# disclosure or distribution of this material and related documentation
|
13 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
14 |
+
# its affiliates is strictly prohibited.
|
15 |
+
#
|
16 |
+
# Modified by Zexin He
|
17 |
+
# The modifications are subject to the same license as the original.
|
18 |
+
|
19 |
+
|
20 |
+
"""
|
21 |
+
The renderer is a module that takes in rays, decides where to sample along each
|
22 |
+
ray, and computes pixel colors using the volume rendering equation.
|
23 |
+
"""
|
24 |
+
|
25 |
+
import torch
|
26 |
+
import torch.nn as nn
|
27 |
+
import torch.nn.functional as F
|
28 |
+
|
29 |
+
from .ray_marcher import MipRayMarcher2
|
30 |
+
from . import math_utils
|
31 |
+
|
32 |
+
def generate_planes():
|
33 |
+
"""
|
34 |
+
Defines planes by the three vectors that form the "axes" of the
|
35 |
+
plane. Should work with arbitrary number of planes and planes of
|
36 |
+
arbitrary orientation.
|
37 |
+
|
38 |
+
Bugfix reference: https://github.com/NVlabs/eg3d/issues/67
|
39 |
+
"""
|
40 |
+
return torch.tensor([[[1, 0, 0],
|
41 |
+
[0, 1, 0],
|
42 |
+
[0, 0, 1]],
|
43 |
+
[[1, 0, 0],
|
44 |
+
[0, 0, 1],
|
45 |
+
[0, 1, 0]],
|
46 |
+
[[0, 0, 1],
|
47 |
+
[0, 1, 0],
|
48 |
+
[1, 0, 0]]], dtype=torch.float32)
|
49 |
+
|
50 |
+
def project_onto_planes(planes, coordinates):
|
51 |
+
"""
|
52 |
+
Does a projection of a 3D point onto a batch of 2D planes,
|
53 |
+
returning 2D plane coordinates.
|
54 |
+
|
55 |
+
Takes plane axes of shape n_planes, 3, 3
|
56 |
+
# Takes coordinates of shape N, M, 3
|
57 |
+
# returns projections of shape N*n_planes, M, 2
|
58 |
+
"""
|
59 |
+
N, M, C = coordinates.shape
|
60 |
+
n_planes, _, _ = planes.shape
|
61 |
+
coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
|
62 |
+
inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
|
63 |
+
coordinates = coordinates.to(inv_planes.device)
|
64 |
+
projections = torch.bmm(coordinates, inv_planes)
|
65 |
+
return projections[..., :2]
|
66 |
+
|
67 |
+
def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
|
68 |
+
assert padding_mode == 'zeros'
|
69 |
+
N, n_planes, C, H, W = plane_features.shape
|
70 |
+
_, M, _ = coordinates.shape
|
71 |
+
plane_features = plane_features.view(N*n_planes, C, H, W)
|
72 |
+
|
73 |
+
coordinates = (2/box_warp) * coordinates # add specific box bounds
|
74 |
+
# half added here
|
75 |
+
projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1)
|
76 |
+
# removed float from projected_coordinates
|
77 |
+
output_features = torch.nn.functional.grid_sample(plane_features.float(), projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
|
78 |
+
return output_features
|
79 |
+
|
80 |
+
def sample_from_3dgrid(grid, coordinates):
|
81 |
+
"""
|
82 |
+
Expects coordinates in shape (batch_size, num_points_per_batch, 3)
|
83 |
+
Expects grid in shape (1, channels, H, W, D)
|
84 |
+
(Also works if grid has batch size)
|
85 |
+
Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels)
|
86 |
+
"""
|
87 |
+
batch_size, n_coords, n_dims = coordinates.shape
|
88 |
+
sampled_features = torch.nn.functional.grid_sample(grid.expand(batch_size, -1, -1, -1, -1),
|
89 |
+
coordinates.reshape(batch_size, 1, 1, -1, n_dims),
|
90 |
+
mode='bilinear', padding_mode='zeros', align_corners=False)
|
91 |
+
N, C, H, W, D = sampled_features.shape
|
92 |
+
sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C)
|
93 |
+
return sampled_features
|
94 |
+
|
95 |
+
class ImportanceRenderer(torch.nn.Module):
|
96 |
+
"""
|
97 |
+
Modified original version to filter out-of-box samples as TensoRF does.
|
98 |
+
|
99 |
+
Reference:
|
100 |
+
TensoRF: https://github.com/apchenstu/TensoRF/blob/main/models/tensorBase.py#L277
|
101 |
+
"""
|
102 |
+
def __init__(self):
|
103 |
+
super().__init__()
|
104 |
+
self.activation_factory = self._build_activation_factory()
|
105 |
+
self.ray_marcher = MipRayMarcher2(self.activation_factory)
|
106 |
+
self.plane_axes = generate_planes()
|
107 |
+
|
108 |
+
def _build_activation_factory(self):
|
109 |
+
def activation_factory(options: dict):
|
110 |
+
if options['clamp_mode'] == 'softplus':
|
111 |
+
return lambda x: F.softplus(x - 1) # activation bias of -1 makes things initialize better
|
112 |
+
else:
|
113 |
+
assert False, "Renderer only supports `clamp_mode`=`softplus`!"
|
114 |
+
return activation_factory
|
115 |
+
|
116 |
+
def _forward_pass(self, depths: torch.Tensor, ray_directions: torch.Tensor, ray_origins: torch.Tensor,
|
117 |
+
planes: torch.Tensor, decoder: nn.Module, rendering_options: dict):
|
118 |
+
"""
|
119 |
+
Additional filtering is applied to filter out-of-box samples.
|
120 |
+
Modifications made by Zexin He.
|
121 |
+
"""
|
122 |
+
|
123 |
+
# context related variables
|
124 |
+
batch_size, num_rays, samples_per_ray, _ = depths.shape
|
125 |
+
device = planes.device
|
126 |
+
depths = depths.to(device)
|
127 |
+
ray_directions = ray_directions.to(device)
|
128 |
+
ray_origins = ray_origins.to(device)
|
129 |
+
# define sample points with depths
|
130 |
+
sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3)
|
131 |
+
sample_coordinates = (ray_origins.unsqueeze(-2) + depths * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3)
|
132 |
+
|
133 |
+
# filter out-of-box samples
|
134 |
+
mask_inbox = \
|
135 |
+
(rendering_options['sampler_bbox_min'] <= sample_coordinates) & \
|
136 |
+
(sample_coordinates <= rendering_options['sampler_bbox_max'])
|
137 |
+
mask_inbox = mask_inbox.all(-1)
|
138 |
+
|
139 |
+
# forward model according to all samples
|
140 |
+
_out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options)
|
141 |
+
|
142 |
+
# set out-of-box samples to zeros(rgb) & -inf(sigma)
|
143 |
+
SAFE_GUARD = 3
|
144 |
+
DATA_TYPE = _out['sigma'].dtype
|
145 |
+
colors_pass = torch.zeros(batch_size, num_rays * samples_per_ray, 3, device=device, dtype=DATA_TYPE)
|
146 |
+
densities_pass = torch.nan_to_num(torch.full((batch_size, num_rays * samples_per_ray, 1), -float('inf'), device=device, dtype=DATA_TYPE)) / SAFE_GUARD
|
147 |
+
colors_pass[mask_inbox], densities_pass[mask_inbox] = _out['rgb'][mask_inbox], _out['sigma'][mask_inbox]
|
148 |
+
|
149 |
+
# reshape back
|
150 |
+
colors_pass = colors_pass.reshape(batch_size, num_rays, samples_per_ray, colors_pass.shape[-1])
|
151 |
+
densities_pass = densities_pass.reshape(batch_size, num_rays, samples_per_ray, densities_pass.shape[-1])
|
152 |
+
|
153 |
+
return colors_pass, densities_pass
|
154 |
+
|
155 |
+
def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options):
|
156 |
+
# self.plane_axes = self.plane_axes.to(ray_origins.device)
|
157 |
+
|
158 |
+
if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto':
|
159 |
+
ray_start, ray_end = math_utils.get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp'])
|
160 |
+
is_ray_valid = ray_end > ray_start
|
161 |
+
if torch.any(is_ray_valid).item():
|
162 |
+
ray_start[~is_ray_valid] = ray_start[is_ray_valid].min()
|
163 |
+
ray_end[~is_ray_valid] = ray_start[is_ray_valid].max()
|
164 |
+
depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
|
165 |
+
else:
|
166 |
+
# Create stratified depth samples
|
167 |
+
depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
|
168 |
+
|
169 |
+
depths_coarse = depths_coarse.to(planes.device)
|
170 |
+
|
171 |
+
# Coarse Pass
|
172 |
+
colors_coarse, densities_coarse = self._forward_pass(
|
173 |
+
depths=depths_coarse, ray_directions=ray_directions, ray_origins=ray_origins,
|
174 |
+
planes=planes, decoder=decoder, rendering_options=rendering_options)
|
175 |
+
|
176 |
+
# Fine Pass
|
177 |
+
N_importance = rendering_options['depth_resolution_importance']
|
178 |
+
if N_importance > 0:
|
179 |
+
_, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options)
|
180 |
+
|
181 |
+
depths_fine = self.sample_importance(depths_coarse, weights, N_importance)
|
182 |
+
|
183 |
+
colors_fine, densities_fine = self._forward_pass(
|
184 |
+
depths=depths_fine, ray_directions=ray_directions, ray_origins=ray_origins,
|
185 |
+
planes=planes, decoder=decoder, rendering_options=rendering_options)
|
186 |
+
|
187 |
+
all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse,
|
188 |
+
depths_fine, colors_fine, densities_fine)
|
189 |
+
|
190 |
+
# Aggregate
|
191 |
+
rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options)
|
192 |
+
else:
|
193 |
+
rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options)
|
194 |
+
|
195 |
+
return rgb_final, depth_final, weights.sum(2)
|
196 |
+
|
197 |
+
def run_model(self, planes, decoder, sample_coordinates, sample_directions, options):
|
198 |
+
plane_axes = self.plane_axes.to(planes.device)
|
199 |
+
sampled_features = sample_from_planes(plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp'])
|
200 |
+
|
201 |
+
out = decoder(sampled_features, sample_directions)
|
202 |
+
if options.get('density_noise', 0) > 0:
|
203 |
+
out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise']
|
204 |
+
return out
|
205 |
+
|
206 |
+
def run_model_activated(self, planes, decoder, sample_coordinates, sample_directions, options):
|
207 |
+
out = self.run_model(planes, decoder, sample_coordinates, sample_directions, options)
|
208 |
+
out['sigma'] = self.activation_factory(options)(out['sigma'])
|
209 |
+
return out
|
210 |
+
|
211 |
+
def sort_samples(self, all_depths, all_colors, all_densities):
|
212 |
+
_, indices = torch.sort(all_depths, dim=-2)
|
213 |
+
all_depths = torch.gather(all_depths, -2, indices)
|
214 |
+
all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
|
215 |
+
all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
|
216 |
+
return all_depths, all_colors, all_densities
|
217 |
+
|
218 |
+
def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2):
|
219 |
+
all_depths = torch.cat([depths1, depths2], dim = -2)
|
220 |
+
all_colors = torch.cat([colors1, colors2], dim = -2)
|
221 |
+
all_densities = torch.cat([densities1, densities2], dim = -2)
|
222 |
+
|
223 |
+
_, indices = torch.sort(all_depths, dim=-2)
|
224 |
+
all_depths = torch.gather(all_depths, -2, indices)
|
225 |
+
all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
|
226 |
+
all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
|
227 |
+
|
228 |
+
return all_depths, all_colors, all_densities
|
229 |
+
|
230 |
+
def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False):
|
231 |
+
"""
|
232 |
+
Return depths of approximately uniformly spaced samples along rays.
|
233 |
+
"""
|
234 |
+
N, M, _ = ray_origins.shape
|
235 |
+
if disparity_space_sampling:
|
236 |
+
depths_coarse = torch.linspace(0,
|
237 |
+
1,
|
238 |
+
depth_resolution,
|
239 |
+
device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
|
240 |
+
depth_delta = 1/(depth_resolution - 1)
|
241 |
+
depths_coarse += torch.rand_like(depths_coarse) * depth_delta
|
242 |
+
depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse)
|
243 |
+
else:
|
244 |
+
if type(ray_start) == torch.Tensor:
|
245 |
+
depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3)
|
246 |
+
depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
|
247 |
+
depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None]
|
248 |
+
else:
|
249 |
+
depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
|
250 |
+
depth_delta = (ray_end - ray_start)/(depth_resolution - 1)
|
251 |
+
depths_coarse += torch.rand_like(depths_coarse) * depth_delta
|
252 |
+
|
253 |
+
return depths_coarse
|
254 |
+
|
255 |
+
def sample_importance(self, z_vals, weights, N_importance):
|
256 |
+
"""
|
257 |
+
Return depths of importance sampled points along rays. See NeRF importance sampling for more.
|
258 |
+
"""
|
259 |
+
with torch.no_grad():
|
260 |
+
batch_size, num_rays, samples_per_ray, _ = z_vals.shape
|
261 |
+
|
262 |
+
z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray)
|
263 |
+
weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher
|
264 |
+
|
265 |
+
# smooth weights
|
266 |
+
weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1).float(), 2, 1, padding=1)
|
267 |
+
weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze()
|
268 |
+
weights = weights + 0.01
|
269 |
+
|
270 |
+
z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:])
|
271 |
+
importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1],
|
272 |
+
N_importance).detach().reshape(batch_size, num_rays, N_importance, 1)
|
273 |
+
return importance_z_vals
|
274 |
+
|
275 |
+
def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5):
|
276 |
+
"""
|
277 |
+
Sample @N_importance samples from @bins with distribution defined by @weights.
|
278 |
+
Inputs:
|
279 |
+
bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2"
|
280 |
+
weights: (N_rays, N_samples_)
|
281 |
+
N_importance: the number of samples to draw from the distribution
|
282 |
+
det: deterministic or not
|
283 |
+
eps: a small number to prevent division by zero
|
284 |
+
Outputs:
|
285 |
+
samples: the sampled samples
|
286 |
+
"""
|
287 |
+
N_rays, N_samples_ = weights.shape
|
288 |
+
weights = weights + eps # prevent division by zero (don't do inplace op!)
|
289 |
+
pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_)
|
290 |
+
cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function
|
291 |
+
cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1)
|
292 |
+
# padded to 0~1 inclusive
|
293 |
+
|
294 |
+
if det:
|
295 |
+
u = torch.linspace(0, 1, N_importance, device=bins.device)
|
296 |
+
u = u.expand(N_rays, N_importance)
|
297 |
+
else:
|
298 |
+
u = torch.rand(N_rays, N_importance, device=bins.device)
|
299 |
+
u = u.contiguous()
|
300 |
+
|
301 |
+
inds = torch.searchsorted(cdf, u, right=True)
|
302 |
+
below = torch.clamp_min(inds-1, 0)
|
303 |
+
above = torch.clamp_max(inds, N_samples_)
|
304 |
+
|
305 |
+
inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance)
|
306 |
+
cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2)
|
307 |
+
bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2)
|
308 |
+
|
309 |
+
denom = cdf_g[...,1]-cdf_g[...,0]
|
310 |
+
denom[denom<eps] = 1 # denom equals 0 means a bin has weight 0, in which case it will not be sampled
|
311 |
+
# anyway, therefore any value for it is fine (set to 1 here)
|
312 |
+
|
313 |
+
samples = bins_g[...,0] + (u-cdf_g[...,0])/denom * (bins_g[...,1]-bins_g[...,0])
|
314 |
+
return samples
|