|
import pickle |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import torch |
|
from tqdm import tqdm |
|
|
|
from sklearn.cluster import MiniBatchKMeans |
|
|
|
from navsim.agents.scripts.gen_vocab_kmeans import vis |
|
|
|
|
|
@torch.no_grad() |
|
def main(): |
|
vocab_size = 4096 |
|
shift_xy = True |
|
ori_traj = np.concatenate([ |
|
np.load('./traj_local/test-pt1.npy'), |
|
np.load('./traj_local/test-pt2.npy'), |
|
np.load('./traj_local/test-pt3.npy') |
|
], axis=0) |
|
ori_traj = ori_traj[:, :30, :2] |
|
sampled_timepoints = [5 * k - 1 for k in range(1, 7)] |
|
ori_traj = ori_traj[:, sampled_timepoints] |
|
if shift_xy: |
|
ori_traj = ori_traj[..., ::-1] |
|
L, HORIZON, DIM = ori_traj.shape |
|
|
|
all_traj = ori_traj.reshape(L, -1) |
|
clustering = MiniBatchKMeans(vocab_size, batch_size=1024, verbose=True, tol=0.0).fit(all_traj) |
|
anchors = clustering.cluster_centers_.reshape(vocab_size, HORIZON, DIM) |
|
cnt = np.zeros(vocab_size, dtype=np.int64) |
|
for i in range(vocab_size): |
|
cnt[i] = (clustering.labels_ == i).sum() |
|
np.save(f'./traj_final/{vocab_size}_kmeans_3sec_xy.npy', anchors) |
|
print(f'result saved to ./traj_final/{vocab_size}_kmeans_3sec_xy.npy') |
|
|
|
vis(anchors) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |