import os import pickle import torch import matplotlib.pyplot as plt import numpy as np from tqdm import tqdm from argparse import ArgumentParser from dev.datasets.preprocess import TokenProcessor from dev.transforms.target_builder import WaymoTargetBuilder colors = [ ('#1f77b4', '#1a5a8a'), # blue ('#2ca02c', '#217721'), # green ('#ff7f0e', '#cc660b'), # orange ('#9467bd', '#6f4a91'), # purple ('#d62728', '#a31d1d'), # red ('#000000', '#000000'), # black ] def draw_map(tokenize_data, token_processor: TokenProcessor, index, posfix): print("Drawing raw data ...") shift = 5 token_size = 2048 traj_token = token_processor.trajectory_token["veh"] traj_token_all = token_processor.trajectory_token_all["veh"] plt.subplots_adjust(left=0.3, right=0.7, top=0.7, bottom=0.3) fig, ax = plt.subplots() ax.set_axis_off() scenario_id = data['scenario_id'] ax.scatter(tokenize_data["map_point"]["position"][:, 0], tokenize_data["map_point"]["position"][:, 1], s=0.2, c='black', edgecolors='none') index = np.array(index).astype(np.int32) agent_data = tokenize_data["agent"] token_index = agent_data["token_idx"][index] token_valid_mask = agent_data["agent_valid_mask"][index] num_agent, num_token = token_index.shape tokens = traj_token[token_index.view(-1)].reshape(num_agent, num_token, 4, 2) tokens_all = traj_token_all[token_index.view(-1)].reshape(num_agent, num_token, 6, 4, 2) position = agent_data['position'][index, :, :2] # (num_agent, 91, 2) heading = agent_data['heading'][index] # (num_agent, 91) valid_mask = (position[..., 0] != 0) & (position[..., 1] != 0) # (num_agent, 91) # TODO: fix this if args.smart: for shifted_tid in range(token_valid_mask.shape[1]): valid_mask[:, shifted_tid * shift : (shifted_tid + 1) * shift] = token_valid_mask[:, shifted_tid : shifted_tid + 1].repeat(1, shift) else: for shifted_tid in range(token_index.shape[1]): valid_mask[:, shifted_tid * shift : (shifted_tid + 1) * shift] = token_index[:, shifted_tid : shifted_tid + 1] != token_size + 2 last_valid_step = valid_mask.shape[1] - 1 - torch.argmax(valid_mask.flip(dims=[1]).long(), dim=1) last_valid_step = {int(index[i]): int(last_valid_step[i]) for i in range(len(index))} _, token_num, token_contour_dim, feat_dim = tokens.shape tokens_src = tokens.reshape(num_agent, token_num * token_contour_dim, feat_dim) tokens_all_src = tokens_all.reshape(num_agent, token_num * 6 * token_contour_dim, feat_dim) prev_heading = heading[:, 0] prev_pos = position[:, 0] fig_paths = [] agent_colors = np.zeros((num_agent, position.shape[1])) shape = np.zeros((num_agent, position.shape[1], 2)) + 3. for tid in tqdm(range(shift, position.shape[1], shift), leave=False, desc="Token ..."): cos, sin = prev_heading.cos(), prev_heading.sin() rot_mat = prev_heading.new_zeros(num_agent, 2, 2) rot_mat[:, 0, 0] = cos rot_mat[:, 0, 1] = sin rot_mat[:, 1, 0] = -sin rot_mat[:, 1, 1] = cos tokens_world = torch.bmm(torch.from_numpy(tokens_src).float(), rot_mat).reshape(num_agent, token_num, token_contour_dim, feat_dim) tokens_all_world = torch.bmm(torch.from_numpy(tokens_all_src).float(), rot_mat).reshape(num_agent, token_num, 6, token_contour_dim, feat_dim) tokens_world += prev_pos[:, None, None, :2] tokens_all_world += prev_pos[:, None, None, None, :2] tokens_select = tokens_world[:, tid // shift - 1] # (num_agent, token_contour_dim, feat_dim) tokens_all_select = tokens_all_world[:, tid // shift - 1] # (num_agent, 6, token_contour_dim, feat_dim) diff_xy = tokens_select[:, 0, :] - tokens_select[:, 3, :] prev_heading = heading[:, tid].clone() # prev_heading[valid_mask[:, tid - shift]] = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])[ # valid_mask[:, tid - shift]] prev_pos = position[:, tid].clone() # prev_pos[valid_mask[:, tid - shift]] = tokens_select.mean(dim=1)[valid_mask[:, tid - shift]] # NOTE tokens_pos equals to tokens_all_pos[:, -1] tokens_pos = tokens_select.mean(dim=1) # (num_agent, 2) tokens_all_pos = tokens_all_select.mean(dim=2) # (num_agent, 6, 2) # colors cur_token_index = token_index[:, tid // shift - 1] is_bos = cur_token_index == token_size is_eos = cur_token_index == token_size + 1 is_invalid = cur_token_index == token_size + 2 is_valid = ~is_bos & ~is_eos & ~is_invalid agent_colors[is_valid, tid - shift : tid] = 1 agent_colors[is_bos, tid - shift : tid] = 2 agent_colors[is_eos, tid - shift : tid] = 3 agent_colors[is_invalid, tid - shift : tid] = 4 for i in tqdm(range(shift), leave=False, desc="Timestep ..."): global_tid = tid - shift + i cur_valid_mask = valid_mask[:, tid - shift] # only when the last tokenized timestep is valid the current shifts trajectory is valid xs = tokens_all_pos[cur_valid_mask, i, 0] ys = tokens_all_pos[cur_valid_mask, i, 1] widths = shape[cur_valid_mask, global_tid, 1] lengths = shape[cur_valid_mask, global_tid, 0] angles = heading[cur_valid_mask, global_tid] cur_agent_colors = agent_colors[cur_valid_mask, global_tid] current_index = index[cur_valid_mask] drawn_agents = [] drawn_texts = [] for x, y, width, length, angle, color_type, id in zip( xs, ys, widths, lengths, angles, cur_agent_colors, current_index): if x < 3000: continue agent = plt.Rectangle((x, y), width, length, # angle=((angle + np.pi / 2) / np.pi * 360) % 360, linewidth=0.2, facecolor=colors[int(color_type) - 1][0], edgecolor=colors[int(color_type) - 1][1]) ax.add_patch(agent) text = plt.text(x-4, y-4, f"{str(id)}:{str(global_tid)}", fontdict={'family': 'serif', 'size': 3, 'color': 'red'}) if global_tid != last_valid_step[id]: drawn_agents.append(agent) drawn_texts.append(text) # draw timestep to be tokenized if global_tid % shift == 0: tokenize_agent = plt.Rectangle((x, y), width, length, # angle=((angle + np.pi / 2) / np.pi * 360) % 360, linewidth=0.2, fill=False, edgecolor=colors[int(color_type) - 1][1]) ax.add_patch(tokenize_agent) plt.gca().set_aspect('equal', adjustable='box') fig_path = f"debug/tokenize/steps/{scenario_id}_{global_tid}.png" plt.savefig(fig_path, dpi=600, bbox_inches="tight") fig_paths.append(fig_path) for drawn_agent, drawn_text in zip(drawn_agents, drawn_texts): drawn_agent.remove() drawn_text.remove() plt.close() # generate gif import imageio.v2 as imageio images = [] for fig_path in tqdm(fig_paths, leave=False, desc="Generate gif ..."): images.append(imageio.imread(fig_path)) imageio.mimsave(f"debug/tokenize/{scenario_id}_tokenize_{posfix}.gif", images, duration=0.1) def main(data): token_size = 2048 os.makedirs("debug/tokenize/steps/", exist_ok=True) scenario_id = data["scenario_id"] selected_agents_index = [1, 21, 35, 36, 46] # raw data if not os.path.exists(f"debug/tokenize/{scenario_id}_raw.gif"): draw_raw(data, selected_agents_index) # tokenization token_processor = TokenProcessor(token_size, disable_invalid=args.smart) print(f"Loaded token processor with token_size: {token_size}") data = token_processor.preprocess(data) # tokenzied data posfix = "smart" if args.smart else "ours" # if not os.path.exists(f"debug/tokenize/{scenario_id}_tokenize_{posfix}.gif"): draw_tokenize(data, token_processor, selected_agents_index, posfix) target_builder = WaymoTargetBuilder(num_historical_steps=11, num_future_steps=80) data = target_builder(data) if __name__ == "__main__": parser = ArgumentParser(description="Testing script parameters") parser.add_argument("--smart", action="store_true") parser.add_argument("--data_path", type=str, default="/u/xiuyu/work/dev4/data/waymo_processed/training") args = parser.parse_args() scenario_id = "74ad7b76d5906d39" data_path = os.path.join(args.data_path, f"{scenario_id}.pkl") data = pickle.load(open(data_path, "rb")) print(f"Loaded scenario {scenario_id}") main(data)