navsim_ours / navsim /agents /scripts /gen_vocab_kmeans.py
lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
raw
history blame
1.81 kB
import pickle
import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm import tqdm
from sklearn.cluster import MiniBatchKMeans
@torch.no_grad()
def main():
vocab_size = 8192
split = 'test'
ori_traj = np.concatenate([
np.load('./traj_local/test-pt1.npy'),
np.load('./traj_local/test-pt2.npy'),
np.load('./traj_local/test-pt3.npy')
], axis=0)
L, HORIZON, DIM = ori_traj.shape
# MINI-BATCH
all_traj = ori_traj.reshape(L, -1)
clustering = MiniBatchKMeans(vocab_size, batch_size=1024, verbose=True, tol=0.0).fit(all_traj)
anchors = clustering.cluster_centers_.reshape(vocab_size, HORIZON, DIM)
cnt = np.zeros(vocab_size, dtype=np.int64)
for i in range(vocab_size):
cnt[i] = (clustering.labels_ == i).sum()
cnt = np.clip(cnt, 1, vocab_size)
np.save(f'./traj_final/{split}_{vocab_size}_kmeans.npy', anchors)
np.save(f'./traj_final/{split}_{vocab_size}_kmeans_cnt.npy', cnt)
print(f'result saved to ./traj_final/{split}_{vocab_size}_kmeans.npy')
# plot
vis(anchors)
def vis(data):
vocab_size = data.shape[0]
fig, ax = plt.subplots()
for i in range(vocab_size):
ax.plot(data[i, :, 0], data[i, :, 1])
ax.legend()
plt.show()
def vis_pdm(data, pdm):
for k, v in pdm.items():
mask = v > 0.95
vocab_size = data.shape[0]
fig, ax = plt.subplots()
for i in range(vocab_size):
if mask[i]:
ax.plot(data[i, :, 0], data[i, :, 1])
ax.legend()
plt.show()
break
if __name__ == '__main__':
# main()
vis_pdm(np.load(f'./traj_final/test_4096_kmeans.npy'),
pickle.load(open('./vocab_score_local/tiny.pkl', 'rb')))