import torch import torch.nn as nn from typing import List, Dict from additional_modules.deeplabv3.deeplabv3 import DeepLabV3 from additional_modules.segformer.backbone import Block, OverlapPatchEmbed from additional_modules.eg3d.networks import OSGDecoder from models import get_model from rendering.ray_sampler import RaySampler from rendering.triplane_rendering.renderer import ImportanceRenderer from utils.registry import MODEL_REGISTRY class PositionalEncoder(nn.Module): def __init__(self, img_size: int): super().__init__() h_linspace = torch.linspace(-1, 1, img_size) w_linspace = torch.linspace(-1, 1, img_size) gh, gw = torch.meshgrid(h_linspace, w_linspace, indexing='xy') gh, gw = gh.unsqueeze(0), gw.unsqueeze(0) id_grid = torch.cat((gh, gw), dim=0).unsqueeze(0) self.register_buffer('id_grid', id_grid) def _add_positional_encoding(self, img): id_grid = self.id_grid.repeat(img.shape[0], 1, 1, 1) x = torch.cat((img, id_grid), dim=1) return x class ELow(PositionalEncoder): def __init__(self, img_size: int = 512, img_channels: int = 3): super().__init__(img_size) self.deeplabv3_backbone = DeepLabV3(input_channels=img_channels + 2) self.patch_embed = OverlapPatchEmbed( img_size=img_size // 8, patch_size=3, stride=2, in_chans=256, embed_dim=1024 ) self.block1 = Block(dim=1024, num_heads=4, mlp_ratio=2, sr_ratio=1) self.block2 = Block(dim=1024, num_heads=4, mlp_ratio=2, sr_ratio=1) self.block3 = Block(dim=1024, num_heads=4, mlp_ratio=2, sr_ratio=1) self.block4 = Block(dim=1024, num_heads=4, mlp_ratio=2, sr_ratio=1) self.block5 = Block(dim=1024, num_heads=4, mlp_ratio=2, sr_ratio=1) self.up1 = nn.PixelShuffle(upscale_factor=2) self.up2 = nn.Upsample(scale_factor=2, mode='bilinear') self.conv1 = nn.Conv2d(256, 128, 3, 1, 1, bias=True) self.act1 = nn.ReLU() self.up3 = nn.Upsample(scale_factor=2, mode='bilinear') self.conv2 = nn.Conv2d(128, 128, 3, 1, 1, bias=True) self.act2 = nn.ReLU() self.conv3 = nn.Conv2d(128, 96, 3, 1, 1, bias=True) def forward(self, img: torch.Tensor): x = self._add_positional_encoding(img) x = self.deeplabv3_backbone(x) x, H, W = self.patch_embed(x) x = self.block1(x, H, W) x = self.block2(x, H, W) x = self.block3(x, H, W) x = self.block4(x, H, W) x = self.block5(x, H, W) x = x.reshape(img.shape[0], H, W, -1).permute(0, 3, 1, 2).contiguous() x = self.up1(x) x = self.up2(x) x = self.conv1(x) x = self.act1(x) x = self.up3(x) x = self.conv2(x) x = self.act2(x) x = self.conv3(x) return x class EHigh(PositionalEncoder): def __init__(self, img_size: int = 512, img_channels: int = 3): super().__init__(img_size) self.conv1 = nn.Conv2d(img_channels + 2, 64, 7, 2, 3, bias=True) self.act1 = nn.LeakyReLU(0.01) self.conv2 = nn.Conv2d(64, 96, 3, 1, 1, bias=True) self.act2 = nn.LeakyReLU(0.01) self.conv3 = nn.Conv2d(96, 96, 3, 1, 1, bias=True) self.act3 = nn.LeakyReLU(0.01) self.conv4 = nn.Conv2d(96, 96, 3, 1, 1, bias=True) self.act4 = nn.LeakyReLU(0.01) self.conv5 = nn.Conv2d(96, 96, 3, 1, 1, bias=True) self.act5 = nn.LeakyReLU(0.01) def forward(self, img: torch.Tensor): x = self._add_positional_encoding(img) x = self.conv1(x) x = self.act1(x) x = self.conv2(x) x = self.act2(x) x = self.conv3(x) x = self.act3(x) x = self.conv4(x) x = self.act4(x) x = self.conv5(x) x = self.act5(x) return x class Lp3DEncoder(nn.Module): def __init__(self, img_size: int = 512, img_channels: int = 3, triplane_nd: int = 32): super().__init__() self.img_size = img_size self.elo = ELow(img_size, img_channels) self.ehi = EHigh(img_size, img_channels) self.conv1 = nn.Conv2d(192, 256, 3, 1, 1, bias=True) self.act1 = nn.LeakyReLU(0.01) self.conv2 = nn.Conv2d(256, 128, 3, 1, 1, bias=True) self.act2 = nn.LeakyReLU(0.01) self.patch_embed = OverlapPatchEmbed( img_size=img_size // 2, patch_size=3, stride=2, in_chans=128, embed_dim=1024 ) self.transformer_block = Block(dim=1024, num_heads=2, mlp_ratio=2, sr_ratio=2) self.up = nn.PixelShuffle(upscale_factor=2) self.conv3 = nn.Conv2d(352, 256, 3, 1, 1, bias=True) self.act3 = nn.LeakyReLU(0.01) self.conv4 = nn.Conv2d(256, 128, 3, 1, 1, bias=True) self.act4 = nn.LeakyReLU(0.01) self.conv5 = nn.Conv2d(128, 128, 3, 1, 1, bias=True) self.act5 = nn.LeakyReLU(0.01) self.conv6 = nn.Conv2d(128, triplane_nd * 3, 3, 1, 1, bias=True) def forward(self, img: torch.Tensor): assert img.shape[-1] == self.img_size and img.shape[-2] == self.img_size f_lo = self.elo(img) f_hi = self.ehi(img) f = torch.cat((f_lo, f_hi), dim=1) f = self.conv1(f) f = self.act1(f) f = self.conv2(f) f = self.act2(f) f, H, W = self.patch_embed(f) f = self.transformer_block(f, H, W) f = f.reshape(img.shape[0], H, W, -1).permute(0, 3, 1, 2).contiguous() f = self.up(f) f = torch.cat((f, f_lo), dim=1) f = self.conv3(f) f = self.act3(f) f = self.conv4(f) f = self.act4(f) f = self.conv5(f) f = self.act5(f) f = self.conv6(f) return f @MODEL_REGISTRY.register() class Lp3D(nn.Module): def __init__( self, neural_rendering_resolution: int, # Render at this resolution and use superres to upsample to 512x512 triplane_nd: int, # Triplane's number of channels triplane_h: int, # Triplane height triplane_w: int, # Triplane width rendering_kwargs, superresolution_kwargs, ): super().__init__() self.triplane_nd = triplane_nd self.triplane_h = triplane_h self.triplane_w = triplane_w self.neural_rendering_resolution = neural_rendering_resolution self.superresolution_opt = superresolution_kwargs self.rendering_kwargs = rendering_kwargs self._setup_modules() def _setup_modules(self): # For now only support 512x512 input image and 256x256 triplane self.triplane_encoder = Lp3DEncoder(triplane_nd=self.triplane_nd) self.renderer = ImportanceRenderer() self.decoder = OSGDecoder( self.triplane_nd, { 'decoder_lr_mul': self.rendering_kwargs.get('decoder_lr_mul', 1), 'decoder_output_dim': self.triplane_nd } ) self.ray_sampler = RaySampler() self.superresolution = get_model(self.superresolution_opt) def render(self, planes, cam2world, intrinsics, upsample=True): """ Render the triplane using cam2wolrd and intrinsics matrices Parameters: - triplane (Tensor) - cam2world (Tensor) - intrinsics (Tensor) Returns: - image (Tensor) [Range -1..1]: The rendered images - cam2world (Tensor): The input cam2world. Can be useful to stack multiple renderings. - intrinsics (Tensor): The input intrinsics. Can be useful to stack multiple renderings. """ ray_origins, ray_directions = self.ray_sampler( cam2world, intrinsics, self.neural_rendering_resolution ) batch_size = cam2world.shape[0] feature_samples, depth_samples, _, _ = self.renderer( planes, self.decoder, ray_origins, ray_directions, self.rendering_kwargs ) # channels last H = W = self.neural_rendering_resolution feature_image = feature_samples.permute(0, 2, 1).reshape(batch_size, feature_samples.shape[-1], H, W).contiguous() rgb_image_raw = feature_image[:, :3] if upsample: rgb_image = self.superresolution(rgb_image_raw, feature_image) else: rgb_image = None depth_image = depth_samples.permute(0, 2, 1).reshape(batch_size, 1, H, W).contiguous() return { 'image_raw': rgb_image_raw, 'image': rgb_image, 'planes': planes, 'depth': depth_image, 'feature_image': feature_image, 'cam2world': cam2world, 'intrinsics': intrinsics } def canonicalize(self, image: torch.Tensor): """ Transform the input image to the canonicalized 3D space which represented by a triplane Parameters: - image (Tensor): Input image Returns: - triplanes (Tensor): The canonical representation of the input """ image = (image + 1) / 2. # Legacy issue :( triplanes = self.triplane_encoder(image) B = triplanes.shape[0] triplanes = triplanes.view(B, 3, 32, triplanes.shape[-2], triplanes.shape[-1]).contiguous() return triplanes def forward( self, xs_data: Dict[str, torch.Tensor], all_xds_data: List[Dict[str, torch.Tensor]] ): """ Render the source image using camera parameters from the driver(s). This inference function support multiple camera inputs. Can be useful when training in which the loss is calculated on multiple views of a single source image Parameters: - xs_data: The source's data. Must have 'image' key in it - all_xds_data: All drivers' data. Each of them must have 'cam2world' and 'intrinsics' """ xs_triplane = self.canonicalize(xs_data['image']) all_out = [] for xd_data in all_xds_data: driver_out = self.render(xs_triplane, xd_data['cam2world'], xd_data['intrinsics']) all_out.append(driver_out) return all_out