File size: 2,794 Bytes
da2e2ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from __future__ import annotations

import os
import uuid

import numpy as np
from nuplan.planning.scenario_builder.nuplan_db.nuplan_scenario_builder import NuPlanScenarioBuilder
from nuplan.planning.scenario_builder.nuplan_db.nuplan_scenario_utils import ScenarioMapping
from nuplan.planning.scenario_builder.scenario_filter import ScenarioFilter
from nuplan.planning.utils.multithreading.worker_utils import worker_map

from navsim.planning.simulation.planner.pdm_planner.utils.pdm_geometry_utils import \
    convert_absolute_to_relative_se2_array
from navsim.planning.utils.multithreading.worker_ray_no_torch import RayDistributedNoTorch


def get_local_ego_poses(scenarios):
    results = []
    thread_id = str(uuid.uuid4())
    for idx, scenario in enumerate(scenarios):
        print(
            f"Processing scenario {idx + 1} / {len(scenarios)} in thread_id={thread_id}"
        )
        init_ego_state = scenario.initial_ego_state
        future_traj = scenario.get_ego_future_trajectory(0, 4)
        local_ego_poses = convert_absolute_to_relative_se2_array(
            init_ego_state.center, np.array([tmp.center.serialize() for tmp in future_traj], dtype=np.float64)
        )
        results.append(local_ego_poses[None].astype(np.float32))
    return results


def main():
    root = '/mnt/g'
    split = 'test'
    logs = os.listdir(f'{root}/nuplan/nuplan-v1.1/splits/{split}')
    logs = [tmp.replace('.db', '') for tmp in logs]
    navsim_logs = [log.replace('.pkl', '') for log in os.listdir(f'{root}/navsim/navsim_logs/{split}')]
    start_idx = 400000
    end_idx = 600000
    save_dir = './traj_local'
    os.makedirs(save_dir, exist_ok=True)
    save_file = f'{save_dir}/{split}-pt3.npy'

    logs = list(set(logs) & set(navsim_logs))
    print(f'total logs: {len(logs)}')
    filter = ScenarioFilter(
        None, None,
        logs,
        None, None, None, None, None, False, False, False
    )
    worker = RayDistributedNoTorch(threads_per_node=16)

    builder = NuPlanScenarioBuilder(
        data_root=f'{root}/nuplan/',
        map_root=f'{root}/nuplan/maps',
        sensor_root=f'{root}/nuplan/',
        db_files=f'{root}/nuplan/nuplan-v1.1/splits/{split}',
        map_version='nuplan-maps-v1.0',
        scenario_mapping=ScenarioMapping({}, 0.5)
    )
    scenarios = builder.get_scenarios(filter, worker)

    print(f'total scenarios: {len(scenarios)}, now: {start_idx} to {end_idx}')
    all_ego_poses = worker_map(worker, get_local_ego_poses, scenarios[start_idx:end_idx])

    all_ego_poses = np.concatenate(all_ego_poses, axis=0)
    print(f'save to: {save_file}')
    np.save(save_file, all_ego_poses)
    print(all_ego_poses.shape)


if __name__ == '__main__':
    main()