from typing import Dict import cv2 import numpy as np import torch from torchvision import transforms from navsim.agents.dreamer.hydra_dreamer_config import HydraDreamerConfig from navsim.common.dataclasses import AgentInput, Scene from navsim.common.dataclasses import Cameras from navsim.planning.training.abstract_feature_target_builder import ( AbstractFeatureBuilder, AbstractTargetBuilder, ) def cat_flr_imgs(camera: Cameras, config: HydraDreamerConfig): l0 = camera.cam_l0.image[28:-28, 416:-416] f0 = camera.cam_f0.image[28:-28] r0 = camera.cam_r0.image[28:-28, 416:-416] stitched_image = np.concatenate([l0, f0, r0], axis=1) resized_image = cv2.resize(stitched_image, (config.camera_width, config.camera_height)) tensor_image = transforms.ToTensor()(resized_image) return tensor_image class HydraDreamerWmFeatureBuilder(AbstractFeatureBuilder): def __init__(self, config: HydraDreamerConfig): super().__init__() self._config = config def get_unique_name(self) -> str: """Inherited, see superclass.""" return "hydra_dreamer_wm_feature" def _get_camera_feature(self, agent_input: AgentInput): """ Extract stitched camera from AgentInput :param agent_input: input dataclass :return: stitched front view image as torch tensor """ cameras = agent_input.cameras[:3] image_list = [] for camera in cameras: image_list.append(cat_flr_imgs(camera, self._config)) return image_list def compute_features(self, agent_input: AgentInput) -> Dict[str, torch.Tensor]: """Inherited, see superclass.""" features = {} ego_status_list = [] for i in range(self._config.num_ego_status): # i=0: idx=-1 # i=1: idx=-2 # i=2: idx=-3 # i=3: idx=-4 idx = - (i + 1) ego_status_list += [ torch.tensor(agent_input.ego_statuses[idx].driving_command, dtype=torch.float32), torch.tensor(agent_input.ego_statuses[idx].ego_velocity, dtype=torch.float32), torch.tensor(agent_input.ego_statuses[idx].ego_acceleration, dtype=torch.float32), ] features["status_feature"] = torch.concatenate( ego_status_list ) imgs = self._get_camera_feature(agent_input) features['img_3'] = imgs[0] features['img_2'] = imgs[1] features['img_1'] = imgs[2] # todo perspective box, map, cam # box # map # cam return features class HydraDreamerWmTargetBuilder(AbstractTargetBuilder): def __init__(self, config: HydraDreamerConfig): super().__init__() self._config = config def get_unique_name(self) -> str: """Inherited, see superclass.""" return "hydra_dreamer_wm_target" def compute_targets(self, scene: Scene) -> Dict[str, torch.Tensor]: return { 'img_gt': cat_flr_imgs(scene.get_agent_input().cameras[-1], self._config) }