from typing import Literal import numpy as np import roma import scipy.interpolate import torch import torch.nn.functional as F DEFAULT_FOV_RAD = 0.9424777960769379 # 54 degrees by default def get_camera_dist( source_c2ws: torch.Tensor, # N x 3 x 4 target_c2ws: torch.Tensor, # M x 3 x 4 mode: str = "translation", ): if mode == "rotation": dists = torch.acos( ( ( torch.matmul( source_c2ws[:, None, :3, :3], target_c2ws[None, :, :3, :3].transpose(-1, -2), ) .diagonal(offset=0, dim1=-2, dim2=-1) .sum(-1) - 1 ) / 2 ).clamp(-1, 1) ) * (180 / torch.pi) elif mode == "translation": dists = torch.norm( source_c2ws[:, None, :3, 3] - target_c2ws[None, :, :3, 3], dim=-1 ) else: raise NotImplementedError( f"Mode {mode} is not implemented for finding nearest source indices." ) return dists def to_hom(X): # get homogeneous coordinates of the input X_hom = torch.cat([X, torch.ones_like(X[..., :1])], dim=-1) return X_hom def to_hom_pose(pose): # get homogeneous coordinates of the input pose if pose.shape[-2:] == (3, 4): pose_hom = torch.eye(4, device=pose.device)[None].repeat(pose.shape[0], 1, 1) pose_hom[:, :3, :] = pose return pose_hom return pose def get_default_intrinsics( fov_rad=DEFAULT_FOV_RAD, aspect_ratio=1.0, ): if not isinstance(fov_rad, torch.Tensor): fov_rad = torch.tensor( [fov_rad] if isinstance(fov_rad, (int, float)) else fov_rad ) if aspect_ratio >= 1.0: # W >= H focal_x = 0.5 / torch.tan(0.5 * fov_rad) focal_y = focal_x * aspect_ratio else: # W < H focal_y = 0.5 / torch.tan(0.5 * fov_rad) focal_x = focal_y / aspect_ratio intrinsics = focal_x.new_zeros((focal_x.shape[0], 3, 3)) intrinsics[:, torch.eye(3, device=focal_x.device, dtype=bool)] = torch.stack( [focal_x, focal_y, torch.ones_like(focal_x)], dim=-1 ) intrinsics[:, :, -1] = torch.tensor( [0.5, 0.5, 1.0], device=focal_x.device, dtype=focal_x.dtype ) return intrinsics def get_image_grid(img_h, img_w): # add 0.5 is VERY important especially when your img_h and img_w # is not very large (e.g., 72)!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! y_range = torch.arange(img_h, dtype=torch.float32).add_(0.5) x_range = torch.arange(img_w, dtype=torch.float32).add_(0.5) Y, X = torch.meshgrid(y_range, x_range, indexing="ij") # [H,W] xy_grid = torch.stack([X, Y], dim=-1).view(-1, 2) # [HW,2] return to_hom(xy_grid) # [HW,3] def img2cam(X, cam_intr): return X @ cam_intr.inverse().transpose(-1, -2) def cam2world(X, pose): X_hom = to_hom(X) pose_inv = torch.linalg.inv(to_hom_pose(pose))[..., :3, :4] return X_hom @ pose_inv.transpose(-1, -2) def get_center_and_ray( img_h, img_w, pose, intr, zero_center_for_debugging=False ): # [HW,2] # given the intrinsic/extrinsic matrices, get the camera center and ray directions] # assert(opt.camera.model=="perspective") # compute center and ray grid_img = get_image_grid(img_h, img_w) # [HW,3] grid_3D_cam = img2cam(grid_img.to(intr.device), intr.float()) # [B,HW,3] center_3D_cam = torch.zeros_like(grid_3D_cam) # [B,HW,3] # transform from camera to world coordinates grid_3D = cam2world(grid_3D_cam, pose) # [B,HW,3] center_3D = cam2world(center_3D_cam, pose) # [B,HW,3] ray = grid_3D - center_3D # [B,HW,3] return center_3D_cam if zero_center_for_debugging else center_3D, ray, grid_3D_cam def get_plucker_coordinates( extrinsics_src, extrinsics, intrinsics=None, fov_rad=DEFAULT_FOV_RAD, mode="plucker", rel_zero_translation=True, zero_center_for_debugging=False, target_size=[72, 72], # 576-size image return_grid_cam=False, # save for later use if want restore ): if intrinsics is None: intrinsics = get_default_intrinsics(fov_rad).to(extrinsics.device) else: # for some data preprocessed in the early stage (e.g., MVI and CO3D), # intrinsics are expressed in raw pixel space (e.g., 576x576) instead # of normalized image coordinates if not ( torch.all(intrinsics[:, :2, -1] >= 0) and torch.all(intrinsics[:, :2, -1] <= 1) ): intrinsics[:, :2] /= intrinsics.new_tensor(target_size).view(1, -1, 1) * 8 # you should ensure the intrisics are expressed in # resolution-independent normalized image coordinates just performing a # very simple verification here checking if principal points are # between 0 and 1 assert ( torch.all(intrinsics[:, :2, -1] >= 0) and torch.all(intrinsics[:, :2, -1] <= 1) ), "Intrinsics should be expressed in resolution-independent normalized image coordinates." c2w_src = torch.linalg.inv(extrinsics_src) if not rel_zero_translation: c2w_src[:3, 3] = c2w_src[3, :3] = 0.0 # transform coordinates from the source camera's coordinate system to the coordinate system of the respective camera extrinsics_rel = torch.einsum( "vnm,vmp->vnp", extrinsics, c2w_src[None].repeat(extrinsics.shape[0], 1, 1) ) intrinsics[:, :2] *= extrinsics.new_tensor( [ target_size[1], # w target_size[0], # h ] ).view(1, -1, 1) centers, rays, grid_cam = get_center_and_ray( img_h=target_size[0], img_w=target_size[1], pose=extrinsics_rel[:, :3, :], intr=intrinsics, zero_center_for_debugging=zero_center_for_debugging, ) if mode == "plucker" or "v1" in mode: rays = torch.nn.functional.normalize(rays, dim=-1) plucker = torch.cat((rays, torch.cross(centers, rays, dim=-1)), dim=-1) else: raise ValueError(f"Unknown Plucker coordinate mode: {mode}") plucker = plucker.permute(0, 2, 1).reshape(plucker.shape[0], -1, *target_size) if return_grid_cam: return plucker, grid_cam.reshape(-1, *target_size, 3) return plucker def rt_to_mat4( R: torch.Tensor, t: torch.Tensor, s: torch.Tensor | None = None ) -> torch.Tensor: """ Args: R (torch.Tensor): (..., 3, 3). t (torch.Tensor): (..., 3). s (torch.Tensor): (...,). Returns: torch.Tensor: (..., 4, 4) """ mat34 = torch.cat([R, t[..., None]], dim=-1) if s is None: bottom = ( mat34.new_tensor([[0.0, 0.0, 0.0, 1.0]]) .reshape((1,) * (mat34.dim() - 2) + (1, 4)) .expand(mat34.shape[:-2] + (1, 4)) ) else: bottom = F.pad(1.0 / s[..., None, None], (3, 0), value=0.0) mat4 = torch.cat([mat34, bottom], dim=-2) return mat4 def get_preset_pose_fov( option: Literal[ "orbit", "spiral", "lemniscate", "zoom-in", "zoom-out", "dolly zoom-in", "dolly zoom-out", "move-forward", "move-backward", "move-up", "move-down", "move-left", "move-right", "roll", ], num_frames: int, start_w2c: torch.Tensor, look_at: torch.Tensor, up_direction: torch.Tensor | None = None, fov: float = DEFAULT_FOV_RAD, spiral_radii: list[float] = [0.5, 0.5, 0.2], zoom_factor: float | None = None, ): poses = fovs = None if option == "orbit": poses = torch.linalg.inv( get_arc_horizontal_w2cs( start_w2c, look_at, up_direction, num_frames=num_frames, endpoint=False, ) ).numpy() fovs = np.full((num_frames,), fov) elif option == "spiral": poses = generate_spiral_path( torch.linalg.inv(start_w2c)[None].numpy() @ np.diagflat([1, -1, -1, 1]), np.array([1, 5]), n_frames=num_frames, n_rots=2, zrate=0.5, radii=spiral_radii, endpoint=False, ) @ np.diagflat([1, -1, -1, 1]) poses = np.concatenate( [ poses, np.array([0.0, 0.0, 0.0, 1.0])[None, None].repeat(len(poses), 0), ], 1, ) # We want the spiral trajectory to always start from start_w2c. Thus we # apply the relative pose to get the final trajectory. poses = ( np.linalg.inv(start_w2c.numpy())[None] @ np.linalg.inv(poses[:1]) @ poses ) fovs = np.full((num_frames,), fov) elif option == "lemniscate": poses = torch.linalg.inv( get_lemniscate_w2cs( start_w2c, look_at, up_direction, num_frames, degree=60.0, endpoint=False, ) ).numpy() fovs = np.full((num_frames,), fov) elif option == "roll": poses = torch.linalg.inv( get_roll_w2cs( start_w2c, look_at, None, num_frames, degree=360.0, endpoint=False, ) ).numpy() fovs = np.full((num_frames,), fov) elif option in [ "dolly zoom-in", "dolly zoom-out", "zoom-in", "zoom-out", ]: if option.startswith("dolly"): direction = "backward" if option == "dolly zoom-in" else "forward" poses = torch.linalg.inv( get_moving_w2cs( start_w2c, look_at, up_direction, num_frames, endpoint=True, direction=direction, ) ).numpy() else: poses = torch.linalg.inv(start_w2c)[None].repeat(num_frames, 1, 1).numpy() fov_rad_start = fov if zoom_factor is None: zoom_factor = 0.28 if option.endswith("zoom-in") else 1.5 fov_rad_end = zoom_factor * fov fovs = ( np.linspace(0, 1, num_frames) * (fov_rad_end - fov_rad_start) + fov_rad_start ) elif option in [ "move-forward", "move-backward", "move-up", "move-down", "move-left", "move-right", ]: poses = torch.linalg.inv( get_moving_w2cs( start_w2c, look_at, up_direction, num_frames, endpoint=True, direction=option.removeprefix("move-"), ) ).numpy() fovs = np.full((num_frames,), fov) else: raise ValueError(f"Unknown preset option {option}.") return poses, fovs def get_lookat(origins: torch.Tensor, viewdirs: torch.Tensor) -> torch.Tensor: """Triangulate a set of rays to find a single lookat point. Args: origins (torch.Tensor): A (N, 3) array of ray origins. viewdirs (torch.Tensor): A (N, 3) array of ray view directions. Returns: torch.Tensor: A (3,) lookat point. """ viewdirs = torch.nn.functional.normalize(viewdirs, dim=-1) eye = torch.eye(3, device=origins.device, dtype=origins.dtype)[None] # Calculate projection matrix I - rr^T I_min_cov = eye - (viewdirs[..., None] * viewdirs[..., None, :]) # Compute sum of projections sum_proj = I_min_cov.matmul(origins[..., None]).sum(dim=-3) # Solve for the intersection point using least squares lookat = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0] # Check NaNs. assert not torch.any(torch.isnan(lookat)) return lookat def get_lookat_w2cs( positions: torch.Tensor, lookat: torch.Tensor, up: torch.Tensor, face_off: bool = False, ): """ Args: positions: (N, 3) tensor of camera positions lookat: (3,) tensor of lookat point up: (3,) or (N, 3) tensor of up vector Returns: w2cs: (N, 3, 3) tensor of world to camera rotation matrices """ forward_vectors = F.normalize(lookat - positions, dim=-1) if face_off: forward_vectors = -forward_vectors if up.dim() == 1: up = up[None] right_vectors = F.normalize(torch.cross(forward_vectors, up, dim=-1), dim=-1) down_vectors = F.normalize( torch.cross(forward_vectors, right_vectors, dim=-1), dim=-1 ) Rs = torch.stack([right_vectors, down_vectors, forward_vectors], dim=-1) w2cs = torch.linalg.inv(rt_to_mat4(Rs, positions)) return w2cs def get_arc_horizontal_w2cs( ref_w2c: torch.Tensor, lookat: torch.Tensor, up: torch.Tensor | None, num_frames: int, clockwise: bool = True, face_off: bool = False, endpoint: bool = False, degree: float = 360.0, ref_up_shift: float = 0.0, ref_radius_scale: float = 1.0, **_, ) -> torch.Tensor: ref_c2w = torch.linalg.inv(ref_w2c) ref_position = ref_c2w[:3, 3] if up is None: up = -ref_c2w[:3, 1] assert up is not None ref_position += up * ref_up_shift ref_position *= ref_radius_scale thetas = ( torch.linspace(0.0, torch.pi * degree / 180, num_frames, device=ref_w2c.device) if endpoint else torch.linspace( 0.0, torch.pi * degree / 180, num_frames + 1, device=ref_w2c.device )[:-1] ) if not clockwise: thetas = -thetas positions = ( torch.einsum( "nij,j->ni", roma.rotvec_to_rotmat(thetas[:, None] * up[None]), ref_position - lookat, ) + lookat ) return get_lookat_w2cs(positions, lookat, up, face_off=face_off) def get_lemniscate_w2cs( ref_w2c: torch.Tensor, lookat: torch.Tensor, up: torch.Tensor | None, num_frames: int, degree: float, endpoint: bool = False, **_, ) -> torch.Tensor: ref_c2w = torch.linalg.inv(ref_w2c) a = torch.linalg.norm(ref_c2w[:3, 3] - lookat) * np.tan(degree / 360 * np.pi) # Lemniscate curve in camera space. Starting at the origin. thetas = ( torch.linspace(0, 2 * torch.pi, num_frames, device=ref_w2c.device) if endpoint else torch.linspace(0, 2 * torch.pi, num_frames + 1, device=ref_w2c.device)[:-1] ) + torch.pi / 2 positions = torch.stack( [ a * torch.cos(thetas) / (1 + torch.sin(thetas) ** 2), a * torch.cos(thetas) * torch.sin(thetas) / (1 + torch.sin(thetas) ** 2), torch.zeros(num_frames, device=ref_w2c.device), ], dim=-1, ) # Transform to world space. positions = torch.einsum( "ij,nj->ni", ref_c2w[:3], F.pad(positions, (0, 1), value=1.0) ) if up is None: up = -ref_c2w[:3, 1] assert up is not None return get_lookat_w2cs(positions, lookat, up) def get_moving_w2cs( ref_w2c: torch.Tensor, lookat: torch.Tensor, up: torch.Tensor | None, num_frames: int, endpoint: bool = False, direction: str = "forward", tilt_xy: torch.Tensor = None, ): """ Args: ref_w2c: (4, 4) tensor of the reference wolrd-to-camera matrix lookat: (3,) tensor of lookat point up: (3,) tensor of up vector Returns: w2cs: (N, 3, 3) tensor of world to camera rotation matrices """ ref_c2w = torch.linalg.inv(ref_w2c) ref_position = ref_c2w[:3, -1] if up is None: up = -ref_c2w[:3, 1] direction_vectors = { "forward": (lookat - ref_position).clone(), "backward": -(lookat - ref_position).clone(), "up": up.clone(), "down": -up.clone(), "right": torch.cross((lookat - ref_position), up, dim=0), "left": -torch.cross((lookat - ref_position), up, dim=0), } if direction not in direction_vectors: raise ValueError( f"Invalid direction: {direction}. Must be one of {list(direction_vectors.keys())}" ) positions = ref_position + ( F.normalize(direction_vectors[direction], dim=0) * ( torch.linspace(0, 0.99, num_frames, device=ref_w2c.device) if endpoint else torch.linspace(0, 1, num_frames + 1, device=ref_w2c.device)[:-1] )[:, None] ) if tilt_xy is not None: positions[:, :2] += tilt_xy return get_lookat_w2cs(positions, lookat, up) def get_roll_w2cs( ref_w2c: torch.Tensor, lookat: torch.Tensor, up: torch.Tensor | None, num_frames: int, endpoint: bool = False, degree: float = 360.0, **_, ) -> torch.Tensor: ref_c2w = torch.linalg.inv(ref_w2c) ref_position = ref_c2w[:3, 3] if up is None: up = -ref_c2w[:3, 1] # Infer the up vector from the reference. # Create vertical angles thetas = ( torch.linspace(0.0, torch.pi * degree / 180, num_frames, device=ref_w2c.device) if endpoint else torch.linspace( 0.0, torch.pi * degree / 180, num_frames + 1, device=ref_w2c.device )[:-1] )[:, None] lookat_vector = F.normalize(lookat[None].float(), dim=-1) up = up[None] up = ( up * torch.cos(thetas) + torch.cross(lookat_vector, up) * torch.sin(thetas) + lookat_vector * torch.einsum("ij,ij->i", lookat_vector, up)[:, None] * (1 - torch.cos(thetas)) ) # Normalize the camera orientation return get_lookat_w2cs(ref_position[None].repeat(num_frames, 1), lookat, up) def normalize(x): """Normalization helper function.""" return x / np.linalg.norm(x) def viewmatrix(lookdir, up, position, subtract_position=False): """Construct lookat view matrix.""" vec2 = normalize((lookdir - position) if subtract_position else lookdir) vec0 = normalize(np.cross(up, vec2)) vec1 = normalize(np.cross(vec2, vec0)) m = np.stack([vec0, vec1, vec2, position], axis=1) return m def poses_avg(poses): """New pose using average position, z-axis, and up vector of input poses.""" position = poses[:, :3, 3].mean(0) z_axis = poses[:, :3, 2].mean(0) up = poses[:, :3, 1].mean(0) cam2world = viewmatrix(z_axis, up, position) return cam2world def generate_spiral_path( poses, bounds, n_frames=120, n_rots=2, zrate=0.5, endpoint=False, radii=None ): """Calculates a forward facing spiral path for rendering.""" # Find a reasonable 'focus depth' for this dataset as a weighted average # of near and far bounds in disparity space. close_depth, inf_depth = bounds.min() * 0.9, bounds.max() * 5.0 dt = 0.75 focal = 1 / ((1 - dt) / close_depth + dt / inf_depth) # Get radii for spiral path using 90th percentile of camera positions. positions = poses[:, :3, 3] if radii is None: radii = np.percentile(np.abs(positions), 90, 0) radii = np.concatenate([radii, [1.0]]) # Generate poses for spiral path. render_poses = [] cam2world = poses_avg(poses) up = poses[:, :3, 1].mean(0) for theta in np.linspace(0.0, 2.0 * np.pi * n_rots, n_frames, endpoint=endpoint): t = radii * [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0] position = cam2world @ t lookat = cam2world @ [0, 0, -focal, 1.0] z_axis = position - lookat render_poses.append(viewmatrix(z_axis, up, position)) render_poses = np.stack(render_poses, axis=0) return render_poses def generate_interpolated_path( poses: np.ndarray, n_interp: int, spline_degree: int = 5, smoothness: float = 0.03, rot_weight: float = 0.1, endpoint: bool = False, ): """Creates a smooth spline path between input keyframe camera poses. Spline is calculated with poses in format (position, lookat-point, up-point). Args: poses: (n, 3, 4) array of input pose keyframes. n_interp: returned path will have n_interp * (n - 1) total poses. spline_degree: polynomial degree of B-spline. smoothness: parameter for spline smoothing, 0 forces exact interpolation. rot_weight: relative weighting of rotation/translation in spline solve. Returns: Array of new camera poses with shape (n_interp * (n - 1), 3, 4). """ def poses_to_points(poses, dist): """Converts from pose matrices to (position, lookat, up) format.""" pos = poses[:, :3, -1] lookat = poses[:, :3, -1] - dist * poses[:, :3, 2] up = poses[:, :3, -1] + dist * poses[:, :3, 1] return np.stack([pos, lookat, up], 1) def points_to_poses(points): """Converts from (position, lookat, up) format to pose matrices.""" return np.array([viewmatrix(p - l, u - p, p) for p, l, u in points]) def interp(points, n, k, s): """Runs multidimensional B-spline interpolation on the input points.""" sh = points.shape pts = np.reshape(points, (sh[0], -1)) k = min(k, sh[0] - 1) tck, _ = scipy.interpolate.splprep(pts.T, k=k, s=s) u = np.linspace(0, 1, n, endpoint=endpoint) new_points = np.array(scipy.interpolate.splev(u, tck)) new_points = np.reshape(new_points.T, (n, sh[1], sh[2])) return new_points points = poses_to_points(poses, dist=rot_weight) new_points = interp( points, n_interp * (points.shape[0] - 1), k=spline_degree, s=smoothness ) return points_to_poses(new_points) def similarity_from_cameras(c2w, strict_scaling=False, center_method="focus"): """ reference: nerf-factory Get a similarity transform to normalize dataset from c2w (OpenCV convention) cameras :param c2w: (N, 4) :return T (4,4) , scale (float) """ t = c2w[:, :3, 3] R = c2w[:, :3, :3] # (1) Rotate the world so that z+ is the up axis # we estimate the up axis by averaging the camera up axes ups = np.sum(R * np.array([0, -1.0, 0]), axis=-1) world_up = np.mean(ups, axis=0) world_up /= np.linalg.norm(world_up) up_camspace = np.array([0.0, -1.0, 0.0]) c = (up_camspace * world_up).sum() cross = np.cross(world_up, up_camspace) skew = np.array( [ [0.0, -cross[2], cross[1]], [cross[2], 0.0, -cross[0]], [-cross[1], cross[0], 0.0], ] ) if c > -1: R_align = np.eye(3) + skew + (skew @ skew) * 1 / (1 + c) else: # In the unlikely case the original data has y+ up axis, # rotate 180-deg about x axis R_align = np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) # R_align = np.eye(3) # DEBUG R = R_align @ R fwds = np.sum(R * np.array([0, 0.0, 1.0]), axis=-1) t = (R_align @ t[..., None])[..., 0] # (2) Recenter the scene. if center_method == "focus": # find the closest point to the origin for each camera's center ray nearest = t + (fwds * -t).sum(-1)[:, None] * fwds translate = -np.median(nearest, axis=0) elif center_method == "poses": # use center of the camera positions translate = -np.median(t, axis=0) else: raise ValueError(f"Unknown center_method {center_method}") transform = np.eye(4) transform[:3, 3] = translate transform[:3, :3] = R_align # (3) Rescale the scene using camera distances scale_fn = np.max if strict_scaling else np.median inv_scale = scale_fn(np.linalg.norm(t + translate, axis=-1)) if inv_scale == 0: inv_scale = 1.0 scale = 1.0 / inv_scale transform[:3, :] *= scale return transform def align_principle_axes(point_cloud): # Compute centroid centroid = np.median(point_cloud, axis=0) # Translate point cloud to centroid translated_point_cloud = point_cloud - centroid # Compute covariance matrix covariance_matrix = np.cov(translated_point_cloud, rowvar=False) # Compute eigenvectors and eigenvalues eigenvalues, eigenvectors = np.linalg.eigh(covariance_matrix) # Sort eigenvectors by eigenvalues (descending order) so that the z-axis # is the principal axis with the smallest eigenvalue. sort_indices = eigenvalues.argsort()[::-1] eigenvectors = eigenvectors[:, sort_indices] # Check orientation of eigenvectors. If the determinant of the eigenvectors is # negative, then we need to flip the sign of one of the eigenvectors. if np.linalg.det(eigenvectors) < 0: eigenvectors[:, 0] *= -1 # Create rotation matrix rotation_matrix = eigenvectors.T # Create SE(3) matrix (4x4 transformation matrix) transform = np.eye(4) transform[:3, :3] = rotation_matrix transform[:3, 3] = -rotation_matrix @ centroid return transform def transform_points(matrix, points): """Transform points using a SE(4) matrix. Args: matrix: 4x4 SE(4) matrix points: Nx3 array of points Returns: Nx3 array of transformed points """ assert matrix.shape == (4, 4) assert len(points.shape) == 2 and points.shape[1] == 3 return points @ matrix[:3, :3].T + matrix[:3, 3] def transform_cameras(matrix, camtoworlds): """Transform cameras using a SE(4) matrix. Args: matrix: 4x4 SE(4) matrix camtoworlds: Nx4x4 array of camera-to-world matrices Returns: Nx4x4 array of transformed camera-to-world matrices """ assert matrix.shape == (4, 4) assert len(camtoworlds.shape) == 3 and camtoworlds.shape[1:] == (4, 4) camtoworlds = np.einsum("nij, ki -> nkj", camtoworlds, matrix) scaling = np.linalg.norm(camtoworlds[:, 0, :3], axis=1) camtoworlds[:, :3, :3] = camtoworlds[:, :3, :3] / scaling[:, None, None] return camtoworlds def normalize_scene(camtoworlds, points=None, camera_center_method="focus"): T1 = similarity_from_cameras(camtoworlds, center_method=camera_center_method) camtoworlds = transform_cameras(T1, camtoworlds) if points is not None: points = transform_points(T1, points) T2 = align_principle_axes(points) camtoworlds = transform_cameras(T2, camtoworlds) points = transform_points(T2, points) return camtoworlds, points, T2 @ T1 else: return camtoworlds, T1