JiantaoLin
commited on
Commit
·
73f9b1b
1
Parent(s):
5a1fa77
new
Browse files- models/lrm/data/irrmaps/README.txt +3 -0
- models/lrm/env_mipmap/6/specular_3.pth +3 -0
- models/lrm/env_mipmap/6/specular_4.pth +3 -0
- models/lrm/env_mipmap/6/specular_5.pth +3 -0
- models/lrm/online_render/src/__init__.py +0 -0
- models/lrm/online_render/src/data/irrmaps/README.txt +3 -0
- models/lrm/online_render/utils/camera_util.py +111 -0
- models/lrm/online_render/utils/camera_utils.py +83 -0
- models/lrm/online_render/utils/general_utils.py +133 -0
- models/lrm/online_render/utils/graphics_utils.py +90 -0
- models/lrm/online_render/utils/image_utils.py +19 -0
- models/lrm/online_render/utils/loss_utils.py +183 -0
- models/lrm/online_render/utils/obj.py +209 -0
- models/lrm/online_render/utils/sh_utils.py +118 -0
- models/lrm/online_render/utils/system_utils.py +28 -0
- models/lrm/online_render/utils/taming/modules/autoencoder/lpips/vgg.pth +3 -0
- models/lrm/online_render/utils/vis_utils.py +151 -0
models/lrm/data/irrmaps/README.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
The aerodynamics_workshop_2k.hdr HDR probe is from https://polyhaven.com/a/aerodynamics_workshop
|
2 |
+
CC0 License.
|
3 |
+
|
models/lrm/env_mipmap/6/specular_3.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:09f83af826453e34c70afff7c617092baed35ff644f5c6799f4c2b1bdecc8d69
|
3 |
+
size 296107
|
models/lrm/env_mipmap/6/specular_4.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e94f34638e65bba737806ad3b3031482f80f6ecdc1c97cde5d216c2e90eb9017
|
3 |
+
size 74923
|
models/lrm/env_mipmap/6/specular_5.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3742584ea8fe108908757165d66eb8a003e121858514ec2b3980f48366e4e4f1
|
3 |
+
size 19627
|
models/lrm/online_render/src/__init__.py
ADDED
File without changes
|
models/lrm/online_render/src/data/irrmaps/README.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
The aerodynamics_workshop_2k.hdr HDR probe is from https://polyhaven.com/a/aerodynamics_workshop
|
2 |
+
CC0 License.
|
3 |
+
|
models/lrm/online_render/utils/camera_util.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def pad_camera_extrinsics_4x4(extrinsics):
|
7 |
+
if extrinsics.shape[-2] == 4:
|
8 |
+
return extrinsics
|
9 |
+
padding = torch.tensor([[0, 0, 0, 1]]).to(extrinsics)
|
10 |
+
if extrinsics.ndim == 3:
|
11 |
+
padding = padding.unsqueeze(0).repeat(extrinsics.shape[0], 1, 1)
|
12 |
+
extrinsics = torch.cat([extrinsics, padding], dim=-2)
|
13 |
+
return extrinsics
|
14 |
+
|
15 |
+
|
16 |
+
def center_looking_at_camera_pose(camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None):
|
17 |
+
"""
|
18 |
+
Create OpenGL camera extrinsics from camera locations and look-at position.
|
19 |
+
|
20 |
+
camera_position: (M, 3) or (3,)
|
21 |
+
look_at: (3)
|
22 |
+
up_world: (3)
|
23 |
+
return: (M, 3, 4) or (3, 4)
|
24 |
+
"""
|
25 |
+
# by default, looking at the origin and world up is z-axis
|
26 |
+
if look_at is None:
|
27 |
+
look_at = torch.tensor([0, 0, 0], dtype=torch.float32)
|
28 |
+
if up_world is None:
|
29 |
+
up_world = torch.tensor([0, 0, 1], dtype=torch.float32)
|
30 |
+
if camera_position.ndim == 2:
|
31 |
+
look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1)
|
32 |
+
up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1)
|
33 |
+
|
34 |
+
# OpenGL camera: z-backward, x-right, y-up
|
35 |
+
z_axis = camera_position - look_at
|
36 |
+
z_axis = F.normalize(z_axis, dim=-1).float()
|
37 |
+
x_axis = torch.linalg.cross(up_world, z_axis, dim=-1)
|
38 |
+
x_axis = F.normalize(x_axis, dim=-1).float()
|
39 |
+
y_axis = torch.linalg.cross(z_axis, x_axis, dim=-1)
|
40 |
+
y_axis = F.normalize(y_axis, dim=-1).float()
|
41 |
+
|
42 |
+
extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1)
|
43 |
+
extrinsics = pad_camera_extrinsics_4x4(extrinsics)
|
44 |
+
return extrinsics
|
45 |
+
|
46 |
+
|
47 |
+
def spherical_camera_pose(azimuths: np.ndarray, elevations: np.ndarray, radius=2.5):
|
48 |
+
azimuths = np.deg2rad(azimuths)
|
49 |
+
elevations = np.deg2rad(elevations)
|
50 |
+
|
51 |
+
xs = radius * np.cos(elevations) * np.cos(azimuths)
|
52 |
+
ys = radius * np.cos(elevations) * np.sin(azimuths)
|
53 |
+
zs = radius * np.sin(elevations)
|
54 |
+
|
55 |
+
cam_locations = np.stack([xs, ys, zs], axis=-1)
|
56 |
+
cam_locations = torch.from_numpy(cam_locations).float()
|
57 |
+
|
58 |
+
c2ws = center_looking_at_camera_pose(cam_locations)
|
59 |
+
return c2ws
|
60 |
+
|
61 |
+
|
62 |
+
def get_circular_camera_poses(M=120, radius=2.5, elevation=30.0):
|
63 |
+
# M: number of circular views
|
64 |
+
# radius: camera dist to center
|
65 |
+
# elevation: elevation degrees of the camera
|
66 |
+
# return: (M, 4, 4)
|
67 |
+
assert M > 0 and radius > 0
|
68 |
+
|
69 |
+
elevation = np.deg2rad(elevation)
|
70 |
+
|
71 |
+
camera_positions = []
|
72 |
+
for i in range(M):
|
73 |
+
azimuth = 2 * np.pi * i / M
|
74 |
+
x = radius * np.cos(elevation) * np.cos(azimuth)
|
75 |
+
y = radius * np.cos(elevation) * np.sin(azimuth)
|
76 |
+
z = radius * np.sin(elevation)
|
77 |
+
camera_positions.append([x, y, z])
|
78 |
+
camera_positions = np.array(camera_positions)
|
79 |
+
camera_positions = torch.from_numpy(camera_positions).float()
|
80 |
+
extrinsics = center_looking_at_camera_pose(camera_positions)
|
81 |
+
return extrinsics
|
82 |
+
|
83 |
+
|
84 |
+
def FOV_to_intrinsics(fov, device='cpu'):
|
85 |
+
"""
|
86 |
+
Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees.
|
87 |
+
Note the intrinsics are returned as normalized by image size, rather than in pixel units.
|
88 |
+
Assumes principal point is at image center.
|
89 |
+
"""
|
90 |
+
focal_length = 0.5 / np.tan(np.deg2rad(fov) * 0.5)
|
91 |
+
intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
|
92 |
+
return intrinsics
|
93 |
+
|
94 |
+
|
95 |
+
def get_zero123plus_input_cameras(batch_size=1, radius=4.0, fov=30.0):
|
96 |
+
"""
|
97 |
+
Get the input camera parameters.
|
98 |
+
"""
|
99 |
+
azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float)
|
100 |
+
elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float)
|
101 |
+
|
102 |
+
c2ws = spherical_camera_pose(azimuths, elevations, radius)
|
103 |
+
c2ws = c2ws.float().flatten(-2)
|
104 |
+
|
105 |
+
Ks = FOV_to_intrinsics(fov).unsqueeze(0).repeat(6, 1, 1).float().flatten(-2)
|
106 |
+
|
107 |
+
extrinsics = c2ws[:, :12]
|
108 |
+
intrinsics = torch.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], dim=-1)
|
109 |
+
cameras = torch.cat([extrinsics, intrinsics], dim=-1)
|
110 |
+
|
111 |
+
return cameras.unsqueeze(0).repeat(batch_size, 1, 1)
|
models/lrm/online_render/utils/camera_utils.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
from scene.cameras import Camera
|
13 |
+
import numpy as np
|
14 |
+
from utils.general_utils import PILtoTorch
|
15 |
+
from utils.graphics_utils import fov2focal
|
16 |
+
|
17 |
+
WARNED = False
|
18 |
+
|
19 |
+
def loadCam(args, id, cam_info, resolution_scale):
|
20 |
+
orig_w, orig_h = cam_info.image.size
|
21 |
+
|
22 |
+
if args.resolution in [1, 2, 4, 8]:
|
23 |
+
resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
|
24 |
+
else: # should be a type that converts to float
|
25 |
+
if args.resolution == -1:
|
26 |
+
if orig_w > 1600:
|
27 |
+
global WARNED
|
28 |
+
if not WARNED:
|
29 |
+
print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n "
|
30 |
+
"If this is not desired, please explicitly specify '--resolution/-r' as 1")
|
31 |
+
WARNED = True
|
32 |
+
global_down = orig_w / 1600
|
33 |
+
else:
|
34 |
+
global_down = 1
|
35 |
+
else:
|
36 |
+
global_down = orig_w / args.resolution
|
37 |
+
|
38 |
+
scale = float(global_down) * float(resolution_scale)
|
39 |
+
resolution = (int(orig_w / scale), int(orig_h / scale))
|
40 |
+
|
41 |
+
resized_image_rgb = PILtoTorch(cam_info.image, resolution)
|
42 |
+
|
43 |
+
gt_image = resized_image_rgb[:3, ...]
|
44 |
+
loaded_mask = None
|
45 |
+
|
46 |
+
if resized_image_rgb.shape[1] == 4:
|
47 |
+
loaded_mask = resized_image_rgb[3:4, ...]
|
48 |
+
mask_image = cam_info.mask
|
49 |
+
# breakpoint()
|
50 |
+
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
|
51 |
+
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
|
52 |
+
image=gt_image, gt_alpha_mask=loaded_mask, mask=mask_image,
|
53 |
+
image_name=cam_info.image_name, uid=id, data_device=args.data_device)
|
54 |
+
|
55 |
+
def cameraList_from_camInfos(cam_infos, resolution_scale, args):
|
56 |
+
camera_list = []
|
57 |
+
|
58 |
+
for id, c in enumerate(cam_infos):
|
59 |
+
camera_list.append(loadCam(args, id, c, resolution_scale))
|
60 |
+
|
61 |
+
return camera_list
|
62 |
+
|
63 |
+
def camera_to_JSON(id, camera : Camera):
|
64 |
+
Rt = np.zeros((4, 4))
|
65 |
+
Rt[:3, :3] = camera.R.transpose()
|
66 |
+
Rt[:3, 3] = camera.T
|
67 |
+
Rt[3, 3] = 1.0
|
68 |
+
|
69 |
+
W2C = np.linalg.inv(Rt)
|
70 |
+
pos = W2C[:3, 3]
|
71 |
+
rot = W2C[:3, :3]
|
72 |
+
serializable_array_2d = [x.tolist() for x in rot]
|
73 |
+
camera_entry = {
|
74 |
+
'id' : id,
|
75 |
+
'img_name' : camera.image_name,
|
76 |
+
'width' : camera.width,
|
77 |
+
'height' : camera.height,
|
78 |
+
'position': pos.tolist(),
|
79 |
+
'rotation': serializable_array_2d,
|
80 |
+
'fy' : fov2focal(camera.FovY, camera.height),
|
81 |
+
'fx' : fov2focal(camera.FovX, camera.width)
|
82 |
+
}
|
83 |
+
return camera_entry
|
models/lrm/online_render/utils/general_utils.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import sys
|
14 |
+
from datetime import datetime
|
15 |
+
import numpy as np
|
16 |
+
import random
|
17 |
+
|
18 |
+
def inverse_sigmoid(x):
|
19 |
+
return torch.log(x/(1-x))
|
20 |
+
|
21 |
+
def PILtoTorch(pil_image, resolution):
|
22 |
+
resized_image_PIL = pil_image.resize(resolution)
|
23 |
+
resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
|
24 |
+
if len(resized_image.shape) == 3:
|
25 |
+
return resized_image.permute(2, 0, 1)
|
26 |
+
else:
|
27 |
+
return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
|
28 |
+
|
29 |
+
def get_expon_lr_func(
|
30 |
+
lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
|
31 |
+
):
|
32 |
+
"""
|
33 |
+
Copied from Plenoxels
|
34 |
+
|
35 |
+
Continuous learning rate decay function. Adapted from JaxNeRF
|
36 |
+
The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
|
37 |
+
is log-linearly interpolated elsewhere (equivalent to exponential decay).
|
38 |
+
If lr_delay_steps>0 then the learning rate will be scaled by some smooth
|
39 |
+
function of lr_delay_mult, such that the initial learning rate is
|
40 |
+
lr_init*lr_delay_mult at the beginning of optimization but will be eased back
|
41 |
+
to the normal learning rate when steps>lr_delay_steps.
|
42 |
+
:param conf: config subtree 'lr' or similar
|
43 |
+
:param max_steps: int, the number of steps during optimization.
|
44 |
+
:return HoF which takes step as input
|
45 |
+
"""
|
46 |
+
|
47 |
+
def helper(step):
|
48 |
+
if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
|
49 |
+
# Disable this parameter
|
50 |
+
return 0.0
|
51 |
+
if lr_delay_steps > 0:
|
52 |
+
# A kind of reverse cosine decay.
|
53 |
+
delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
|
54 |
+
0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
|
55 |
+
)
|
56 |
+
else:
|
57 |
+
delay_rate = 1.0
|
58 |
+
t = np.clip(step / max_steps, 0, 1)
|
59 |
+
log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
|
60 |
+
return delay_rate * log_lerp
|
61 |
+
|
62 |
+
return helper
|
63 |
+
|
64 |
+
def strip_lowerdiag(L):
|
65 |
+
uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
|
66 |
+
|
67 |
+
uncertainty[:, 0] = L[:, 0, 0]
|
68 |
+
uncertainty[:, 1] = L[:, 0, 1]
|
69 |
+
uncertainty[:, 2] = L[:, 0, 2]
|
70 |
+
uncertainty[:, 3] = L[:, 1, 1]
|
71 |
+
uncertainty[:, 4] = L[:, 1, 2]
|
72 |
+
uncertainty[:, 5] = L[:, 2, 2]
|
73 |
+
return uncertainty
|
74 |
+
|
75 |
+
def strip_symmetric(sym):
|
76 |
+
return strip_lowerdiag(sym)
|
77 |
+
|
78 |
+
def build_rotation(r):
|
79 |
+
norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
|
80 |
+
|
81 |
+
q = r / norm[:, None]
|
82 |
+
|
83 |
+
R = torch.zeros((q.size(0), 3, 3), device='cuda')
|
84 |
+
|
85 |
+
r = q[:, 0]
|
86 |
+
x = q[:, 1]
|
87 |
+
y = q[:, 2]
|
88 |
+
z = q[:, 3]
|
89 |
+
|
90 |
+
R[:, 0, 0] = 1 - 2 * (y*y + z*z)
|
91 |
+
R[:, 0, 1] = 2 * (x*y - r*z)
|
92 |
+
R[:, 0, 2] = 2 * (x*z + r*y)
|
93 |
+
R[:, 1, 0] = 2 * (x*y + r*z)
|
94 |
+
R[:, 1, 1] = 1 - 2 * (x*x + z*z)
|
95 |
+
R[:, 1, 2] = 2 * (y*z - r*x)
|
96 |
+
R[:, 2, 0] = 2 * (x*z - r*y)
|
97 |
+
R[:, 2, 1] = 2 * (y*z + r*x)
|
98 |
+
R[:, 2, 2] = 1 - 2 * (x*x + y*y)
|
99 |
+
return R
|
100 |
+
|
101 |
+
def build_scaling_rotation(s, r):
|
102 |
+
L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
|
103 |
+
R = build_rotation(r)
|
104 |
+
|
105 |
+
L[:,0,0] = s[:,0]
|
106 |
+
L[:,1,1] = s[:,1]
|
107 |
+
L[:,2,2] = s[:,2]
|
108 |
+
|
109 |
+
L = R @ L
|
110 |
+
return L
|
111 |
+
|
112 |
+
def safe_state(silent):
|
113 |
+
old_f = sys.stdout
|
114 |
+
class F:
|
115 |
+
def __init__(self, silent):
|
116 |
+
self.silent = silent
|
117 |
+
|
118 |
+
def write(self, x):
|
119 |
+
if not self.silent:
|
120 |
+
if x.endswith("\n"):
|
121 |
+
old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
|
122 |
+
else:
|
123 |
+
old_f.write(x)
|
124 |
+
|
125 |
+
def flush(self):
|
126 |
+
old_f.flush()
|
127 |
+
|
128 |
+
sys.stdout = F(silent)
|
129 |
+
|
130 |
+
random.seed(0)
|
131 |
+
np.random.seed(0)
|
132 |
+
torch.manual_seed(0)
|
133 |
+
torch.cuda.set_device(torch.device("cuda:0"))
|
models/lrm/online_render/utils/graphics_utils.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import math
|
14 |
+
import numpy as np
|
15 |
+
from typing import NamedTuple
|
16 |
+
|
17 |
+
class BasicPointCloud(NamedTuple):
|
18 |
+
points : np.array
|
19 |
+
colors : np.array
|
20 |
+
normals : np.array
|
21 |
+
|
22 |
+
def geom_transform_points(points, transf_matrix):
|
23 |
+
P, _ = points.shape
|
24 |
+
ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
|
25 |
+
points_hom = torch.cat([points, ones], dim=1)
|
26 |
+
points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
|
27 |
+
|
28 |
+
denom = points_out[..., 3:] + 0.0000001
|
29 |
+
return (points_out[..., :3] / denom).squeeze(dim=0)
|
30 |
+
|
31 |
+
def getWorld2View(R, t):
|
32 |
+
Rt = np.zeros((4, 4))
|
33 |
+
Rt[:3, :3] = R.transpose()
|
34 |
+
Rt[:3, 3] = t
|
35 |
+
Rt[3, 3] = 1.0
|
36 |
+
return np.float32(Rt)
|
37 |
+
|
38 |
+
def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
|
39 |
+
Rt = np.zeros((4, 4))
|
40 |
+
Rt[:3, :3] = R.transpose()
|
41 |
+
Rt[:3, 3] = t
|
42 |
+
Rt[3, 3] = 1.0
|
43 |
+
|
44 |
+
C2W = np.linalg.inv(Rt)
|
45 |
+
cam_center = C2W[:3, 3]
|
46 |
+
cam_center = (cam_center + translate) * scale
|
47 |
+
C2W[:3, 3] = cam_center
|
48 |
+
Rt = np.linalg.inv(C2W)
|
49 |
+
return np.float32(Rt)
|
50 |
+
|
51 |
+
def getView2World(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
|
52 |
+
Rt = np.zeros((4, 4))
|
53 |
+
Rt[:3, :3] = R.transpose()
|
54 |
+
Rt[:3, 3] = t
|
55 |
+
Rt[3, 3] = 1.0
|
56 |
+
|
57 |
+
C2W = np.linalg.inv(Rt)
|
58 |
+
cam_center = C2W[:3, 3]
|
59 |
+
cam_center = (cam_center + translate) * scale
|
60 |
+
C2W[:3, 3] = cam_center
|
61 |
+
Rt = C2W
|
62 |
+
return np.float32(Rt)
|
63 |
+
|
64 |
+
def getProjectionMatrix(znear, zfar, fovX, fovY):
|
65 |
+
tanHalfFovY = math.tan((fovY / 2))
|
66 |
+
tanHalfFovX = math.tan((fovX / 2))
|
67 |
+
|
68 |
+
top = tanHalfFovY * znear
|
69 |
+
bottom = -top
|
70 |
+
right = tanHalfFovX * znear
|
71 |
+
left = -right
|
72 |
+
|
73 |
+
P = torch.zeros(4, 4)
|
74 |
+
|
75 |
+
z_sign = 1.0
|
76 |
+
|
77 |
+
P[0, 0] = 2.0 * znear / (right - left)
|
78 |
+
P[1, 1] = 2.0 * znear / (top - bottom)
|
79 |
+
P[0, 2] = (right + left) / (right - left)
|
80 |
+
P[1, 2] = (top + bottom) / (top - bottom)
|
81 |
+
P[3, 2] = z_sign
|
82 |
+
P[2, 2] = z_sign * zfar / (zfar - znear)
|
83 |
+
P[2, 3] = -(zfar * znear) / (zfar - znear)
|
84 |
+
return P
|
85 |
+
|
86 |
+
def fov2focal(fov, pixels):
|
87 |
+
return pixels / (2 * math.tan(fov / 2))
|
88 |
+
|
89 |
+
def focal2fov(focal, pixels):
|
90 |
+
return 2*math.atan(pixels/(2*focal))
|
models/lrm/online_render/utils/image_utils.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
def mse(img1, img2):
|
15 |
+
return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
|
16 |
+
|
17 |
+
def psnr(img1, img2):
|
18 |
+
mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
|
19 |
+
return 20 * torch.log10(1.0 / torch.sqrt(mse))
|
models/lrm/online_render/utils/loss_utils.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from torch.autograd import Variable
|
15 |
+
from math import exp
|
16 |
+
|
17 |
+
def l1_loss(network_output, gt):
|
18 |
+
return torch.abs((network_output - gt)).mean()
|
19 |
+
|
20 |
+
def l2_loss(network_output, gt):
|
21 |
+
return ((network_output - gt) ** 2).mean()
|
22 |
+
|
23 |
+
def gaussian(window_size, sigma):
|
24 |
+
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
|
25 |
+
return gauss / gauss.sum()
|
26 |
+
|
27 |
+
def create_window(window_size, channel):
|
28 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
29 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
30 |
+
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
31 |
+
return window
|
32 |
+
|
33 |
+
def ssim(img1, img2, window_size=11, size_average=True):
|
34 |
+
channel = img1.size(-3)
|
35 |
+
window = create_window(window_size, channel)
|
36 |
+
|
37 |
+
if img1.is_cuda:
|
38 |
+
window = window.cuda(img1.get_device())
|
39 |
+
window = window.type_as(img1)
|
40 |
+
|
41 |
+
return _ssim(img1, img2, window, window_size, channel, size_average)
|
42 |
+
|
43 |
+
def _ssim(img1, img2, window, window_size, channel, size_average=True):
|
44 |
+
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
45 |
+
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
46 |
+
|
47 |
+
mu1_sq = mu1.pow(2)
|
48 |
+
mu2_sq = mu2.pow(2)
|
49 |
+
mu1_mu2 = mu1 * mu2
|
50 |
+
|
51 |
+
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
52 |
+
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
53 |
+
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
|
54 |
+
|
55 |
+
C1 = 0.01 ** 2
|
56 |
+
C2 = 0.03 ** 2
|
57 |
+
|
58 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
59 |
+
|
60 |
+
if size_average:
|
61 |
+
return ssim_map.mean()
|
62 |
+
else:
|
63 |
+
return ssim_map.mean(1).mean(1).mean(1)
|
64 |
+
|
65 |
+
import torch
|
66 |
+
import torch.nn as nn
|
67 |
+
|
68 |
+
from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
|
69 |
+
|
70 |
+
|
71 |
+
class LPIPSWithDiscriminator(nn.Module):
|
72 |
+
def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
|
73 |
+
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
|
74 |
+
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
|
75 |
+
disc_loss="hinge"):
|
76 |
+
|
77 |
+
super().__init__()
|
78 |
+
assert disc_loss in ["hinge", "vanilla"]
|
79 |
+
self.kl_weight = kl_weight
|
80 |
+
self.pixel_weight = pixelloss_weight
|
81 |
+
self.perceptual_loss = LPIPS().eval()
|
82 |
+
self.perceptual_weight = perceptual_weight
|
83 |
+
# output log variance
|
84 |
+
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
|
85 |
+
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
|
86 |
+
n_layers=disc_num_layers,
|
87 |
+
use_actnorm=use_actnorm
|
88 |
+
).apply(weights_init)
|
89 |
+
self.discriminator_iter_start = disc_start
|
90 |
+
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
|
91 |
+
self.disc_factor = disc_factor
|
92 |
+
self.discriminator_weight = disc_weight
|
93 |
+
self.disc_conditional = disc_conditional
|
94 |
+
|
95 |
+
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
96 |
+
if last_layer is not None:
|
97 |
+
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
98 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
99 |
+
else:
|
100 |
+
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
|
101 |
+
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
|
102 |
+
|
103 |
+
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
104 |
+
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
105 |
+
d_weight = d_weight * self.discriminator_weight
|
106 |
+
return d_weight
|
107 |
+
|
108 |
+
def forward(self, inputs, reconstructions, optimizer_idx,
|
109 |
+
global_step, last_layer=None, cond=None, split="train"):
|
110 |
+
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
111 |
+
if self.perceptual_weight > 0:
|
112 |
+
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
|
113 |
+
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
114 |
+
|
115 |
+
# nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
116 |
+
# now the GAN part
|
117 |
+
if optimizer_idx == 0:
|
118 |
+
# generator update
|
119 |
+
logits_fake = self.discriminator(reconstructions.contiguous())
|
120 |
+
# g_loss = -torch.mean(logits_fake)
|
121 |
+
g_loss = F.relu(1 - logits_fake).mean()
|
122 |
+
# if self.disc_factor > 0.0:
|
123 |
+
# try:
|
124 |
+
# d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
125 |
+
# except RuntimeError:
|
126 |
+
# assert not self.training
|
127 |
+
# d_weight = torch.tensor(0.0)
|
128 |
+
# else:
|
129 |
+
# d_weight = torch.tensor(0.0)
|
130 |
+
|
131 |
+
# disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
132 |
+
# loss = d_weight * disc_factor * g_loss
|
133 |
+
|
134 |
+
# return loss, log
|
135 |
+
return g_loss
|
136 |
+
|
137 |
+
if optimizer_idx == 1:
|
138 |
+
# second pass for discriminator update
|
139 |
+
|
140 |
+
logits_real = self.discriminator(inputs.contiguous().detach())
|
141 |
+
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
142 |
+
|
143 |
+
# disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
144 |
+
# d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
145 |
+
|
146 |
+
# log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
147 |
+
# "{}/logits_real".format(split): logits_real.detach().mean(),
|
148 |
+
# "{}/logits_fake".format(split): logits_fake.detach().mean()
|
149 |
+
# }
|
150 |
+
# return d_loss, log
|
151 |
+
|
152 |
+
d_loss = self.disc_loss(logits_real, logits_fake)
|
153 |
+
return d_loss
|
154 |
+
|
155 |
+
import torch
|
156 |
+
from chamfer_distance import ChamferDistance
|
157 |
+
|
158 |
+
# 初始化 Chamfer Distance 模块
|
159 |
+
chamfer_dist_module = ChamferDistance()
|
160 |
+
|
161 |
+
def calculate_chamfer_loss(pred, gt):
|
162 |
+
"""
|
163 |
+
计算 Chamfer Distance 损失
|
164 |
+
Args:
|
165 |
+
pred (torch.Tensor): 预测点云,维度为 (batch_size, num_points, 3)
|
166 |
+
gt (torch.Tensor): 真实点云,维度为 (batch_size, num_points, 3)
|
167 |
+
chamfer_dist_module (ChamferDistance): 预先初始化的 Chamfer Distance 模块
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
torch.Tensor: Chamfer Distance 损失
|
171 |
+
"""
|
172 |
+
# 计算 Chamfer Distance
|
173 |
+
dist1, dist2, idx1, idx2 = chamfer_dist_module(pred, gt)
|
174 |
+
loss = (torch.mean(dist1) + torch.mean(dist2)) / 2
|
175 |
+
|
176 |
+
return loss
|
177 |
+
|
178 |
+
if __name__ == "__main__":
|
179 |
+
|
180 |
+
discriminator = LPIPSWithDiscriminator(disc_start=0, disc_weight=0.5)
|
181 |
+
|
182 |
+
|
183 |
+
|
models/lrm/online_render/utils/obj.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
4 |
+
# property and proprietary rights in and to this material, related
|
5 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
6 |
+
# disclosure or distribution of this material and related documentation
|
7 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
8 |
+
# its affiliates is strictly prohibited.
|
9 |
+
|
10 |
+
import os
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from . import texture
|
14 |
+
from . import mesh
|
15 |
+
from . import material
|
16 |
+
|
17 |
+
######################################################################################
|
18 |
+
# Utility functions
|
19 |
+
######################################################################################
|
20 |
+
|
21 |
+
def _find_mat(materials, name):
|
22 |
+
for mat in materials:
|
23 |
+
if mat['name'] == name:
|
24 |
+
return mat
|
25 |
+
return materials[0] # Materials 0 is the default
|
26 |
+
|
27 |
+
|
28 |
+
def normalize_mesh(vertices):
|
29 |
+
# 计算边界框
|
30 |
+
min_vals, _ = torch.min(vertices, dim=0)
|
31 |
+
max_vals, _ = torch.max(vertices, dim=0)
|
32 |
+
|
33 |
+
# 计算中心点
|
34 |
+
center = (max_vals + min_vals) / 2
|
35 |
+
|
36 |
+
# 平移顶点
|
37 |
+
vertices = vertices - center
|
38 |
+
|
39 |
+
# 计算缩放因子
|
40 |
+
max_extent = torch.max(max_vals - min_vals)
|
41 |
+
scale = 2.0 / max_extent
|
42 |
+
|
43 |
+
# 缩放顶点
|
44 |
+
vertices = vertices * scale
|
45 |
+
|
46 |
+
return vertices
|
47 |
+
|
48 |
+
######################################################################################
|
49 |
+
# Create mesh object from objfile
|
50 |
+
######################################################################################
|
51 |
+
def rotate_y_90(v_pos):
|
52 |
+
# 定义绕X轴旋转90度的旋转矩阵
|
53 |
+
rotate_y = torch.tensor([[0, 0, 1, 0],
|
54 |
+
[0, 1, 0, 0],
|
55 |
+
[-1, 0, 0, 0],
|
56 |
+
[0, 0, 0, 1]], dtype=torch.float32, device=v_pos.device)
|
57 |
+
return rotate_y
|
58 |
+
|
59 |
+
def load_obj(filename, clear_ks=True, mtl_override=None, return_attributes=False, path_is_attributrs=False):
|
60 |
+
obj_path = os.path.dirname(filename)
|
61 |
+
|
62 |
+
# Read entire file
|
63 |
+
with open(filename, 'r') as f:
|
64 |
+
lines = f.readlines()
|
65 |
+
|
66 |
+
# Load materials
|
67 |
+
all_materials = [
|
68 |
+
{
|
69 |
+
'name' : '_default_mat',
|
70 |
+
'bsdf' : 'pbr',
|
71 |
+
'kd' : texture.Texture2D(torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32, device='cuda')),
|
72 |
+
'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'))
|
73 |
+
}
|
74 |
+
]
|
75 |
+
if mtl_override is None:
|
76 |
+
for line in lines:
|
77 |
+
if len(line.split()) == 0:
|
78 |
+
continue
|
79 |
+
if line.split()[0] == 'mtllib':
|
80 |
+
all_materials += material.load_mtl(os.path.join(obj_path, line.split()[1]), clear_ks) # Read in entire material library
|
81 |
+
else:
|
82 |
+
all_materials += material.load_mtl(mtl_override)
|
83 |
+
|
84 |
+
# load vertices
|
85 |
+
vertices, texcoords, normals = [], [], []
|
86 |
+
for line in lines:
|
87 |
+
if len(line.split()) == 0:
|
88 |
+
continue
|
89 |
+
|
90 |
+
prefix = line.split()[0].lower()
|
91 |
+
if prefix == 'v':
|
92 |
+
vertices.append([float(v) for v in line.split()[1:]])
|
93 |
+
elif prefix == 'vt':
|
94 |
+
val = [float(v) for v in line.split()[1:]]
|
95 |
+
texcoords.append([val[0], 1.0 - val[1]])
|
96 |
+
elif prefix == 'vn':
|
97 |
+
normals.append([float(v) for v in line.split()[1:]])
|
98 |
+
|
99 |
+
# load faces
|
100 |
+
activeMatIdx = None
|
101 |
+
used_materials = []
|
102 |
+
faces, tfaces, nfaces, mfaces = [], [], [], []
|
103 |
+
for line in lines:
|
104 |
+
if len(line.split()) == 0:
|
105 |
+
continue
|
106 |
+
|
107 |
+
prefix = line.split()[0].lower()
|
108 |
+
if prefix == 'usemtl': # Track used materials
|
109 |
+
mat = _find_mat(all_materials, line.split()[1])
|
110 |
+
if not mat in used_materials:
|
111 |
+
used_materials.append(mat)
|
112 |
+
activeMatIdx = used_materials.index(mat)
|
113 |
+
elif prefix == 'f': # Parse face
|
114 |
+
vs = line.split()[1:]
|
115 |
+
nv = len(vs)
|
116 |
+
vv = vs[0].split('/')
|
117 |
+
v0 = int(vv[0]) - 1
|
118 |
+
t0 = int(vv[1]) - 1 if vv[1] != "" else -1
|
119 |
+
n0 = int(vv[2]) - 1 if vv[2] != "" else -1
|
120 |
+
for i in range(nv - 2): # Triangulate polygons
|
121 |
+
vv = vs[i + 1].split('/')
|
122 |
+
v1 = int(vv[0]) - 1
|
123 |
+
t1 = int(vv[1]) - 1 if vv[1] != "" else -1
|
124 |
+
n1 = int(vv[2]) - 1 if vv[2] != "" else -1
|
125 |
+
vv = vs[i + 2].split('/')
|
126 |
+
v2 = int(vv[0]) - 1
|
127 |
+
t2 = int(vv[1]) - 1 if vv[1] != "" else -1
|
128 |
+
n2 = int(vv[2]) - 1 if vv[2] != "" else -1
|
129 |
+
mfaces.append(activeMatIdx)
|
130 |
+
faces.append([v0, v1, v2])
|
131 |
+
tfaces.append([t0, t1, t2])
|
132 |
+
nfaces.append([n0, n1, n2])
|
133 |
+
assert len(tfaces) == len(faces) and len(nfaces) == len (faces)
|
134 |
+
|
135 |
+
# Create an "uber" material by combining all textures into a larger texture
|
136 |
+
if len(used_materials) > 1:
|
137 |
+
uber_material, texcoords, tfaces = material.merge_materials(used_materials, texcoords, tfaces, mfaces)
|
138 |
+
else:
|
139 |
+
uber_material = used_materials[0]
|
140 |
+
|
141 |
+
vertices = torch.tensor(vertices, dtype=torch.float32, device='cuda')
|
142 |
+
texcoords = torch.tensor(texcoords, dtype=torch.float32, device='cuda') if len(texcoords) > 0 else None
|
143 |
+
normals = torch.tensor(normals, dtype=torch.float32, device='cuda') if len(normals) > 0 else None
|
144 |
+
|
145 |
+
faces = torch.tensor(faces, dtype=torch.int64, device='cuda')
|
146 |
+
tfaces = torch.tensor(tfaces, dtype=torch.int64, device='cuda') if texcoords is not None else None
|
147 |
+
nfaces = torch.tensor(nfaces, dtype=torch.int64, device='cuda') if normals is not None else None
|
148 |
+
|
149 |
+
vertices = normalize_mesh(vertices)
|
150 |
+
# vertices = vertices @ rotate_y_90(vertices)[:3,:3]
|
151 |
+
|
152 |
+
if return_attributes:
|
153 |
+
return mesh.Mesh(vertices, faces, normals, nfaces, texcoords, tfaces, material=uber_material), vertices, faces, normals, nfaces, texcoords, tfaces, uber_material
|
154 |
+
return mesh.Mesh(vertices, faces, normals, nfaces, texcoords, tfaces, material=uber_material)
|
155 |
+
|
156 |
+
######################################################################################
|
157 |
+
# Save mesh object to objfile
|
158 |
+
######################################################################################
|
159 |
+
|
160 |
+
def write_obj(folder, mesh, save_material=True):
|
161 |
+
obj_file = os.path.join(folder, 'mesh.obj')
|
162 |
+
print("Writing mesh: ", obj_file)
|
163 |
+
with open(obj_file, "w") as f:
|
164 |
+
f.write("mtllib mesh.mtl\n")
|
165 |
+
f.write("g default\n")
|
166 |
+
|
167 |
+
v_pos = mesh.v_pos.detach().cpu().numpy() if mesh.v_pos is not None else None
|
168 |
+
v_nrm = mesh.v_nrm.detach().cpu().numpy() if mesh.v_nrm is not None else None
|
169 |
+
v_tex = mesh.v_tex.detach().cpu().numpy() if mesh.v_tex is not None else None
|
170 |
+
|
171 |
+
t_pos_idx = mesh.t_pos_idx.detach().cpu().numpy() if mesh.t_pos_idx is not None else None
|
172 |
+
t_nrm_idx = mesh.t_nrm_idx.detach().cpu().numpy() if mesh.t_nrm_idx is not None else None
|
173 |
+
t_tex_idx = mesh.t_tex_idx.detach().cpu().numpy() if mesh.t_tex_idx is not None else None
|
174 |
+
|
175 |
+
print(" writing %d vertices" % len(v_pos))
|
176 |
+
for v in v_pos:
|
177 |
+
f.write('v {} {} {} \n'.format(v[0], v[1], v[2]))
|
178 |
+
|
179 |
+
if v_tex is not None:
|
180 |
+
print(" writing %d texcoords" % len(v_tex))
|
181 |
+
assert(len(t_pos_idx) == len(t_tex_idx))
|
182 |
+
for v in v_tex:
|
183 |
+
f.write('vt {} {} \n'.format(v[0], 1.0 - v[1]))
|
184 |
+
|
185 |
+
if v_nrm is not None:
|
186 |
+
print(" writing %d normals" % len(v_nrm))
|
187 |
+
assert(len(t_pos_idx) == len(t_nrm_idx))
|
188 |
+
for v in v_nrm:
|
189 |
+
f.write('vn {} {} {}\n'.format(v[0], v[1], v[2]))
|
190 |
+
|
191 |
+
# faces
|
192 |
+
f.write("s 1 \n")
|
193 |
+
f.write("g pMesh1\n")
|
194 |
+
f.write("usemtl defaultMat\n")
|
195 |
+
|
196 |
+
# Write faces
|
197 |
+
print(" writing %d faces" % len(t_pos_idx))
|
198 |
+
for i in range(len(t_pos_idx)):
|
199 |
+
f.write("f ")
|
200 |
+
for j in range(3):
|
201 |
+
f.write(' %s/%s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1), '' if v_nrm is None else str(t_nrm_idx[i][j]+1)))
|
202 |
+
f.write("\n")
|
203 |
+
|
204 |
+
if save_material:
|
205 |
+
mtl_file = os.path.join(folder, 'mesh.mtl')
|
206 |
+
print("Writing material: ", mtl_file)
|
207 |
+
material.save_mtl(mtl_file, mesh.material)
|
208 |
+
|
209 |
+
print("Done exporting mesh")
|
models/lrm/online_render/utils/sh_utils.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 The PlenOctree Authors.
|
2 |
+
# Redistribution and use in source and binary forms, with or without
|
3 |
+
# modification, are permitted provided that the following conditions are met:
|
4 |
+
#
|
5 |
+
# 1. Redistributions of source code must retain the above copyright notice,
|
6 |
+
# this list of conditions and the following disclaimer.
|
7 |
+
#
|
8 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
9 |
+
# this list of conditions and the following disclaimer in the documentation
|
10 |
+
# and/or other materials provided with the distribution.
|
11 |
+
#
|
12 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
13 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
14 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
15 |
+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
16 |
+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
17 |
+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
18 |
+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
19 |
+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
20 |
+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
21 |
+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
22 |
+
# POSSIBILITY OF SUCH DAMAGE.
|
23 |
+
|
24 |
+
import torch
|
25 |
+
|
26 |
+
C0 = 0.28209479177387814
|
27 |
+
C1 = 0.4886025119029199
|
28 |
+
C2 = [
|
29 |
+
1.0925484305920792,
|
30 |
+
-1.0925484305920792,
|
31 |
+
0.31539156525252005,
|
32 |
+
-1.0925484305920792,
|
33 |
+
0.5462742152960396
|
34 |
+
]
|
35 |
+
C3 = [
|
36 |
+
-0.5900435899266435,
|
37 |
+
2.890611442640554,
|
38 |
+
-0.4570457994644658,
|
39 |
+
0.3731763325901154,
|
40 |
+
-0.4570457994644658,
|
41 |
+
1.445305721320277,
|
42 |
+
-0.5900435899266435
|
43 |
+
]
|
44 |
+
C4 = [
|
45 |
+
2.5033429417967046,
|
46 |
+
-1.7701307697799304,
|
47 |
+
0.9461746957575601,
|
48 |
+
-0.6690465435572892,
|
49 |
+
0.10578554691520431,
|
50 |
+
-0.6690465435572892,
|
51 |
+
0.47308734787878004,
|
52 |
+
-1.7701307697799304,
|
53 |
+
0.6258357354491761,
|
54 |
+
]
|
55 |
+
|
56 |
+
|
57 |
+
def eval_sh(deg, sh, dirs):
|
58 |
+
"""
|
59 |
+
Evaluate spherical harmonics at unit directions
|
60 |
+
using hardcoded SH polynomials.
|
61 |
+
Works with torch/np/jnp.
|
62 |
+
... Can be 0 or more batch dimensions.
|
63 |
+
Args:
|
64 |
+
deg: int SH deg. Currently, 0-3 supported
|
65 |
+
sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
|
66 |
+
dirs: jnp.ndarray unit directions [..., 3]
|
67 |
+
Returns:
|
68 |
+
[..., C]
|
69 |
+
"""
|
70 |
+
assert deg <= 4 and deg >= 0
|
71 |
+
coeff = (deg + 1) ** 2
|
72 |
+
assert sh.shape[-1] >= coeff
|
73 |
+
|
74 |
+
result = C0 * sh[..., 0]
|
75 |
+
if deg > 0:
|
76 |
+
x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
|
77 |
+
result = (result -
|
78 |
+
C1 * y * sh[..., 1] +
|
79 |
+
C1 * z * sh[..., 2] -
|
80 |
+
C1 * x * sh[..., 3])
|
81 |
+
|
82 |
+
if deg > 1:
|
83 |
+
xx, yy, zz = x * x, y * y, z * z
|
84 |
+
xy, yz, xz = x * y, y * z, x * z
|
85 |
+
result = (result +
|
86 |
+
C2[0] * xy * sh[..., 4] +
|
87 |
+
C2[1] * yz * sh[..., 5] +
|
88 |
+
C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
|
89 |
+
C2[3] * xz * sh[..., 7] +
|
90 |
+
C2[4] * (xx - yy) * sh[..., 8])
|
91 |
+
|
92 |
+
if deg > 2:
|
93 |
+
result = (result +
|
94 |
+
C3[0] * y * (3 * xx - yy) * sh[..., 9] +
|
95 |
+
C3[1] * xy * z * sh[..., 10] +
|
96 |
+
C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
|
97 |
+
C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
|
98 |
+
C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
|
99 |
+
C3[5] * z * (xx - yy) * sh[..., 14] +
|
100 |
+
C3[6] * x * (xx - 3 * yy) * sh[..., 15])
|
101 |
+
|
102 |
+
if deg > 3:
|
103 |
+
result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
|
104 |
+
C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
|
105 |
+
C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
|
106 |
+
C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
|
107 |
+
C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
|
108 |
+
C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
|
109 |
+
C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
|
110 |
+
C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
|
111 |
+
C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
|
112 |
+
return result
|
113 |
+
|
114 |
+
def RGB2SH(rgb):
|
115 |
+
return (rgb - 0.5) / C0
|
116 |
+
|
117 |
+
def SH2RGB(sh):
|
118 |
+
return sh * C0 + 0.5
|
models/lrm/online_render/utils/system_utils.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
from errno import EEXIST
|
13 |
+
from os import makedirs, path
|
14 |
+
import os
|
15 |
+
|
16 |
+
def mkdir_p(folder_path):
|
17 |
+
# Creates a directory. equivalent to using mkdir -p on the command line
|
18 |
+
try:
|
19 |
+
makedirs(folder_path)
|
20 |
+
except OSError as exc: # Python >2.5
|
21 |
+
if exc.errno == EEXIST and path.isdir(folder_path):
|
22 |
+
pass
|
23 |
+
else:
|
24 |
+
raise
|
25 |
+
|
26 |
+
def searchForMaxIteration(folder):
|
27 |
+
saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)]
|
28 |
+
return max(saved_iters)
|
models/lrm/online_render/utils/taming/modules/autoencoder/lpips/vgg.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
|
3 |
+
size 7289
|
models/lrm/online_render/utils/vis_utils.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Save a couple images to grids with cond, render cond, novel render, novel gt
|
3 |
+
Also save images to a render video
|
4 |
+
"""
|
5 |
+
import glob
|
6 |
+
import os
|
7 |
+
from PIL import Image
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from matplotlib import pyplot as plt
|
12 |
+
from utils.sh_utils import eval_sh
|
13 |
+
from einops import rearrange
|
14 |
+
|
15 |
+
def gridify():
|
16 |
+
|
17 |
+
out_folder = "grids_objaverse"
|
18 |
+
os.makedirs(out_folder, exist_ok=True)
|
19 |
+
|
20 |
+
folder_paths = glob.glob("/scratch/shared/beegfs/stan/scaling_splatter_image/objaverse/*")
|
21 |
+
# pixelnerf_root = "/scratch/shared/beegfs/stan/splatter_image/pixelnerf/teddybears"
|
22 |
+
folder_paths_test = sorted([fpath for fpath in folder_paths if "gt" not in fpath], key= lambda x: int(os.path.basename(x).split("_")[0]))
|
23 |
+
"""folder_paths_test = [folder_paths_test[i] for i in [5, 7, 12, 15,
|
24 |
+
18, 19, 30, 33,
|
25 |
+
37, 42, 43, 44,
|
26 |
+
48, 51, 64, 66,
|
27 |
+
70, 74, 78, 85,
|
28 |
+
89, 91, 92]]"""
|
29 |
+
|
30 |
+
# Initialize variables for grid dimensions
|
31 |
+
num_examples_row = 6
|
32 |
+
rows = num_examples_row
|
33 |
+
num_per_ex = 2
|
34 |
+
cols = num_examples_row * num_per_ex # 7 * 2
|
35 |
+
im_res = 128
|
36 |
+
|
37 |
+
for im_idx in range(100):
|
38 |
+
print("Doing frame {}".format(im_idx))
|
39 |
+
# for im_name in ["xyz", "colours", "opacity", "scaling"]:
|
40 |
+
grid = np.zeros((rows*im_res, cols*im_res, 3), dtype=np.uint8)
|
41 |
+
|
42 |
+
# Iterate through the folders in the out_folder
|
43 |
+
for f_idx, folder_path_test in enumerate(folder_paths_test[:num_examples_row*num_examples_row]):
|
44 |
+
# if im_name == "xyz":
|
45 |
+
# print(folder_path_test)
|
46 |
+
row_idx = f_idx // num_examples_row
|
47 |
+
col_idx = f_idx % num_examples_row
|
48 |
+
im_path = os.path.join(folder_path_test, "{:05d}.png".format(im_idx))
|
49 |
+
im_path_gt = os.path.join(folder_path_test + "_gt", "{:05d}.png".format(im_idx))
|
50 |
+
"""im_path_pixelnerf = os.path.join(pixelnerf_root, os.path.basename(folder_path_test),
|
51 |
+
"{:06d}.png".format(im_idx))"""
|
52 |
+
|
53 |
+
# im_path = os.path.join(folder_path_test, "{}.png".format(im_name))
|
54 |
+
try:
|
55 |
+
im = np.array(Image.open(im_path))
|
56 |
+
im_gt = np.array(Image.open(im_path_gt))
|
57 |
+
#im_pn = np.array(Image.open(im_path_pixelnerf))
|
58 |
+
grid[row_idx*im_res: (row_idx+1)*im_res,
|
59 |
+
col_idx * num_per_ex *im_res: (col_idx * num_per_ex+1)*im_res, : ] = im[:, :, :3]
|
60 |
+
grid[row_idx*im_res: (row_idx+1)*im_res,
|
61 |
+
(col_idx * num_per_ex + 1) *im_res: (col_idx* num_per_ex +2)*im_res, : ] = im_gt[:, :, :3]
|
62 |
+
"""grid[row_idx*im_res: (row_idx+1)*im_res,
|
63 |
+
(col_idx * num_per_ex + 2) *im_res: (col_idx* num_per_ex +3)*im_res, : ] = im_pn[:, :, :3]"""
|
64 |
+
except FileNotFoundError:
|
65 |
+
a = 0
|
66 |
+
im_out = Image.fromarray(grid)
|
67 |
+
im_out.save(os.path.join(out_folder, "{:05d}.png".format(im_idx)))
|
68 |
+
# im_out.save(os.path.join(out_folder, "{}.png".format(im_name)))
|
69 |
+
|
70 |
+
def comparisons():
|
71 |
+
|
72 |
+
out_root = "hydrants_comparisons"
|
73 |
+
os.makedirs(out_root, exist_ok=True)
|
74 |
+
|
75 |
+
folder_paths = glob.glob("/users/stan/pixel-nerf/full_eval_hydrant/*")
|
76 |
+
folder_paths_test = sorted(folder_paths)
|
77 |
+
folder_paths_ours_root = "/scratch/shared/beegfs/stan/out_hydrants_with_lpips_ours"
|
78 |
+
|
79 |
+
# Initialize variables for grid dimensions
|
80 |
+
rows = 3
|
81 |
+
cols = 1
|
82 |
+
im_res = 128
|
83 |
+
|
84 |
+
for f_idx, folder_path_test in enumerate(folder_paths_test):
|
85 |
+
|
86 |
+
example_id = "_".join(os.path.basename(folder_path_test).split("_")[1:])
|
87 |
+
out_folder = os.path.join(out_root, example_id)
|
88 |
+
os.makedirs(out_folder, exist_ok=True)
|
89 |
+
num_images = len([p for p in glob.glob(os.path.join(folder_path_test, "*.png")) if "gt" not in p])
|
90 |
+
|
91 |
+
grid = np.zeros((rows*im_res, cols*im_res, 3), dtype=np.uint8)
|
92 |
+
|
93 |
+
for im_idx in range(num_images):
|
94 |
+
|
95 |
+
im_path_pixelnerf = os.path.join(folder_path_test, "{:06d}.png".format(im_idx+1))
|
96 |
+
im_path_ours = os.path.join(folder_paths_ours_root, example_id, "{:05d}.png".format(im_idx))
|
97 |
+
im_path_gt = os.path.join(folder_paths_ours_root, example_id + "_gt", "{:05d}.png".format(im_idx))
|
98 |
+
# im_path = os.path.join(folder_path_test, "{}.png".format(im_name))
|
99 |
+
|
100 |
+
im_pn = np.array(Image.open(im_path_pixelnerf))
|
101 |
+
im_ours = np.array(Image.open(im_path_ours))
|
102 |
+
im_gt = np.array(Image.open(im_path_gt))
|
103 |
+
|
104 |
+
grid[:im_res, :, :] = im_pn
|
105 |
+
grid[im_res:2*im_res, :, :] = im_ours
|
106 |
+
grid[2*im_res:3*im_res, :, :] = im_gt
|
107 |
+
|
108 |
+
im_out = Image.fromarray(grid)
|
109 |
+
im_out.save(os.path.join(out_folder, "{:05d}.png".format(im_idx)))
|
110 |
+
|
111 |
+
def vis_image_preds(image_preds: dict, folder_out: str):
|
112 |
+
"""
|
113 |
+
Visualises network's image predictions.
|
114 |
+
Args:
|
115 |
+
image_preds: a dictionary of xyz, opacity, scaling, rotation, features_dc and features_rest
|
116 |
+
"""
|
117 |
+
image_preds_reshaped = {}
|
118 |
+
ray_dirs = (image_preds["xyz"].detach().cpu() / torch.norm(image_preds["xyz"].detach().cpu(), dim=-1, keepdim=True)).reshape(128, 128, 3)
|
119 |
+
|
120 |
+
for k, v in image_preds.items():
|
121 |
+
image_preds_reshaped[k] = v
|
122 |
+
if k == "xyz":
|
123 |
+
image_preds_reshaped[k] = (image_preds_reshaped[k] - torch.min(image_preds_reshaped[k], dim=0, keepdim=True)[0]) / (
|
124 |
+
torch.max(image_preds_reshaped[k], dim=0, keepdim=True)[0] - torch.min(image_preds_reshaped[k], dim=0, keepdim=True)[0]
|
125 |
+
)
|
126 |
+
if k == "scaling":
|
127 |
+
image_preds_reshaped["scaling"] = (image_preds_reshaped["scaling"] - torch.min(image_preds_reshaped["scaling"], dim=0, keepdim=True)[0]) / (
|
128 |
+
torch.max(image_preds_reshaped["scaling"], dim=0, keepdim=True)[0] - torch.min(image_preds_reshaped["scaling"], dim=0, keepdim=True)[0]
|
129 |
+
)
|
130 |
+
if k != "features_rest":
|
131 |
+
image_preds_reshaped[k] = image_preds_reshaped[k].reshape(128, 128, -1).detach().cpu()
|
132 |
+
else:
|
133 |
+
image_preds_reshaped[k] = image_preds_reshaped[k].reshape(128, 128, 3, 3).detach().cpu().permute(0, 1, 3, 2)
|
134 |
+
if k == "opacity":
|
135 |
+
image_preds_reshaped[k] = image_preds_reshaped[k].expand(128, 128, 3)
|
136 |
+
|
137 |
+
|
138 |
+
colours = torch.cat([image_preds_reshaped["features_dc"].unsqueeze(-1), image_preds_reshaped["features_rest"]], dim=-1)
|
139 |
+
colours = eval_sh(1, colours, ray_dirs)
|
140 |
+
|
141 |
+
plt.imsave(os.path.join(folder_out, "colours.png"),
|
142 |
+
colours.numpy())
|
143 |
+
plt.imsave(os.path.join(folder_out, "opacity.png"),
|
144 |
+
image_preds_reshaped["opacity"].numpy())
|
145 |
+
plt.imsave(os.path.join(folder_out, "xyz.png"),
|
146 |
+
(image_preds_reshaped["xyz"] * image_preds_reshaped["opacity"]+ 1 - image_preds_reshaped["opacity"]).numpy())
|
147 |
+
plt.imsave(os.path.join(folder_out, "scaling.png"),
|
148 |
+
(image_preds_reshaped["scaling"] * image_preds_reshaped["opacity"] + 1 - image_preds_reshaped["opacity"]).numpy())
|
149 |
+
|
150 |
+
if __name__ == "__main__":
|
151 |
+
gridify()
|