Spaces:
Sleeping
Sleeping
File size: 3,799 Bytes
82b898c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import numpy as np
import torch
import cv2
import open3d as o3d
from dust3r.post_process import estimate_focal_knowing_depth
from dust3r.utils.geometry import inv
def estimate_focal(pts3d_i, pp=None):
if pp is None:
H, W, THREE = pts3d_i.shape
assert THREE == 3
pp = torch.tensor((W/2, H/2), device=pts3d_i.device)
focal = estimate_focal_knowing_depth(pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode='weiszfeld').ravel()
return float(focal)
def pixel_grid(H, W):
return np.mgrid[:W, :H].T.astype(np.float32)
def sRT_to_4x4(scale, R, T, device):
trf = torch.eye(4, device=device)
trf[:3, :3] = R * scale
trf[:3, 3] = T.ravel() # doesn't need scaling
return trf
def to_numpy(tensor):
return tensor.cpu().numpy() if isinstance(tensor, torch.Tensor) else tensor
def calculate_depth_map(pts3d, R, T):
"""
Calculate ray depths directly using camera center and 3D points.
Args:
pts3d (np.array): 3D points in world coordinates, shape (H, W, 3)
R (np.array): Rotation matrix, shape (3, 3)
T (np.array): Translation vector, shape (3, 1)
Returns:
np.array: Depth map of shape (H, W)
"""
# Camera center in world coordinates is simply -T
C = -T.ravel()
# Calculate ray vectors
ray_vectors = pts3d - C
# Calculate ray depths
depth_map = np.linalg.norm(ray_vectors, axis=2)
return depth_map
def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10):
# extract camera poses and focals with RANSAC-PnP
if msk.sum() < 4:
return None # we need at least 4 points for PnP
pts3d, msk = map(to_numpy, (pts3d, msk))
H, W, THREE = pts3d.shape
assert THREE == 3
pixels = pixel_grid(H, W)
if focal is None:
S = max(W, H)
tentative_focals = np.geomspace(S/2, S*3, 21)
else:
tentative_focals = [focal]
if pp is None:
pp = (W/2, H/2)
else:
pp = to_numpy(pp)
best = 0, None, None, None, None
for focal in tentative_focals:
K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
success, R, T, inliers = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None,
iterationsCount=niter_PnP, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)
if not success:
continue
score = len(inliers)
if success and score > best[0]:
depth_map = calculate_depth_map(pts3d, R, T)
best = score, R, T, focal, depth_map
if not best[0]:
return None
_, R, T, best_focal, depth_map = best
R = cv2.Rodrigues(R)[0] # world to cam
R, T = map(torch.from_numpy, (R, T))
depth_map = torch.from_numpy(depth_map).to(device)
cam_to_world = inv(sRT_to_4x4(1, R, T, device)) # cam to world
return best_focal, cam_to_world, depth_map
def solve_cemara(pts3d, msk, device, focal=None, pp=None):
# Estimate focal length
if focal is None:
focal = estimate_focal(pts3d, pp)
# Compute camera pose using PnP
result = fast_pnp(pts3d, focal, msk, device, pp)
if result is None:
return None, focal, None
best_focal, camera_to_world, depth_map = result
# Construct K matrix
H, W, _ = pts3d.shape
if pp is None:
pp = (W/2, H/2)
camera_parameters = o3d.camera.PinholeCameraParameters()
intrinsic = o3d.camera.PinholeCameraIntrinsic()
intrinsic.set_intrinsics(W, H,
best_focal, best_focal,
pp[0], pp[1])
camera_parameters.intrinsic = intrinsic
camera_parameters.extrinsic = torch.inverse(camera_to_world).cpu().numpy()
return camera_parameters, best_focal, depth_map
|