|
import hydra
|
|
from hydra.utils import instantiate
|
|
import logging
|
|
from omegaconf import DictConfig
|
|
import os
|
|
from pathlib import Path
|
|
import pytorch_lightning as pl
|
|
from typing import Any, Dict, List, Optional, Union
|
|
import uuid
|
|
|
|
from navsim.planning.training.dataset import Dataset
|
|
from navsim.common.dataloader import SceneLoader
|
|
from navsim.common.dataclasses import SceneFilter, SensorConfig
|
|
from navsim.agents.abstract_agent import AbstractAgent
|
|
|
|
from nuplan.planning.utils.multithreading.worker_pool import WorkerPool
|
|
from nuplan.planning.utils.multithreading.worker_utils import worker_map
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
CONFIG_PATH = "config/training"
|
|
CONFIG_NAME = "default_training"
|
|
|
|
def cache_features(args: List[Dict[str, Union[List[str], DictConfig]]]) -> List[Optional[Any]]:
|
|
node_id = int(os.environ.get("NODE_RANK", 0))
|
|
thread_id = str(uuid.uuid4())
|
|
log_names = [a["log_file"] for a in args]
|
|
tokens = [t for a in args for t in a["tokens"]]
|
|
cfg: DictConfig = args[0]["cfg"]
|
|
agent = args[0]['agent']
|
|
scene_filter: SceneFilter =instantiate(cfg.scene_filter)
|
|
scene_filter.log_names = log_names
|
|
scene_filter.tokens = tokens
|
|
|
|
scene_loader = SceneLoader(
|
|
sensor_blobs_path=Path(cfg.sensor_blobs_path),
|
|
data_path=Path(cfg.navsim_log_path),
|
|
scene_filter=scene_filter,
|
|
sensor_config=agent.get_sensor_config(),
|
|
)
|
|
logger.info(
|
|
f"Extracted {len(scene_loader.tokens)} scenarios for thread_id={thread_id}, node_id={node_id}."
|
|
)
|
|
|
|
|
|
dataset = Dataset(
|
|
scene_loader=scene_loader,
|
|
feature_builders=agent.get_feature_builders(),
|
|
target_builders=agent.get_target_builders(),
|
|
cache_path=cfg.cache_path,
|
|
force_cache_computation=cfg.force_cache_computation,
|
|
)
|
|
return []
|
|
|
|
|
|
|
|
|
|
@hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME)
|
|
def main(cfg: DictConfig) -> None:
|
|
logger.info("Global Seed set to 0")
|
|
pl.seed_everything(0, workers=True)
|
|
|
|
logger.info("Building Worker")
|
|
worker: WorkerPool = instantiate(cfg.worker)
|
|
|
|
logger.info("Building SceneLoader")
|
|
scene_filter: SceneFilter = instantiate(cfg.scene_filter)
|
|
data_path = Path(cfg.navsim_log_path)
|
|
sensor_blobs_path = Path(cfg.sensor_blobs_path)
|
|
scene_loader = SceneLoader(
|
|
sensor_blobs_path=sensor_blobs_path,
|
|
data_path=data_path,
|
|
scene_filter=scene_filter,
|
|
sensor_config=SensorConfig.build_no_sensors(),
|
|
)
|
|
agent: AbstractAgent = instantiate(cfg.agent)
|
|
|
|
logger.info(f"Extracted {len(scene_loader)} scenarios for training/validation dataset")
|
|
|
|
data_points = [
|
|
{
|
|
"cfg": cfg,
|
|
"log_file": log_file,
|
|
"tokens": tokens_list,
|
|
"agent": agent
|
|
}
|
|
for log_file, tokens_list in scene_loader.get_tokens_list_per_log().items()
|
|
]
|
|
|
|
_ = worker_map(worker, cache_features, data_points)
|
|
|
|
logger.info(f"Finished caching {len(scene_loader)} scenarios for training/validation dataset")
|
|
|
|
if __name__ == "__main__":
|
|
main() |