|
import logging |
|
import os |
|
import pickle |
|
|
|
import numpy as np |
|
import torch |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
""" |
|
pkl -> search params and calculation process |
|
""" |
|
|
|
|
|
pkl_path = '/mnt/g/navsim_vis/subscores/dreamer_wm_2sec.pkl' |
|
valid_k_path = '/mnt/g/navsim_vis/subscores/dreamer_wm_2sec.pkl' |
|
|
|
def main() -> None: |
|
valid_keys = set(pickle.load(open(valid_k_path, 'rb')).keys()) |
|
merged_predictions = pickle.load(open(pkl_path, 'rb')) |
|
navtest_scores = pickle.load( |
|
open(f'/mnt/g/navsim/traj_pdm/vocab_score_full_8192_navtest/navtest.pkl', 'rb') |
|
) |
|
navtest_scores = {key: value for key, value in navtest_scores.items() if key in valid_keys} |
|
|
|
|
|
imi_weights = [0.01] |
|
noc_weights = [0.1] |
|
da_weights = [0.5] |
|
tpc_weights = [3.0] |
|
ttc_weights = [5.0] |
|
progress_weights = [5.0] |
|
comfort_weights = [2.0] |
|
|
|
print( |
|
f'Search space: {len(imi_weights) * len(noc_weights) * len(da_weights) * len(tpc_weights) * len(ttc_weights) * len(progress_weights) * len(comfort_weights)}') |
|
|
|
(imi_preds, |
|
noc_preds, |
|
da_preds, |
|
dd_preds, |
|
ttc_preds, |
|
progress_preds, |
|
comfort_preds) = ([], [], |
|
[], [], |
|
[], [], |
|
[]) |
|
pdm_scores, noc_scores, da_scores, dd_scores, ttc_scores, progress_scores, comfort_scores = ( |
|
[], [], [], [], [], [], []) |
|
total_scene_cnt = len(navtest_scores) |
|
print(f'total_scene_cnt: {total_scene_cnt}') |
|
for k, v in navtest_scores.items(): |
|
pdm_scores.append(torch.from_numpy(v['total'][None]).cuda()) |
|
noc_scores.append(torch.from_numpy(v['noc'][None]).cuda()) |
|
da_scores.append(torch.from_numpy(v['da'][None]).cuda()) |
|
dd_scores.append(torch.from_numpy(v['dd'][None]).cuda()) |
|
ttc_scores.append(torch.from_numpy(v['ttc'][None]).cuda()) |
|
progress_scores.append(torch.from_numpy(v['progress'][None]).cuda()) |
|
comfort_scores.append(torch.from_numpy(v['comfort'][None]).cuda()) |
|
imi_preds.append(torch.from_numpy(merged_predictions[k]['imi'][None]).cuda()) |
|
noc_preds.append(torch.from_numpy(merged_predictions[k]['noc'][None]).cuda()) |
|
da_preds.append(torch.from_numpy(merged_predictions[k]['da'][None]).cuda()) |
|
ttc_preds.append(torch.from_numpy(merged_predictions[k]['ttc'][None]).cuda()) |
|
progress_preds.append(torch.from_numpy(merged_predictions[k]['progress'][None]).cuda()) |
|
comfort_preds.append(torch.from_numpy(merged_predictions[k]['comfort'][None]).cuda()) |
|
|
|
pdm_scores = torch.cat(pdm_scores, 0).contiguous() |
|
noc_scores = torch.cat(noc_scores, 0).contiguous() |
|
da_scores = torch.cat(da_scores, 0).contiguous() |
|
dd_scores = torch.cat(dd_scores, 0).contiguous() |
|
ttc_scores = torch.cat(ttc_scores, 0).contiguous() |
|
progress_scores = torch.cat(progress_scores, 0).contiguous() |
|
comfort_scores = torch.cat(comfort_scores, 0).contiguous() |
|
imi_preds = torch.cat(imi_preds, 0).contiguous() |
|
noc_preds = torch.cat(noc_preds, 0).contiguous() |
|
da_preds = torch.cat(da_preds, 0).contiguous() |
|
ttc_preds = torch.cat(ttc_preds, 0).contiguous() |
|
progress_preds = torch.cat(progress_preds, 0).contiguous() |
|
comfort_preds = torch.cat(comfort_preds, 0).contiguous() |
|
rows = [] |
|
highest_info = { |
|
'score': -100, |
|
} |
|
for imi_weight in imi_weights: |
|
for noc_weight in noc_weights: |
|
for da_weight in da_weights: |
|
for ttc_weight in ttc_weights: |
|
for comfort_weight in comfort_weights: |
|
for progress_weight in progress_weights: |
|
for tpc_weight in tpc_weights: |
|
|
|
scores = ( |
|
imi_weight * imi_preds + |
|
noc_weight * noc_preds + |
|
da_weight * da_preds + |
|
tpc_weight * ( |
|
ttc_weight * torch.exp(ttc_preds) + |
|
comfort_weight * torch.exp(comfort_preds) + |
|
progress_weight * torch.exp(progress_preds) |
|
).log() |
|
) |
|
chosen_idx = scores.argmax(-1) |
|
scene_cnt_tensor = torch.arange(total_scene_cnt, device=pdm_scores.device) |
|
pdm_score = pdm_scores[ |
|
scene_cnt_tensor, |
|
chosen_idx |
|
] |
|
noc_score = noc_scores[ |
|
scene_cnt_tensor, |
|
chosen_idx |
|
] |
|
da_score = da_scores[ |
|
scene_cnt_tensor, |
|
chosen_idx |
|
] |
|
dd_score = dd_scores[ |
|
scene_cnt_tensor, |
|
chosen_idx |
|
] |
|
ttc_score = ttc_scores[ |
|
scene_cnt_tensor, |
|
chosen_idx |
|
] |
|
progress_score = progress_scores[ |
|
scene_cnt_tensor, |
|
chosen_idx |
|
] |
|
comfort_score = comfort_scores[ |
|
scene_cnt_tensor, |
|
chosen_idx |
|
] |
|
|
|
pdm_score = pdm_score.mean().item() |
|
noc_score = noc_score.float().mean().item() |
|
da_score = da_score.float().mean().item() |
|
dd_score = dd_score.float().mean().item() |
|
ttc_score = ttc_score.float().mean().item() |
|
progress_score = progress_score.float().mean().item() |
|
comfort_score = comfort_score.float().mean().item() |
|
row = { |
|
'imi_weight': imi_weight, |
|
'noc_weight': noc_weight, |
|
'da_weight': da_weight, |
|
'ttc_weight': ttc_weight, |
|
'progress_weight': progress_weight, |
|
'comfort_weight': comfort_weight, |
|
'tpc_weight': tpc_weight, |
|
'overall_score': pdm_score |
|
} |
|
if pdm_score > highest_info['score']: |
|
highest_info['score'] = pdm_score |
|
highest_info['noc'] = noc_score |
|
highest_info['da'] = da_score |
|
highest_info['dd'] = dd_score |
|
highest_info['ttc'] = ttc_score |
|
highest_info['progress'] = progress_score |
|
highest_info['comfort'] = comfort_score |
|
highest_info['imi_weight'] = imi_weight |
|
highest_info['noc_weight'] = noc_weight |
|
highest_info['da_weight'] = da_weight |
|
highest_info['ttc_weight'] = ttc_weight |
|
highest_info['progress_weight'] = progress_weight |
|
highest_info['comfort_weight'] = comfort_weight |
|
highest_info['tpc_weight'] = tpc_weight |
|
print(f'Done: {len(rows)}. score: {pdm_score}') |
|
rows.append(row) |
|
|
|
|
|
|
|
for k, v in highest_info.items(): |
|
print(k, v) |
|
|
|
|
|
if __name__ == "__main__": |
|
with torch.no_grad(): |
|
main() |
|
|