navsim_ours / det_map /data /datasets /feature_builders.py
lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
raw
history blame
4.42 kB
from __future__ import annotations
from typing import Dict
import numpy as np
import torch
from det_map.data.datasets.dataclasses import AgentInput, Camera
from det_map.data.datasets.lidar_utils import transform_points, render_image
from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder
from mmcv.parallel import DataContainer as DC
class LiDARCameraFeatureBuilder(AbstractFeatureBuilder):
def __init__(self, pipelines):
super().__init__()
self.pipelines = pipelines
def compute_features(self, agent_input: AgentInput) -> Dict[str, torch.Tensor]:
img_pipeline = self.pipelines['img']
timestamps_ori = agent_input.timestamps
timestamps = [(timestamps_ori[-1] - tmp) / 1e6 for tmp in timestamps_ori]
lidars = [np.copy(tmp.lidar_pc) for tmp in agent_input.lidars]
ego2globals = [tmp for tmp in agent_input.ego2globals]
# last frame is the key frame
global2ego_key = np.linalg.inv(ego2globals[-1])
# ego2global, global2ego key frame
lidars_warped = [transform_points(transform_points(pts, mat), global2ego_key)
for pts, mat in zip(lidars[:-1], ego2globals[:-1])]
lidars_warped.append(lidars[-1])
for i, l in enumerate(lidars_warped):
# x,y,z,intensity,timestamp
l[4] = timestamps[i]
lidars_warped[i] = torch.from_numpy(l[:5]).t()
# debug visualize lidar pc
# for idx, lidar in enumerate(lidars_warped):
# render_image(lidar, str('warped'+ str(idx)))
# for idx, lidar in enumerate([tmp.lidar_pc for tmp in agent_input.lidars]):
# render_image(lidar, str('ori'+ str(idx)))
cams_all_frames = [[
tmp.cam_f0,
# tmp.cam_l0,
# tmp.cam_l1,
# tmp.cam_l2,
# tmp.cam_r0,
# tmp.cam_r1,
# tmp.cam_r2,
tmp.cam_b0
] for tmp in agent_input.cameras]
image, canvas, sensor2lidar_rotation, sensor2lidar_translation, intrinsics, distortion, post_rot, post_tran = [], [], [], [], [], [], [], []
for cams_frame_t in cams_all_frames:
image_t, canvas_t, sensor2lidar_rotation_t, sensor2lidar_translation_t, intrinsics_t, distortion_t, post_rot_t, post_tran_t = [], [], [], [], [], [], [], []
for cam in cams_frame_t:
cam_processed: Camera = img_pipeline(cam)
image_t.append(cam_processed.image)
canvas_t.append(cam_processed.canvas)
sensor2lidar_rotation_t.append(cam_processed.sensor2lidar_rotation)
sensor2lidar_translation_t.append(cam_processed.sensor2lidar_translation)
intrinsics_t.append(cam_processed.intrinsics)
distortion_t.append(cam_processed.distortion)
post_rot_t.append(cam_processed.post_rot)
post_tran_t.append(cam_processed.post_tran)
image.append(torch.stack(image_t))
canvas.append(torch.stack(canvas_t))
sensor2lidar_rotation.append(torch.stack(sensor2lidar_rotation_t))
sensor2lidar_translation.append(torch.stack(sensor2lidar_translation_t))
intrinsics.append(torch.stack(intrinsics_t))
distortion.append(torch.stack(distortion_t))
post_rot.append(torch.stack(post_rot_t))
post_tran.append(torch.stack(post_tran_t))
# img: T, N_CAM, C, H, W
# imgs = DC(torch.stack(image), cpu_only=False, stack=True)
#combine = torch.matmul(sensor2lidar_rotation, torch.inverse(intrinsics))
#coords = torch.matmul(combine, coords)
#coords += sensor2lidar_translation
imgs = torch.stack(image)
return {
"image": imgs,
'canvas': torch.stack(canvas).to(imgs),
'sensor2lidar_rotation': torch.stack(sensor2lidar_rotation).to(imgs),
'sensor2lidar_translation': torch.stack(sensor2lidar_translation).to(imgs),
'intrinsics': torch.stack(intrinsics).to(imgs),
'distortion': torch.stack(distortion).to(imgs),
'post_rot': torch.stack(post_rot).to(imgs),
'post_tran': torch.stack(post_tran).to(imgs),
"lidars_warped": lidars_warped
}