from typing import Dict, List, Tuple import torch from det_map.data.datasets.dataloader import SceneLoader from det_map.data.datasets.dataset import Dataset from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder class DetDataset(Dataset): def __init__( self, **kwargs ): super().__init__(**kwargs) def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: scene = self._scene_loader.get_scene_from_token(self._scene_loader.tokens[idx]) features: Dict[str, torch.Tensor] = {} for builder in self._feature_builders: features.update(builder.compute_features(scene.get_agent_input())) targets: Dict[str, torch.Tensor] = {} for builder in self._target_builders: targets.update(builder.compute_targets(scene)) # todo sampler features, targets = self.pipelines['lidar_aug'](features, targets) features, targets = self.pipelines['depth'](features, targets) features, targets = self.pipelines['lidar_filter'](features, targets) features, targets = self.pipelines['point_shuffle'](features, targets) return (features, targets)