|
import pickle |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import torch |
|
from tqdm import tqdm |
|
|
|
from sklearn.cluster import MiniBatchKMeans |
|
import random |
|
from tqdm import tqdm |
|
@torch.no_grad() |
|
def main(): |
|
vocab_size = 512 |
|
split = 'test' |
|
ori_traj = np.concatenate([ |
|
np.load('./traj_local/test-pt1.npy'), |
|
|
|
|
|
], axis=0) |
|
L, HORIZON, DIM = ori_traj.shape |
|
k = random.randint(0, L - 1) |
|
selected = [] |
|
selected.append(np.copy(ori_traj[k])[None]) |
|
ori_traj = np.delete(ori_traj, k, axis=0) |
|
|
|
for _ in tqdm(range(vocab_size - 1)): |
|
max_dis = 0 |
|
candidate = None |
|
for traj in ori_traj: |
|
traj = traj[None] |
|
vocab_curr = np.concatenate(selected, 0) |
|
dis = (traj[:, -1, :2] - vocab_curr[:, -1, :2]) ** 2 |
|
dis = dis.sum(-1).min(0) |
|
if dis > max_dis: |
|
candidate = traj |
|
selected.append(candidate) |
|
anchors = np.concatenate(selected, 0) |
|
np.save(f'./traj_final/{split}_{vocab_size}_far.npy', anchors) |
|
print(f'result saved to ./traj_final/{split}_{vocab_size}_far.npy') |
|
|
|
vis(anchors) |
|
|
|
|
|
def vis(data): |
|
vocab_size = data.shape[0] |
|
fig, ax = plt.subplots() |
|
for i in range(vocab_size): |
|
ax.plot(data[i, :, 0], data[i, :, 1]) |
|
|
|
ax.legend() |
|
plt.show() |
|
|
|
def vis_pdm(data, pdm): |
|
for k, v in pdm.items(): |
|
mask = v > 0.95 |
|
vocab_size = data.shape[0] |
|
fig, ax = plt.subplots() |
|
for i in range(vocab_size): |
|
if mask[i]: |
|
ax.plot(data[i, :, 0], data[i, :, 1]) |
|
|
|
ax.legend() |
|
plt.show() |
|
break |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |