|
import io |
|
import logging |
|
import os |
|
import pickle |
|
import uuid |
|
from pathlib import Path |
|
|
|
import hydra |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from hydra.utils import instantiate |
|
from matplotlib.collections import LineCollection |
|
from nuplan.planning.utils.multithreading.worker_utils import worker_map |
|
from omegaconf import DictConfig |
|
from tqdm import tqdm |
|
|
|
from navsim.common.dataclasses import AgentInput, Scene |
|
from navsim.common.dataclasses import SensorConfig |
|
from navsim.common.dataloader import SceneLoader |
|
from navsim.planning.script.builders.worker_pool_builder import build_worker |
|
|
|
logger = logging.getLogger(__name__) |
|
CONFIG_PATH = "../../planning/script/config/pdm_scoring" |
|
CONFIG_NAME = "run_pdm_score_ddp" |
|
norm = plt.Normalize(vmin=0.0, vmax=1.0) |
|
cmap = plt.get_cmap('viridis') |
|
|
|
def get_distribution(scores, vocab, gt_traj): |
|
|
|
metrics = ['gt', 'noc', 'tl', 'progress', 'lk', 'dr'] |
|
fig, axes = plt.subplots(2, 3, figsize=(16.2, 10.8)) |
|
|
|
for i, ax in enumerate(axes.flat): |
|
metric = metrics[i] |
|
ax.set_xlim(-5, 65) |
|
ax.set_ylim(-25, 25) |
|
ax.set_title(f"Metric {metric}") |
|
if metric == 'gt': |
|
ax.plot(gt_traj[:, 0], gt_traj[:, 1], c='r', alpha=1.0) |
|
continue |
|
vocab_scores = scores[metric] |
|
line_collection = LineCollection(vocab[..., :2], |
|
colors=[cmap(norm(score)) for score in vocab_scores], |
|
alpha=[1.0 if score > 0.1 else 0.001 for score in vocab_scores]) |
|
ax.add_collection(line_collection) |
|
|
|
fig.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), cax=fig.add_axes([0.92, 0.15, 0.02, 0.7])) |
|
plt.tight_layout(rect=[0, 0, 0.9, 1]) |
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png') |
|
buf.seek(0) |
|
image = Image.open(buf) |
|
|
|
return image |
|
|
|
|
|
def worker_task(args): |
|
node_id = int(os.environ.get("NODE_RANK", 0)) |
|
thread_id = str(uuid.uuid4()) |
|
logger.info(f"Starting worker in thread_id={thread_id}, node_id={node_id}") |
|
|
|
for arg in tqdm(args, desc="Running visualization"): |
|
token, gt_scores, vocab = arg['token'], arg['gt_scores'], arg['vocab'] |
|
scene_loader = arg['scene_loader'] |
|
agent_input = AgentInput.from_scene_dict_list( |
|
scene_loader.scene_frames_dicts[token], |
|
scene_loader._sensor_blobs_path, |
|
scene_loader._scene_filter.num_history_frames, |
|
scene_loader._sensor_config |
|
) |
|
gt_traj = Scene.from_scene_dict_list( |
|
scene_loader.scene_frames_dicts[token], |
|
scene_loader._sensor_blobs_path, |
|
scene_loader._scene_filter.num_history_frames, |
|
10, |
|
scene_loader._sensor_config |
|
).get_future_trajectory(int(4 / 0.5)) |
|
|
|
gt_traj = gt_traj.poses |
|
|
|
|
|
cam = agent_input.cameras[-1].cam_f0 |
|
img, cam2lidar_rot, cam2lidar_tran, cam_intrin = cam.image, cam.sensor2lidar_rotation, cam.sensor2lidar_translation, cam.intrinsics |
|
|
|
img = Image.fromarray(img.astype('uint8'), 'RGB') |
|
|
|
|
|
figs = get_distribution(gt_scores, vocab, gt_traj) |
|
|
|
|
|
total_width = img.width + figs.width |
|
max_height = max(img.height, figs.height) |
|
new_image = Image.new('RGB', (total_width, max_height)) |
|
new_image.paste(img, (0, 0)) |
|
new_image.paste(figs, (img.width, 0)) |
|
|
|
output_dir = args[0]['result_dir'] |
|
new_image.save(f'{output_dir}/{token}/{token}.png') |
|
|
|
return [] |
|
|
|
|
|
@hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME) |
|
def main(cfg: DictConfig) -> None: |
|
data_path = Path(cfg.navsim_log_path) |
|
sensor_blobs_path = Path(cfg.sensor_blobs_path) |
|
scene_filter = instantiate(cfg.scene_filter) |
|
scene_loader = SceneLoader( |
|
data_path=data_path, |
|
scene_filter=scene_filter, |
|
sensor_blobs_path=sensor_blobs_path, |
|
sensor_config=SensorConfig( |
|
cam_f0=True, |
|
cam_l0=True, |
|
cam_l1=True, |
|
cam_l2=True, |
|
cam_r0=True, |
|
cam_r1=True, |
|
cam_r2=True, |
|
cam_b0=True, |
|
lidar_pc=False, |
|
) |
|
) |
|
worker = build_worker(cfg) |
|
result_dir = f'{os.getenv("NAVSIM_TRAJPDM_ROOT")}/vocab_expanded_{cfg.vocab_size}_{cfg.scene_filter_name}' |
|
vocab = np.load(f'{os.getenv("NAVSIM_DEVKIT_ROOT")}/traj_final/test_{cfg.vocab_size}_kmeans.npy') |
|
|
|
data_points = [] |
|
valid_tokens = os.listdir(result_dir) |
|
valid_tokens = set(valid_tokens) & set(scene_loader.tokens) |
|
for token in tqdm(valid_tokens): |
|
gt_scores = pickle.load(open(f'{result_dir}/{token}/tmp.pkl', 'rb')) |
|
data_points.append({ |
|
'token': token, |
|
'scene_loader': scene_loader, |
|
'result_dir': result_dir, |
|
'vocab': vocab, |
|
'gt_scores': gt_scores, |
|
}) |
|
|
|
worker_map(worker, worker_task, data_points) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|