Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d | |
from sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d | |
import einops | |
from modules.util import UpBlock2d, DownBlock2d | |
def make_coordinate_grid(spatial_size, type): | |
d, h, w = spatial_size | |
x = torch.arange(w).type(type) | |
y = torch.arange(h).type(type) | |
z = torch.arange(d).type(type) | |
x = (2 * (x / (w - 1)) - 1) | |
y = (2 * (y / (h - 1)) - 1) | |
z = (2 * (z / (d - 1)) - 1) | |
yy = y.view(1, -1, 1).repeat(d, 1, w) | |
xx = x.view(1, 1, -1).repeat(d, h, 1) | |
zz = z.view(-1, 1, 1).repeat(1, h, w) | |
meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3) | |
return meshed | |
def kp2gaussian_3d(kp, spatial_size, kp_variance): | |
""" | |
Transform a keypoint into gaussian like representation | |
""" | |
# mean = kp['value'] | |
mean = kp | |
coordinate_grid = make_coordinate_grid(spatial_size, mean.type()) | |
number_of_leading_dimensions = len(mean.shape) - 1 | |
shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape | |
coordinate_grid = coordinate_grid.view(*shape) | |
repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1) | |
coordinate_grid = coordinate_grid.repeat(*repeats) | |
# Preprocess kp shape | |
shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3) | |
mean = mean.view(*shape) | |
mean_sub = (coordinate_grid - mean) | |
out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance) | |
return out | |
class ResBlock3d(nn.Module): | |
""" | |
Res block, preserve spatial resolution. | |
""" | |
def __init__(self, in_features, kernel_size, padding): | |
super(ResBlock3d, self).__init__() | |
self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, | |
padding=padding) | |
self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, | |
padding=padding) | |
self.norm1 = BatchNorm3d(in_features, affine=True) | |
self.norm2 = BatchNorm3d(in_features, affine=True) | |
def forward(self, x): | |
out = self.norm1(x) | |
out = F.relu(out) | |
out = self.conv1(out) | |
out = self.norm2(out) | |
out = F.relu(out) | |
out = self.conv2(out) | |
out += x | |
return out | |
class rgb_predictor(nn.Module): | |
def __init__(self, in_channels, simpled_channel=128, floor_num=8): | |
super(rgb_predictor, self).__init__() | |
self.floor_num = floor_num | |
self.down_conv = nn.Conv2d(in_channels=in_channels, out_channels=simpled_channel, kernel_size=3, padding=1) | |
def forward(self, feature): | |
""" | |
Args: | |
feature: warp feature: bs * c * h * w | |
Returns: | |
rgb: bs * h * w * floor_num * e | |
""" | |
feature = self.down_conv(feature) | |
feature = einops.rearrange(feature, 'b (c f) h w -> b c f h w', f=self.floor_num) | |
feature = einops.rearrange(feature, 'b c f h w -> b h w f c') | |
return feature | |
class sigma_predictor(nn.Module): | |
def __init__(self, in_channels, simpled_channel=128, floor_num=8): | |
super(sigma_predictor, self).__init__() | |
self.floor_num = floor_num | |
self.down_conv = nn.Conv2d(in_channels=in_channels, out_channels=simpled_channel, kernel_size=3, padding=1) | |
self.res_conv3d = nn.Sequential( | |
ResBlock3d(16, 3, 1), | |
nn.BatchNorm3d(16), | |
ResBlock3d(16, 3, 1), | |
nn.BatchNorm3d(16), | |
ResBlock3d(16, 3, 1), | |
nn.BatchNorm3d(16) | |
) | |
def forward(self, feature): | |
""" | |
Args: | |
feature: bs * h * w * floor * c, the output of rgb predictor | |
Returns: | |
sigma: bs * h * w * floor * encode | |
point: bs * 5023 * 3 | |
""" | |
heatmap = self.down_conv(feature) | |
heatmap = einops.rearrange(heatmap, "b (c f) h w -> b c f h w", f=self.floor_num) | |
heatmap = self.res_conv3d(heatmap) | |
sigma = einops.rearrange(heatmap, "b c f h w -> b h w f c") | |
point_dict = {'sigma_map': heatmap} | |
# point_pred = einops.rearrange(point_pred, 'b p n -> b n p') | |
return sigma, point_dict | |
class MultiHeadNeRFModel(torch.nn.Module): | |
def __init__(self, hidden_size=128, num_encoding_rgb=16, num_encoding_sigma=16): | |
super(MultiHeadNeRFModel, self).__init__() | |
# self.xyz_encoding_dims = 1 + 1 * 2 * num_encoding_functions + num_encoding_rgb | |
self.xyz_encoding_dims = num_encoding_sigma | |
self.viewdir_encoding_dims = num_encoding_rgb | |
# Input layer (default: 16 -> 128) | |
self.layer1 = torch.nn.Linear(self.xyz_encoding_dims, hidden_size) | |
# Layer 2 (default: 128 -> 128) | |
self.layer2 = torch.nn.Linear(hidden_size, hidden_size) | |
# Layer 3_1 (default: 128 -> 1): Predicts radiance ("sigma") | |
self.layer3_1 = torch.nn.Linear(hidden_size, 1) | |
# Layer 3_2 (default: 128 -> 32): Predicts a feature vector (used for color) | |
self.layer3_2 = torch.nn.Linear(hidden_size, hidden_size // 4) | |
self.layer3_3 = torch.nn.Linear(self.viewdir_encoding_dims, hidden_size) | |
# Layer 4 (default: 32 + 128 -> 128) | |
self.layer4 = torch.nn.Linear( | |
hidden_size // 4 + hidden_size, hidden_size | |
) | |
# Layer 5 (default: 128 -> 128) | |
self.layer5 = torch.nn.Linear(hidden_size, hidden_size) | |
# Layer 6 (default: 128 -> 256): Predicts RGB color | |
self.layer6 = torch.nn.Linear(hidden_size, 256) | |
# Short hand for torch.nn.functional.relu | |
self.relu = torch.nn.functional.relu | |
def forward(self, rgb_in, sigma_in): | |
""" | |
Args: | |
x: rgb pred result of Perdict3D | |
view: result of LightPredict | |
Returns: | |
""" | |
bs, h, w, floor_num, _ = rgb_in.size() | |
# x = torch.cat((x, point3D), dim=-1) | |
out = self.relu(self.layer1(sigma_in)) | |
out = self.relu(self.layer2(out)) | |
sigma = self.layer3_1(out) | |
feat_sigma = self.relu(self.layer3_2(out)) | |
feat_rgb = self.relu(self.layer3_3(rgb_in)) | |
x = torch.cat((feat_sigma, feat_rgb), dim=-1) | |
x = self.relu(self.layer4(x)) | |
x = self.relu(self.layer5(x)) | |
x = self.layer6(x) | |
return x, sigma | |
def volume_render(rgb_pred, sigma_pred): | |
""" | |
Args: | |
rgb_pred: result of Nerf, [bs, h, w, floor, rgb_channel] | |
sigma_pred: result of Nerf, [bs, h, w, floor, sigma_channel] | |
Returns: | |
""" | |
_, _, _, floor, _ = sigma_pred.size() | |
c = 0 | |
T = 0 | |
for i in range(floor): | |
sigma_mid = torch.nn.functional.relu(sigma_pred[:, :, :, i, :]) | |
T = T + (-sigma_mid) | |
c = c + torch.exp(T) * (1 - torch.exp(-sigma_mid)) * rgb_pred[:, :, :, i, :] | |
c = einops.rearrange(c, 'b h w c -> b c h w') | |
return c | |
class RenderModel(nn.Module): | |
def __init__(self, in_channels, simpled_channel_rgb, simpled_channel_sigma, floor_num, hidden_size): | |
super(RenderModel, self).__init__() | |
self.rgb_predict = rgb_predictor(in_channels=in_channels, simpled_channel=simpled_channel_rgb, | |
floor_num=floor_num) | |
self.sigma_predict = sigma_predictor(in_channels=in_channels, simpled_channel=simpled_channel_sigma, | |
floor_num=floor_num) | |
num_encoding_rgb, num_encoding_sigma = simpled_channel_rgb // floor_num, simpled_channel_sigma // floor_num | |
self.nerf_module = MultiHeadNeRFModel(hidden_size=hidden_size, num_encoding_rgb=num_encoding_rgb, | |
num_encoding_sigma=num_encoding_sigma) | |
self.mini_decoder = nn.Sequential( | |
UpBlock2d(256, 64, kernel_size=3, padding=1), | |
nn.ReLU(), | |
UpBlock2d(64, 3, kernel_size=3, padding=1), | |
nn.Sigmoid() | |
) | |
def forward(self, feature): | |
rgb_in = self.rgb_predict(feature) | |
# sigma_in, point_dict = self.sigma_predict(feature.detach()) | |
sigma_in, point_dict = self.sigma_predict(feature) | |
rgb_out, sigma_out = self.nerf_module(rgb_in, sigma_in) | |
render_result = volume_render(rgb_out, sigma_out) | |
render_result = torch.sigmoid(render_result) | |
mini_pred = self.mini_decoder(render_result) | |
out_dict = {'render': render_result, 'mini_pred': mini_pred, 'point_pred': point_dict} | |
return out_dict | |