import math |
import os |
import torch |
import pickle |
import matplotlib.pyplot as plt |
import tensorflow as tf |
import numpy as np |
import numpy.typing as npt |
import fnmatch |
import io |
import seaborn as sns |
import matplotlib.axes as Axes |
import matplotlib.transforms as mtransforms |
from PIL import Image |
from functools import wraps |
from typing import Sequence, Union, Optional |
from tqdm import tqdm |
from typing import List, Literal |
from argparse import ArgumentParser |
from scipy.ndimage.filters import gaussian_filter |
from matplotlib.patches import FancyBboxPatch, Polygon, Rectangle, Circle |
from matplotlib.collections import LineCollection |
from torch_geometric.data import HeteroData, Dataset |
from waymo_open_dataset.protos import scenario_pb2 |
from dev.utils.func import CONSOLE |
from dev.modules.attr_tokenizer import Attr_Tokenizer |
from dev.datasets.preprocess import TokenProcessor, cal_polygon_contour, AGENT_TYPE |
from dev.datasets.scalable_dataset import WaymoTargetBuilder |
__all__ = ['plot_occ_grid', 'plot_interact_edge', 'plot_map_edge', 'plot_insert_grid', 'plot_binary_map', |
'plot_map_token', 'plot_prob_seed', 'plot_scenario', 'get_heatmap', 'draw_heatmap', 'plot_val', 'plot_tokenize'] |
def safe_run(func): |
@wraps(func) |
def wrapper1(*args, **kwargs): |
try: |
return func(*args, **kwargs) |
except Exception as e: |
print(e) |
return |
@wraps(func) |
def wrapper2(*args, **kwargs): |
return func(*args, **kwargs) |
if int(os.getenv('DEBUG', 0)): |
return wrapper2 |
else: |
return wrapper1 |
@safe_run |
def plot_occ_grid(scenario_id, occ, gt_occ=None, save_path='', mode='agent', prefix=''): |
def generate_box_edges(matrix, find_value=1): |
y, x = np.where(matrix == find_value) |
edges = [] |
for xi, yi in zip(x, y): |
edges.append([(xi - 0.5, yi - 0.5), (xi + 0.5, yi - 0.5)]) |
edges.append([(xi + 0.5, yi - 0.5), (xi + 0.5, yi + 0.5)]) |
edges.append([(xi + 0.5, yi + 0.5), (xi - 0.5, yi + 0.5)]) |
edges.append([(xi - 0.5, yi + 0.5), (xi - 0.5, yi - 0.5)]) |
return edges |
os.makedirs(save_path, exist_ok=True) |
n = int(math.sqrt(occ.shape[-1])) |
plot_n = 3 |
plot_t = 5 |
occ_list = [] |
for i in range(plot_n): |
for j in range(plot_t): |
occ_list.append(occ[i, j].reshape(n, n)) |
occ_gt_list = [] |
if gt_occ is not None: |
for i in range(plot_n): |
for j in range(plot_t): |
occ_gt_list.append(gt_occ[i, j].reshape(n, n)) |
row_labels = [f'n={n}' for n in range(plot_n)] |
col_labels = [f't={t}' for t in range(plot_t)] |
fig, axes = plt.subplots(plot_n, plot_t, figsize=(9, 6)) |
plt.subplots_adjust(wspace=0.1, hspace=0.1) |
for i, ax in enumerate(axes.flat): |
ax.imshow(occ_list[i], cmap='viridis', interpolation='nearest') |
ax.axis('off') |
if occ_gt_list: |
gt_edges = generate_box_edges(occ_gt_list[i]) |
gts = LineCollection(gt_edges, colors='blue', linewidths=0.5) |
ax.add_collection(gts) |
insert_edges = generate_box_edges(occ_gt_list[i], find_value=-1) |
inserts = LineCollection(insert_edges, colors='red', linewidths=0.5) |
ax.add_collection(inserts) |
ax.add_patch(plt.Rectangle((-0.5, -0.5), occ_list[i].shape[1], occ_list[i].shape[0], |
linewidth=2, edgecolor='black', facecolor='none')) |
for i, ax in enumerate(axes[:, 0]): |
ax.annotate(row_labels[i], xy=(-0.1, 0.5), xycoords="axes fraction", |
fontsize=12, ha="right", va="center", rotation=0) |
for j, ax in enumerate(axes[0, :]): |
ax.annotate(col_labels[j], xy=(0.5, 1.05), xycoords="axes fraction", |
fontsize=12, ha="center", va="bottom") |
plt.savefig(os.path.join(save_path, f'{prefix}{scenario_id}_occ_{mode}.png'), dpi=500, bbox_inches='tight') |
plt.close() |
@safe_run |
def plot_interact_edge(edge_index, scenario_ids, batch_sizes, num_seed, num_step, save_path='interact_edge_map', |
**kwargs): |
num_batch = len(scenario_ids) |
batches = torch.cat([ |
torch.arange(num_batch).repeat_interleave(repeats=batch_sizes, dim=0), |
torch.arange(num_batch).repeat_interleave(repeats=num_seed, dim=0), |
], dim=0).repeat(num_step).numpy() |
num_agent = batch_sizes.sum() + num_seed * num_batch |
batch_sizes = torch.nn.functional.pad(batch_sizes, (1, 0), mode='constant', value=0) |
ptr = torch.cumsum(batch_sizes, dim=0) |
ptr_seed = torch.tensor(np.array([0] + [num_seed] * num_batch), device=ptr.device) |
all_av_index = None |
if 'av_index' in kwargs: |
all_av_index = kwargs.pop('av_index').cpu() - ptr[:-1] |
is_bos = np.zeros((batch_sizes.sum(), num_step)).astype(np.bool_) |
if 'is_bos' in kwargs: |
is_bos = kwargs.pop('is_bos').cpu().numpy() |
src_index = torch.unique(edge_index[1]) |
for idx, src in enumerate(tqdm(src_index)): |
src_batch = batches[src] |
src_row = src % num_agent |
if src_row // batch_sizes.sum() > 0: |
seed_row = src_row % batch_sizes.sum() - ptr_seed[src_batch] |
src_row = batch_sizes[src_batch + 1] + seed_row |
else: |
src_row = src_row - ptr[src_batch] |
src_col = src // (num_agent) |
src_mask = np.zeros((batch_sizes[src_batch + 1] + num_seed, num_step)) |
src_mask[src_row, src_col] = 1 |
tgt_mask = np.zeros((src_mask.shape[0], num_step)) |
tgt_index = edge_index[0, edge_index[1] == src] |
for tgt in tgt_index: |
tgt_batch = batches[tgt] |
tgt_row = tgt % num_agent |
if tgt_row // batch_sizes.sum() > 0: |
seed_row = tgt_row % batch_sizes.sum() - ptr_seed[tgt_batch] |
tgt_row = batch_sizes[tgt_batch + 1] + seed_row |
else: |
tgt_row = tgt_row - ptr[tgt_batch] |
tgt_col = tgt // num_agent |
tgt_mask[tgt_row, tgt_col] = 1 |
assert tgt_batch == src_batch |
selected_step = tgt_mask.sum(axis=0) > 0 |
if selected_step.sum() > 1: |
print(f"\nidx={idx}", src.item(), src_row.item(), src_col.item()) |
print(selected_step) |
print(edge_index[:, edge_index[1] == src].tolist()) |
if all_av_index is not None: |
kwargs['av_index'] = int(all_av_index[src_batch]) |
t = kwargs.get('t', src_col) |
n = kwargs.get('n', 0) |
is_bos_batch = is_bos[ptr[src_batch] : ptr[src_batch + 1]] |
plot_binary_map(src_mask, tgt_mask, save_path, suffix=f'_{scenario_ids[src_batch]}_{t:02d}_{n:02d}_{idx:04d}', |
is_bos=is_bos_batch, **kwargs) |
@safe_run |
def plot_map_edge(edge_index, pos_a, data, save_path='map_edge_map'): |
map_points = data['map_point']['position'][:, :2].cpu().numpy() |
token_pos = data['pt_token']['position'][:, :2].cpu().numpy() |
token_heading = data['pt_token']['orientation'].cpu().numpy() |
num_pt = token_pos.shape[0] |
agent_index = torch.unique(edge_index[1]) |
for i in tqdm(agent_index): |
xy = pos_a[i].cpu().numpy() |
pt_index = edge_index[0, edge_index[1] == i].cpu().numpy() |
pt_index = pt_index % num_pt |
plt.subplots_adjust(left=0.3, right=0.7, top=0.7, bottom=0.3) |
_, ax = plt.subplots() |
ax.set_axis_off() |
plot_map_token(ax, map_points, token_pos[pt_index], token_heading[pt_index], colors='blue') |
ax.scatter(xy[0], xy[1], s=0.5, c='red', edgecolors='none') |
os.makedirs(save_path, exist_ok=True) |
plt.savefig(os.path.join(save_path, f'map_{i}.png'), dpi=600, bbox_inches='tight') |
plt.close() |
def get_heatmap(x, y, prob, s=3, bins=1000): |
heatmap, xedges, yedges = np.histogram2d(x, y, bins=bins, weights=prob, density=True) |
heatmap = gaussian_filter(heatmap, sigma=s) |
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]] |
return heatmap.T, extent |
@safe_run |
def draw_heatmap(vector, vector_prob, gt_idx): |
fig, ax = plt.subplots(figsize=(10, 10)) |
vector_prob = vector_prob.cpu().numpy() |
for j in range(vector.shape[0]): |
if j in gt_idx: |
color = (0, 0, 1) |
else: |
grey_scale = max(0, 0.9 - vector_prob[j]) |
color = (0.9, grey_scale, grey_scale) |
x0, y0, x1, y1, = vector[j, :4] |
ax.plot((x0, x1), (y0, y1), color=color, linewidth=2) |
return plt |
@safe_run |
def plot_insert_grid(scenario_id, prob, grid, ego_pos, map, save_path='', prefix='', inference=False, indices=None, |
all_t_in_one=False): |
""" |
prob: float array of shape (num_step, num_grid) |
grid: float array of shape (num_grid, 2) |
""" |
os.makedirs(save_path, exist_ok=True) |
n = int(math.sqrt(prob.shape[1])) |
for t in range(ego_pos.shape[0]): |
plt.subplots_adjust(left=0.3, right=0.7, top=0.7, bottom=0.3) |
_, ax = plt.subplots() |
prob_t = prob[t].reshape(n, n) |
plt.imshow(prob_t, cmap='viridis', interpolation='nearest') |
if indices is not None: |
indice = indices[t] |
if isinstance(indice, (int, float, np.int_)): |
indice = [indice] |
for _indice in indice: |
if _indice == -1: continue |
row = _indice // n |
col = _indice % n |
rect = Rectangle((col - 0.5, row - 0.5), 1, 1, edgecolor='red', facecolor='none', lw=2) |
ax.add_patch(rect) |
ax.grid(False) |
ax.set_aspect('equal', adjustable='box') |
plt.title('Prob of Rel Position Grid') |
plt.savefig(os.path.join(save_path, f'{prefix}{scenario_id}_heat_map_{t}.png'), dpi=300, bbox_inches='tight') |
plt.close() |
if all_t_in_one: |
break |
@safe_run |
def plot_insert_grid(scenario_id, prob, indices=None, save_path='', prefix='', inference=False): |
""" |
prob: float array of shape (num_seed, num_step, num_grid) |
grid: float array of shape (num_grid, 2) |
""" |
os.makedirs(save_path, exist_ok=True) |
n = int(math.sqrt(prob.shape[-1])) |
plot_n = 3 |
plot_t = 5 |
prob_list = [] |
for i in range(plot_n): |
for j in range(plot_t): |
prob_list.append(prob[i, j].reshape(n, n)) |
indice_list = [] |
if indices is not None: |
for i in range(plot_n): |
for j in range(plot_t): |
indice_list.append(indices[i, j]) |
row_labels = [f'n={n}' for n in range(plot_n)] |
col_labels = [f't={t}' for t in range(plot_t)] |
fig, axes = plt.subplots(plot_n, plot_t, figsize=(9, 6)) |
fig.suptitle('Prob of Insert Position Grid') |
plt.subplots_adjust(wspace=0.1, hspace=0.1) |
for i, ax in enumerate(axes.flat): |
ax.imshow(prob_list[i], cmap='viridis', interpolation='nearest') |
ax.axis('off') |
if indice_list: |
row = indice_list[i] // n |
col = indice_list[i] % n |
rect = Rectangle((col - .5, row - .5), 1, 1, edgecolor='red', facecolor='none', lw=2) |
ax.add_patch(rect) |
ax.add_patch(plt.Rectangle((-0.5, -0.5), prob_list[i].shape[1], prob_list[i].shape[0], |
linewidth=2, edgecolor='black', facecolor='none')) |
for i, ax in enumerate(axes[:, 0]): |
ax.annotate(row_labels[i], xy=(-0.1, 0.5), xycoords="axes fraction", |
fontsize=12, ha="right", va="center", rotation=0) |
for j, ax in enumerate(axes[0, :]): |
ax.annotate(col_labels[j], xy=(0.5, 1.05), xycoords="axes fraction", |
fontsize=12, ha="center", va="bottom") |
ax.grid(False) |
ax.set_aspect('equal', adjustable='box') |
plt.savefig(os.path.join(save_path, f'{prefix}{scenario_id}_insert_map.png'), dpi=500, bbox_inches='tight') |
plt.close() |
@safe_run |
def plot_binary_map(src_mask, tgt_mask, save_path='', suffix='', av_index=None, is_bos=None, **kwargs): |
from matplotlib.colors import ListedColormap |
os.makedirs(save_path, exist_ok=True) |
fig, axes = plt.subplots(1, 2, figsize=(10, 8)) |
title = [] |
if kwargs.get('t', None) is not None: |
t = kwargs['t'] |
title.append(f't={t}') |
if kwargs.get('n', None) is not None: |
n = kwargs['n'] |
title.append(f'n={n}') |
plt.title(' '.join(title)) |
cmap = ListedColormap(['white', 'green']) |
axes[0].imshow(src_mask, cmap=cmap, interpolation='nearest') |
cmap = ListedColormap(['white', 'orange']) |
axes[1].imshow(tgt_mask, cmap=cmap, interpolation='nearest') |
if av_index is not None: |
rect = Rectangle((-0.5, av_index - 0.5), src_mask.shape[1], 1, edgecolor='red', facecolor='none', lw=2) |
axes[0].add_patch(rect) |
rect = Rectangle((-0.5, av_index - 0.5), tgt_mask.shape[1], 1, edgecolor='red', facecolor='none', lw=2) |
axes[1].add_patch(rect) |
if is_bos is not None: |
rows, cols = np.where(is_bos) |
for row, col in zip(rows, cols): |
rect = Rectangle((col - 0.5, row - 0.5), 1, 1, edgecolor='blue', facecolor='none', lw=1) |
axes[0].add_patch(rect) |
rect = Rectangle((col - 0.5, row - 0.5), 1, 1, edgecolor='blue', facecolor='none', lw=1) |
axes[1].add_patch(rect) |
for ax in axes: |
ax.set_xticks(range(src_mask.shape[1] + 1), minor=False) |
ax.set_yticks(range(src_mask.shape[0] + 1), minor=False) |
ax.grid(which='major', color='gray', linestyle='--', linewidth=0.5) |
plt.savefig(os.path.join(save_path, f'map{suffix}.png'), dpi=300, bbox_inches='tight') |
plt.close() |
@safe_run |
def plot_prob_seed(scenario_id, prob, save_path, prefix='', indices=None): |
os.makedirs(save_path, exist_ok=True) |
plt.figure(figsize=(8, 5)) |
plt.imshow(prob, cmap='viridis', aspect='auto') |
plt.colorbar() |
plt.title('Seed Probability') |
if indices is not None: |
for col in range(indices.shape[1]): |
for row in indices[:, col]: |
if row == -1: continue |
rect = Rectangle((col - 0.5, row - 0.5), 1, 1, edgecolor='red', facecolor='none', lw=2) |
plt.gca().add_patch(rect) |
plt.tight_layout() |
plt.savefig(os.path.join(save_path, f'{prefix}{scenario_id}_prob_seed.png'), dpi=300, bbox_inches='tight') |
plt.close() |
@safe_run |
def plot_raw(): |
plt.figure(figsize=(30, 30)) |
plt.rcParams['axes.facecolor']='white' |
data_path = '/u/xiuyu/work/dev4/data/waymo/scenario/training' |
os.makedirs("data/vis/raw/0/", exist_ok=True) |
file_list = os.listdir(data_path) |
for cnt_file, file in enumerate(file_list): |
file_path = os.path.join(data_path, file) |
dataset = tf.data.TFRecordDataset(file_path, compression_type='') |
for scenario_idx, data in enumerate(dataset): |
scenario = scenario_pb2.Scenario() |
scenario.ParseFromString(bytearray(data.numpy())) |
tqdm.write(f"scenario id: {scenario.scenario_id}") |
for i in range(len(scenario.map_features)): |
if str(scenario.map_features[i].lane) != '': |
line_x = [z.x for z in scenario.map_features[i].lane.polyline] |
line_y = [z.y for z in scenario.map_features[i].lane.polyline] |
plt.scatter(line_x, line_y, c='g', s=5) |
plt.text(line_x[0], line_y[0], str(scenario.map_features[i].id), fontdict={'family': 'serif', 'size': 20, 'color': 'green'}) |
if str(scenario.map_features[i].road_edge) != '': |
road_edge_x = [polyline.x for polyline in scenario.map_features[i].road_edge.polyline] |
road_edge_y = [polyline.y for polyline in scenario.map_features[i].road_edge.polyline] |
plt.scatter(road_edge_x, road_edge_y) |
plt.text(road_edge_x[0], road_edge_y[0], scenario.map_features[i].road_edge.type, fontdict={'family': 'serif', 'size': 20, 'color': 'black'}) |
if scenario.map_features[i].road_edge.type == 2: |
plt.scatter(road_edge_x, road_edge_y, c='k') |
elif scenario.map_features[i].road_edge.type == 3: |
plt.scatter(road_edge_x, road_edge_y, c='purple') |
print(scenario.map_features[i].road_edge) |
else: |
plt.scatter(road_edge_x, road_edge_y, c='k') |
if str(scenario.map_features[i].road_line) != '': |
road_line_x = [j.x for j in scenario.map_features[i].road_line.polyline] |
road_line_y = [j.y for j in scenario.map_features[i].road_line.polyline] |
if scenario.map_features[i].road_line.type == 7: |
plt.plot(road_line_x, road_line_y, c='y') |
elif scenario.map_features[i].road_line.type == 8: |
plt.plot(road_line_x, road_line_y, c='y') |
elif scenario.map_features[i].road_line.type == 6: |
plt.plot(road_line_x, road_line_y, c='y') |
elif scenario.map_features[i].road_line.type == 1: |
for i in range(int(len(road_line_x) / 7)): |
plt.plot(road_line_x[i * 7 : 5 + i * 7], road_line_y[i * 7 : 5 + i * 7], color='w') |
elif scenario.map_features[i].road_line.type == 2: |
plt.plot(road_line_x, road_line_y, c='w') |
else: |
plt.plot(road_line_x, road_line_y, c='w') |
scenario_has_invalid_tracks = False |
for i in range(len(scenario.tracks)): |
traj_x = [center.center_x for center in scenario.tracks[i].states] |
traj_y = [center.center_y for center in scenario.tracks[i].states] |
head = [center.heading for center in scenario.tracks[i].states] |
valid = [center.valid for center in scenario.tracks[i].states] |
print(valid) |
if i == scenario.sdc_track_index: |
plt.scatter(traj_x[0], traj_y[0], s=140, c='r', marker='s') |
plt.scatter([x for x, v in zip(traj_x, valid) if v], |
[y for y, v in zip(traj_y, valid) if v], s=14, c='r') |
plt.scatter([x for x, v in zip(traj_x, valid) if not v], |
[y for y, v in zip(traj_y, valid) if not v], s=14, c='m') |
else: |
plt.scatter(traj_x[0], traj_y[0], s=140, c='k', marker='s') |
plt.scatter([x for x, v in zip(traj_x, valid) if v], |
[y for y, v in zip(traj_y, valid) if v], s=14, c='b') |
plt.scatter([x for x, v in zip(traj_x, valid) if not v], |
[y for y, v in zip(traj_y, valid) if not v], s=14, c='m') |
if valid.count(False) > 0: |
scenario_has_invalid_tracks = True |
if scenario_has_invalid_tracks: |
plt.savefig(f"scenario_{scenario_idx}_{scenario.scenario_id}.png") |
plt.clf() |
breakpoint() |
break |
colors = [ |
('#1f77b4', '#1a5a8a'), |
('#2ca02c', '#217721'), |
('#ff7f0e', '#cc660b'), |
('#9467bd', '#6f4a91'), |
('#d62728', '#a31d1d'), |
('#000000', '#000000'), |
] |
@safe_run |
def plot_gif(): |
data_path = "/u/xiuyu/work/dev4/data/waymo_processed/training" |
os.makedirs("data/vis/processed/0/gif", exist_ok=True) |
file_list = os.listdir(data_path) |
for scenario_idx, file in tqdm(enumerate(file_list), leave=False, desc="Scenario"): |
fig, ax = plt.subplots() |
ax.set_axis_off() |
file_path = os.path.join(data_path, file) |
data = pickle.load(open(file_path, "rb")) |
scenario_id = data['scenario_id'] |
save_path = os.path.join("data/vis/processed/0/gif", |
f"scenario_{scenario_idx}_{scenario_id}.gif") |
if os.path.exists(save_path): |
tqdm.write(f"Skipped {save_path}.") |
continue |
ax.scatter(data['map_point']['position'][:, 0], |
data['map_point']['position'][:, 1], s=0.2, c='black', edgecolors='none') |
agent_data = data['agent'] |
av_index = agent_data['av_index'] |
position = agent_data['position'] |
heading = agent_data['heading'] |
shape = agent_data['shape'] |
category = agent_data['category'] |
valid_mask = (position[..., 0] != 0) & (position[..., 1] != 0) |
num_agent = valid_mask.shape[0] |
num_timestep = position.shape[1] |
is_av = np.arange(num_agent) == int(av_index) |
is_blue = valid_mask.sum(axis=1) == num_timestep |
is_green = ~valid_mask[:, 0] & valid_mask[:, -1] |
is_orange = valid_mask[:, 0] & ~valid_mask[:, -1] |
is_purple = (valid_mask.sum(axis=1) != num_timestep |
) & (~is_green) & (~is_orange) |
agent_colors = np.zeros((num_agent,)) |
agent_colors[is_blue] = 1 |
agent_colors[is_green] = 2 |
agent_colors[is_orange] = 3 |
agent_colors[is_purple] = 4 |
agent_colors[is_av] = 5 |
veh_mask = category == 1 |
ped_mask = category == 2 |
cyc_mask = category == 3 |
shape[veh_mask, :, 1] = 1.8 |
shape[veh_mask, :, 0] = 1.8 |
shape[ped_mask, :, 1] = 0.5 |
shape[ped_mask, :, 0] = 0.5 |
shape[cyc_mask, :, 1] = 1.0 |
shape[cyc_mask, :, 0] = 1.0 |
fig_paths = [] |
for tid in tqdm(range(num_timestep), leave=False, desc="Timestep"): |
current_valid_mask = valid_mask[:, tid] |
xs = position[current_valid_mask, tid, 0] |
ys = position[current_valid_mask, tid, 1] |
widths = shape[current_valid_mask, tid, 1] |
lengths = shape[current_valid_mask, tid, 0] |
angles = heading[current_valid_mask, tid] |
current_agent_colors = agent_colors[current_valid_mask] |
drawn_agents = [] |
contours = cal_polygon_contour(xs, ys, angles, widths, lengths) |
contours = np.concatenate([contours, contours[:, 0:1]], axis=1) |
for x, y, width, length, angle, color_type in zip( |
xs, ys, widths, lengths, angles, current_agent_colors): |
agent = plt.Rectangle((x, y), width, length, angle=((angle + np.pi / 2) / np.pi * 360) % 360, |
linewidth=0.2, |
facecolor=colors[int(color_type) - 1][0], |
edgecolor=colors[int(color_type) - 1][1]) |
ax.add_patch(agent) |
drawn_agents.append(agent) |
plt.gca().set_aspect('equal', adjustable='box') |
fig_path = os.path.join("data/vis/processed/0/", |
f"scenario_{scenario_idx}_{scenario_id}_{tid}.png") |
plt.savefig(fig_path, dpi=600) |
fig_paths.append(fig_path) |
for drawn_agent in drawn_agents: |
drawn_agent.remove() |
plt.close() |
import imageio.v2 as imageio |
images = [] |
for fig_path in tqdm(fig_paths, leave=False, desc="Generate gif ..."): |
images.append(imageio.imread(fig_path)) |
imageio.mimsave(save_path, images, duration=0.1) |
@safe_run |
def plot_map_token(ax: Axes, map_points: npt.NDArray, token_pos: npt.NDArray, token_heading: npt.NDArray, colors: Union[str, npt.NDArray]=None): |
plot_map(ax, map_points) |
x, y = token_pos[:, 0], token_pos[:, 1] |
u = np.cos(token_heading) |
v = np.sin(token_heading) |
if colors is None: |
colors = np.random.rand(x.shape[0], 3) |
ax.quiver(x, y, u, v, angles='xy', scale_units='xy', scale=0.2, color=colors, width=0.005, |
headwidth=0.2, headlength=2) |
ax.scatter(x, y, color='blue', s=0.2, edgecolors='none') |
ax.axis("equal") |
@safe_run |
def plot_map(ax: Axes, map_points: npt.NDArray, color='black'): |
ax.scatter(map_points[:, 0], map_points[:, 1], s=0.5, c=color, edgecolors='none') |
xmin = np.min(map_points[:, 0]) |
xmax = np.max(map_points[:, 0]) |
ymin = np.min(map_points[:, 1]) |
ymax = np.max(map_points[:, 1]) |
ax.set_xlim(xmin, xmax) |
ax.set_ylim(ymin, ymax) |
@safe_run |
def plot_agent(ax: Axes, xy: Sequence[float], heading: float, type: str, state, is_av: bool=False, |
pl2seed_radius: float=25., attr_tokenizer: Attr_Tokenizer=None, enter_index: list=[], **kwargs): |
if type == 'veh': |
length = 4.3 |
width = 1.8 |
size = 1.0 |
elif type == 'ped': |
length = 0.5 |
width = 0.5 |
size = 0.1 |
elif type == 'cyc': |
length = 1.9 |
width = 0.5 |
size = 0.3 |
else: |
raise ValueError(f"Unsupported agent type {type}") |
if kwargs.get('label', None) is not None: |
ax.text( |
xy[0] + 1.5, xy[1] + 1.5, |
kwargs['label'], fontsize=2, color="darkred", ha="center", va="center" |
) |
patch = FancyBboxPatch([-length / 2, -width / 2], length, width, linewidth=.2, **kwargs) |
transform = ( |
mtransforms.Affine2D().rotate(heading).translate(xy[0], xy[1]) |
+ ax.transData |
) |
patch.set_transform(transform) |
kwargs['label'] = None |
angles = [0, 2 * np.pi / 3, np.pi, 4 * np.pi / 3] |
pts = np.stack([size * np.cos(angles), size * np.sin(angles)], axis=-1) |
center_patch = Polygon(pts, zorder=10., linewidth=.2, **kwargs) |
center_patch.set_transform(transform) |
ax.add_patch(patch) |
ax.add_patch(center_patch) |
if is_av: |
if attr_tokenizer is not None: |
circle_patch = Circle( |
(xy[0], xy[1]), pl2seed_radius, linewidth=0.5, edgecolor='gray', linestyle='--', facecolor='none' |
) |
ax.add_patch(circle_patch) |
grid = attr_tokenizer.get_grid(torch.tensor(np.array(xy)).float(), |
torch.tensor(np.array([heading])).float()).numpy()[0] |
ax.scatter(grid[:, 0], grid[:, 1], s=0.3, c='blue', edgecolors='none') |
ax.text(grid[0, 0], grid[0, 1], 'Front', fontsize=2, color='darkred', ha='center', va='center') |
ax.text(grid[-1, 0], grid[-1, 1], 'Back', fontsize=2, color='darkred', ha='center', va='center') |
if enter_index: |
for i in enter_index: |
ax.plot(grid[int(i), 0], grid[int(i), 1], marker='x', color='red', markersize=1) |
return patch, center_patch |
@safe_run |
def plot_all(map, xs, ys, angles, types, colors, is_avs, pl2seed_radius: float=25., |
attr_tokenizer: Attr_Tokenizer=None, enter_index: list=[], labels: list=[], **kwargs): |
plt.subplots_adjust(left=0.3, right=0.7, top=0.7, bottom=0.3) |
_, ax = plt.subplots() |
ax.set_axis_off() |
plot_map(ax, map) |
if not labels: |
labels = [None] * xs.shape[0] |
for x, y, angle, type, color, label, is_av in zip(xs, ys, angles, types, colors, labels, is_avs): |
assert type in ('veh', 'ped', 'cyc'), f"Unsupported type {type}." |
plot_agent(ax, [x, y], angle.item(), type, None, is_av, facecolor=color, edgecolor='k', label=label, |
pl2seed_radius=pl2seed_radius, attr_tokenizer=attr_tokenizer, enter_index=enter_index) |
ax.grid(False) |
ax.set_aspect('equal', adjustable='box') |
if kwargs.get('limit_size', None): |
cx = float(xs[is_avs]) |
cy = float(ys[is_avs]) |
lx, ly = kwargs['limit_size'] |
xmin, xmax = cx - lx, cx + lx |
ymin, ymax = cy - ly, cy + ly |
ax.set_xlim(xmin, xmax) |
ax.set_ylim(ymin, ymax) |
pil_image = None |
if kwargs.get('save_path', None): |
plt.savefig(kwargs['save_path'], dpi=600, bbox_inches="tight") |
else: |
buf = io.BytesIO() |
plt.savefig(buf, format='png', dpi=600, bbox_inches='tight') |
buf.seek(0) |
pil_image = Image.open(buf).convert('RGB') |
plt.close() |
return pil_image |
@safe_run |
def plot_file(gt_folder: str, |
folder: Optional[str] = None, |
files: Optional[str] = None, |
save_gif: bool = True, |
batch_idx: Optional[int] = None, |
time_idx: Optional[List[int]] = None, |
limit_size: Optional[List[int]] = None, |
**kwargs, |
) -> List[Image.Image]: |
from dev.metrics.compute_metrics import _unbatch |
shift = 5 |
if files is None: |
assert os.path.exists(folder), f'Path {folder} does not exist.' |
files = list(fnmatch.filter(os.listdir(folder), 'idx_*_rollouts.pkl')) |
CONSOLE.log(f'Found {len(files)} rollouts files from {folder}.') |
if folder is None: |
assert os.path.exists(files), f'Path {files} does not exist.' |
folder = os.path.dirname(files) |
files = [files] |
parent, folder_name = os.path.split(folder.rstrip(os.sep)) |
if save_gif: |
save_path = os.path.join(parent, f'{folder_name}_plots') |
else: |
save_path = None |
file_outs = [] |
for file in (pbar := tqdm(files, leave=False, desc='Plotting files ...')): |
pbar.set_postfix(file=file) |
with open(os.path.join(folder, file), 'rb') as f: |
preds = pickle.load(f) |
scenario_ids = preds['_scenario_id'] |
agent_batch = preds['agent_batch'] |
agent_id = _unbatch(preds['agent_id'], agent_batch) |
preds_traj = _unbatch(preds['pred_traj'], agent_batch) |
preds_head = _unbatch(preds['pred_head'], agent_batch) |
preds_type = _unbatch(preds['pred_type'], agent_batch) |
if 'pred_state' in preds: |
preds_state = _unbatch(preds['pred_state'], agent_batch) |
else: |
preds_state = tuple([torch.ones((*traj.shape[:2], traj.shape[2] // shift)) for traj in preds_traj]) |
preds_valid = _unbatch(preds['pred_valid'], agent_batch) |
if batch_idx is not None: |
scenario_ids = scenario_ids[batch_idx : batch_idx + 1] |
agent_id = (agent_id[batch_idx],) |
preds_traj = (preds_traj[batch_idx],) |
preds_head = (preds_head[batch_idx],) |
preds_type = (preds_type[batch_idx],) |
preds_state = (preds_state[batch_idx],) |
preds_valid = (preds_valid[batch_idx],) |
scenario_outs = [] |
for i, scenario_id in enumerate(scenario_ids): |
n_agent, n_rollouts = preds_traj[0].shape[:2] |
rollout_outs = [] |
for j in range(n_rollouts): |
pred = dict(scenario_id=[scenario_id], |
pred_traj=preds_traj[i][:, j], |
pred_head=preds_head[i][:, j], |
pred_state=( |
torch.cat([torch.zeros(n_agent, 1), preds_state[i][:, j].repeat_interleave(repeats=shift, dim=-1)], |
dim=1) |
), |
pred_type=preds_type[i][:, j], |
) |
if 'av_id' in preds: |
av_index = agent_id[i][:, 0].tolist().index(preds['av_id']) |
else: |
av_index = n_agent - 1 |
data_path = os.path.join(gt_folder, 'validation', f'{scenario_id}.pkl') |
with open(data_path, 'rb') as f: |
data = pickle.load(f) |
rollout_outs.append( |
plot_val(data, pred, |
av_index=av_index, |
save_path=save_path, |
save_gif=save_gif, |
time_idx=time_idx, |
limit_size=limit_size, |
**kwargs |
) |
) |
scenario_outs.append(rollout_outs) |
file_outs.append(scenario_outs) |
return file_outs |
@safe_run |
def plot_val(data: Union[dict, str], pred: dict, av_index: int, save_path: str, suffix: str='', |
pl2seed_radius: float=75., attr_tokenizer=None, **kwargs): |
if isinstance(data, str): |
assert data.endswith('.pkl'), f'Got invalid data path {data}.' |
assert os.path.exists(data), f'Path {data} does not exist.' |
with open(data, 'rb') as f: |
data = pickle.load(f) |
map_point = data['map_point']['position'].cpu().numpy() |
scenario_id = pred['scenario_id'][0] |
pred_traj = pred['pred_traj'].cpu().numpy() |
pred_type = list(map(lambda i: AGENT_TYPE[i], pred['pred_type'].tolist())) |
pred_state = pred['pred_state'].cpu().numpy() |
pred_head = pred['pred_head'].cpu().numpy() |
ids = np.arange(pred_traj.shape[0]) |
if 'agent_labels' in pred: |
kwargs.update(agent_labels=pred['agent_labels']) |
return plot_scenario(scenario_id, map_point, pred_traj, pred_head, pred_state, pred_type, |
av_index=av_index, ids=ids, save_path=save_path, suffix=suffix, |
pl2seed_radius=pl2seed_radius, attr_tokenizer=attr_tokenizer, **kwargs) |
@safe_run |
def plot_scenario(scenario_id: str, |
map_data: npt.NDArray, |
traj: npt.NDArray, |
heading: npt.NDArray, |
state: npt.NDArray, |
types: List[str], |
av_index: int, |
color_type: Literal['state', 'type', 'seed', 'insert']='seed', |
state_type: List[str]=['invalid', 'valid', 'enter', 'exit'], |
plot_enter: bool=False, |
suffix: str='', |
pl2seed_radius: float=25., |
attr_tokenizer: Attr_Tokenizer=None, |
enter_index: List[list] = [], |
save_gif: bool=True, |
tokenized: bool=False, |
agent_labels: List[List[Optional[str]]] = [], |
**kwargs): |
num_historical_steps = 11 |
shift = 5 |
num_agent, num_timestep = traj.shape[:2] |
if tokenized: |
num_historical_steps = 2 |
shift = 1 |
if ( |
'save_path' in kwargs |
and kwargs['save_path'] != '' |
and kwargs['save_path'] != None |
): |
os.makedirs(kwargs['save_path'], exist_ok=True) |
save_id = int(max([0] + list(map(lambda fname: int(fname.split("_")[-1]), |
filter(lambda fname: fname.startswith(scenario_id) |
and os.path.isdir(os.path.join(kwargs['save_path'], fname)), |
os.listdir(kwargs['save_path'])))))) + 1 |
os.makedirs(f"{kwargs['save_path']}/{scenario_id}_{str(save_id).zfill(3)}", exist_ok=True) |
if save_id > 1: |
try: |
import shutil |
shutil.rmtree(f"{kwargs['save_path']}/{scenario_id}_{str(save_id - 1).zfill(3)}") |
except: |
pass |
visible_mask = state != state_type.index('invalid') |
if not plot_enter: |
visible_mask &= (state != state_type.index('enter')) |
last_valid_step = visible_mask.shape[1] - 1 - torch.argmax(torch.Tensor(visible_mask).flip(dims=[1]).long(), dim=1) |
ids = None |
if 'ids' in kwargs: |
ids = kwargs['ids'] |
last_valid_step = {int(ids[i]): int(last_valid_step[i]) for i in range(len(ids))} |
agent_colors = np.zeros((num_agent, num_timestep, 3)) |
agent_palette = sns.color_palette('husl', n_colors=7) |
state_colors = {state: np.array(agent_palette[i]) for i, state in enumerate(state_type)} |
seed_colors = {seed: np.array(agent_palette[i]) for i, seed in enumerate(['existing', 'entered', 'exited'])} |
if color_type == 'state': |
for t in range(state.shape[1]): |
agent_colors[state[:, t] == state_type.index('invalid'), t * shift : (t + 1) * shift] = state_colors['invalid'] |
agent_colors[state[:, t] == state_type.index('valid'), t * shift : (t + 1) * shift] = state_colors['valid'] |
agent_colors[state[:, t] == state_type.index('enter'), t * shift : (t + 1) * shift] = state_colors['enter'] |
agent_colors[state[:, t] == state_type.index('exit'), t * shift : (t + 1) * shift] = state_colors['exit'] |
if color_type == 'seed': |
agent_colors[:, :] = seed_colors['existing'] |
is_exited = np.any(state[:, num_historical_steps - 1:] == state_type.index('exit'), axis=-1) |
is_entered = np.any(state[:, num_historical_steps - 1:] == state_type.index('enter'), axis=-1) |
is_entered[av_index + 1:] = True |
agent_colors[is_exited, :] = seed_colors['exited'] |
agent_colors[is_entered, :] = seed_colors['entered'] |
if color_type == 'insert': |
agent_colors[:, :] = seed_colors['exited'] |
agent_colors[av_index + 1:] = seed_colors['existing'] |
agent_colors[av_index, :] = np.array(agent_palette[-1]) |
is_av = np.zeros_like(state[:, 0]).astype(np.bool_) |
is_av[av_index] = True |
timesteps = list(range(num_timestep)) |
if kwargs.get('time_idx', None) is not None: |
time_idx = kwargs['time_idx'] |
assert set(time_idx).issubset(set(timesteps)), f'Got invalid time_idx: {time_idx=} v.s. {timesteps=}' |
timesteps = sorted(time_idx) |
limit_size = kwargs.get('limit_size', None) |
if limit_size is not None: |
assert len(limit_size) == 2, f'Got invalid `limit_size`: {limit_size=}' |
pil_images = [] |
fig_paths = [] |
for tid in tqdm(timesteps, leave=False, desc="Plot ..."): |
mask_t = visible_mask[:, tid] |
xs = traj[mask_t, tid, 0] |
ys = traj[mask_t, tid, 1] |
angles = heading[mask_t, tid] |
colors = agent_colors[mask_t, tid] |
types_t = [types[i] for i, mask in enumerate(mask_t) if mask] |
if ids is not None: |
ids_t = ids[mask_t] |
is_av_t = is_av[mask_t] |
enter_index_t = enter_index[tid] if enter_index else None |
labels = [] |
if agent_labels: |
labels = [agent_labels[i][tid // shift] for i in range(len(agent_labels)) if mask_t[i]] |
fig_path = None |
if kwargs.get('save_path', None) is not None: |
save_path = kwargs['save_path'] |
fig_path = os.path.join(f"{save_path}/{scenario_id}_{str(save_id).zfill(3)}", f"{tid}.png") |
fig_paths.append(fig_path) |
pil_images.append( |
plot_all(map_data, xs, ys, angles, types_t, |
colors=colors, |
save_path=fig_path, |
is_avs=is_av_t, |
pl2seed_radius=pl2seed_radius, |
attr_tokenizer=attr_tokenizer, |
enter_index=enter_index_t, |
labels=labels, |
limit_size=limit_size, |
) |
) |
if fig_paths and save_gif: |
os.makedirs(os.path.join(save_path, 'gifs'), exist_ok=True) |
images = [] |
gif_path = f"{save_path}/gifs/{scenario_id}_{str(save_id).zfill(3)}.gif" |
for fig_path in tqdm(fig_paths, leave=False, desc="Generate gif ..."): |
images.append(Image.open(fig_path)) |
try: |
images[0].save(gif_path, save_all=True, append_images=images[1:], duration=100, loop=0) |
tqdm.write(f"Saved gif at {gif_path}") |
try: |
import shutil |
shutil.rmtree(f"{save_path}/{scenario_id}_{str(save_id).zfill(3)}") |
os.remove(f"{save_path}/gifs/{scenario_id}_{str(save_id - 1).zfill(3)}.gif") |
except: |
pass |
except Exception as e: |
tqdm.write(f"{e}! Failed to save gif at {gif_path}") |
return pil_images |
def match_token_map(data): |
argmin_sample_len = 3 |
map_token_traj_path = '/u/xiuyu/work/dev4/dev/tokens/map_traj_token5.pkl' |
map_token_traj = pickle.load(open(map_token_traj_path, 'rb')) |
map_token = {'traj_src': map_token_traj['traj_src'], } |
traj_end_theta = np.arctan2(map_token['traj_src'][:, -1, 1] - map_token['traj_src'][:, -2, 1], |
map_token['traj_src'][:, -1, 0] - map_token['traj_src'][:, -2, 0]) |
indices = torch.linspace(0, map_token['traj_src'].shape[1]-1, steps=argmin_sample_len).long() |
map_token['sample_pt'] = torch.from_numpy(map_token['traj_src'][:, indices]).to(torch.float) |
map_token['traj_end_theta'] = torch.from_numpy(traj_end_theta).to(torch.float) |
map_token['traj_src'] = torch.from_numpy(map_token['traj_src']).to(torch.float) |
traj_pos = data['map_save']['traj_pos'].to(torch.float) |
traj_theta = data['map_save']['traj_theta'].to(torch.float) |
pl_idx_list = data['map_save']['pl_idx_list'] |
token_sample_pt = map_token['sample_pt'].to(traj_pos.device) |
token_src = map_token['traj_src'].to(traj_pos.device) |
max_traj_len = map_token['traj_src'].shape[1] |
pl_num = traj_pos.shape[0] |
pt_token_pos = traj_pos[:, 0, :].clone() |
pt_token_orientation = traj_theta.clone() |
cos, sin = traj_theta.cos(), traj_theta.sin() |
rot_mat = traj_theta.new_zeros(pl_num, 2, 2) |
rot_mat[..., 0, 0] = cos |
rot_mat[..., 0, 1] = -sin |
rot_mat[..., 1, 0] = sin |
rot_mat[..., 1, 1] = cos |
traj_pos_local = torch.bmm((traj_pos - traj_pos[:, 0:1]), rot_mat.view(-1, 2, 2)) |
distance = torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1)) ** 2, dim=(-2, -1)) |
pt_token_id = torch.argmin(distance, dim=1) |
noise = False |
if noise: |
topk_indices = torch.argsort(torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1)) ** 2, dim=(-2, -1)), dim=1)[:, :8] |
sample_topk = torch.randint(0, topk_indices.shape[-1], size=(topk_indices.shape[0], 1), device=topk_indices.device) |
pt_token_id = torch.gather(topk_indices, 1, sample_topk).squeeze(-1) |
pl_idx_full = pl_idx_list.clone() |
token2pl = torch.stack([torch.arange(len(pl_idx_list), device=traj_pos.device), pl_idx_full.long()]) |
count_nums = [] |
for pl in pl_idx_full.unique(): |
pt = token2pl[0, token2pl[1, :] == pl] |
left_side = (data['pt_token']['side'][pt] == 0).sum() |
right_side = (data['pt_token']['side'][pt] == 1).sum() |
center_side = (data['pt_token']['side'][pt] == 2).sum() |
count_nums.append(torch.Tensor([left_side, right_side, center_side])) |
count_nums = torch.stack(count_nums, dim=0) |
num_polyline = int(count_nums.max().item()) |
traj_mask = torch.zeros((int(len(pl_idx_full.unique())), 3, num_polyline), dtype=bool) |
idx_matrix = torch.arange(traj_mask.size(2)).unsqueeze(0).unsqueeze(0) |
idx_matrix = idx_matrix.expand(traj_mask.size(0), traj_mask.size(1), -1) |
counts_num_expanded = count_nums.unsqueeze(-1) |
mask_update = idx_matrix < counts_num_expanded |
traj_mask[mask_update] = True |
data['pt_token']['traj_mask'] = traj_mask |
data['pt_token']['position'] = torch.cat([pt_token_pos, torch.zeros((data['pt_token']['num_nodes'], 1), |
device=traj_pos.device, dtype=torch.float)], dim=-1) |
data['pt_token']['orientation'] = pt_token_orientation |
data['pt_token']['height'] = data['pt_token']['position'][:, -1] |
data[('pt_token', 'to', 'map_polygon')] = {} |
data[('pt_token', 'to', 'map_polygon')]['edge_index'] = token2pl |
data['pt_token']['token_idx'] = pt_token_id |
return data |
@safe_run |
def plot_tokenize(data, save_path: str): |
shift = 5 |
token_size = 2048 |
pl2seed_radius = 75 |
transform = WaymoTargetBuilder(num_historical_steps=11, |
num_future_steps=80, |
max_num=32, |
training=False) |
grid_range = 150. |
grid_interval = 3. |
angle_interval = 3. |
attr_tokenizer = Attr_Tokenizer(grid_range=grid_range, |
grid_interval=grid_interval, |
radius=pl2seed_radius, |
angle_interval=angle_interval) |
token_processor = TokenProcessor(token_size, |
training=False, |
predict_motion=True, |
predict_state=True, |
predict_map=True, |
state_token={'invalid': 0, 'valid': 1, 'enter': 2, 'exit': 3}, |
pl2seed_radius=pl2seed_radius) |
CONSOLE.log(f"Loaded token processor with token_size: {token_size}") |
data: HeteroData = transform(data) |
tokenized_data = token_processor(data) |
CONSOLE.log(f"Keys in tokenized data:\n{tokenized_data.keys()}") |
agent_data = tokenized_data['agent'] |
map_data = tokenized_data['map_point'] |
av_index = agent_data['av_index'] |
raw_traj = agent_data['position'][..., :2].contiguous() |
raw_heading = agent_data['heading'] |
traj = agent_data['traj_pos'][..., :2].contiguous() |
traj = traj[:, :, 1:, :].flatten(1, 2) |
traj = torch.cat([raw_traj[:, :1], traj], dim=1) |
heading = agent_data['traj_heading'] |
heading = heading[:, :, 1:].flatten(1, 2) |
heading = torch.cat([raw_heading[:, :1], heading], dim=1) |
agent_state = agent_data['state_idx'].repeat_interleave(repeats=shift, dim=-1) |
agent_state = torch.cat([torch.zeros_like(agent_state[:, :1]), agent_state], dim=1) |
agent_type = agent_data['type'] |
ids = np.arange(raw_traj.shape[0]) |
return plot_scenario( |
scenario_id=tokenized_data['scenario_id'], |
map_data=tokenized_data['map_point']['position'].numpy(), |
traj=raw_traj.numpy(), |
heading=raw_heading.numpy(), |
state=agent_state.numpy(), |
types=list(map(lambda i: AGENT_TYPE[i], agent_type.tolist())), |
av_index=av_index, |
ids=ids, |
save_path=save_path, |
pl2seed_radius=pl2seed_radius, |
attr_tokenizer=attr_tokenizer, |
color_type='state', |
) |
def get_metainfos(folder: str): |
import pandas as pd |
assert os.path.exists(folder), f'Path {folder} does not exist.' |
files = list(fnmatch.filter(os.listdir(folder), 'idx_*_rollouts.pkl')) |
CONSOLE.log(f'Found {len(files)} rollouts files from {folder}.') |
metainfos_path = f'{os.path.normpath(folder)}_metainfos.parquet' |
csv_path = f'{os.path.normpath(folder)}_metainfos.csv' |
if not os.path.exists(metainfos_path): |
data = [] |
for file in tqdm(files): |
pkl_data = pickle.load(open(os.path.join(folder, file), 'rb')) |
data.extend((file, scenario_id, index) for index, scenario_id in enumerate(pkl_data['_scenario_id'])) |
df = pd.DataFrame(data, columns=('rollout_file', 'scenario_id', 'index')) |
df.to_parquet(metainfos_path) |
df.to_csv(csv_path) |
CONSOLE.log(f'Successfully saved to {metainfos_path}.') |
else: |
CONSOLE.log(f'File {metainfos_path} already exists!') |
return |
def plot_comparison(methods: List[str], rollouts_paths: List[str], gt_folders: List[str], |
save_path: str, scenario_ids: Optional[List[str]] = None): |
import pandas as pd |
from collections import defaultdict |
fps = 10 |
plot_time = [1, 6, 12, 18, 24, 30] |
time_idx = [int(time * fps) for time in plot_time] |
limit_size = [75, 60] |
metainfos = defaultdict(dict) |
for method, rollout_path in zip(methods, rollouts_paths): |
meta_info_path = f'{os.path.normpath(rollout_path)}_metainfos.parquet' |
metainfos[method]['df'] = pd.read_parquet(meta_info_path) |
CONSOLE.log(f'Loaded {method=} with {len(metainfos[method]["df"]["scenario_id"])=}.') |
common_scenarios = set(metainfos['ours']['df']['scenario_id']) |
for method, meta_info in metainfos.items(): |
if method == 'ours': |
continue |
common_scenarios &= set(meta_info['df']['scenario_id']) |
for method, meta_info in metainfos.items(): |
df = metainfos[method]['df'] |
metainfos[method]['df'] = df[df['scenario_id'].isin(common_scenarios)] |
CONSOLE.log(f'Filter and get {len(common_scenarios)=}.') |
if scenario_ids is None: |
scenario_ids = metainfos['ours']['df']['scenario_id'].tolist() |
CONSOLE.log(f'Plotting {len(scenario_ids)=} ...') |
for scenario_id in (pbar := tqdm(scenario_ids)): |
pbar.set_postfix(scenario_id=scenario_id) |
figures = dict() |
for method, rollout_path, gt_folder in zip(methods, rollouts_paths, gt_folders): |
df = metainfos[method]['df'] |
_df = df.loc[df['scenario_id'] == scenario_id] |
batch_idx = int(_df['index'].tolist()[0]) |
rollout_file = _df['rollout_file'].tolist()[0] |
figures[method] = plot_file( |
gt_folder=gt_folder, |
files=os.path.join(rollout_path, rollout_file), |
save_gif=False, |
batch_idx=batch_idx, |
time_idx=time_idx, |
limit_size=limit_size, |
color_type='insert', |
)[0][0][0] |
border = 5 |
padding_x = 20 |
padding_y = 50 |
img_width, img_height = figures['ours'][0].size |
img_width = img_width + 2 * border |
img_height = img_height + 2 * border |
n_col = len(time_idx) |
n_row = len(methods) |
W = n_col * img_width + (n_col - 1) * padding_x |
H = n_row * img_height + (n_row - 1) * padding_y |
canvas = Image.new('RGB', (W, H), 'white') |
for i_row, (method, method_figures) in enumerate(figures.items()): |
for i_col, method_figure in enumerate(method_figures): |
x = i_col * (img_width + padding_x) |
y = i_row * (img_height + padding_y) |
padded_figure = Image.new('RGB', (img_width, img_height), 'black') |
padded_figure.paste(method_figure, (border, border)) |
canvas.paste(padded_figure, (x, y)) |
canvas.save( |
os.path.join(save_path, f'{scenario_id}.png') |
) |
if __name__ == "__main__": |
parser = ArgumentParser() |
parser.add_argument('--data_path', type=str, default='/u/xiuyu/work/dev4/data/waymo_processed') |
parser.add_argument('--tfrecord_dir', type=str, default='validation_tfrecords_splitted') |
parser.add_argument('--save_folder', type=str, default='plot_gt') |
parser.add_argument('--split', type=str, default='validation') |
parser.add_argument('--scenario_id', type=str, default=None) |
parser.add_argument('--plot_tokenize', action='store_true') |
parser.add_argument('--plot_file', action='store_true') |
parser.add_argument('--folder_path', type=str, default=None) |
parser.add_argument('--file_path', type=str, default=None) |
parser.add_argument('--get_metainfos', action='store_true') |
parser.add_argument('--plot_comparison', action='store_true') |
parser.add_argument('--comparison_folder', type=str, default='comparisons') |
args = parser.parse_args() |
if args.plot_tokenize: |
scenario_id = "74ad7b76d5906d39" |
data_path = os.path.join(args.data_path, args.split, f"{scenario_id}.pkl") |
data = pickle.load(open(data_path, "rb")) |
data['tfrecord_path'] = os.path.join(args.tfrecord_dir, f'{scenario_id}.tfrecords') |
CONSOLE.log(f"Loaded scenario {scenario_id}") |
save_path = os.path.join(args.data_path, args.save_folder, args.split) |
os.makedirs(save_path, exist_ok=True) |
plot_tokenize(data, save_path) |
if args.plot_file: |
plot_file(args.data_path, folder=args.folder_path, files=args.file_path) |
if args.get_metainfos: |
assert args.folder_path is not None, f'`folder_path` should not be None!' |
get_metainfos(args.folder_path) |
if args.plot_comparison: |
methods = ['ours', 'smart'] |
gt_folders = [ |
'/baai-cwm-1/baai_cwm_ml/algorithm/xiuyu.yang/work/dev4/data/processed', |
'/baai-cwm-1/baai_cwm_ml/algorithm/xiuyu.yang/work/dev4/data/processed', |
'/baai-cwm-1/baai_cwm_ml/algorithm/xiuyu.yang/work/dev4/data/processed', |
] |
rollouts_paths = [ |
'/baai-cwm-1/baai_cwm_ml/algorithm/xiuyu.yang/work/dev4/output/scalable_smart_long/validation_ours0', |
'/baai-cwm-1/baai_cwm_ml/algorithm/xiuyu.yang/work/dev4/output/scalable_smart_long/validation_smart', |
'/baai-cwm-1/baai_cwm_ml/algorithm/xiuyu.yang/work/dev4/output/scalable_smart_long/validation_cslft', |
] |
save_path = f'/baai-cwm-1/baai_cwm_ml/algorithm/xiuyu.yang/work/dev4/output/scalable_smart_long/{args.comparison_folder}/' |
os.makedirs(save_path, exist_ok=True) |
scenario_ids = ['72ff3e1540b28431','a16c927b1a1cca74','a504d55ea6658de7','639949ea1d16125b'] |
plot_comparison(methods, rollouts_paths, gt_folders, |
save_path=save_path, |
scenario_ids=scenario_ids) |