Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
# | |
# @File: __init__.py | |
# @Author: Jiaxiang Tang (@ashawkey) | |
# @Date: 2023-04-15 10:39:28 | |
# @Last Modified by: Haozhe Xie | |
# @Last Modified at: 2023-04-15 13:08:46 | |
# @Email: [email protected] | |
# @Ref: https://github.com/ashawkey/torch-ngp | |
import math | |
import numpy as np | |
import torch | |
import grid_encoder_ext | |
class GridEncoderFunction(torch.autograd.Function): | |
def forward( | |
ctx, | |
inputs, | |
embeddings, | |
offsets, | |
per_level_scale, | |
base_resolution, | |
calc_grad_inputs=False, | |
gridtype=0, | |
align_corners=False, | |
): | |
# inputs: [B, D], float in [0, 1] | |
# embeddings: [sO, C], float | |
# offsets: [L + 1], int | |
# RETURN: [B, F], float | |
inputs = inputs.contiguous() | |
# batch size, coord dim | |
B, D = inputs.shape | |
# level | |
L = offsets.shape[0] - 1 | |
# embedding dim for each level | |
C = embeddings.shape[1] | |
# resolution multiplier at each level, apply log2 for later CUDA exp2f | |
S = math.log2(per_level_scale) | |
# base resolution | |
H = base_resolution | |
# L first, optimize cache for cuda kernel, but needs an extra permute later | |
outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype) | |
if calc_grad_inputs: | |
dy_dx = torch.empty( | |
B, L * D * C, device=inputs.device, dtype=embeddings.dtype | |
) | |
else: | |
dy_dx = torch.empty( | |
1, device=inputs.device, dtype=embeddings.dtype | |
) # placeholder... TODO: a better way? | |
grid_encoder_ext.forward( | |
inputs, | |
embeddings, | |
offsets, | |
outputs, | |
B, | |
D, | |
C, | |
L, | |
S, | |
H, | |
calc_grad_inputs, | |
dy_dx, | |
gridtype, | |
align_corners, | |
) | |
# permute back to [B, L * C] | |
outputs = outputs.permute(1, 0, 2).reshape(B, L * C) | |
ctx.save_for_backward(inputs, embeddings, offsets, dy_dx) | |
ctx.dims = [B, D, C, L, S, H, gridtype] | |
ctx.calc_grad_inputs = calc_grad_inputs | |
ctx.align_corners = align_corners | |
return outputs | |
def backward(ctx, grad): | |
inputs, embeddings, offsets, dy_dx = ctx.saved_tensors | |
B, D, C, L, S, H, gridtype = ctx.dims | |
calc_grad_inputs = ctx.calc_grad_inputs | |
align_corners = ctx.align_corners | |
# grad: [B, L * C] --> [L, B, C] | |
grad = grad.view(B, L, C).permute(1, 0, 2).contiguous() | |
grad_embeddings = torch.zeros_like(embeddings) | |
if calc_grad_inputs: | |
grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype) | |
else: | |
grad_inputs = torch.zeros(1, device=inputs.device, dtype=embeddings.dtype) | |
grid_encoder_ext.backward( | |
grad, | |
inputs, | |
embeddings, | |
offsets, | |
grad_embeddings, | |
B, | |
D, | |
C, | |
L, | |
S, | |
H, | |
calc_grad_inputs, | |
dy_dx, | |
grad_inputs, | |
gridtype, | |
align_corners, | |
) | |
if calc_grad_inputs: | |
grad_inputs = grad_inputs.to(inputs.dtype) | |
return grad_inputs, grad_embeddings, None, None, None, None, None, None | |
else: | |
return None, grad_embeddings, None, None, None, None, None, None | |
class GridEncoder(torch.nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
n_levels, | |
lvl_channels, | |
desired_resolution, | |
per_level_scale=2, | |
base_resolution=16, | |
log2_hashmap_size=19, | |
gridtype="hash", | |
align_corners=False, | |
): | |
super(GridEncoder, self).__init__() | |
self.in_channels = in_channels | |
self.n_levels = n_levels # num levels, each level multiply resolution by 2 | |
self.lvl_channels = lvl_channels # encode channels per level | |
self.per_level_scale = 2 ** ( | |
math.log2(desired_resolution / base_resolution) / (n_levels - 1) | |
) | |
self.log2_hashmap_size = log2_hashmap_size | |
self.base_resolution = base_resolution | |
self.output_dim = n_levels * lvl_channels | |
self.gridtype = gridtype | |
self.gridtype_id = 0 if gridtype == "hash" else 1 | |
self.align_corners = align_corners | |
# allocate parameters | |
offsets = [] | |
offset = 0 | |
self.max_params = 2**log2_hashmap_size | |
for i in range(n_levels): | |
resolution = int(math.ceil(base_resolution * per_level_scale**i)) | |
params_in_level = min( | |
self.max_params, | |
(resolution if align_corners else resolution + 1) ** in_channels, | |
) # limit max number | |
params_in_level = int(math.ceil(params_in_level / 8) * 8) # make divisible | |
offsets.append(offset) | |
offset += params_in_level | |
offsets.append(offset) | |
offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) | |
self.register_buffer("offsets", offsets) | |
self.n_params = offsets[-1] * lvl_channels | |
self.embeddings = torch.nn.Parameter(torch.empty(offset, lvl_channels)) | |
self._init_weights() | |
def _init_weights(self): | |
self.embeddings.data.uniform_(-1e-4, 1e-4) | |
def forward(self, inputs, bound=1): | |
# inputs: [..., in_channels], normalized real world positions in [-bound, bound] | |
# return: [..., n_levels * lvl_channels] | |
inputs = (inputs + bound) / (2 * bound) # map to [0, 1] | |
prefix_shape = list(inputs.shape[:-1]) | |
inputs = inputs.view(-1, self.in_channels) | |
outputs = GridEncoderFunction.apply( | |
inputs, | |
self.embeddings, | |
self.offsets, | |
self.per_level_scale, | |
self.base_resolution, | |
inputs.requires_grad, | |
self.gridtype_id, | |
self.align_corners, | |
) | |
return outputs.view(prefix_shape + [self.output_dim]) | |