import hydra from hydra.utils import instantiate import logging from omegaconf import DictConfig import os from pathlib import Path import pytorch_lightning as pl from typing import Any, Dict, List, Optional, Union import uuid from navsim.planning.training.dataset import Dataset from navsim.common.dataloader import SceneLoader from navsim.common.dataclasses import SceneFilter, SensorConfig from navsim.agents.abstract_agent import AbstractAgent from nuplan.planning.utils.multithreading.worker_pool import WorkerPool from nuplan.planning.utils.multithreading.worker_utils import worker_map logger = logging.getLogger(__name__) CONFIG_PATH = "config/training" CONFIG_NAME = "default_training" def cache_features(args: List[Dict[str, Union[List[str], DictConfig]]]) -> List[Optional[Any]]: node_id = int(os.environ.get("NODE_RANK", 0)) thread_id = str(uuid.uuid4()) log_names = [a["log_file"] for a in args] tokens = [t for a in args for t in a["tokens"]] cfg: DictConfig = args[0]["cfg"] agent = args[0]['agent'] scene_filter: SceneFilter =instantiate(cfg.scene_filter) scene_filter.log_names = log_names scene_filter.tokens = tokens scene_loader = SceneLoader( sensor_blobs_path=Path(cfg.sensor_blobs_path), data_path=Path(cfg.navsim_log_path), scene_filter=scene_filter, sensor_config=agent.get_sensor_config(), ) logger.info( f"Extracted {len(scene_loader.tokens)} scenarios for thread_id={thread_id}, node_id={node_id}." ) dataset = Dataset( scene_loader=scene_loader, feature_builders=agent.get_feature_builders(), target_builders=agent.get_target_builders(), cache_path=cfg.cache_path, force_cache_computation=cfg.force_cache_computation, ) return [] @hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME) def main(cfg: DictConfig) -> None: logger.info("Global Seed set to 0") pl.seed_everything(0, workers=True) logger.info("Building Worker") worker: WorkerPool = instantiate(cfg.worker) logger.info("Building SceneLoader") scene_filter: SceneFilter = instantiate(cfg.scene_filter) data_path = Path(cfg.navsim_log_path) sensor_blobs_path = Path(cfg.sensor_blobs_path) scene_loader = SceneLoader( sensor_blobs_path=sensor_blobs_path, data_path=data_path, scene_filter=scene_filter, sensor_config=SensorConfig.build_no_sensors(), ) agent: AbstractAgent = instantiate(cfg.agent) logger.info(f"Extracted {len(scene_loader)} scenarios for training/validation dataset") data_points = [ { "cfg": cfg, "log_file": log_file, "tokens": tokens_list, "agent": agent } for log_file, tokens_list in scene_loader.get_tokens_list_per_log().items() ] _ = worker_map(worker, cache_features, data_points) logger.info(f"Finished caching {len(scene_loader)} scenarios for training/validation dataset") if __name__ == "__main__": main()