File size: 3,799 Bytes
82b898c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import numpy as np
import torch
import cv2
import open3d as o3d
from dust3r.post_process import estimate_focal_knowing_depth
from dust3r.utils.geometry import inv

def estimate_focal(pts3d_i, pp=None):
    if pp is None:
        H, W, THREE = pts3d_i.shape
        assert THREE == 3
        pp = torch.tensor((W/2, H/2), device=pts3d_i.device)
    focal = estimate_focal_knowing_depth(pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode='weiszfeld').ravel()
    return float(focal)

def pixel_grid(H, W):
    return np.mgrid[:W, :H].T.astype(np.float32)

def sRT_to_4x4(scale, R, T, device):
    trf = torch.eye(4, device=device)
    trf[:3, :3] = R * scale
    trf[:3, 3] = T.ravel()  # doesn't need scaling
    return trf

def to_numpy(tensor):
    return tensor.cpu().numpy() if isinstance(tensor, torch.Tensor) else tensor

def calculate_depth_map(pts3d, R, T):
    """
    Calculate ray depths directly using camera center and 3D points.
    
    Args:
    pts3d (np.array): 3D points in world coordinates, shape (H, W, 3)
    R (np.array): Rotation matrix, shape (3, 3)
    T (np.array): Translation vector, shape (3, 1)
    
    Returns:
    np.array: Depth map of shape (H, W)
    """
    # Camera center in world coordinates is simply -T
    C = -T.ravel()
    
    # Calculate ray vectors
    ray_vectors = pts3d - C
    
    # Calculate ray depths
    depth_map = np.linalg.norm(ray_vectors, axis=2)
    
    return depth_map

def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10):
    # extract camera poses and focals with RANSAC-PnP
    if msk.sum() < 4:
        return None  # we need at least 4 points for PnP
    pts3d, msk = map(to_numpy, (pts3d, msk))

    H, W, THREE = pts3d.shape
    assert THREE == 3
    pixels = pixel_grid(H, W)

    if focal is None:
        S = max(W, H)
        tentative_focals = np.geomspace(S/2, S*3, 21)
    else:
        tentative_focals = [focal]

    if pp is None:
        pp = (W/2, H/2)
    else:
        pp = to_numpy(pp)

    best = 0, None, None, None, None
    for focal in tentative_focals:
        K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])

        success, R, T, inliers = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None,
                                                    iterationsCount=niter_PnP, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)

        if not success:
            continue

        score = len(inliers)
        if success and score > best[0]:
            depth_map = calculate_depth_map(pts3d, R, T)
            best = score, R, T, focal, depth_map

    if not best[0]:
        return None

    _, R, T, best_focal, depth_map = best
    R = cv2.Rodrigues(R)[0]  # world to cam
    R, T = map(torch.from_numpy, (R, T))
    depth_map = torch.from_numpy(depth_map).to(device)

    cam_to_world = inv(sRT_to_4x4(1, R, T, device))  # cam to world

    return best_focal, cam_to_world, depth_map

def solve_cemara(pts3d, msk, device, focal=None, pp=None):
    # Estimate focal length
    if focal is None:
        focal = estimate_focal(pts3d, pp)
    
    # Compute camera pose using PnP
    result = fast_pnp(pts3d, focal, msk, device, pp)
    
    if result is None:
        return None, focal, None
    
    best_focal, camera_to_world, depth_map = result
    
    # Construct K matrix
    H, W, _ = pts3d.shape
    if pp is None:
        pp = (W/2, H/2)
    
    camera_parameters = o3d.camera.PinholeCameraParameters()
    intrinsic = o3d.camera.PinholeCameraIntrinsic()
    intrinsic.set_intrinsics(W, H, 
                             best_focal, best_focal, 
                             pp[0], pp[1])

    camera_parameters.intrinsic = intrinsic
    camera_parameters.extrinsic = torch.inverse(camera_to_world).cpu().numpy()

    return camera_parameters, best_focal, depth_map