import math
import os
import torch
import pickle
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import numpy.typing as npt
import fnmatch
import io
import seaborn as sns
import matplotlib.axes as Axes
import matplotlib.transforms as mtransforms
from PIL import Image
from functools import wraps
from typing import Sequence, Union, Optional
from tqdm import tqdm
from typing import List, Literal
from argparse import ArgumentParser
from scipy.ndimage.filters import gaussian_filter
from matplotlib.patches import FancyBboxPatch, Polygon, Rectangle, Circle
from matplotlib.collections import LineCollection
from torch_geometric.data import HeteroData, Dataset
from waymo_open_dataset.protos import scenario_pb2

from dev.utils.func import CONSOLE
from dev.modules.attr_tokenizer import Attr_Tokenizer
from dev.datasets.preprocess import TokenProcessor, cal_polygon_contour, AGENT_TYPE
from dev.datasets.scalable_dataset import WaymoTargetBuilder


__all__ = ['plot_occ_grid', 'plot_interact_edge', 'plot_map_edge', 'plot_insert_grid', 'plot_binary_map',
           'plot_map_token', 'plot_prob_seed', 'plot_scenario', 'get_heatmap', 'draw_heatmap', 'plot_val', 'plot_tokenize']


def safe_run(func):

    @wraps(func)
    def wrapper1(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except Exception as e:
            print(e)
            return

    @wraps(func)
    def wrapper2(*args, **kwargs):
        return func(*args, **kwargs)

    if int(os.getenv('DEBUG', 0)):
        return wrapper2
    else:
        return wrapper1


@safe_run
def plot_occ_grid(scenario_id, occ, gt_occ=None, save_path='', mode='agent', prefix=''):

    def generate_box_edges(matrix, find_value=1):
        y, x = np.where(matrix == find_value)
        edges = []

        for xi, yi in zip(x, y):
            edges.append([(xi - 0.5, yi - 0.5), (xi + 0.5, yi - 0.5)])
            edges.append([(xi + 0.5, yi - 0.5), (xi + 0.5, yi + 0.5)])
            edges.append([(xi + 0.5, yi + 0.5), (xi - 0.5, yi + 0.5)])
            edges.append([(xi - 0.5, yi + 0.5), (xi - 0.5, yi - 0.5)])
    
        return edges

    os.makedirs(save_path, exist_ok=True)
    n = int(math.sqrt(occ.shape[-1]))

    plot_n = 3
    plot_t = 5

    occ_list = []
    for i in range(plot_n):
        for j in range(plot_t):
            occ_list.append(occ[i, j].reshape(n, n))

    occ_gt_list = []
    if gt_occ is not None:
        for i in range(plot_n):
            for j in range(plot_t):
                occ_gt_list.append(gt_occ[i, j].reshape(n, n))

    row_labels = [f'n={n}' for n in range(plot_n)]
    col_labels = [f't={t}' for t in range(plot_t)]

    fig, axes = plt.subplots(plot_n, plot_t, figsize=(9, 6))
    plt.subplots_adjust(wspace=0.1, hspace=0.1)

    for i, ax in enumerate(axes.flat):
        # NOTE: do not set vmin and vamx!
        ax.imshow(occ_list[i], cmap='viridis', interpolation='nearest')
        ax.axis('off')

        if occ_gt_list:
            gt_edges = generate_box_edges(occ_gt_list[i])
            gts = LineCollection(gt_edges, colors='blue', linewidths=0.5)
            ax.add_collection(gts)
            insert_edges = generate_box_edges(occ_gt_list[i], find_value=-1)
            inserts = LineCollection(insert_edges, colors='red', linewidths=0.5)
            ax.add_collection(inserts)

        ax.add_patch(plt.Rectangle((-0.5, -0.5), occ_list[i].shape[1], occ_list[i].shape[0],
                                linewidth=2, edgecolor='black', facecolor='none'))

    for i, ax in enumerate(axes[:, 0]):
        ax.annotate(row_labels[i], xy=(-0.1, 0.5), xycoords="axes fraction",
                    fontsize=12, ha="right", va="center", rotation=0)

    for j, ax in enumerate(axes[0, :]):
        ax.annotate(col_labels[j], xy=(0.5, 1.05), xycoords="axes fraction",
                    fontsize=12, ha="center", va="bottom")

    plt.savefig(os.path.join(save_path, f'{prefix}{scenario_id}_occ_{mode}.png'), dpi=500, bbox_inches='tight')
    plt.close()


@safe_run
def plot_interact_edge(edge_index, scenario_ids, batch_sizes, num_seed, num_step, save_path='interact_edge_map',
                        **kwargs):

    num_batch = len(scenario_ids)
    batches = torch.cat([
        torch.arange(num_batch).repeat_interleave(repeats=batch_sizes, dim=0),
        torch.arange(num_batch).repeat_interleave(repeats=num_seed, dim=0),
    ], dim=0).repeat(num_step).numpy()

    num_agent = batch_sizes.sum() + num_seed * num_batch
    batch_sizes = torch.nn.functional.pad(batch_sizes, (1, 0), mode='constant', value=0)
    ptr = torch.cumsum(batch_sizes, dim=0)
    # assume difference scenarios and different timestep have the same number of seed agents
    ptr_seed = torch.tensor(np.array([0] + [num_seed] * num_batch), device=ptr.device)

    all_av_index = None
    if 'av_index' in kwargs:
        all_av_index = kwargs.pop('av_index').cpu() - ptr[:-1]

    is_bos = np.zeros((batch_sizes.sum(), num_step)).astype(np.bool_)
    if 'is_bos' in kwargs:
        is_bos = kwargs.pop('is_bos').cpu().numpy()

    src_index = torch.unique(edge_index[1])
    for idx, src in enumerate(tqdm(src_index)):

        src_batch = batches[src]

        src_row = src % num_agent
        if src_row // batch_sizes.sum() > 0:
            seed_row = src_row % batch_sizes.sum() - ptr_seed[src_batch]
            src_row = batch_sizes[src_batch + 1] + seed_row
        else:
            src_row = src_row - ptr[src_batch]

        src_col = src // (num_agent)
        src_mask = np.zeros((batch_sizes[src_batch + 1] + num_seed, num_step))
        src_mask[src_row, src_col] = 1

        tgt_mask = np.zeros((src_mask.shape[0], num_step))
        tgt_index = edge_index[0, edge_index[1] == src]
        for tgt in tgt_index:

            tgt_batch = batches[tgt]

            tgt_row = tgt % num_agent
            if tgt_row // batch_sizes.sum() > 0:
                seed_row = tgt_row % batch_sizes.sum() - ptr_seed[tgt_batch]
                tgt_row = batch_sizes[tgt_batch + 1] + seed_row
            else:
                tgt_row = tgt_row - ptr[tgt_batch]

            tgt_col = tgt // num_agent
            tgt_mask[tgt_row, tgt_col] = 1
            assert tgt_batch == src_batch

        selected_step = tgt_mask.sum(axis=0) > 0
        if selected_step.sum() > 1:
            print(f"\nidx={idx}", src.item(), src_row.item(), src_col.item())
            print(selected_step)
            print(edge_index[:, edge_index[1] == src].tolist())

        if all_av_index is not None:
            kwargs['av_index'] = int(all_av_index[src_batch])

        t = kwargs.get('t', src_col)
        n = kwargs.get('n', 0)
        is_bos_batch = is_bos[ptr[src_batch] : ptr[src_batch + 1]]
        plot_binary_map(src_mask, tgt_mask, save_path, suffix=f'_{scenario_ids[src_batch]}_{t:02d}_{n:02d}_{idx:04d}',
                        is_bos=is_bos_batch, **kwargs)


@safe_run
def plot_map_edge(edge_index, pos_a, data, save_path='map_edge_map'):

    map_points = data['map_point']['position'][:, :2].cpu().numpy()
    token_pos = data['pt_token']['position'][:, :2].cpu().numpy()
    token_heading = data['pt_token']['orientation'].cpu().numpy()
    num_pt = token_pos.shape[0]
    
    agent_index = torch.unique(edge_index[1])
    for i in tqdm(agent_index):
        xy = pos_a[i].cpu().numpy()
        pt_index = edge_index[0, edge_index[1] == i].cpu().numpy()
        pt_index = pt_index % num_pt

        plt.subplots_adjust(left=0.3, right=0.7, top=0.7, bottom=0.3)
        _, ax = plt.subplots()
        ax.set_axis_off()

        plot_map_token(ax, map_points, token_pos[pt_index], token_heading[pt_index], colors='blue')

        ax.scatter(xy[0], xy[1], s=0.5, c='red', edgecolors='none')

        os.makedirs(save_path, exist_ok=True)
        plt.savefig(os.path.join(save_path, f'map_{i}.png'), dpi=600, bbox_inches='tight')
        plt.close()


def get_heatmap(x, y, prob, s=3, bins=1000):
    heatmap, xedges, yedges = np.histogram2d(x, y, bins=bins, weights=prob, density=True)

    heatmap = gaussian_filter(heatmap, sigma=s)

    extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
    return heatmap.T, extent


@safe_run
def draw_heatmap(vector, vector_prob, gt_idx):
    fig, ax = plt.subplots(figsize=(10, 10))
    vector_prob = vector_prob.cpu().numpy()

    for j in range(vector.shape[0]):
        if j in gt_idx:
            color = (0, 0, 1)
        else:
            grey_scale = max(0, 0.9 - vector_prob[j])
            color = (0.9, grey_scale, grey_scale)

        # if lane[j, k, -1] == 0: continue
        x0, y0, x1, y1, = vector[j, :4]
        ax.plot((x0, x1), (y0, y1), color=color, linewidth=2)

    return plt


@safe_run
def plot_insert_grid(scenario_id, prob, grid, ego_pos, map, save_path='', prefix='', inference=False, indices=None,
                     all_t_in_one=False):

    """
        prob: float array of shape (num_step, num_grid)
        grid: float array of shape (num_grid, 2)
    """

    os.makedirs(save_path, exist_ok=True)

    n = int(math.sqrt(prob.shape[1]))

    # grid = grid[:, np.newaxis] + ego_pos[np.newaxis, ...]
    for t in range(ego_pos.shape[0]):

        plt.subplots_adjust(left=0.3, right=0.7, top=0.7, bottom=0.3)
        _, ax = plt.subplots()

        # plot probability
        prob_t = prob[t].reshape(n, n)
        plt.imshow(prob_t, cmap='viridis', interpolation='nearest')

        if indices is not None:
            indice = indices[t]

            if isinstance(indice, (int, float, np.int_)):
                indice = [indice]

            for _indice in indice:
                if _indice == -1: continue

                row = _indice // n
                col = _indice % n

                rect = Rectangle((col - 0.5, row - 0.5), 1, 1, edgecolor='red', facecolor='none', lw=2)
                ax.add_patch(rect)

        ax.grid(False)
        ax.set_aspect('equal', adjustable='box')

        plt.title('Prob of Rel Position Grid')
        plt.savefig(os.path.join(save_path, f'{prefix}{scenario_id}_heat_map_{t}.png'), dpi=300, bbox_inches='tight')
        plt.close()

        if all_t_in_one:
            break


@safe_run
def plot_insert_grid(scenario_id, prob, indices=None, save_path='', prefix='', inference=False):

    """
        prob: float array of shape (num_seed, num_step, num_grid)
        grid: float array of shape (num_grid, 2)
    """

    os.makedirs(save_path, exist_ok=True)

    n = int(math.sqrt(prob.shape[-1]))

    plot_n = 3
    plot_t = 5

    prob_list = []
    for i in range(plot_n):
        for j in range(plot_t):
            prob_list.append(prob[i, j].reshape(n, n))

    indice_list = []
    if indices is not None:
        for i in range(plot_n):
            for j in range(plot_t):
                indice_list.append(indices[i, j])

    row_labels = [f'n={n}' for n in range(plot_n)]
    col_labels = [f't={t}' for t in range(plot_t)]

    fig, axes = plt.subplots(plot_n, plot_t, figsize=(9, 6))
    fig.suptitle('Prob of Insert Position Grid')
    plt.subplots_adjust(wspace=0.1, hspace=0.1)

    for i, ax in enumerate(axes.flat):
        ax.imshow(prob_list[i], cmap='viridis', interpolation='nearest')
        ax.axis('off')

        if indice_list:
            row = indice_list[i] // n
            col = indice_list[i] % n
            rect = Rectangle((col - .5, row - .5), 1, 1, edgecolor='red', facecolor='none', lw=2)
            ax.add_patch(rect)

        ax.add_patch(plt.Rectangle((-0.5, -0.5), prob_list[i].shape[1], prob_list[i].shape[0],
                                linewidth=2, edgecolor='black', facecolor='none'))

    for i, ax in enumerate(axes[:, 0]):
        ax.annotate(row_labels[i], xy=(-0.1, 0.5), xycoords="axes fraction",
                    fontsize=12, ha="right", va="center", rotation=0)

    for j, ax in enumerate(axes[0, :]):
        ax.annotate(col_labels[j], xy=(0.5, 1.05), xycoords="axes fraction",
                    fontsize=12, ha="center", va="bottom")

    ax.grid(False)
    ax.set_aspect('equal', adjustable='box')

    plt.savefig(os.path.join(save_path, f'{prefix}{scenario_id}_insert_map.png'), dpi=500, bbox_inches='tight')
    plt.close()


@safe_run
def plot_binary_map(src_mask, tgt_mask, save_path='', suffix='', av_index=None, is_bos=None, **kwargs):

    from matplotlib.colors import ListedColormap
    os.makedirs(save_path, exist_ok=True)

    fig, axes = plt.subplots(1, 2, figsize=(10, 8))

    title = []
    if kwargs.get('t', None) is not None:
        t = kwargs['t']
        title.append(f't={t}')

    if kwargs.get('n', None) is not None:
        n = kwargs['n']
        title.append(f'n={n}')

    plt.title(' '.join(title))

    cmap = ListedColormap(['white', 'green'])
    axes[0].imshow(src_mask, cmap=cmap, interpolation='nearest')

    cmap = ListedColormap(['white', 'orange'])
    axes[1].imshow(tgt_mask, cmap=cmap, interpolation='nearest')

    if av_index is not None:
        rect = Rectangle((-0.5, av_index - 0.5), src_mask.shape[1], 1, edgecolor='red', facecolor='none', lw=2)
        axes[0].add_patch(rect)
        rect = Rectangle((-0.5, av_index - 0.5), tgt_mask.shape[1], 1, edgecolor='red', facecolor='none', lw=2)
        axes[1].add_patch(rect)

    if is_bos is not None:
        rows, cols = np.where(is_bos)
        for row, col in zip(rows, cols):
            rect = Rectangle((col - 0.5, row - 0.5), 1, 1, edgecolor='blue', facecolor='none', lw=1)
            axes[0].add_patch(rect)
            rect = Rectangle((col - 0.5, row - 0.5), 1, 1, edgecolor='blue', facecolor='none', lw=1)
            axes[1].add_patch(rect)

    for ax in axes:
        ax.set_xticks(range(src_mask.shape[1] + 1), minor=False)
        ax.set_yticks(range(src_mask.shape[0] + 1), minor=False)
        ax.grid(which='major', color='gray', linestyle='--', linewidth=0.5)

    plt.savefig(os.path.join(save_path, f'map{suffix}.png'), dpi=300, bbox_inches='tight')
    plt.close()


@safe_run
def plot_prob_seed(scenario_id, prob, save_path, prefix='', indices=None):

    os.makedirs(save_path, exist_ok=True)

    plt.figure(figsize=(8, 5))
    plt.imshow(prob, cmap='viridis', aspect='auto')
    plt.colorbar()

    plt.title('Seed Probability')

    if indices is not None:

        for col in range(indices.shape[1]):
            for row in indices[:, col]:

                if row == -1: continue

                rect = Rectangle((col - 0.5, row - 0.5), 1, 1, edgecolor='red', facecolor='none', lw=2)
                plt.gca().add_patch(rect)

    plt.tight_layout()
    plt.savefig(os.path.join(save_path, f'{prefix}{scenario_id}_prob_seed.png'), dpi=300, bbox_inches='tight')
    plt.close()


@safe_run
def plot_raw():
    plt.figure(figsize=(30, 30))
    plt.rcParams['axes.facecolor']='white'

    data_path = '/u/xiuyu/work/dev4/data/waymo/scenario/training'
    os.makedirs("data/vis/raw/0/", exist_ok=True)
    file_list = os.listdir(data_path)

    for cnt_file, file in enumerate(file_list):
        file_path = os.path.join(data_path, file)
        dataset = tf.data.TFRecordDataset(file_path, compression_type='')
        for scenario_idx, data in enumerate(dataset):
            scenario = scenario_pb2.Scenario()
            scenario.ParseFromString(bytearray(data.numpy()))
            tqdm.write(f"scenario id: {scenario.scenario_id}")

            # draw maps
            for i in range(len(scenario.map_features)):

                # draw lanes
                if str(scenario.map_features[i].lane) != '':
                    line_x = [z.x for z in scenario.map_features[i].lane.polyline]
                    line_y = [z.y for z in scenario.map_features[i].lane.polyline]
                    plt.scatter(line_x, line_y, c='g', s=5)
                    plt.text(line_x[0], line_y[0], str(scenario.map_features[i].id), fontdict={'family': 'serif', 'size': 20, 'color': 'green'})

                # draw road_edge
                if str(scenario.map_features[i].road_edge) != '':
                    road_edge_x = [polyline.x for polyline in scenario.map_features[i].road_edge.polyline]
                    road_edge_y = [polyline.y for polyline in scenario.map_features[i].road_edge.polyline]
                    plt.scatter(road_edge_x, road_edge_y)
                    plt.text(road_edge_x[0], road_edge_y[0], scenario.map_features[i].road_edge.type, fontdict={'family': 'serif', 'size': 20, 'color': 'black'})
                    if scenario.map_features[i].road_edge.type == 2:
                        plt.scatter(road_edge_x, road_edge_y, c='k')
                    elif scenario.map_features[i].road_edge.type == 3:
                        plt.scatter(road_edge_x, road_edge_y, c='purple')
                        print(scenario.map_features[i].road_edge)
                    else:
                        plt.scatter(road_edge_x, road_edge_y, c='k')

                # draw road_line
                if str(scenario.map_features[i].road_line) != '':
                    road_line_x = [j.x for j in scenario.map_features[i].road_line.polyline]
                    road_line_y = [j.y for j in scenario.map_features[i].road_line.polyline]
                    if scenario.map_features[i].road_line.type == 7:
                        plt.plot(road_line_x, road_line_y, c='y')
                    elif scenario.map_features[i].road_line.type == 8:
                        plt.plot(road_line_x, road_line_y, c='y') 
                    elif scenario.map_features[i].road_line.type == 6:
                        plt.plot(road_line_x, road_line_y, c='y')
                    elif scenario.map_features[i].road_line.type == 1:
                        for i in range(int(len(road_line_x) / 7)):
                            plt.plot(road_line_x[i * 7 : 5 + i * 7], road_line_y[i * 7 : 5 + i * 7], color='w')
                    elif scenario.map_features[i].road_line.type == 2:
                        plt.plot(road_line_x, road_line_y, c='w')
                    else:
                        plt.plot(road_line_x, road_line_y, c='w')
            
            # draw tracks
            scenario_has_invalid_tracks = False
            for i in range(len(scenario.tracks)):
                traj_x = [center.center_x for center in scenario.tracks[i].states]
                traj_y = [center.center_y for center in scenario.tracks[i].states]
                head = [center.heading for center in scenario.tracks[i].states]
                valid = [center.valid for center in scenario.tracks[i].states]
                print(valid)
                if i == scenario.sdc_track_index:
                    plt.scatter(traj_x[0], traj_y[0], s=140, c='r', marker='s')
                    plt.scatter([x for x, v in zip(traj_x, valid) if v],
                                [y for y, v in zip(traj_y, valid) if v], s=14, c='r')
                    plt.scatter([x for x, v in zip(traj_x, valid) if not v],
                                [y for y, v in zip(traj_y, valid) if not v], s=14, c='m')
                else:
                    plt.scatter(traj_x[0], traj_y[0], s=140, c='k', marker='s')
                    plt.scatter([x for x, v in zip(traj_x, valid) if v],
                                [y for y, v in zip(traj_y, valid) if v], s=14, c='b')
                    plt.scatter([x for x, v in zip(traj_x, valid) if not v],
                                [y for y, v in zip(traj_y, valid) if not v], s=14, c='m')
                if valid.count(False) > 0:
                    scenario_has_invalid_tracks = True
            if scenario_has_invalid_tracks:
                plt.savefig(f"scenario_{scenario_idx}_{scenario.scenario_id}.png")
                plt.clf()
                breakpoint()
        break


colors = [
    ('#1f77b4', '#1a5a8a'),  # blue
    ('#2ca02c', '#217721'),  # green
    ('#ff7f0e', '#cc660b'),  # orange
    ('#9467bd', '#6f4a91'),  # purple
    ('#d62728', '#a31d1d'),  # red
    ('#000000', '#000000'),  # black
]

@safe_run
def plot_gif():
    data_path = "/u/xiuyu/work/dev4/data/waymo_processed/training"
    os.makedirs("data/vis/processed/0/gif", exist_ok=True)
    file_list = os.listdir(data_path)

    for scenario_idx, file in tqdm(enumerate(file_list), leave=False, desc="Scenario"):

        fig, ax = plt.subplots()
        ax.set_axis_off()

        file_path = os.path.join(data_path, file)
        data = pickle.load(open(file_path, "rb"))
        scenario_id = data['scenario_id']

        save_path = os.path.join("data/vis/processed/0/gif",
                                 f"scenario_{scenario_idx}_{scenario_id}.gif")
        if os.path.exists(save_path):
            tqdm.write(f"Skipped {save_path}.")
            continue

        # draw maps
        ax.scatter(data['map_point']['position'][:, 0],
                   data['map_point']['position'][:, 1], s=0.2, c='black', edgecolors='none')

        # draw agents
        agent_data = data['agent']
        av_index = agent_data['av_index']
        position = agent_data['position'] # (num_agent, 91, 3)
        heading = agent_data['heading'] # (num_agent, 91)
        shape = agent_data['shape'] # (num_agent, 91, 3)
        category = agent_data['category'] # (num_agent,)
        valid_mask = (position[..., 0] != 0) & (position[..., 1] != 0) # (num_agent, 91)

        num_agent = valid_mask.shape[0]
        num_timestep = position.shape[1]
        is_av = np.arange(num_agent) == int(av_index)

        is_blue = valid_mask.sum(axis=1) == num_timestep
        is_green = ~valid_mask[:, 0] & valid_mask[:, -1]
        is_orange = valid_mask[:, 0] & ~valid_mask[:, -1]
        is_purple = (valid_mask.sum(axis=1) != num_timestep
                    ) & (~is_green) & (~is_orange)
        agent_colors = np.zeros((num_agent,))
        agent_colors[is_blue] = 1
        agent_colors[is_green] = 2
        agent_colors[is_orange] = 3
        agent_colors[is_purple] = 4
        agent_colors[is_av] = 5

        veh_mask = category == 1
        ped_mask = category == 2
        cyc_mask = category == 3
        shape[veh_mask, :, 1] = 1.8
        shape[veh_mask, :, 0] = 1.8
        shape[ped_mask, :, 1] = 0.5
        shape[ped_mask, :, 0] = 0.5
        shape[cyc_mask, :, 1] = 1.0
        shape[cyc_mask, :, 0] = 1.0

        fig_paths = []
        for tid in tqdm(range(num_timestep), leave=False, desc="Timestep"):
            current_valid_mask = valid_mask[:, tid]
            xs = position[current_valid_mask, tid, 0]
            ys = position[current_valid_mask, tid, 1]
            widths = shape[current_valid_mask, tid, 1]
            lengths = shape[current_valid_mask, tid, 0]
            angles = heading[current_valid_mask, tid]
            current_agent_colors = agent_colors[current_valid_mask]

            drawn_agents = []
            contours = cal_polygon_contour(xs, ys, angles, widths, lengths) # (num_agent, 4, 2)
            contours = np.concatenate([contours, contours[:, 0:1]], axis=1) # (num_agent, 5, 2)
            for x, y, width, length, angle, color_type in zip(
                xs, ys, widths, lengths, angles, current_agent_colors):
                agent = plt.Rectangle((x, y), width, length, angle=((angle + np.pi / 2) / np.pi * 360) % 360,
                                      linewidth=0.2,
                                      facecolor=colors[int(color_type) - 1][0],
                                      edgecolor=colors[int(color_type) - 1][1])
                ax.add_patch(agent)
                drawn_agents.append(agent)
            plt.gca().set_aspect('equal', adjustable='box')
            # for contour, color_type in zip(contours, agent_colors):
            #     drawn_agent = ax.plot(contour[:, 0], contour[:, 1])
            #     drawn_agents.append(drawn_agent)

            fig_path = os.path.join("data/vis/processed/0/",
                                    f"scenario_{scenario_idx}_{scenario_id}_{tid}.png")
            plt.savefig(fig_path, dpi=600)
            fig_paths.append(fig_path)

            for drawn_agent in drawn_agents:
                drawn_agent.remove()

        plt.close()

        # generate gif
        import imageio.v2 as imageio
        images = []
        for fig_path in tqdm(fig_paths, leave=False, desc="Generate gif ..."):
            images.append(imageio.imread(fig_path))
        imageio.mimsave(save_path, images, duration=0.1)


@safe_run
def plot_map_token(ax: Axes, map_points: npt.NDArray, token_pos: npt.NDArray, token_heading: npt.NDArray, colors: Union[str, npt.NDArray]=None):

    plot_map(ax, map_points)

    x, y = token_pos[:, 0], token_pos[:, 1]
    u = np.cos(token_heading)
    v = np.sin(token_heading)

    if colors is None:
        colors = np.random.rand(x.shape[0], 3)
    ax.quiver(x, y, u, v, angles='xy', scale_units='xy', scale=0.2, color=colors, width=0.005,
               headwidth=0.2, headlength=2)
    ax.scatter(x, y, color='blue', s=0.2, edgecolors='none')
    ax.axis("equal")


@safe_run
def plot_map(ax: Axes, map_points: npt.NDArray, color='black'):
    ax.scatter(map_points[:, 0], map_points[:, 1], s=0.5, c=color, edgecolors='none')

    xmin = np.min(map_points[:, 0])
    xmax = np.max(map_points[:, 0])
    ymin = np.min(map_points[:, 1])
    ymax = np.max(map_points[:, 1])
    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)


@safe_run
def plot_agent(ax: Axes, xy: Sequence[float], heading: float, type: str, state, is_av: bool=False,
               pl2seed_radius: float=25., attr_tokenizer: Attr_Tokenizer=None, enter_index: list=[], **kwargs):

    if type == 'veh':
        length = 4.3
        width = 1.8
        size = 1.0
    elif type == 'ped':
        length = 0.5
        width = 0.5
        size = 0.1
    elif type == 'cyc':
        length = 1.9
        width = 0.5
        size = 0.3
    else:
        raise ValueError(f"Unsupported agent type {type}")

    if kwargs.get('label', None) is not None:
        ax.text(
            xy[0] + 1.5, xy[1] + 1.5,
            kwargs['label'], fontsize=2, color="darkred", ha="center", va="center"
        )

    patch = FancyBboxPatch([-length / 2, -width / 2], length, width, linewidth=.2, **kwargs)
    transform = (
        mtransforms.Affine2D().rotate(heading).translate(xy[0], xy[1])
        + ax.transData
    )
    patch.set_transform(transform)

    kwargs['label'] = None
    angles = [0, 2 * np.pi / 3, np.pi, 4 * np.pi / 3]
    pts = np.stack([size * np.cos(angles), size * np.sin(angles)], axis=-1)
    center_patch = Polygon(pts, zorder=10., linewidth=.2, **kwargs)
    center_patch.set_transform(transform)

    ax.add_patch(patch)
    ax.add_patch(center_patch)

    if is_av:

        if attr_tokenizer is not None:

            circle_patch = Circle(
                (xy[0], xy[1]), pl2seed_radius, linewidth=0.5, edgecolor='gray', linestyle='--', facecolor='none'
            )
            ax.add_patch(circle_patch)

            grid = attr_tokenizer.get_grid(torch.tensor(np.array(xy)).float(),
                                           torch.tensor(np.array([heading])).float()).numpy()[0] # (num_grid, 2)
            ax.scatter(grid[:, 0], grid[:, 1], s=0.3, c='blue', edgecolors='none')
            ax.text(grid[0, 0], grid[0, 1], 'Front', fontsize=2, color='darkred', ha='center', va='center')
            ax.text(grid[-1, 0], grid[-1, 1], 'Back', fontsize=2, color='darkred', ha='center', va='center')

        if enter_index:
            for i in enter_index:
                ax.plot(grid[int(i), 0], grid[int(i), 1], marker='x', color='red', markersize=1)

    return patch, center_patch


@safe_run
def plot_all(map, xs, ys, angles, types, colors, is_avs, pl2seed_radius: float=25.,
             attr_tokenizer: Attr_Tokenizer=None, enter_index: list=[], labels: list=[], **kwargs):

    plt.subplots_adjust(left=0.3, right=0.7, top=0.7, bottom=0.3)
    _, ax = plt.subplots()
    ax.set_axis_off()

    plot_map(ax, map)

    if not labels:
        labels = [None] * xs.shape[0]

    for x, y, angle, type, color, label, is_av in zip(xs, ys, angles, types, colors, labels, is_avs):
        assert type in ('veh', 'ped', 'cyc'), f"Unsupported type {type}."
        plot_agent(ax, [x, y], angle.item(), type, None, is_av, facecolor=color, edgecolor='k', label=label,
                   pl2seed_radius=pl2seed_radius, attr_tokenizer=attr_tokenizer, enter_index=enter_index)

    ax.grid(False)
    ax.set_aspect('equal', adjustable='box')

    # ! set plot limit if need
    if kwargs.get('limit_size', None):
        cx = float(xs[is_avs])
        cy = float(ys[is_avs])

        lx, ly = kwargs['limit_size']
        xmin, xmax = cx - lx, cx + lx
        ymin, ymax = cy - ly, cy + ly

        ax.set_xlim(xmin, xmax)
        ax.set_ylim(ymin, ymax)

    # ax.legend(loc='best', frameon=True)

    pil_image = None
    if kwargs.get('save_path', None):
        plt.savefig(kwargs['save_path'], dpi=600, bbox_inches="tight")

    else:
        # !convert to PIL image
        buf = io.BytesIO()
        plt.savefig(buf, format='png', dpi=600, bbox_inches='tight')
        buf.seek(0)
        pil_image = Image.open(buf).convert('RGB')

    plt.close()

    return pil_image


@safe_run
def plot_file(gt_folder: str,
              folder: Optional[str] = None,
              files: Optional[str] = None,
              save_gif: bool = True,
              batch_idx: Optional[int] = None,
              time_idx: Optional[List[int]] = None,
              limit_size: Optional[List[int]] = None,
              **kwargs,
    ) -> List[Image.Image]:

    from dev.metrics.compute_metrics import _unbatch

    shift = 5

    if files is None:
        assert os.path.exists(folder), f'Path {folder} does not exist.'
        files = list(fnmatch.filter(os.listdir(folder), 'idx_*_rollouts.pkl'))
        CONSOLE.log(f'Found {len(files)} rollouts files from {folder}.')

    if folder is None:
        assert os.path.exists(files), f'Path {files} does not exist.'
        folder = os.path.dirname(files)
        files = [files]

    parent, folder_name = os.path.split(folder.rstrip(os.sep))
    if save_gif:
        save_path = os.path.join(parent, f'{folder_name}_plots')
    else:
        save_path = None

    file_outs = []
    for file in (pbar := tqdm(files, leave=False, desc='Plotting files ...')):
        pbar.set_postfix(file=file)

        with open(os.path.join(folder, file), 'rb') as f:
            preds = pickle.load(f)

        scenario_ids = preds['_scenario_id']
        agent_batch = preds['agent_batch']
        agent_id = _unbatch(preds['agent_id'], agent_batch)
        preds_traj = _unbatch(preds['pred_traj'], agent_batch)
        preds_head = _unbatch(preds['pred_head'], agent_batch)
        preds_type = _unbatch(preds['pred_type'], agent_batch)
        if 'pred_state' in preds:
            preds_state = _unbatch(preds['pred_state'], agent_batch)
        else:
            preds_state = tuple([torch.ones((*traj.shape[:2], traj.shape[2] // shift)) for traj in preds_traj])  # [n_agent, n_rollout, n_step2Hz]
        preds_valid = _unbatch(preds['pred_valid'], agent_batch)

        # ! fetch certain scenario
        if batch_idx is not None:
            scenario_ids = scenario_ids[batch_idx : batch_idx + 1]
            agent_id = (agent_id[batch_idx],)
            preds_traj = (preds_traj[batch_idx],)
            preds_head = (preds_head[batch_idx],)
            preds_type = (preds_type[batch_idx],)
            preds_state = (preds_state[batch_idx],)
            preds_valid = (preds_valid[batch_idx],)

        scenario_outs = []
        for i, scenario_id in enumerate(scenario_ids):
            n_agent, n_rollouts = preds_traj[0].shape[:2]

            rollout_outs = []
            for j in range(n_rollouts):  # 1
                pred = dict(scenario_id=[scenario_id],
                            pred_traj=preds_traj[i][:, j],
                            pred_head=preds_head[i][:, j],
                            pred_state=(
                                torch.cat([torch.zeros(n_agent, 1), preds_state[i][:, j].repeat_interleave(repeats=shift, dim=-1)],
                                          dim=1)
                            ),
                            pred_type=preds_type[i][:, j],
                    )

                # NOTE: hard code!!!
                if 'av_id' in preds:
                    av_index = agent_id[i][:, 0].tolist().index(preds['av_id'])
                else:
                    av_index = n_agent - 1

                # ! load logged data
                data_path = os.path.join(gt_folder, 'validation', f'{scenario_id}.pkl')
                with open(data_path, 'rb') as f:
                    data = pickle.load(f)

                rollout_outs.append(
                    plot_val(data, pred,
                             av_index=av_index,
                             save_path=save_path,
                             save_gif=save_gif,
                             time_idx=time_idx,
                             limit_size=limit_size,
                             **kwargs
                    )
                )

            scenario_outs.append(rollout_outs)
        file_outs.append(scenario_outs)

    return file_outs


@safe_run
def plot_val(data: Union[dict, str], pred: dict, av_index: int, save_path: str, suffix: str='',
             pl2seed_radius: float=75., attr_tokenizer=None, **kwargs):

    if isinstance(data, str):
        assert data.endswith('.pkl'), f'Got invalid data path {data}.'
        assert os.path.exists(data), f'Path {data} does not exist.'
        with open(data, 'rb') as f:
            data = pickle.load(f)

    map_point = data['map_point']['position'].cpu().numpy()

    scenario_id = pred['scenario_id'][0]
    pred_traj = pred['pred_traj'].cpu().numpy() # (num_agent, num_future_step, 2)
    pred_type = list(map(lambda i: AGENT_TYPE[i], pred['pred_type'].tolist()))
    pred_state = pred['pred_state'].cpu().numpy()
    pred_head = pred['pred_head'].cpu().numpy()
    ids = np.arange(pred_traj.shape[0])

    if 'agent_labels' in pred:
        kwargs.update(agent_labels=pred['agent_labels'])

    return plot_scenario(scenario_id, map_point, pred_traj, pred_head, pred_state, pred_type,
                         av_index=av_index, ids=ids, save_path=save_path, suffix=suffix,
                         pl2seed_radius=pl2seed_radius, attr_tokenizer=attr_tokenizer, **kwargs)


@safe_run
def plot_scenario(scenario_id: str,
                  map_data: npt.NDArray,
                  traj: npt.NDArray,
                  heading: npt.NDArray,
                  state: npt.NDArray,
                  types: List[str],
                  av_index: int,
                  color_type: Literal['state', 'type', 'seed', 'insert']='seed',
                  state_type: List[str]=['invalid', 'valid', 'enter', 'exit'],
                  plot_enter: bool=False,
                  suffix: str='',
                  pl2seed_radius: float=25.,
                  attr_tokenizer: Attr_Tokenizer=None,
                  enter_index: List[list] = [],
                  save_gif: bool=True,
                  tokenized: bool=False,
                  agent_labels: List[List[Optional[str]]] = [],
                  **kwargs):

    num_historical_steps = 11
    shift = 5
    num_agent, num_timestep = traj.shape[:2]

    if tokenized:
        num_historical_steps = 2
        shift = 1

    if (
        'save_path' in kwargs
        and kwargs['save_path'] != ''
        and kwargs['save_path'] != None
    ):
        os.makedirs(kwargs['save_path'], exist_ok=True)
        save_id = int(max([0] + list(map(lambda fname: int(fname.split("_")[-1]),
                                        filter(lambda fname: fname.startswith(scenario_id)
                                               and os.path.isdir(os.path.join(kwargs['save_path'], fname)),
                                        os.listdir(kwargs['save_path'])))))) + 1
        os.makedirs(f"{kwargs['save_path']}/{scenario_id}_{str(save_id).zfill(3)}", exist_ok=True)

        if save_id > 1:
            try:
                import shutil
                shutil.rmtree(f"{kwargs['save_path']}/{scenario_id}_{str(save_id - 1).zfill(3)}")
            except:
                pass

    visible_mask = state != state_type.index('invalid')
    if not plot_enter:
        visible_mask &= (state != state_type.index('enter'))

    last_valid_step = visible_mask.shape[1] - 1 - torch.argmax(torch.Tensor(visible_mask).flip(dims=[1]).long(), dim=1)
    ids = None
    if 'ids' in kwargs:
        ids = kwargs['ids']
        last_valid_step = {int(ids[i]): int(last_valid_step[i]) for i in range(len(ids))}

    # agent colors
    agent_colors = np.zeros((num_agent, num_timestep, 3))

    agent_palette = sns.color_palette('husl', n_colors=7)
    state_colors = {state: np.array(agent_palette[i]) for i, state in enumerate(state_type)}
    seed_colors = {seed: np.array(agent_palette[i]) for i, seed in enumerate(['existing', 'entered', 'exited'])}

    if color_type == 'state':
        for t in range(state.shape[1]):
            agent_colors[state[:, t] == state_type.index('invalid'), t * shift : (t + 1) * shift] = state_colors['invalid']
            agent_colors[state[:, t] == state_type.index('valid'), t * shift : (t + 1) * shift] = state_colors['valid']
            agent_colors[state[:, t] == state_type.index('enter'), t * shift : (t + 1) * shift] = state_colors['enter']
            agent_colors[state[:, t] == state_type.index('exit'), t * shift : (t + 1) * shift] = state_colors['exit']

    if color_type == 'seed':
        agent_colors[:, :] = seed_colors['existing']
        is_exited = np.any(state[:, num_historical_steps - 1:] == state_type.index('exit'), axis=-1)
        is_entered = np.any(state[:, num_historical_steps - 1:] == state_type.index('enter'), axis=-1)
        is_entered[av_index + 1:] = True  # NOTE: hard code, need improvment
        agent_colors[is_exited, :] = seed_colors['exited']
        agent_colors[is_entered, :] = seed_colors['entered']

    if color_type == 'insert':
        agent_colors[:, :] = seed_colors['exited']
        agent_colors[av_index + 1:] = seed_colors['existing']

    agent_colors[av_index, :] = np.array(agent_palette[-1])
    is_av = np.zeros_like(state[:, 0]).astype(np.bool_)
    is_av[av_index] = True

    # ! get timesteps to plot
    timesteps = list(range(num_timestep))
    if kwargs.get('time_idx', None) is not None:
        time_idx = kwargs['time_idx']
        assert set(time_idx).issubset(set(timesteps)), f'Got invalid time_idx: {time_idx=} v.s. {timesteps=}'
        timesteps = sorted(time_idx)

    # ! get plot limits
    limit_size = kwargs.get('limit_size', None)
    if limit_size is not None:
        assert len(limit_size) == 2, f'Got invalid `limit_size`: {limit_size=}'

    # ! plot all
    pil_images = []
    fig_paths = []
    for tid in tqdm(timesteps, leave=False, desc="Plot ..."):
        mask_t = visible_mask[:, tid]
        xs = traj[mask_t, tid, 0]
        ys = traj[mask_t, tid, 1]
        angles = heading[mask_t, tid]
        colors = agent_colors[mask_t, tid]
        types_t = [types[i] for i, mask in enumerate(mask_t) if mask]
        if ids is not None:
            ids_t = ids[mask_t]
        is_av_t = is_av[mask_t]
        enter_index_t = enter_index[tid] if enter_index else None
        labels = []
        if agent_labels:
            labels = [agent_labels[i][tid // shift] for i in range(len(agent_labels)) if mask_t[i]]

        fig_path = None
        if kwargs.get('save_path', None) is not None:
            save_path = kwargs['save_path']
            fig_path = os.path.join(f"{save_path}/{scenario_id}_{str(save_id).zfill(3)}", f"{tid}.png")
            fig_paths.append(fig_path)

        pil_images.append(
            plot_all(map_data, xs, ys, angles, types_t,
                     colors=colors,
                     save_path=fig_path,
                     is_avs=is_av_t,
                     pl2seed_radius=pl2seed_radius,
                     attr_tokenizer=attr_tokenizer,
                     enter_index=enter_index_t,
                     labels=labels,
                     limit_size=limit_size,
            )
        )

    # generate gif
    if fig_paths and save_gif:
        os.makedirs(os.path.join(save_path, 'gifs'), exist_ok=True)
        images = []
        gif_path = f"{save_path}/gifs/{scenario_id}_{str(save_id).zfill(3)}.gif"
        for fig_path in tqdm(fig_paths, leave=False, desc="Generate gif ..."):
            images.append(Image.open(fig_path))
        try:
            images[0].save(gif_path, save_all=True, append_images=images[1:], duration=100, loop=0)
            tqdm.write(f"Saved gif at {gif_path}")
            try:
                import shutil
                shutil.rmtree(f"{save_path}/{scenario_id}_{str(save_id).zfill(3)}")
                os.remove(f"{save_path}/gifs/{scenario_id}_{str(save_id - 1).zfill(3)}.gif")
            except:
                pass
        except Exception as e:
            tqdm.write(f"{e}! Failed to save gif at {gif_path}")

    return pil_images


def match_token_map(data):

    # init map token
    argmin_sample_len = 3
    map_token_traj_path = '/u/xiuyu/work/dev4/dev/tokens/map_traj_token5.pkl'

    map_token_traj = pickle.load(open(map_token_traj_path, 'rb'))
    map_token = {'traj_src': map_token_traj['traj_src'], }
    traj_end_theta = np.arctan2(map_token['traj_src'][:, -1, 1] - map_token['traj_src'][:, -2, 1],
                                map_token['traj_src'][:, -1, 0] - map_token['traj_src'][:, -2, 0])
    indices = torch.linspace(0, map_token['traj_src'].shape[1]-1, steps=argmin_sample_len).long()
    map_token['sample_pt'] = torch.from_numpy(map_token['traj_src'][:, indices]).to(torch.float)
    map_token['traj_end_theta'] = torch.from_numpy(traj_end_theta).to(torch.float)
    map_token['traj_src'] = torch.from_numpy(map_token['traj_src']).to(torch.float)

    traj_pos = data['map_save']['traj_pos'].to(torch.float)
    traj_theta = data['map_save']['traj_theta'].to(torch.float)
    pl_idx_list = data['map_save']['pl_idx_list']
    token_sample_pt = map_token['sample_pt'].to(traj_pos.device)
    token_src = map_token['traj_src'].to(traj_pos.device)
    max_traj_len = map_token['traj_src'].shape[1]
    pl_num = traj_pos.shape[0]

    pt_token_pos = traj_pos[:, 0, :].clone()
    pt_token_orientation = traj_theta.clone()
    cos, sin = traj_theta.cos(), traj_theta.sin()
    rot_mat = traj_theta.new_zeros(pl_num, 2, 2)
    rot_mat[..., 0, 0] = cos
    rot_mat[..., 0, 1] = -sin
    rot_mat[..., 1, 0] = sin
    rot_mat[..., 1, 1] = cos
    traj_pos_local = torch.bmm((traj_pos - traj_pos[:, 0:1]), rot_mat.view(-1, 2, 2))
    distance = torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1)) ** 2, dim=(-2, -1))
    pt_token_id = torch.argmin(distance, dim=1)

    noise = False
    if noise:
        topk_indices = torch.argsort(torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1)) ** 2, dim=(-2, -1)), dim=1)[:, :8]
        sample_topk = torch.randint(0, topk_indices.shape[-1], size=(topk_indices.shape[0], 1), device=topk_indices.device)
        pt_token_id = torch.gather(topk_indices, 1, sample_topk).squeeze(-1)

    # cos, sin = traj_theta.cos(), traj_theta.sin()
    # rot_mat = traj_theta.new_zeros(pl_num, 2, 2)
    # rot_mat[..., 0, 0] = cos
    # rot_mat[..., 0, 1] = sin
    # rot_mat[..., 1, 0] = -sin
    # rot_mat[..., 1, 1] = cos
    # token_src_world = torch.bmm(token_src[None, ...].repeat(pl_num, 1, 1, 1).reshape(pl_num, -1, 2),
    #                             rot_mat.view(-1, 2, 2)).reshape(pl_num, token_src.shape[0], max_traj_len, 2) + traj_pos[:, None, [0], :]
    # token_src_world_select = token_src_world.view(-1, 1024, 11, 2)[torch.arange(pt_token_id.view(-1).shape[0]), pt_token_id.view(-1)].view(pl_num, max_traj_len, 2)

    pl_idx_full = pl_idx_list.clone()
    token2pl = torch.stack([torch.arange(len(pl_idx_list), device=traj_pos.device), pl_idx_full.long()])
    count_nums = []
    for pl in pl_idx_full.unique():
        pt = token2pl[0, token2pl[1, :] == pl]
        left_side = (data['pt_token']['side'][pt] == 0).sum()
        right_side = (data['pt_token']['side'][pt] == 1).sum()
        center_side = (data['pt_token']['side'][pt] == 2).sum()
        count_nums.append(torch.Tensor([left_side, right_side, center_side]))
    count_nums = torch.stack(count_nums, dim=0)
    num_polyline = int(count_nums.max().item())
    traj_mask = torch.zeros((int(len(pl_idx_full.unique())), 3, num_polyline), dtype=bool)
    idx_matrix = torch.arange(traj_mask.size(2)).unsqueeze(0).unsqueeze(0)
    idx_matrix = idx_matrix.expand(traj_mask.size(0), traj_mask.size(1), -1)
    counts_num_expanded = count_nums.unsqueeze(-1)
    mask_update = idx_matrix < counts_num_expanded
    traj_mask[mask_update] = True

    data['pt_token']['traj_mask'] = traj_mask
    data['pt_token']['position'] = torch.cat([pt_token_pos, torch.zeros((data['pt_token']['num_nodes'], 1),
                                                                        device=traj_pos.device, dtype=torch.float)], dim=-1)
    data['pt_token']['orientation'] = pt_token_orientation
    data['pt_token']['height'] = data['pt_token']['position'][:, -1]
    data[('pt_token', 'to', 'map_polygon')] = {}
    data[('pt_token', 'to', 'map_polygon')]['edge_index'] = token2pl  # (2, num_points)
    data['pt_token']['token_idx'] = pt_token_id
    return data


@safe_run
def plot_tokenize(data, save_path: str):

    shift = 5
    token_size = 2048
    pl2seed_radius = 75

    # transformation
    transform = WaymoTargetBuilder(num_historical_steps=11,
                                   num_future_steps=80,
                                   max_num=32,
                                   training=False)

    grid_range = 150.
    grid_interval = 3.
    angle_interval = 3.
    attr_tokenizer = Attr_Tokenizer(grid_range=grid_range,
                                    grid_interval=grid_interval,
                                    radius=pl2seed_radius,
                                    angle_interval=angle_interval)

    # tokenization
    token_processor = TokenProcessor(token_size,
                                     training=False,
                                     predict_motion=True,
                                     predict_state=True,
                                     predict_map=True,
                                     state_token={'invalid': 0, 'valid': 1, 'enter': 2, 'exit': 3},
                                     pl2seed_radius=pl2seed_radius)
    CONSOLE.log(f"Loaded token processor with token_size: {token_size}")

    # preprocess
    data: HeteroData = transform(data)
    tokenized_data = token_processor(data)
    CONSOLE.log(f"Keys in tokenized data:\n{tokenized_data.keys()}")

    # plot
    agent_data = tokenized_data['agent']
    map_data = tokenized_data['map_point']
    # CONSOLE.log(f"Keys in agent data:\n{agent_data.keys()}")

    av_index = agent_data['av_index']
    raw_traj = agent_data['position'][..., :2].contiguous()  # [n_agent, n_step, 2]
    raw_heading = agent_data['heading']  # [n_agent, n_step]

    traj = agent_data['traj_pos'][..., :2].contiguous()  # [n_agent, n_step, 6, 2]
    traj = traj[:, :, 1:, :].flatten(1, 2)
    traj = torch.cat([raw_traj[:, :1], traj], dim=1)
    heading = agent_data['traj_heading']  # [n_agent, n_step, 6]
    heading = heading[:, :, 1:].flatten(1, 2)
    heading = torch.cat([raw_heading[:, :1], heading], dim=1)

    agent_state = agent_data['state_idx'].repeat_interleave(repeats=shift, dim=-1)
    agent_state = torch.cat([torch.zeros_like(agent_state[:, :1]), agent_state], dim=1)
    agent_type = agent_data['type']
    ids = np.arange(raw_traj.shape[0])

    return plot_scenario(
                scenario_id=tokenized_data['scenario_id'],
                map_data=tokenized_data['map_point']['position'].numpy(),
                traj=raw_traj.numpy(),
                heading=raw_heading.numpy(),
                state=agent_state.numpy(),
                types=list(map(lambda i: AGENT_TYPE[i], agent_type.tolist())),
                av_index=av_index,
                ids=ids,
                save_path=save_path,
                pl2seed_radius=pl2seed_radius,
                attr_tokenizer=attr_tokenizer,
                color_type='state',
        )


def get_metainfos(folder: str):

    import pandas as pd

    assert os.path.exists(folder), f'Path {folder} does not exist.'
    files = list(fnmatch.filter(os.listdir(folder), 'idx_*_rollouts.pkl'))
    CONSOLE.log(f'Found {len(files)} rollouts files from {folder}.')

    metainfos_path = f'{os.path.normpath(folder)}_metainfos.parquet'
    csv_path = f'{os.path.normpath(folder)}_metainfos.csv'

    if not os.path.exists(metainfos_path):

        data = []
        for file in tqdm(files):
            pkl_data = pickle.load(open(os.path.join(folder, file), 'rb'))
            data.extend((file, scenario_id, index) for index, scenario_id in enumerate(pkl_data['_scenario_id']))

        df = pd.DataFrame(data, columns=('rollout_file', 'scenario_id', 'index'))
        df.to_parquet(metainfos_path)
        df.to_csv(csv_path)
        CONSOLE.log(f'Successfully saved to {metainfos_path}.')

    else:
        CONSOLE.log(f'File {metainfos_path} already exists!')
        return


def plot_comparison(methods: List[str], rollouts_paths: List[str], gt_folders: List[str],
                    save_path: str, scenario_ids: Optional[List[str]] = None):
    import pandas as pd
    from collections import defaultdict

    # ! hyperparameter
    fps = 10

    plot_time = [1, 6, 12, 18, 24, 30]
    # plot_time = [1, 5, 10, 15, 20, 25]
    time_idx = [int(time * fps) for time in plot_time]

    limit_size = [75, 60]  # [width, height]

    # ! load metainfos
    metainfos = defaultdict(dict)
    for method, rollout_path in zip(methods, rollouts_paths):
        meta_info_path = f'{os.path.normpath(rollout_path)}_metainfos.parquet'
        metainfos[method]['df'] = pd.read_parquet(meta_info_path)
        CONSOLE.log(f'Loaded {method=} with {len(metainfos[method]["df"]["scenario_id"])=}.')
    common_scenarios = set(metainfos['ours']['df']['scenario_id'])
    for method, meta_info in metainfos.items():
        if method == 'ours':
            continue
        common_scenarios &= set(meta_info['df']['scenario_id'])
    for method, meta_info in metainfos.items():
        df = metainfos[method]['df']
        metainfos[method]['df'] = df[df['scenario_id'].isin(common_scenarios)]
    CONSOLE.log(f'Filter and get {len(common_scenarios)=}.')

    # ! load data and plot
    if scenario_ids is None:
        scenario_ids = metainfos['ours']['df']['scenario_id'].tolist()
    CONSOLE.log(f'Plotting {len(scenario_ids)=} ...')

    for scenario_id in (pbar := tqdm(scenario_ids)):
        pbar.set_postfix(scenario_id=scenario_id)

        figures = dict()
        for method, rollout_path, gt_folder in zip(methods, rollouts_paths, gt_folders):
            df = metainfos[method]['df']
            _df = df.loc[df['scenario_id'] == scenario_id]
            batch_idx = int(_df['index'].tolist()[0])
            rollout_file = _df['rollout_file'].tolist()[0]
            figures[method] = plot_file(
                                gt_folder=gt_folder,
                                files=os.path.join(rollout_path, rollout_file),
                                save_gif=False,
                                batch_idx=batch_idx,
                                time_idx=time_idx,
                                limit_size=limit_size,
                                color_type='insert',
                            )[0][0][0]

        # ! plot figures
        border = 5
        padding_x = 20
        padding_y = 50

        img_width, img_height = figures['ours'][0].size
        img_width = img_width + 2 * border
        img_height = img_height + 2 * border
        n_col = len(time_idx)
        n_row = len(methods)

        W = n_col * img_width + (n_col - 1) * padding_x
        H = n_row * img_height + (n_row - 1) * padding_y

        canvas = Image.new('RGB', (W, H), 'white')
        for i_row, (method, method_figures) in enumerate(figures.items()):
            for i_col, method_figure in enumerate(method_figures):
                x = i_col * (img_width + padding_x)
                y = i_row * (img_height + padding_y)

                padded_figure = Image.new('RGB', (img_width, img_height), 'black')
                padded_figure.paste(method_figure, (border, border))

                canvas.paste(padded_figure, (x, y))

        canvas.save(
            os.path.join(save_path, f'{scenario_id}.png')
        )


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument('--data_path', type=str, default='/u/xiuyu/work/dev4/data/waymo_processed')
    parser.add_argument('--tfrecord_dir', type=str, default='validation_tfrecords_splitted')
    # plot tokenized data
    parser.add_argument('--save_folder', type=str, default='plot_gt')
    parser.add_argument('--split', type=str, default='validation')
    parser.add_argument('--scenario_id', type=str, default=None)
    parser.add_argument('--plot_tokenize', action='store_true')
    # plot generated rollouts
    parser.add_argument('--plot_file', action='store_true')
    parser.add_argument('--folder_path', type=str, default=None)
    parser.add_argument('--file_path', type=str, default=None)
    # metainfos
    parser.add_argument('--get_metainfos', action='store_true')
    # plot comparison
    parser.add_argument('--plot_comparison', action='store_true')
    parser.add_argument('--comparison_folder', type=str, default='comparisons')
    args = parser.parse_args()

    if args.plot_tokenize:

        scenario_id = "74ad7b76d5906d39"
        # scenario_id = "1d60300bc06f4801"
        data_path = os.path.join(args.data_path, args.split, f"{scenario_id}.pkl")
        data = pickle.load(open(data_path, "rb"))
        data['tfrecord_path'] = os.path.join(args.tfrecord_dir, f'{scenario_id}.tfrecords')
        CONSOLE.log(f"Loaded scenario {scenario_id}")

        save_path = os.path.join(args.data_path, args.save_folder, args.split)
        os.makedirs(save_path, exist_ok=True)

        plot_tokenize(data, save_path)

    if args.plot_file:

        plot_file(args.data_path, folder=args.folder_path, files=args.file_path)

    if args.get_metainfos:

        assert args.folder_path is not None, f'`folder_path` should not be None!'
        get_metainfos(args.folder_path)

    if args.plot_comparison:

        methods = ['ours', 'smart']
        gt_folders = [
            '/baai-cwm-1/baai_cwm_ml/algorithm/xiuyu.yang/work/dev4/data/processed',
            '/baai-cwm-1/baai_cwm_ml/algorithm/xiuyu.yang/work/dev4/data/processed',
            '/baai-cwm-1/baai_cwm_ml/algorithm/xiuyu.yang/work/dev4/data/processed',
        ]
        rollouts_paths = [
            '/baai-cwm-1/baai_cwm_ml/algorithm/xiuyu.yang/work/dev4/output/scalable_smart_long/validation_ours0',
            '/baai-cwm-1/baai_cwm_ml/algorithm/xiuyu.yang/work/dev4/output/scalable_smart_long/validation_smart',
            '/baai-cwm-1/baai_cwm_ml/algorithm/xiuyu.yang/work/dev4/output/scalable_smart_long/validation_cslft',
        ]
        save_path = f'/baai-cwm-1/baai_cwm_ml/algorithm/xiuyu.yang/work/dev4/output/scalable_smart_long/{args.comparison_folder}/'
        os.makedirs(save_path, exist_ok=True)

        scenario_ids = ['72ff3e1540b28431','a16c927b1a1cca74','a504d55ea6658de7','639949ea1d16125b']
        plot_comparison(methods, rollouts_paths, gt_folders,
                        save_path=save_path,
                        scenario_ids=scenario_ids)