File size: 1,264 Bytes
da2e2ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import pickle
import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm import tqdm
from sklearn.cluster import MiniBatchKMeans
from navsim.agents.scripts.gen_vocab_kmeans import vis
@torch.no_grad()
def main():
vocab_size = 4096
shift_xy = True
ori_traj = np.concatenate([
np.load('./traj_local/test-pt1.npy'),
np.load('./traj_local/test-pt2.npy'),
np.load('./traj_local/test-pt3.npy')
], axis=0)
ori_traj = ori_traj[:, :30, :2]
sampled_timepoints = [5 * k - 1 for k in range(1, 7)]
ori_traj = ori_traj[:, sampled_timepoints]
if shift_xy:
ori_traj = ori_traj[..., ::-1]
L, HORIZON, DIM = ori_traj.shape
# MINI-BATCH
all_traj = ori_traj.reshape(L, -1)
clustering = MiniBatchKMeans(vocab_size, batch_size=1024, verbose=True, tol=0.0).fit(all_traj)
anchors = clustering.cluster_centers_.reshape(vocab_size, HORIZON, DIM)
cnt = np.zeros(vocab_size, dtype=np.int64)
for i in range(vocab_size):
cnt[i] = (clustering.labels_ == i).sum()
np.save(f'./traj_final/{vocab_size}_kmeans_3sec_xy.npy', anchors)
print(f'result saved to ./traj_final/{vocab_size}_kmeans_3sec_xy.npy')
# plot
vis(anchors)
if __name__ == '__main__':
main() |