|
import logging |
|
import lzma |
|
import os |
|
import pickle |
|
import traceback |
|
import uuid |
|
from pathlib import Path |
|
from typing import Any, Dict, List, Union, Tuple |
|
|
|
import hydra |
|
import numpy as np |
|
from hydra.utils import instantiate |
|
from nuplan.planning.script.builders.logging_builder import build_logger |
|
from nuplan.planning.utils.multithreading.worker_utils import worker_map |
|
from omegaconf import DictConfig |
|
|
|
from navsim.agents.expansion.scoring.pdm_score import pdm_score_expanded |
|
from navsim.common.dataclasses import SensorConfig |
|
from navsim.common.dataloader import MetricCacheLoader |
|
from navsim.common.dataloader import SceneLoader, SceneFilter |
|
from navsim.planning.metric_caching.metric_cache import MetricCache |
|
from navsim.planning.script.builders.worker_pool_builder import build_worker |
|
from navsim.planning.simulation.planner.pdm_planner.simulation.pdm_simulator import ( |
|
PDMSimulator |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
trajpdm_root = os.getenv('NAVSIM_TRAJPDM_ROOT') |
|
devkit_root = os.getenv('NAVSIM_DEVKIT_ROOT') |
|
CONFIG_PATH = f"{devkit_root}/navsim/planning/script/config/pdm_scoring" |
|
CONFIG_NAME = "expanded_run_pdm_score" |
|
|
|
|
|
@hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME) |
|
def main(cfg: DictConfig) -> None: |
|
vocab_size = cfg.vocab_size |
|
scene_filter_name = cfg.scene_filter_name |
|
traj_path = f"{devkit_root}/traj_final/test_{vocab_size}_kmeans.npy" |
|
dir = f'vocab_expanded_{vocab_size}_{scene_filter_name}' |
|
|
|
build_logger(cfg) |
|
worker = build_worker(cfg) |
|
vocab = np.load(traj_path) |
|
|
|
scene_loader = SceneLoader( |
|
sensor_blobs_path=None, |
|
data_path=Path(cfg.navsim_log_path), |
|
scene_filter=instantiate(cfg.scene_filter), |
|
sensor_config=SensorConfig.build_no_sensors(), |
|
) |
|
os.makedirs(f'{trajpdm_root}/{dir}', exist_ok=True) |
|
result_path = f'{trajpdm_root}/{dir}/{scene_filter_name}.pkl' |
|
print(f'Results will be written to {result_path}') |
|
|
|
data_points = [ |
|
{ |
|
"cfg": cfg, |
|
"log_file": log_file, |
|
"tokens": tokens_list, |
|
"vocab": vocab |
|
} |
|
for log_file, tokens_list in scene_loader.get_tokens_list_per_log().items() |
|
] |
|
new_data_points = [] |
|
for data in data_points: |
|
for token in data['tokens']: |
|
new_data_points.append({ |
|
"cfg": cfg, |
|
"result_dir": dir, |
|
"log_file": data['log_file'], |
|
"token": token, |
|
"vocab": vocab |
|
}) |
|
|
|
score_rows: List[Tuple[Dict[str, Any], int, int]] = worker_map(worker, run_pdm_score, new_data_points) |
|
final = {} |
|
for tmp in score_rows: |
|
final[tmp['token']] = tmp['score'] |
|
pickle.dump(final, open(result_path, 'wb')) |
|
|
|
|
|
def run_pdm_score(args: List[Dict[str, Union[List[str], DictConfig]]]) -> List[Dict[str, Any]]: |
|
node_id = int(os.environ.get("NODE_RANK", 0)) |
|
thread_id = str(uuid.uuid4()) |
|
logger.info(f"Starting worker in thread_id={thread_id}, node_id={node_id}") |
|
|
|
log_names = [a["log_file"] for a in args] |
|
|
|
tokens = [a["token"] for a in args] |
|
cfg: DictConfig = args[0]["cfg"] |
|
result_dir = args[0]["result_dir"] |
|
vocab = args[0]["vocab"] |
|
|
|
simulator: PDMSimulator = instantiate(cfg.simulator) |
|
scorer = instantiate(cfg.scorer) |
|
assert simulator.proposal_sampling == scorer.proposal_sampling, "Simulator and scorer proposal sampling has to be identical" |
|
|
|
metric_cache_loader = MetricCacheLoader(Path(cfg.metric_cache_path)) |
|
scene_filter: SceneFilter = instantiate(cfg.scene_filter) |
|
scene_filter.log_names = log_names |
|
scene_filter.tokens = tokens |
|
scene_loader = SceneLoader( |
|
sensor_blobs_path=Path(cfg.sensor_blobs_path), |
|
data_path=Path(cfg.navsim_log_path), |
|
scene_filter=scene_filter, |
|
) |
|
|
|
tokens_to_evaluate = list(set(scene_loader.tokens) & set(metric_cache_loader.tokens)) |
|
pdm_results: List[Dict[str, Any]] = [] |
|
for idx, (token) in enumerate(tokens_to_evaluate): |
|
logger.info( |
|
f"Processing scenario {idx + 1} / {len(tokens_to_evaluate)} in thread_id={thread_id}, node_id={node_id}" |
|
) |
|
score_row: Dict[str, Any] = {"token": token} |
|
try: |
|
tmp_cache_path = f'{trajpdm_root}/{result_dir}/{token}/tmp.pkl' |
|
if not cfg.get('force_recompute_tmp', False) and os.path.exists(tmp_cache_path): |
|
print(f'Exists: {tmp_cache_path}') |
|
|
|
score_row['score'] = pickle.load(open(tmp_cache_path, 'rb')) |
|
pdm_results.append(score_row) |
|
continue |
|
|
|
metric_cache_path = metric_cache_loader.metric_cache_paths[token] |
|
with lzma.open(metric_cache_path, "rb") as f: |
|
metric_cache: MetricCache = pickle.load(f) |
|
|
|
|
|
pdm_result = pdm_score_expanded( |
|
metric_cache=metric_cache, |
|
vocab_trajectory=vocab, |
|
future_sampling=simulator.proposal_sampling, |
|
simulator=simulator, |
|
scorer=scorer, |
|
expansion_only=cfg.get('expansion_only', True) |
|
) |
|
|
|
score_row['score'] = pdm_result |
|
|
|
os.makedirs(tmp_cache_path.replace('tmp.pkl', ''), exist_ok=True) |
|
pickle.dump(pdm_result, open(tmp_cache_path, 'wb')) |
|
|
|
except Exception as e: |
|
logger.warning(f"----------- Agent failed for token {token}:") |
|
traceback.print_exc() |
|
|
|
pdm_results.append(score_row) |
|
return pdm_results |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|