lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
raw
history blame
6.34 kB
import pickle
import matplotlib.pyplot as plt
import torch
position_range = [-32.0, -32.0, -10.0, 32.0, 32.0, 10.0]
depth_num = 64
depth_start = 1
def position_embedding(features, img_features):
eps = 1e-5
img_features = img_features.unsqueeze(1)
B, N, C, tar_H, tar_W = img_features.shape
device = img_features.device
crop_top = 28
crop_left = 416
H = [16 for _ in range(3)]
W = [
64 * 1088 // (1088 * 2 + 1920),
64 * 1920 // (1088 * 2 + 1920),
64 * 1088 // (1088 * 2 + 1920)
]
# 左视图(16,17)
coords_h_l = torch.arange(H[0], device=device).float() * 1080 / H[0] + crop_top / H[0]
coords_w_l = torch.arange(W[0], device=device).float() * 1920 / W[0] + crop_left / W[0]
# 前视图(16,30)
coords_h_f = torch.arange(H[1], device=device).float() * 1080 / H[1] + crop_top / H[1]
coords_w_f = torch.arange(W[1], device=device).float() * 1920 / W[1]
# 右视图(16,17)
coords_h_r = torch.arange(H[2], device=device).float() * 1080 / H[2] + crop_top / H[2]
coords_w_r = torch.arange(W[2], device=device).float() * 1920 / W[2] + crop_left / W[2]
index = torch.arange(start=0, end=depth_num, step=1, device=img_features.device).float()
index_1 = index + 1
bin_size = (position_range[3] - depth_start) / (depth_num * (1 + depth_num))
coords_d = depth_start + bin_size * index * index_1
D = coords_d.shape[0]
coords = [1] * 3 # 0,1,2 -> front, left, right
coords[0] = torch.stack(torch.meshgrid([coords_w_l, coords_h_l, coords_d])).permute(1, 2, 3, 0) # W, H, D, 3
coords[1] = torch.stack(torch.meshgrid([coords_w_f, coords_h_f, coords_d])).permute(1, 2, 3, 0) # W, H, D, 3
coords[2] = torch.stack(torch.meshgrid([coords_w_r, coords_h_r, coords_d])).permute(1, 2, 3, 0) # W, H, D, 3
# coords = torch.cat((coords, torch.ones_like(coords[..., :1])), -1)
coords[0][..., :2] = coords[0][..., :2] * torch.max(coords[0][..., 2:3], torch.ones_like(coords[0][..., 2:3]) * eps)
coords[1][..., :2] = coords[1][..., :2] * torch.max(coords[1][..., 2:3], torch.ones_like(coords[1][..., 2:3]) * eps)
coords[2][..., :2] = coords[2][..., :2] * torch.max(coords[2][..., 2:3], torch.ones_like(coords[2][..., 2:3]) * eps)
# img_meta
# img2lidars = ?
pos_3d_embed = None
for i in range(3):
sensor2lidar_rotation = features["sensor2lidar_rotation"][i]
sensor2lidar_translation = features["sensor2lidar_translation"][i]
intrinsics = features["intrinsics"][i]
combine = torch.matmul(sensor2lidar_rotation, torch.inverse(intrinsics)).float() # (B, 1, 3, 3) ?
# print(combine.shape)
# coords_front,coords_fleft,coords_fright (W, H, D, 3)
# coords3d = torch.stack((coords_front, coords_fleft, coords_fright), dim=0) # (N, W, H, D, 3) -> (B, N, W, H, D, 3, 1)
# coords = coords.view(1, H, W, D, 1, 3).repeat(B, 1, 1, 1, 1, 1)
coords3d = coords[i].view(1, N, W[i], H[i], D, 3, 1).repeat(B, 1, 1, 1, 1, 1,
1) # (B, N, W, H, D, 3, 1) -> (B, N, W, H, D, 3, 3)
combine = combine.view(B, N, 1, 1, 1, 3, 3).repeat(1, 1, W[i], H[i], D, 1, 1)
coords3d = torch.matmul(combine, coords3d).squeeze(-1) # (B, N, W, H, D, 3)
sensor2lidar_translation = sensor2lidar_translation.view(B, N, 1, 1, 1, 3)
coords3d += sensor2lidar_translation
coords3d[..., 0:1] = (coords3d[..., 0:1] - position_range[0]) / (
position_range[3] - position_range[0])
coords3d[..., 1:2] = (coords3d[..., 1:2] - position_range[1]) / (
position_range[4] - position_range[1])
coords3d[..., 2:3] = (coords3d[..., 2:3] - position_range[2]) / (
position_range[5] - position_range[2])
# coords_mask = (coords3d > 1.0) | (coords3d < 0.0)
# coords_mask = coords_mask.flatten(-2).sum(-1) > (D * 0.5)
# coords_mask = coords_mask.permute(0, 1, 3, 2)
# for j in range(1000000):
# print(coords3d.shape)
# (2, 1, 17, 16, 64, 3) -> (B, N, W, H, D, 3)
# (2, 1, 30, 16, 64, 3)
# -> (2, 1, 17+30+17, 16, 64, 3)
# coords3d = coords3d.permute(0, 1, 4, 5, 3, 2).contiguous().view(B * N, -1, H[i], W[i])
if pos_3d_embed is None:
pos_3d_embed = coords3d
else:
pos_3d_embed = torch.cat((pos_3d_embed, coords3d), dim=2)
# for i in range(100000):
# print(img_features.shape)
pos_3d_embed = pos_3d_embed.permute(0, 1, 4, 5, 3, 2).contiguous().view(B * N, -1, tar_H, tar_W)
return pos_3d_embed
if __name__ == '__main__':
H, W = 16, 64
logs = pickle.load(open('/mnt/g/navsim/navsim_logs/tiny/2021.05.12.22.28.35_veh-35_00620_01164.pkl', 'rb'))
log = logs[0]
features = {
'sensor2lidar_rotation': [
torch.from_numpy(log['cams']['CAM_L0']['sensor2lidar_rotation']),
torch.from_numpy(log['cams']['CAM_F0']['sensor2lidar_rotation']),
torch.from_numpy(log['cams']['CAM_R0']['sensor2lidar_rotation']),
],
'sensor2lidar_translation': [
torch.from_numpy(log['cams']['CAM_L0']['sensor2lidar_translation']),
torch.from_numpy(log['cams']['CAM_F0']['sensor2lidar_translation']),
torch.from_numpy(log['cams']['CAM_R0']['sensor2lidar_translation']),
],
'intrinsics': [
torch.from_numpy(log['cams']['CAM_L0']['cam_intrinsic']),
torch.from_numpy(log['cams']['CAM_F0']['cam_intrinsic']),
torch.from_numpy(log['cams']['CAM_R0']['cam_intrinsic']),
]
}
img_features = torch.randn((1, 3, H, W))
coords_3d = position_embedding(
features, img_features
)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
for i in range(H):
for j in range(W):
frustum_points = coords_3d.permute(2, 3, 1, 0).reshape(H, W, depth_num, 3)
pixel_points = frustum_points[i, j]
x_points = pixel_points[:, 0]
y_points = pixel_points[:, 1]
z_points = pixel_points[:, 2]
ax.scatter(x_points, y_points, z_points)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.view_init(elev=90, azim=0)
ax.set_zlabel('Z')
plt.show()