|
from __future__ import annotations
|
|
|
|
from typing import Dict
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from det_map.data.datasets.dataclasses import AgentInput, Camera
|
|
from det_map.data.datasets.lidar_utils import transform_points, render_image
|
|
from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder
|
|
from mmcv.parallel import DataContainer as DC
|
|
|
|
class LiDARCameraFeatureBuilder(AbstractFeatureBuilder):
|
|
def __init__(self, pipelines):
|
|
super().__init__()
|
|
self.pipelines = pipelines
|
|
|
|
def compute_features(self, agent_input: AgentInput) -> Dict[str, torch.Tensor]:
|
|
img_pipeline = self.pipelines['img']
|
|
timestamps_ori = agent_input.timestamps
|
|
timestamps = [(timestamps_ori[-1] - tmp) / 1e6 for tmp in timestamps_ori]
|
|
|
|
lidars = [np.copy(tmp.lidar_pc) for tmp in agent_input.lidars]
|
|
ego2globals = [tmp for tmp in agent_input.ego2globals]
|
|
|
|
|
|
global2ego_key = np.linalg.inv(ego2globals[-1])
|
|
|
|
lidars_warped = [transform_points(transform_points(pts, mat), global2ego_key)
|
|
for pts, mat in zip(lidars[:-1], ego2globals[:-1])]
|
|
lidars_warped.append(lidars[-1])
|
|
for i, l in enumerate(lidars_warped):
|
|
|
|
l[4] = timestamps[i]
|
|
lidars_warped[i] = torch.from_numpy(l[:5]).t()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cams_all_frames = [[
|
|
tmp.cam_f0,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tmp.cam_b0
|
|
] for tmp in agent_input.cameras]
|
|
|
|
image, canvas, sensor2lidar_rotation, sensor2lidar_translation, intrinsics, distortion, post_rot, post_tran = [], [], [], [], [], [], [], []
|
|
for cams_frame_t in cams_all_frames:
|
|
image_t, canvas_t, sensor2lidar_rotation_t, sensor2lidar_translation_t, intrinsics_t, distortion_t, post_rot_t, post_tran_t = [], [], [], [], [], [], [], []
|
|
for cam in cams_frame_t:
|
|
cam_processed: Camera = img_pipeline(cam)
|
|
image_t.append(cam_processed.image)
|
|
canvas_t.append(cam_processed.canvas)
|
|
sensor2lidar_rotation_t.append(cam_processed.sensor2lidar_rotation)
|
|
sensor2lidar_translation_t.append(cam_processed.sensor2lidar_translation)
|
|
intrinsics_t.append(cam_processed.intrinsics)
|
|
distortion_t.append(cam_processed.distortion)
|
|
post_rot_t.append(cam_processed.post_rot)
|
|
post_tran_t.append(cam_processed.post_tran)
|
|
image.append(torch.stack(image_t))
|
|
canvas.append(torch.stack(canvas_t))
|
|
sensor2lidar_rotation.append(torch.stack(sensor2lidar_rotation_t))
|
|
sensor2lidar_translation.append(torch.stack(sensor2lidar_translation_t))
|
|
intrinsics.append(torch.stack(intrinsics_t))
|
|
distortion.append(torch.stack(distortion_t))
|
|
post_rot.append(torch.stack(post_rot_t))
|
|
post_tran.append(torch.stack(post_tran_t))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
imgs = torch.stack(image)
|
|
return {
|
|
"image": imgs,
|
|
'canvas': torch.stack(canvas).to(imgs),
|
|
'sensor2lidar_rotation': torch.stack(sensor2lidar_rotation).to(imgs),
|
|
'sensor2lidar_translation': torch.stack(sensor2lidar_translation).to(imgs),
|
|
'intrinsics': torch.stack(intrinsics).to(imgs),
|
|
'distortion': torch.stack(distortion).to(imgs),
|
|
'post_rot': torch.stack(post_rot).to(imgs),
|
|
'post_tran': torch.stack(post_tran).to(imgs),
|
|
"lidars_warped": lidars_warped
|
|
}
|
|
|