ameerazam08's picture
Upload folder using huggingface_hub
03da825 verified
import torch
import torch.nn as nn
import numpy as np
from additional_modules.eg3d.networks import TriPlaneGenerator
from additional_modules.eg3d.camera_utils import LookAtPoseSampler, IntrinsicsSampler
NETWORK_PKL = 'experiments/pretrained_models/eg3d_ffhq_rebalance.pth'
class EG3DSampler(nn.Module):
def __init__(self):
super().__init__()
eg3d_data = torch.load(NETWORK_PKL, map_location='cpu')
G = TriPlaneGenerator(**eg3d_data['init_kwargs'])
G.init_kwargs = eg3d_data['init_kwargs']
G.load_state_dict(eg3d_data['state_dict'])
G.neural_rendering_resolution = eg3d_data['neural_rendering_resolution']
G.rendering_kwargs = eg3d_data['rendering_kwargs']
G.rendering_kwargs['ray_start'] = 2.0
G.rendering_kwargs['ray_end'] = 3.5
G.rendering_kwargs['depth_resolution'] = 52
G.rendering_kwargs['depth_resolution_importance'] = 60
self.G = G
self.pose_sampler = LookAtPoseSampler()
self.intrinsics_sampler = IntrinsicsSampler()
self.register_buffer('lookat_position', torch.tensor([0, 0, 0]))
def render(self, z, yaw, pitch):
device = self.lookat_position.device
lookat_position = self.lookat_position.unsqueeze(0)
cam2world_pose = self.pose_sampler.sample(
yaw, pitch, 2.7,
lookat_position,
yaw, pitch, 0.0,
batch_size=1, device=device
)
intrinsics = self.intrinsics_sampler.sample(
18.837, 0.5,
0.0, 0.0,
batch_size=1, device=device
)
radius = torch.linalg.vector_norm(cam2world_pose[:, :3, 3], dim=1, keepdim=True)
conditioning_cam2world_pose = self.pose_sampler.sample(
np.pi/2, np.pi/2, radius,
lookat_position,
np.pi/2, np.pi/2, 0,
batch_size=1, device=device
)
camera_params = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
conditioning_params = torch.cat([conditioning_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
ws = self.G.mapping(z, conditioning_params, truncation_psi=0.7, truncation_cutoff=14)
out = self.G.synthesis(ws, camera_params)
return out['image']
@torch.no_grad()
def forward(self, num_views, batch_size, z=None):
device = self.lookat_position.device
lookat_position = self.lookat_position.unsqueeze(0).repeat(batch_size, 1)
if z is None:
z = torch.randn((batch_size, self.G.z_dim), device=device)
assert z.shape[0] == batch_size
all_out = []
for view_idx in range(num_views):
cam2world_pose = self.pose_sampler.sample(
0.71, 1.11, 2.7,
lookat_position,
2.42, 2.02, 0.1,
batch_size=batch_size, device=device
)
intrinsics = self.intrinsics_sampler.sample(
18.837, 0.5,
1.5, 0.02,
batch_size=batch_size, device=device
)
radius = torch.linalg.vector_norm(cam2world_pose[:, :3, 3], dim=1, keepdim=True)
conditioning_cam2world_pose = self.pose_sampler.sample(
np.pi/2, np.pi/2, radius,
lookat_position,
np.pi/2, np.pi/2, 0,
batch_size=batch_size, device=device
)
camera_params = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
conditioning_params = torch.cat([conditioning_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
ws = self.G.mapping(z, conditioning_params, truncation_psi=0.7, truncation_cutoff=14)
out = self.G.synthesis(ws, camera_params)
out['cam2world'] = cam2world_pose
out['intrinsics'] = intrinsics
all_out.append(out)
return all_out