File size: 1,264 Bytes
da2e2ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import pickle

import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm import tqdm

from sklearn.cluster import MiniBatchKMeans

from navsim.agents.scripts.gen_vocab_kmeans import vis


@torch.no_grad()
def main():
    vocab_size = 4096
    shift_xy = True
    ori_traj = np.concatenate([
        np.load('./traj_local/test-pt1.npy'),
        np.load('./traj_local/test-pt2.npy'),
        np.load('./traj_local/test-pt3.npy')
    ], axis=0)
    ori_traj = ori_traj[:, :30, :2]
    sampled_timepoints = [5 * k - 1 for k in range(1, 7)]
    ori_traj = ori_traj[:, sampled_timepoints]
    if shift_xy:
        ori_traj = ori_traj[..., ::-1]
    L, HORIZON, DIM = ori_traj.shape
    # MINI-BATCH
    all_traj = ori_traj.reshape(L, -1)
    clustering = MiniBatchKMeans(vocab_size, batch_size=1024, verbose=True, tol=0.0).fit(all_traj)
    anchors = clustering.cluster_centers_.reshape(vocab_size, HORIZON, DIM)
    cnt = np.zeros(vocab_size, dtype=np.int64)
    for i in range(vocab_size):
        cnt[i] = (clustering.labels_ == i).sum()
    np.save(f'./traj_final/{vocab_size}_kmeans_3sec_xy.npy', anchors)
    print(f'result saved to ./traj_final/{vocab_size}_kmeans_3sec_xy.npy')
    # plot
    vis(anchors)


if __name__ == '__main__':
    main()