|
import torch |
|
import torch.nn as nn |
|
|
|
from typing import List, Dict |
|
|
|
from additional_modules.deeplabv3.deeplabv3 import DeepLabV3 |
|
from additional_modules.segformer.backbone import Block, OverlapPatchEmbed |
|
from additional_modules.eg3d.networks import OSGDecoder |
|
from models import get_model |
|
from rendering.ray_sampler import RaySampler |
|
from rendering.triplane_rendering.renderer import ImportanceRenderer |
|
from utils.registry import MODEL_REGISTRY |
|
|
|
|
|
class PositionalEncoder(nn.Module): |
|
def __init__(self, img_size: int): |
|
super().__init__() |
|
|
|
h_linspace = torch.linspace(-1, 1, img_size) |
|
w_linspace = torch.linspace(-1, 1, img_size) |
|
gh, gw = torch.meshgrid(h_linspace, w_linspace, indexing='xy') |
|
gh, gw = gh.unsqueeze(0), gw.unsqueeze(0) |
|
id_grid = torch.cat((gh, gw), dim=0).unsqueeze(0) |
|
self.register_buffer('id_grid', id_grid) |
|
|
|
def _add_positional_encoding(self, img): |
|
id_grid = self.id_grid.repeat(img.shape[0], 1, 1, 1) |
|
x = torch.cat((img, id_grid), dim=1) |
|
|
|
return x |
|
|
|
|
|
class ELow(PositionalEncoder): |
|
def __init__(self, img_size: int = 512, img_channels: int = 3): |
|
super().__init__(img_size) |
|
|
|
self.deeplabv3_backbone = DeepLabV3(input_channels=img_channels + 2) |
|
self.patch_embed = OverlapPatchEmbed( |
|
img_size=img_size // 8, patch_size=3, stride=2, in_chans=256, embed_dim=1024 |
|
) |
|
|
|
self.block1 = Block(dim=1024, num_heads=4, mlp_ratio=2, sr_ratio=1) |
|
self.block2 = Block(dim=1024, num_heads=4, mlp_ratio=2, sr_ratio=1) |
|
self.block3 = Block(dim=1024, num_heads=4, mlp_ratio=2, sr_ratio=1) |
|
self.block4 = Block(dim=1024, num_heads=4, mlp_ratio=2, sr_ratio=1) |
|
self.block5 = Block(dim=1024, num_heads=4, mlp_ratio=2, sr_ratio=1) |
|
|
|
self.up1 = nn.PixelShuffle(upscale_factor=2) |
|
self.up2 = nn.Upsample(scale_factor=2, mode='bilinear') |
|
|
|
self.conv1 = nn.Conv2d(256, 128, 3, 1, 1, bias=True) |
|
self.act1 = nn.ReLU() |
|
self.up3 = nn.Upsample(scale_factor=2, mode='bilinear') |
|
|
|
self.conv2 = nn.Conv2d(128, 128, 3, 1, 1, bias=True) |
|
self.act2 = nn.ReLU() |
|
|
|
self.conv3 = nn.Conv2d(128, 96, 3, 1, 1, bias=True) |
|
|
|
def forward(self, img: torch.Tensor): |
|
x = self._add_positional_encoding(img) |
|
|
|
x = self.deeplabv3_backbone(x) |
|
x, H, W = self.patch_embed(x) |
|
|
|
x = self.block1(x, H, W) |
|
x = self.block2(x, H, W) |
|
x = self.block3(x, H, W) |
|
x = self.block4(x, H, W) |
|
x = self.block5(x, H, W) |
|
|
|
x = x.reshape(img.shape[0], H, W, -1).permute(0, 3, 1, 2).contiguous() |
|
|
|
x = self.up1(x) |
|
x = self.up2(x) |
|
|
|
x = self.conv1(x) |
|
x = self.act1(x) |
|
x = self.up3(x) |
|
|
|
x = self.conv2(x) |
|
x = self.act2(x) |
|
x = self.conv3(x) |
|
|
|
return x |
|
|
|
|
|
class EHigh(PositionalEncoder): |
|
def __init__(self, img_size: int = 512, img_channels: int = 3): |
|
super().__init__(img_size) |
|
|
|
self.conv1 = nn.Conv2d(img_channels + 2, 64, 7, 2, 3, bias=True) |
|
self.act1 = nn.LeakyReLU(0.01) |
|
|
|
self.conv2 = nn.Conv2d(64, 96, 3, 1, 1, bias=True) |
|
self.act2 = nn.LeakyReLU(0.01) |
|
|
|
self.conv3 = nn.Conv2d(96, 96, 3, 1, 1, bias=True) |
|
self.act3 = nn.LeakyReLU(0.01) |
|
|
|
self.conv4 = nn.Conv2d(96, 96, 3, 1, 1, bias=True) |
|
self.act4 = nn.LeakyReLU(0.01) |
|
|
|
self.conv5 = nn.Conv2d(96, 96, 3, 1, 1, bias=True) |
|
self.act5 = nn.LeakyReLU(0.01) |
|
|
|
def forward(self, img: torch.Tensor): |
|
x = self._add_positional_encoding(img) |
|
x = self.conv1(x) |
|
x = self.act1(x) |
|
|
|
x = self.conv2(x) |
|
x = self.act2(x) |
|
|
|
x = self.conv3(x) |
|
x = self.act3(x) |
|
|
|
x = self.conv4(x) |
|
x = self.act4(x) |
|
|
|
x = self.conv5(x) |
|
x = self.act5(x) |
|
|
|
return x |
|
|
|
|
|
class Lp3DEncoder(nn.Module): |
|
def __init__(self, img_size: int = 512, img_channels: int = 3, triplane_nd: int = 32): |
|
super().__init__() |
|
self.img_size = img_size |
|
|
|
self.elo = ELow(img_size, img_channels) |
|
self.ehi = EHigh(img_size, img_channels) |
|
|
|
self.conv1 = nn.Conv2d(192, 256, 3, 1, 1, bias=True) |
|
self.act1 = nn.LeakyReLU(0.01) |
|
|
|
self.conv2 = nn.Conv2d(256, 128, 3, 1, 1, bias=True) |
|
self.act2 = nn.LeakyReLU(0.01) |
|
|
|
self.patch_embed = OverlapPatchEmbed( |
|
img_size=img_size // 2, patch_size=3, stride=2, in_chans=128, embed_dim=1024 |
|
) |
|
self.transformer_block = Block(dim=1024, num_heads=2, mlp_ratio=2, sr_ratio=2) |
|
|
|
self.up = nn.PixelShuffle(upscale_factor=2) |
|
|
|
self.conv3 = nn.Conv2d(352, 256, 3, 1, 1, bias=True) |
|
self.act3 = nn.LeakyReLU(0.01) |
|
|
|
self.conv4 = nn.Conv2d(256, 128, 3, 1, 1, bias=True) |
|
self.act4 = nn.LeakyReLU(0.01) |
|
|
|
self.conv5 = nn.Conv2d(128, 128, 3, 1, 1, bias=True) |
|
self.act5 = nn.LeakyReLU(0.01) |
|
|
|
self.conv6 = nn.Conv2d(128, triplane_nd * 3, 3, 1, 1, bias=True) |
|
|
|
def forward(self, img: torch.Tensor): |
|
assert img.shape[-1] == self.img_size and img.shape[-2] == self.img_size |
|
|
|
f_lo = self.elo(img) |
|
f_hi = self.ehi(img) |
|
|
|
f = torch.cat((f_lo, f_hi), dim=1) |
|
f = self.conv1(f) |
|
f = self.act1(f) |
|
|
|
f = self.conv2(f) |
|
f = self.act2(f) |
|
|
|
f, H, W = self.patch_embed(f) |
|
f = self.transformer_block(f, H, W) |
|
f = f.reshape(img.shape[0], H, W, -1).permute(0, 3, 1, 2).contiguous() |
|
|
|
f = self.up(f) |
|
f = torch.cat((f, f_lo), dim=1) |
|
|
|
f = self.conv3(f) |
|
f = self.act3(f) |
|
|
|
f = self.conv4(f) |
|
f = self.act4(f) |
|
|
|
f = self.conv5(f) |
|
f = self.act5(f) |
|
|
|
f = self.conv6(f) |
|
|
|
return f |
|
|
|
|
|
@MODEL_REGISTRY.register() |
|
class Lp3D(nn.Module): |
|
def __init__( |
|
self, |
|
neural_rendering_resolution: int, |
|
triplane_nd: int, |
|
triplane_h: int, |
|
triplane_w: int, |
|
rendering_kwargs, |
|
superresolution_kwargs, |
|
): |
|
super().__init__() |
|
|
|
self.triplane_nd = triplane_nd |
|
self.triplane_h = triplane_h |
|
self.triplane_w = triplane_w |
|
self.neural_rendering_resolution = neural_rendering_resolution |
|
self.superresolution_opt = superresolution_kwargs |
|
self.rendering_kwargs = rendering_kwargs |
|
|
|
self._setup_modules() |
|
|
|
def _setup_modules(self): |
|
|
|
self.triplane_encoder = Lp3DEncoder(triplane_nd=self.triplane_nd) |
|
|
|
self.renderer = ImportanceRenderer() |
|
self.decoder = OSGDecoder( |
|
self.triplane_nd, |
|
{ |
|
'decoder_lr_mul': self.rendering_kwargs.get('decoder_lr_mul', 1), |
|
'decoder_output_dim': self.triplane_nd |
|
} |
|
) |
|
self.ray_sampler = RaySampler() |
|
self.superresolution = get_model(self.superresolution_opt) |
|
|
|
def render(self, planes, cam2world, intrinsics, upsample=True): |
|
""" |
|
Render the triplane using cam2wolrd and intrinsics matrices |
|
|
|
Parameters: |
|
- triplane (Tensor) |
|
- cam2world (Tensor) |
|
- intrinsics (Tensor) |
|
Returns: |
|
- image (Tensor) [Range -1..1]: The rendered images |
|
- cam2world (Tensor): The input cam2world. Can be useful to stack multiple renderings. |
|
- intrinsics (Tensor): The input intrinsics. Can be useful to stack multiple renderings. |
|
""" |
|
ray_origins, ray_directions = self.ray_sampler( |
|
cam2world, intrinsics, self.neural_rendering_resolution |
|
) |
|
|
|
batch_size = cam2world.shape[0] |
|
|
|
feature_samples, depth_samples, _, _ = self.renderer( |
|
planes, |
|
self.decoder, |
|
ray_origins, |
|
ray_directions, |
|
self.rendering_kwargs |
|
) |
|
|
|
H = W = self.neural_rendering_resolution |
|
feature_image = feature_samples.permute(0, 2, 1).reshape(batch_size, feature_samples.shape[-1], H, W).contiguous() |
|
rgb_image_raw = feature_image[:, :3] |
|
|
|
if upsample: |
|
rgb_image = self.superresolution(rgb_image_raw, feature_image) |
|
else: |
|
rgb_image = None |
|
|
|
depth_image = depth_samples.permute(0, 2, 1).reshape(batch_size, 1, H, W).contiguous() |
|
|
|
return { |
|
'image_raw': rgb_image_raw, |
|
'image': rgb_image, |
|
'planes': planes, |
|
'depth': depth_image, |
|
'feature_image': feature_image, |
|
'cam2world': cam2world, |
|
'intrinsics': intrinsics |
|
} |
|
|
|
def canonicalize(self, image: torch.Tensor): |
|
""" |
|
Transform the input image to the canonicalized 3D space which represented by a triplane |
|
|
|
Parameters: |
|
- image (Tensor): Input image |
|
Returns: |
|
- triplanes (Tensor): The canonical representation of the input |
|
""" |
|
|
|
image = (image + 1) / 2. |
|
triplanes = self.triplane_encoder(image) |
|
B = triplanes.shape[0] |
|
triplanes = triplanes.view(B, 3, 32, triplanes.shape[-2], triplanes.shape[-1]).contiguous() |
|
|
|
return triplanes |
|
|
|
def forward( |
|
self, |
|
xs_data: Dict[str, torch.Tensor], |
|
all_xds_data: List[Dict[str, torch.Tensor]] |
|
): |
|
""" |
|
Render the source image using camera parameters from the driver(s). |
|
|
|
This inference function support multiple camera inputs. Can be useful when |
|
training in which the loss is calculated on multiple views of a single |
|
source image |
|
|
|
Parameters: |
|
- xs_data: The source's data. Must have 'image' key in it |
|
- all_xds_data: All drivers' data. Each of them must have 'cam2world' and 'intrinsics' |
|
""" |
|
xs_triplane = self.canonicalize(xs_data['image']) |
|
|
|
all_out = [] |
|
for xd_data in all_xds_data: |
|
driver_out = self.render(xs_triplane, xd_data['cam2world'], xd_data['intrinsics']) |
|
all_out.append(driver_out) |
|
|
|
return all_out |
|
|