|
import pickle |
|
|
|
import matplotlib.pyplot as plt |
|
import torch |
|
|
|
position_range = [-32.0, -32.0, -10.0, 32.0, 32.0, 10.0] |
|
depth_num = 64 |
|
depth_start = 1 |
|
|
|
|
|
def position_embedding(features, img_features): |
|
eps = 1e-5 |
|
img_features = img_features.unsqueeze(1) |
|
B, N, C, tar_H, tar_W = img_features.shape |
|
device = img_features.device |
|
crop_top = 28 |
|
crop_left = 416 |
|
H = [16 for _ in range(3)] |
|
W = [ |
|
64 * 1088 // (1088 * 2 + 1920), |
|
64 * 1920 // (1088 * 2 + 1920), |
|
64 * 1088 // (1088 * 2 + 1920) |
|
] |
|
|
|
|
|
coords_h_l = torch.arange(H[0], device=device).float() * 1080 / H[0] + crop_top / H[0] |
|
coords_w_l = torch.arange(W[0], device=device).float() * 1920 / W[0] + crop_left / W[0] |
|
|
|
coords_h_f = torch.arange(H[1], device=device).float() * 1080 / H[1] + crop_top / H[1] |
|
coords_w_f = torch.arange(W[1], device=device).float() * 1920 / W[1] |
|
|
|
coords_h_r = torch.arange(H[2], device=device).float() * 1080 / H[2] + crop_top / H[2] |
|
coords_w_r = torch.arange(W[2], device=device).float() * 1920 / W[2] + crop_left / W[2] |
|
|
|
index = torch.arange(start=0, end=depth_num, step=1, device=img_features.device).float() |
|
index_1 = index + 1 |
|
bin_size = (position_range[3] - depth_start) / (depth_num * (1 + depth_num)) |
|
coords_d = depth_start + bin_size * index * index_1 |
|
|
|
D = coords_d.shape[0] |
|
coords = [1] * 3 |
|
coords[0] = torch.stack(torch.meshgrid([coords_w_l, coords_h_l, coords_d])).permute(1, 2, 3, 0) |
|
coords[1] = torch.stack(torch.meshgrid([coords_w_f, coords_h_f, coords_d])).permute(1, 2, 3, 0) |
|
coords[2] = torch.stack(torch.meshgrid([coords_w_r, coords_h_r, coords_d])).permute(1, 2, 3, 0) |
|
|
|
coords[0][..., :2] = coords[0][..., :2] * torch.max(coords[0][..., 2:3], torch.ones_like(coords[0][..., 2:3]) * eps) |
|
coords[1][..., :2] = coords[1][..., :2] * torch.max(coords[1][..., 2:3], torch.ones_like(coords[1][..., 2:3]) * eps) |
|
coords[2][..., :2] = coords[2][..., :2] * torch.max(coords[2][..., 2:3], torch.ones_like(coords[2][..., 2:3]) * eps) |
|
|
|
|
|
|
|
pos_3d_embed = None |
|
for i in range(3): |
|
sensor2lidar_rotation = features["sensor2lidar_rotation"][i] |
|
sensor2lidar_translation = features["sensor2lidar_translation"][i] |
|
intrinsics = features["intrinsics"][i] |
|
combine = torch.matmul(sensor2lidar_rotation, torch.inverse(intrinsics)).float() |
|
|
|
|
|
|
|
|
|
|
|
coords3d = coords[i].view(1, N, W[i], H[i], D, 3, 1).repeat(B, 1, 1, 1, 1, 1, |
|
1) |
|
combine = combine.view(B, N, 1, 1, 1, 3, 3).repeat(1, 1, W[i], H[i], D, 1, 1) |
|
coords3d = torch.matmul(combine, coords3d).squeeze(-1) |
|
sensor2lidar_translation = sensor2lidar_translation.view(B, N, 1, 1, 1, 3) |
|
coords3d += sensor2lidar_translation |
|
|
|
coords3d[..., 0:1] = (coords3d[..., 0:1] - position_range[0]) / ( |
|
position_range[3] - position_range[0]) |
|
coords3d[..., 1:2] = (coords3d[..., 1:2] - position_range[1]) / ( |
|
position_range[4] - position_range[1]) |
|
coords3d[..., 2:3] = (coords3d[..., 2:3] - position_range[2]) / ( |
|
position_range[5] - position_range[2]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if pos_3d_embed is None: |
|
pos_3d_embed = coords3d |
|
else: |
|
pos_3d_embed = torch.cat((pos_3d_embed, coords3d), dim=2) |
|
|
|
|
|
pos_3d_embed = pos_3d_embed.permute(0, 1, 4, 5, 3, 2).contiguous().view(B * N, -1, tar_H, tar_W) |
|
return pos_3d_embed |
|
|
|
|
|
if __name__ == '__main__': |
|
H, W = 16, 64 |
|
logs = pickle.load(open('/mnt/g/navsim/navsim_logs/tiny/2021.05.12.22.28.35_veh-35_00620_01164.pkl', 'rb')) |
|
log = logs[0] |
|
features = { |
|
'sensor2lidar_rotation': [ |
|
torch.from_numpy(log['cams']['CAM_L0']['sensor2lidar_rotation']), |
|
torch.from_numpy(log['cams']['CAM_F0']['sensor2lidar_rotation']), |
|
torch.from_numpy(log['cams']['CAM_R0']['sensor2lidar_rotation']), |
|
], |
|
'sensor2lidar_translation': [ |
|
torch.from_numpy(log['cams']['CAM_L0']['sensor2lidar_translation']), |
|
torch.from_numpy(log['cams']['CAM_F0']['sensor2lidar_translation']), |
|
torch.from_numpy(log['cams']['CAM_R0']['sensor2lidar_translation']), |
|
], |
|
'intrinsics': [ |
|
torch.from_numpy(log['cams']['CAM_L0']['cam_intrinsic']), |
|
torch.from_numpy(log['cams']['CAM_F0']['cam_intrinsic']), |
|
torch.from_numpy(log['cams']['CAM_R0']['cam_intrinsic']), |
|
] |
|
} |
|
img_features = torch.randn((1, 3, H, W)) |
|
coords_3d = position_embedding( |
|
features, img_features |
|
) |
|
fig = plt.figure() |
|
ax = fig.add_subplot(111, projection='3d') |
|
|
|
for i in range(H): |
|
for j in range(W): |
|
frustum_points = coords_3d.permute(2, 3, 1, 0).reshape(H, W, depth_num, 3) |
|
pixel_points = frustum_points[i, j] |
|
x_points = pixel_points[:, 0] |
|
y_points = pixel_points[:, 1] |
|
z_points = pixel_points[:, 2] |
|
ax.scatter(x_points, y_points, z_points) |
|
ax.set_xlabel('X') |
|
ax.set_ylabel('Y') |
|
ax.view_init(elev=90, azim=0) |
|
ax.set_zlabel('Z') |
|
plt.show() |
|
|