import pickle | |
import torch | |
import os | |
""" | |
subscores -> total score | |
""" | |
root = f'{os.getenv("NAVSIM_EXP_ROOT")}/v299_vis' | |
subscores_name = 'v299-subscores' | |
subscores = pickle.load(open(f'{root}/{subscores_name}.pkl', 'rb')) | |
for token, subscore in subscores.items(): | |
for k, v in subscore.items(): | |
if k != 'trajectory': | |
subscore[k] = torch.from_numpy(v) | |
subscores[token]['total'] = ( | |
0.02 * subscore['imi'] + | |
0.7 * subscore['noc'] + | |
0.1 * subscore['da'] + | |
8.0 * (( | |
5 * torch.exp(subscore['ttc']) + | |
2 * torch.exp(subscore['comfort']) + | |
5 * torch.exp(subscore['progress']) | |
) / 12.0).log() | |
) | |
for token, subscore in subscores.items(): | |
for k, v in subscore.items(): | |
if k != 'trajectory': | |
subscore[k] = v.numpy() | |
pickle.dump(subscores, open(f'{root}/{subscores_name}-total.pkl', 'wb')) | |