|
import pickle
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
from tqdm import tqdm
|
|
|
|
from sklearn.cluster import MiniBatchKMeans
|
|
|
|
|
|
@torch.no_grad()
|
|
def main():
|
|
vocab_size = 8192
|
|
split = 'test'
|
|
ori_traj = np.concatenate([
|
|
np.load('./traj_local/test-pt1.npy'),
|
|
np.load('./traj_local/test-pt2.npy'),
|
|
np.load('./traj_local/test-pt3.npy')
|
|
], axis=0)
|
|
L, HORIZON, DIM = ori_traj.shape
|
|
|
|
all_traj = ori_traj.reshape(L, -1)
|
|
clustering = MiniBatchKMeans(vocab_size, batch_size=1024, verbose=True, tol=0.0).fit(all_traj)
|
|
anchors = clustering.cluster_centers_.reshape(vocab_size, HORIZON, DIM)
|
|
cnt = np.zeros(vocab_size, dtype=np.int64)
|
|
for i in range(vocab_size):
|
|
cnt[i] = (clustering.labels_ == i).sum()
|
|
cnt = np.clip(cnt, 1, vocab_size)
|
|
np.save(f'./traj_final/{split}_{vocab_size}_kmeans.npy', anchors)
|
|
np.save(f'./traj_final/{split}_{vocab_size}_kmeans_cnt.npy', cnt)
|
|
print(f'result saved to ./traj_final/{split}_{vocab_size}_kmeans.npy')
|
|
|
|
vis(anchors)
|
|
|
|
|
|
def vis(data):
|
|
vocab_size = data.shape[0]
|
|
fig, ax = plt.subplots()
|
|
for i in range(vocab_size):
|
|
ax.plot(data[i, :, 0], data[i, :, 1])
|
|
|
|
ax.legend()
|
|
plt.show()
|
|
|
|
def vis_pdm(data, pdm):
|
|
for k, v in pdm.items():
|
|
mask = v > 0.95
|
|
vocab_size = data.shape[0]
|
|
fig, ax = plt.subplots()
|
|
for i in range(vocab_size):
|
|
if mask[i]:
|
|
ax.plot(data[i, :, 0], data[i, :, 1])
|
|
|
|
ax.legend()
|
|
plt.show()
|
|
break
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
vis_pdm(np.load(f'./traj_final/test_4096_kmeans.npy'),
|
|
pickle.load(open('./vocab_score_local/tiny.pkl', 'rb')))
|
|
|