JiantaoLin commited on
Commit
73f9b1b
·
1 Parent(s): 5a1fa77
models/lrm/data/irrmaps/README.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ The aerodynamics_workshop_2k.hdr HDR probe is from https://polyhaven.com/a/aerodynamics_workshop
2
+ CC0 License.
3
+
models/lrm/env_mipmap/6/specular_3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09f83af826453e34c70afff7c617092baed35ff644f5c6799f4c2b1bdecc8d69
3
+ size 296107
models/lrm/env_mipmap/6/specular_4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e94f34638e65bba737806ad3b3031482f80f6ecdc1c97cde5d216c2e90eb9017
3
+ size 74923
models/lrm/env_mipmap/6/specular_5.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3742584ea8fe108908757165d66eb8a003e121858514ec2b3980f48366e4e4f1
3
+ size 19627
models/lrm/online_render/src/__init__.py ADDED
File without changes
models/lrm/online_render/src/data/irrmaps/README.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ The aerodynamics_workshop_2k.hdr HDR probe is from https://polyhaven.com/a/aerodynamics_workshop
2
+ CC0 License.
3
+
models/lrm/online_render/utils/camera_util.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+
5
+
6
+ def pad_camera_extrinsics_4x4(extrinsics):
7
+ if extrinsics.shape[-2] == 4:
8
+ return extrinsics
9
+ padding = torch.tensor([[0, 0, 0, 1]]).to(extrinsics)
10
+ if extrinsics.ndim == 3:
11
+ padding = padding.unsqueeze(0).repeat(extrinsics.shape[0], 1, 1)
12
+ extrinsics = torch.cat([extrinsics, padding], dim=-2)
13
+ return extrinsics
14
+
15
+
16
+ def center_looking_at_camera_pose(camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None):
17
+ """
18
+ Create OpenGL camera extrinsics from camera locations and look-at position.
19
+
20
+ camera_position: (M, 3) or (3,)
21
+ look_at: (3)
22
+ up_world: (3)
23
+ return: (M, 3, 4) or (3, 4)
24
+ """
25
+ # by default, looking at the origin and world up is z-axis
26
+ if look_at is None:
27
+ look_at = torch.tensor([0, 0, 0], dtype=torch.float32)
28
+ if up_world is None:
29
+ up_world = torch.tensor([0, 0, 1], dtype=torch.float32)
30
+ if camera_position.ndim == 2:
31
+ look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1)
32
+ up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1)
33
+
34
+ # OpenGL camera: z-backward, x-right, y-up
35
+ z_axis = camera_position - look_at
36
+ z_axis = F.normalize(z_axis, dim=-1).float()
37
+ x_axis = torch.linalg.cross(up_world, z_axis, dim=-1)
38
+ x_axis = F.normalize(x_axis, dim=-1).float()
39
+ y_axis = torch.linalg.cross(z_axis, x_axis, dim=-1)
40
+ y_axis = F.normalize(y_axis, dim=-1).float()
41
+
42
+ extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1)
43
+ extrinsics = pad_camera_extrinsics_4x4(extrinsics)
44
+ return extrinsics
45
+
46
+
47
+ def spherical_camera_pose(azimuths: np.ndarray, elevations: np.ndarray, radius=2.5):
48
+ azimuths = np.deg2rad(azimuths)
49
+ elevations = np.deg2rad(elevations)
50
+
51
+ xs = radius * np.cos(elevations) * np.cos(azimuths)
52
+ ys = radius * np.cos(elevations) * np.sin(azimuths)
53
+ zs = radius * np.sin(elevations)
54
+
55
+ cam_locations = np.stack([xs, ys, zs], axis=-1)
56
+ cam_locations = torch.from_numpy(cam_locations).float()
57
+
58
+ c2ws = center_looking_at_camera_pose(cam_locations)
59
+ return c2ws
60
+
61
+
62
+ def get_circular_camera_poses(M=120, radius=2.5, elevation=30.0):
63
+ # M: number of circular views
64
+ # radius: camera dist to center
65
+ # elevation: elevation degrees of the camera
66
+ # return: (M, 4, 4)
67
+ assert M > 0 and radius > 0
68
+
69
+ elevation = np.deg2rad(elevation)
70
+
71
+ camera_positions = []
72
+ for i in range(M):
73
+ azimuth = 2 * np.pi * i / M
74
+ x = radius * np.cos(elevation) * np.cos(azimuth)
75
+ y = radius * np.cos(elevation) * np.sin(azimuth)
76
+ z = radius * np.sin(elevation)
77
+ camera_positions.append([x, y, z])
78
+ camera_positions = np.array(camera_positions)
79
+ camera_positions = torch.from_numpy(camera_positions).float()
80
+ extrinsics = center_looking_at_camera_pose(camera_positions)
81
+ return extrinsics
82
+
83
+
84
+ def FOV_to_intrinsics(fov, device='cpu'):
85
+ """
86
+ Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees.
87
+ Note the intrinsics are returned as normalized by image size, rather than in pixel units.
88
+ Assumes principal point is at image center.
89
+ """
90
+ focal_length = 0.5 / np.tan(np.deg2rad(fov) * 0.5)
91
+ intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
92
+ return intrinsics
93
+
94
+
95
+ def get_zero123plus_input_cameras(batch_size=1, radius=4.0, fov=30.0):
96
+ """
97
+ Get the input camera parameters.
98
+ """
99
+ azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float)
100
+ elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float)
101
+
102
+ c2ws = spherical_camera_pose(azimuths, elevations, radius)
103
+ c2ws = c2ws.float().flatten(-2)
104
+
105
+ Ks = FOV_to_intrinsics(fov).unsqueeze(0).repeat(6, 1, 1).float().flatten(-2)
106
+
107
+ extrinsics = c2ws[:, :12]
108
+ intrinsics = torch.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], dim=-1)
109
+ cameras = torch.cat([extrinsics, intrinsics], dim=-1)
110
+
111
+ return cameras.unsqueeze(0).repeat(batch_size, 1, 1)
models/lrm/online_render/utils/camera_utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ from scene.cameras import Camera
13
+ import numpy as np
14
+ from utils.general_utils import PILtoTorch
15
+ from utils.graphics_utils import fov2focal
16
+
17
+ WARNED = False
18
+
19
+ def loadCam(args, id, cam_info, resolution_scale):
20
+ orig_w, orig_h = cam_info.image.size
21
+
22
+ if args.resolution in [1, 2, 4, 8]:
23
+ resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
24
+ else: # should be a type that converts to float
25
+ if args.resolution == -1:
26
+ if orig_w > 1600:
27
+ global WARNED
28
+ if not WARNED:
29
+ print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n "
30
+ "If this is not desired, please explicitly specify '--resolution/-r' as 1")
31
+ WARNED = True
32
+ global_down = orig_w / 1600
33
+ else:
34
+ global_down = 1
35
+ else:
36
+ global_down = orig_w / args.resolution
37
+
38
+ scale = float(global_down) * float(resolution_scale)
39
+ resolution = (int(orig_w / scale), int(orig_h / scale))
40
+
41
+ resized_image_rgb = PILtoTorch(cam_info.image, resolution)
42
+
43
+ gt_image = resized_image_rgb[:3, ...]
44
+ loaded_mask = None
45
+
46
+ if resized_image_rgb.shape[1] == 4:
47
+ loaded_mask = resized_image_rgb[3:4, ...]
48
+ mask_image = cam_info.mask
49
+ # breakpoint()
50
+ return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
51
+ FoVx=cam_info.FovX, FoVy=cam_info.FovY,
52
+ image=gt_image, gt_alpha_mask=loaded_mask, mask=mask_image,
53
+ image_name=cam_info.image_name, uid=id, data_device=args.data_device)
54
+
55
+ def cameraList_from_camInfos(cam_infos, resolution_scale, args):
56
+ camera_list = []
57
+
58
+ for id, c in enumerate(cam_infos):
59
+ camera_list.append(loadCam(args, id, c, resolution_scale))
60
+
61
+ return camera_list
62
+
63
+ def camera_to_JSON(id, camera : Camera):
64
+ Rt = np.zeros((4, 4))
65
+ Rt[:3, :3] = camera.R.transpose()
66
+ Rt[:3, 3] = camera.T
67
+ Rt[3, 3] = 1.0
68
+
69
+ W2C = np.linalg.inv(Rt)
70
+ pos = W2C[:3, 3]
71
+ rot = W2C[:3, :3]
72
+ serializable_array_2d = [x.tolist() for x in rot]
73
+ camera_entry = {
74
+ 'id' : id,
75
+ 'img_name' : camera.image_name,
76
+ 'width' : camera.width,
77
+ 'height' : camera.height,
78
+ 'position': pos.tolist(),
79
+ 'rotation': serializable_array_2d,
80
+ 'fy' : fov2focal(camera.FovY, camera.height),
81
+ 'fx' : fov2focal(camera.FovX, camera.width)
82
+ }
83
+ return camera_entry
models/lrm/online_render/utils/general_utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import torch
13
+ import sys
14
+ from datetime import datetime
15
+ import numpy as np
16
+ import random
17
+
18
+ def inverse_sigmoid(x):
19
+ return torch.log(x/(1-x))
20
+
21
+ def PILtoTorch(pil_image, resolution):
22
+ resized_image_PIL = pil_image.resize(resolution)
23
+ resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
24
+ if len(resized_image.shape) == 3:
25
+ return resized_image.permute(2, 0, 1)
26
+ else:
27
+ return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
28
+
29
+ def get_expon_lr_func(
30
+ lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
31
+ ):
32
+ """
33
+ Copied from Plenoxels
34
+
35
+ Continuous learning rate decay function. Adapted from JaxNeRF
36
+ The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
37
+ is log-linearly interpolated elsewhere (equivalent to exponential decay).
38
+ If lr_delay_steps>0 then the learning rate will be scaled by some smooth
39
+ function of lr_delay_mult, such that the initial learning rate is
40
+ lr_init*lr_delay_mult at the beginning of optimization but will be eased back
41
+ to the normal learning rate when steps>lr_delay_steps.
42
+ :param conf: config subtree 'lr' or similar
43
+ :param max_steps: int, the number of steps during optimization.
44
+ :return HoF which takes step as input
45
+ """
46
+
47
+ def helper(step):
48
+ if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
49
+ # Disable this parameter
50
+ return 0.0
51
+ if lr_delay_steps > 0:
52
+ # A kind of reverse cosine decay.
53
+ delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
54
+ 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
55
+ )
56
+ else:
57
+ delay_rate = 1.0
58
+ t = np.clip(step / max_steps, 0, 1)
59
+ log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
60
+ return delay_rate * log_lerp
61
+
62
+ return helper
63
+
64
+ def strip_lowerdiag(L):
65
+ uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
66
+
67
+ uncertainty[:, 0] = L[:, 0, 0]
68
+ uncertainty[:, 1] = L[:, 0, 1]
69
+ uncertainty[:, 2] = L[:, 0, 2]
70
+ uncertainty[:, 3] = L[:, 1, 1]
71
+ uncertainty[:, 4] = L[:, 1, 2]
72
+ uncertainty[:, 5] = L[:, 2, 2]
73
+ return uncertainty
74
+
75
+ def strip_symmetric(sym):
76
+ return strip_lowerdiag(sym)
77
+
78
+ def build_rotation(r):
79
+ norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
80
+
81
+ q = r / norm[:, None]
82
+
83
+ R = torch.zeros((q.size(0), 3, 3), device='cuda')
84
+
85
+ r = q[:, 0]
86
+ x = q[:, 1]
87
+ y = q[:, 2]
88
+ z = q[:, 3]
89
+
90
+ R[:, 0, 0] = 1 - 2 * (y*y + z*z)
91
+ R[:, 0, 1] = 2 * (x*y - r*z)
92
+ R[:, 0, 2] = 2 * (x*z + r*y)
93
+ R[:, 1, 0] = 2 * (x*y + r*z)
94
+ R[:, 1, 1] = 1 - 2 * (x*x + z*z)
95
+ R[:, 1, 2] = 2 * (y*z - r*x)
96
+ R[:, 2, 0] = 2 * (x*z - r*y)
97
+ R[:, 2, 1] = 2 * (y*z + r*x)
98
+ R[:, 2, 2] = 1 - 2 * (x*x + y*y)
99
+ return R
100
+
101
+ def build_scaling_rotation(s, r):
102
+ L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
103
+ R = build_rotation(r)
104
+
105
+ L[:,0,0] = s[:,0]
106
+ L[:,1,1] = s[:,1]
107
+ L[:,2,2] = s[:,2]
108
+
109
+ L = R @ L
110
+ return L
111
+
112
+ def safe_state(silent):
113
+ old_f = sys.stdout
114
+ class F:
115
+ def __init__(self, silent):
116
+ self.silent = silent
117
+
118
+ def write(self, x):
119
+ if not self.silent:
120
+ if x.endswith("\n"):
121
+ old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
122
+ else:
123
+ old_f.write(x)
124
+
125
+ def flush(self):
126
+ old_f.flush()
127
+
128
+ sys.stdout = F(silent)
129
+
130
+ random.seed(0)
131
+ np.random.seed(0)
132
+ torch.manual_seed(0)
133
+ torch.cuda.set_device(torch.device("cuda:0"))
models/lrm/online_render/utils/graphics_utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import torch
13
+ import math
14
+ import numpy as np
15
+ from typing import NamedTuple
16
+
17
+ class BasicPointCloud(NamedTuple):
18
+ points : np.array
19
+ colors : np.array
20
+ normals : np.array
21
+
22
+ def geom_transform_points(points, transf_matrix):
23
+ P, _ = points.shape
24
+ ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
25
+ points_hom = torch.cat([points, ones], dim=1)
26
+ points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
27
+
28
+ denom = points_out[..., 3:] + 0.0000001
29
+ return (points_out[..., :3] / denom).squeeze(dim=0)
30
+
31
+ def getWorld2View(R, t):
32
+ Rt = np.zeros((4, 4))
33
+ Rt[:3, :3] = R.transpose()
34
+ Rt[:3, 3] = t
35
+ Rt[3, 3] = 1.0
36
+ return np.float32(Rt)
37
+
38
+ def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
39
+ Rt = np.zeros((4, 4))
40
+ Rt[:3, :3] = R.transpose()
41
+ Rt[:3, 3] = t
42
+ Rt[3, 3] = 1.0
43
+
44
+ C2W = np.linalg.inv(Rt)
45
+ cam_center = C2W[:3, 3]
46
+ cam_center = (cam_center + translate) * scale
47
+ C2W[:3, 3] = cam_center
48
+ Rt = np.linalg.inv(C2W)
49
+ return np.float32(Rt)
50
+
51
+ def getView2World(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
52
+ Rt = np.zeros((4, 4))
53
+ Rt[:3, :3] = R.transpose()
54
+ Rt[:3, 3] = t
55
+ Rt[3, 3] = 1.0
56
+
57
+ C2W = np.linalg.inv(Rt)
58
+ cam_center = C2W[:3, 3]
59
+ cam_center = (cam_center + translate) * scale
60
+ C2W[:3, 3] = cam_center
61
+ Rt = C2W
62
+ return np.float32(Rt)
63
+
64
+ def getProjectionMatrix(znear, zfar, fovX, fovY):
65
+ tanHalfFovY = math.tan((fovY / 2))
66
+ tanHalfFovX = math.tan((fovX / 2))
67
+
68
+ top = tanHalfFovY * znear
69
+ bottom = -top
70
+ right = tanHalfFovX * znear
71
+ left = -right
72
+
73
+ P = torch.zeros(4, 4)
74
+
75
+ z_sign = 1.0
76
+
77
+ P[0, 0] = 2.0 * znear / (right - left)
78
+ P[1, 1] = 2.0 * znear / (top - bottom)
79
+ P[0, 2] = (right + left) / (right - left)
80
+ P[1, 2] = (top + bottom) / (top - bottom)
81
+ P[3, 2] = z_sign
82
+ P[2, 2] = z_sign * zfar / (zfar - znear)
83
+ P[2, 3] = -(zfar * znear) / (zfar - znear)
84
+ return P
85
+
86
+ def fov2focal(fov, pixels):
87
+ return pixels / (2 * math.tan(fov / 2))
88
+
89
+ def focal2fov(focal, pixels):
90
+ return 2*math.atan(pixels/(2*focal))
models/lrm/online_render/utils/image_utils.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import torch
13
+
14
+ def mse(img1, img2):
15
+ return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
16
+
17
+ def psnr(img1, img2):
18
+ mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
19
+ return 20 * torch.log10(1.0 / torch.sqrt(mse))
models/lrm/online_render/utils/loss_utils.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch.autograd import Variable
15
+ from math import exp
16
+
17
+ def l1_loss(network_output, gt):
18
+ return torch.abs((network_output - gt)).mean()
19
+
20
+ def l2_loss(network_output, gt):
21
+ return ((network_output - gt) ** 2).mean()
22
+
23
+ def gaussian(window_size, sigma):
24
+ gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
25
+ return gauss / gauss.sum()
26
+
27
+ def create_window(window_size, channel):
28
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
29
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
30
+ window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
31
+ return window
32
+
33
+ def ssim(img1, img2, window_size=11, size_average=True):
34
+ channel = img1.size(-3)
35
+ window = create_window(window_size, channel)
36
+
37
+ if img1.is_cuda:
38
+ window = window.cuda(img1.get_device())
39
+ window = window.type_as(img1)
40
+
41
+ return _ssim(img1, img2, window, window_size, channel, size_average)
42
+
43
+ def _ssim(img1, img2, window, window_size, channel, size_average=True):
44
+ mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
45
+ mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
46
+
47
+ mu1_sq = mu1.pow(2)
48
+ mu2_sq = mu2.pow(2)
49
+ mu1_mu2 = mu1 * mu2
50
+
51
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
52
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
53
+ sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
54
+
55
+ C1 = 0.01 ** 2
56
+ C2 = 0.03 ** 2
57
+
58
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
59
+
60
+ if size_average:
61
+ return ssim_map.mean()
62
+ else:
63
+ return ssim_map.mean(1).mean(1).mean(1)
64
+
65
+ import torch
66
+ import torch.nn as nn
67
+
68
+ from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
69
+
70
+
71
+ class LPIPSWithDiscriminator(nn.Module):
72
+ def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
73
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
74
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
75
+ disc_loss="hinge"):
76
+
77
+ super().__init__()
78
+ assert disc_loss in ["hinge", "vanilla"]
79
+ self.kl_weight = kl_weight
80
+ self.pixel_weight = pixelloss_weight
81
+ self.perceptual_loss = LPIPS().eval()
82
+ self.perceptual_weight = perceptual_weight
83
+ # output log variance
84
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
85
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
86
+ n_layers=disc_num_layers,
87
+ use_actnorm=use_actnorm
88
+ ).apply(weights_init)
89
+ self.discriminator_iter_start = disc_start
90
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
91
+ self.disc_factor = disc_factor
92
+ self.discriminator_weight = disc_weight
93
+ self.disc_conditional = disc_conditional
94
+
95
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
96
+ if last_layer is not None:
97
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
98
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
99
+ else:
100
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
101
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
102
+
103
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
104
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
105
+ d_weight = d_weight * self.discriminator_weight
106
+ return d_weight
107
+
108
+ def forward(self, inputs, reconstructions, optimizer_idx,
109
+ global_step, last_layer=None, cond=None, split="train"):
110
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
111
+ if self.perceptual_weight > 0:
112
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
113
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
114
+
115
+ # nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
116
+ # now the GAN part
117
+ if optimizer_idx == 0:
118
+ # generator update
119
+ logits_fake = self.discriminator(reconstructions.contiguous())
120
+ # g_loss = -torch.mean(logits_fake)
121
+ g_loss = F.relu(1 - logits_fake).mean()
122
+ # if self.disc_factor > 0.0:
123
+ # try:
124
+ # d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
125
+ # except RuntimeError:
126
+ # assert not self.training
127
+ # d_weight = torch.tensor(0.0)
128
+ # else:
129
+ # d_weight = torch.tensor(0.0)
130
+
131
+ # disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
132
+ # loss = d_weight * disc_factor * g_loss
133
+
134
+ # return loss, log
135
+ return g_loss
136
+
137
+ if optimizer_idx == 1:
138
+ # second pass for discriminator update
139
+
140
+ logits_real = self.discriminator(inputs.contiguous().detach())
141
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
142
+
143
+ # disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
144
+ # d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
145
+
146
+ # log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
147
+ # "{}/logits_real".format(split): logits_real.detach().mean(),
148
+ # "{}/logits_fake".format(split): logits_fake.detach().mean()
149
+ # }
150
+ # return d_loss, log
151
+
152
+ d_loss = self.disc_loss(logits_real, logits_fake)
153
+ return d_loss
154
+
155
+ import torch
156
+ from chamfer_distance import ChamferDistance
157
+
158
+ # 初始化 Chamfer Distance 模块
159
+ chamfer_dist_module = ChamferDistance()
160
+
161
+ def calculate_chamfer_loss(pred, gt):
162
+ """
163
+ 计算 Chamfer Distance 损失
164
+ Args:
165
+ pred (torch.Tensor): 预测点云,维度为 (batch_size, num_points, 3)
166
+ gt (torch.Tensor): 真实点云,维度为 (batch_size, num_points, 3)
167
+ chamfer_dist_module (ChamferDistance): 预先初始化的 Chamfer Distance 模块
168
+
169
+ Returns:
170
+ torch.Tensor: Chamfer Distance 损失
171
+ """
172
+ # 计算 Chamfer Distance
173
+ dist1, dist2, idx1, idx2 = chamfer_dist_module(pred, gt)
174
+ loss = (torch.mean(dist1) + torch.mean(dist2)) / 2
175
+
176
+ return loss
177
+
178
+ if __name__ == "__main__":
179
+
180
+ discriminator = LPIPSWithDiscriminator(disc_start=0, disc_weight=0.5)
181
+
182
+
183
+
models/lrm/online_render/utils/obj.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4
+ # property and proprietary rights in and to this material, related
5
+ # documentation and any modifications thereto. Any use, reproduction,
6
+ # disclosure or distribution of this material and related documentation
7
+ # without an express license agreement from NVIDIA CORPORATION or
8
+ # its affiliates is strictly prohibited.
9
+
10
+ import os
11
+ import torch
12
+
13
+ from . import texture
14
+ from . import mesh
15
+ from . import material
16
+
17
+ ######################################################################################
18
+ # Utility functions
19
+ ######################################################################################
20
+
21
+ def _find_mat(materials, name):
22
+ for mat in materials:
23
+ if mat['name'] == name:
24
+ return mat
25
+ return materials[0] # Materials 0 is the default
26
+
27
+
28
+ def normalize_mesh(vertices):
29
+ # 计算边界框
30
+ min_vals, _ = torch.min(vertices, dim=0)
31
+ max_vals, _ = torch.max(vertices, dim=0)
32
+
33
+ # 计算中心点
34
+ center = (max_vals + min_vals) / 2
35
+
36
+ # 平移顶点
37
+ vertices = vertices - center
38
+
39
+ # 计算缩放因子
40
+ max_extent = torch.max(max_vals - min_vals)
41
+ scale = 2.0 / max_extent
42
+
43
+ # 缩放顶点
44
+ vertices = vertices * scale
45
+
46
+ return vertices
47
+
48
+ ######################################################################################
49
+ # Create mesh object from objfile
50
+ ######################################################################################
51
+ def rotate_y_90(v_pos):
52
+ # 定义绕X轴旋转90度的旋转矩阵
53
+ rotate_y = torch.tensor([[0, 0, 1, 0],
54
+ [0, 1, 0, 0],
55
+ [-1, 0, 0, 0],
56
+ [0, 0, 0, 1]], dtype=torch.float32, device=v_pos.device)
57
+ return rotate_y
58
+
59
+ def load_obj(filename, clear_ks=True, mtl_override=None, return_attributes=False, path_is_attributrs=False):
60
+ obj_path = os.path.dirname(filename)
61
+
62
+ # Read entire file
63
+ with open(filename, 'r') as f:
64
+ lines = f.readlines()
65
+
66
+ # Load materials
67
+ all_materials = [
68
+ {
69
+ 'name' : '_default_mat',
70
+ 'bsdf' : 'pbr',
71
+ 'kd' : texture.Texture2D(torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32, device='cuda')),
72
+ 'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'))
73
+ }
74
+ ]
75
+ if mtl_override is None:
76
+ for line in lines:
77
+ if len(line.split()) == 0:
78
+ continue
79
+ if line.split()[0] == 'mtllib':
80
+ all_materials += material.load_mtl(os.path.join(obj_path, line.split()[1]), clear_ks) # Read in entire material library
81
+ else:
82
+ all_materials += material.load_mtl(mtl_override)
83
+
84
+ # load vertices
85
+ vertices, texcoords, normals = [], [], []
86
+ for line in lines:
87
+ if len(line.split()) == 0:
88
+ continue
89
+
90
+ prefix = line.split()[0].lower()
91
+ if prefix == 'v':
92
+ vertices.append([float(v) for v in line.split()[1:]])
93
+ elif prefix == 'vt':
94
+ val = [float(v) for v in line.split()[1:]]
95
+ texcoords.append([val[0], 1.0 - val[1]])
96
+ elif prefix == 'vn':
97
+ normals.append([float(v) for v in line.split()[1:]])
98
+
99
+ # load faces
100
+ activeMatIdx = None
101
+ used_materials = []
102
+ faces, tfaces, nfaces, mfaces = [], [], [], []
103
+ for line in lines:
104
+ if len(line.split()) == 0:
105
+ continue
106
+
107
+ prefix = line.split()[0].lower()
108
+ if prefix == 'usemtl': # Track used materials
109
+ mat = _find_mat(all_materials, line.split()[1])
110
+ if not mat in used_materials:
111
+ used_materials.append(mat)
112
+ activeMatIdx = used_materials.index(mat)
113
+ elif prefix == 'f': # Parse face
114
+ vs = line.split()[1:]
115
+ nv = len(vs)
116
+ vv = vs[0].split('/')
117
+ v0 = int(vv[0]) - 1
118
+ t0 = int(vv[1]) - 1 if vv[1] != "" else -1
119
+ n0 = int(vv[2]) - 1 if vv[2] != "" else -1
120
+ for i in range(nv - 2): # Triangulate polygons
121
+ vv = vs[i + 1].split('/')
122
+ v1 = int(vv[0]) - 1
123
+ t1 = int(vv[1]) - 1 if vv[1] != "" else -1
124
+ n1 = int(vv[2]) - 1 if vv[2] != "" else -1
125
+ vv = vs[i + 2].split('/')
126
+ v2 = int(vv[0]) - 1
127
+ t2 = int(vv[1]) - 1 if vv[1] != "" else -1
128
+ n2 = int(vv[2]) - 1 if vv[2] != "" else -1
129
+ mfaces.append(activeMatIdx)
130
+ faces.append([v0, v1, v2])
131
+ tfaces.append([t0, t1, t2])
132
+ nfaces.append([n0, n1, n2])
133
+ assert len(tfaces) == len(faces) and len(nfaces) == len (faces)
134
+
135
+ # Create an "uber" material by combining all textures into a larger texture
136
+ if len(used_materials) > 1:
137
+ uber_material, texcoords, tfaces = material.merge_materials(used_materials, texcoords, tfaces, mfaces)
138
+ else:
139
+ uber_material = used_materials[0]
140
+
141
+ vertices = torch.tensor(vertices, dtype=torch.float32, device='cuda')
142
+ texcoords = torch.tensor(texcoords, dtype=torch.float32, device='cuda') if len(texcoords) > 0 else None
143
+ normals = torch.tensor(normals, dtype=torch.float32, device='cuda') if len(normals) > 0 else None
144
+
145
+ faces = torch.tensor(faces, dtype=torch.int64, device='cuda')
146
+ tfaces = torch.tensor(tfaces, dtype=torch.int64, device='cuda') if texcoords is not None else None
147
+ nfaces = torch.tensor(nfaces, dtype=torch.int64, device='cuda') if normals is not None else None
148
+
149
+ vertices = normalize_mesh(vertices)
150
+ # vertices = vertices @ rotate_y_90(vertices)[:3,:3]
151
+
152
+ if return_attributes:
153
+ return mesh.Mesh(vertices, faces, normals, nfaces, texcoords, tfaces, material=uber_material), vertices, faces, normals, nfaces, texcoords, tfaces, uber_material
154
+ return mesh.Mesh(vertices, faces, normals, nfaces, texcoords, tfaces, material=uber_material)
155
+
156
+ ######################################################################################
157
+ # Save mesh object to objfile
158
+ ######################################################################################
159
+
160
+ def write_obj(folder, mesh, save_material=True):
161
+ obj_file = os.path.join(folder, 'mesh.obj')
162
+ print("Writing mesh: ", obj_file)
163
+ with open(obj_file, "w") as f:
164
+ f.write("mtllib mesh.mtl\n")
165
+ f.write("g default\n")
166
+
167
+ v_pos = mesh.v_pos.detach().cpu().numpy() if mesh.v_pos is not None else None
168
+ v_nrm = mesh.v_nrm.detach().cpu().numpy() if mesh.v_nrm is not None else None
169
+ v_tex = mesh.v_tex.detach().cpu().numpy() if mesh.v_tex is not None else None
170
+
171
+ t_pos_idx = mesh.t_pos_idx.detach().cpu().numpy() if mesh.t_pos_idx is not None else None
172
+ t_nrm_idx = mesh.t_nrm_idx.detach().cpu().numpy() if mesh.t_nrm_idx is not None else None
173
+ t_tex_idx = mesh.t_tex_idx.detach().cpu().numpy() if mesh.t_tex_idx is not None else None
174
+
175
+ print(" writing %d vertices" % len(v_pos))
176
+ for v in v_pos:
177
+ f.write('v {} {} {} \n'.format(v[0], v[1], v[2]))
178
+
179
+ if v_tex is not None:
180
+ print(" writing %d texcoords" % len(v_tex))
181
+ assert(len(t_pos_idx) == len(t_tex_idx))
182
+ for v in v_tex:
183
+ f.write('vt {} {} \n'.format(v[0], 1.0 - v[1]))
184
+
185
+ if v_nrm is not None:
186
+ print(" writing %d normals" % len(v_nrm))
187
+ assert(len(t_pos_idx) == len(t_nrm_idx))
188
+ for v in v_nrm:
189
+ f.write('vn {} {} {}\n'.format(v[0], v[1], v[2]))
190
+
191
+ # faces
192
+ f.write("s 1 \n")
193
+ f.write("g pMesh1\n")
194
+ f.write("usemtl defaultMat\n")
195
+
196
+ # Write faces
197
+ print(" writing %d faces" % len(t_pos_idx))
198
+ for i in range(len(t_pos_idx)):
199
+ f.write("f ")
200
+ for j in range(3):
201
+ f.write(' %s/%s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1), '' if v_nrm is None else str(t_nrm_idx[i][j]+1)))
202
+ f.write("\n")
203
+
204
+ if save_material:
205
+ mtl_file = os.path.join(folder, 'mesh.mtl')
206
+ print("Writing material: ", mtl_file)
207
+ material.save_mtl(mtl_file, mesh.material)
208
+
209
+ print("Done exporting mesh")
models/lrm/online_render/utils/sh_utils.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The PlenOctree Authors.
2
+ # Redistribution and use in source and binary forms, with or without
3
+ # modification, are permitted provided that the following conditions are met:
4
+ #
5
+ # 1. Redistributions of source code must retain the above copyright notice,
6
+ # this list of conditions and the following disclaimer.
7
+ #
8
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
9
+ # this list of conditions and the following disclaimer in the documentation
10
+ # and/or other materials provided with the distribution.
11
+ #
12
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15
+ # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16
+ # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17
+ # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18
+ # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19
+ # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20
+ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21
+ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22
+ # POSSIBILITY OF SUCH DAMAGE.
23
+
24
+ import torch
25
+
26
+ C0 = 0.28209479177387814
27
+ C1 = 0.4886025119029199
28
+ C2 = [
29
+ 1.0925484305920792,
30
+ -1.0925484305920792,
31
+ 0.31539156525252005,
32
+ -1.0925484305920792,
33
+ 0.5462742152960396
34
+ ]
35
+ C3 = [
36
+ -0.5900435899266435,
37
+ 2.890611442640554,
38
+ -0.4570457994644658,
39
+ 0.3731763325901154,
40
+ -0.4570457994644658,
41
+ 1.445305721320277,
42
+ -0.5900435899266435
43
+ ]
44
+ C4 = [
45
+ 2.5033429417967046,
46
+ -1.7701307697799304,
47
+ 0.9461746957575601,
48
+ -0.6690465435572892,
49
+ 0.10578554691520431,
50
+ -0.6690465435572892,
51
+ 0.47308734787878004,
52
+ -1.7701307697799304,
53
+ 0.6258357354491761,
54
+ ]
55
+
56
+
57
+ def eval_sh(deg, sh, dirs):
58
+ """
59
+ Evaluate spherical harmonics at unit directions
60
+ using hardcoded SH polynomials.
61
+ Works with torch/np/jnp.
62
+ ... Can be 0 or more batch dimensions.
63
+ Args:
64
+ deg: int SH deg. Currently, 0-3 supported
65
+ sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
66
+ dirs: jnp.ndarray unit directions [..., 3]
67
+ Returns:
68
+ [..., C]
69
+ """
70
+ assert deg <= 4 and deg >= 0
71
+ coeff = (deg + 1) ** 2
72
+ assert sh.shape[-1] >= coeff
73
+
74
+ result = C0 * sh[..., 0]
75
+ if deg > 0:
76
+ x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
77
+ result = (result -
78
+ C1 * y * sh[..., 1] +
79
+ C1 * z * sh[..., 2] -
80
+ C1 * x * sh[..., 3])
81
+
82
+ if deg > 1:
83
+ xx, yy, zz = x * x, y * y, z * z
84
+ xy, yz, xz = x * y, y * z, x * z
85
+ result = (result +
86
+ C2[0] * xy * sh[..., 4] +
87
+ C2[1] * yz * sh[..., 5] +
88
+ C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
89
+ C2[3] * xz * sh[..., 7] +
90
+ C2[4] * (xx - yy) * sh[..., 8])
91
+
92
+ if deg > 2:
93
+ result = (result +
94
+ C3[0] * y * (3 * xx - yy) * sh[..., 9] +
95
+ C3[1] * xy * z * sh[..., 10] +
96
+ C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
97
+ C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
98
+ C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
99
+ C3[5] * z * (xx - yy) * sh[..., 14] +
100
+ C3[6] * x * (xx - 3 * yy) * sh[..., 15])
101
+
102
+ if deg > 3:
103
+ result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
104
+ C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
105
+ C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
106
+ C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
107
+ C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
108
+ C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
109
+ C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
110
+ C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
111
+ C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
112
+ return result
113
+
114
+ def RGB2SH(rgb):
115
+ return (rgb - 0.5) / C0
116
+
117
+ def SH2RGB(sh):
118
+ return sh * C0 + 0.5
models/lrm/online_render/utils/system_utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ from errno import EEXIST
13
+ from os import makedirs, path
14
+ import os
15
+
16
+ def mkdir_p(folder_path):
17
+ # Creates a directory. equivalent to using mkdir -p on the command line
18
+ try:
19
+ makedirs(folder_path)
20
+ except OSError as exc: # Python >2.5
21
+ if exc.errno == EEXIST and path.isdir(folder_path):
22
+ pass
23
+ else:
24
+ raise
25
+
26
+ def searchForMaxIteration(folder):
27
+ saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)]
28
+ return max(saved_iters)
models/lrm/online_render/utils/taming/modules/autoencoder/lpips/vgg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
3
+ size 7289
models/lrm/online_render/utils/vis_utils.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Save a couple images to grids with cond, render cond, novel render, novel gt
3
+ Also save images to a render video
4
+ """
5
+ import glob
6
+ import os
7
+ from PIL import Image
8
+ import numpy as np
9
+ import torch
10
+
11
+ from matplotlib import pyplot as plt
12
+ from utils.sh_utils import eval_sh
13
+ from einops import rearrange
14
+
15
+ def gridify():
16
+
17
+ out_folder = "grids_objaverse"
18
+ os.makedirs(out_folder, exist_ok=True)
19
+
20
+ folder_paths = glob.glob("/scratch/shared/beegfs/stan/scaling_splatter_image/objaverse/*")
21
+ # pixelnerf_root = "/scratch/shared/beegfs/stan/splatter_image/pixelnerf/teddybears"
22
+ folder_paths_test = sorted([fpath for fpath in folder_paths if "gt" not in fpath], key= lambda x: int(os.path.basename(x).split("_")[0]))
23
+ """folder_paths_test = [folder_paths_test[i] for i in [5, 7, 12, 15,
24
+ 18, 19, 30, 33,
25
+ 37, 42, 43, 44,
26
+ 48, 51, 64, 66,
27
+ 70, 74, 78, 85,
28
+ 89, 91, 92]]"""
29
+
30
+ # Initialize variables for grid dimensions
31
+ num_examples_row = 6
32
+ rows = num_examples_row
33
+ num_per_ex = 2
34
+ cols = num_examples_row * num_per_ex # 7 * 2
35
+ im_res = 128
36
+
37
+ for im_idx in range(100):
38
+ print("Doing frame {}".format(im_idx))
39
+ # for im_name in ["xyz", "colours", "opacity", "scaling"]:
40
+ grid = np.zeros((rows*im_res, cols*im_res, 3), dtype=np.uint8)
41
+
42
+ # Iterate through the folders in the out_folder
43
+ for f_idx, folder_path_test in enumerate(folder_paths_test[:num_examples_row*num_examples_row]):
44
+ # if im_name == "xyz":
45
+ # print(folder_path_test)
46
+ row_idx = f_idx // num_examples_row
47
+ col_idx = f_idx % num_examples_row
48
+ im_path = os.path.join(folder_path_test, "{:05d}.png".format(im_idx))
49
+ im_path_gt = os.path.join(folder_path_test + "_gt", "{:05d}.png".format(im_idx))
50
+ """im_path_pixelnerf = os.path.join(pixelnerf_root, os.path.basename(folder_path_test),
51
+ "{:06d}.png".format(im_idx))"""
52
+
53
+ # im_path = os.path.join(folder_path_test, "{}.png".format(im_name))
54
+ try:
55
+ im = np.array(Image.open(im_path))
56
+ im_gt = np.array(Image.open(im_path_gt))
57
+ #im_pn = np.array(Image.open(im_path_pixelnerf))
58
+ grid[row_idx*im_res: (row_idx+1)*im_res,
59
+ col_idx * num_per_ex *im_res: (col_idx * num_per_ex+1)*im_res, : ] = im[:, :, :3]
60
+ grid[row_idx*im_res: (row_idx+1)*im_res,
61
+ (col_idx * num_per_ex + 1) *im_res: (col_idx* num_per_ex +2)*im_res, : ] = im_gt[:, :, :3]
62
+ """grid[row_idx*im_res: (row_idx+1)*im_res,
63
+ (col_idx * num_per_ex + 2) *im_res: (col_idx* num_per_ex +3)*im_res, : ] = im_pn[:, :, :3]"""
64
+ except FileNotFoundError:
65
+ a = 0
66
+ im_out = Image.fromarray(grid)
67
+ im_out.save(os.path.join(out_folder, "{:05d}.png".format(im_idx)))
68
+ # im_out.save(os.path.join(out_folder, "{}.png".format(im_name)))
69
+
70
+ def comparisons():
71
+
72
+ out_root = "hydrants_comparisons"
73
+ os.makedirs(out_root, exist_ok=True)
74
+
75
+ folder_paths = glob.glob("/users/stan/pixel-nerf/full_eval_hydrant/*")
76
+ folder_paths_test = sorted(folder_paths)
77
+ folder_paths_ours_root = "/scratch/shared/beegfs/stan/out_hydrants_with_lpips_ours"
78
+
79
+ # Initialize variables for grid dimensions
80
+ rows = 3
81
+ cols = 1
82
+ im_res = 128
83
+
84
+ for f_idx, folder_path_test in enumerate(folder_paths_test):
85
+
86
+ example_id = "_".join(os.path.basename(folder_path_test).split("_")[1:])
87
+ out_folder = os.path.join(out_root, example_id)
88
+ os.makedirs(out_folder, exist_ok=True)
89
+ num_images = len([p for p in glob.glob(os.path.join(folder_path_test, "*.png")) if "gt" not in p])
90
+
91
+ grid = np.zeros((rows*im_res, cols*im_res, 3), dtype=np.uint8)
92
+
93
+ for im_idx in range(num_images):
94
+
95
+ im_path_pixelnerf = os.path.join(folder_path_test, "{:06d}.png".format(im_idx+1))
96
+ im_path_ours = os.path.join(folder_paths_ours_root, example_id, "{:05d}.png".format(im_idx))
97
+ im_path_gt = os.path.join(folder_paths_ours_root, example_id + "_gt", "{:05d}.png".format(im_idx))
98
+ # im_path = os.path.join(folder_path_test, "{}.png".format(im_name))
99
+
100
+ im_pn = np.array(Image.open(im_path_pixelnerf))
101
+ im_ours = np.array(Image.open(im_path_ours))
102
+ im_gt = np.array(Image.open(im_path_gt))
103
+
104
+ grid[:im_res, :, :] = im_pn
105
+ grid[im_res:2*im_res, :, :] = im_ours
106
+ grid[2*im_res:3*im_res, :, :] = im_gt
107
+
108
+ im_out = Image.fromarray(grid)
109
+ im_out.save(os.path.join(out_folder, "{:05d}.png".format(im_idx)))
110
+
111
+ def vis_image_preds(image_preds: dict, folder_out: str):
112
+ """
113
+ Visualises network's image predictions.
114
+ Args:
115
+ image_preds: a dictionary of xyz, opacity, scaling, rotation, features_dc and features_rest
116
+ """
117
+ image_preds_reshaped = {}
118
+ ray_dirs = (image_preds["xyz"].detach().cpu() / torch.norm(image_preds["xyz"].detach().cpu(), dim=-1, keepdim=True)).reshape(128, 128, 3)
119
+
120
+ for k, v in image_preds.items():
121
+ image_preds_reshaped[k] = v
122
+ if k == "xyz":
123
+ image_preds_reshaped[k] = (image_preds_reshaped[k] - torch.min(image_preds_reshaped[k], dim=0, keepdim=True)[0]) / (
124
+ torch.max(image_preds_reshaped[k], dim=0, keepdim=True)[0] - torch.min(image_preds_reshaped[k], dim=0, keepdim=True)[0]
125
+ )
126
+ if k == "scaling":
127
+ image_preds_reshaped["scaling"] = (image_preds_reshaped["scaling"] - torch.min(image_preds_reshaped["scaling"], dim=0, keepdim=True)[0]) / (
128
+ torch.max(image_preds_reshaped["scaling"], dim=0, keepdim=True)[0] - torch.min(image_preds_reshaped["scaling"], dim=0, keepdim=True)[0]
129
+ )
130
+ if k != "features_rest":
131
+ image_preds_reshaped[k] = image_preds_reshaped[k].reshape(128, 128, -1).detach().cpu()
132
+ else:
133
+ image_preds_reshaped[k] = image_preds_reshaped[k].reshape(128, 128, 3, 3).detach().cpu().permute(0, 1, 3, 2)
134
+ if k == "opacity":
135
+ image_preds_reshaped[k] = image_preds_reshaped[k].expand(128, 128, 3)
136
+
137
+
138
+ colours = torch.cat([image_preds_reshaped["features_dc"].unsqueeze(-1), image_preds_reshaped["features_rest"]], dim=-1)
139
+ colours = eval_sh(1, colours, ray_dirs)
140
+
141
+ plt.imsave(os.path.join(folder_out, "colours.png"),
142
+ colours.numpy())
143
+ plt.imsave(os.path.join(folder_out, "opacity.png"),
144
+ image_preds_reshaped["opacity"].numpy())
145
+ plt.imsave(os.path.join(folder_out, "xyz.png"),
146
+ (image_preds_reshaped["xyz"] * image_preds_reshaped["opacity"]+ 1 - image_preds_reshaped["opacity"]).numpy())
147
+ plt.imsave(os.path.join(folder_out, "scaling.png"),
148
+ (image_preds_reshaped["scaling"] * image_preds_reshaped["opacity"] + 1 - image_preds_reshaped["opacity"]).numpy())
149
+
150
+ if __name__ == "__main__":
151
+ gridify()