File size: 5,776 Bytes
da2e2ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import logging
import lzma
import os
import pickle
import traceback
import uuid
from pathlib import Path
from typing import Any, Dict, List, Union, Tuple

import hydra
import numpy as np
from hydra.utils import instantiate
from nuplan.planning.script.builders.logging_builder import build_logger
from nuplan.planning.utils.multithreading.worker_utils import worker_map
from omegaconf import DictConfig

from navsim.agents.expansion.scoring.pdm_score import pdm_score_expanded
from navsim.common.dataclasses import SensorConfig
from navsim.common.dataloader import MetricCacheLoader
from navsim.common.dataloader import SceneLoader, SceneFilter
from navsim.planning.metric_caching.metric_cache import MetricCache
from navsim.planning.script.builders.worker_pool_builder import build_worker
from navsim.planning.simulation.planner.pdm_planner.simulation.pdm_simulator import (
    PDMSimulator
)

logger = logging.getLogger(__name__)
trajpdm_root = os.getenv('NAVSIM_TRAJPDM_ROOT')
devkit_root = os.getenv('NAVSIM_DEVKIT_ROOT')
CONFIG_PATH = f"{devkit_root}/navsim/planning/script/config/pdm_scoring"
CONFIG_NAME = "expanded_run_pdm_score"


@hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME)
def main(cfg: DictConfig) -> None:
    vocab_size = cfg.vocab_size
    scene_filter_name = cfg.scene_filter_name
    traj_path = f"{devkit_root}/traj_final/test_{vocab_size}_kmeans.npy"
    dir = f'vocab_expanded_{vocab_size}_{scene_filter_name}'

    build_logger(cfg)
    worker = build_worker(cfg)
    vocab = np.load(traj_path)
    # Extract scenes based on scene-loader to know which tokens to distribute across workers
    scene_loader = SceneLoader(
        sensor_blobs_path=None,
        data_path=Path(cfg.navsim_log_path),
        scene_filter=instantiate(cfg.scene_filter),
        sensor_config=SensorConfig.build_no_sensors(),
    )
    os.makedirs(f'{trajpdm_root}/{dir}', exist_ok=True)
    result_path = f'{trajpdm_root}/{dir}/{scene_filter_name}.pkl'
    print(f'Results will be written to {result_path}')

    data_points = [
        {
            "cfg": cfg,
            "log_file": log_file,
            "tokens": tokens_list,
            "vocab": vocab
        }
        for log_file, tokens_list in scene_loader.get_tokens_list_per_log().items()
    ]
    new_data_points = []
    for data in data_points:
        for token in data['tokens']:
            new_data_points.append({
                "cfg": cfg,
                "result_dir": dir,
                "log_file": data['log_file'],
                "token": token,
                "vocab": vocab
            })

    score_rows: List[Tuple[Dict[str, Any], int, int]] = worker_map(worker, run_pdm_score, new_data_points)
    final = {}
    for tmp in score_rows:
        final[tmp['token']] = tmp['score']
    pickle.dump(final, open(result_path, 'wb'))


def run_pdm_score(args: List[Dict[str, Union[List[str], DictConfig]]]) -> List[Dict[str, Any]]:
    node_id = int(os.environ.get("NODE_RANK", 0))
    thread_id = str(uuid.uuid4())
    logger.info(f"Starting worker in thread_id={thread_id}, node_id={node_id}")

    log_names = [a["log_file"] for a in args]
    # tokens = [t for a in args for t in a["tokens"]]
    tokens = [a["token"] for a in args]
    cfg: DictConfig = args[0]["cfg"]
    result_dir = args[0]["result_dir"]
    vocab = args[0]["vocab"]

    simulator: PDMSimulator = instantiate(cfg.simulator)
    scorer = instantiate(cfg.scorer)
    assert simulator.proposal_sampling == scorer.proposal_sampling, "Simulator and scorer proposal sampling has to be identical"

    metric_cache_loader = MetricCacheLoader(Path(cfg.metric_cache_path))
    scene_filter: SceneFilter = instantiate(cfg.scene_filter)
    scene_filter.log_names = log_names
    scene_filter.tokens = tokens
    scene_loader = SceneLoader(
        sensor_blobs_path=Path(cfg.sensor_blobs_path),
        data_path=Path(cfg.navsim_log_path),
        scene_filter=scene_filter,
    )

    tokens_to_evaluate = list(set(scene_loader.tokens) & set(metric_cache_loader.tokens))
    pdm_results: List[Dict[str, Any]] = []
    for idx, (token) in enumerate(tokens_to_evaluate):
        logger.info(
            f"Processing scenario {idx + 1} / {len(tokens_to_evaluate)} in thread_id={thread_id}, node_id={node_id}"
        )
        score_row: Dict[str, Any] = {"token": token}
        try:
            tmp_cache_path = f'{trajpdm_root}/{result_dir}/{token}/tmp.pkl'
            if not cfg.get('force_recompute_tmp', False) and os.path.exists(tmp_cache_path):
                print(f'Exists: {tmp_cache_path}')
                # load cache
                score_row['score'] = pickle.load(open(tmp_cache_path, 'rb'))
                pdm_results.append(score_row)
                continue

            metric_cache_path = metric_cache_loader.metric_cache_paths[token]
            with lzma.open(metric_cache_path, "rb") as f:
                metric_cache: MetricCache = pickle.load(f)

            # transform vocab into traj
            pdm_result = pdm_score_expanded(
                metric_cache=metric_cache,
                vocab_trajectory=vocab,
                future_sampling=simulator.proposal_sampling,
                simulator=simulator,
                scorer=scorer,
                expansion_only=cfg.get('expansion_only', True)
            )

            score_row['score'] = pdm_result
            #     save cache
            os.makedirs(tmp_cache_path.replace('tmp.pkl', ''), exist_ok=True)
            pickle.dump(pdm_result, open(tmp_cache_path, 'wb'))

        except Exception as e:
            logger.warning(f"----------- Agent failed for token {token}:")
            traceback.print_exc()

        pdm_results.append(score_row)
    return pdm_results


if __name__ == "__main__":
    main()