import numpy as np | |
import torch | |
import cv2 | |
import open3d as o3d | |
from dust3r.post_process import estimate_focal_knowing_depth | |
from dust3r.utils.geometry import inv | |
def estimate_focal(pts3d_i, pp=None): | |
if pp is None: | |
H, W, THREE = pts3d_i.shape | |
assert THREE == 3 | |
pp = torch.tensor((W/2, H/2), device=pts3d_i.device) | |
focal = estimate_focal_knowing_depth(pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode='weiszfeld').ravel() | |
return float(focal) | |
def pixel_grid(H, W): | |
return np.mgrid[:W, :H].T.astype(np.float32) | |
def sRT_to_4x4(scale, R, T, device): | |
trf = torch.eye(4, device=device) | |
trf[:3, :3] = R * scale | |
trf[:3, 3] = T.ravel() # doesn't need scaling | |
return trf | |
def to_numpy(tensor): | |
return tensor.cpu().numpy() if isinstance(tensor, torch.Tensor) else tensor | |
def calculate_depth_map(pts3d, R, T): | |
""" | |
Calculate ray depths directly using camera center and 3D points. | |
Args: | |
pts3d (np.array): 3D points in world coordinates, shape (H, W, 3) | |
R (np.array): Rotation matrix, shape (3, 3) | |
T (np.array): Translation vector, shape (3, 1) | |
Returns: | |
np.array: Depth map of shape (H, W) | |
""" | |
# Camera center in world coordinates is simply -T | |
C = -T.ravel() | |
# Calculate ray vectors | |
ray_vectors = pts3d - C | |
# Calculate ray depths | |
depth_map = np.linalg.norm(ray_vectors, axis=2) | |
return depth_map | |
def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10): | |
# extract camera poses and focals with RANSAC-PnP | |
if msk.sum() < 4: | |
return None # we need at least 4 points for PnP | |
pts3d, msk = map(to_numpy, (pts3d, msk)) | |
H, W, THREE = pts3d.shape | |
assert THREE == 3 | |
pixels = pixel_grid(H, W) | |
if focal is None: | |
S = max(W, H) | |
tentative_focals = np.geomspace(S/2, S*3, 21) | |
else: | |
tentative_focals = [focal] | |
if pp is None: | |
pp = (W/2, H/2) | |
else: | |
pp = to_numpy(pp) | |
best = 0, None, None, None, None | |
for focal in tentative_focals: | |
K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)]) | |
success, R, T, inliers = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None, | |
iterationsCount=niter_PnP, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP) | |
if not success: | |
continue | |
score = len(inliers) | |
if success and score > best[0]: | |
depth_map = calculate_depth_map(pts3d, R, T) | |
best = score, R, T, focal, depth_map | |
if not best[0]: | |
return None | |
_, R, T, best_focal, depth_map = best | |
R = cv2.Rodrigues(R)[0] # world to cam | |
R, T = map(torch.from_numpy, (R, T)) | |
depth_map = torch.from_numpy(depth_map).to(device) | |
cam_to_world = inv(sRT_to_4x4(1, R, T, device)) # cam to world | |
return best_focal, cam_to_world, depth_map | |
def solve_cemara(pts3d, msk, device, focal=None, pp=None): | |
# Estimate focal length | |
if focal is None: | |
focal = estimate_focal(pts3d, pp) | |
# Compute camera pose using PnP | |
result = fast_pnp(pts3d, focal, msk, device, pp) | |
if result is None: | |
return None, focal, None | |
best_focal, camera_to_world, depth_map = result | |
# Construct K matrix | |
H, W, _ = pts3d.shape | |
if pp is None: | |
pp = (W/2, H/2) | |
camera_parameters = | |
intrinsic = | |
intrinsic.set_intrinsics(W, H, | |
best_focal, best_focal, | |
pp[0], pp[1]) | |
camera_parameters.intrinsic = intrinsic | |
camera_parameters.extrinsic = torch.inverse(camera_to_world).cpu().numpy() | |
return camera_parameters, best_focal, depth_map | |