navsim_ours / navsim /planning /script /run_dataset_caching.py
lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
raw
history blame
3.16 kB
import hydra
from hydra.utils import instantiate
import logging
from omegaconf import DictConfig
import os
from pathlib import Path
import pytorch_lightning as pl
from typing import Any, Dict, List, Optional, Union
import uuid
from navsim.planning.training.dataset import Dataset
from navsim.common.dataloader import SceneLoader
from navsim.common.dataclasses import SceneFilter, SensorConfig
from navsim.agents.abstract_agent import AbstractAgent
from nuplan.planning.utils.multithreading.worker_pool import WorkerPool
from nuplan.planning.utils.multithreading.worker_utils import worker_map
logger = logging.getLogger(__name__)
CONFIG_PATH = "config/training"
CONFIG_NAME = "default_training"
def cache_features(args: List[Dict[str, Union[List[str], DictConfig]]]) -> List[Optional[Any]]:
node_id = int(os.environ.get("NODE_RANK", 0))
thread_id = str(uuid.uuid4())
log_names = [a["log_file"] for a in args]
tokens = [t for a in args for t in a["tokens"]]
cfg: DictConfig = args[0]["cfg"]
agent = args[0]['agent']
scene_filter: SceneFilter =instantiate(cfg.scene_filter)
scene_filter.log_names = log_names
scene_filter.tokens = tokens
scene_loader = SceneLoader(
sensor_blobs_path=Path(cfg.sensor_blobs_path),
data_path=Path(cfg.navsim_log_path),
scene_filter=scene_filter,
sensor_config=agent.get_sensor_config(),
)
logger.info(
f"Extracted {len(scene_loader.tokens)} scenarios for thread_id={thread_id}, node_id={node_id}."
)
dataset = Dataset(
scene_loader=scene_loader,
feature_builders=agent.get_feature_builders(),
target_builders=agent.get_target_builders(),
cache_path=cfg.cache_path,
force_cache_computation=cfg.force_cache_computation,
)
return []
@hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME)
def main(cfg: DictConfig) -> None:
logger.info("Global Seed set to 0")
pl.seed_everything(0, workers=True)
logger.info("Building Worker")
worker: WorkerPool = instantiate(cfg.worker)
logger.info("Building SceneLoader")
scene_filter: SceneFilter = instantiate(cfg.scene_filter)
data_path = Path(cfg.navsim_log_path)
sensor_blobs_path = Path(cfg.sensor_blobs_path)
scene_loader = SceneLoader(
sensor_blobs_path=sensor_blobs_path,
data_path=data_path,
scene_filter=scene_filter,
sensor_config=SensorConfig.build_no_sensors(),
)
agent: AbstractAgent = instantiate(cfg.agent)
logger.info(f"Extracted {len(scene_loader)} scenarios for training/validation dataset")
data_points = [
{
"cfg": cfg,
"log_file": log_file,
"tokens": tokens_list,
"agent": agent
}
for log_file, tokens_list in scene_loader.get_tokens_list_per_log().items()
]
_ = worker_map(worker, cache_features, data_points)
logger.info(f"Finished caching {len(scene_loader)} scenarios for training/validation dataset")
if __name__ == "__main__":
main()