navsim_ours / det_map /data /pipelines /prepare_depth.py
lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
raw
history blame
3.43 kB
import torch
import numpy as np
import PIL.Image as Image
class LiDAR2Depth(object):
def __init__(self,
grid_config,
):
self.x = eval(grid_config['x'])
self.y = eval(grid_config['y'])
self.z = eval(grid_config['z'])
self.depth = eval(grid_config['depth'])
def points2depthmap(self, points, height, width):
height, width = height, width
depth_map = torch.zeros((height, width), dtype=torch.float32)
coor = torch.round(points[:, :2])
depth = points[:, 2]
kept1 = (coor[:, 0] >= 0) & (coor[:, 0] < width) & (
coor[:, 1] >= 0) & (coor[:, 1] < height) & (
depth < self.depth[1]) & (
depth >= self.depth[0])
coor, depth = coor[kept1], depth[kept1]
ranks = coor[:, 0] + coor[:, 1] * width
sort = (ranks + depth / 100.).argsort()
coor, depth, ranks = coor[sort], depth[sort], ranks[sort]
kept2 = torch.ones(coor.shape[0], device=coor.device, dtype=torch.bool)
kept2[1:] = (ranks[1:] != ranks[:-1])
coor, depth = coor[kept2], depth[kept2]
coor = coor.to(torch.long)
depth_map[coor[:, 1], coor[:, 0]] = depth
return depth_map
def __call__(self, features, targets):
# points, img, sensor2lidar_rotation, sensor2lidar_translation, intrinsics,
# post_rot, post_tran
# List: length=frames
lidar_all_frames = features['lidars_warped']
# image: T, N_CAMS, C, H, W
T, N, _, H, W = features['image'].shape
rots, trans, intrinsics = (features['sensor2lidar_rotation'],
features['sensor2lidar_translation'],
features['intrinsics'])
post_rot, post_tran, bda = (features['post_rot'],
features['post_tran'], features['bda'])
t = -1
depth_t = []
lidar_t = lidar_all_frames[t][:, :3]
lidar_t = lidar_t - bda[:3, 3].view(1, 3)
lidar_t = lidar_t.matmul(torch.inverse(bda[:3, :3]).T)
# print('cancel bda')
# print(lidar_t[:, 0].max())
# print(lidar_t[:, 0].min())
# print(lidar_t[:, 1].max())
# print(lidar_t[:, 1].min())
for n in range(N):
points_img = lidar_t - trans[t, n:n + 1, :]
lidar2cam_rot = torch.inverse(rots[t, n])
# lidar2cam, cam2img
points_img = points_img.matmul(lidar2cam_rot.T).matmul(intrinsics[t, n].T)
points_img = torch.cat(
[points_img[:, :2] / points_img[:, 2:3], points_img[:, 2:3]],
1)
points_img = points_img.matmul(
post_rot[t, n].T) + post_tran[t, n:n + 1, :]
depth_curr = self.points2depthmap(points_img, features['canvas'][-1, n].shape[0], features['canvas'][-1, n].shape[1])
depth_t.append(depth_curr)
# Image.fromarray((1- depth_curr.clamp(0,1)).cpu().numpy() * 255).convert('L').save(f'/mnt/f/e2e/navsim_ours/debug/depth{n}.png')
# Image.fromarray(features['canvas'][-1, n].cpu().numpy().astype(np.uint8)).convert('RGB').save(f'/mnt/f/e2e/navsim_ours/debug/canvas{n}.png')
features['gt_depth'] = torch.stack(depth_t)
return features, targets