navsim_ours / navsim /agents /dreamer /hydra_dreamer_wm_features.py
lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
raw
history blame
3.1 kB
from typing import Dict
import cv2
import numpy as np
import torch
from torchvision import transforms
from navsim.agents.dreamer.hydra_dreamer_config import HydraDreamerConfig
from navsim.common.dataclasses import AgentInput, Scene
from navsim.common.dataclasses import Cameras
from navsim.planning.training.abstract_feature_target_builder import (
AbstractFeatureBuilder,
AbstractTargetBuilder,
)
def cat_flr_imgs(camera: Cameras, config: HydraDreamerConfig):
l0 = camera.cam_l0.image[28:-28, 416:-416]
f0 = camera.cam_f0.image[28:-28]
r0 = camera.cam_r0.image[28:-28, 416:-416]
stitched_image = np.concatenate([l0, f0, r0], axis=1)
resized_image = cv2.resize(stitched_image, (config.camera_width, config.camera_height))
tensor_image = transforms.ToTensor()(resized_image)
return tensor_image
class HydraDreamerWmFeatureBuilder(AbstractFeatureBuilder):
def __init__(self, config: HydraDreamerConfig):
super().__init__()
self._config = config
def get_unique_name(self) -> str:
"""Inherited, see superclass."""
return "hydra_dreamer_wm_feature"
def _get_camera_feature(self, agent_input: AgentInput):
"""
Extract stitched camera from AgentInput
:param agent_input: input dataclass
:return: stitched front view image as torch tensor
"""
cameras = agent_input.cameras[:3]
image_list = []
for camera in cameras:
image_list.append(cat_flr_imgs(camera, self._config))
return image_list
def compute_features(self, agent_input: AgentInput) -> Dict[str, torch.Tensor]:
"""Inherited, see superclass."""
features = {}
ego_status_list = []
for i in range(self._config.num_ego_status):
# i=0: idx=-1
# i=1: idx=-2
# i=2: idx=-3
# i=3: idx=-4
idx = - (i + 1)
ego_status_list += [
torch.tensor(agent_input.ego_statuses[idx].driving_command, dtype=torch.float32),
torch.tensor(agent_input.ego_statuses[idx].ego_velocity, dtype=torch.float32),
torch.tensor(agent_input.ego_statuses[idx].ego_acceleration, dtype=torch.float32),
]
features["status_feature"] = torch.concatenate(
ego_status_list
)
imgs = self._get_camera_feature(agent_input)
features['img_3'] = imgs[0]
features['img_2'] = imgs[1]
features['img_1'] = imgs[2]
# todo perspective box, map, cam
# box
# map
# cam
return features
class HydraDreamerWmTargetBuilder(AbstractTargetBuilder):
def __init__(self, config: HydraDreamerConfig):
super().__init__()
self._config = config
def get_unique_name(self) -> str:
"""Inherited, see superclass."""
return "hydra_dreamer_wm_target"
def compute_targets(self, scene: Scene) -> Dict[str, torch.Tensor]:
return {
'img_gt': cat_flr_imgs(scene.get_agent_input().cameras[-1], self._config)
}