import pickle import matplotlib.pyplot as plt import numpy as np import torch from tqdm import tqdm from sklearn.cluster import MiniBatchKMeans from navsim.agents.scripts.gen_vocab_kmeans import vis @torch.no_grad() def main(): vocab_size = 4096 shift_xy = True ori_traj = np.concatenate([ np.load('./traj_local/test-pt1.npy'), np.load('./traj_local/test-pt2.npy'), np.load('./traj_local/test-pt3.npy') ], axis=0) ori_traj = ori_traj[:, :30, :2] sampled_timepoints = [5 * k - 1 for k in range(1, 7)] ori_traj = ori_traj[:, sampled_timepoints] if shift_xy: ori_traj = ori_traj[..., ::-1] L, HORIZON, DIM = ori_traj.shape # MINI-BATCH all_traj = ori_traj.reshape(L, -1) clustering = MiniBatchKMeans(vocab_size, batch_size=1024, verbose=True, tol=0.0).fit(all_traj) anchors = clustering.cluster_centers_.reshape(vocab_size, HORIZON, DIM) cnt = np.zeros(vocab_size, dtype=np.int64) for i in range(vocab_size): cnt[i] = (clustering.labels_ == i).sum() np.save(f'./traj_final/{vocab_size}_kmeans_3sec_xy.npy', anchors) print(f'result saved to ./traj_final/{vocab_size}_kmeans_3sec_xy.npy') # plot vis(anchors) if __name__ == '__main__': main()