File size: 9,564 Bytes
c1a7f73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
import os
import pickle
import torch
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from argparse import ArgumentParser
from dev.datasets.preprocess import TokenProcessor
from dev.transforms.target_builder import WaymoTargetBuilder
colors = [
('#1f77b4', '#1a5a8a'), # blue
('#2ca02c', '#217721'), # green
('#ff7f0e', '#cc660b'), # orange
('#9467bd', '#6f4a91'), # purple
('#d62728', '#a31d1d'), # red
('#000000', '#000000'), # black
]
def draw_map(tokenize_data, token_processor: TokenProcessor, index, posfix):
print("Drawing raw data ...")
shift = 5
token_size = 2048
traj_token = token_processor.trajectory_token["veh"]
traj_token_all = token_processor.trajectory_token_all["veh"]
plt.subplots_adjust(left=0.3, right=0.7, top=0.7, bottom=0.3)
fig, ax = plt.subplots()
ax.set_axis_off()
scenario_id = data['scenario_id']
ax.scatter(tokenize_data["map_point"]["position"][:, 0],
tokenize_data["map_point"]["position"][:, 1], s=0.2, c='black', edgecolors='none')
index = np.array(index).astype(np.int32)
agent_data = tokenize_data["agent"]
token_index = agent_data["token_idx"][index]
token_valid_mask = agent_data["agent_valid_mask"][index]
num_agent, num_token = token_index.shape
tokens = traj_token[token_index.view(-1)].reshape(num_agent, num_token, 4, 2)
tokens_all = traj_token_all[token_index.view(-1)].reshape(num_agent, num_token, 6, 4, 2)
position = agent_data['position'][index, :, :2] # (num_agent, 91, 2)
heading = agent_data['heading'][index] # (num_agent, 91)
valid_mask = (position[..., 0] != 0) & (position[..., 1] != 0) # (num_agent, 91)
# TODO: fix this
if args.smart:
for shifted_tid in range(token_valid_mask.shape[1]):
valid_mask[:, shifted_tid * shift : (shifted_tid + 1) * shift] = token_valid_mask[:, shifted_tid : shifted_tid + 1].repeat(1, shift)
else:
for shifted_tid in range(token_index.shape[1]):
valid_mask[:, shifted_tid * shift : (shifted_tid + 1) * shift] = token_index[:, shifted_tid : shifted_tid + 1] != token_size + 2
last_valid_step = valid_mask.shape[1] - 1 - torch.argmax(valid_mask.flip(dims=[1]).long(), dim=1)
last_valid_step = {int(index[i]): int(last_valid_step[i]) for i in range(len(index))}
_, token_num, token_contour_dim, feat_dim = tokens.shape
tokens_src = tokens.reshape(num_agent, token_num * token_contour_dim, feat_dim)
tokens_all_src = tokens_all.reshape(num_agent, token_num * 6 * token_contour_dim, feat_dim)
prev_heading = heading[:, 0]
prev_pos = position[:, 0]
fig_paths = []
agent_colors = np.zeros((num_agent, position.shape[1]))
shape = np.zeros((num_agent, position.shape[1], 2)) + 3.
for tid in tqdm(range(shift, position.shape[1], shift), leave=False, desc="Token ..."):
cos, sin = prev_heading.cos(), prev_heading.sin()
rot_mat = prev_heading.new_zeros(num_agent, 2, 2)
rot_mat[:, 0, 0] = cos
rot_mat[:, 0, 1] = sin
rot_mat[:, 1, 0] = -sin
rot_mat[:, 1, 1] = cos
tokens_world = torch.bmm(torch.from_numpy(tokens_src).float(), rot_mat).reshape(num_agent,
token_num,
token_contour_dim,
feat_dim)
tokens_all_world = torch.bmm(torch.from_numpy(tokens_all_src).float(), rot_mat).reshape(num_agent,
token_num,
6,
token_contour_dim,
feat_dim)
tokens_world += prev_pos[:, None, None, :2]
tokens_all_world += prev_pos[:, None, None, None, :2]
tokens_select = tokens_world[:, tid // shift - 1] # (num_agent, token_contour_dim, feat_dim)
tokens_all_select = tokens_all_world[:, tid // shift - 1] # (num_agent, 6, token_contour_dim, feat_dim)
diff_xy = tokens_select[:, 0, :] - tokens_select[:, 3, :]
prev_heading = heading[:, tid].clone()
# prev_heading[valid_mask[:, tid - shift]] = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])[
# valid_mask[:, tid - shift]]
prev_pos = position[:, tid].clone()
# prev_pos[valid_mask[:, tid - shift]] = tokens_select.mean(dim=1)[valid_mask[:, tid - shift]]
# NOTE tokens_pos equals to tokens_all_pos[:, -1]
tokens_pos = tokens_select.mean(dim=1) # (num_agent, 2)
tokens_all_pos = tokens_all_select.mean(dim=2) # (num_agent, 6, 2)
# colors
cur_token_index = token_index[:, tid // shift - 1]
is_bos = cur_token_index == token_size
is_eos = cur_token_index == token_size + 1
is_invalid = cur_token_index == token_size + 2
is_valid = ~is_bos & ~is_eos & ~is_invalid
agent_colors[is_valid, tid - shift : tid] = 1
agent_colors[is_bos, tid - shift : tid] = 2
agent_colors[is_eos, tid - shift : tid] = 3
agent_colors[is_invalid, tid - shift : tid] = 4
for i in tqdm(range(shift), leave=False, desc="Timestep ..."):
global_tid = tid - shift + i
cur_valid_mask = valid_mask[:, tid - shift] # only when the last tokenized timestep is valid the current shifts trajectory is valid
xs = tokens_all_pos[cur_valid_mask, i, 0]
ys = tokens_all_pos[cur_valid_mask, i, 1]
widths = shape[cur_valid_mask, global_tid, 1]
lengths = shape[cur_valid_mask, global_tid, 0]
angles = heading[cur_valid_mask, global_tid]
cur_agent_colors = agent_colors[cur_valid_mask, global_tid]
current_index = index[cur_valid_mask]
drawn_agents = []
drawn_texts = []
for x, y, width, length, angle, color_type, id in zip(
xs, ys, widths, lengths, angles, cur_agent_colors, current_index):
if x < 3000: continue
agent = plt.Rectangle((x, y), width, length, # angle=((angle + np.pi / 2) / np.pi * 360) % 360,
linewidth=0.2,
facecolor=colors[int(color_type) - 1][0],
edgecolor=colors[int(color_type) - 1][1])
ax.add_patch(agent)
text = plt.text(x-4, y-4, f"{str(id)}:{str(global_tid)}", fontdict={'family': 'serif', 'size': 3, 'color': 'red'})
if global_tid != last_valid_step[id]:
drawn_agents.append(agent)
drawn_texts.append(text)
# draw timestep to be tokenized
if global_tid % shift == 0:
tokenize_agent = plt.Rectangle((x, y), width, length, # angle=((angle + np.pi / 2) / np.pi * 360) % 360,
linewidth=0.2, fill=False,
edgecolor=colors[int(color_type) - 1][1])
ax.add_patch(tokenize_agent)
plt.gca().set_aspect('equal', adjustable='box')
fig_path = f"debug/tokenize/steps/{scenario_id}_{global_tid}.png"
plt.savefig(fig_path, dpi=600, bbox_inches="tight")
fig_paths.append(fig_path)
for drawn_agent, drawn_text in zip(drawn_agents, drawn_texts):
drawn_agent.remove()
drawn_text.remove()
plt.close()
# generate gif
import imageio.v2 as imageio
images = []
for fig_path in tqdm(fig_paths, leave=False, desc="Generate gif ..."):
images.append(imageio.imread(fig_path))
imageio.mimsave(f"debug/tokenize/{scenario_id}_tokenize_{posfix}.gif", images, duration=0.1)
def main(data):
token_size = 2048
os.makedirs("debug/tokenize/steps/", exist_ok=True)
scenario_id = data["scenario_id"]
selected_agents_index = [1, 21, 35, 36, 46]
# raw data
if not os.path.exists(f"debug/tokenize/{scenario_id}_raw.gif"):
draw_raw(data, selected_agents_index)
# tokenization
token_processor = TokenProcessor(token_size, disable_invalid=args.smart)
print(f"Loaded token processor with token_size: {token_size}")
data = token_processor.preprocess(data)
# tokenzied data
posfix = "smart" if args.smart else "ours"
# if not os.path.exists(f"debug/tokenize/{scenario_id}_tokenize_{posfix}.gif"):
draw_tokenize(data, token_processor, selected_agents_index, posfix)
target_builder = WaymoTargetBuilder(num_historical_steps=11, num_future_steps=80)
data = target_builder(data)
if __name__ == "__main__":
parser = ArgumentParser(description="Testing script parameters")
parser.add_argument("--smart", action="store_true")
parser.add_argument("--data_path", type=str, default="/u/xiuyu/work/dev4/data/waymo_processed/training")
args = parser.parse_args()
scenario_id = "74ad7b76d5906d39"
data_path = os.path.join(args.data_path, f"{scenario_id}.pkl")
data = pickle.load(open(data_path, "rb"))
print(f"Loaded scenario {scenario_id}")
main(data) |