import torch import numpy as np import PIL.Image as Image class LiDAR2Depth(object): def __init__(self, grid_config, ): self.x = eval(grid_config['x']) self.y = eval(grid_config['y']) self.z = eval(grid_config['z']) self.depth = eval(grid_config['depth']) def points2depthmap(self, points, height, width): height, width = height, width depth_map = torch.zeros((height, width), dtype=torch.float32) coor = torch.round(points[:, :2]) depth = points[:, 2] kept1 = (coor[:, 0] >= 0) & (coor[:, 0] < width) & ( coor[:, 1] >= 0) & (coor[:, 1] < height) & ( depth < self.depth[1]) & ( depth >= self.depth[0]) coor, depth = coor[kept1], depth[kept1] ranks = coor[:, 0] + coor[:, 1] * width sort = (ranks + depth / 100.).argsort() coor, depth, ranks = coor[sort], depth[sort], ranks[sort] kept2 = torch.ones(coor.shape[0], device=coor.device, dtype=torch.bool) kept2[1:] = (ranks[1:] != ranks[:-1]) coor, depth = coor[kept2], depth[kept2] coor = coor.to(torch.long) depth_map[coor[:, 1], coor[:, 0]] = depth return depth_map def __call__(self, features, targets): # points, img, sensor2lidar_rotation, sensor2lidar_translation, intrinsics, # post_rot, post_tran # List: length=frames lidar_all_frames = features['lidars_warped'] # image: T, N_CAMS, C, H, W T, N, _, H, W = features['image'].shape rots, trans, intrinsics = (features['sensor2lidar_rotation'], features['sensor2lidar_translation'], features['intrinsics']) post_rot, post_tran, bda = (features['post_rot'], features['post_tran'], features['bda']) t = -1 depth_t = [] lidar_t = lidar_all_frames[t][:, :3] lidar_t = lidar_t - bda[:3, 3].view(1, 3) lidar_t = lidar_t.matmul(torch.inverse(bda[:3, :3]).T) # print('cancel bda') # print(lidar_t[:, 0].max()) # print(lidar_t[:, 0].min()) # print(lidar_t[:, 1].max()) # print(lidar_t[:, 1].min()) for n in range(N): points_img = lidar_t - trans[t, n:n + 1, :] lidar2cam_rot = torch.inverse(rots[t, n]) # lidar2cam, cam2img points_img = points_img.matmul(lidar2cam_rot.T).matmul(intrinsics[t, n].T) points_img = torch.cat( [points_img[:, :2] / points_img[:, 2:3], points_img[:, 2:3]], 1) points_img = points_img.matmul( post_rot[t, n].T) + post_tran[t, n:n + 1, :] depth_curr = self.points2depthmap(points_img, features['canvas'][-1, n].shape[0], features['canvas'][-1, n].shape[1]) depth_t.append(depth_curr) # Image.fromarray((1- depth_curr.clamp(0,1)).cpu().numpy() * 255).convert('L').save(f'/mnt/f/e2e/navsim_ours/debug/depth{n}.png') # Image.fromarray(features['canvas'][-1, n].cpu().numpy().astype(np.uint8)).convert('RGB').save(f'/mnt/f/e2e/navsim_ours/debug/canvas{n}.png') features['gt_depth'] = torch.stack(depth_t) return features, targets