import pickle import matplotlib.pyplot as plt import torch position_range = [-32.0, -32.0, -10.0, 32.0, 32.0, 10.0] depth_num = 64 depth_start = 1 def position_embedding(features, img_features): eps = 1e-5 img_features = img_features.unsqueeze(1) B, N, C, tar_H, tar_W = img_features.shape device = img_features.device crop_top = 28 crop_left = 416 H = [16 for _ in range(3)] W = [ 64 * 1088 // (1088 * 2 + 1920), 64 * 1920 // (1088 * 2 + 1920), 64 * 1088 // (1088 * 2 + 1920) ] # 左视图(16,17) coords_h_l = torch.arange(H[0], device=device).float() * 1080 / H[0] + crop_top / H[0] coords_w_l = torch.arange(W[0], device=device).float() * 1920 / W[0] + crop_left / W[0] # 前视图(16,30) coords_h_f = torch.arange(H[1], device=device).float() * 1080 / H[1] + crop_top / H[1] coords_w_f = torch.arange(W[1], device=device).float() * 1920 / W[1] # 右视图(16,17) coords_h_r = torch.arange(H[2], device=device).float() * 1080 / H[2] + crop_top / H[2] coords_w_r = torch.arange(W[2], device=device).float() * 1920 / W[2] + crop_left / W[2] index = torch.arange(start=0, end=depth_num, step=1, device=img_features.device).float() index_1 = index + 1 bin_size = (position_range[3] - depth_start) / (depth_num * (1 + depth_num)) coords_d = depth_start + bin_size * index * index_1 D = coords_d.shape[0] coords = [1] * 3 # 0,1,2 -> front, left, right coords[0] = torch.stack(torch.meshgrid([coords_w_l, coords_h_l, coords_d])).permute(1, 2, 3, 0) # W, H, D, 3 coords[1] = torch.stack(torch.meshgrid([coords_w_f, coords_h_f, coords_d])).permute(1, 2, 3, 0) # W, H, D, 3 coords[2] = torch.stack(torch.meshgrid([coords_w_r, coords_h_r, coords_d])).permute(1, 2, 3, 0) # W, H, D, 3 # coords = torch.cat((coords, torch.ones_like(coords[..., :1])), -1) coords[0][..., :2] = coords[0][..., :2] * torch.max(coords[0][..., 2:3], torch.ones_like(coords[0][..., 2:3]) * eps) coords[1][..., :2] = coords[1][..., :2] * torch.max(coords[1][..., 2:3], torch.ones_like(coords[1][..., 2:3]) * eps) coords[2][..., :2] = coords[2][..., :2] * torch.max(coords[2][..., 2:3], torch.ones_like(coords[2][..., 2:3]) * eps) # img_meta # img2lidars = ? pos_3d_embed = None for i in range(3): sensor2lidar_rotation = features["sensor2lidar_rotation"][i] sensor2lidar_translation = features["sensor2lidar_translation"][i] intrinsics = features["intrinsics"][i] combine = torch.matmul(sensor2lidar_rotation, torch.inverse(intrinsics)).float() # (B, 1, 3, 3) ? # print(combine.shape) # coords_front,coords_fleft,coords_fright (W, H, D, 3) # coords3d = torch.stack((coords_front, coords_fleft, coords_fright), dim=0) # (N, W, H, D, 3) -> (B, N, W, H, D, 3, 1) # coords = coords.view(1, H, W, D, 1, 3).repeat(B, 1, 1, 1, 1, 1) coords3d = coords[i].view(1, N, W[i], H[i], D, 3, 1).repeat(B, 1, 1, 1, 1, 1, 1) # (B, N, W, H, D, 3, 1) -> (B, N, W, H, D, 3, 3) combine = combine.view(B, N, 1, 1, 1, 3, 3).repeat(1, 1, W[i], H[i], D, 1, 1) coords3d = torch.matmul(combine, coords3d).squeeze(-1) # (B, N, W, H, D, 3) sensor2lidar_translation = sensor2lidar_translation.view(B, N, 1, 1, 1, 3) coords3d += sensor2lidar_translation coords3d[..., 0:1] = (coords3d[..., 0:1] - position_range[0]) / ( position_range[3] - position_range[0]) coords3d[..., 1:2] = (coords3d[..., 1:2] - position_range[1]) / ( position_range[4] - position_range[1]) coords3d[..., 2:3] = (coords3d[..., 2:3] - position_range[2]) / ( position_range[5] - position_range[2]) # coords_mask = (coords3d > 1.0) | (coords3d < 0.0) # coords_mask = coords_mask.flatten(-2).sum(-1) > (D * 0.5) # coords_mask = coords_mask.permute(0, 1, 3, 2) # for j in range(1000000): # print(coords3d.shape) # (2, 1, 17, 16, 64, 3) -> (B, N, W, H, D, 3) # (2, 1, 30, 16, 64, 3) # -> (2, 1, 17+30+17, 16, 64, 3) # coords3d = coords3d.permute(0, 1, 4, 5, 3, 2).contiguous().view(B * N, -1, H[i], W[i]) if pos_3d_embed is None: pos_3d_embed = coords3d else: pos_3d_embed = torch.cat((pos_3d_embed, coords3d), dim=2) # for i in range(100000): # print(img_features.shape) pos_3d_embed = pos_3d_embed.permute(0, 1, 4, 5, 3, 2).contiguous().view(B * N, -1, tar_H, tar_W) return pos_3d_embed if __name__ == '__main__': H, W = 16, 64 logs = pickle.load(open('/mnt/g/navsim/navsim_logs/tiny/2021.05.12.22.28.35_veh-35_00620_01164.pkl', 'rb')) log = logs[0] features = { 'sensor2lidar_rotation': [ torch.from_numpy(log['cams']['CAM_L0']['sensor2lidar_rotation']), torch.from_numpy(log['cams']['CAM_F0']['sensor2lidar_rotation']), torch.from_numpy(log['cams']['CAM_R0']['sensor2lidar_rotation']), ], 'sensor2lidar_translation': [ torch.from_numpy(log['cams']['CAM_L0']['sensor2lidar_translation']), torch.from_numpy(log['cams']['CAM_F0']['sensor2lidar_translation']), torch.from_numpy(log['cams']['CAM_R0']['sensor2lidar_translation']), ], 'intrinsics': [ torch.from_numpy(log['cams']['CAM_L0']['cam_intrinsic']), torch.from_numpy(log['cams']['CAM_F0']['cam_intrinsic']), torch.from_numpy(log['cams']['CAM_R0']['cam_intrinsic']), ] } img_features = torch.randn((1, 3, H, W)) coords_3d = position_embedding( features, img_features ) fig = plt.figure() ax = fig.add_subplot(111, projection='3d') for i in range(H): for j in range(W): frustum_points = coords_3d.permute(2, 3, 1, 0).reshape(H, W, depth_num, 3) pixel_points = frustum_points[i, j] x_points = pixel_points[:, 0] y_points = pixel_points[:, 1] z_points = pixel_points[:, 2] ax.scatter(x_points, y_points, z_points) ax.set_xlabel('X') ax.set_ylabel('Y') ax.view_init(elev=90, azim=0) ax.set_zlabel('Z') plt.show()