navsim_ours / navsim /agents /scripts /gen_vocab_kmeans_3sec.py
lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
raw
history blame
1.26 kB
import pickle
import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm import tqdm
from sklearn.cluster import MiniBatchKMeans
from navsim.agents.scripts.gen_vocab_kmeans import vis
@torch.no_grad()
def main():
vocab_size = 4096
shift_xy = True
ori_traj = np.concatenate([
np.load('./traj_local/test-pt1.npy'),
np.load('./traj_local/test-pt2.npy'),
np.load('./traj_local/test-pt3.npy')
], axis=0)
ori_traj = ori_traj[:, :30, :2]
sampled_timepoints = [5 * k - 1 for k in range(1, 7)]
ori_traj = ori_traj[:, sampled_timepoints]
if shift_xy:
ori_traj = ori_traj[..., ::-1]
L, HORIZON, DIM = ori_traj.shape
# MINI-BATCH
all_traj = ori_traj.reshape(L, -1)
clustering = MiniBatchKMeans(vocab_size, batch_size=1024, verbose=True, tol=0.0).fit(all_traj)
anchors = clustering.cluster_centers_.reshape(vocab_size, HORIZON, DIM)
cnt = np.zeros(vocab_size, dtype=np.int64)
for i in range(vocab_size):
cnt[i] = (clustering.labels_ == i).sum()
np.save(f'./traj_final/{vocab_size}_kmeans_3sec_xy.npy', anchors)
print(f'result saved to ./traj_final/{vocab_size}_kmeans_3sec_xy.npy')
# plot
vis(anchors)
if __name__ == '__main__':
main()