import pickle import matplotlib.pyplot as plt import numpy as np import torch from tqdm import tqdm from sklearn.cluster import MiniBatchKMeans @torch.no_grad() def main(): vocab_size = 8192 split = 'test' ori_traj = np.concatenate([ np.load('./traj_local/test-pt1.npy'), np.load('./traj_local/test-pt2.npy'), np.load('./traj_local/test-pt3.npy') ], axis=0) L, HORIZON, DIM = ori_traj.shape # MINI-BATCH all_traj = ori_traj.reshape(L, -1) clustering = MiniBatchKMeans(vocab_size, batch_size=1024, verbose=True, tol=0.0).fit(all_traj) anchors = clustering.cluster_centers_.reshape(vocab_size, HORIZON, DIM) cnt = np.zeros(vocab_size, dtype=np.int64) for i in range(vocab_size): cnt[i] = (clustering.labels_ == i).sum() cnt = np.clip(cnt, 1, vocab_size) np.save(f'./traj_final/{split}_{vocab_size}_kmeans.npy', anchors) np.save(f'./traj_final/{split}_{vocab_size}_kmeans_cnt.npy', cnt) print(f'result saved to ./traj_final/{split}_{vocab_size}_kmeans.npy') # plot vis(anchors) def vis(data): vocab_size = data.shape[0] fig, ax = plt.subplots() for i in range(vocab_size): ax.plot(data[i, :, 0], data[i, :, 1]) ax.legend() plt.show() def vis_pdm(data, pdm): for k, v in pdm.items(): mask = v > 0.95 vocab_size = data.shape[0] fig, ax = plt.subplots() for i in range(vocab_size): if mask[i]: ax.plot(data[i, :, 0], data[i, :, 1]) ax.legend() plt.show() break if __name__ == '__main__': # main() vis_pdm(np.load(f'./traj_final/test_4096_kmeans.npy'), pickle.load(open('./vocab_score_local/tiny.pkl', 'rb')))