import pickle import matplotlib.pyplot as plt import numpy as np def vis(data): vocab_size = data.shape[0] fig, ax = plt.subplots() for i in range(vocab_size): ax.plot(data[i, :, 0], data[i, :, 1]) ax.legend() plt.show() plt.savefig('debug/traj.png') def vis_pdm(data, pdm): for k, scores in pdm.items(): print(k) for m, v in scores.items(): mask = v > 0.95 vocab_size = data.shape[0] fig, ax = plt.subplots() reds = [] for i in range(vocab_size): if mask[i]: reds.append(data[i]) # ax.plot(data[i, :, 0], data[i, :, 1], 'r', alpha=1.0) else: ax.plot(data[i, :, 0], data[i, :, 1], 'k', alpha=0.1) for red in reds: ax.plot(red[:, 0], red[:, 1], 'r', alpha=1.0) ax.legend() plt.show() plt.savefig(f'debug/traj_{m}.png') return if __name__ == '__main__': # vis(np.load(f'./traj_final/test_4096_kmeans.npy')) vis_pdm(np.load(f'./traj_final/test_4096_kmeans.npy'), pickle.load(open('/mnt/g/navsim/traj_pdm/vocab_score_full_4096/tiny.pkl', 'rb')))