|
import torch |
|
import torch.nn as nn |
|
|
|
import numpy as np |
|
|
|
from additional_modules.eg3d.networks import TriPlaneGenerator |
|
from additional_modules.eg3d.camera_utils import LookAtPoseSampler, IntrinsicsSampler |
|
|
|
|
|
NETWORK_PKL = 'experiments/pretrained_models/eg3d_ffhq_rebalance.pth' |
|
|
|
|
|
class EG3DSampler(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
eg3d_data = torch.load(NETWORK_PKL, map_location='cpu') |
|
G = TriPlaneGenerator(**eg3d_data['init_kwargs']) |
|
G.init_kwargs = eg3d_data['init_kwargs'] |
|
G.load_state_dict(eg3d_data['state_dict']) |
|
G.neural_rendering_resolution = eg3d_data['neural_rendering_resolution'] |
|
G.rendering_kwargs = eg3d_data['rendering_kwargs'] |
|
G.rendering_kwargs['ray_start'] = 2.0 |
|
G.rendering_kwargs['ray_end'] = 3.5 |
|
G.rendering_kwargs['depth_resolution'] = 52 |
|
G.rendering_kwargs['depth_resolution_importance'] = 60 |
|
self.G = G |
|
|
|
self.pose_sampler = LookAtPoseSampler() |
|
self.intrinsics_sampler = IntrinsicsSampler() |
|
|
|
self.register_buffer('lookat_position', torch.tensor([0, 0, 0])) |
|
|
|
def render(self, z, yaw, pitch): |
|
device = self.lookat_position.device |
|
lookat_position = self.lookat_position.unsqueeze(0) |
|
|
|
cam2world_pose = self.pose_sampler.sample( |
|
yaw, pitch, 2.7, |
|
lookat_position, |
|
yaw, pitch, 0.0, |
|
batch_size=1, device=device |
|
) |
|
intrinsics = self.intrinsics_sampler.sample( |
|
18.837, 0.5, |
|
0.0, 0.0, |
|
batch_size=1, device=device |
|
) |
|
|
|
radius = torch.linalg.vector_norm(cam2world_pose[:, :3, 3], dim=1, keepdim=True) |
|
conditioning_cam2world_pose = self.pose_sampler.sample( |
|
np.pi/2, np.pi/2, radius, |
|
lookat_position, |
|
np.pi/2, np.pi/2, 0, |
|
batch_size=1, device=device |
|
) |
|
|
|
camera_params = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) |
|
conditioning_params = torch.cat([conditioning_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) |
|
|
|
ws = self.G.mapping(z, conditioning_params, truncation_psi=0.7, truncation_cutoff=14) |
|
out = self.G.synthesis(ws, camera_params) |
|
|
|
return out['image'] |
|
|
|
@torch.no_grad() |
|
def forward(self, num_views, batch_size, z=None): |
|
device = self.lookat_position.device |
|
lookat_position = self.lookat_position.unsqueeze(0).repeat(batch_size, 1) |
|
|
|
if z is None: |
|
z = torch.randn((batch_size, self.G.z_dim), device=device) |
|
assert z.shape[0] == batch_size |
|
|
|
all_out = [] |
|
for view_idx in range(num_views): |
|
cam2world_pose = self.pose_sampler.sample( |
|
0.71, 1.11, 2.7, |
|
lookat_position, |
|
2.42, 2.02, 0.1, |
|
batch_size=batch_size, device=device |
|
) |
|
intrinsics = self.intrinsics_sampler.sample( |
|
18.837, 0.5, |
|
1.5, 0.02, |
|
batch_size=batch_size, device=device |
|
) |
|
|
|
radius = torch.linalg.vector_norm(cam2world_pose[:, :3, 3], dim=1, keepdim=True) |
|
conditioning_cam2world_pose = self.pose_sampler.sample( |
|
np.pi/2, np.pi/2, radius, |
|
lookat_position, |
|
np.pi/2, np.pi/2, 0, |
|
batch_size=batch_size, device=device |
|
) |
|
|
|
camera_params = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) |
|
conditioning_params = torch.cat([conditioning_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) |
|
|
|
ws = self.G.mapping(z, conditioning_params, truncation_psi=0.7, truncation_cutoff=14) |
|
out = self.G.synthesis(ws, camera_params) |
|
out['cam2world'] = cam2world_pose |
|
out['intrinsics'] = intrinsics |
|
|
|
all_out.append(out) |
|
|
|
return all_out |
|
|