|
import copy |
|
import pickle |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from tqdm import tqdm |
|
|
|
root = '/mnt/g/navsim_vis/subscores' |
|
gt_path = '/mnt/g/navsim/traj_pdm/vocab_score_full_8192_navtest/navtest.pkl' |
|
|
|
dreamer_pkl = 'dreamer_wm_3f.pkl' |
|
hydra_vitl_pkl = 'hydra_vitl_subscores.pkl' |
|
|
|
|
|
def analyze(results): |
|
threshold = 0.5 |
|
gt, pred_dreamer, pred_hydra = results['gt'], results['dreamer'], results['hydra'] |
|
length = gt['noc'].shape[-1] |
|
print(f'Data points: {length}') |
|
for metric in gt: |
|
gt_curr = gt[metric] |
|
dreamer_curr = pred_dreamer[metric] |
|
hydra_curr = pred_hydra[metric] |
|
print( |
|
f'metric {metric}: bce dreamer: {F.binary_cross_entropy(dreamer_curr, gt_curr.float(), reduction="mean")}' |
|
) |
|
print( |
|
f'metric {metric}: bce hydra: {F.binary_cross_entropy(hydra_curr, gt_curr.float(), reduction="mean")}' |
|
) |
|
if metric == 'progress': |
|
print( |
|
f'metric {metric}: mse dreamer: {F.mse_loss(dreamer_curr, gt_curr.float(), reduction="sum") / length}' |
|
) |
|
print( |
|
f'metric {metric}: mse hydra: {F.mse_loss(hydra_curr, gt_curr.float(), reduction="sum") / length}' |
|
) |
|
else: |
|
|
|
print( |
|
f'metric {metric}: acc dreamer: {((dreamer_curr >= threshold) == (gt_curr >= 0.8)).float().mean()}' |
|
) |
|
print( |
|
f'metric {metric}: acc hydra: {((hydra_curr >= threshold) == (gt_curr >= 0.8)).float().mean()}' |
|
) |
|
|
|
|
|
def main(): |
|
gt = pickle.load(open(gt_path, 'rb')) |
|
dreamer = pickle.load(open(f'{root}/{dreamer_pkl}', 'rb')) |
|
hydra = pickle.load(open(f'{root}/{hydra_vitl_pkl}', 'rb')) |
|
dict_template = { |
|
'noc': [], 'da': [], 'ttc': [], 'comfort': [], 'progress': [] |
|
} |
|
results = { |
|
'gt': copy.deepcopy(dict_template), |
|
'dreamer': copy.deepcopy(dict_template), |
|
'hydra': copy.deepcopy(dict_template) |
|
} |
|
valid_keys = set(dreamer.keys()) |
|
|
|
for (k, gt_score) in tqdm(gt.items()): |
|
if k not in valid_keys: |
|
continue |
|
hydra_score, dreamer_score = hydra[k], dreamer[k] |
|
for metric in dict_template: |
|
results['gt'][metric].append(torch.from_numpy(gt_score[metric][..., None]).cuda()) |
|
results['dreamer'][metric].append(torch.from_numpy(dreamer_score[metric][..., None]).cuda().exp()) |
|
results['hydra'][metric].append(torch.from_numpy(hydra_score[metric][..., None]).cuda().exp()) |
|
for _, allscores in results.items(): |
|
for metric in dict_template: |
|
allscores[metric] = torch.cat(allscores[metric], dim=-1) |
|
analyze(results) |
|
|
|
|
|
if __name__ == '__main__': |
|
with torch.no_grad(): |
|
main() |
|
|