import pickle import matplotlib.pyplot as plt import numpy as np import torch from tqdm import tqdm from sklearn.cluster import MiniBatchKMeans import random from tqdm import tqdm @torch.no_grad() def main(): vocab_size = 512 split = 'test' ori_traj = np.concatenate([ np.load('./traj_local/test-pt1.npy'), # np.load('./traj_local/test-pt2.npy'), # np.load('./traj_local/test-pt3.npy') ], axis=0) L, HORIZON, DIM = ori_traj.shape k = random.randint(0, L - 1) selected = [] selected.append(np.copy(ori_traj[k])[None]) ori_traj = np.delete(ori_traj, k, axis=0) for _ in tqdm(range(vocab_size - 1)): max_dis = 0 candidate = None for traj in ori_traj: traj = traj[None] vocab_curr = np.concatenate(selected, 0) dis = (traj[:, -1, :2] - vocab_curr[:, -1, :2]) ** 2 dis = dis.sum(-1).min(0) if dis > max_dis: candidate = traj selected.append(candidate) anchors = np.concatenate(selected, 0) np.save(f'./traj_final/{split}_{vocab_size}_far.npy', anchors) print(f'result saved to ./traj_final/{split}_{vocab_size}_far.npy') # plot vis(anchors) def vis(data): vocab_size = data.shape[0] fig, ax = plt.subplots() for i in range(vocab_size): ax.plot(data[i, :, 0], data[i, :, 1]) ax.legend() plt.show() def vis_pdm(data, pdm): for k, v in pdm.items(): mask = v > 0.95 vocab_size = data.shape[0] fig, ax = plt.subplots() for i in range(vocab_size): if mask[i]: ax.plot(data[i, :, 0], data[i, :, 1]) ax.legend() plt.show() break if __name__ == '__main__': main()