diff --git a/backups/configs/experiments/ablate_grid_tokens.yaml b/backups/configs/experiments/ablate_grid_tokens.yaml new file mode 100644 index 0000000000000000000000000000000000000000..24bbb28be95f43d145b5b2738b30df6f0e252d58 --- /dev/null +++ b/backups/configs/experiments/ablate_grid_tokens.yaml @@ -0,0 +1,106 @@ +# Config format schema number, the yaml support to valid case source from different dataset +time_info: &time_info + num_historical_steps: 11 + num_future_steps: 80 + use_intention: True + token_size: 2048 + predict_motion: True + predict_state: True + predict_map: True + predict_occ: True + state_token: + invalid: 0 + valid: 1 + enter: 2 + exit: 3 + pl2seed_radius: 75. + disable_grid_token: True + grid_range: 150. # 2 times of pl2seed_radius + grid_interval: 3. + angle_interval: 3. + seed_size: 1 + buffer_size: 128 + max_num: 32 + +Dataset: + root: + train_batch_size: 1 + val_batch_size: 1 + test_batch_size: 1 + shuffle: True + num_workers: 1 + pin_memory: True + persistent_workers: True + train_raw_dir: 'data/waymo_processed/training' + val_raw_dir: 'data/waymo_processed/validation' + test_raw_dir: 'data/waymo_processed/validation' + val_tfrecords_splitted: data/waymo_processed/training/validation_tfrecords_splitted + transform: WaymoTargetBuilder + train_processed_dir: + val_processed_dir: + test_processed_dir: + dataset: 'scalable' + <<: *time_info + +Trainer: + strategy: ddp_find_unused_parameters_false + accelerator: 'gpu' + devices: 1 + max_epochs: 32 + save_ckpt_path: + num_nodes: 1 + mode: + ckpt_path: + precision: 32 + accumulate_grad_batches: 1 + overfit_epochs: 6000 + +Model: + predictor: 'smart' + decoder_type: 'agent_decoder' # choose from ['agent_decoder', 'occ_decoder'] + dataset: 'waymo' + input_dim: 2 + hidden_dim: 128 + output_dim: 2 + output_head: False + num_heads: 8 + <<: *time_info + head_dim: 16 + dropout: 0.1 + num_freq_bands: 64 + lr: 0.0005 + warmup_steps: 0 + total_steps: 32 + predict_map_token: False + num_recurrent_steps_val: 300 + val_open_loop: False + val_close_loop: True + val_insert: False + n_rollout_close_val: 1 + decoder: + <<: *time_info + num_map_layers: 3 + num_agent_layers: 6 + a2a_radius: 60 + pl2pl_radius: 10 + pl2a_radius: 30 + a2sa_radius: 10 + pl2sa_radius: 10 + time_span: 60 + loss_weight: + token_cls_loss: 1 + map_token_loss: 1 + state_cls_loss: 10 + type_cls_loss: 5 + pos_cls_loss: 1 + head_cls_loss: 1 + offset_reg_loss: 5 + shape_reg_loss: .2 + pos_reg_loss: 10 + state_weight: [0.1, 0.1, 0.8] # invalid, valid, exit + seed_state_weight: [0.1, 0.9] # invalid, enter + seed_type_weight: [0.8, 0.1, 0.1] + agent_occ_pos_weight: 100 + pt_occ_pos_weight: 5 + agent_occ_loss: 10 + pt_occ_loss: 10 diff --git a/backups/configs/ours_long_term.yaml b/backups/configs/ours_long_term.yaml new file mode 100644 index 0000000000000000000000000000000000000000..56075a36ee9e1db495a61ec8ab000a1ef6f0abd3 --- /dev/null +++ b/backups/configs/ours_long_term.yaml @@ -0,0 +1,105 @@ +# Config format schema number, the yaml support to valid case source from different dataset +time_info: &time_info + num_historical_steps: 11 + num_future_steps: 80 + use_intention: True + token_size: 2048 + predict_motion: True + predict_state: True + predict_map: True + predict_occ: True + state_token: + invalid: 0 + valid: 1 + enter: 2 + exit: 3 + pl2seed_radius: 75. + grid_range: 150. # 2 times of pl2seed_radius + grid_interval: 3. + angle_interval: 3. + seed_size: 1 + buffer_size: 128 + max_num: 32 + +Dataset: + root: + train_batch_size: 1 + val_batch_size: 1 + test_batch_size: 1 + shuffle: True + num_workers: 1 + pin_memory: True + persistent_workers: True + train_raw_dir: 'data/waymo_processed/training' + val_raw_dir: 'data/waymo_processed/validation' + test_raw_dir: 'data/waymo_processed/validation' + val_tfrecords_splitted: data/waymo_processed/training/validation_tfrecords_splitted + transform: WaymoTargetBuilder + train_processed_dir: + val_processed_dir: + test_processed_dir: + dataset: 'scalable' + <<: *time_info + +Trainer: + strategy: ddp_find_unused_parameters_false + accelerator: 'gpu' + devices: 1 + max_epochs: 32 + save_ckpt_path: + num_nodes: 1 + mode: + ckpt_path: + precision: 32 + accumulate_grad_batches: 1 + overfit_epochs: 6000 + +Model: + predictor: 'smart' + decoder_type: 'agent_decoder' # choose from ['agent_decoder', 'occ_decoder'] + dataset: 'waymo' + input_dim: 2 + hidden_dim: 128 + output_dim: 2 + output_head: False + num_heads: 8 + <<: *time_info + head_dim: 16 + dropout: 0.1 + num_freq_bands: 64 + lr: 0.0005 + warmup_steps: 0 + total_steps: 32 + predict_map_token: False + num_recurrent_steps_val: 300 + val_open_loop: False + val_close_loop: True + val_insert: False + n_rollout_close_val: 1 + decoder: + <<: *time_info + num_map_layers: 3 + num_agent_layers: 6 + a2a_radius: 60 + pl2pl_radius: 10 + pl2a_radius: 30 + a2sa_radius: 10 + pl2sa_radius: 10 + time_span: 60 + loss_weight: + token_cls_loss: 1 + map_token_loss: 1 + state_cls_loss: 10 + type_cls_loss: 5 + pos_cls_loss: 1 + head_cls_loss: 1 + offset_reg_loss: 5 + shape_reg_loss: .2 + pos_reg_loss: 10 + state_weight: [0.1, 0.1, 0.8] # invalid, valid, exit + seed_state_weight: [0.1, 0.9] # invalid, enter + seed_type_weight: [0.8, 0.1, 0.1] + agent_occ_pos_weight: 100 + pt_occ_pos_weight: 5 + agent_occ_loss: 10 + pt_occ_loss: 10 diff --git a/backups/configs/ours_standard.yaml b/backups/configs/ours_standard.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3d3012b60a12e7d3c65ba03d71444d3bde1b6a85 --- /dev/null +++ b/backups/configs/ours_standard.yaml @@ -0,0 +1,101 @@ +# Config format schema number, the yaml support to valid case source from different dataset +time_info: &time_info + num_historical_steps: 11 + num_future_steps: 80 + use_intention: True + token_size: 2048 + predict_motion: True + predict_state: True + predict_map: False + predict_occ: True + state_token: + invalid: 0 + valid: 1 + enter: 2 + exit: 3 + pl2seed_radius: 75. + grid_range: 150. # 2 times of pl2seed_radius + grid_interval: 3. + angle_interval: 3. + seed_size: 1 + buffer_size: 128 + +Dataset: + root: + train_batch_size: 1 + val_batch_size: 1 + test_batch_size: 1 + shuffle: True + num_workers: 1 + pin_memory: True + persistent_workers: True + train_raw_dir: ["data/waymo_processed/training"] + val_raw_dir: ["data/waymo_processed/validation"] + test_raw_dir: ["data/waymo_processed/validation"] + transform: WaymoTargetBuilder + train_processed_dir: + val_processed_dir: + test_processed_dir: + dataset: "scalable" + <<: *time_info + +Trainer: + strategy: ddp_find_unused_parameters_false + accelerator: "gpu" + devices: 1 + max_epochs: 32 + overfit_epochs: 6000 + save_ckpt_path: + num_nodes: 1 + mode: + ckpt_path: + precision: 32 + accumulate_grad_batches: 1 + +Model: + predictor: "smart" + decoder_type: "agent_decoder" # choose from ['agent_decoder', 'occ_decoder'] + dataset: "waymo" + input_dim: 2 + hidden_dim: 128 + output_dim: 2 + output_head: False + num_heads: 8 + <<: *time_info + head_dim: 16 + dropout: 0.1 + num_freq_bands: 64 + lr: 0.0005 + warmup_steps: 0 + total_steps: 32 + predict_map_token: False + num_recurrent_steps_val: -1 + val_open_loop: True + val_close_loop: False + val_insert: False + decoder: + <<: *time_info + num_map_layers: 3 + num_agent_layers: 6 + a2a_radius: 60 + pl2pl_radius: 10 + pl2a_radius: 30 + a2sa_radius: 10 + pl2sa_radius: 10 + time_span: 60 + loss_weight: + token_cls_loss: 1 + map_token_loss: 1 + state_cls_loss: 10 + type_cls_loss: 5 + pos_cls_loss: 1 + head_cls_loss: 1 + offset_reg_loss: 5 + shape_reg_loss: .2 + state_weight: [0.1, 0.1, 0.8] # invalid, valid, exit + seed_state_weight: [0.1, 0.9] # invalid, enter + seed_type_weight: [0.8, 0.1, 0.1] + agent_occ_pos_weight: 100 + pt_occ_pos_weight: 5 + agent_occ_loss: 10 + pt_occ_loss: 10 diff --git a/backups/configs/ours_standard_decode_occ.yaml b/backups/configs/ours_standard_decode_occ.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f8b4ff3f10ae865a382999e722c2ee5dd9d0b4f2 --- /dev/null +++ b/backups/configs/ours_standard_decode_occ.yaml @@ -0,0 +1,100 @@ +# Config format schema number, the yaml support to valid case source from different dataset +time_info: &time_info + num_historical_steps: 11 + num_future_steps: 80 + use_intention: True + token_size: 2048 + predict_motion: False + predict_state: False + predict_map: False + predict_occ: True + state_token: + invalid: 0 + valid: 1 + enter: 2 + exit: 3 + pl2seed_radius: 75. + grid_range: 150. # 2 times of pl2seed_radius + grid_interval: 3. + angle_interval: 3. + seed_size: 1 + buffer_size: 128 + +Dataset: + root: + train_batch_size: 1 + val_batch_size: 1 + test_batch_size: 1 + shuffle: True + num_workers: 1 + pin_memory: True + persistent_workers: True + train_raw_dir: ["data/waymo_processed/training"] + val_raw_dir: ["data/waymo_processed/validation"] + test_raw_dir: ["data/waymo_processed/validation"] + transform: WaymoTargetBuilder + train_processed_dir: + val_processed_dir: + test_processed_dir: + dataset: "scalable" + <<: *time_info + +Trainer: + strategy: ddp_find_unused_parameters_false + accelerator: "gpu" + devices: 1 + max_epochs: 32 + overfit_epochs: 6000 + save_ckpt_path: + num_nodes: 1 + mode: + ckpt_path: + precision: 32 + accumulate_grad_batches: 1 + +Model: + predictor: "smart" + decoder_type: "occ_decoder" # choose from ['agent_decoder', 'occ_decoder'] + dataset: "waymo" + input_dim: 2 + hidden_dim: 128 + output_dim: 2 + output_head: False + num_heads: 8 + <<: *time_info + head_dim: 16 + dropout: 0.1 + num_freq_bands: 64 + lr: 0.0005 + warmup_steps: 0 + total_steps: 32 + predict_map_token: False + num_recurrent_steps_val: -1 + val_open_loop: True + val_closed_loop: False + decoder: + <<: *time_info + num_map_layers: 3 + num_agent_layers: 6 + a2a_radius: 60 + pl2pl_radius: 10 + pl2a_radius: 30 + a2sa_radius: 10 + pl2sa_radius: 10 + time_span: 60 + loss_weight: + token_cls_loss: 1 + map_token_loss: 1 + state_cls_loss: 10 + type_cls_loss: 5 + pos_cls_loss: 1 + head_cls_loss: 1 + offset_reg_loss: 5 + shape_reg_loss: .2 + state_weight: [0.1, 0.1, 0.8] # invalid, valid, exit + seed_state_weight: [0.1, 0.9] # invalid, enter + seed_type_weight: [0.8, 0.1, 0.1] + agent_occ_pos_weight: 100 + pt_occ_pos_weight: 5 + agent_occ_loss: 10 + pt_occ_loss: 10 diff --git a/backups/configs/pretrain_scalable_map.yaml b/backups/configs/pretrain_scalable_map.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cba9933c150effc85cd147d413159c2d12067a51 --- /dev/null +++ b/backups/configs/pretrain_scalable_map.yaml @@ -0,0 +1,97 @@ +# Config format schema number, the yaml support to valid case source from different dataset +time_info: &time_info + num_historical_steps: 11 + num_future_steps: 80 + use_intention: True + token_size: 2048 + predict_motion: False + predict_state: False + predict_map: True + predict_occ: False + state_token: + invalid: 0 + valid: 1 + enter: 2 + exit: 3 + pl2seed_radius: 75. + grid_range: 150. # 2 times of pl2seed_radius + grid_interval: 3. + angle_interval: 3. + seed_size: 1 + buffer_size: 32 + +Dataset: + root: + train_batch_size: 1 + val_batch_size: 1 + test_batch_size: 1 + shuffle: True + num_workers: 1 + pin_memory: True + persistent_workers: True + train_raw_dir: ["data/waymo_processed/training"] + val_raw_dir: ["data/waymo_processed/validation"] + test_raw_dir: ["data/waymo_processed/validation"] + transform: WaymoTargetBuilder + train_processed_dir: + val_processed_dir: + test_processed_dir: + dataset: "scalable" + <<: *time_info + +Trainer: + strategy: ddp_find_unused_parameters_false + accelerator: "gpu" + devices: 1 + max_epochs: 32 + overfit_epochs: 6000 + save_ckpt_path: + num_nodes: 1 + mode: + ckpt_path: + precision: 32 + accumulate_grad_batches: 1 + +Model: + mode: "train" + predictor: "smart" + decoder_type: "agent_decoder" + dataset: "waymo" + input_dim: 2 + hidden_dim: 128 + output_dim: 2 + output_head: False + num_heads: 8 + <<: *time_info + head_dim: 16 + dropout: 0.1 + num_freq_bands: 64 + lr: 0.0005 + warmup_steps: 0 + total_steps: 32 + predict_map_token: False + decoder: + <<: *time_info + num_map_layers: 3 + num_agent_layers: 6 + a2a_radius: 60 + pl2pl_radius: 10 + pl2a_radius: 30 + a2sa_radius: 10 + pl2sa_radius: 10 + time_span: 60 + loss_weight: + token_cls_loss: 1 + map_token_loss: 1 + state_cls_loss: 10 + type_cls_loss: 5 + pos_cls_loss: 1 + head_cls_loss: 1 + shape_reg_loss: .2 + state_weight: [0.1, 0.1, 0.8] # invalid, valid, exit + seed_state_weight: [0.1, 0.9] # invalid, enter + seed_type_weight: [0.8, 0.1, 0.1] + agent_occ_pos_weight: 100 + pt_occ_pos_weight: 5 + agent_occ_loss: 10 + pt_occ_loss: 10 diff --git a/backups/configs/smart.yaml b/backups/configs/smart.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9528e64b7afbbd626fc5c5686954e17a6291aa80 --- /dev/null +++ b/backups/configs/smart.yaml @@ -0,0 +1,70 @@ +# Config format schema number, the yaml support to valid case source from different dataset +time_info: &time_info + num_historical_steps: 11 + num_future_steps: 80 + use_intention: True + token_size: 2048 + disable_invalid: True + use_special_motion_token: False + use_state_token: False + only_state: False + +Dataset: + root: + train_batch_size: 1 + val_batch_size: 1 + test_batch_size: 1 + shuffle: True + num_workers: 1 + pin_memory: True + persistent_workers: True + train_raw_dir: ["data/waymo_processed/training"] + val_raw_dir: ["data/waymo_processed/validation"] + test_raw_dir: ["data/waymo_processed/validation"] + transform: WaymoTargetBuilder + train_processed_dir: + val_processed_dir: + test_processed_dir: + dataset: "scalable" + <<: *time_info + +Trainer: + strategy: ddp_find_unused_parameters_false + accelerator: "gpu" + devices: 1 + max_epochs: 32 + overfit_epochs: 5000 + save_ckpt_path: + num_nodes: 1 + mode: + ckpt_path: + precision: 32 + accumulate_grad_batches: 1 + +Model: + mode: "train" + predictor: "smart" + dataset: "waymo" + input_dim: 2 + hidden_dim: 128 + output_dim: 2 + output_head: False + num_heads: 8 + <<: *time_info + head_dim: 16 + dropout: 0.1 + num_freq_bands: 64 + lr: 0.0005 + warmup_steps: 0 + total_steps: 32 + decoder: + <<: *time_info + num_map_layers: 3 + num_agent_layers: 6 + a2a_radius: 60 + pl2pl_radius: 10 + pl2a_radius: 30 + time_span: 30 + loss_weight: + token_cls_loss: 1 + state_cls_loss: 5 diff --git a/backups/data_preprocess.py b/backups/data_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..5d1d49e0e10811650b53a3657a367f18f68e3c72 --- /dev/null +++ b/backups/data_preprocess.py @@ -0,0 +1,916 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import signal +import multiprocessing +import os +import numpy as np +import pandas as pd +import tensorflow as tf +import torch +import pickle +import easydict +from functools import partial +from scipy.interpolate import interp1d +from argparse import ArgumentParser +from tqdm import tqdm +from typing import Any, Dict, List, Optional +from waymo_open_dataset.protos import scenario_pb2 + + +MIN_VALID_STEPS = 15 + + +_polygon_types = ['VEHICLE', 'BIKE', 'BUS', 'PEDESTRIAN'] +_polygon_light_type = ['LANE_STATE_STOP', 'LANE_STATE_GO', 'LANE_STATE_CAUTION', 'LANE_STATE_UNKNOWN'] +_point_types = ['DASH_SOLID_YELLOW', 'DASH_SOLID_WHITE', 'DASHED_WHITE', 'DASHED_YELLOW', + 'DOUBLE_SOLID_YELLOW', 'DOUBLE_SOLID_WHITE', 'DOUBLE_DASH_YELLOW', 'DOUBLE_DASH_WHITE', + 'SOLID_YELLOW', 'SOLID_WHITE', 'SOLID_DASH_WHITE', 'SOLID_DASH_YELLOW', 'EDGE', + 'NONE', 'UNKNOWN', 'CROSSWALK', 'CENTERLINE'] +_polygon_to_polygon_types = ['NONE', 'PRED', 'SUCC', 'LEFT', 'RIGHT'] + + +Lane_type_hash = { + 4: "BIKE", + 3: "VEHICLE", + 2: "VEHICLE", + 1: "BUS" +} + +boundary_type_hash = { + 5: "UNKNOWN", + 6: "DASHED_WHITE", + 7: "SOLID_WHITE", + 8: "DOUBLE_DASH_WHITE", + 9: "DASHED_YELLOW", + 10: "DOUBLE_DASH_YELLOW", + 11: "SOLID_YELLOW", + 12: "DOUBLE_SOLID_YELLOW", + 13: "DASH_SOLID_YELLOW", + 14: "UNKNOWN", + 15: "EDGE", + 16: "EDGE" +} + + +def safe_list_index(ls: List[Any], elem: Any) -> Optional[int]: + try: + return ls.index(elem) + except ValueError: + return None + + +# def get_agent_features(df: pd.DataFrame, av_id, num_historical_steps=11, dim=3, num_steps=91) -> Dict[str, Any]: +# if args.disable_invalid: # filter out agents that are unseen during the historical time steps +# historical_df = df[df['timestep'] == num_historical_steps-1] # extract the timestep==10 (current) +# agent_ids = list(historical_df['track_id'].unique()) # these agents are seen at timestep==10 (current) +# df = df[df['track_id'].isin(agent_ids)] # remove other agents +# else: +# agent_ids = list(df['track_id'].unique()) + +# num_agents = len(agent_ids) +# # initialization +# valid_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool) +# current_valid_mask = torch.zeros(num_agents, dtype=torch.bool) +# predict_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool) +# agent_id: List[Optional[str]] = [None] * num_agents +# agent_type = torch.zeros(num_agents, dtype=torch.uint8) +# agent_category = torch.zeros(num_agents, dtype=torch.uint8) +# position = torch.zeros(num_agents, num_steps, dim, dtype=torch.float) +# heading = torch.zeros(num_agents, num_steps, dtype=torch.float) +# velocity = torch.zeros(num_agents, num_steps, dim, dtype=torch.float) +# shape = torch.zeros(num_agents, num_steps, dim, dtype=torch.float) + +# for track_id, track_df in df.groupby('track_id'): +# agent_idx = agent_ids.index(track_id) +# all_agent_steps = track_df['timestep'].values +# valid_agent_steps = all_agent_steps[track_df['validity'].astype(np.bool_)].astype(np.int32) +# valid_mask[agent_idx, valid_agent_steps] = True +# current_valid_mask[agent_idx] = valid_mask[agent_idx, num_historical_steps - 1] # current timestep 10 +# if args.disable_invalid: +# predict_mask[agent_idx, valid_agent_steps] = True +# else: +# predict_mask[agent_idx] = True +# predict_mask[agent_idx, :num_historical_steps] = False +# if not current_valid_mask[agent_idx]: +# predict_mask[agent_idx, num_historical_steps:] = False + +# # TODO: why using vector_repr? +# if vector_repr: # a time step t is valid only when both t and t-1 are valid +# valid_mask[agent_idx, 1 : num_historical_steps] = ( +# valid_mask[agent_idx, : num_historical_steps - 1] & +# valid_mask[agent_idx, 1 : num_historical_steps]) +# valid_mask[agent_idx, 0] = False + +# agent_id[agent_idx] = track_id +# agent_type[agent_idx] = _agent_types.index(track_df['object_type'].values[0]) +# agent_category[agent_idx] = track_df['object_category'].values[0] +# position[agent_idx, valid_agent_steps, :3] = torch.from_numpy(np.stack([track_df['position_x'].values[valid_agent_steps], +# track_df['position_y'].values[valid_agent_steps], +# track_df['position_z'].values[valid_agent_steps]], +# axis=-1)).float() +# heading[agent_idx, valid_agent_steps] = torch.from_numpy(track_df['heading'].values[valid_agent_steps]).float() +# velocity[agent_idx, valid_agent_steps, :2] = torch.from_numpy(np.stack([track_df['velocity_x'].values[valid_agent_steps], +# track_df['velocity_y'].values[valid_agent_steps]], +# axis=-1)).float() +# shape[agent_idx, valid_agent_steps, :3] = torch.from_numpy(np.stack([track_df['length'].values[valid_agent_steps], +# track_df['width'].values[valid_agent_steps], +# track_df["height"].values[valid_agent_steps]], +# axis=-1)).float() +# av_idx = agent_id.index(av_id) +# if split == 'test': +# predict_mask[current_valid_mask +# | (agent_category == 2) +# | (agent_category == 3), num_historical_steps:] = True + +# return { +# 'num_nodes': num_agents, +# 'av_index': av_idx, +# 'valid_mask': valid_mask, +# 'predict_mask': predict_mask, +# 'id': agent_id, +# 'type': agent_type, +# 'category': agent_category, +# 'position': position, +# 'heading': heading, +# 'velocity': velocity, +# 'shape': shape +# } + + +def get_agent_features(track_infos: Dict[str, np.ndarray], av_id: int, num_historical_steps: int, num_steps: int) -> Dict[str, Any]: + + agent_idx_to_add = [] + for i in range(len(track_infos['object_id'])): + is_visible = track_infos['valid'][i, num_historical_steps - 1] + valid_steps = np.where(track_infos['valid'][i])[0] + valid_start, valid_end = valid_steps[0], valid_steps[-1] + is_valid = (valid_end - valid_start + 1) >= MIN_VALID_STEPS + + if (is_visible or not args.disable_invalid) and is_valid: + agent_idx_to_add.append(i) + + num_agents = len(agent_idx_to_add) + out_dict = { + 'num_nodes': num_agents, + 'valid_mask': torch.zeros(num_agents, num_steps, dtype=torch.bool), + 'role': torch.zeros(num_agents, 3, dtype=torch.bool), + 'id': torch.zeros(num_agents, dtype=torch.int64) - 1, + 'type': torch.zeros(num_agents, dtype=torch.uint8), + 'category': torch.zeros(num_agents, dtype=torch.uint8), + 'position': torch.zeros(num_agents, num_steps, 3, dtype=torch.float), + 'heading': torch.zeros(num_agents, num_steps, dtype=torch.float), + 'velocity': torch.zeros(num_agents, num_steps, 2, dtype=torch.float), + 'shape': torch.zeros(num_agents, num_steps, 3, dtype=torch.float), + } + + for i, idx in enumerate(agent_idx_to_add): + + out_dict['role'][i] = torch.from_numpy(track_infos['role'][idx]) + out_dict['id'][i] = track_infos['object_id'][idx] + out_dict['type'][i] = track_infos['object_type'][idx] + out_dict['category'][i] = idx in track_infos['tracks_to_predict'] + + valid = track_infos["valid"][idx] # [n_step] + states = track_infos["states"][idx] + + object_shape = states[:, 3:6] # [n_step, 3], length, width, height + object_shape = object_shape[valid].mean(axis=0) # [3] + out_dict["shape"][i] = torch.from_numpy(object_shape) + + valid_steps = np.where(valid)[0] + position = states[:, :3] # [n_step, dim], x, y, z + velocity = states[:, 7:9] # [n_step, 2], vx, vy + heading = states[:, 6] # [n_step], heading + + # valid.sum() should > 1: + t_start, t_end = valid_steps[0], valid_steps[-1] + f_pos = interp1d(valid_steps, position[valid], axis=0) + f_vel = interp1d(valid_steps, velocity[valid], axis=0) + f_yaw = interp1d(valid_steps, np.unwrap(heading[valid], axis=0), axis=0) + t_in = np.arange(t_start, t_end + 1) + out_dict["valid_mask"][i, t_start : t_end + 1] = True + out_dict["position"][i, t_start : t_end + 1] = torch.from_numpy(f_pos(t_in)) + out_dict["velocity"][i, t_start : t_end + 1] = torch.from_numpy(f_vel(t_in)) + out_dict["heading"][i, t_start : t_end + 1] = torch.from_numpy(f_yaw(t_in)) + + out_dict['av_idx'] = out_dict['id'].tolist().index(av_id) + + return out_dict + + +def get_map_features(map_infos, tf_current_light, dim=3): + lane_segments = map_infos['lane'] + all_polylines = map_infos["all_polylines"] + crosswalks = map_infos['crosswalk'] + road_edges = map_infos['road_edge'] + road_lines = map_infos['road_line'] + lane_segment_ids = [info["id"] for info in lane_segments] + cross_walk_ids = [info["id"] for info in crosswalks] + road_edge_ids = [info["id"] for info in road_edges] + road_line_ids = [info["id"] for info in road_lines] + polygon_ids = lane_segment_ids + road_edge_ids + road_line_ids + cross_walk_ids + num_polygons = len(lane_segment_ids) + len(road_edge_ids) + len(road_line_ids) + len(cross_walk_ids) + + # initialization + polygon_type = torch.zeros(num_polygons, dtype=torch.uint8) + polygon_light_type = torch.ones(num_polygons, dtype=torch.uint8) * 3 + + # list of (num_of_segments,), each element has shape of (num_of_points_of_current_segment - 1, dim) + point_position: List[Optional[torch.Tensor]] = [None] * num_polygons + point_orientation: List[Optional[torch.Tensor]] = [None] * num_polygons + point_magnitude: List[Optional[torch.Tensor]] = [None] * num_polygons + point_height: List[Optional[torch.Tensor]] = [None] * num_polygons + point_type: List[Optional[torch.Tensor]] = [None] * num_polygons + + for lane_segment in lane_segments: + lane_segment = easydict.EasyDict(lane_segment) + lane_segment_idx = polygon_ids.index(lane_segment.id) + polyline_index = lane_segment.polyline_index # (start index of point in current scenario, end index of point in current scenario) + centerline = all_polylines[polyline_index[0] : polyline_index[1], :] # (num_of_points_of_current_segment, 5) + centerline = torch.from_numpy(centerline).float() + polygon_type[lane_segment_idx] = _polygon_types.index(Lane_type_hash[lane_segment.type]) + + res = tf_current_light[tf_current_light["lane_id"] == str(lane_segment.id)] + if len(res) != 0: + polygon_light_type[lane_segment_idx] = _polygon_light_type.index(res["state"].item()) + + point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0) # (num_of_points_of_current_segment - 1, 3) + center_vectors = centerline[1:] - centerline[:-1] # (num_of_points_of_current_segment - 1, 5) + point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0) # (num_of_points_of_current_segment - 1,) + point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1) # (num_of_points_of_current_segment - 1,) + point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0) # (num_of_points_of_current_segment - 1,) + center_type = _point_types.index('CENTERLINE') + point_type[lane_segment_idx] = torch.cat( + [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0) + + for lane_segment in road_edges: + lane_segment = easydict.EasyDict(lane_segment) + lane_segment_idx = polygon_ids.index(lane_segment.id) + polyline_index = lane_segment.polyline_index + centerline = all_polylines[polyline_index[0] : polyline_index[1], :] + centerline = torch.from_numpy(centerline).float() + polygon_type[lane_segment_idx] = _polygon_types.index("VEHICLE") + + point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0) + center_vectors = centerline[1:] - centerline[:-1] + point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0) + point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1) + point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0) + center_type = _point_types.index('EDGE') + point_type[lane_segment_idx] = torch.cat( + [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0) + + for lane_segment in road_lines: + lane_segment = easydict.EasyDict(lane_segment) + lane_segment_idx = polygon_ids.index(lane_segment.id) + polyline_index = lane_segment.polyline_index + centerline = all_polylines[polyline_index[0] : polyline_index[1], :] + centerline = torch.from_numpy(centerline).float() + + polygon_type[lane_segment_idx] = _polygon_types.index("VEHICLE") + + point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0) + center_vectors = centerline[1:] - centerline[:-1] + point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0) + point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1) + point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0) + center_type = _point_types.index(boundary_type_hash[lane_segment.type]) + point_type[lane_segment_idx] = torch.cat( + [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0) + + for crosswalk in crosswalks: + crosswalk = easydict.EasyDict(crosswalk) + lane_segment_idx = polygon_ids.index(crosswalk.id) + polyline_index = crosswalk.polyline_index + centerline = all_polylines[polyline_index[0] : polyline_index[1], :] + centerline = torch.from_numpy(centerline).float() + + polygon_type[lane_segment_idx] = _polygon_types.index("PEDESTRIAN") + + point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0) + center_vectors = centerline[1:] - centerline[:-1] + point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0) + point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1) + point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0) + center_type = _point_types.index("CROSSWALK") + point_type[lane_segment_idx] = torch.cat( + [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0) + + # (num_of_segments,), each element represents the number of points of the segment + num_points = torch.tensor([point.size(0) for point in point_position], dtype=torch.long) + # (2, total_num_of_points_of_all_segments), store the point index of segment and its corresponding segment index + # e.g. a scenario has 203 segments, and totally 14039 points: + # tensor([[ 0, 1, 2, ..., 14927, 14928, 14929], + # [ 0, 0, 0, ..., 202, 202, 202]]) => polygon_ids.index(lane_segment.id) + point_to_polygon_edge_index = torch.stack( + [torch.arange(num_points.sum(), dtype=torch.long), + torch.arange(num_polygons, dtype=torch.long).repeat_interleave(num_points)], dim=0) + # list of (num_of_lane_segments,) + polygon_to_polygon_edge_index = [] + # list of (num_of_lane_segments,) + polygon_to_polygon_type = [] + for lane_segment in lane_segments: + lane_segment = easydict.EasyDict(lane_segment) + lane_segment_idx = polygon_ids.index(lane_segment.id) + pred_inds = [] + for pred in lane_segment.entry_lanes: + pred_idx = safe_list_index(polygon_ids, pred) + if pred_idx is not None: + pred_inds.append(pred_idx) + if len(pred_inds) != 0: + polygon_to_polygon_edge_index.append( + torch.stack([torch.tensor(pred_inds, dtype=torch.long), + torch.full((len(pred_inds),), lane_segment_idx, dtype=torch.long)], dim=0)) + polygon_to_polygon_type.append( + torch.full((len(pred_inds),), _polygon_to_polygon_types.index('PRED'), dtype=torch.uint8)) + succ_inds = [] + for succ in lane_segment.exit_lanes: + succ_idx = safe_list_index(polygon_ids, succ) + if succ_idx is not None: + succ_inds.append(succ_idx) + if len(succ_inds) != 0: + polygon_to_polygon_edge_index.append( + torch.stack([torch.tensor(succ_inds, dtype=torch.long), + torch.full((len(succ_inds),), lane_segment_idx, dtype=torch.long)], dim=0)) + polygon_to_polygon_type.append( + torch.full((len(succ_inds),), _polygon_to_polygon_types.index('SUCC'), dtype=torch.uint8)) + if len(lane_segment.left_neighbors) != 0: + left_neighbor_ids = lane_segment.left_neighbors + for left_neighbor_id in left_neighbor_ids: + left_idx = safe_list_index(polygon_ids, left_neighbor_id) + if left_idx is not None: + polygon_to_polygon_edge_index.append( + torch.tensor([[left_idx], [lane_segment_idx]], dtype=torch.long)) + polygon_to_polygon_type.append( + torch.tensor([_polygon_to_polygon_types.index('LEFT')], dtype=torch.uint8)) + if len(lane_segment.right_neighbors) != 0: + right_neighbor_ids = lane_segment.right_neighbors + for right_neighbor_id in right_neighbor_ids: + right_idx = safe_list_index(polygon_ids, right_neighbor_id) + if right_idx is not None: + polygon_to_polygon_edge_index.append( + torch.tensor([[right_idx], [lane_segment_idx]], dtype=torch.long)) + polygon_to_polygon_type.append( + torch.tensor([_polygon_to_polygon_types.index('RIGHT')], dtype=torch.uint8)) + if len(polygon_to_polygon_edge_index) != 0: + polygon_to_polygon_edge_index = torch.cat(polygon_to_polygon_edge_index, dim=1) + polygon_to_polygon_type = torch.cat(polygon_to_polygon_type, dim=0) + else: + polygon_to_polygon_edge_index = torch.tensor([[], []], dtype=torch.long) + polygon_to_polygon_type = torch.tensor([], dtype=torch.uint8) + + map_data = { + 'map_polygon': {}, + 'map_point': {}, + ('map_point', 'to', 'map_polygon'): {}, + ('map_polygon', 'to', 'map_polygon'): {}, + } + map_data['map_polygon']['num_nodes'] = num_polygons # int, number of map segments in the scenario + map_data['map_polygon']['type'] = polygon_type # (num_polygons,) type of each polygon + map_data['map_polygon']['light_type'] = polygon_light_type # (num_polygons,) light type of each polygon, 3 means unknown + if len(num_points) == 0: + map_data['map_point']['num_nodes'] = 0 + map_data['map_point']['position'] = torch.tensor([], dtype=torch.float) + map_data['map_point']['orientation'] = torch.tensor([], dtype=torch.float) + map_data['map_point']['magnitude'] = torch.tensor([], dtype=torch.float) + if dim == 3: + map_data['map_point']['height'] = torch.tensor([], dtype=torch.float) + map_data['map_point']['type'] = torch.tensor([], dtype=torch.uint8) + map_data['map_point']['side'] = torch.tensor([], dtype=torch.uint8) + else: + map_data['map_point']['num_nodes'] = num_points.sum().item() # int, number of total points of all segments in the scenario + map_data['map_point']['position'] = torch.cat(point_position, dim=0) # (num_of_total_points_of_all_segments, 3) + map_data['map_point']['orientation'] = torch.cat(point_orientation, dim=0) # (num_of_total_points_of_all_segments,) + map_data['map_point']['magnitude'] = torch.cat(point_magnitude, dim=0) # (num_of_total_points_of_all_segments,) + if dim == 3: + map_data['map_point']['height'] = torch.cat(point_height, dim=0) # (num_of_total_points_of_all_segments,) + map_data['map_point']['type'] = torch.cat(point_type, dim=0) # (num_of_total_points_of_all_segments,) type of point => `_point_types` + map_data['map_point', 'to', 'map_polygon']['edge_index'] = point_to_polygon_edge_index # (2, num_of_total_points_of_all_segments) + map_data['map_polygon', 'to', 'map_polygon']['edge_index'] = polygon_to_polygon_edge_index + map_data['map_polygon', 'to', 'map_polygon']['type'] = polygon_to_polygon_type + + if int(os.getenv('DEBUG_MAP', 1)): + import matplotlib.pyplot as plt + plt.axis('equal') + plt.scatter(map_data['map_point']['position'][:, 0], + map_data['map_point']['position'][:, 1], s=0.2, c='black', edgecolors='none') + plt.savefig("debug.png", dpi=600) + + return map_data + + +# def process_agent(track_info, tracks_to_predict, scenario_id, start_timestamp, end_timestamp): + +# agents_array = track_info["states"].transpose(1, 0, 2) # (num_timesteps, num_agents, 10) e.g. (91, 15, 10) +# object_id = np.array(track_info["object_id"]) # (num_agents,) global id of each agent +# object_type = track_info["object_type"] # (num_agents,) type of each agent, e.g. 'TYPE_VEHICLE' +# id_hash = {object_id[o_idx]: object_type[o_idx] for o_idx in range(len(object_id))} + +# def type_hash(x): +# tp = id_hash[x] +# type_re_hash = { +# "TYPE_VEHICLE": "vehicle", +# "TYPE_PEDESTRIAN": "pedestrian", +# "TYPE_CYCLIST": "cyclist", +# "TYPE_OTHER": "background", +# "TYPE_UNSET": "background" +# } +# return type_re_hash[tp] + +# columns = ['observed', 'track_id', 'object_type', 'object_category', 'timestep', +# 'position_x', 'position_y', 'position_z', 'length', 'width', 'height', 'heading', 'velocity_x', 'velocity_y', +# 'scenario_id', 'start_timestamp', 'end_timestamp', 'num_timestamps', +# 'focal_track_id', 'city', 'validity'] + +# # (num_timesteps, num_agents, 10) e.g. (91, 15, 10) +# new_columns = np.ones((agents_array.shape[0], agents_array.shape[1], 11)) +# new_columns[:11, :, 0] = True # observed, 10 timesteps +# new_columns[11:, :, 0] = False # not observed (current + future) +# for index in range(new_columns.shape[0]): +# new_columns[index, :, 4] = int(index) # timestep (0 ~ 90) +# new_columns[..., 1] = object_id +# new_columns[..., 2] = object_id +# new_columns[:, tracks_to_predict['track_index'], 3] = 3 +# new_columns[..., 5] = 11 +# new_columns[..., 6] = int(start_timestamp) # 0 +# new_columns[..., 7] = int(end_timestamp) # 91 +# new_columns[..., 8] = int(91) # 91 +# new_columns[..., 9] = object_id +# new_columns[..., 10] = 10086 +# new_columns = new_columns +# new_agents_array = np.concatenate([new_columns, agents_array], axis=-1) # (num_timesteps, num_agents, 21) e.g. (91, 15, 21) +# # filter out the invalid timestep of agents, reshape to (num_valid_of_timesteps_of_all_agents, 21) e.g. (91, 15, 21) -> (1137, 21) +# if args.disable_invalid: +# new_agents_array = new_agents_array[new_agents_array[..., -1] == 1.0].reshape(-1, new_agents_array.shape[-1]) +# else: +# agent_valid_mask = new_agents_array[..., -1] # (num_timesteps, num_agents) +# agent_mask = np.sum(agent_valid_mask, axis=0) > MIN_VALID_STEPS # NOTE: 10 is a empirical parameter +# new_agents_array = new_agents_array[:, agent_mask] +# new_agents_array = new_agents_array.reshape(-1, new_agents_array.shape[-1]) # (91, 15, 21) -> (1365, 21) +# new_agents_array = new_agents_array[..., [0, 1, 2, 3, 4, 11, 12, 13, 14, 15, 16, 17, 18, 19, 5, 6, 7, 8, 9, 10, 20]] +# new_agents_array = pd.DataFrame(data=new_agents_array, columns=columns) +# new_agents_array["object_type"] = new_agents_array["object_type"].apply(func=type_hash) +# new_agents_array["start_timestamp"] = new_agents_array["start_timestamp"].astype(int) +# new_agents_array["end_timestamp"] = new_agents_array["end_timestamp"].astype(int) +# new_agents_array["num_timestamps"] = new_agents_array["num_timestamps"].astype(int) +# new_agents_array["scenario_id"] = scenario_id + +# return new_agents_array + + +def process_dynamic_map(dynamic_map_infos): + lane_ids = dynamic_map_infos["lane_id"] + tf_lights = [] + for t in range(len(lane_ids)): + lane_id = lane_ids[t] + time = np.ones_like(lane_id) * t + state = dynamic_map_infos["state"][t] + tf_light = np.concatenate([lane_id, time, state], axis=0) + tf_lights.append(tf_light) + tf_lights = np.concatenate(tf_lights, axis=1).transpose(1, 0) + tf_lights = pd.DataFrame(data=tf_lights, columns=["lane_id", "time_step", "state"]) + tf_lights["time_step"] = tf_lights["time_step"].astype("str") + tf_lights["lane_id"] = tf_lights["lane_id"].astype("str") + tf_lights["state"] = tf_lights["state"].astype("str") + tf_lights.loc[tf_lights["state"].str.contains("STOP"), ["state"]] = ( + "LANE_STATE_STOP" + ) + tf_lights.loc[tf_lights["state"].str.contains("GO"), ["state"]] = "LANE_STATE_GO" + tf_lights.loc[tf_lights["state"].str.contains("CAUTION"), ["state"]] = ( + "LANE_STATE_CAUTION" + ) + tf_lights.loc[tf_lights["state"].str.contains("UNKNOWN"), ["state"]] = ( + "LANE_STATE_UNKNOWN" + ) + + return tf_lights + + +polyline_type = { + # for lane + 'TYPE_UNDEFINED': -1, + 'TYPE_FREEWAY': 1, + 'TYPE_SURFACE_STREET': 2, + 'TYPE_BIKE_LANE': 3, + + # for roadline + 'TYPE_UNKNOWN': -1, + 'TYPE_BROKEN_SINGLE_WHITE': 6, + 'TYPE_SOLID_SINGLE_WHITE': 7, + 'TYPE_SOLID_DOUBLE_WHITE': 8, + 'TYPE_BROKEN_SINGLE_YELLOW': 9, + 'TYPE_BROKEN_DOUBLE_YELLOW': 10, + 'TYPE_SOLID_SINGLE_YELLOW': 11, + 'TYPE_SOLID_DOUBLE_YELLOW': 12, + 'TYPE_PASSING_DOUBLE_YELLOW': 13, + + # for roadedge + 'TYPE_ROAD_EDGE_BOUNDARY': 15, + 'TYPE_ROAD_EDGE_MEDIAN': 16, + + # for stopsign + 'TYPE_STOP_SIGN': 17, + + # for crosswalk + 'TYPE_CROSSWALK': 18, + + # for speed bump + 'TYPE_SPEED_BUMP': 19 +} + +object_type = { + 0: 'TYPE_UNSET', + 1: 'TYPE_VEHICLE', + 2: 'TYPE_PEDESTRIAN', + 3: 'TYPE_CYCLIST', + 4: 'TYPE_OTHER' +} + + +def decode_tracks_from_proto(scenario): + sdc_track_index = scenario.sdc_track_index + track_index_predict = [i.track_index for i in scenario.tracks_to_predict] + object_id_interest = [i for i in scenario.objects_of_interest] + + track_infos = { + 'object_id': [], # {0: unset, 1: vehicle, 2: pedestrian, 3: cyclist, 4: others} + 'object_type': [], + 'states': [], + 'valid': [], + 'role': [], + } + + # tracks mean N number of objects, e.g. len(tracks) = 55 + # each track has 91 states, e.g. len(tracks[0].states) == 91 + # each state has 10 attributes: center_x, center_y, center_z, length, ..., velocity_y, valid + for i, cur_data in enumerate(scenario.tracks): + + step_state = [] + step_valid = [] + + for s in cur_data.states: # n_steps + step_state.append( + [ + s.center_x, + s.center_y, + s.center_z, + s.length, + s.width, + s.height, + s.heading, + s.velocity_x, + s.velocity_y, + ] + ) + step_valid.append(s.valid) + # This angle is normalized to [-pi, pi). The velocity vector in m/s + + track_infos['object_id'].append(cur_data.id) # id of object in this track + track_infos['object_type'].append(cur_data.object_type - 1) + track_infos['states'].append(np.array(step_state, dtype=np.float32)) + track_infos['valid'].append(np.array(step_valid)) + + track_infos['role'].append([False, False, False]) + if i in track_index_predict: + track_infos['role'][-1][2] = True # predict=2 + if cur_data.id in object_id_interest: + track_infos['role'][-1][1] = True # interest=1 + if i == sdc_track_index: + track_infos['role'][-1][0] = True # ego_vehicle=0 + + track_infos['states'] = np.array(track_infos['states'], dtype=np.float32) # (n_agent, n_step, 9) + track_infos['valid'] = np.array(track_infos['valid'], dtype=np.bool_) + track_infos['role'] = np.array(track_infos['role'], dtype=np.bool_) + track_infos['object_id'] = np.array(track_infos['object_id'], dtype=np.int64) + track_infos['object_type'] = np.array(track_infos['object_type'], dtype=np.uint8) + track_infos['tracks_to_predict'] = np.array(track_index_predict, dtype=np.int64) + + return track_infos + + +from collections import defaultdict + +def decode_map_features_from_proto(map_features): + map_infos = { + 'lane': [], + 'road_line': [], + 'road_edge': [], + 'stop_sign': [], + 'crosswalk': [], + 'speed_bump': [], + 'lane_dict': {}, + 'lane2other_dict': {} + } + polylines = [] + + point_cnt = 0 + lane2other_dict = defaultdict(list) + + for cur_data in map_features: + cur_info = {'id': cur_data.id} + + if cur_data.lane.ByteSize() > 0: + cur_info['speed_limit_mph'] = cur_data.lane.speed_limit_mph + cur_info['type'] = cur_data.lane.type + 1 # 0: undefined, 1: freeway, 2: surface_street, 3: bike_lane + cur_info['left_neighbors'] = [lane.feature_id for lane in cur_data.lane.left_neighbors] + + cur_info['right_neighbors'] = [lane.feature_id for lane in cur_data.lane.right_neighbors] + + cur_info['interpolating'] = cur_data.lane.interpolating + cur_info['entry_lanes'] = list(cur_data.lane.entry_lanes) + cur_info['exit_lanes'] = list(cur_data.lane.exit_lanes) + + cur_info['left_boundary_type'] = [x.boundary_type + 5 for x in cur_data.lane.left_boundaries] + cur_info['right_boundary_type'] = [x.boundary_type + 5 for x in cur_data.lane.right_boundaries] + + cur_info['left_boundary'] = [x.boundary_feature_id for x in cur_data.lane.left_boundaries] + cur_info['right_boundary'] = [x.boundary_feature_id for x in cur_data.lane.right_boundaries] + cur_info['left_boundary_start_index'] = [lane.lane_start_index for lane in cur_data.lane.left_boundaries] + cur_info['left_boundary_end_index'] = [lane.lane_end_index for lane in cur_data.lane.left_boundaries] + cur_info['right_boundary_start_index'] = [lane.lane_start_index for lane in cur_data.lane.right_boundaries] + cur_info['right_boundary_end_index'] = [lane.lane_end_index for lane in cur_data.lane.right_boundaries] + + lane2other_dict[cur_data.id].extend(cur_info['left_boundary']) + lane2other_dict[cur_data.id].extend(cur_info['right_boundary']) + + global_type = cur_info['type'] + cur_polyline = np.stack( + [np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in cur_data.lane.polyline], + axis=0) + cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1) + if cur_polyline.shape[0] <= 1: + continue + map_infos['lane'].append(cur_info) + map_infos['lane_dict'][cur_data.id] = cur_info + + elif cur_data.road_line.ByteSize() > 0: + cur_info['type'] = cur_data.road_line.type + 5 + + global_type = cur_info['type'] + cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in + cur_data.road_line.polyline], axis=0) + cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1) + if cur_polyline.shape[0] <= 1: + continue + map_infos['road_line'].append(cur_info) # (num_points, 5) + + elif cur_data.road_edge.ByteSize() > 0: + cur_info['type'] = cur_data.road_edge.type + 14 + + global_type = cur_info['type'] + cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in + cur_data.road_edge.polyline], axis=0) + cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1) + if cur_polyline.shape[0] <= 1: + continue + map_infos['road_edge'].append(cur_info) + + elif cur_data.stop_sign.ByteSize() > 0: + cur_info['lane_ids'] = list(cur_data.stop_sign.lane) + for i in cur_info['lane_ids']: + lane2other_dict[i].append(cur_data.id) + point = cur_data.stop_sign.position + cur_info['position'] = np.array([point.x, point.y, point.z]) + + global_type = polyline_type['TYPE_STOP_SIGN'] + cur_polyline = np.array([point.x, point.y, point.z, global_type, cur_data.id]).reshape(1, 5) + if cur_polyline.shape[0] <= 1: + continue + map_infos['stop_sign'].append(cur_info) + elif cur_data.crosswalk.ByteSize() > 0: + global_type = polyline_type['TYPE_CROSSWALK'] + cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in + cur_data.crosswalk.polygon], axis=0) + cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1) + if cur_polyline.shape[0] <= 1: + continue + map_infos['crosswalk'].append(cur_info) + + elif cur_data.speed_bump.ByteSize() > 0: + global_type = polyline_type['TYPE_SPEED_BUMP'] + cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in + cur_data.speed_bump.polygon], axis=0) + cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1) + if cur_polyline.shape[0] <= 1: + continue + map_infos['speed_bump'].append(cur_info) + + else: + continue + polylines.append(cur_polyline) + cur_info['polyline_index'] = (point_cnt, point_cnt + len(cur_polyline)) # (start index of point in current scenario, end index of point in current scenario) + point_cnt += len(cur_polyline) + + polylines = np.concatenate(polylines, axis=0).astype(np.float32) + map_infos['all_polylines'] = polylines # (num_of_total_points_in_current_scenario, 5) + map_infos['lane2other_dict'] = lane2other_dict + return map_infos + + +def decode_dynamic_map_states_from_proto(dynamic_map_states): + + signal_state = { + 0: 'LANE_STATE_UNKNOWN', + # States for traffic signals with arrows. + 1: 'LANE_STATE_ARROW_STOP', + 2: 'LANE_STATE_ARROW_CAUTION', + 3: 'LANE_STATE_ARROW_GO', + # Standard round traffic signals. + 4: 'LANE_STATE_STOP', + 5: 'LANE_STATE_CAUTION', + 6: 'LANE_STATE_GO', + # Flashing light signals. + 7: 'LANE_STATE_FLASHING_STOP', + 8: 'LANE_STATE_FLASHING_CAUTION' + } + + dynamic_map_infos = { + 'lane_id': [], + 'state': [], + 'stop_point': [] + } + for cur_data in dynamic_map_states: # len(dynamic_map_states) = num_timestamp + lane_id, state, stop_point = [], [], [] + for cur_signal in cur_data.lane_states: # (num_observed_signals) + lane_id.append(cur_signal.lane) + state.append(signal_state[cur_signal.state]) + stop_point.append([cur_signal.stop_point.x, cur_signal.stop_point.y, cur_signal.stop_point.z]) + + dynamic_map_infos['lane_id'].append(np.array([lane_id])) + dynamic_map_infos['state'].append(np.array([state])) + dynamic_map_infos['stop_point'].append(np.array([stop_point])) + + return dynamic_map_infos + + +# def process_single_data(scenario): +# info = {} +# info['scenario_id'] = scenario.scenario_id +# info['timestamps_seconds'] = list(scenario.timestamps_seconds) # list of int of shape (91) +# info['current_time_index'] = scenario.current_time_index # int, 10 +# info['sdc_track_index'] = scenario.sdc_track_index # int +# info['objects_of_interest'] = list(scenario.objects_of_interest) # list, could be empty list + +# info['tracks_to_predict'] = { +# 'track_index': [cur_pred.track_index for cur_pred in scenario.tracks_to_predict], +# 'difficulty': [cur_pred.difficulty for cur_pred in scenario.tracks_to_predict] +# } # for training: suggestion of objects to train on, for val/test: need to be predicted + +# # decode tracks data +# track_infos = decode_tracks_from_proto(scenario.tracks) +# info['tracks_to_predict']['object_type'] = [track_infos['object_type'][cur_idx] for cur_idx in +# info['tracks_to_predict']['track_index']] +# # decode map related data +# map_infos = decode_map_features_from_proto(scenario.map_features) +# dynamic_map_infos = decode_dynamic_map_states_from_proto(scenario.dynamic_map_states) + +# save_infos = { +# 'track_infos': track_infos, +# 'map_infos': map_infos, +# 'dynamic_map_infos': dynamic_map_infos, +# } +# save_infos.update(info) +# return save_infos + + +def wm2argo(file, input_dir, output_dir, existing_files=[], output_dir_tfrecords_splitted=None): + file_path = os.path.join(input_dir, file) + dataset = tf.data.TFRecordDataset(file_path, compression_type='', num_parallel_reads=3) + + for cnt, tf_data in tqdm(enumerate(dataset), leave=False, desc=f'Process {file}...'): + + scenario = scenario_pb2.Scenario() + scenario.ParseFromString(bytearray(tf_data.numpy())) + scenario_id = scenario.scenario_id + tqdm.write(f"idx: {cnt}, scenario_id: {scenario_id} of {file}") + + if f'{scenario_id}.pkl' not in existing_files: + + map_infos = decode_map_features_from_proto(scenario.map_features) + track_infos = decode_tracks_from_proto(scenario) + dynamic_map_infos = decode_dynamic_map_states_from_proto(scenario.dynamic_map_states) + sdc_track_index = scenario.sdc_track_index # int + av_id = track_infos['object_id'][sdc_track_index] + # if len(track_infos['tracks_to_predict']) < 1: + # return + + current_time_index = scenario.current_time_index + tf_lights = process_dynamic_map(dynamic_map_infos) + tf_current_light = tf_lights.loc[tf_lights["time_step"] == current_time_index] # 10 (history) + 1 (current) + 80 (future) + map_data = get_map_features(map_infos, tf_current_light) + + # new_agents_array = process_agent(track_infos, tracks_to_predict, scenario_id, 0, 91) # mtr2argo + data = dict() + data.update(map_data) + data['scenario_id'] = scenario_id + data['agent'] = get_agent_features(track_infos, av_id, num_historical_steps=current_time_index + 1, num_steps=91) + + with open(os.path.join(output_dir, f'{scenario_id}.pkl'), "wb+") as f: + pickle.dump(data, f) + + if output_dir_tfrecords_splitted is not None: + tf_file = os.path.join(output_dir_tfrecords_splitted, f'{scenario_id}.tfrecords') + if not os.path.exists(tf_file): + with tf.io.TFRecordWriter(tf_file) as file_writer: + file_writer.write(tf_data.numpy()) + + +def batch_process9s_transformer(input_dir, output_dir, split, num_workers=2): + signal.signal(signal.SIGINT, signal.SIG_IGN) + + output_dir_tfrecords_splitted = None + if split == "validation": + output_dir_tfrecords_splitted = os.path.join(output_dir, 'validation_tfrecords_splitted') + os.makedirs(output_dir_tfrecords_splitted, exist_ok=True) + + input_dir = os.path.join(input_dir, split) + output_dir = os.path.join(output_dir, split) + os.makedirs(output_dir, exist_ok=True) + + packages = sorted(os.listdir(input_dir)) + existing_files = sorted(os.listdir(output_dir)) + func = partial( + wm2argo, + output_dir=output_dir, + input_dir=input_dir, + existing_files=existing_files, + output_dir_tfrecords_splitted=output_dir_tfrecords_splitted + ) + try: + with multiprocessing.Pool(num_workers, maxtasksperchild=10) as p: + r = list(tqdm(p.imap_unordered(func, packages), total=len(packages))) + except KeyboardInterrupt: + p.terminate() + p.join() + + +def generate_meta_infos(data_dir): + import json + + meta_infos = dict() + + for split in tqdm(['training', 'validation', 'test'], leave=False): + if not os.path.exists(os.path.join(data_dir, split)): + continue + + split_infos = dict() + files = os.listdir(os.path.join(data_dir, split)) + for file in tqdm(files, leave=False): + try: + data = pickle.load(open(os.path.join(data_dir, split, file), 'rb')) + except Exception as e: + tqdm.write(f'Failed to load scenario {file} due to {e}') + continue + scenario_infos = dict(num_agents=data['agent']['num_nodes']) + scenario_id = data['scenario_id'] + split_infos[scenario_id] = scenario_infos + + meta_infos[split] = split_infos + + with open(os.path.join(data_dir, 'meta_infos.json'), 'w', encoding='utf-8') as f: + json.dump(meta_infos, f, indent=4) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument('--input_dir', type=str, default='data/waymo/') + parser.add_argument('--output_dir', type=str, default='data/waymo_processed/') + parser.add_argument('--split', type=str, default='validation') + parser.add_argument('--no_batch', action='store_true') + parser.add_argument('--disable_invalid', action="store_true") + parser.add_argument('--generate_meta_infos', action="store_true") + args = parser.parse_args() + + if args.generate_meta_infos: + generate_meta_infos(args.output_dir) + + elif args.no_batch: + + output_dir_tfrecords_splitted = None + if args.split == "validation": + output_dir_tfrecords_splitted = os.path.join(args.output_dir, 'validation_tfrecords_splitted') + os.makedirs(output_dir_tfrecords_splitted, exist_ok=True) + + input_dir = os.path.join(args.input_dir, args.split) + output_dir = os.path.join(args.output_dir, args.split) + os.makedirs(output_dir, exist_ok=True) + + files = sorted(os.listdir(input_dir)) + os.makedirs(args.output_dir, exist_ok=True) + for file in tqdm(files, leave=False, desc=f'Process {args.split}...'): + wm2argo(file, input_dir, output_dir, output_dir_tfrecords_splitted) + + else: + + batch_process9s_transformer(args.input_dir, args.output_dir, args.split, num_workers=96) diff --git a/backups/dev/datasets/preprocess.py b/backups/dev/datasets/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..dffa56e43b1a236b22ee86c0f708af260763b396 --- /dev/null +++ b/backups/dev/datasets/preprocess.py @@ -0,0 +1,761 @@ +import os +import torch +import math +import pickle +import numpy as np +from torch import nn +from typing import Dict, Sequence +from scipy.interpolate import interp1d +from scipy.spatial.distance import euclidean +from dev.utils.func import wrap_angle + + +SHIFT = 5 +AGENT_SHAPE = { + 'vehicle': [4.3, 1.8, 1.], + 'pedstrain': [0.5, 0.5, 1.], + 'cyclist': [1.9, 0.5, 1.], +} +AGENT_TYPE = ['veh', 'ped', 'cyc', 'seed'] +AGENT_STATE = ['invalid', 'valid', 'enter', 'exit'] + + +@torch.no_grad() +def cal_polygon_contour(pos, head, width_length) -> torch.Tensor: # [n_agent, n_step, n_target, 4, 2] + x, y = pos[..., 0], pos[..., 1] # [n_agent, n_step, n_target] + width, length = width_length[..., 0], width_length[..., 1] # [n_agent, 1, 1] + + half_cos = 0.5 * head.cos() # [n_agent, n_step, n_target] + half_sin = 0.5 * head.sin() # [n_agent, n_step, n_target] + length_cos = length * half_cos # [n_agent, n_step, n_target] + length_sin = length * half_sin # [n_agent, n_step, n_target] + width_cos = width * half_cos # [n_agent, n_step, n_target] + width_sin = width * half_sin # [n_agent, n_step, n_target] + + left_front_x = x + length_cos - width_sin + left_front_y = y + length_sin + width_cos + left_front = torch.stack((left_front_x, left_front_y), dim=-1) + + right_front_x = x + length_cos + width_sin + right_front_y = y + length_sin - width_cos + right_front = torch.stack((right_front_x, right_front_y), dim=-1) + + right_back_x = x - length_cos + width_sin + right_back_y = y - length_sin - width_cos + right_back = torch.stack((right_back_x, right_back_y), dim=-1) + + left_back_x = x - length_cos - width_sin + left_back_y = y - length_sin + width_cos + left_back = torch.stack((left_back_x, left_back_y), dim=-1) + + polygon_contour = torch.stack( + (left_front, right_front, right_back, left_back), dim=-2 + ) + + return polygon_contour + + +def interplating_polyline(polylines, heading, distance=0.5, split_distace=5): + # Calculate the cumulative distance along the path, up-sample the polyline to 0.5 meter + dist_along_path_list = [[0]] + polylines_list = [[polylines[0]]] + for i in range(1, polylines.shape[0]): + euclidean_dist = euclidean(polylines[i, :2], polylines[i - 1, :2]) + heading_diff = min(abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1])), + abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1]) + math.pi)) + if heading_diff > math.pi / 4 and euclidean_dist > 3: + dist_along_path_list.append([0]) + polylines_list.append([polylines[i]]) + elif heading_diff > math.pi / 8 and euclidean_dist > 3: + dist_along_path_list.append([0]) + polylines_list.append([polylines[i]]) + elif heading_diff > 0.1 and euclidean_dist > 3: + dist_along_path_list.append([0]) + polylines_list.append([polylines[i]]) + elif euclidean_dist > 10: + dist_along_path_list.append([0]) + polylines_list.append([polylines[i]]) + else: + dist_along_path_list[-1].append(dist_along_path_list[-1][-1] + euclidean_dist) + polylines_list[-1].append(polylines[i]) + # plt.plot(polylines[:, 0], polylines[:, 1]) + # plt.savefig('tmp.jpg') + new_x_list = [] + new_y_list = [] + multi_polylines_list = [] + for idx in range(len(dist_along_path_list)): + if len(dist_along_path_list[idx]) < 2: + continue + dist_along_path = np.array(dist_along_path_list[idx]) + polylines_cur = np.array(polylines_list[idx]) + # Create interpolation functions for x and y coordinates + fx = interp1d(dist_along_path, polylines_cur[:, 0]) + fy = interp1d(dist_along_path, polylines_cur[:, 1]) + # fyaw = interp1d(dist_along_path, heading) + + # Create an array of distances at which to interpolate + new_dist_along_path = np.arange(0, dist_along_path[-1], distance) + new_dist_along_path = np.concatenate([new_dist_along_path, dist_along_path[[-1]]]) + # Use the interpolation functions to generate new x and y coordinates + new_x = fx(new_dist_along_path) + new_y = fy(new_dist_along_path) + # new_yaw = fyaw(new_dist_along_path) + new_x_list.append(new_x) + new_y_list.append(new_y) + + # Combine the new x and y coordinates into a single array + new_polylines = np.vstack((new_x, new_y)).T + polyline_size = int(split_distace / distance) + if new_polylines.shape[0] >= (polyline_size + 1): + padding_size = (new_polylines.shape[0] - (polyline_size + 1)) % polyline_size + final_index = (new_polylines.shape[0] - (polyline_size + 1)) // polyline_size + 1 + else: + padding_size = new_polylines.shape[0] + final_index = 0 + multi_polylines = None + new_polylines = torch.from_numpy(new_polylines) + new_heading = torch.atan2(new_polylines[1:, 1] - new_polylines[:-1, 1], + new_polylines[1:, 0] - new_polylines[:-1, 0]) + new_heading = torch.cat([new_heading, new_heading[-1:]], -1)[..., None] + new_polylines = torch.cat([new_polylines, new_heading], -1) + if new_polylines.shape[0] >= (polyline_size + 1): + multi_polylines = new_polylines.unfold(dimension=0, size=polyline_size + 1, step=polyline_size) + multi_polylines = multi_polylines.transpose(1, 2) + multi_polylines = multi_polylines[:, ::5, :] + if padding_size >= 3: + last_polyline = new_polylines[final_index * polyline_size:] + last_polyline = last_polyline[torch.linspace(0, last_polyline.shape[0] - 1, steps=3).long()] + if multi_polylines is not None: + multi_polylines = torch.cat([multi_polylines, last_polyline.unsqueeze(0)], dim=0) + else: + multi_polylines = last_polyline.unsqueeze(0) + if multi_polylines is None: + continue + multi_polylines_list.append(multi_polylines) + if len(multi_polylines_list) > 0: + multi_polylines_list = torch.cat(multi_polylines_list, dim=0) + else: + multi_polylines_list = None + return multi_polylines_list + + +# def interplating_polyline(polylines, heading, distance=0.5, split_distance=5, device='cpu'): +# dist_along_path_list = [[0]] +# polylines_list = [[polylines[0]]] +# for i in range(1, polylines.shape[0]): +# euclidean_dist = torch.norm(polylines[i, :2] - polylines[i - 1, :2]) +# heading_diff = min(abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1])), +# abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1]) + torch.pi)) +# if heading_diff > torch.pi / 4 and euclidean_dist > 3: +# dist_along_path_list.append([0]) +# polylines_list.append([polylines[i]]) +# elif heading_diff > torch.pi / 8 and euclidean_dist > 3: +# dist_along_path_list.append([0]) +# polylines_list.append([polylines[i]]) +# elif heading_diff > 0.1 and euclidean_dist > 3: +# dist_along_path_list.append([0]) +# polylines_list.append([polylines[i]]) +# elif euclidean_dist > 10: +# dist_along_path_list.append([0]) +# polylines_list.append([polylines[i]]) +# else: +# dist_along_path_list[-1].append(dist_along_path_list[-1][-1] + euclidean_dist) +# polylines_list[-1].append(polylines[i]) + +# new_x_list = [] +# new_y_list = [] +# multi_polylines_list = [] + +# for idx in range(len(dist_along_path_list)): +# if len(dist_along_path_list[idx]) < 2: +# continue + +# dist_along_path = torch.tensor(dist_along_path_list[idx], device=device) +# polylines_cur = torch.stack(polylines_list[idx]) + +# new_dist_along_path = torch.arange(0, dist_along_path[-1], distance) +# new_dist_along_path = torch.cat([new_dist_along_path, dist_along_path[[-1]]]) + +# new_x = torch.interp(new_dist_along_path, dist_along_path, polylines_cur[:, 0]) +# new_y = torch.interp(new_dist_along_path, dist_along_path, polylines_cur[:, 1]) + +# new_x_list.append(new_x) +# new_y_list.append(new_y) + +# new_polylines = torch.stack((new_x, new_y), dim=-1) + +# polyline_size = int(split_distance / distance) +# if new_polylines.shape[0] >= (polyline_size + 1): +# padding_size = (new_polylines.shape[0] - (polyline_size + 1)) % polyline_size +# final_index = (new_polylines.shape[0] - (polyline_size + 1)) // polyline_size + 1 +# else: +# padding_size = new_polylines.shape[0] +# final_index = 0 + +# multi_polylines = None +# new_heading = torch.atan2(new_polylines[1:, 1] - new_polylines[:-1, 1], +# new_polylines[1:, 0] - new_polylines[:-1, 0]) +# new_heading = torch.cat([new_heading, new_heading[-1:]], -1)[..., None] +# new_polylines = torch.cat([new_polylines, new_heading], -1) + +# if new_polylines.shape[0] >= (polyline_size + 1): +# multi_polylines = new_polylines.unfold(dimension=0, size=polyline_size + 1, step=polyline_size) +# multi_polylines = multi_polylines.transpose(1, 2) +# multi_polylines = multi_polylines[:, ::5, :] + +# if padding_size >= 3: +# last_polyline = new_polylines[final_index * polyline_size:] +# last_polyline = last_polyline[torch.linspace(0, last_polyline.shape[0] - 1, steps=3).long()] +# if multi_polylines is not None: +# multi_polylines = torch.cat([multi_polylines, last_polyline.unsqueeze(0)], dim=0) +# else: +# multi_polylines = last_polyline.unsqueeze(0) + +# if multi_polylines is None: +# continue +# multi_polylines_list.append(multi_polylines) + +# if len(multi_polylines_list) > 0: +# multi_polylines_list = torch.cat(multi_polylines_list, dim=0) +# else: +# multi_polylines_list = None + +# return multi_polylines_list + + +def average_distance_vectorized(point_set1, centroids): + dists = np.sqrt(np.sum((point_set1[:, None, :, :] - centroids[None, :, :, :]) ** 2, axis=-1)) + return np.mean(dists, axis=2) + + +def assign_clusters(sub_X, centroids): + distances = average_distance_vectorized(sub_X, centroids) + return np.argmin(distances, axis=1) + + +class TokenProcessor(nn.Module): + + def __init__(self, token_size, + training: bool=False, + predict_motion: bool=False, + predict_state: bool=False, + predict_map: bool=False, + state_token: Dict[str, int]=None, **kwargs): + super().__init__() + + module_dir = os.path.dirname(os.path.dirname(__file__)) + self.agent_token_path = os.path.join(module_dir, f'tokens/agent_vocab_555_s2.pkl') + self.map_token_traj_path = os.path.join(module_dir, 'tokens/map_traj_token5.pkl') + assert os.path.exists(self.agent_token_path), f"File {self.agent_token_path} not found." + assert os.path.exists(self.map_token_traj_path), f"File {self.map_token_traj_path} not found." + + self.training = training + self.token_size = token_size + self.disable_invalid = not predict_state + self.predict_motion = predict_motion + self.predict_state = predict_state + self.predict_map = predict_map + + # define new special tokens + self.bos_token_index = token_size + self.eos_token_index = token_size + 1 + self.invalid_token_index = token_size + 2 + self.special_token_index = [] + self._init_token() + + # define agent states + self.invalid_state = int(state_token['invalid']) + self.valid_state = int(state_token['valid']) + self.enter_state = int(state_token['enter']) + self.exit_state = int(state_token['exit']) + + self.pl2seed_radius = kwargs.get('pl2seed_radius', None) + + self.noise = False + self.disturb = False + self.shift = 5 + self.training = False + self.current_step = 10 + + # debugging + self.debug_data = None + + def forward(self, data): + """ + Each pkl data represents a extracted scenario from raw tfrecord data + """ + data['agent']['av_index'] = data['agent']['av_idx'] + data = self._tokenize_agent(data) + # data = self._tokenize_map(data) + del data['city'] + if 'polygon_is_intersection' in data['map_polygon']: + del data['map_polygon']['polygon_is_intersection'] + if 'route_type' in data['map_polygon']: + del data['map_polygon']['route_type'] + + av_index = int(data['agent']['av_idx']) + data['ego_pos'] = data['agent']['token_pos'][[av_index]] + data['ego_heading'] = data['agent']['token_heading'][[av_index]] + + return data + + def _init_token(self): + + agent_token_data = pickle.load(open(self.agent_token_path, 'rb')) + for agent_type, token in agent_token_data['token_all'].items(): + token = torch.tensor(token, dtype=torch.float32) + self.register_buffer(f'agent_token_all_{agent_type}', token, persistent=False) # [n_token, 6, 4, 2] + + map_token_traj = pickle.load(open(self.map_token_traj_path, 'rb'))['traj_src'] + map_token_traj = torch.tensor(map_token_traj, dtype=torch.float32) + self.register_buffer('map_token_traj_src', map_token_traj, persistent=False) # [n_token, 11 * 2] + + # self.trajectory_token = agent_token_data['token'] # (token_size, 4, 2) + # self.trajectory_token_all = agent_token_data['token_all'] # (token_size, shift + 1, 4, 2) + # self.map_token = {'traj_src': map_token_traj['traj_src']} + + @staticmethod + def clean_heading(valid: torch.Tensor, heading: torch.Tensor) -> torch.Tensor: + valid_pairs = valid[:, :-1] & valid[:, 1:] + for i in range(heading.shape[1] - 1): + heading_diff = torch.abs(wrap_angle(heading[:, i] - heading[:, i + 1])) + change_needed = (heading_diff > 1.5) & valid_pairs[:, i] + heading[:, i + 1][change_needed] = heading[:, i][change_needed] + return heading + + def _extrapolate_agent_to_prev_token_step(self, valid, pos, heading, vel) -> Sequence[torch.Tensor]: + # [n_agent], max will give the first True step + first_valid_step = torch.max(valid, dim=1).indices + + for i, t in enumerate(first_valid_step): # extrapolate to previous 5th step. + n_step_to_extrapolate = t % self.shift + if (t == self.current_step) and (not valid[i, self.current_step - self.shift]): + # such that at least one token is valid in the history. + n_step_to_extrapolate = self.shift + + if n_step_to_extrapolate > 0: + vel[i, t - n_step_to_extrapolate : t] = vel[i, t] + valid[i, t - n_step_to_extrapolate : t] = True + heading[i, t - n_step_to_extrapolate : t] = heading[i, t] + + for j in range(n_step_to_extrapolate): + pos[i, t - j - 1] = pos[i, t - j] - vel[i, t] * 0.1 + + return valid, pos, heading, vel + + def _get_agent_shape(self, agent_type_masks: dict) -> torch.Tensor: + agent_shape = 0. + for type, type_mask in agent_type_masks.items(): + if type == 'veh': width = 2.; length = 4.8 + if type == 'ped': width = 1.; length = 2. + if type == 'cyc': width = 1.; length = 1. + agent_shape += torch.stack([width * type_mask, length * type_mask], dim=-1) + + return agent_shape + + def _get_token_traj_all(self, agent_type_masks: dict) -> torch.Tensor: + token_traj_all = 0. + for type, type_mask in agent_type_masks.items(): + token_traj_all += type_mask[:, None, None, None, None] * ( + getattr(self, f'agent_token_all_{type}').unsqueeze(0) + ) + return token_traj_all + + def _tokenize_agent(self, data): + + # get raw data + valid_mask = data['agent']['valid_mask'] # [n_agent, n_step] + agent_heading = data['agent']['heading'] # [n_agent, n_step] + agent_pos = data['agent']['position'][..., :2].contiguous() # [n_agent, n_step, 2] + agent_vel = data['agent']['velocity'] # [n_agent, n_step, 2] + agent_type = data['agent']['type'] + agent_category = data['agent']['category'] + + n_agent, n_all_step = valid_mask.shape + + agent_type_masks = { + "veh": agent_type == 0, + "ped": agent_type == 1, + "cyc": agent_type == 2, + } + agent_heading = self.clean_heading(valid_mask, agent_heading) + agent_shape = self._get_agent_shape(agent_type_masks) + token_traj_all = self._get_token_traj_all(agent_type_masks) + valid_mask, agent_pos, agent_heading, agent_vel = self._extrapolate_agent_to_prev_token_step( + valid_mask, agent_pos, agent_heading, agent_vel + ) + token_traj = token_traj_all[:, :, -1, ...] + data['agent']['token_traj_all'] = token_traj_all # [n_agent, n_token, 6, 4, 2] + data['agent']['token_traj'] = token_traj # [n_agent, n_token, 4, 2] + + valid_mask_shift = valid_mask.unfold(1, self.shift + 1, self.shift) + token_valid_mask = valid_mask_shift[:, :, 0] * valid_mask_shift[:, :, -1] + + # vehicle_mask = agent_type == 0 + # cyclist_mask = agent_type == 2 + # ped_mask = agent_type == 1 + + # veh_pos = agent_pos[vehicle_mask, :, :] + # veh_valid_mask = valid_mask[vehicle_mask, :] + # cyc_pos = agent_pos[cyclist_mask, :, :] + # cyc_valid_mask = valid_mask[cyclist_mask, :] + # ped_pos = agent_pos[ped_mask, :, :] + # ped_valid_mask = valid_mask[ped_mask, :] + + # index: [n_agent, n_step] contour: [n_agent, n_step, 4, 2] + token_index, token_contour, token_all = self._match_agent_token( + valid_mask, agent_pos, agent_heading, agent_shape, token_traj, None # token_traj_all + ) + + traj_pos = traj_heading = None + if len(token_all) > 0: + traj_pos = token_all.mean(dim=3) # [n_agent, n_step, 6, 2] + diff_xy = token_all[..., 0, :] - token_all[..., 3, :] + traj_heading = torch.arctan2(diff_xy[..., 1], diff_xy[..., 0]) + token_pos = token_contour.mean(dim=2) # [n_agent, n_step, 2] + diff_xy = token_contour[:, :, 0, :] - token_contour[:, :, 3, :] + token_heading = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0]) + + # token_index: (num_agent, num_timestep // shift) e.g. (49, 18) + # token_contour: (num_agent, num_timestep // shift, contour_dim, feat_dim, 2) e.g. (49, 18, 4, 2) + # veh_token_index, veh_token_contour = self._match_agent_token(veh_valid_mask, veh_pos, agent_heading[vehicle_mask], + # 'veh', agent_shape[vehicle_mask]) + # ped_token_index, ped_token_contour = self._match_agent_token(ped_valid_mask, ped_pos, agent_heading[ped_mask], + # 'ped', agent_shape[ped_mask]) + # cyc_token_index, cyc_token_contour = self._match_agent_token(cyc_valid_mask, cyc_pos, agent_heading[cyclist_mask], + # 'cyc', agent_shape[cyclist_mask]) + + # token_index = torch.zeros((agent_pos.shape[0], veh_token_index.shape[1])).to(torch.int64) + # token_index[vehicle_mask] = veh_token_index + # token_index[ped_mask] = ped_token_index + # token_index[cyclist_mask] = cyc_token_index + + # ! compute agent states + bos_index = torch.argmax(token_valid_mask.long(), dim=1) + eos_index = token_valid_mask.shape[1] - 1 - torch.argmax(torch.flip(token_valid_mask.long(), dims=[1]), dim=1) + state_index = torch.ones_like(token_index) # init with all valid + step_index = torch.arange(state_index.shape[1])[None].repeat(state_index.shape[0], 1).to(token_index.device) + state_index[step_index == bos_index[:, None]] = self.enter_state + state_index[step_index == eos_index[:, None]] = self.exit_state + state_index[(step_index < bos_index[:, None]) | (step_index > eos_index[:, None])] = self.invalid_state + # ! IMPORTANT: if the last step is exit token, should convert it back to valid token + state_index[state_index[:, -1] == self.exit_state, -1] = self.valid_state + + # update token attributions according to state tokens + token_valid_mask[state_index == self.enter_state] = False + token_pos[state_index == self.invalid_state] = 0. + token_heading[state_index == self.invalid_state] = 0. + for i in range(self.shift, agent_pos.shape[1], self.shift): + is_bos = state_index[:, i // self.shift - 1] == self.enter_state + token_pos[is_bos, i // self.shift - 1] = agent_pos[is_bos, i].clone() + # token_heading[is_bos, i // self.shift - 1] = agent_heading[is_bos, i].clone() + token_index[state_index == self.invalid_state] = -1 + token_index[state_index == self.enter_state] = -2 + + # acc_token_valid_step = torch.concat([torch.zeros_like(token_valid_mask[:, :1]), + # torch.cumsum(token_valid_mask.int(), dim=1), + # torch.zeros_like(token_valid_mask[:, -1:])], dim=1) + # state_index = torch.ones_like(token_index) # init with all valid + # max_valid_index = torch.argmax(acc_token_valid_step, dim=1) + # for step in range(1, acc_token_valid_step.shape[1] - 1): + + # # replace part of motion tokens with special tokens + # is_bos = (acc_token_valid_step[:, step] == 0) & (acc_token_valid_step[:, step + 1] == 1) + # is_eos = (step == max_valid_index) & (step < acc_token_valid_step.shape[1] - 2) & ~is_bos + # is_invalid = ~token_valid_mask[:, step - 1] & ~is_bos & ~is_eos + + # state_index[is_bos, step - 1] = self.enter_state + # state_index[is_eos, step - 1] = self.exit_state + # state_index[is_invalid, step - 1] = self.invalid_state + + # token_valid_mask[state_index[:, 0] == self.valid_state, 0] = False + # state_index[state_index[:, 0] == self.valid_state, 0] = self.enter_state + + # token_contour = torch.zeros((agent_pos.shape[0], veh_token_contour.shape[1], + # veh_token_contour.shape[2], veh_token_contour.shape[3])) + # token_contour[vehicle_mask] = veh_token_contour + # token_contour[ped_mask] = ped_token_contour + # token_contour[cyclist_mask] = cyc_token_contour + + raw_token_valid_mask = token_valid_mask.clone() + if not self.disable_invalid: + token_valid_mask = torch.ones_like(token_valid_mask).bool() + + # apply mask + # apply_mask = raw_token_valid_mask.sum(dim=-1) > 2 + # if self.training and os.getenv('AUG_MASK', False): + # aug_mask = torch.randint(0, 2, (raw_token_valid_mask.shape[0],)).to(raw_token_valid_mask).bool() + # apply_mask &= aug_mask + + # remove invalid agents which are outside the range of pl2inva_radius + # remove_ina_mask = torch.zeros_like(data['agent']['train_mask']) + # if self.pl2seed_radius is not None: + # num_history_token = 1 if self.training else 2 # NOTE: hard code!!! + # av_index = int(data['agent']['av_index']) + # is_invalid = torch.any(state_index[:, :num_history_token] == self.invalid_state, dim=-1) + # ina_bos_mask = (state_index == self.enter_state) & is_invalid[:, None] + # invalid_bos_step = torch.nonzero(ina_bos_mask, as_tuple=False) + # av_bos_pos = token_pos[av_index, invalid_bos_step[:, 1]] # (num_invalid_bos, 2) + # ina_bos_pos = token_pos[invalid_bos_step[:, 0], invalid_bos_step[:, 1]] # (num_invalid_bos, 2) + # distance = torch.sqrt(torch.sum((ina_bos_pos - av_bos_pos) ** 2, dim=-1)) + # remove_ina_mask = (distance > self.pl2seed_radius) | (distance < 0.) + # # apply_mask[invalid_bos_step[remove_ina_mask, 0]] = False + + # data['agent']['remove_ina_mask'] = remove_ina_mask + + # apply_mask[int(data['agent']['av_index'])] = True + # data['agent']['num_nodes'] = apply_mask.sum() + + # av_id = data['agent']['id'][data['agent']['av_index']] + # data['agent']['id'] = [data['agent']['id'][i] for i in range(len(apply_mask)) if apply_mask[i]] + # data['agent']['av_index'] = data['agent']['id'].index(av_id) + # data['agent']['id'] = torch.tensor(data['agent']['id'], dtype=torch.long) + + # agent_keys = ['valid_mask', 'predict_mask', 'type', 'category', 'position', 'heading', 'velocity', 'shape'] + # for key in agent_keys: + # if key in data['agent']: + # data['agent'][key] = data['agent'][key][apply_mask] + + # reset agent shapes + for i in range(n_agent): + bos_shape_index = torch.nonzero(torch.all(data['agent']['shape'][i] != 0., dim=-1))[0] + data['agent']['shape'][i, :] = data['agent']['shape'][i, bos_shape_index] + if torch.any(torch.all(data['agent']['shape'][i] == 0., dim=-1)): + raise ValueError(f"Found invalid shape values.") + + # compute mean height values for each scenario + raw_height = data['agent']['position'][:, self.current_step, 2] + valid_height = raw_token_valid_mask[:, 1].bool() + veh_mean_z = raw_height[agent_type_masks['veh'] & valid_height].mean() + ped_mean_z = raw_height[agent_type_masks['ped'] & valid_height].mean().nan_to_num_(veh_mean_z) # FIXME: hard code + cyc_mean_z = raw_height[agent_type_masks['cyc'] & valid_height].mean().nan_to_num_(veh_mean_z) + + # output + data['agent']['token_idx'] = token_index + data['agent']['state_idx'] = state_index + data['agent']['token_contour'] = token_contour + data['agent']['traj_pos'] = traj_pos + data['agent']['traj_heading'] = traj_heading + data['agent']['token_pos'] = token_pos + data['agent']['token_heading'] = token_heading + data['agent']['agent_valid_mask'] = token_valid_mask # (a, t) + data['agent']['raw_agent_valid_mask'] = raw_token_valid_mask + data['agent']['raw_height'] = dict(veh=veh_mean_z, + ped=ped_mean_z, + cyc=cyc_mean_z) + for type in ['veh', 'ped', 'cyc']: + data['agent'][f'trajectory_token_{type}'] = getattr( + self, f'agent_token_all_{type}') # [n_token, 6, 4, 2] + + return data + + def _match_agent_token(self, valid_mask, pos, heading, shape, token_traj, token_traj_all=None): + """ + Parameters: + valid_mask (torch.Tensor): Validity mask for agents over time. Shape: (n_agent, n_step) + pos (torch.Tensor): Positions of agents at each time step. Shape: (n_agent, n_step, 3) + heading (torch.Tensor): Headings of agents at each time step. Shape: (n_agent, n_step) + shape (torch.Tensor): Shape information of agents. Shape: (n_agent, 3) + token_traj (torch.Tensor): Token trajectories for agents. Shape: (n_agent, n_token, 4, 2) + token_traj_all (torch.Tensor): Token trajectories for all agents. Shape: (n_agnet, n_token_all, n_contour, 4, 2) + + Returns: + tuple: Contains token indices and contours for agents. + """ + + n_agent, n_step = valid_mask.shape + + # agent_token_src = self.trajectory_token[category] + # if self.shift <= 2: + # if category == 'veh': + # width = 1.0 + # length = 2.4 + # elif category == 'cyc': + # width = 0.5 + # length = 1.5 + # else: + # width = 0.5 + # length = 0.5 + # else: + # if category == 'veh': + # width = 2.0 + # length = 4.8 + # elif category == 'cyc': + # width = 1.0 + # length = 2.0 + # else: + # width = 1.0 + # length = 1.0 + + _, n_token, token_contour_dim, feat_dim = token_traj.shape + # agent_token_src = agent_token_src.reshape(1, token_num * token_contour_dim, feat_dim).repeat(agent_num, 0) + + token_index_list = [] + token_contour_list = [] + token_all = [] + + prev_heading = heading[:, 0] + prev_pos = pos[:, 0] + prev_token_idx = None + for i in range(self.shift, n_step, self.shift): # [5, 10, 15, ..., 90] + _valid_mask = valid_mask[:, i - self.shift] & valid_mask[:, i] + _invalid_mask = ~_valid_mask + + # transformation + theta = prev_heading + cos, sin = theta.cos(), theta.sin() + rot_mat = theta.new_zeros(n_agent, 2, 2) + rot_mat[:, 0, 0] = cos + rot_mat[:, 0, 1] = sin + rot_mat[:, 1, 0] = -sin + rot_mat[:, 1, 1] = cos + agent_token_world = torch.bmm(token_traj.flatten(1, 2), rot_mat).reshape(*token_traj.shape) + agent_token_world += prev_pos[:, None, None, :] + + cur_contour = cal_polygon_contour(pos[:, i], heading[:, i], shape) # [n_agent, 4, 2] + agent_token_index = torch.argmin( + torch.norm(agent_token_world - cur_contour[:, None, ...], dim=-1).sum(-1), dim=-1 + ) + agent_token_contour = agent_token_world[torch.arange(n_agent), agent_token_index] # [n_agent, 4, 2] + # agent_token_index = torch.from_numpy(np.argmin( + # np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), axis=2), + # axis=-1)) + + # except for the first timestep TODO + if prev_token_idx is not None and self.noise: + same_idx = prev_token_idx == agent_token_index + same_idx[:] = True + topk_indices = np.argsort( + np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), + axis=2), axis=-1)[:, :5] + sample_topk = np.random.choice(range(0, topk_indices.shape[1]), topk_indices.shape[0]) + agent_token_index[same_idx] = \ + torch.from_numpy(topk_indices[np.arange(topk_indices.shape[0]), sample_topk])[same_idx] + + # update prev_heading + prev_heading = heading[:, i].clone() + diff_xy = agent_token_contour[:, 0] - agent_token_contour[:, 3] + prev_heading[_valid_mask] = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])[_valid_mask] + + # update prev_pos + prev_pos = pos[:, i].clone() + prev_pos[_valid_mask] = agent_token_contour.mean(dim=1)[_valid_mask] + + prev_token_idx = agent_token_index + token_index_list.append(agent_token_index) + token_contour_list.append(agent_token_contour) + + # calculate tokenized trajectory + if token_traj_all is not None: + agent_token_all_world = torch.bmm(token_traj_all.flatten(1, 3), rot_mat).reshape(*token_traj_all.shape) + agent_token_all_world += prev_pos[:, None, None, None, :] + agent_token_all = agent_token_all_world[torch.arange(n_agent), agent_token_index] # [n_agent, 6, 4, 2] + token_all.append(agent_token_all) + + token_index = torch.stack(token_index_list, dim=1) # [n_agent, n_step] + token_contour = torch.stack(token_contour_list, dim=1) # [n_agent, n_step, 4, 2] + if len(token_all) > 0: + token_all = torch.stack(token_all, dim=1) # [n_agent, n_step, 6, 4, 2] + + # sanity check + assert tuple(token_index.shape) == (n_agent, n_step // self.shift), \ + f'Invalid token_index shape, got {token_index.shape}' + assert tuple(token_contour.shape )== (n_agent, n_step // self.shift, token_contour_dim, feat_dim), \ + f'Invalid token_contour shape, got {token_contour.shape}' + + # extra matching + # if not self.training: + # theta = heading[extra_mask, self.current_step - 1] + # prev_pos = pos[extra_mask, self.current_step - 1] + # cur_pos = pos[extra_mask, self.current_step] + # cur_heading = heading[extra_mask, self.current_step] + # cos, sin = theta.cos(), theta.sin() + # rot_mat = theta.new_zeros(extra_mask.sum(), 2, 2) + # rot_mat[:, 0, 0] = cos + # rot_mat[:, 0, 1] = sin + # rot_mat[:, 1, 0] = -sin + # rot_mat[:, 1, 1] = cos + # agent_token_world = torch.bmm(torch.from_numpy(token_last).to(torch.float), rot_mat).reshape( + # extra_mask.sum(), token_num, token_contour_dim, feat_dim) + # agent_token_world += prev_pos[:, None, None, :] + + # cur_contour = cal_polygon_contour(cur_pos[:, 0], cur_pos[:, 1], cur_heading, width, length) + # agent_token_index = torch.from_numpy(np.argmin( + # np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), axis=2), + # axis=-1)) + # token_contour_select = agent_token_world[torch.arange(extra_mask.sum()), agent_token_index] + + # token_index[extra_mask, 1] = agent_token_index + # token_contour[extra_mask, 1] = token_contour_select + + return token_index, token_contour, token_all + + @staticmethod + def _tokenize_map(data): + + data['map_polygon']['type'] = data['map_polygon']['type'].to(torch.uint8) + data['map_point']['type'] = data['map_point']['type'].to(torch.uint8) + pt2pl = data[('map_point', 'to', 'map_polygon')]['edge_index'] + pt_type = data['map_point']['type'].to(torch.uint8) + pt_side = torch.zeros_like(pt_type) + pt_pos = data['map_point']['position'][:, :2] + data['map_point']['orientation'] = wrap_angle(data['map_point']['orientation']) + pt_heading = data['map_point']['orientation'] + split_polyline_type = [] + split_polyline_pos = [] + split_polyline_theta = [] + split_polyline_side = [] + pl_idx_list = [] + split_polygon_type = [] + data['map_point']['type'].unique() + + for i in sorted(np.unique(pt2pl[1])): # number of polygons in the scenario + index = pt2pl[0, pt2pl[1] == i] # index of points which belongs to i-th polygon + polygon_type = data['map_polygon']["type"][i] + cur_side = pt_side[index] + cur_type = pt_type[index] + cur_pos = pt_pos[index] + cur_heading = pt_heading[index] + + for side_val in np.unique(cur_side): + for type_val in np.unique(cur_type): + if type_val == 13: + continue + indices = np.where((cur_side == side_val) & (cur_type == type_val))[0] + if len(indices) <= 2: + continue + split_polyline = interplating_polyline(cur_pos[indices].numpy(), cur_heading[indices].numpy()) + if split_polyline is None: + continue + new_cur_type = cur_type[indices][0] + new_cur_side = cur_side[indices][0] + map_polygon_type = polygon_type.repeat(split_polyline.shape[0]) + new_cur_type = new_cur_type.repeat(split_polyline.shape[0]) + new_cur_side = new_cur_side.repeat(split_polyline.shape[0]) + cur_pl_idx = torch.Tensor([i]) + new_cur_pl_idx = cur_pl_idx.repeat(split_polyline.shape[0]) + split_polyline_pos.append(split_polyline[..., :2]) + split_polyline_theta.append(split_polyline[..., 2]) + split_polyline_type.append(new_cur_type) + split_polyline_side.append(new_cur_side) + pl_idx_list.append(new_cur_pl_idx) + split_polygon_type.append(map_polygon_type) + + split_polyline_pos = torch.cat(split_polyline_pos, dim=0) + split_polyline_theta = torch.cat(split_polyline_theta, dim=0) + split_polyline_type = torch.cat(split_polyline_type, dim=0) + split_polyline_side = torch.cat(split_polyline_side, dim=0) + split_polygon_type = torch.cat(split_polygon_type, dim=0) + pl_idx_list = torch.cat(pl_idx_list, dim=0) + + data['map_save'] = {} + data['pt_token'] = {} + data['map_save']['traj_pos'] = split_polyline_pos + data['map_save']['traj_theta'] = split_polyline_theta[:, 0] # torch.arctan2(vec[:, 1], vec[:, 0]) + data['map_save']['pl_idx_list'] = pl_idx_list + data['pt_token']['type'] = split_polyline_type + data['pt_token']['side'] = split_polyline_side + data['pt_token']['pl_type'] = split_polygon_type + data['pt_token']['num_nodes'] = split_polyline_pos.shape[0] + + return data \ No newline at end of file diff --git a/backups/dev/datasets/scalable_dataset.py b/backups/dev/datasets/scalable_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..15d6d10fd802aa683bf09e20c14b0d2ee426c842 --- /dev/null +++ b/backups/dev/datasets/scalable_dataset.py @@ -0,0 +1,276 @@ +import os +import pickle +import torch +import json +import pytorch_lightning as pl +import pandas as pd +from tqdm import tqdm +from torch_geometric.data import HeteroData, Dataset +from torch_geometric.transforms import BaseTransform +from torch_geometric.loader import DataLoader +from typing import Callable, Dict, List, Optional + +from dev.datasets.preprocess import TokenProcessor + + +class MultiDataset(Dataset): + def __init__(self, + split: str, + raw_dir: List[str] = None, + transform: Optional[Callable] = None, + tfrecord_dir: Optional[str] = None, + token_size=512, + predict_motion: bool=False, + predict_state: bool=False, + predict_map: bool=False, + # state_token: Dict[str, int]=None, + # pl2seed_radius: float=None, + buffer_size: int=128, + logger=None) -> None: + + self.disable_invalid = not predict_state + self.predict_motion = predict_motion + self.predict_state = predict_state + self.predict_map = predict_map + self.logger = logger + if split not in ('train', 'val', 'test'): + raise ValueError(f'{split} is not a valid split') + self.training = split == 'train' + self.buffer_size = buffer_size + self._tfrecord_dir = tfrecord_dir + self.logger.debug('Starting loading dataset') + + raw_dir = os.path.expanduser(os.path.normpath(raw_dir)) + self._raw_files = sorted(os.listdir(raw_dir)) + + # for debugging + if int(os.getenv('OVERFIT', 0)): + # if self.training: + # # scenario_id = ['74ad7b76d5906d39', '13596229fd8cdb7e', '1d73db1fc42be3bf', '1351ea8b8333ddcb'] + # self._raw_files = ['74ad7b76d5906d39.pkl'] + self._raw_files[:9] + # else: + # self._raw_files = self._raw_files[:10] + self._raw_files = self._raw_files[:1] + # self._raw_files = ['1002fdc9826fc6d1.pkl'] + + # load meta infos and do filter + json_path = '/u/xiuyu/work/dev4/data/waymo_processed/meta_infos.json' + label = 'training' if split == 'train' else ('validation' if split == 'val' else split) + self.meta_infos = json.load(open(json_path, 'r', encoding='utf-8'))[label] + self.logger.debug(f"Loaded meta infos from {json_path}") + self.available_scenarios = list(self.meta_infos.keys()) + # self._raw_files = list(tqdm(filter(lambda fn: ( + # scenario_id := fn.removesuffix('.pkl') in self.available_scenarios and + # 8 <= self.meta_infos[scenario_id]['num_agents'] < self.buffer_size + # ), self._raw_files), leave=False)) + df = pd.DataFrame.from_dict(self.meta_infos, orient='index') + available_scenarios_set = set(self.available_scenarios) + df_filtered = df[(df.index.isin(available_scenarios_set)) & (df['num_agents'] >= 8) & (df['num_agents'] < self.buffer_size)] + valid_scenarios = set(df_filtered.index) + self._raw_files = list(tqdm(filter(lambda fn: fn.removesuffix('.pkl') in valid_scenarios, self._raw_files), leave=False)) + if len(self._raw_files) <= 0: + raise RuntimeError(f'Invalid number of data {len(self._raw_files)}!') + self._raw_paths = list(map(lambda fn: os.path.join(raw_dir, fn), self._raw_files)) + + self.logger.debug(f"The number of {split} dataset is {len(self._raw_paths)}") + self.logger.debug(f"The buffer size is {self.buffer_size}") + # self.token_processor = TokenProcessor(token_size, + # training=self.training, + # predict_motion=self.predict_motion, + # predict_state=self.predict_state, + # predict_map=self.predict_map, + # state_token=state_token, + # pl2seed_radius=pl2seed_radius) # 2048 + self.logger.debug(f"The used token size is {token_size}.") + super().__init__(transform=transform, pre_transform=None, pre_filter=None) + + def len(self) -> int: + return len(self._raw_paths) + + def get(self, idx: int): + """ + Load pkl file (each represents a 91s scenario for waymo dataset) + """ + with open(self._raw_paths[idx], 'rb') as handle: + data = pickle.load(handle) + + if self._tfrecord_dir is not None: + data['tfrecord_path'] = os.path.join(self._tfrecord_dir, f"{data['scenario_id']}.tfrecords") + + # data = self.token_processor.preprocess(data) + return data + + +class WaymoTargetBuilder(BaseTransform): + + def __init__(self, + num_historical_steps: int, + num_future_steps: int, + max_num: int, + training: bool=False) -> None: + + self.max_num = max_num + self.num_historical_steps = num_historical_steps + self.num_future_steps = num_future_steps + self.step_current = num_historical_steps - 1 + self.training = training + + def _score_trained_agents(self, data): + pos = data['agent']['position'] + av_index = torch.where(data['agent']['role'][:, 0])[0].item() + distance = torch.norm(pos - pos[av_index], dim=-1) + + # we do not believe the perception out of range of 150 meters + data['agent']['valid_mask'] &= distance < 150 + + # we do not predict vehicle too far away from ego car + role_train_mask = data['agent']['role'].any(-1) + extra_train_mask = (distance[:, self.step_current] < 100) & ( + data['agent']['valid_mask'][:, self.step_current + 1 :].sum(-1) >= 5 + ) + + train_mask = extra_train_mask | role_train_mask + if train_mask.sum() > self.max_num: # too many vehicle + _indices = torch.where(extra_train_mask & ~role_train_mask)[0] + selected_indices = _indices[ + torch.randperm(_indices.size(0))[: self.max_num - role_train_mask.sum()] + ] + data['agent']['train_mask'] = role_train_mask + data['agent']['train_mask'][selected_indices] = True + else: + data['agent']['train_mask'] = train_mask # [n_agent] + + return data + + def __call__(self, data) -> HeteroData: + + if self.training: + self._score_trained_agents(data) + + data = TokenProcessor._tokenize_map(data) + # data keys: dict_keys(['scenario_id', 'agent', 'map_polygon', 'map_point', ('map_point', 'to', 'map_polygon'), ('map_polygon', 'to', 'map_polygon'), 'map_save', 'pt_token']) + return HeteroData(data) + + +class MultiDataModule(pl.LightningDataModule): + transforms = { + 'WaymoTargetBuilder': WaymoTargetBuilder, + } + + dataset = { + 'scalable': MultiDataset, + } + + def __init__(self, + root: str, + train_batch_size: int, + val_batch_size: int, + test_batch_size: int, + shuffle: bool = False, + num_workers: int = 0, + pin_memory: bool = True, + persistent_workers: bool = True, + train_raw_dir: Optional[str] = None, + val_raw_dir: Optional[str] = None, + test_raw_dir: Optional[str] = None, + train_processed_dir: Optional[str] = None, + val_processed_dir: Optional[str] = None, + test_processed_dir: Optional[str] = None, + val_tfrecords_splitted: Optional[str] = None, + transform: Optional[str] = None, + dataset: Optional[str] = None, + num_historical_steps: int = 50, + num_future_steps: int = 60, + processor='ntp', + token_size=512, + predict_motion: bool=False, + predict_state: bool=False, + predict_map: bool=False, + state_token: Dict[str, int]=None, + pl2seed_radius: float=None, + max_num: int=32, + buffer_size: int=256, + logger=None, + **kwargs) -> None: + + super(MultiDataModule, self).__init__() + self.root = root + self.dataset_class = dataset + self.train_batch_size = train_batch_size + self.val_batch_size = val_batch_size + self.test_batch_size = test_batch_size + self.shuffle = shuffle + self.num_workers = num_workers + self.pin_memory = pin_memory + self.persistent_workers = persistent_workers and num_workers > 0 + self.train_raw_dir = train_raw_dir + self.val_raw_dir = val_raw_dir + self.test_raw_dir = test_raw_dir + self.train_processed_dir = train_processed_dir + self.val_processed_dir = val_processed_dir + self.test_processed_dir = test_processed_dir + self.val_tfrecords_splitted = val_tfrecords_splitted + self.processor = processor + self.token_size = token_size + self.predict_motion = predict_motion + self.predict_state = predict_state + self.predict_map = predict_map + self.state_token = state_token + self.pl2seed_radius = pl2seed_radius + self.buffer_size = buffer_size + self.logger = logger + + self.train_transform = MultiDataModule.transforms[transform](num_historical_steps, + num_future_steps, + max_num=max_num, + training=True) + self.val_transform = MultiDataModule.transforms[transform](num_historical_steps, + num_future_steps, + max_num=max_num, + training=False) + + def setup(self, stage: Optional[str] = None) -> None: + general_params = dict(token_size=self.token_size, + predict_motion=self.predict_motion, + predict_state=self.predict_state, + predict_map=self.predict_map, + buffer_size=self.buffer_size, + logger=self.logger) + + if stage == 'fit' or stage is None: + self.train_dataset = MultiDataModule.dataset[self.dataset_class](split='train', + raw_dir=self.train_raw_dir, + transform=self.train_transform, + **general_params) + self.val_dataset = MultiDataModule.dataset[self.dataset_class](split='val', + raw_dir=self.val_raw_dir, + transform=self.val_transform, + tfrecord_dir=self.val_tfrecords_splitted, + **general_params) + if stage == 'validate': + self.val_dataset = MultiDataModule.dataset[self.dataset_class](split='val', + raw_dir=self.val_raw_dir, + transform=self.val_transform, + tfrecord_dir=self.val_tfrecords_splitted, + **general_params) + if stage == 'test': + self.test_dataset = MultiDataModule.dataset[self.dataset_class](split='test', + raw_dir=self.test_raw_dir, + transform=self.val_transform, + tfrecord_dir=self.val_tfrecords_splitted, + **general_params) + + def train_dataloader(self): + return DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=self.shuffle, + num_workers=self.num_workers, pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers) + + def val_dataloader(self): + return DataLoader(self.val_dataset, batch_size=self.val_batch_size, shuffle=False, + num_workers=self.num_workers, pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers) + + def test_dataloader(self): + return DataLoader(self.test_dataset, batch_size=self.test_batch_size, shuffle=False, + num_workers=self.num_workers, pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers) diff --git a/backups/dev/metrics/box_utils.py b/backups/dev/metrics/box_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..13553c8c58712d8e8c5e55da6a843826f593b6e5 --- /dev/null +++ b/backups/dev/metrics/box_utils.py @@ -0,0 +1,113 @@ +import torch +from torch import Tensor + + +def get_yaw_rotation_2d(yaw): + """ + Gets a 2D rotation matrix given a yaw angle. + + Args: + yaw: torch.Tensor, rotation angle in radians. Can be any shape except empty. + + Returns: + rotation: torch.Tensor with shape [..., 2, 2], where `...` matches input shape. + """ + cos_yaw = torch.cos(yaw) + sin_yaw = torch.sin(yaw) + + rotation = torch.stack([ + torch.stack([cos_yaw, -sin_yaw], dim=-1), + torch.stack([sin_yaw, cos_yaw], dim=-1), + ], dim=-2) # Shape: [..., 2, 2] + + return rotation + + +def get_yaw_rotation(yaw): + """ + Computes a 3D rotation matrix given a yaw angle (rotation around the Z-axis). + + Args: + yaw: torch.Tensor of any shape, representing yaw angles in radians. + + Returns: + rotation: torch.Tensor of shape [input_shape, 3, 3], representing the rotation matrices. + """ + cos_yaw = torch.cos(yaw) + sin_yaw = torch.sin(yaw) + ones = torch.ones_like(yaw) + zeros = torch.zeros_like(yaw) + + return torch.stack([ + torch.stack([cos_yaw, -sin_yaw, zeros], dim=-1), + torch.stack([sin_yaw, cos_yaw, zeros], dim=-1), + torch.stack([zeros, zeros, ones], dim=-1), + ], dim=-2) + + +def get_transform(rotation, translation): + """ + Combines an NxN rotation matrix and an Nx1 translation vector into an (N+1)x(N+1) transformation matrix. + + Args: + rotation: torch.Tensor of shape [..., N, N], representing rotation matrices. + translation: torch.Tensor of shape [..., N], representing translation vectors. + This must have the same dtype as rotation. + + Returns: + transform: torch.Tensor of shape [..., (N+1), (N+1)], representing the transformation matrices. + This has the same dtype as rotation. + """ + # [..., N, 1] + translation_n_1 = translation.unsqueeze(-1) + + # [..., N, N+1] - Combine rotation and translation + transform = torch.cat([rotation, translation_n_1], dim=-1) + + # [..., N] - Create the last row, which is [0, 0, ..., 0, 1] + last_row = torch.zeros_like(translation) + last_row = torch.cat([last_row, torch.ones_like(last_row[..., :1])], dim=-1) + + # [..., N+1, N+1] - Append the last row to form the final transformation matrix + transform = torch.cat([transform, last_row.unsqueeze(-2)], dim=-2) + + return transform + + +def get_upright_3d_box_corners(boxes: Tensor): + """ + Given a set of upright 3D bounding boxes, return its 8 corner points. + + Args: + boxes: torch.Tensor [N, 7]. The inner dims are [center{x,y,z}, length, width, + height, heading]. + + Returns: + corners: torch.Tensor [N, 8, 3]. + """ + center_x, center_y, center_z, length, width, height, heading = boxes.unbind(dim=-1) + + # Compute rotation matrix [N, 3, 3] + rotation = get_yaw_rotation(heading) + + # Translation [N, 3] + translation = torch.stack([center_x, center_y, center_z], dim=-1) + + l2, w2, h2 = length * 0.5, width * 0.5, height * 0.5 + + # Define the 8 corners in local coordinates [N, 8, 3] + corners_local = torch.stack([ + torch.stack([ l2, w2, -h2], dim=-1), + torch.stack([-l2, w2, -h2], dim=-1), + torch.stack([-l2, -w2, -h2], dim=-1), + torch.stack([ l2, -w2, -h2], dim=-1), + torch.stack([ l2, w2, h2], dim=-1), + torch.stack([-l2, w2, h2], dim=-1), + torch.stack([-l2, -w2, h2], dim=-1), + torch.stack([ l2, -w2, h2], dim=-1), + ], dim=1) # Shape: [N, 8, 3] + + # Rotate and translate the corners + corners = torch.einsum('n i j, n k j -> n k i', rotation, corners_local) + translation.unsqueeze(1) + + return corners diff --git a/backups/dev/metrics/compute_metrics.py b/backups/dev/metrics/compute_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..5ce78f861953f9cb3d53375051467446da812fa0 --- /dev/null +++ b/backups/dev/metrics/compute_metrics.py @@ -0,0 +1,1812 @@ +# ! Metrics Calculation +import concurrent.futures +import os +import torch +import tensorflow as tf +import collections +import dataclasses +import fnmatch +import json +import pandas as pd +import pickle +import copy +import concurrent +import multiprocessing +from torch_geometric.utils import degree +from functools import partial +from dataclasses import dataclass, field +from tqdm import tqdm +from argparse import ArgumentParser +from torch import Tensor +from google.protobuf import text_format +from torchmetrics import Metric +from typing import Optional, Sequence, List, Dict + +from waymo_open_dataset.utils.sim_agents import submission_specs + +from dev.utils.visualization import safe_run +from dev.utils.func import CONSOLE +from dev.datasets.scalable_dataset import WaymoTargetBuilder +from dev.datasets.preprocess import TokenProcessor, SHIFT, AGENT_STATE +from dev.metrics import trajectory_features, interact_features, map_features, placement_features +from dev.metrics.protos import scenario_pb2, long_metrics_pb2 + + +_METRIC_FIELD_NAMES_BY_BUCKET = { + 'kinematic': [ + 'linear_speed', 'linear_acceleration', + 'angular_speed', 'angular_acceleration', + ], + 'interactive': [ + 'distance_to_nearest_object', 'collision_indication', + 'time_to_collision', + ], + 'map_based': [ + # 'distance_to_road_edge', 'offroad_indication' + ], + 'placement_based': [ + 'num_placement', 'num_removement', + 'distance_placement', 'distance_removement', + ] +} +_METRIC_FIELD_NAMES = ( + _METRIC_FIELD_NAMES_BY_BUCKET['kinematic'] + + _METRIC_FIELD_NAMES_BY_BUCKET['interactive'] + + _METRIC_FIELD_NAMES_BY_BUCKET['map_based'] + + _METRIC_FIELD_NAMES_BY_BUCKET['placement_based'] +) + + +""" Help Functions """ + +def _arg_gather(tensor: Tensor, reference_tensor: Tensor) -> Tensor: + """Finds corresponding indices in `tensor` for each element in `reference_tensor`. + + Args: + tensor: A 1D tensor without repetitions. + reference_tensor: A 1D tensor containing items from `tensor`. + + Returns: + A tensor of indices such that `tensor[indices] == reference_tensor`. + """ + assert tensor.ndim == 1, "tensor must be 1D" + assert reference_tensor.ndim == 1, "reference_tensor must be 1D" + + # Create the comparison matrix + bit_mask = tensor[None, :] == reference_tensor[:, None] # Shape: [len(reference_tensor), len(tensor)] + + # Count the matches along `tensor` dimension + bit_mask_sum = bit_mask.int().sum(dim=1) + + if (bit_mask_sum < 1).any(): + raise ValueError( + 'Some items in `reference_tensor` are missing from `tensor`: ' + f'\n{reference_tensor} \nvs. \n{tensor}.' + ) + + if (bit_mask_sum > 1).any(): + raise ValueError('Some items in `tensor` are repeated.') + + # Compute indices + indices = torch.matmul(bit_mask.int(), torch.arange(tensor.shape[0], dtype=torch.int32)) + return indices + + +def is_valid_sim_agent(track: scenario_pb2.Track) -> bool: # type: ignore + """Checks if the object needs to be resimulated as a sim agent. + + For the Sim Agents challenge, every object that is valid at the + `current_time_index` step (here hardcoded to 10) needs to be resimulated. + + Args: + track: A track proto for a single object. + + Returns: + A boolean flag, True if the object needs to be resimulated, False otherwise. + """ + return track.states[submission_specs.CURRENT_TIME_INDEX].valid + + +def get_sim_agent_ids( + scenario: scenario_pb2.Scenario) -> Sequence[int]: # type: ignore + """Returns the list of object IDs that needs to be resimulated. + + Internally calls `is_valid_sim_agent` to verify the simulation criteria, + i.e. is the object valid at `current_time_index`. + + Args: + scenario: The Scenario proto containing the data. + + Returns: + A list of int IDs, containing all the objects that need to be simulated. + """ + object_ids = [] + for track in scenario.tracks: + if is_valid_sim_agent(track): + object_ids.append(track.id) + return object_ids + + +def get_evaluation_agent_ids( + scenario: scenario_pb2.Scenario) -> Sequence[int]: # type: ignore + # Start with the AV object. + object_ids = {scenario.tracks[scenario.sdc_track_index].id} + # Add the `tracks_to_predict` objects. + for required_prediction in scenario.tracks_to_predict: + object_ids.add(scenario.tracks[required_prediction.track_index].id) + return sorted(object_ids) + + +""" Base Data Classes s""" + +@dataclass(frozen=True) +class ObjectTrajectories: + + x: Tensor + y: Tensor + z: Tensor + heading: Tensor + length: Tensor + width: Tensor + height: Tensor + valid: Tensor + object_id: Tensor + object_type: Tensor + + state: Optional[Tensor] = None + token_pos: Optional[Tensor] = None + token_heading: Optional[Tensor] = None + token_valid: Optional[Tensor] = None + processed_object_id: Optional[Tensor] = None + av_id: Optional[int] = None + processed_av_id: Optional[int] = None + + def slice_time(self, start_index: int = 0, end_index: Optional[int] = None): + return ObjectTrajectories( + x=self.x[..., start_index:end_index], + y=self.y[..., start_index:end_index], + z=self.z[..., start_index:end_index], + heading=self.heading[..., start_index:end_index], + length=self.length[..., start_index:end_index], + width=self.width[..., start_index:end_index], + height=self.height[..., start_index:end_index], + valid=self.valid[..., start_index:end_index], + object_id=self.object_id, + object_type=self.object_type, + + # these properties can only come from processed file + state=self.state, + token_pos=self.token_pos, + token_heading=self.token_heading, + token_valid=self.token_valid, + processed_object_id=self.processed_object_id, + av_id=self.av_id, + processed_av_id=self.processed_av_id, + ) + + def gather_objects(self, object_indices: Tensor): + assert object_indices.ndim == 1, "object_indices must be 1D" + return ObjectTrajectories( + x=torch.index_select(self.x, dim=-2, index=object_indices), + y=torch.index_select(self.y, dim=-2, index=object_indices), + z=torch.index_select(self.z, dim=-2, index=object_indices), + heading=torch.index_select(self.heading, dim=-2, index=object_indices), + length=torch.index_select(self.length, dim=-2, index=object_indices), + width=torch.index_select(self.width, dim=-2, index=object_indices), + height=torch.index_select(self.height, dim=-2, index=object_indices), + valid=torch.index_select(self.valid, dim=-2, index=object_indices), + object_id=torch.index_select(self.object_id, dim=-1, index=object_indices), + object_type=torch.index_select(self.object_type, dim=-1, index=object_indices), + + # these properties can only come from processed file + state=self.state, + token_pos=self.token_pos, + token_heading=self.token_heading, + token_valid=self.token_valid, + processed_object_id=self.processed_object_id, + av_id=self.av_id, + processed_av_id=self.processed_av_id, + ) + + def gather_objects_by_id(self, object_ids: tf.Tensor): + indices = _arg_gather(self.object_id, object_ids) + return self.gather_objects(indices) + + @classmethod + def _get_init_dict_from_processed(cls, scenario: dict): + """Load from processed pkl data""" + position = scenario['agent']['position'] + heading = scenario['agent']['heading'] + shape = scenario['agent']['shape'] + object_ids = scenario['agent']['id'] + object_types = scenario['agent']['type'] + valid = scenario['agent']['valid_mask'] + + init_dict = dict(x=position[..., 0], + y=position[..., 1], + z=position[..., 2], + heading=heading, + length=shape[..., 0], + width=shape[..., 1], + height=shape[..., 2], + valid=valid, + object_ids=object_ids, + object_types=object_types) + + return init_dict + + @classmethod + def _get_init_dict_from_raw(cls, + scenario: scenario_pb2.Scenario): # type: ignore + + """Load from tfrecords data""" + states, dimensions, objects = [], [], [] + for track in scenario.tracks: # n_object + # Iterate over a single object's states. + track_states, track_dimensions = [], [] + for state in track.states: # n_timestep + track_states.append((state.center_x, state.center_y, state.center_z, + state.heading, state.valid)) + track_dimensions.append((state.length, state.width, state.height)) + # Adds to the global states. + states.append(list(zip(*track_states))) + dimensions.append(list(zip(*track_dimensions))) + objects.append((track.id, track.object_type)) + + # Unpack and convert to tf tensors. + x, y, z, heading, valid = [torch.tensor(s) for s in zip(*states)] + length, width, height = [torch.tensor(s) for s in zip(*dimensions)] + object_ids, object_types = [torch.tensor(s) for s in zip(*objects)] + + av_id = object_ids[scenario.sdc_track_index] + + init_dict = dict(x=x, y=y, z=z, + heading=heading, + length=length, + width=width, + height=height, + valid=valid, + object_id=object_ids, + object_type=object_types, + av_id=int(av_id)) + + return init_dict + + @classmethod + def from_scenario(cls, + scenario: scenario_pb2.Scenario, # type: ignore + processed_scenario: Optional[dict]=None, + from_where: str='raw'): + + if from_where == 'raw': + init_dict = cls._get_init_dict_from_raw(scenario) + elif from_where == 'processed': + assert processed_scenario is not None, f'`processed_scenario` should be given!' + init_dict = cls._get_init_dict_from_processed(processed_scenario) + else: + raise RuntimeError(f'Invalid from {from_where}') + + if processed_scenario is not None: + init_dict.update(state=processed_scenario['agent']['state_idx'], + token_pos=processed_scenario['agent']['token_pos'], + token_heading=processed_scenario['agent']['token_heading'], + token_valid=processed_scenario['agent']['raw_agent_valid_mask'], + processed_object_id=processed_scenario['agent']['id'], + processed_av_id=int(processed_scenario['agent']['id'][ + processed_scenario['agent']['av_idx'] + ]), + ) + + return cls(**init_dict) + + +@dataclass +class ScenarioRollouts: + scenario_id: Optional[str] = None + joint_scenes: List[ObjectTrajectories] = field(default_factory=list) + + +""" Conversion Methods """ + +def scenario_to_trajectories( + scenario: scenario_pb2.Scenario, # type: ignore + processed_scenario: Optional[dict]=None, + from_where: Optional[str]='raw', + remove_history: Optional[bool]=False +) -> ObjectTrajectories: + """Converts a WOMD Scenario proto into the `ObjectTrajectories`. + + Returns: + A `ObjectTrajectories` with trajectories copied from data. + """ + trajectories = ObjectTrajectories.from_scenario(scenario, + processed_scenario, + from_where, + ) + # Slice by the required sim agents. + sim_agent_ids = get_sim_agent_ids(scenario) + # CONSOLE.log(f'sim_agent_ids of log scenario: {sim_agent_ids} total: {len(sim_agent_ids)}') + trajectories = trajectories.gather_objects_by_id(torch.tensor(sim_agent_ids)) + + if remove_history: + # Slice in time to only include steps after `current_time_index`. + trajectories = trajectories.slice_time(submission_specs.CURRENT_TIME_INDEX + 1) # 10 + 1 + if trajectories.valid.shape[-1] != submission_specs.N_SIMULATION_STEPS: # 80 simulated steps + raise ValueError( + 'The Scenario used does not include the right number of time steps. ' + f'Expected: {submission_specs.N_SIMULATION_STEPS}, ' + f'Actual: {trajectories.valid.shape[-1]}.') + + return trajectories + + +def _unbatch(src: Tensor, batch: Tensor, dim: int = 0) -> List[Tensor]: + sizes = degree(batch, dtype=torch.long).tolist() + return src.split(sizes, dim) + + +def get_scenario_id_int_tensor(scenario_id: List[str], device: torch.device=torch.device('cpu')) -> torch.Tensor: + scenario_id_int_tensor = [] + for str_id in scenario_id: + int_id = [-1] * 16 # max_len of scenario_id string is 16 + for i, c in enumerate(str_id): + int_id[i] = ord(c) + scenario_id_int_tensor.append( + torch.tensor(int_id, dtype=torch.int32, device=device) + ) + return torch.stack(scenario_id_int_tensor, dim=0) # [n_scenario, 16] + + +def output_to_rollouts(scenario: dict) -> List[ScenarioRollouts]: # n_scenario + # scenario_id: Tensor, # [n_scenario, n_str_length] + # agent_id: Tensor, # [n_agent, n_rollout] + # agent_batch: Tensor, # [n_agent] + # pred_traj: Tensor, # [n_agent, n_rollout, n_step, 2] + # pred_z: Tensor, # [n_agent, n_rollout, n_step] + # pred_head: Tensor, # [n_agent, n_rollout, n_step] + # pred_shape: Tensor, # [n_agent, n_rollout, 3] + # pred_type: Tensor, # [n_agent, n_rollout] + # pred_state: Tensor, # [n_agent, n_rollout, n_step] + scenario_id = scenario['scenario_id'] + av_id = ( + scenario['av_id'] if 'av_id' in scenario else -1 + ) + agent_id = scenario['agent_id'] + agent_batch = scenario['agent_batch'] + pred_traj = scenario['pred_traj'] + pred_z = scenario['pred_z'] + pred_head = scenario['pred_head'] + pred_shape = scenario['pred_shape'] + pred_type = scenario['pred_type'] + pred_state = ( + scenario['pred_state'] if 'pred_state' in scenario else + torch.zeros_like(pred_z).long() + ) + pred_valid = scenario['pred_valid'] + token_pos = scenario['token_pos'] + token_head = scenario['token_head'] + + # CONSOLE.log("Generate scenario rollouts ...") + # CONSOLE.log(f'scenario_id: {scenario_id}') + # CONSOLE.log(f'agent_id: {agent_id.flatten()} total: {agent_id.shape}') + # CONSOLE.log(f'av_id: {av_id}') + # CONSOLE.log(f'agent_batch: {agent_batch} total: {agent_batch.shape}') + # CONSOLE.log(f'pred_traj: {pred_traj.shape}') + # CONSOLE.log(f'pred_z: {pred_z.shape}') + # CONSOLE.log(f'pred_head: {pred_head.shape}') + # CONSOLE.log(f'pred_shape: {pred_shape.shape}') + # CONSOLE.log(f'pred_type: {pred_type.shape}') + # CONSOLE.log(f'pred_state: {pred_state.shape}') + # CONSOLE.log(f'token_pos: {token_pos.shape}') + # CONSOLE.log(f'token_head: {token_head.shape}') + + scenario_id = scenario_id.cpu().numpy() + n_agent, n_rollout, n_step, _ = pred_traj.shape + agent_id = _unbatch(agent_id, agent_batch) + pred_traj = _unbatch(pred_traj, agent_batch) + pred_z = _unbatch(pred_z, agent_batch) + pred_head = _unbatch(pred_head, agent_batch) + pred_shape = _unbatch(pred_shape, agent_batch) + pred_type = _unbatch(pred_type, agent_batch) + pred_state = _unbatch(pred_state, agent_batch) + pred_valid = _unbatch(pred_valid, agent_batch) + token_pos = _unbatch(token_pos, agent_batch) + token_head = _unbatch(token_head, agent_batch) + + agent_id = [x.cpu() for x in agent_id] + pred_traj = [x.cpu() for x in pred_traj] + pred_z = [x.cpu() for x in pred_z] + pred_head = [x.cpu() for x in pred_head] + pred_shape = [x[:, :, None].repeat(1, 1, n_step, 1).cpu() for x in pred_shape] + pred_type = [x[:, :, None].repeat(1, 1, n_step, 1).cpu() for x in pred_type] + pred_state = [x.cpu() for x in pred_state] + pred_valid = [x.cpu() for x in pred_valid] + token_pos = [x.cpu() for x in token_pos] + token_head = [x.cpu() for x in token_head] + + n_scenario = scenario_id.shape[0] + scenario_rollouts = [] + for i_scenario in range(n_scenario): + joint_scenes = [] + for i_rollout in range(n_rollout): # 1 + joint_scenes.append( + ObjectTrajectories( + x=pred_traj[i_scenario][:, i_rollout, :, 0], + y=pred_traj[i_scenario][:, i_rollout, :, 1], + z=pred_z[i_scenario][:, i_rollout], + heading=pred_head[i_scenario][:, i_rollout], + length=pred_shape[i_scenario][:, i_rollout, :, 0], + width=pred_shape[i_scenario][:, i_rollout, :, 1], + height=pred_shape[i_scenario][:, i_rollout, :, 2], + valid=pred_valid[i_scenario][:, i_rollout], + state=pred_state[i_scenario][:, i_rollout], + object_id=agent_id[i_scenario][:, i_rollout], + processed_object_id=agent_id[i_scenario][:, i_rollout], + object_type=pred_type[i_scenario][:, i_rollout], + token_pos=token_pos[i_scenario][:, i_rollout, :, :2], + token_heading=token_head[i_scenario][:, i_rollout], + av_id=av_id, + processed_av_id=av_id, + ) + ) + + _str_scenario_id = "".join([chr(x) for x in scenario_id[i_scenario] if x > 0]) + scenario_rollouts.append( + ScenarioRollouts( + joint_scenes=joint_scenes, scenario_id=_str_scenario_id + ) + ) + + # CONSOLE.log(f'n_scenario: {len(scenario_rollouts)}') + # CONSOLE.log(f'n_rollout: {len(scenario_rollouts[0].joint_scenes)}') + # CONSOLE.log(f'x shape: {scenario_rollouts[0].joint_scenes[0].x.shape}') + + return scenario_rollouts + + +""" Compute Metric Features """ + +def _compute_metametric( + config: long_metrics_pb2.SimAgentMetricsConfig, # type: ignore + metrics: long_metrics_pb2.SimAgentMetrics, # type: ignore +): + """Computes the meta-metric aggregation.""" + metametric = 0.0 + for field_name in _METRIC_FIELD_NAMES: + likelihood_field_name = field_name + '_likelihood' + weight = getattr(config, field_name).metametric_weight + metric_score = getattr(metrics, likelihood_field_name) + metametric += weight * metric_score + return metametric + + +@dataclasses.dataclass(frozen=True) +class MetricFeatures: + + object_id: Tensor + valid: Tensor + linear_speed: Tensor + linear_acceleration: Tensor + angular_speed: Tensor + angular_acceleration: Tensor + distance_to_nearest_object: Tensor + collision_per_step: Tensor + time_to_collision: Tensor + distance_to_road_edge: Tensor + offroad_per_step: Tensor + num_placement: Tensor + num_removement: Tensor + distance_placement: Tensor + distance_removement: Tensor + + @classmethod + def from_file(cls, file_path: str): + + if not os.path.exists(file_path): + raise FileNotFoundError(f'Not found file {file_path}') + + with open(file_path, 'rb') as f: + feat_dict = pickle.load(f) + + fields = [field.name for field in dataclasses.fields(cls)] + init_dict = dict() + + for field in fields: + if field in feat_dict: + init_dict[field] = feat_dict[field] + else: + init_dict[field] = None + + return cls(**init_dict) + + def unfold(self, size: int, step: int): + return MetricFeatures( + object_id=self.object_id, + valid=self.valid.unfold(1, size, step), + linear_speed=self.linear_speed.unfold(1, size, step), + linear_acceleration=self.linear_acceleration.unfold(1, size, step), + angular_speed=self.angular_speed.unfold(1, size, step), + angular_acceleration=self.angular_acceleration.unfold(1, size, step), + distance_to_nearest_object=self.distance_to_nearest_object.unfold(1, size, step), + collision_per_step=self.collision_per_step.unfold(1, size, step), + time_to_collision=self.time_to_collision.unfold(1, size, step), + distance_to_road_edge=self.distance_to_road_edge.unfold(1, size, step), + offroad_per_step=self.offroad_per_step.unfold(1, size, step), + num_placement=self.num_placement.unfold(1, size // SHIFT, step // SHIFT), + num_removement=self.num_removement.unfold(1, size // SHIFT, step // SHIFT), + distance_placement=self.distance_placement.unfold(1, size // SHIFT, step // SHIFT), + distance_removement=self.distance_removement.unfold(1, size // SHIFT, step // SHIFT), + ) + + +def compute_metric_features( + simulate_trajectories: ObjectTrajectories, + evaluate_agent_ids: Optional[Tensor]=None, + scenario_log: Optional[scenario_pb2.Scenario]=None, # type: ignore +) -> MetricFeatures: + + if evaluate_agent_ids is not None: + evaluate_trajectories = simulate_trajectories.gather_objects_by_id( + evaluate_agent_ids + ) + else: + evaluate_trajectories = simulate_trajectories + + # valid mask + validity_mask = evaluate_trajectories.valid + validity_mask = validity_mask[:, submission_specs.CURRENT_TIME_INDEX + 1:] + + # ! Kinematics-related features, i.e. speed and acceleration, this needs + # history steps to be prepended to make the first evaluate step valid. + # Resulted `lienar_speed` and others: (n_object_to_evaluate, n_future_step) + linear_speed, linear_accel, angular_speed, angular_accel = ( + trajectory_features.compute_kinematic_features( + evaluate_trajectories.x, + evaluate_trajectories.y, + evaluate_trajectories.z, + evaluate_trajectories.heading, + seconds_per_step=submission_specs.STEP_DURATION_SECONDS)) + # Removes the data corresponding to the history time interval. + linear_speed, linear_accel, angular_speed, angular_accel = ( + map(lambda t: t[:, submission_specs.CURRENT_TIME_INDEX + 1:], + [linear_speed, linear_accel, angular_speed, angular_accel]) + ) + + # ! Distances to nearest objects. + # evaluate_object_mask = torch.any( + # evaluate_agent_ids[:, None] == simulated_trajectories.object_id, axis=0 + # ) + evaluate_object_mask = torch.ones(len(simulate_trajectories.object_id)).bool() + distances_to_objects = interact_features.compute_distance_to_nearest_object( + center_x=simulate_trajectories.x, + center_y=simulate_trajectories.y, + center_z=simulate_trajectories.z, + length=simulate_trajectories.length, + width=simulate_trajectories.width, + height=simulate_trajectories.height, + heading=simulate_trajectories.heading, + valid=simulate_trajectories.valid, + evaluated_object_mask=evaluate_object_mask, + ) + distances_to_objects = ( + distances_to_objects[:, submission_specs.CURRENT_TIME_INDEX + 1:]) + is_colliding_per_step = torch.lt( + distances_to_objects, interact_features.COLLISION_DISTANCE_THRESHOLD) + + # ! Time to collision + times_to_collision = ( + interact_features.compute_time_to_collision_with_object_in_front( + center_x=simulate_trajectories.x, + center_y=simulate_trajectories.y, + length=simulate_trajectories.length, + width=simulate_trajectories.width, + heading=simulate_trajectories.heading, + valid=simulate_trajectories.valid, + evaluated_object_mask=evaluate_object_mask, + seconds_per_step=submission_specs.STEP_DURATION_SECONDS, + ) + ) + times_to_collision = times_to_collision[:, submission_specs.CURRENT_TIME_INDEX + 1:] + + # ! Distance to road edge + distances_to_road_edge = torch.empty_like(distances_to_objects) + is_offroad_per_step = torch.empty_like(is_colliding_per_step) + if scenario_log is not None: + road_edges = [] + for map_feature in scenario_log.map_features: + if map_feature.HasField('road_edge'): + road_edges.append(map_feature.road_edge.polyline) + distances_to_road_edge = map_features.compute_distance_to_road_edge( + center_x=simulate_trajectories.x, + center_y=simulate_trajectories.y, + center_z=simulate_trajectories.z, + length=simulate_trajectories.length, + width=simulate_trajectories.width, + height=simulate_trajectories.height, + heading=simulate_trajectories.heading, + valid=simulate_trajectories.valid, + evaluated_object_mask=evaluate_object_mask, + road_edge_polylines=road_edges, + ) + distances_to_road_edge = distances_to_road_edge[:, submission_specs.CURRENT_TIME_INDEX + 1:] + is_offroad_per_step = torch.gt( + distances_to_road_edge, map_features.OFFROAD_DISTANCE_THRESHOLD + ) + + # ! Placement + if simulate_trajectories.av_id == simulate_trajectories.processed_av_id == -1: + n_agent, n_step_10hz = linear_speed.shape + num_placement = torch.zeros((n_step_10hz // SHIFT,)) + num_removement = torch.zeros((n_step_10hz // SHIFT,)) + distance_placement = torch.zeros((n_agent, n_step_10hz // SHIFT)) + distance_removement = torch.zeros((n_agent, n_step_10hz // SHIFT)) + + else: + assert simulate_trajectories.av_id == simulate_trajectories.processed_av_id, \ + f"Got duplicated av_id: {simulate_trajectories.av_id} and {simulate_trajectories.processed_av_id}" + num_placement, num_removement = ( + placement_features.compute_num_placement( + state=simulate_trajectories.state, + valid=simulate_trajectories.token_valid, + av_id=simulate_trajectories.processed_av_id, + object_id=simulate_trajectories.processed_object_id, + agent_state=AGENT_STATE, + ) + ) + num_placement = num_placement[submission_specs.CURRENT_TIME_INDEX // SHIFT:] + num_removement = num_removement[submission_specs.CURRENT_TIME_INDEX // SHIFT:] + distance_placement, distance_removement = ( + placement_features.compute_distance_placement( + position=simulate_trajectories.token_pos, + state=simulate_trajectories.state, + valid=simulate_trajectories.valid, + av_id=simulate_trajectories.processed_av_id, + object_id=simulate_trajectories.processed_object_id, + agent_state=AGENT_STATE, + ) + ) + distance_placement = distance_placement[:, submission_specs.CURRENT_TIME_INDEX // SHIFT:] + distance_removement = distance_removement[:, submission_specs.CURRENT_TIME_INDEX // SHIFT:] + # distance_placement = distance_placement[distance_placement > 0] + # distance_removement = distance_removement[distance_removement > 0] + + # print out some results for debugging + # CONSOLE.log(f'trajectory x: {simulate_trajectories.x.shape}, \n{simulate_trajectories.x}') + # CONSOLE.log(f'linear speed: {linear_speed.shape}, \n{linear_speed}') + # CONSOLE.log(f'distances: {distances_to_objects.shape}, \n{distances_to_objects}') + # CONSOLE.log(f'time to collision: {times_to_collision.shape}, {times_to_collision}') + + return MetricFeatures( + object_id=simulate_trajectories.object_id, + valid=validity_mask, + # kinematic + linear_speed=linear_speed, + linear_acceleration=linear_accel, + angular_speed=angular_speed, + angular_acceleration=angular_accel, + # interact + distance_to_nearest_object=distances_to_objects, + collision_per_step=is_colliding_per_step, + time_to_collision=times_to_collision, + # map + distance_to_road_edge=distances_to_road_edge, + offroad_per_step=is_offroad_per_step, + # placement + num_placement=num_placement[None, ...], + num_removement=num_removement[None, ...], + distance_placement=distance_placement, + distance_removement=distance_removement, + ) + + +@dataclass(frozen=True) +class LogDistributions: + + linear_speed: Tensor + linear_acceleration: Tensor + angular_speed: Tensor + angular_acceleration: Tensor + distance_to_nearest_object: Tensor + collision_indication: Tensor + time_to_collision: Tensor + distance_to_road_edge: Tensor + num_placement: Tensor + num_removement: Tensor + distance_placement: Tensor + distance_removement: Tensor + offroad_indication: Optional[Tensor] = None + + +""" Compute Metrics """ + +def _assert_and_return_batch_size( + log_samples: Tensor, + sim_samples: Tensor +) -> int: + """Asserts consistency in the tensor shapes and returns batch size. + + Args: + log_samples: A tensor of shape (batch_size, log_sample_size). + sim_samples: A tensor of shape (batch_size, sim_sample_size). + + Returns: + The `batch_size`. + """ + assert log_samples.shape[0] == sim_samples.shape[0], "Log and Sim batch sizes must be equal." + return log_samples.shape[0] + + +def _reduce_average_with_validity( + tensor: Tensor, validity: Tensor) -> Tensor: + """Returns the tensor's average, only selecting valid items. + + Args: + tensor: A float tensor of any shape. + validity: A boolean tensor of the same shape as `tensor`. + + Returns: + A float tensor of shape (1,), containing the average of the valid elements + of `tensor`. + """ + if tensor.shape != validity.shape: + raise ValueError('Shapes of `tensor` and `validity` must be the same.' + f'(Actual: {tensor.shape}, {validity.shape}).') + cond_sum = torch.sum(torch.where(validity, tensor, torch.zeros_like(tensor))) + valid_sum = torch.sum(validity) + if valid_sum == 0: + return torch.tensor(0.) + return cond_sum / valid_sum + + +def histogram_estimate( + config: long_metrics_pb2.SimAgentMetricsConfig.HistogramEstimate, # type: ignore + log_samples: Tensor, + sim_samples: Tensor, +) -> Tensor: + """Computes log-likelihoods of samples based on histograms. + + Args: + config: A configuration dictionary, similar to the one in TensorFlow. + log_samples: A tensor of shape (batch_size, log_sample_size), + containing `log_sample_size` samples from `batch_size` independent + populations. + sim_samples: A tensor of shape (batch_size, sim_sample_size), + containing `sim_sample_size` samples from `batch_size` independent + populations. + + Returns: + A tensor of shape (batch_size, log_sample_size), where each element (i, k) + is the log likelihood of the log sample (i, k) under the sim distribution + (i). + """ + batch_size = _assert_and_return_batch_size(log_samples, sim_samples) + + # We generate `num_bins`+1 edges for the histogram buckets. + edges = torch.linspace( + config.min_val, config.max_val, config.num_bins + 1 + ).float() + + # Clip the samples to avoid errors with histograms. + log_samples = torch.clamp(log_samples, config.min_val, config.max_val) + sim_samples = torch.clamp(sim_samples, config.min_val, config.max_val) + + # Create the categorical distribution for simulation. `tfp.histogram` returns + # a tensor of shape (num_bins, batch_size), so we need to transpose to conform + # to `tfp.distribution.Categorical`, which requires `probs` to be + # (batch_size, num_bins). + sim_counts = torch.vmap(lambda x: torch.histogram(x, bins=edges).hist)(sim_samples) + sim_counts += config.additive_smoothing_pseudocount + distributions = torch.distributions.Categorical(probs=sim_counts) + + # Generate the counts for the log distribution. We reshape the log samples to + # (batch_size * log_sample_size, 1), so every log sample is independently + # scored. + log_values_flat = log_samples.reshape(-1, 1) + # Shape of log_counts: (batch_size * log_sample_size, num_bins). + log_counts = torch.vmap(lambda x: torch.histogram(x, bins=edges).hist)(log_values_flat) + # Identify which bin each sample belongs to and get the log probability of + # that bin under the sim distribution. + max_log_bin = log_counts.argmax(dim=-1) + batched_max_log_bin = max_log_bin.reshape(batch_size, -1) + + # Since we have defined the categorical distribution to have `batch_size` + # independent populations, tfp expects this `batch_size` to be in the last + # dimension of the tensor, so transpose the log bins to + # (log_sample_size, batch_size). + log_likelihood = distributions.log_prob(batched_max_log_bin.transpose(0, 1)) + + # Return log likelihood in the shape (batch_size, log_sample_size) + return log_likelihood.transpose(0, 1) + + +def log_likelihood_estimate_timeseries( + field: str, + feature_config: long_metrics_pb2.SimAgentMetricsConfig.FeatureConfig, # type: ignore + sim_values: Tensor, + log_distributions: torch.distributions.Categorical, + estimate_method: str='histogram', +) -> Tensor: + """Computes the log-likelihood estimates for a time-series simulated feature. + + Args: + feature_config: A time-series compatible `FeatureConfig`. + log_distributions: A float Tensor with shape (batch_sizie, n_bins). + sim_values: A float Tensor with shape (n_objects / n_scenarios, n_segments, n_steps). + + Returns: + A tensor of shape (n_objects, n_steps) containing the simulation probability + estimates of the simulation features under the logged distribution of the same + feature. + """ + assert sim_values.ndim == 3, f'Expect sim_values.ndim==3, got {sim_values.ndim}, shape {sim_values.shape} for {field}' + + sim_values_flat = sim_values.reshape(-1, 1) # [n_objects * n_segments * n_steps] + + # if not feature_config.independent_timesteps: + # # If time steps needs to be considered independent, reshape: + # # - `sim_values` as (n_objects, n_rollouts * n_steps) + # # - `log_values` as (n_objects, n_steps) + # # If values in time are instead to be compared per-step, reshape: + # # - `sim_values` as (n_objects * n_steps, n_rollouts) + # # - `log_values` as (n_objects * n_steps, 1) + # sim_values = sim_values.reshape(-1, 1) # n_rollouts=1 + + # if feature_config.independent_timesteps: + # sim_values = sim_values.permute(1, 0, 2).reshape(n_objects, n_rollouts * n_steps) + # else: + # sim_values = sim_values.permute(1, 2, 0).reshape(n_objects * n_steps, n_rollouts) + # log_values = log_values.reshape(n_objects * n_steps, 1) + + # ! calculate distributions for simulate features + if estimate_method == 'histogram': + config = feature_config.histogram + elif estimate_method == 'bernoulli': + config = ( + long_metrics_pb2.SimAgentMetricsConfig.HistogramEstimate( + min_val=-0.5, max_val=0.5, num_bins=2, + additive_smoothing_pseudocount=feature_config.bernoulli.additive_smoothing_pseudocount + ) + ) + sim_values_flat = sim_values_flat.float() # cast torch.bool to torch.float32 + + # We generate `num_bins`+1 edges for the histogram buckets. + edges = torch.linspace( + config.min_val, config.max_val, config.num_bins + 1 + ).float() + + sim_counts = torch.vmap(lambda x: torch.histogram(x, bins=edges).hist)(sim_values_flat) # [batch_size, num_bins] + # Identify which bin each sample belongs to and get the log probability of + # that bin under the sim distribution. + max_sim_bin = sim_counts.argmax(dim=-1) + batched_max_sim_bin = max_sim_bin.reshape(1, -1) # `batch_size` = 1, follows the log distributions + + sim_likelihood = log_distributions.log_prob(batched_max_sim_bin.transpose(0, 1)).flatten() + return sim_likelihood.reshape(*sim_values.shape) # [n_objects, n_segments, n_steps] + + +def compute_scenario_metrics_for_bundle( + config: long_metrics_pb2.SimAgentMetricsConfig, # type: ignore + log_distributions: LogDistributions, + scenario_log: Optional[scenario_pb2.Scenario], # type: ignore + scenario_rollouts: ScenarioRollouts, +) -> long_metrics_pb2.SimAgentMetrics: # type: ignore + + features_fields = [field.name for field in dataclasses.fields(MetricFeatures)] + features_fields.remove('object_id') + + # ! compute simluation features + # CONSOLE.log('[on yellow] Compute sim features [/]') + sim_features = collections.defaultdict(list) + for simulate_trajectories in tqdm(scenario_rollouts.joint_scenes, leave=False, desc='rollouts ...'): # n_rollout=1 + rollout_features = compute_metric_features( + simulate_trajectories, + evaluate_agent_ids=None, + scenario_log=scenario_log + ) + + for field in features_fields: + sim_features[field].append(getattr(rollout_features, field)) + + for field in features_fields: + if sim_features[field][0] is not None: + sim_features[field] = torch.concat(sim_features[field], dim=0) # n_rollout for dim=0 + + sim_features = MetricFeatures( + **sim_features, object_id=None, + ) + # after unfold: linear_speed shape [n_agent, n_window, window_size], + # num_placement shape [n_scenario=1, n_window, window_size] + flattened_sim_features = copy.deepcopy(sim_features) + sim_features = sim_features.unfold(size=submission_specs.N_SIMULATION_STEPS, step=SHIFT) + # CONSOLE.log(f'sim linear_speed feature: {sim_features.linear_speed.shape}') + # CONSOLE.log(f'sim num_placement feature: {sim_features.num_placement.shape}') + + ## ! compute metrics + + # ! kinematics-related metrics + linear_speed_log_likelihood = log_likelihood_estimate_timeseries( + field='linear_speed', + feature_config=config.linear_speed, + sim_values=sim_features.linear_speed, + log_distributions=log_distributions.linear_speed, + ) + angular_speed_log_likelihood = log_likelihood_estimate_timeseries( + field='angular_speed', + feature_config=config.angular_speed, + sim_values=sim_features.angular_speed, + log_distributions=log_distributions.angular_speed, + ) + speed_validity, acceleration_validity = ( + trajectory_features.compute_kinematic_validity(flattened_sim_features.valid) + ) + speed_validity = speed_validity.unfold(1, size=submission_specs.N_SIMULATION_STEPS, step=SHIFT) + acceleration_validity = acceleration_validity.unfold(1, size=submission_specs.N_SIMULATION_STEPS, step=SHIFT) + linear_speed_likelihood = torch.exp(_reduce_average_with_validity( + linear_speed_log_likelihood, speed_validity)) + angular_speed_likelihood = torch.exp(_reduce_average_with_validity( + angular_speed_log_likelihood, speed_validity)) + # CONSOLE.log(f'linear_speed_likelihood: {linear_speed_likelihood}') + # CONSOLE.log(f'angular_speed_likelihood: {angular_speed_likelihood}') + + linear_accel_log_likelihood = log_likelihood_estimate_timeseries( + field='linear_acceleration', + feature_config=config.linear_acceleration, + sim_values=sim_features.linear_acceleration, + log_distributions=log_distributions.linear_acceleration, + ) + angular_accel_log_likelihood = log_likelihood_estimate_timeseries( + field='angular_acceleration', + feature_config=config.angular_acceleration, + sim_values=sim_features.angular_acceleration, + log_distributions=log_distributions.angular_acceleration, + ) + linear_accel_likelihood = torch.exp(_reduce_average_with_validity( + linear_accel_log_likelihood, acceleration_validity)) + angular_accel_likelihood = torch.exp(_reduce_average_with_validity( + angular_accel_log_likelihood, acceleration_validity)) + # CONSOLE.log(f'linear_accel_likelihood: {linear_accel_likelihood}') + # CONSOLE.log(f'angular_accel_likelihood: {angular_accel_likelihood}') + + # ! collision and distance to other objects. + + sim_collision_indication = torch.any( + torch.where(sim_features.valid, sim_features.collision_per_step, False), + dim=2)[..., None] # add a dummy time dimension + collision_score = log_likelihood_estimate_timeseries( + field='collision_indication', + feature_config=config.collision_indication, + sim_values=sim_collision_indication, + log_distributions=log_distributions.collision_indication, + estimate_method='bernoulli', + ) + collision_likelihood = torch.exp(torch.mean(collision_score)) + + distance_to_objects_log_likelihodd = log_likelihood_estimate_timeseries( + field='distance_to_nearest_object', + feature_config=config.distance_to_nearest_object, + sim_values=sim_features.distance_to_nearest_object, + log_distributions=log_distributions.distance_to_nearest_object, + ) + distance_to_objects_likelihodd = torch.exp(_reduce_average_with_validity( + distance_to_objects_log_likelihodd, sim_features.valid)) + # CONSOLE.log(f'distance_to_objects_likelihodd: {distance_to_objects_likelihodd}') + + ttc_log_likelihood = log_likelihood_estimate_timeseries( + field='time_to_collision', + feature_config=config.time_to_collision, + sim_values=sim_features.time_to_collision, + log_distributions=log_distributions.time_to_collision, + ) + ttc_likelihood = torch.exp(_reduce_average_with_validity( + ttc_log_likelihood, sim_features.valid)) + # CONSOLE.log(f'ttc_likelihood: {ttc_likelihood}') + + # ! offroad and distance to road edge. + + # distance_to_road_edge_log_likelihood = log_likelihood_estimate_timeseries( + # field='distance_to_road_edge', + # sim_values=sim_features.distance_to_road_edge, + # log_distributions=log_distributions.distance_to_road_edge, + # ) + # distance_to_road_edge_likelihood = torch.exp(_reduce_average_with_validity( + # distance_to_road_edge_log_likelihood, sim_features.valid)) + # CONSOLE.log(f'distance_to_road_edge_likelihood: {distance_to_road_edge_likelihood}') + + # ! placement + + num_placement_log_likelihood = log_likelihood_estimate_timeseries( + field='num_placement', + feature_config=config.num_placement, + sim_values=sim_features.num_placement.float(), + log_distributions=log_distributions.num_placement, + ) + num_placement_likelihood = torch.exp(torch.mean(num_placement_log_likelihood)) + num_removement_log_likelihood = log_likelihood_estimate_timeseries( + field='num_removement', + feature_config=config.num_removement, + sim_values=sim_features.num_removement.float(), + log_distributions=log_distributions.num_removement, + ) + num_removement_likelihood = torch.exp(torch.mean(num_removement_log_likelihood)) + # CONSOLE.log(f'num_placement_likelihood: {num_placement_likelihood}') + # CONSOLE.log(f'num_removement_likelihood: {num_removement_likelihood}') + + # tensor([[0.0013, 0.0078, 0.0194, 0.0373, 0.0628, 0.0938, 0.1232, 0.1470, 0.1701, + # 0.3371]]) + # tensor([[0.0201, 0.0570, 0.0689, 0.0839, 0.1029, 0.1172, 0.1282, 0.1286, 0.1237, + # 0.1695]]) + distance_placement_log_likelihood = log_likelihood_estimate_timeseries( + field='distance_placement', + feature_config=config.distance_placement, + sim_values=sim_features.distance_placement, + log_distributions=log_distributions.distance_placement, + ) + distance_placement_validity = ( + (sim_features.distance_placement > config.distance_placement.histogram.min_val) & + (sim_features.distance_placement < config.distance_placement.histogram.max_val) + ) + distance_placement_likelihood = torch.exp(_reduce_average_with_validity( + distance_placement_log_likelihood, distance_placement_validity)) + distance_removement_log_likelihood = log_likelihood_estimate_timeseries( + field='distance_removement', + feature_config=config.distance_removement, + sim_values=sim_features.distance_removement, + log_distributions=log_distributions.distance_removement, + ) + distance_removement_validity = ( + (sim_features.distance_removement > config.distance_removement.histogram.min_val) & + (sim_features.distance_removement < config.distance_removement.histogram.max_val) + ) + distance_removement_likelihood = torch.exp(_reduce_average_with_validity( + distance_removement_log_likelihood, distance_removement_validity)) + + # ==== Simulated collision and offroad rates ==== + simulated_collision_rate = torch.sum( + sim_collision_indication.long() + ) / torch.sum(torch.ones_like(sim_collision_indication).long()) + # simulated_offroad_rate = tf.reduce_sum( + # # `sim_offroad_indication` shape: (n_samples, n_objects). + # tf.cast(sim_offroad_indication, tf.int32) + # ) / tf.reduce_sum(tf.ones_like(sim_offroad_indication, dtype=tf.int32)) + + # ==== Meta metric ==== + likelihood_metrics = { + 'linear_speed_likelihood': float(linear_speed_likelihood.numpy()), + 'linear_acceleration_likelihood': float(linear_accel_likelihood.numpy()), + 'angular_speed_likelihood': float(angular_speed_likelihood.numpy()), + 'angular_acceleration_likelihood': float(angular_accel_likelihood.numpy()), + 'distance_to_nearest_object_likelihood': float(distance_to_objects_likelihodd.numpy()), + 'collision_indication_likelihood': float(collision_likelihood.numpy()), + 'time_to_collision_likelihood': float(ttc_likelihood.numpy()), + # 'distance_to_road_edge_likelihoodfloat(': distance_road_edge_likelihood.nump)y(), + # 'offroad_indication_likelihoodfloat(': offroad_likelihood.nump)y(), + 'num_placement_likelihood': float(num_placement_likelihood.numpy()), + 'num_removement_likelihood': float(num_removement_likelihood.numpy()), + 'distance_placement_likelihood': float(distance_placement_likelihood.numpy()), + 'distance_removement_likelihood': float(distance_removement_likelihood.numpy()), + } + + metametric = _compute_metametric( + config, long_metrics_pb2.SimAgentMetrics(**likelihood_metrics) + ) + # CONSOLE.log(f'metametric: {metametric}') + + return long_metrics_pb2.SimAgentMetrics( + scenario_id=scenario_rollouts.scenario_id, + metametric=metametric, + simulated_collision_rate=float(simulated_collision_rate.numpy()), + # simulated_offroad_rate=simulated_offroad_rate.numpy(), + **likelihood_metrics, + ) + + +""" Log Features """ + +def _get_log_distributions( + field: str, + feature_config: long_metrics_pb2.SimAgentMetricsConfig.FeatureConfig, # type: ignore + log_values: Tensor, + estimate_method: str = 'histogram', +) -> Tensor: + """Computes the log-likelihood estimates for a time-series simulated feature. + + Args: + feature_config: A time-series compatible `FeatureConfig`. + log_values: A float Tensor with shape (n_objects, n_steps). + sim_values: A float Tensor with shape (n_rollouts, n_objects, n_steps). + + Returns: + A tensor of shape (n_objects, n_steps) containing the log probability + estimates of the log features under the simulated distribution of the same + feature. + """ + assert log_values.ndim == 2, f'Expect log_values.ndim==2, got {log_values.ndim}, shape {log_values.shape} for {field}' + + # [n_objects, n_steps] -> [n_objects * n_steps] + log_samples = log_values.reshape(-1) + + # ! estimate + if estimate_method == 'histogram': + config = feature_config.histogram + elif estimate_method == 'bernoulli': + config = ( + long_metrics_pb2.SimAgentMetricsConfig.HistogramEstimate( + min_val=-0.5, max_val=0.5, num_bins=2, + additive_smoothing_pseudocount=feature_config.bernoulli.additive_smoothing_pseudocount + ) + ) + log_samples = log_samples.float() # cast torch.bool to torch.float32 + + # We generate `num_bins`+1 edges for the histogram buckets. + edges = torch.linspace( + config.min_val, config.max_val, config.num_bins + 1 + ).float() + + if field in ('distance_placement', 'distance_removement'): + log_samples = log_samples[(log_samples > config.min_val) & (log_samples < config.max_val)] + + # Clip the samples to avoid errors with histograms. Nonetheless, the min/max + # values should be configured to never hit this condition in practice. + log_samples = torch.clamp(log_samples, config.min_val, config.max_val) + + # Create the categorical distribution for simulation. `tfp.histogram` returns + # a tensor of shape (num_bins, batch_size), so we need to transpose to conform + # to `tfp.distribution.Categorical`, which requires `probs` to be + # (batch_size, num_bins). + log_counts = torch.histogram(log_samples, bins=edges).hist.unsqueeze(dim=0) # [1, n_samples] + log_counts += config.additive_smoothing_pseudocount + distributions = torch.distributions.Categorical(probs=log_counts) + + return distributions + + +class LongMetric(Metric): + + log_features: MetricFeatures + + def __init__( + self, + prefix: str='', + log_features_dir: str='data/waymo_processed/log_features/', + config_path: str='dev/metrics/metric_config.textproto', + ) -> None: + super().__init__() + self.prefix = prefix + self.metrics_config = self.load_metrics_config(config_path) + + self.use_log = False + + self.field_names = [ + "metametric", + "average_displacement_error", + "min_average_displacement_error", + "linear_speed_likelihood", + "linear_acceleration_likelihood", + "angular_speed_likelihood", + "angular_acceleration_likelihood", + 'distance_to_nearest_object_likelihood', + 'collision_indication_likelihood', + 'time_to_collision_likelihood', + # 'distance_to_road_edge_likelihood', + # 'offroad_indication_likelihood', + 'simulated_collision_rate', + # 'simulated_offroad_rate', + 'num_placement_likelihood', + 'num_removement_likelihood', + 'distance_placement_likelihood', + 'distance_removement_likelihood', + ] + for k in self.field_names: + self.add_state(k, default=torch.tensor(0.), dist_reduce_fx='sum') + self.add_state('scenario_counter', default=torch.tensor(0.), dist_reduce_fx='sum') + self.add_state('placement_valid_scenario_counter', default=torch.tensor(0.), dist_reduce_fx='sum') + self.add_state('removement_valid_scenario_counter', default=torch.tensor(0.), dist_reduce_fx='sum') + + # get log features + log_features_path = os.path.join(log_features_dir, 'total_features.pkl') + if not os.path.exists(log_features_path): + CONSOLE.log(f'[on yellow] Log features does not exist, loading now ... [/]') + log_features = aggregate_log_metric_features(log_features_dir) + else: + log_features = MetricFeatures.from_file(log_features_path) + CONSOLE.log(f'Loaded log features from {log_features_path}') + self.log_features = log_features + + self._compute_distributions() + CONSOLE.log(f"Calculated log distributions:\n{self.log_distributions}") + + def _compute_distributions(self): + self.log_distributions = LogDistributions( + linear_speed = _get_log_distributions('linear_speed', + self.metrics_config.linear_speed, self.log_features.linear_speed, + ), + linear_acceleration = _get_log_distributions('linear_acceleration', + self.metrics_config.linear_acceleration, self.log_features.linear_acceleration, + ), + angular_speed = _get_log_distributions('angular_speed', + self.metrics_config.angular_speed, self.log_features.angular_speed, + ), + angular_acceleration = _get_log_distributions('angular_acceleration', + self.metrics_config.angular_acceleration, self.log_features.angular_acceleration, + ), + distance_to_nearest_object = _get_log_distributions('distance_to_nearest_object', + self.metrics_config.distance_to_nearest_object, self.log_features.distance_to_nearest_object, + ), + collision_indication = _get_log_distributions('collision_indication', + self.metrics_config.collision_indication, + log_collision_indication := torch.any( + torch.where(self.log_features.valid, self.log_features.collision_per_step, False), dim=1 + )[..., None], # add a dummy time dimension + estimate_method = 'bernoulli', + ), + time_to_collision = _get_log_distributions('time_to_collision', + self.metrics_config.time_to_collision, self.log_features.time_to_collision, + ), + distance_to_road_edge = _get_log_distributions('distance_to_road_edge', + self.metrics_config.distance_to_road_edge, self.log_features.distance_to_road_edge, + ), + # dist_offroad_indication = _get_log_distributions( + # 'offroad_indication', + # self.metrics_config.offroad_indication, + # log_offroad_indication := torch.any( + # torch.where(self.log_features.valid, self.log_features.offroad_per_step, False), dim=1 + # ), + # ), + num_placement = _get_log_distributions('num_placement', + self.metrics_config.num_placement, self.log_features.num_placement.float(), + ), + num_removement = _get_log_distributions('num_removement', + self.metrics_config.num_removement, self.log_features.num_removement.float(), + ), + distance_placement = _get_log_distributions('distance_placement', + self.metrics_config.distance_placement, ( + self.log_features.distance_placement[self.log_features.distance_placement > 0])[None, ...], + ), + distance_removement = _get_log_distributions('distance_removement', + self.metrics_config.distance_removement, ( + self.log_features.distance_removement[self.log_features.distance_removement > 0])[None, ...], + ), + ) + + def _compute_scenario_metrics( + self, + scenario_file: Optional[str], + scenario_rollout: ScenarioRollouts, + ) -> long_metrics_pb2.SimAgentMetrics: # type: ignore + + scenario_log = None + if self.use_log and scenario_file is not None: + if not os.path.exists(scenario_file): + raise FileNotFoundError(f"Not found file {scenario_file}") + scenario_log = scenario_pb2.Scenario() + for data in tf.data.TFRecordDataset([scenario_file], compression_type=''): + scenario_log.ParseFromString(bytes(data.numpy())) + break + + return compute_scenario_metrics_for_bundle( + self.metrics_config, self.log_distributions, scenario_log, scenario_rollout + ) + + def compute_metrics(self, outputs: dict) -> List[long_metrics_pb2.SimAgentMetrics]: # type: ignore + """ + `outputs` is a dict directly generated by predict models: + >>> outputs = dict( + >>> scenario_id=get_scenario_id_int_tensor(data['scenario_id'], device), + >>> agent_id=agent_id, + >>> agent_batch=agent_batch, + >>> pred_traj=pred_traj, + >>> pred_z=pred_z, + >>> pred_head=pred_head, + >>> pred_shape=pred_shape, + >>> pred_type=pred_type, + >>> pred_state=pred_state, + >>> ) + """ + + scenario_rollouts = output_to_rollouts(outputs) + log_paths: List[str] = outputs['tfrecord_path'] + + pool_scenario_metrics = [] + for _scenario_file, _scenario_rollout in tqdm( + zip(log_paths, scenario_rollouts), leave=False, desc='scenarios ...'): # n_scenarios + pool_scenario_metrics.append( + self._compute_scenario_metrics( + _scenario_file, _scenario_rollout, + ) + ) + + return pool_scenario_metrics + + def update( + self, + outputs: Optional[dict]=None, + metrics: Optional[List[long_metrics_pb2.SimAgentMetrics]]=None # type: ignore + ) -> None: + + if metrics is None: + assert outputs is not None, f'`outputs` should not be None!' + metrics = self.compute_metrics(outputs) + + for scenario_metrics in metrics: + self.scenario_counter += 1 + + self.metametric += scenario_metrics.metametric + self.average_displacement_error += ( + scenario_metrics.average_displacement_error + ) + self.min_average_displacement_error += ( + scenario_metrics.min_average_displacement_error + ) + self.linear_speed_likelihood += scenario_metrics.linear_speed_likelihood + self.linear_acceleration_likelihood += ( + scenario_metrics.linear_acceleration_likelihood + ) + self.angular_speed_likelihood += scenario_metrics.angular_speed_likelihood + self.angular_acceleration_likelihood += ( + scenario_metrics.angular_acceleration_likelihood + ) + self.distance_to_nearest_object_likelihood += ( + scenario_metrics.distance_to_nearest_object_likelihood + ) + self.collision_indication_likelihood += ( + scenario_metrics.collision_indication_likelihood + ) + self.time_to_collision_likelihood += ( + scenario_metrics.time_to_collision_likelihood + ) + # self.distance_to_road_edge_likelihood += ( + # scenario_metrics.distance_to_road_edge_likelihood + # ) + # self.offroad_indication_likelihood += ( + # scenario_metrics.offroad_indication_likelihood + # ) + self.simulated_collision_rate += scenario_metrics.simulated_collision_rate + # self.simulated_offroad_rate += scenario_metrics.simulated_offroad_rate + + self.num_placement_likelihood += ( + scenario_metrics.num_placement_likelihood + ) + self.num_removement_likelihood += ( + scenario_metrics.num_removement_likelihood + ) + self.distance_placement_likelihood += ( + scenario_metrics.distance_placement_likelihood + ) + self.distance_removement_likelihood += ( + scenario_metrics.distance_removement_likelihood + ) + + if scenario_metrics.distance_placement_likelihood > 0: + self.placement_valid_scenario_counter += 1 + + if scenario_metrics.distance_removement_likelihood > 0: + self.removement_valid_scenario_counter += 1 + + def compute(self) -> Dict[str, Tensor]: + metrics_dict = {} + for k in self.field_names: + if k not in ('distance_placement', 'distance_removement'): + metrics_dict[k] = getattr(self, k) / self.scenario_counter + if k == 'distance_placement': + metrics_dict[k] = getattr(self, k) / self.placement_valid_scenario_counter + if k == 'distance_removement': + metrics_dict[k] = getattr(self, k) / self.removement_valid_scenario_counter + + mean_metrics = long_metrics_pb2.SimAgentMetrics( + scenario_id='', **metrics_dict, + ) + final_metrics = self.aggregate_metrics_to_buckets( + self.metrics_config, mean_metrics + ) + CONSOLE.log(f'final_metrics:\n{final_metrics}') + + out_dict = { + f"{self.prefix}/wosac/realism_meta_metric": final_metrics.realism_meta_metric, + f"{self.prefix}/wosac/kinematic_metrics": final_metrics.kinematic_metrics, + f"{self.prefix}/wosac/interactive_metrics": final_metrics.interactive_metrics, + f"{self.prefix}/wosac/map_based_metrics": final_metrics.map_based_metrics, + f"{self.prefix}/wosac/placement_based_metrics": final_metrics.placement_based_metrics, + f"{self.prefix}/wosac/min_ade": final_metrics.min_ade, + f"{self.prefix}/wosac/scenario_counter": int(self.scenario_counter), + } + for k in self.field_names: + out_dict[f"{self.prefix}/wosac_likelihood/{k}"] = float(metrics_dict[k]) + + return out_dict + + @staticmethod + def aggregate_metrics_to_buckets( + config: long_metrics_pb2.SimAgentMetricsConfig, # type: ignore + metrics: long_metrics_pb2.SimAgentMetrics # type: ignore + ) -> long_metrics_pb2.SimAgentsBucketedMetrics: # type: ignore + """Aggregates metrics into buckets for better readability.""" + bucketed_metrics = {} + for bucket_name, fields_in_bucket in _METRIC_FIELD_NAMES_BY_BUCKET.items(): + weighted_metric, weights_sum = 0.0, 0.0 + for field_name in fields_in_bucket: + likelihood_field_name = field_name + '_likelihood' + weight = getattr(config, field_name).metametric_weight + metric_score = getattr(metrics, likelihood_field_name) + weighted_metric += weight * metric_score + weights_sum += weight + if weights_sum == 0: + weights_sum = 1 # FIXME: hack!!! + # raise ValueError('The bucket\'s weight sum is zero. Check your metrics' + # ' config.') + bucketed_metrics[bucket_name] = weighted_metric / weights_sum + + return long_metrics_pb2.SimAgentsBucketedMetrics( + realism_meta_metric=metrics.metametric, + kinematic_metrics=bucketed_metrics['kinematic'], + interactive_metrics=bucketed_metrics['interactive'], + map_based_metrics=bucketed_metrics['map_based'], + placement_based_metrics=bucketed_metrics['placement_based'], + min_ade=metrics.min_average_displacement_error, + simulated_collision_rate=metrics.simulated_collision_rate, + simulated_offroad_rate=metrics.simulated_offroad_rate, + ) + + @staticmethod + def load_metrics_config(config_path: str = 'dev/metrics/metric_config.textproto', + ) -> long_metrics_pb2.SimAgentMetricsConfig: # type: ignore + config = long_metrics_pb2.SimAgentMetricsConfig() + with open(config_path, 'r') as f: + text_format.Parse(f.read(), config) + return config + + def dumps(self, dir): + from datetime import datetime + + timestamp = datetime.now().strftime("%m_%d_%H%M%S") + + results = self.compute() + path = os.path.join(dir, f'{self.prefix}_{timestamp}.json') + with open(path, 'w', encoding='utf-8') as f: + json.dump(results, f, indent=4) + + CONSOLE.log(f'Saved results to [bold][yellow]{path}') + + +""" Preprocess Methods """ + +def _dump_log_metric_features( + pkl_dir: str, + tfrecords_dir: str, + save_dir: str, + transform: WaymoTargetBuilder, + token_processor: TokenProcessor, + scenario_id: str, + ): + + try: + + tqdm.write(f'Processing scenario {scenario_id}') + save_path = os.path.join(save_dir, f'{scenario_id}.pkl') + if os.path.exists(save_path): + return + + # load gt data + pkl_file = os.path.join(pkl_dir, f'{scenario_id}.pkl') + if not os.path.exists(pkl_file): + raise FileNotFoundError(f"Not found file {pkl_file}") + tfrecord_file = os.path.join(tfrecords_dir, f'{scenario_id}.tfrecords') + if not os.path.exists(tfrecord_file): + raise FileNotFoundError(f"Not found file {tfrecord_file}") + + scenario_log = scenario_pb2.Scenario() + for data in tf.data.TFRecordDataset([tfrecord_file], compression_type=''): + scenario_log.ParseFromString(bytes(data.numpy())) + break + + with open(pkl_file, 'rb') as f: + log_data = pickle.load(f) + + # preprocess data + log_data = transform._score_trained_agents(log_data) # get `train_mask` + log_data = token_processor._tokenize_agent(log_data) + + # convert to `JointScene` and compute features + log_trajectories = scenario_to_trajectories(scenario_log, processed_scenario=log_data) + # log_trajectories = ObjectTrajectories.init_from_processed_scenario(data) + + # NOTE: we do not consider the `evaluation_agent_ids` here + # evaluate_agent_ids = torch.tensor( + # get_evaluation_agent_ids(scenario_log) + # ) + evaluate_agent_ids = None + log_features = compute_metric_features( + log_trajectories, evaluate_agent_ids=evaluate_agent_ids, #scenario_log=scenario_log, + ) + + # save to pkl file + with open(save_path, 'wb') as f: + pickle.dump(log_features, f) + + except Exception as e: + CONSOLE.log(f'[on red] Failed to process scenario {scenario_id} due to {e}.[/]') + return + + +def dump_log_metric_features(log_dir, save_dir): + + buffer_size = 128 + + # file loaders + pkl_dir = os.path.join(log_dir, 'validation') + if not os.path.exists(pkl_dir): + raise RuntimeError(f'Not found folder {pkl_dir}') + tfrecords_dir = os.path.join(log_dir, 'validation_tfrecords_splitted') + if not os.path.exists(tfrecords_dir): + raise RuntimeError(f'Not found folder {tfrecords_dir}') + + files = list(fnmatch.filter(os.listdir(pkl_dir), '*.pkl')) + json_path = os.path.join(log_dir, 'meta_infos.json') + meta_infos = json.load(open(json_path, 'r', encoding='utf-8'))['validation'] + CONSOLE.log(f"Loaded meta infos from {json_path}") + available_scenarios = list(meta_infos.keys()) + df = pd.DataFrame.from_dict(meta_infos, orient='index') + available_scenarios_set = set(available_scenarios) + df_filtered = df[(df.index.isin(available_scenarios_set)) & (df['num_agents'] >= 8) & (df['num_agents'] < buffer_size)] + valid_scenarios = set(df_filtered.index) + files = list(tqdm(filter(lambda fn: fn.removesuffix('.pkl') in valid_scenarios, files), leave=False)) + + scenario_ids = list(map(lambda fn: fn.removesuffix('.pkl'), files)) + CONSOLE.log(f'Loaded {len(scenario_ids)} scenarios from validation split.') + + # initialize + transform = WaymoTargetBuilder(num_historical_steps=11, + num_future_steps=80, + max_num=32) + + token_processor = TokenProcessor(token_size=2048, + state_token={'invalid': 0, 'valid': 1, 'enter': 2, 'exit': 3}, + pl2seed_radius=75) + + partial_dump_gt_metric_features = partial( + _dump_log_metric_features, pkl_dir, tfrecords_dir, save_dir, transform, token_processor) + + for scenario_id in tqdm(scenario_ids, leave=False, desc='scenarios ...'): + + partial_dump_gt_metric_features(scenario_id) + + +def batch_dump_log_metric_features(log_dir, save_dir, num_workers=64): + + buffer_size = 128 + + # file loaders + pkl_dir = os.path.join(log_dir, 'validation') + if not os.path.exists(pkl_dir): + raise RuntimeError(f'Not found folder {pkl_dir}') + tfrecords_dir = os.path.join(log_dir, 'validation_tfrecords_splitted') + if not os.path.exists(tfrecords_dir): + raise RuntimeError(f'Not found folder {tfrecords_dir}') + + files = list(fnmatch.filter(os.listdir(pkl_dir), '*.pkl')) + json_path = os.path.join(log_dir, 'meta_infos.json') + meta_infos = json.load(open(json_path, 'r', encoding='utf-8'))['validation'] + CONSOLE.log(f"Loaded meta infos from {json_path}") + available_scenarios = list(meta_infos.keys()) + df = pd.DataFrame.from_dict(meta_infos, orient='index') + available_scenarios_set = set(available_scenarios) + df_filtered = df[(df.index.isin(available_scenarios_set)) & (df['num_agents'] >= 8) & (df['num_agents'] < buffer_size)] + valid_scenarios = set(df_filtered.index) + files = list(tqdm(filter(lambda fn: fn.removesuffix('.pkl') in valid_scenarios, files), leave=False)) + + scenario_ids = list(map(lambda fn: fn.removesuffix('.pkl'), files)) + CONSOLE.log(f'Loaded {len(scenario_ids)} scenarios from validation split.') + + # initialize + transform = WaymoTargetBuilder(num_historical_steps=11, + num_future_steps=80, + max_num=32) + + token_processor = TokenProcessor(token_size=2048, + state_token={'invalid': 0, 'valid': 1, 'enter': 2, 'exit': 3}, + pl2seed_radius=75) + + partial_dump_gt_metric_features = partial( + _dump_log_metric_features, pkl_dir, tfrecords_dir, save_dir, transform, token_processor) + + with multiprocessing.Pool(num_workers) as p: + list(tqdm(p.imap_unordered(partial_dump_gt_metric_features, scenario_ids), total=len(scenario_ids))) + + +def aggregate_log_metric_features(load_dir): + + files = list(fnmatch.filter(os.listdir(load_dir), '*.pkl')) + if 'total_features.pkl' in files: + files.remove('total_features.pkl') + CONSOLE.log(f'Loaded {len(files)} scenarios from dumpped log metric features') + + features_fields = [field.name for field in dataclasses.fields(MetricFeatures)] + features_fields.remove('object_id') + + # load and append + total_features = collections.defaultdict(list) + for file in tqdm(files, leave=False, desc='scenario ...'): + + with open(os.path.join(load_dir, file), 'rb') as f: + log_features = pickle.load(f) + + for field in features_fields: + total_features[field].append(getattr(log_features, field)) + + # aggregate + features_info = dict() + for field in (pbar := tqdm(features_fields, leave=False)): + pbar.set_postfix(f=field) + if total_features[field][0] is not None: + total_features[field] = torch.concat(total_features[field], dim=0) # n_agent or n_scenario + features_info[field] = total_features[field].shape + CONSOLE.log(f'Aggregated log features:\n{features_info}') + + # save + save_path = os.path.join(load_dir, 'total_features.pkl') + with open(save_path, 'wb') as f: + pickle.dump(total_features, f) + CONSOLE.log(f'Saved total features to [green]{save_path}.[/]') + + return MetricFeatures(**total_features, object_id=None) + + +def _compute_metrics( + metric: LongMetric, + load_dir: str, + verbose: bool, + rollouts_file: str, +) -> List[long_metrics_pb2.SimAgentMetrics]: # type: ignore + + if verbose: + print(f'Processing {rollouts_file}') + + with open(os.path.join(load_dir, rollouts_file), 'rb') as f: + rollouts = pickle.load(f) + # CONSOLE.log(f'Loaded rollouts from {rollouts_file}') + + return metric.compute_metrics(rollouts) + + +def compute_metrics(load_dir, rollouts_files): + + log_every_n_steps = 100 + + metric = LongMetric('val_close_long') + CONSOLE.log(f'metrics config:\n{metric.metrics_config}') + + i = 0 + for rollouts_file in tqdm(rollouts_files, leave=False, desc='Rollouts files ...'): + + # ! compute metrics and update + metric.update( + metrics=_compute_metrics(metric, load_dir, verbose=False, rollouts_file=rollouts_file) + ) + + if i % log_every_n_steps == 0: + CONSOLE.log(f'Step={i}:\n{metric.compute()}') + + i += 1 + + CONSOLE.log(f'[bold][yellow] Compute metrics completed!') + CONSOLE.log(f'[bold][yellow] Final metrics: [/]\n {metric.compute()}') + + +def batch_compute_metrics(load_dir, rollouts_files, num_workers, save_dir=None): + from queue import Queue + from threading import Thread + + if save_dir is None: + save_dir = load_dir + + results_buffer = Queue() + + log_every_n_steps = 20 + + metric = LongMetric('val_close_long') + CONSOLE.log(f'metrics config:\n{metric.metrics_config}') + + def _collect_result(): + while True: + r = results_buffer.get() + if r is None: + break + metric.update(metrics=r) + results_buffer.task_done() + + collector = Thread(target=_collect_result, daemon=True) + collector.start() + + partial_compute_metrics = partial(_compute_metrics, metric, load_dir, True) + + # ! compute metrics in batch + with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor: + # results = list(executor.map(partial_compute_metrics, rollouts_files)) + futures = [executor.submit(partial_compute_metrics, rollouts_file) for rollouts_file in rollouts_files] + # results = [f.result() for f in concurrent.futures.as_completed(futures)] + + for i, future in tqdm(enumerate(concurrent.futures.as_completed(futures)), total=len(futures), leave=False): + results_buffer.put(future.result()) + + if i % log_every_n_steps == 0: + CONSOLE.log(f'Step={i}:\n{metric.compute()}') + metric.dumps(save_dir) + + results_buffer.put(None) + collector.join() + + CONSOLE.log(f'[bold][yellow] Compute metrics completed!') + CONSOLE.log(f'[bold][yellow] Final metrics: [/]\n {metric.compute()}') + + # save results to disk + metric.dumps(save_dir) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument('--dump_log', action='store_true') + parser.add_argument('--dump_sim', action='store_true') + parser.add_argument('--aggregate_log', action='store_true') + parser.add_argument('--num_workers', type=int, default=32) + parser.add_argument('--compute_metric', action='store_true') + parser.add_argument('--log_dir', type=str, default='data/waymo_processed/') + parser.add_argument('--sim_dir', type=str, default=None, required=False) + parser.add_argument('--save_dir', type=str, default='results', required=False) + parser.add_argument('--no_batch', action='store_true') + parser.add_argument('--debug', action='store_true') + parser.add_argument('--debug_batch', action='store_true') + args = parser.parse_args() + + if args.dump_log: + + save_dir = os.path.join(args.log_dir, 'log_features') + os.makedirs(save_dir, exist_ok=True) + + if args.no_batch or args.debug: + dump_log_metric_features(args.log_dir, save_dir) + else: + batch_dump_log_metric_features(args.log_dir, save_dir) + + elif args.aggregate_log: + + load_dir = os.path.join(args.log_dir, 'log_features') + aggregate_log_metric_features(load_dir) + + elif args.compute_metric: + + assert args.sim_dir is not None and os.path.exists(args.sim_dir), \ + f'Folder {args.sim_dir} does not exist!' + rollouts_files = list(sorted(fnmatch.filter(os.listdir(args.sim_dir), 'idx_*_rollouts.pkl'))) + CONSOLE.log(f'Found {len(rollouts_files)} rollouts files.') + + os.makedirs(args.save_dir, exist_ok=True) + if args.no_batch: + compute_metrics(args.sim_dir, rollouts_files) + + else: + multiprocessing.set_start_method('spawn', force=True) + batch_compute_metrics(args.sim_dir, rollouts_files, args.num_workers, save_dir=args.save_dir) + + elif args.debug: + + debug_path = 'output/scalable_smart_long/validation_catk/idx_0_0_rollouts.pkl' + + # ! for debugging + with open(debug_path, 'rb') as f: + rollouts = pickle.load(f) + metric = LongMetric('debug') + CONSOLE.log(f'metrics config: {metric.metrics_config}') + + metric.update(outputs=rollouts) + CONSOLE.log(f'metrics:\n{metric.compute()}') + + + elif args.debug_batch: + + rollouts_files = ['idx_0_rollouts.pkl'] * 1000 + CONSOLE.log(f'Found {len(rollouts_files)} rollouts files.') + + sim_dir = 'dev/metrics/' + + os.makedirs(args.save_dir, exist_ok=True) + multiprocessing.set_start_method('spawn', force=True) + batch_compute_metrics(args.sim_dir, rollouts_files, args.num_workers, save_dir=args.save_dir) diff --git a/backups/dev/metrics/geometry_utils.py b/backups/dev/metrics/geometry_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8424b71a95a4d041f925f5a80b76b902c3c31c30 --- /dev/null +++ b/backups/dev/metrics/geometry_utils.py @@ -0,0 +1,137 @@ +import torch +import numpy as np +from torch import Tensor +from typing import Tuple + + +NUM_VERTICES_IN_BOX = 4 + + +def minkowski_sum_of_box_and_box_points(box1_points: Tensor, + box2_points: Tensor) -> Tensor: + """Batched Minkowski sum of two boxes (counter-clockwise corners in xy).""" + point_order_1 = torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.long) + point_order_2 = torch.tensor([0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long) + + box1_start_idx, downmost_box1_edge_direction = _get_downmost_edge_in_box( + box1_points) + box2_start_idx, downmost_box2_edge_direction = _get_downmost_edge_in_box( + box2_points) + + condition = (cross_product_2d(downmost_box1_edge_direction, downmost_box2_edge_direction) >= 0.) + condition = condition.repeat(1, 8) + + box1_point_order = torch.where(condition, point_order_2, point_order_1) + box1_point_order = (box1_point_order + box1_start_idx) % NUM_VERTICES_IN_BOX + ordered_box1_points = torch.gather( + box1_points, 1, box1_point_order.unsqueeze(-1).expand(-1, -1, 2)) + + box2_point_order = torch.where(condition, point_order_1, point_order_2) + box2_point_order = (box2_point_order + box2_start_idx) % NUM_VERTICES_IN_BOX + ordered_box2_points = torch.gather( + box2_points, 1, box2_point_order.unsqueeze(-1).expand(-1, -1, 2)) + + minkowski_sum = ordered_box1_points + ordered_box2_points + + return minkowski_sum + + +def signed_distance_from_point_to_convex_polygon(query_points: Tensor, polygon_points: Tensor) -> Tensor: + """Finds the signed distances from query points to convex polygons.""" + tangent_unit_vectors, normal_unit_vectors, edge_lengths = _get_edge_info( + polygon_points) + + query_points = query_points.unsqueeze(1) + vertices_to_query_vectors = query_points - polygon_points + vertices_distances = torch.norm(vertices_to_query_vectors, dim=-1) + + edge_signed_perp_distances = torch.sum(-normal_unit_vectors * vertices_to_query_vectors, dim=-1) + + is_inside = torch.all(edge_signed_perp_distances <= 0, dim=-1) + + projection_along_tangent = torch.sum(tangent_unit_vectors * vertices_to_query_vectors, dim=-1) + projection_along_tangent_proportion = projection_along_tangent / edge_lengths + + is_projection_on_edge = (projection_along_tangent_proportion >= 0.) & ( + projection_along_tangent_proportion <= 1.) + + edge_perp_distances = edge_signed_perp_distances.abs() + edge_distances = torch.where(is_projection_on_edge, edge_perp_distances, torch.tensor(np.inf)) + + edge_and_vertex_distance = torch.cat([edge_distances, vertices_distances], dim=-1) + min_distance = torch.min(edge_and_vertex_distance, dim=-1)[0] + + signed_distances = torch.where(is_inside, -min_distance, min_distance) + + return signed_distances + + +def _get_downmost_edge_in_box(box: Tensor) -> Tuple[Tensor, Tensor]: + """Finds the downmost (lowest y-coordinate) edge in the box.""" + downmost_vertex_idx = torch.argmin(box[..., 1], dim=-1, keepdim=True) + + edge_start_vertex = torch.gather(box, 1, downmost_vertex_idx.unsqueeze(-1).expand(-1, -1, 2)) + edge_end_idx = (downmost_vertex_idx + 1) % NUM_VERTICES_IN_BOX + edge_end_vertex = torch.gather(box, 1, edge_end_idx.unsqueeze(-1).expand(-1, -1, 2)) + + downmost_edge = edge_end_vertex - edge_start_vertex + downmost_edge_length = torch.norm(downmost_edge, dim=-1, keepdim=True) + downmost_edge_direction = downmost_edge / downmost_edge_length + + return downmost_vertex_idx, downmost_edge_direction + + +def cross_product_2d(a: Tensor, b: Tensor) -> Tensor: + return a[..., 0] * b[..., 1] - a[..., 1] * b[..., 0] + + +def dot_product_2d(a: Tensor, b: Tensor) -> Tensor: + return a[..., 0] * b[..., 0] + a[..., 1] * b[..., 1] + + +def _get_edge_info(polygon_points: Tensor): + """ + Computes properties about the edges of a polygon. + + Args: + polygon_points: Tensor containing the vertices of each polygon, with + shape (num_polygons, num_points_per_polygon, 2). Each polygon is assumed + to have an equal number of vertices. + + Returns: + tangent_unit_vectors: A unit vector in (x,y) with the same direction as + the tangent to the edge. Shape: (num_polygons, num_points_per_polygon, 2). + normal_unit_vectors: A unit vector in (x,y) with the same direction as + the normal to the edge. + Shape: (num_polygons, num_points_per_polygon, 2). + edge_lengths: Lengths of the edges. + Shape (num_polygons, num_points_per_polygon). + """ + # Shift the polygon points by 1 position to get the edges. + first_point_in_polygon = polygon_points[:, 0:1, :] # Shape: (num_polygons, 1, 2) + shifted_polygon_points = torch.cat([polygon_points[:, 1:, :], first_point_in_polygon], dim=1) + # Shape: (num_polygons, num_points_per_polygon, 2) + + edge_vectors = shifted_polygon_points - polygon_points # Shape: (num_polygons, num_points_per_polygon, 2) + edge_lengths = torch.norm(edge_vectors, dim=-1) # Shape: (num_polygons, num_points_per_polygon) + + # Avoid division by zero by adding a small epsilon + eps = torch.finfo(edge_lengths.dtype).eps + tangent_unit_vectors = edge_vectors / (edge_lengths[..., None] + eps) # Shape: (num_polygons, num_points_per_polygon, 2) + + normal_unit_vectors = torch.stack( + [-tangent_unit_vectors[..., 1], tangent_unit_vectors[..., 0]], dim=-1 + ) # Shape: (num_polygons, num_points_per_polygon, 2) + + return tangent_unit_vectors, normal_unit_vectors, edge_lengths + + +def rotate_2d_points(xys: Tensor, rotation_yaws: Tensor) -> Tensor: + """Rotates `xys` counterclockwise using the `rotation_yaws`.""" + cos_yaws = torch.cos(rotation_yaws) + sin_yaws = torch.sin(rotation_yaws) + + rotated_x = cos_yaws * xys[..., 0] - sin_yaws * xys[..., 1] + rotated_y = sin_yaws * xys[..., 0] + cos_yaws * xys[..., 1] + + return torch.stack([rotated_x, rotated_y], axis=-1) diff --git a/backups/dev/metrics/interact_features.py b/backups/dev/metrics/interact_features.py new file mode 100644 index 0000000000000000000000000000000000000000..1db982f18987793e4c3a69922e1dcc2441b90c83 --- /dev/null +++ b/backups/dev/metrics/interact_features.py @@ -0,0 +1,220 @@ +import math +import torch +from torch import Tensor + +from dev.metrics import box_utils +from dev.metrics import geometry_utils +from dev.metrics import trajectory_features + + +EXTREMELY_LARGE_DISTANCE = 1e10 +COLLISION_DISTANCE_THRESHOLD = 0.0 +CORNER_ROUNDING_FACTOR = 0.7 +MAX_HEADING_DIFF = math.radians(75.0) +MAX_HEADING_DIFF_FOR_SMALL_OVERLAP = math.radians(10.0) +SMALL_OVERLAP_THRESHOLD = 0.5 +MAXIMUM_TIME_TO_COLLISION = 5.0 + + +def compute_distance_to_nearest_object( + center_x: Tensor, + center_y: Tensor, + center_z: Tensor, + length: Tensor, + width: Tensor, + height: Tensor, + heading: Tensor, + valid: Tensor, + evaluated_object_mask: Tensor, + corner_rounding_factor: float = CORNER_ROUNDING_FACTOR, +) -> Tensor: + """Computes the distance to nearest object for each of the evaluated objects.""" + boxes = torch.stack([center_x, center_y, center_z, length, width, height, heading], dim=-1) + num_objects, num_steps, num_features = boxes.shape + + shrinking_distance = (torch.minimum(boxes[:, :, 3], boxes[:, :, 4]) * corner_rounding_factor / 2.) + + boxes = torch.cat([ + boxes[:, :, :3], + boxes[:, :, 3:4] - 2.0 * shrinking_distance[..., None], + boxes[:, :, 4:5] - 2.0 * shrinking_distance[..., None], + boxes[:, :, 5:] + ], dim=2) + + boxes = boxes.reshape(num_objects * num_steps, num_features) + + box_corners = box_utils.get_upright_3d_box_corners(boxes)[:, :4, :2] + box_corners = box_corners.reshape(num_objects, num_steps, 4, 2) + + eval_corners = box_corners[evaluated_object_mask] + num_eval_objects = eval_corners.shape[0] + other_corners = box_corners[~evaluated_object_mask] + all_corners = torch.cat([eval_corners, other_corners], dim=0) + + eval_corners = eval_corners.unsqueeze(1).expand(num_eval_objects, num_objects, num_steps, 4, 2) + all_corners = all_corners.unsqueeze(0).expand(num_eval_objects, num_objects, num_steps, 4, 2) + + eval_corners = eval_corners.reshape(num_eval_objects * num_objects * num_steps, 4, 2) + all_corners = all_corners.reshape(num_eval_objects * num_objects * num_steps, 4, 2) + + neg_all_corners = -1.0 * all_corners + minkowski_sum = geometry_utils.minkowski_sum_of_box_and_box_points( + box1_points=eval_corners, box2_points=neg_all_corners, + ) + + assert minkowski_sum.shape[1:] == (8, 2), f"Shape mismatch: {minkowski_sum.shape}, expected (*, 8, 2)" + signed_distances_flat = ( + geometry_utils.signed_distance_from_point_to_convex_polygon( + query_points=torch.zeros_like(minkowski_sum[:, 0, :]), + polygon_points=minkowski_sum, + ) + ) + + signed_distances = signed_distances_flat.reshape(num_eval_objects, num_objects, num_steps) + + eval_shrinking_distance = shrinking_distance[evaluated_object_mask] + other_shrinking_distance = shrinking_distance[~evaluated_object_mask] + all_shrinking_distance = torch.cat([eval_shrinking_distance, other_shrinking_distance], dim=0) + + signed_distances -= eval_shrinking_distance.unsqueeze(1) + signed_distances -= all_shrinking_distance.unsqueeze(0) + + self_mask = torch.eye(num_eval_objects, num_objects, dtype=torch.float32)[:, :, None] + signed_distances = signed_distances + self_mask * EXTREMELY_LARGE_DISTANCE + + eval_validity = valid[evaluated_object_mask] + other_validity = valid[~evaluated_object_mask] + all_validity = torch.cat([eval_validity, other_validity], dim=0) + + valid_mask = eval_validity.unsqueeze(1) & all_validity.unsqueeze(0) + + signed_distances = torch.where(valid_mask, signed_distances, EXTREMELY_LARGE_DISTANCE) + + return torch.min(signed_distances, dim=1).values + + +def compute_time_to_collision_with_object_in_front( + *, + center_x: Tensor, + center_y: Tensor, + length: Tensor, + width: Tensor, + heading: Tensor, + valid: Tensor, + evaluated_object_mask: Tensor, + seconds_per_step: float, +) -> Tensor: + """Computes the time-to-collision of the evaluated objects.""" + # `speed` shape: (num_objects, num_steps) + speed = trajectory_features.compute_kinematic_features( + x=center_x, + y=center_y, + z=torch.zeros_like(center_x), + heading=heading, + seconds_per_step=seconds_per_step, + )[0] + + boxes = torch.stack([center_x, center_y, length, width, heading, speed], dim=-1) + boxes = boxes.permute(1, 0, 2) # (num_steps, num_objects, 6) + valid = valid.permute(1, 0) + + eval_boxes = boxes[:, evaluated_object_mask] + ego_xy, ego_sizes, ego_yaw, ego_speed = torch.split(eval_boxes, [2, 2, 1, 1], dim=-1) + other_xy, other_sizes, other_yaw, _ = torch.split(boxes, [2, 2, 1, 1], dim=-1) + + yaw_diff = torch.abs(other_yaw[:, None] - ego_yaw[:, :, None]) + yaw_diff_cos = torch.cos(yaw_diff) + yaw_diff_sin = torch.sin(yaw_diff) + + other_long_offset = geometry_utils.dot_product_2d( + other_sizes[:, None] / 2.0, torch.abs(torch.cat([yaw_diff_cos, yaw_diff_sin], dim=-1)) + ) + other_lat_offset = geometry_utils.dot_product_2d( + other_sizes[:, None] / 2.0, torch.abs(torch.cat([yaw_diff_sin, yaw_diff_cos], dim=-1)) + ) + + other_relative_xy = geometry_utils.rotate_2d_points( + (other_xy[:, None] - ego_xy[:, :, None]), -ego_yaw + ) + + long_distance = ( + other_relative_xy[..., 0] - ego_sizes[:, :, None, 0] / 2.0 - other_long_offset + ) + lat_overlap = ( + torch.abs(other_relative_xy[..., 1]) - ego_sizes[:, :, None, 1] / 2.0 - other_lat_offset + ) + + following_mask = _get_object_following_mask( + long_distance, lat_overlap, yaw_diff[..., 0] + ) + valid_mask = valid[:, None] & following_mask + + masked_long_distance = ( + long_distance + (1.0 - valid_mask.float()) * EXTREMELY_LARGE_DISTANCE + ) + + box_ahead_index = masked_long_distance.argmin(dim=-1) + distance_to_box_ahead = torch.gather( + masked_long_distance, -1, box_ahead_index.unsqueeze(-1) + ).squeeze(-1) + + speed_broadcast = speed.T[:, None, :].expand_as(masked_long_distance) + box_ahead_speed = torch.gather(speed_broadcast, -1, box_ahead_index.unsqueeze(-1)).squeeze(-1) + + rel_speed = ego_speed[..., 0] - box_ahead_speed + time_to_collision = torch.where( + rel_speed > 0.0, + torch.minimum(distance_to_box_ahead / rel_speed, + torch.tensor(MAXIMUM_TIME_TO_COLLISION)), # the float will be broadcasted automatically + MAXIMUM_TIME_TO_COLLISION, + ) + + return time_to_collision.T + + +def _get_object_following_mask( + longitudinal_distance: Tensor, + lateral_overlap: Tensor, + yaw_diff: Tensor, +) -> Tensor: + """Checks whether objects satisfy criteria for following another object. + + An object on which the criteria are applied is called "ego object" in this + function to disambiguate it from the other objects acting as obstacles. + + An "ego" object is considered to be following another object if they satisfy + conditions on the longitudinal distance, lateral overlap, and yaw alignment + between them. + + Args: + longitudinal_distance: A float Tensor with shape (batch_dim, num_egos, + num_others) containing longitudinal distances from the back side of each + ego box to every other box. + lateral_overlap: A float Tensor with shape (batch_dim, num_egos, num_others) + containing lateral overlaps of other boxes over the trails of ego boxes. + yaw_diff: A float Tensor with shape (batch_dim, num_egos, num_others) + containing absolute yaw differences between egos and other boxes. + + Returns: + A boolean Tensor with shape (batch_dim, num_egos, num_others) indicating for + each ego box if it is following the other boxes. + """ + # Check object is ahead of the ego box's front. + valid_mask = longitudinal_distance > 0.0 + + # Check alignment. + valid_mask = torch.logical_and(valid_mask, yaw_diff <= MAX_HEADING_DIFF) + + # Check object is directly ahead of the ego box. + valid_mask = torch.logical_and(valid_mask, lateral_overlap < 0.0) + + # Check strict alignment if the overlap is small. + # `lateral_overlap` is a signed penetration distance: it is negative when the + # boxes have an actual lateral overlap. + return torch.logical_and( + valid_mask, + torch.logical_or( + lateral_overlap < -SMALL_OVERLAP_THRESHOLD, + yaw_diff <= MAX_HEADING_DIFF_FOR_SMALL_OVERLAP, + ), + ) \ No newline at end of file diff --git a/backups/dev/metrics/map_features.py b/backups/dev/metrics/map_features.py new file mode 100644 index 0000000000000000000000000000000000000000..7a16c37ffcf45a5be6ba4942bd28d640a2524a5f --- /dev/null +++ b/backups/dev/metrics/map_features.py @@ -0,0 +1,349 @@ +import torch +from torch import Tensor +from typing import Optional, Sequence + +from dev.metrics import box_utils +from dev.metrics import geometry_utils +from dev.metrics.protos import map_pb2 + +# Constant distance to apply when distances are invalid. This will avoid the +# propagation of nans and should be reduced out when taking the minimum anyway. +EXTREMELY_LARGE_DISTANCE = 1e10 +# Off-road threshold, i.e. smallest distance away from the road edge that is +# considered to be a off-road. +OFFROAD_DISTANCE_THRESHOLD = 0.0 + +# How close the start and end point of a map feature need to be for the feature +# to be considered cyclic, in m^2. +_CYCLIC_MAP_FEATURE_TOLERANCE_M2 = 1.0 +# Scaling factor for vertical distances used when finding the closest segment to +# a query point. This prevents wrong associations in cases with under- and +# over-passes. +_Z_STRETCH_FACTOR = 3.0 + +_Polyline = Sequence[map_pb2.MapPoint] + + +def compute_distance_to_road_edge( + *, + center_x: Tensor, + center_y: Tensor, + center_z: Tensor, + length: Tensor, + width: Tensor, + height: Tensor, + heading: Tensor, + valid: Tensor, + evaluated_object_mask: Tensor, + road_edge_polylines: Sequence[_Polyline], +) -> Tensor: + """Computes the distance to the road edge for each of the evaluated objects.""" + if not road_edge_polylines: + raise ValueError('Missing road edges.') + + # Concatenate tensors to have the same convention as `box_utils`. + boxes = torch.stack([center_x, center_y, center_z, length, width, height, heading], dim=-1) + num_objects, num_steps, num_features = boxes.shape + boxes = boxes.reshape(num_objects * num_steps, num_features) + + # Compute box corners using `box_utils`, and take the xyz coords of the bottom corners. + box_corners = box_utils.get_upright_3d_box_corners(boxes)[:, :4] + box_corners = box_corners.reshape(num_objects, num_steps, 4, 3) + + # Gather objects in the evaluation set + eval_corners = box_corners[evaluated_object_mask] + num_eval_objects = eval_corners.shape[0] + + # Flatten query points. + flat_eval_corners = eval_corners.reshape(-1, 3) + + # Tensorize road edges. + polylines_tensor = _tensorize_polylines(road_edge_polylines) + is_polyline_cyclic = _check_polyline_cycles(road_edge_polylines) + + # Compute distances for all query points. + corner_distance_to_road_edge = _compute_signed_distance_to_polylines( + xyzs=flat_eval_corners, polylines=polylines_tensor, + is_polyline_cyclic=is_polyline_cyclic, z_stretch=_Z_STRETCH_FACTOR + ) + + # Reshape back to (num_evaluated_objects, num_steps, 4) + corner_distance_to_road_edge = corner_distance_to_road_edge.reshape(num_eval_objects, num_steps, 4) + + # Reduce to most off-road corner. + signed_distances = torch.max(corner_distance_to_road_edge, dim=-1)[0] + + # Mask out invalid boxes. + eval_validity = valid[evaluated_object_mask] + + return torch.where(eval_validity, signed_distances, -EXTREMELY_LARGE_DISTANCE) + + +def _tensorize_polylines(polylines): + """Stacks a sequence of polylines into a tensor. + + Args: + polylines: A sequence of Polyline objects. + + Returns: + A float tensor with shape (num_polylines, max_length, 4) containing xyz + coordinates and a validity flag for all points in the polylines. Polylines + are padded with zeros up to the length of the longest one. + """ + polyline_tensors = [] + max_length = 0 + + for polyline in polylines: + # Skip degenerate polylines. + if len(polyline) < 2: + continue + max_length = max(max_length, len(polyline)) + polyline_tensors.append( + torch.tensor( + [[map_point.x, map_point.y, map_point.z, 1.0] for map_point in polyline], + dtype=torch.float32 + ) + ) + + # Pad and stack polylines + padded_polylines = [ + torch.cat([p, torch.zeros((max_length - p.shape[0], 4), dtype=torch.float32)], dim=0) + for p in polyline_tensors + ] + + return torch.stack(padded_polylines, dim=0) + + +def _check_polyline_cycles(polylines): + """Checks if given polylines are cyclic and returns the result as a tensor. + + Args: + polylines: A sequence of Polyline objects. + tolerance: A float representing the cyclic tolerance. + + Returns: + A bool tensor with shape (num_polylines) indicating whether each polyline is cyclic. + """ + cycles = [] + for polyline in polylines: + # Skip degenerate polylines. + if len(polyline) < 2: + continue + first_point = torch.tensor([polyline[0].x, polyline[0].y, polyline[0].z], dtype=torch.float32) + last_point = torch.tensor([polyline[-1].x, polyline[-1].y, polyline[-1].z], dtype=torch.float32) + cycles.append(torch.sum((first_point - last_point) ** 2) < _CYCLIC_MAP_FEATURE_TOLERANCE_M2) + + return torch.stack(cycles, dim=0) + + +def _compute_signed_distance_to_polylines( + xyzs: Tensor, + polylines: Tensor, + is_polyline_cyclic: Optional[Tensor] = None, + z_stretch: float = 1.0, +) -> Tensor: + """Computes the signed distance to the 2D boundary defined by polylines. + + Negative distances correspond to being inside the boundary (e.g. on the + road), positive distances to being outside (e.g. off-road). + + The polylines should be oriented such that port side is inside the boundary + and starboard is outside, a.k.a counterclockwise winding order. + + The altitudes i.e. the z-coordinates of query points and polyline segments + are used to pair each query point with the most relevant segment, that is + closest and at the right altitude. The distances returned are 2D distances in + the xy plane. + + Args: + xyzs: A float Tensor of shape (num_points, 3) containing xyz coordinates of + query points. + polylines: Tensor with shape (num_polylines, num_segments+1, 4) containing + sequences of xyz coordinates and validity, representing start and end + points of consecutive segments. + is_polyline_cyclic: A boolean Tensor with shape (num_polylines) indicating + whether each polyline is cyclic. If None, all polylines are considered + non-cyclic. + z_stretch: Factor by which to scale distances over the z axis. This can be + done to ensure edge points from the wrong level (e.g. overpasses) are not + selected. Defaults to 1.0 (no stretching). + + Returns: + A tensor of shape (num_points), containing the signed 2D distance from + queried points to the nearest polyline. + """ + num_points = xyzs.shape[0] + assert xyzs.shape == (num_points, 3), f"Expected shape ({num_points}, 3), but got {xyzs.shape}" + num_polylines = polylines.shape[0] + num_segments = polylines.shape[1] - 1 + assert polylines.shape == (num_polylines, num_segments + 1, 4), \ + f"Expected shape ({num_polylines}, {num_segments + 1}, 4), but got {polylines.shape}" + + # shape: (num_polylines, num_segments+1) + is_point_valid = polylines[:, :, 3].bool() + # shape: (num_polylines, num_segments) + is_segment_valid = is_point_valid[:, :-1] & is_point_valid[:, 1:] + + if is_polyline_cyclic is None: + is_polyline_cyclic = torch.zeros(num_polylines, dtype=torch.bool) + else: + assert is_polyline_cyclic.shape == (num_polylines,), \ + f"Expected shape ({num_polylines},), but got {is_polyline_cyclic.shape}" + + # Get distance to each segment. + # shape: (num_points, num_polylines, num_segments, 3) + xyz_starts = polylines[None, :, :-1, :3] + xyz_ends = polylines[None, :, 1:, :3] + start_to_point = xyzs[:, None, None, :3] - xyz_starts + start_to_end = xyz_ends - xyz_starts + + # Relative coordinate of point projection on segment. + # shape: (num_points, num_polylines, num_segments) + numerator = geometry_utils.dot_product_2d( + start_to_point[..., :2], start_to_end[..., :2] + ) + denominator = geometry_utils.dot_product_2d( + start_to_end[..., :2], start_to_end[..., :2] + ) + rel_t = torch.where(denominator != 0, numerator / denominator, torch.zeros_like(numerator)) + + # Negative if point is on port side of segment, positive if point on + # starboard side of segment. + # shape: (num_points, num_polylines, num_segments) + n = torch.sign( + geometry_utils.cross_product_2d( + start_to_point[..., :2], start_to_end[..., :2] + ) + ) + + # Compute the absolute 3D distance to segment. + # The vertical component is scaled by `z-stretch` to increase the separation + # between different road altitudes. + # shape: (num_points, num_polylines, num_segments, 3) + segment_to_point = start_to_point - ( + start_to_end * torch.clamp(rel_t, 0.0, 1.0)[..., None] + ) + stretch_vector = torch.tensor([1.0, 1.0, z_stretch], dtype=torch.float32) + distance_to_segment_3d = torch.norm( + segment_to_point * stretch_vector[None, None, None], + dim=-1, + ) + + # Absolute planar distance to segment. + # shape: (num_points, num_polylines, num_segments) + distance_to_segment_2d = torch.norm(segment_to_point[..., :2], dim=-1) + + # Padded start-to-end segments. + # shape: (num_points, num_polylines, num_segments+2, 2) + start_to_end_padded = torch.cat( + [ + start_to_end[:, :, -1:, :2], + start_to_end[..., :2], + start_to_end[:, :, :1, :2], + ], + dim=-2, + ) + + # shape: (num_points, num_polylines, num_segments+1) + is_locally_convex = torch.gt( + geometry_utils.cross_product_2d( + start_to_end_padded[:, :, :-1], start_to_end_padded[:, :, 1:] + ), + 0., + ) + + # Get shifted versions of `n` and `is_segment_valid`. If the polyline is + # cyclic, the tensors are rolled, else they are padded with their edge value. + # shape: (num_points, num_polylines, num_segments) + n_prior = torch.cat( + [ + torch.where( + is_polyline_cyclic[None, :, None], + n[:, :, -1:], + n[:, :, :1], + ), + n[:, :, :-1], + ], + dim=-1, + ) + n_next = torch.cat( + [ + n[:, :, 1:], + torch.where( + is_polyline_cyclic[None, :, None], + n[:, :, :1], + n[:, :, -1:], + ), + ], + dim=-1, + ) + # shape: (num_polylines, num_segments) + is_prior_segment_valid = torch.cat( + [ + torch.where( + is_polyline_cyclic[:, None], + is_segment_valid[:, -1:], + is_segment_valid[:, :1], + ), + is_segment_valid[:, :-1], + ], + dim=-1, + ) + is_next_segment_valid = torch.cat( + [ + is_segment_valid[:, 1:], + torch.where( + is_polyline_cyclic[:, None], + is_segment_valid[:, :1], + is_segment_valid[:, -1:], + ), + ], + dim=-1, + ) + + # shape: (num_points, num_polylines, num_segments) + sign_if_before = torch.where( + is_locally_convex[:, :, :-1], + torch.maximum(n, n_prior), + torch.minimum(n, n_prior), + ) + sign_if_after = torch.where( + is_locally_convex[:, :, 1:], torch.maximum(n, n_next), torch.minimum(n, n_next) + ) + + # shape: (num_points, num_polylines, num_segments) + sign_to_segment = torch.where( + (rel_t < 0.0) & is_prior_segment_valid, + sign_if_before, + torch.where((rel_t > 1.0) & is_next_segment_valid, sign_if_after, n), + ) + + # Flatten polylines together. + # shape: (num_points, all_segments) + distance_to_segment_3d = distance_to_segment_3d.view(num_points, num_polylines * num_segments) + distance_to_segment_2d = distance_to_segment_2d.view(num_points, num_polylines * num_segments) + sign_to_segment = sign_to_segment.view(num_points, num_polylines * num_segments) + + # Mask out invalid segments. + # shape: (all_segments) + is_segment_valid = is_segment_valid.view(num_polylines * num_segments) + # shape: (num_points, all_segments) + distance_to_segment_3d = torch.where( + is_segment_valid[None], + distance_to_segment_3d, + EXTREMELY_LARGE_DISTANCE, + ) + distance_to_segment_2d = torch.where( + is_segment_valid[None], + distance_to_segment_2d, + EXTREMELY_LARGE_DISTANCE, + ) + + # Get closest segment according to absolute 3D distance and return the + # corresponding signed 2D distance. + # shape: (num_points) + closest_segment_index = torch.argmin(distance_to_segment_3d, dim=-1) + distance_sign = torch.gather(sign_to_segment, 1, closest_segment_index.unsqueeze(-1)).squeeze(-1) + distance_2d = torch.gather(distance_to_segment_2d, 1, closest_segment_index.unsqueeze(-1)).squeeze(-1) + + return distance_sign * distance_2d diff --git a/backups/dev/metrics/placement_features.py b/backups/dev/metrics/placement_features.py new file mode 100644 index 0000000000000000000000000000000000000000..0c5fbc62316e7436fb48ca7c8cb9d0c09fd5f4a5 --- /dev/null +++ b/backups/dev/metrics/placement_features.py @@ -0,0 +1,48 @@ +import torch +from torch import Tensor +from typing import Optional, Sequence, List + + +def compute_num_placement( + valid: Tensor, # [n_agent, n_step] + state: Tensor, # [n_agent, n_step] + av_id: int, + object_id: Tensor, + agent_state: List[str], +) -> Tensor: + + enter_state = agent_state.index('enter') + exit_state = agent_state.index('exit') + + av_index = object_id.tolist().index(av_id) + state[av_index] = -1 # we do not incorporate the sdc + + is_bos = state == enter_state + is_eos = state == exit_state + + num_bos = torch.sum(is_bos, dim=0) + num_eos = torch.sum(is_eos, dim=0) + + return num_bos, num_eos + + +def compute_distance_placement( + position: Tensor, + state: Tensor, + valid: Tensor, + av_id: int, + object_id: Tensor, + agent_state: List[str], +) -> Tensor: + + enter_state = agent_state.index('enter') + exit_state = agent_state.index('exit') + + av_index = object_id.tolist().index(av_id) + state[av_index] = -1 # we do not incorporate the sdc + distance = torch.norm(position - position[av_index : av_index + 1], p=2, dim=-1) + + bos_distance = distance * (state == enter_state) + eos_distance = distance * (state == exit_state) + + return bos_distance, eos_distance diff --git a/backups/dev/metrics/protos/long_metrics_pb2.py b/backups/dev/metrics/protos/long_metrics_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..8dad3962520ee1ade83e915690956e46701c842d --- /dev/null +++ b/backups/dev/metrics/protos/long_metrics_pb2.py @@ -0,0 +1,648 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: dev/metrics/protos/long_metrics.proto + +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='dev/metrics/protos/long_metrics.proto', + package='long_metric', + syntax='proto2', + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_pb=b'\n%dev/metrics/protos/long_metrics.proto\x12\x0blong_metric\"\xb4\x0c\n\x15SimAgentMetricsConfig\x12\x46\n\x0clinear_speed\x18\x01 \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12M\n\x13linear_acceleration\x18\x02 \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12G\n\rangular_speed\x18\x03 \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12N\n\x14\x61ngular_acceleration\x18\x04 \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12T\n\x1a\x64istance_to_nearest_object\x18\x05 \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12N\n\x14\x63ollision_indication\x18\x06 \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12K\n\x11time_to_collision\x18\x07 \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12O\n\x15\x64istance_to_road_edge\x18\x08 \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12L\n\x12offroad_indication\x18\t \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12G\n\rnum_placement\x18\n \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12H\n\x0enum_removement\x18\x0b \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12L\n\x12\x64istance_placement\x18\x0c \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12M\n\x13\x64istance_removement\x18\r \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x1a\xc0\x02\n\rFeatureConfig\x12I\n\thistogram\x18\x01 \x01(\x0b\x32\x34.long_metric.SimAgentMetricsConfig.HistogramEstimateH\x00\x12R\n\x0ekernel_density\x18\x02 \x01(\x0b\x32\x38.long_metric.SimAgentMetricsConfig.KernelDensityEstimateH\x00\x12I\n\tbernoulli\x18\x03 \x01(\x0b\x32\x34.long_metric.SimAgentMetricsConfig.BernoulliEstimateH\x00\x12\x1d\n\x15independent_timesteps\x18\x04 \x01(\x08\x12\x19\n\x11metametric_weight\x18\x05 \x01(\x02\x42\x0b\n\testimator\x1av\n\x11HistogramEstimate\x12\x0f\n\x07min_val\x18\x01 \x01(\x02\x12\x0f\n\x07max_val\x18\x02 \x01(\x02\x12\x10\n\x08num_bins\x18\x03 \x01(\x05\x12-\n\x1e\x61\x64\x64itive_smoothing_pseudocount\x18\x04 \x01(\x02:\x05\x30.001\x1a*\n\x15KernelDensityEstimate\x12\x11\n\tbandwidth\x18\x01 \x01(\x02\x1a\x42\n\x11\x42\x65rnoulliEstimate\x12-\n\x1e\x61\x64\x64itive_smoothing_pseudocount\x18\x04 \x01(\x02:\x05\x30.001\"\xbf\x05\n\x0fSimAgentMetrics\x12\x13\n\x0bscenario_id\x18\x01 \x01(\t\x12\x12\n\nmetametric\x18\x02 \x01(\x02\x12\"\n\x1a\x61verage_displacement_error\x18\x03 \x01(\x02\x12&\n\x1emin_average_displacement_error\x18\x13 \x01(\x02\x12\x1f\n\x17linear_speed_likelihood\x18\x04 \x01(\x02\x12&\n\x1elinear_acceleration_likelihood\x18\x05 \x01(\x02\x12 \n\x18\x61ngular_speed_likelihood\x18\x06 \x01(\x02\x12\'\n\x1f\x61ngular_acceleration_likelihood\x18\x07 \x01(\x02\x12-\n%distance_to_nearest_object_likelihood\x18\x08 \x01(\x02\x12\'\n\x1f\x63ollision_indication_likelihood\x18\t \x01(\x02\x12$\n\x1ctime_to_collision_likelihood\x18\n \x01(\x02\x12(\n distance_to_road_edge_likelihood\x18\x0b \x01(\x02\x12%\n\x1doffroad_indication_likelihood\x18\x0c \x01(\x02\x12 \n\x18num_placement_likelihood\x18\r \x01(\x02\x12!\n\x19num_removement_likelihood\x18\x0e \x01(\x02\x12%\n\x1d\x64istance_placement_likelihood\x18\x0f \x01(\x02\x12&\n\x1e\x64istance_removement_likelihood\x18\x10 \x01(\x02\x12 \n\x18simulated_collision_rate\x18\x11 \x01(\x02\x12\x1e\n\x16simulated_offroad_rate\x18\x12 \x01(\x02\"\xfe\x01\n\x18SimAgentsBucketedMetrics\x12\x1b\n\x13realism_meta_metric\x18\x01 \x01(\x02\x12\x19\n\x11kinematic_metrics\x18\x02 \x01(\x02\x12\x1b\n\x13interactive_metrics\x18\x05 \x01(\x02\x12\x19\n\x11map_based_metrics\x18\x06 \x01(\x02\x12\x1f\n\x17placement_based_metrics\x18\x07 \x01(\x02\x12\x0f\n\x07min_ade\x18\x08 \x01(\x02\x12 \n\x18simulated_collision_rate\x18\t \x01(\x02\x12\x1e\n\x16simulated_offroad_rate\x18\n \x01(\x02' +) + + + + +_SIMAGENTMETRICSCONFIG_FEATURECONFIG = _descriptor.Descriptor( + name='FeatureConfig', + full_name='long_metric.SimAgentMetricsConfig.FeatureConfig', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='histogram', full_name='long_metric.SimAgentMetricsConfig.FeatureConfig.histogram', index=0, + number=1, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='kernel_density', full_name='long_metric.SimAgentMetricsConfig.FeatureConfig.kernel_density', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='bernoulli', full_name='long_metric.SimAgentMetricsConfig.FeatureConfig.bernoulli', index=2, + number=3, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='independent_timesteps', full_name='long_metric.SimAgentMetricsConfig.FeatureConfig.independent_timesteps', index=3, + number=4, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='metametric_weight', full_name='long_metric.SimAgentMetricsConfig.FeatureConfig.metametric_weight', index=4, + number=5, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name='estimator', full_name='long_metric.SimAgentMetricsConfig.FeatureConfig.estimator', + index=0, containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[]), + ], + serialized_start=1091, + serialized_end=1411, +) + +_SIMAGENTMETRICSCONFIG_HISTOGRAMESTIMATE = _descriptor.Descriptor( + name='HistogramEstimate', + full_name='long_metric.SimAgentMetricsConfig.HistogramEstimate', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='min_val', full_name='long_metric.SimAgentMetricsConfig.HistogramEstimate.min_val', index=0, + number=1, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='max_val', full_name='long_metric.SimAgentMetricsConfig.HistogramEstimate.max_val', index=1, + number=2, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='num_bins', full_name='long_metric.SimAgentMetricsConfig.HistogramEstimate.num_bins', index=2, + number=3, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='additive_smoothing_pseudocount', full_name='long_metric.SimAgentMetricsConfig.HistogramEstimate.additive_smoothing_pseudocount', index=3, + number=4, type=2, cpp_type=6, label=1, + has_default_value=True, default_value=float(0.001), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1413, + serialized_end=1531, +) + +_SIMAGENTMETRICSCONFIG_KERNELDENSITYESTIMATE = _descriptor.Descriptor( + name='KernelDensityEstimate', + full_name='long_metric.SimAgentMetricsConfig.KernelDensityEstimate', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='bandwidth', full_name='long_metric.SimAgentMetricsConfig.KernelDensityEstimate.bandwidth', index=0, + number=1, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1533, + serialized_end=1575, +) + +_SIMAGENTMETRICSCONFIG_BERNOULLIESTIMATE = _descriptor.Descriptor( + name='BernoulliEstimate', + full_name='long_metric.SimAgentMetricsConfig.BernoulliEstimate', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='additive_smoothing_pseudocount', full_name='long_metric.SimAgentMetricsConfig.BernoulliEstimate.additive_smoothing_pseudocount', index=0, + number=4, type=2, cpp_type=6, label=1, + has_default_value=True, default_value=float(0.001), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1577, + serialized_end=1643, +) + +_SIMAGENTMETRICSCONFIG = _descriptor.Descriptor( + name='SimAgentMetricsConfig', + full_name='long_metric.SimAgentMetricsConfig', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='linear_speed', full_name='long_metric.SimAgentMetricsConfig.linear_speed', index=0, + number=1, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='linear_acceleration', full_name='long_metric.SimAgentMetricsConfig.linear_acceleration', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='angular_speed', full_name='long_metric.SimAgentMetricsConfig.angular_speed', index=2, + number=3, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='angular_acceleration', full_name='long_metric.SimAgentMetricsConfig.angular_acceleration', index=3, + number=4, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='distance_to_nearest_object', full_name='long_metric.SimAgentMetricsConfig.distance_to_nearest_object', index=4, + number=5, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='collision_indication', full_name='long_metric.SimAgentMetricsConfig.collision_indication', index=5, + number=6, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='time_to_collision', full_name='long_metric.SimAgentMetricsConfig.time_to_collision', index=6, + number=7, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='distance_to_road_edge', full_name='long_metric.SimAgentMetricsConfig.distance_to_road_edge', index=7, + number=8, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='offroad_indication', full_name='long_metric.SimAgentMetricsConfig.offroad_indication', index=8, + number=9, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='num_placement', full_name='long_metric.SimAgentMetricsConfig.num_placement', index=9, + number=10, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='num_removement', full_name='long_metric.SimAgentMetricsConfig.num_removement', index=10, + number=11, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='distance_placement', full_name='long_metric.SimAgentMetricsConfig.distance_placement', index=11, + number=12, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='distance_removement', full_name='long_metric.SimAgentMetricsConfig.distance_removement', index=12, + number=13, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[_SIMAGENTMETRICSCONFIG_FEATURECONFIG, _SIMAGENTMETRICSCONFIG_HISTOGRAMESTIMATE, _SIMAGENTMETRICSCONFIG_KERNELDENSITYESTIMATE, _SIMAGENTMETRICSCONFIG_BERNOULLIESTIMATE, ], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=55, + serialized_end=1643, +) + + +_SIMAGENTMETRICS = _descriptor.Descriptor( + name='SimAgentMetrics', + full_name='long_metric.SimAgentMetrics', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='scenario_id', full_name='long_metric.SimAgentMetrics.scenario_id', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='metametric', full_name='long_metric.SimAgentMetrics.metametric', index=1, + number=2, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='average_displacement_error', full_name='long_metric.SimAgentMetrics.average_displacement_error', index=2, + number=3, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='min_average_displacement_error', full_name='long_metric.SimAgentMetrics.min_average_displacement_error', index=3, + number=19, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='linear_speed_likelihood', full_name='long_metric.SimAgentMetrics.linear_speed_likelihood', index=4, + number=4, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='linear_acceleration_likelihood', full_name='long_metric.SimAgentMetrics.linear_acceleration_likelihood', index=5, + number=5, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='angular_speed_likelihood', full_name='long_metric.SimAgentMetrics.angular_speed_likelihood', index=6, + number=6, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='angular_acceleration_likelihood', full_name='long_metric.SimAgentMetrics.angular_acceleration_likelihood', index=7, + number=7, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='distance_to_nearest_object_likelihood', full_name='long_metric.SimAgentMetrics.distance_to_nearest_object_likelihood', index=8, + number=8, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='collision_indication_likelihood', full_name='long_metric.SimAgentMetrics.collision_indication_likelihood', index=9, + number=9, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='time_to_collision_likelihood', full_name='long_metric.SimAgentMetrics.time_to_collision_likelihood', index=10, + number=10, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='distance_to_road_edge_likelihood', full_name='long_metric.SimAgentMetrics.distance_to_road_edge_likelihood', index=11, + number=11, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='offroad_indication_likelihood', full_name='long_metric.SimAgentMetrics.offroad_indication_likelihood', index=12, + number=12, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='num_placement_likelihood', full_name='long_metric.SimAgentMetrics.num_placement_likelihood', index=13, + number=13, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='num_removement_likelihood', full_name='long_metric.SimAgentMetrics.num_removement_likelihood', index=14, + number=14, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='distance_placement_likelihood', full_name='long_metric.SimAgentMetrics.distance_placement_likelihood', index=15, + number=15, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='distance_removement_likelihood', full_name='long_metric.SimAgentMetrics.distance_removement_likelihood', index=16, + number=16, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='simulated_collision_rate', full_name='long_metric.SimAgentMetrics.simulated_collision_rate', index=17, + number=17, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='simulated_offroad_rate', full_name='long_metric.SimAgentMetrics.simulated_offroad_rate', index=18, + number=18, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1646, + serialized_end=2349, +) + + +_SIMAGENTSBUCKETEDMETRICS = _descriptor.Descriptor( + name='SimAgentsBucketedMetrics', + full_name='long_metric.SimAgentsBucketedMetrics', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='realism_meta_metric', full_name='long_metric.SimAgentsBucketedMetrics.realism_meta_metric', index=0, + number=1, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='kinematic_metrics', full_name='long_metric.SimAgentsBucketedMetrics.kinematic_metrics', index=1, + number=2, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='interactive_metrics', full_name='long_metric.SimAgentsBucketedMetrics.interactive_metrics', index=2, + number=5, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='map_based_metrics', full_name='long_metric.SimAgentsBucketedMetrics.map_based_metrics', index=3, + number=6, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='placement_based_metrics', full_name='long_metric.SimAgentsBucketedMetrics.placement_based_metrics', index=4, + number=7, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='min_ade', full_name='long_metric.SimAgentsBucketedMetrics.min_ade', index=5, + number=8, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='simulated_collision_rate', full_name='long_metric.SimAgentsBucketedMetrics.simulated_collision_rate', index=6, + number=9, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='simulated_offroad_rate', full_name='long_metric.SimAgentsBucketedMetrics.simulated_offroad_rate', index=7, + number=10, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=2352, + serialized_end=2606, +) + +_SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['histogram'].message_type = _SIMAGENTMETRICSCONFIG_HISTOGRAMESTIMATE +_SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['kernel_density'].message_type = _SIMAGENTMETRICSCONFIG_KERNELDENSITYESTIMATE +_SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['bernoulli'].message_type = _SIMAGENTMETRICSCONFIG_BERNOULLIESTIMATE +_SIMAGENTMETRICSCONFIG_FEATURECONFIG.containing_type = _SIMAGENTMETRICSCONFIG +_SIMAGENTMETRICSCONFIG_FEATURECONFIG.oneofs_by_name['estimator'].fields.append( + _SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['histogram']) +_SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['histogram'].containing_oneof = _SIMAGENTMETRICSCONFIG_FEATURECONFIG.oneofs_by_name['estimator'] +_SIMAGENTMETRICSCONFIG_FEATURECONFIG.oneofs_by_name['estimator'].fields.append( + _SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['kernel_density']) +_SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['kernel_density'].containing_oneof = _SIMAGENTMETRICSCONFIG_FEATURECONFIG.oneofs_by_name['estimator'] +_SIMAGENTMETRICSCONFIG_FEATURECONFIG.oneofs_by_name['estimator'].fields.append( + _SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['bernoulli']) +_SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['bernoulli'].containing_oneof = _SIMAGENTMETRICSCONFIG_FEATURECONFIG.oneofs_by_name['estimator'] +_SIMAGENTMETRICSCONFIG_HISTOGRAMESTIMATE.containing_type = _SIMAGENTMETRICSCONFIG +_SIMAGENTMETRICSCONFIG_KERNELDENSITYESTIMATE.containing_type = _SIMAGENTMETRICSCONFIG +_SIMAGENTMETRICSCONFIG_BERNOULLIESTIMATE.containing_type = _SIMAGENTMETRICSCONFIG +_SIMAGENTMETRICSCONFIG.fields_by_name['linear_speed'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG +_SIMAGENTMETRICSCONFIG.fields_by_name['linear_acceleration'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG +_SIMAGENTMETRICSCONFIG.fields_by_name['angular_speed'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG +_SIMAGENTMETRICSCONFIG.fields_by_name['angular_acceleration'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG +_SIMAGENTMETRICSCONFIG.fields_by_name['distance_to_nearest_object'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG +_SIMAGENTMETRICSCONFIG.fields_by_name['collision_indication'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG +_SIMAGENTMETRICSCONFIG.fields_by_name['time_to_collision'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG +_SIMAGENTMETRICSCONFIG.fields_by_name['distance_to_road_edge'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG +_SIMAGENTMETRICSCONFIG.fields_by_name['offroad_indication'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG +_SIMAGENTMETRICSCONFIG.fields_by_name['num_placement'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG +_SIMAGENTMETRICSCONFIG.fields_by_name['num_removement'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG +_SIMAGENTMETRICSCONFIG.fields_by_name['distance_placement'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG +_SIMAGENTMETRICSCONFIG.fields_by_name['distance_removement'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG +DESCRIPTOR.message_types_by_name['SimAgentMetricsConfig'] = _SIMAGENTMETRICSCONFIG +DESCRIPTOR.message_types_by_name['SimAgentMetrics'] = _SIMAGENTMETRICS +DESCRIPTOR.message_types_by_name['SimAgentsBucketedMetrics'] = _SIMAGENTSBUCKETEDMETRICS +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +SimAgentMetricsConfig = _reflection.GeneratedProtocolMessageType('SimAgentMetricsConfig', (_message.Message,), { + + 'FeatureConfig' : _reflection.GeneratedProtocolMessageType('FeatureConfig', (_message.Message,), { + 'DESCRIPTOR' : _SIMAGENTMETRICSCONFIG_FEATURECONFIG, + '__module__' : 'dev.metrics.protos.long_metrics_pb2' + # @@protoc_insertion_point(class_scope:long_metric.SimAgentMetricsConfig.FeatureConfig) + }) + , + + 'HistogramEstimate' : _reflection.GeneratedProtocolMessageType('HistogramEstimate', (_message.Message,), { + 'DESCRIPTOR' : _SIMAGENTMETRICSCONFIG_HISTOGRAMESTIMATE, + '__module__' : 'dev.metrics.protos.long_metrics_pb2' + # @@protoc_insertion_point(class_scope:long_metric.SimAgentMetricsConfig.HistogramEstimate) + }) + , + + 'KernelDensityEstimate' : _reflection.GeneratedProtocolMessageType('KernelDensityEstimate', (_message.Message,), { + 'DESCRIPTOR' : _SIMAGENTMETRICSCONFIG_KERNELDENSITYESTIMATE, + '__module__' : 'dev.metrics.protos.long_metrics_pb2' + # @@protoc_insertion_point(class_scope:long_metric.SimAgentMetricsConfig.KernelDensityEstimate) + }) + , + + 'BernoulliEstimate' : _reflection.GeneratedProtocolMessageType('BernoulliEstimate', (_message.Message,), { + 'DESCRIPTOR' : _SIMAGENTMETRICSCONFIG_BERNOULLIESTIMATE, + '__module__' : 'dev.metrics.protos.long_metrics_pb2' + # @@protoc_insertion_point(class_scope:long_metric.SimAgentMetricsConfig.BernoulliEstimate) + }) + , + 'DESCRIPTOR' : _SIMAGENTMETRICSCONFIG, + '__module__' : 'dev.metrics.protos.long_metrics_pb2' + # @@protoc_insertion_point(class_scope:long_metric.SimAgentMetricsConfig) + }) +_sym_db.RegisterMessage(SimAgentMetricsConfig) +_sym_db.RegisterMessage(SimAgentMetricsConfig.FeatureConfig) +_sym_db.RegisterMessage(SimAgentMetricsConfig.HistogramEstimate) +_sym_db.RegisterMessage(SimAgentMetricsConfig.KernelDensityEstimate) +_sym_db.RegisterMessage(SimAgentMetricsConfig.BernoulliEstimate) + +SimAgentMetrics = _reflection.GeneratedProtocolMessageType('SimAgentMetrics', (_message.Message,), { + 'DESCRIPTOR' : _SIMAGENTMETRICS, + '__module__' : 'dev.metrics.protos.long_metrics_pb2' + # @@protoc_insertion_point(class_scope:long_metric.SimAgentMetrics) + }) +_sym_db.RegisterMessage(SimAgentMetrics) + +SimAgentsBucketedMetrics = _reflection.GeneratedProtocolMessageType('SimAgentsBucketedMetrics', (_message.Message,), { + 'DESCRIPTOR' : _SIMAGENTSBUCKETEDMETRICS, + '__module__' : 'dev.metrics.protos.long_metrics_pb2' + # @@protoc_insertion_point(class_scope:long_metric.SimAgentsBucketedMetrics) + }) +_sym_db.RegisterMessage(SimAgentsBucketedMetrics) + + +# @@protoc_insertion_point(module_scope) diff --git a/backups/dev/metrics/protos/map_pb2.py b/backups/dev/metrics/protos/map_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..e5ac4fc2ba76c3707206a4f59237ba86bf3913ab --- /dev/null +++ b/backups/dev/metrics/protos/map_pb2.py @@ -0,0 +1,1070 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: dev/metrics/protos/map.proto + +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='dev/metrics/protos/map.proto', + package='long_metric', + syntax='proto2', + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_pb=b'\n\x1c\x64\x65v/metrics/protos/map.proto\x12\x0blong_metric\"g\n\x03Map\x12-\n\x0cmap_features\x18\x01 \x03(\x0b\x32\x17.long_metric.MapFeature\x12\x31\n\x0e\x64ynamic_states\x18\x02 \x03(\x0b\x32\x19.long_metric.DynamicState\"c\n\x0c\x44ynamicState\x12\x19\n\x11timestamp_seconds\x18\x01 \x01(\x01\x12\x38\n\x0blane_states\x18\x02 \x03(\x0b\x32#.long_metric.TrafficSignalLaneState\"\xfe\x02\n\x16TrafficSignalLaneState\x12\x0c\n\x04lane\x18\x01 \x01(\x03\x12\x38\n\x05state\x18\x02 \x01(\x0e\x32).long_metric.TrafficSignalLaneState.State\x12)\n\nstop_point\x18\x03 \x01(\x0b\x32\x15.long_metric.MapPoint\"\xf0\x01\n\x05State\x12\x16\n\x12LANE_STATE_UNKNOWN\x10\x00\x12\x19\n\x15LANE_STATE_ARROW_STOP\x10\x01\x12\x1c\n\x18LANE_STATE_ARROW_CAUTION\x10\x02\x12\x17\n\x13LANE_STATE_ARROW_GO\x10\x03\x12\x13\n\x0fLANE_STATE_STOP\x10\x04\x12\x16\n\x12LANE_STATE_CAUTION\x10\x05\x12\x11\n\rLANE_STATE_GO\x10\x06\x12\x1c\n\x18LANE_STATE_FLASHING_STOP\x10\x07\x12\x1f\n\x1bLANE_STATE_FLASHING_CAUTION\x10\x08\"\xdb\x02\n\nMapFeature\x12\n\n\x02id\x18\x01 \x01(\x03\x12\'\n\x04lane\x18\x03 \x01(\x0b\x32\x17.long_metric.LaneCenterH\x00\x12*\n\troad_line\x18\x04 \x01(\x0b\x32\x15.long_metric.RoadLineH\x00\x12*\n\troad_edge\x18\x05 \x01(\x0b\x32\x15.long_metric.RoadEdgeH\x00\x12*\n\tstop_sign\x18\x07 \x01(\x0b\x32\x15.long_metric.StopSignH\x00\x12+\n\tcrosswalk\x18\x08 \x01(\x0b\x32\x16.long_metric.CrosswalkH\x00\x12,\n\nspeed_bump\x18\t \x01(\x0b\x32\x16.long_metric.SpeedBumpH\x00\x12)\n\x08\x64riveway\x18\n \x01(\x0b\x32\x15.long_metric.DrivewayH\x00\x42\x0e\n\x0c\x66\x65\x61ture_data\"+\n\x08MapPoint\x12\t\n\x01x\x18\x01 \x01(\x01\x12\t\n\x01y\x18\x02 \x01(\x01\x12\t\n\x01z\x18\x03 \x01(\x01\"\x9b\x01\n\x0f\x42oundarySegment\x12\x18\n\x10lane_start_index\x18\x01 \x01(\x05\x12\x16\n\x0elane_end_index\x18\x02 \x01(\x05\x12\x1b\n\x13\x62oundary_feature_id\x18\x03 \x01(\x03\x12\x39\n\rboundary_type\x18\x04 \x01(\x0e\x32\".long_metric.RoadLine.RoadLineType\"\xc0\x01\n\x0cLaneNeighbor\x12\x12\n\nfeature_id\x18\x01 \x01(\x03\x12\x18\n\x10self_start_index\x18\x02 \x01(\x05\x12\x16\n\x0eself_end_index\x18\x03 \x01(\x05\x12\x1c\n\x14neighbor_start_index\x18\x04 \x01(\x05\x12\x1a\n\x12neighbor_end_index\x18\x05 \x01(\x05\x12\x30\n\nboundaries\x18\x06 \x03(\x0b\x32\x1c.long_metric.BoundarySegment\"\xfb\x03\n\nLaneCenter\x12\x17\n\x0fspeed_limit_mph\x18\x01 \x01(\x01\x12.\n\x04type\x18\x02 \x01(\x0e\x32 .long_metric.LaneCenter.LaneType\x12\x15\n\rinterpolating\x18\x03 \x01(\x08\x12\'\n\x08polyline\x18\x08 \x03(\x0b\x32\x15.long_metric.MapPoint\x12\x17\n\x0b\x65ntry_lanes\x18\t \x03(\x03\x42\x02\x10\x01\x12\x16\n\nexit_lanes\x18\n \x03(\x03\x42\x02\x10\x01\x12\x35\n\x0fleft_boundaries\x18\r \x03(\x0b\x32\x1c.long_metric.BoundarySegment\x12\x36\n\x10right_boundaries\x18\x0e \x03(\x0b\x32\x1c.long_metric.BoundarySegment\x12\x31\n\x0eleft_neighbors\x18\x0b \x03(\x0b\x32\x19.long_metric.LaneNeighbor\x12\x32\n\x0fright_neighbors\x18\x0c \x03(\x0b\x32\x19.long_metric.LaneNeighbor\"]\n\x08LaneType\x12\x12\n\x0eTYPE_UNDEFINED\x10\x00\x12\x10\n\x0cTYPE_FREEWAY\x10\x01\x12\x17\n\x13TYPE_SURFACE_STREET\x10\x02\x12\x12\n\x0eTYPE_BIKE_LANE\x10\x03\"\xbf\x01\n\x08RoadEdge\x12\x30\n\x04type\x18\x01 \x01(\x0e\x32\".long_metric.RoadEdge.RoadEdgeType\x12\'\n\x08polyline\x18\x02 \x03(\x0b\x32\x15.long_metric.MapPoint\"X\n\x0cRoadEdgeType\x12\x10\n\x0cTYPE_UNKNOWN\x10\x00\x12\x1b\n\x17TYPE_ROAD_EDGE_BOUNDARY\x10\x01\x12\x19\n\x15TYPE_ROAD_EDGE_MEDIAN\x10\x02\"\xfa\x02\n\x08RoadLine\x12\x30\n\x04type\x18\x01 \x01(\x0e\x32\".long_metric.RoadLine.RoadLineType\x12\'\n\x08polyline\x18\x02 \x03(\x0b\x32\x15.long_metric.MapPoint\"\x92\x02\n\x0cRoadLineType\x12\x10\n\x0cTYPE_UNKNOWN\x10\x00\x12\x1c\n\x18TYPE_BROKEN_SINGLE_WHITE\x10\x01\x12\x1b\n\x17TYPE_SOLID_SINGLE_WHITE\x10\x02\x12\x1b\n\x17TYPE_SOLID_DOUBLE_WHITE\x10\x03\x12\x1d\n\x19TYPE_BROKEN_SINGLE_YELLOW\x10\x04\x12\x1d\n\x19TYPE_BROKEN_DOUBLE_YELLOW\x10\x05\x12\x1c\n\x18TYPE_SOLID_SINGLE_YELLOW\x10\x06\x12\x1c\n\x18TYPE_SOLID_DOUBLE_YELLOW\x10\x07\x12\x1e\n\x1aTYPE_PASSING_DOUBLE_YELLOW\x10\x08\"A\n\x08StopSign\x12\x0c\n\x04lane\x18\x01 \x03(\x03\x12\'\n\x08position\x18\x02 \x01(\x0b\x32\x15.long_metric.MapPoint\"3\n\tCrosswalk\x12&\n\x07polygon\x18\x01 \x03(\x0b\x32\x15.long_metric.MapPoint\"3\n\tSpeedBump\x12&\n\x07polygon\x18\x01 \x03(\x0b\x32\x15.long_metric.MapPoint\"2\n\x08\x44riveway\x12&\n\x07polygon\x18\x01 \x03(\x0b\x32\x15.long_metric.MapPoint' +) + + + +_TRAFFICSIGNALLANESTATE_STATE = _descriptor.EnumDescriptor( + name='State', + full_name='long_metric.TrafficSignalLaneState.State', + filename=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + values=[ + _descriptor.EnumValueDescriptor( + name='LANE_STATE_UNKNOWN', index=0, number=0, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='LANE_STATE_ARROW_STOP', index=1, number=1, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='LANE_STATE_ARROW_CAUTION', index=2, number=2, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='LANE_STATE_ARROW_GO', index=3, number=3, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='LANE_STATE_STOP', index=4, number=4, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='LANE_STATE_CAUTION', index=5, number=5, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='LANE_STATE_GO', index=6, number=6, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='LANE_STATE_FLASHING_STOP', index=7, number=7, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='LANE_STATE_FLASHING_CAUTION', index=8, number=8, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + ], + containing_type=None, + serialized_options=None, + serialized_start=394, + serialized_end=634, +) +_sym_db.RegisterEnumDescriptor(_TRAFFICSIGNALLANESTATE_STATE) + +_LANECENTER_LANETYPE = _descriptor.EnumDescriptor( + name='LaneType', + full_name='long_metric.LaneCenter.LaneType', + filename=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + values=[ + _descriptor.EnumValueDescriptor( + name='TYPE_UNDEFINED', index=0, number=0, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='TYPE_FREEWAY', index=1, number=1, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='TYPE_SURFACE_STREET', index=2, number=2, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='TYPE_BIKE_LANE', index=3, number=3, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + ], + containing_type=None, + serialized_options=None, + serialized_start=1799, + serialized_end=1892, +) +_sym_db.RegisterEnumDescriptor(_LANECENTER_LANETYPE) + +_ROADEDGE_ROADEDGETYPE = _descriptor.EnumDescriptor( + name='RoadEdgeType', + full_name='long_metric.RoadEdge.RoadEdgeType', + filename=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + values=[ + _descriptor.EnumValueDescriptor( + name='TYPE_UNKNOWN', index=0, number=0, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='TYPE_ROAD_EDGE_BOUNDARY', index=1, number=1, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='TYPE_ROAD_EDGE_MEDIAN', index=2, number=2, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + ], + containing_type=None, + serialized_options=None, + serialized_start=1998, + serialized_end=2086, +) +_sym_db.RegisterEnumDescriptor(_ROADEDGE_ROADEDGETYPE) + +_ROADLINE_ROADLINETYPE = _descriptor.EnumDescriptor( + name='RoadLineType', + full_name='long_metric.RoadLine.RoadLineType', + filename=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + values=[ + _descriptor.EnumValueDescriptor( + name='TYPE_UNKNOWN', index=0, number=0, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='TYPE_BROKEN_SINGLE_WHITE', index=1, number=1, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='TYPE_SOLID_SINGLE_WHITE', index=2, number=2, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='TYPE_SOLID_DOUBLE_WHITE', index=3, number=3, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='TYPE_BROKEN_SINGLE_YELLOW', index=4, number=4, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='TYPE_BROKEN_DOUBLE_YELLOW', index=5, number=5, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='TYPE_SOLID_SINGLE_YELLOW', index=6, number=6, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='TYPE_SOLID_DOUBLE_YELLOW', index=7, number=7, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='TYPE_PASSING_DOUBLE_YELLOW', index=8, number=8, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + ], + containing_type=None, + serialized_options=None, + serialized_start=2193, + serialized_end=2467, +) +_sym_db.RegisterEnumDescriptor(_ROADLINE_ROADLINETYPE) + + +_MAP = _descriptor.Descriptor( + name='Map', + full_name='long_metric.Map', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='map_features', full_name='long_metric.Map.map_features', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='dynamic_states', full_name='long_metric.Map.dynamic_states', index=1, + number=2, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=45, + serialized_end=148, +) + + +_DYNAMICSTATE = _descriptor.Descriptor( + name='DynamicState', + full_name='long_metric.DynamicState', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='timestamp_seconds', full_name='long_metric.DynamicState.timestamp_seconds', index=0, + number=1, type=1, cpp_type=5, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='lane_states', full_name='long_metric.DynamicState.lane_states', index=1, + number=2, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=150, + serialized_end=249, +) + + +_TRAFFICSIGNALLANESTATE = _descriptor.Descriptor( + name='TrafficSignalLaneState', + full_name='long_metric.TrafficSignalLaneState', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='lane', full_name='long_metric.TrafficSignalLaneState.lane', index=0, + number=1, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='state', full_name='long_metric.TrafficSignalLaneState.state', index=1, + number=2, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='stop_point', full_name='long_metric.TrafficSignalLaneState.stop_point', index=2, + number=3, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + _TRAFFICSIGNALLANESTATE_STATE, + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=252, + serialized_end=634, +) + + +_MAPFEATURE = _descriptor.Descriptor( + name='MapFeature', + full_name='long_metric.MapFeature', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='id', full_name='long_metric.MapFeature.id', index=0, + number=1, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='lane', full_name='long_metric.MapFeature.lane', index=1, + number=3, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='road_line', full_name='long_metric.MapFeature.road_line', index=2, + number=4, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='road_edge', full_name='long_metric.MapFeature.road_edge', index=3, + number=5, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='stop_sign', full_name='long_metric.MapFeature.stop_sign', index=4, + number=7, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='crosswalk', full_name='long_metric.MapFeature.crosswalk', index=5, + number=8, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='speed_bump', full_name='long_metric.MapFeature.speed_bump', index=6, + number=9, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='driveway', full_name='long_metric.MapFeature.driveway', index=7, + number=10, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name='feature_data', full_name='long_metric.MapFeature.feature_data', + index=0, containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[]), + ], + serialized_start=637, + serialized_end=984, +) + + +_MAPPOINT = _descriptor.Descriptor( + name='MapPoint', + full_name='long_metric.MapPoint', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='x', full_name='long_metric.MapPoint.x', index=0, + number=1, type=1, cpp_type=5, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='y', full_name='long_metric.MapPoint.y', index=1, + number=2, type=1, cpp_type=5, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='z', full_name='long_metric.MapPoint.z', index=2, + number=3, type=1, cpp_type=5, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=986, + serialized_end=1029, +) + + +_BOUNDARYSEGMENT = _descriptor.Descriptor( + name='BoundarySegment', + full_name='long_metric.BoundarySegment', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='lane_start_index', full_name='long_metric.BoundarySegment.lane_start_index', index=0, + number=1, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='lane_end_index', full_name='long_metric.BoundarySegment.lane_end_index', index=1, + number=2, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='boundary_feature_id', full_name='long_metric.BoundarySegment.boundary_feature_id', index=2, + number=3, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='boundary_type', full_name='long_metric.BoundarySegment.boundary_type', index=3, + number=4, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1032, + serialized_end=1187, +) + + +_LANENEIGHBOR = _descriptor.Descriptor( + name='LaneNeighbor', + full_name='long_metric.LaneNeighbor', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='feature_id', full_name='long_metric.LaneNeighbor.feature_id', index=0, + number=1, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='self_start_index', full_name='long_metric.LaneNeighbor.self_start_index', index=1, + number=2, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='self_end_index', full_name='long_metric.LaneNeighbor.self_end_index', index=2, + number=3, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='neighbor_start_index', full_name='long_metric.LaneNeighbor.neighbor_start_index', index=3, + number=4, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='neighbor_end_index', full_name='long_metric.LaneNeighbor.neighbor_end_index', index=4, + number=5, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='boundaries', full_name='long_metric.LaneNeighbor.boundaries', index=5, + number=6, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1190, + serialized_end=1382, +) + + +_LANECENTER = _descriptor.Descriptor( + name='LaneCenter', + full_name='long_metric.LaneCenter', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='speed_limit_mph', full_name='long_metric.LaneCenter.speed_limit_mph', index=0, + number=1, type=1, cpp_type=5, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='type', full_name='long_metric.LaneCenter.type', index=1, + number=2, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='interpolating', full_name='long_metric.LaneCenter.interpolating', index=2, + number=3, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='polyline', full_name='long_metric.LaneCenter.polyline', index=3, + number=8, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='entry_lanes', full_name='long_metric.LaneCenter.entry_lanes', index=4, + number=9, type=3, cpp_type=2, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=b'\020\001', file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='exit_lanes', full_name='long_metric.LaneCenter.exit_lanes', index=5, + number=10, type=3, cpp_type=2, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=b'\020\001', file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='left_boundaries', full_name='long_metric.LaneCenter.left_boundaries', index=6, + number=13, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='right_boundaries', full_name='long_metric.LaneCenter.right_boundaries', index=7, + number=14, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='left_neighbors', full_name='long_metric.LaneCenter.left_neighbors', index=8, + number=11, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='right_neighbors', full_name='long_metric.LaneCenter.right_neighbors', index=9, + number=12, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + _LANECENTER_LANETYPE, + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1385, + serialized_end=1892, +) + + +_ROADEDGE = _descriptor.Descriptor( + name='RoadEdge', + full_name='long_metric.RoadEdge', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='type', full_name='long_metric.RoadEdge.type', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='polyline', full_name='long_metric.RoadEdge.polyline', index=1, + number=2, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + _ROADEDGE_ROADEDGETYPE, + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1895, + serialized_end=2086, +) + + +_ROADLINE = _descriptor.Descriptor( + name='RoadLine', + full_name='long_metric.RoadLine', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='type', full_name='long_metric.RoadLine.type', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='polyline', full_name='long_metric.RoadLine.polyline', index=1, + number=2, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + _ROADLINE_ROADLINETYPE, + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=2089, + serialized_end=2467, +) + + +_STOPSIGN = _descriptor.Descriptor( + name='StopSign', + full_name='long_metric.StopSign', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='lane', full_name='long_metric.StopSign.lane', index=0, + number=1, type=3, cpp_type=2, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='position', full_name='long_metric.StopSign.position', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=2469, + serialized_end=2534, +) + + +_CROSSWALK = _descriptor.Descriptor( + name='Crosswalk', + full_name='long_metric.Crosswalk', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='polygon', full_name='long_metric.Crosswalk.polygon', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=2536, + serialized_end=2587, +) + + +_SPEEDBUMP = _descriptor.Descriptor( + name='SpeedBump', + full_name='long_metric.SpeedBump', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='polygon', full_name='long_metric.SpeedBump.polygon', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=2589, + serialized_end=2640, +) + + +_DRIVEWAY = _descriptor.Descriptor( + name='Driveway', + full_name='long_metric.Driveway', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='polygon', full_name='long_metric.Driveway.polygon', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=2642, + serialized_end=2692, +) + +_MAP.fields_by_name['map_features'].message_type = _MAPFEATURE +_MAP.fields_by_name['dynamic_states'].message_type = _DYNAMICSTATE +_DYNAMICSTATE.fields_by_name['lane_states'].message_type = _TRAFFICSIGNALLANESTATE +_TRAFFICSIGNALLANESTATE.fields_by_name['state'].enum_type = _TRAFFICSIGNALLANESTATE_STATE +_TRAFFICSIGNALLANESTATE.fields_by_name['stop_point'].message_type = _MAPPOINT +_TRAFFICSIGNALLANESTATE_STATE.containing_type = _TRAFFICSIGNALLANESTATE +_MAPFEATURE.fields_by_name['lane'].message_type = _LANECENTER +_MAPFEATURE.fields_by_name['road_line'].message_type = _ROADLINE +_MAPFEATURE.fields_by_name['road_edge'].message_type = _ROADEDGE +_MAPFEATURE.fields_by_name['stop_sign'].message_type = _STOPSIGN +_MAPFEATURE.fields_by_name['crosswalk'].message_type = _CROSSWALK +_MAPFEATURE.fields_by_name['speed_bump'].message_type = _SPEEDBUMP +_MAPFEATURE.fields_by_name['driveway'].message_type = _DRIVEWAY +_MAPFEATURE.oneofs_by_name['feature_data'].fields.append( + _MAPFEATURE.fields_by_name['lane']) +_MAPFEATURE.fields_by_name['lane'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data'] +_MAPFEATURE.oneofs_by_name['feature_data'].fields.append( + _MAPFEATURE.fields_by_name['road_line']) +_MAPFEATURE.fields_by_name['road_line'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data'] +_MAPFEATURE.oneofs_by_name['feature_data'].fields.append( + _MAPFEATURE.fields_by_name['road_edge']) +_MAPFEATURE.fields_by_name['road_edge'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data'] +_MAPFEATURE.oneofs_by_name['feature_data'].fields.append( + _MAPFEATURE.fields_by_name['stop_sign']) +_MAPFEATURE.fields_by_name['stop_sign'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data'] +_MAPFEATURE.oneofs_by_name['feature_data'].fields.append( + _MAPFEATURE.fields_by_name['crosswalk']) +_MAPFEATURE.fields_by_name['crosswalk'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data'] +_MAPFEATURE.oneofs_by_name['feature_data'].fields.append( + _MAPFEATURE.fields_by_name['speed_bump']) +_MAPFEATURE.fields_by_name['speed_bump'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data'] +_MAPFEATURE.oneofs_by_name['feature_data'].fields.append( + _MAPFEATURE.fields_by_name['driveway']) +_MAPFEATURE.fields_by_name['driveway'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data'] +_BOUNDARYSEGMENT.fields_by_name['boundary_type'].enum_type = _ROADLINE_ROADLINETYPE +_LANENEIGHBOR.fields_by_name['boundaries'].message_type = _BOUNDARYSEGMENT +_LANECENTER.fields_by_name['type'].enum_type = _LANECENTER_LANETYPE +_LANECENTER.fields_by_name['polyline'].message_type = _MAPPOINT +_LANECENTER.fields_by_name['left_boundaries'].message_type = _BOUNDARYSEGMENT +_LANECENTER.fields_by_name['right_boundaries'].message_type = _BOUNDARYSEGMENT +_LANECENTER.fields_by_name['left_neighbors'].message_type = _LANENEIGHBOR +_LANECENTER.fields_by_name['right_neighbors'].message_type = _LANENEIGHBOR +_LANECENTER_LANETYPE.containing_type = _LANECENTER +_ROADEDGE.fields_by_name['type'].enum_type = _ROADEDGE_ROADEDGETYPE +_ROADEDGE.fields_by_name['polyline'].message_type = _MAPPOINT +_ROADEDGE_ROADEDGETYPE.containing_type = _ROADEDGE +_ROADLINE.fields_by_name['type'].enum_type = _ROADLINE_ROADLINETYPE +_ROADLINE.fields_by_name['polyline'].message_type = _MAPPOINT +_ROADLINE_ROADLINETYPE.containing_type = _ROADLINE +_STOPSIGN.fields_by_name['position'].message_type = _MAPPOINT +_CROSSWALK.fields_by_name['polygon'].message_type = _MAPPOINT +_SPEEDBUMP.fields_by_name['polygon'].message_type = _MAPPOINT +_DRIVEWAY.fields_by_name['polygon'].message_type = _MAPPOINT +DESCRIPTOR.message_types_by_name['Map'] = _MAP +DESCRIPTOR.message_types_by_name['DynamicState'] = _DYNAMICSTATE +DESCRIPTOR.message_types_by_name['TrafficSignalLaneState'] = _TRAFFICSIGNALLANESTATE +DESCRIPTOR.message_types_by_name['MapFeature'] = _MAPFEATURE +DESCRIPTOR.message_types_by_name['MapPoint'] = _MAPPOINT +DESCRIPTOR.message_types_by_name['BoundarySegment'] = _BOUNDARYSEGMENT +DESCRIPTOR.message_types_by_name['LaneNeighbor'] = _LANENEIGHBOR +DESCRIPTOR.message_types_by_name['LaneCenter'] = _LANECENTER +DESCRIPTOR.message_types_by_name['RoadEdge'] = _ROADEDGE +DESCRIPTOR.message_types_by_name['RoadLine'] = _ROADLINE +DESCRIPTOR.message_types_by_name['StopSign'] = _STOPSIGN +DESCRIPTOR.message_types_by_name['Crosswalk'] = _CROSSWALK +DESCRIPTOR.message_types_by_name['SpeedBump'] = _SPEEDBUMP +DESCRIPTOR.message_types_by_name['Driveway'] = _DRIVEWAY +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +Map = _reflection.GeneratedProtocolMessageType('Map', (_message.Message,), { + 'DESCRIPTOR' : _MAP, + '__module__' : 'dev.metrics.protos.map_pb2' + # @@protoc_insertion_point(class_scope:long_metric.Map) + }) +_sym_db.RegisterMessage(Map) + +DynamicState = _reflection.GeneratedProtocolMessageType('DynamicState', (_message.Message,), { + 'DESCRIPTOR' : _DYNAMICSTATE, + '__module__' : 'dev.metrics.protos.map_pb2' + # @@protoc_insertion_point(class_scope:long_metric.DynamicState) + }) +_sym_db.RegisterMessage(DynamicState) + +TrafficSignalLaneState = _reflection.GeneratedProtocolMessageType('TrafficSignalLaneState', (_message.Message,), { + 'DESCRIPTOR' : _TRAFFICSIGNALLANESTATE, + '__module__' : 'dev.metrics.protos.map_pb2' + # @@protoc_insertion_point(class_scope:long_metric.TrafficSignalLaneState) + }) +_sym_db.RegisterMessage(TrafficSignalLaneState) + +MapFeature = _reflection.GeneratedProtocolMessageType('MapFeature', (_message.Message,), { + 'DESCRIPTOR' : _MAPFEATURE, + '__module__' : 'dev.metrics.protos.map_pb2' + # @@protoc_insertion_point(class_scope:long_metric.MapFeature) + }) +_sym_db.RegisterMessage(MapFeature) + +MapPoint = _reflection.GeneratedProtocolMessageType('MapPoint', (_message.Message,), { + 'DESCRIPTOR' : _MAPPOINT, + '__module__' : 'dev.metrics.protos.map_pb2' + # @@protoc_insertion_point(class_scope:long_metric.MapPoint) + }) +_sym_db.RegisterMessage(MapPoint) + +BoundarySegment = _reflection.GeneratedProtocolMessageType('BoundarySegment', (_message.Message,), { + 'DESCRIPTOR' : _BOUNDARYSEGMENT, + '__module__' : 'dev.metrics.protos.map_pb2' + # @@protoc_insertion_point(class_scope:long_metric.BoundarySegment) + }) +_sym_db.RegisterMessage(BoundarySegment) + +LaneNeighbor = _reflection.GeneratedProtocolMessageType('LaneNeighbor', (_message.Message,), { + 'DESCRIPTOR' : _LANENEIGHBOR, + '__module__' : 'dev.metrics.protos.map_pb2' + # @@protoc_insertion_point(class_scope:long_metric.LaneNeighbor) + }) +_sym_db.RegisterMessage(LaneNeighbor) + +LaneCenter = _reflection.GeneratedProtocolMessageType('LaneCenter', (_message.Message,), { + 'DESCRIPTOR' : _LANECENTER, + '__module__' : 'dev.metrics.protos.map_pb2' + # @@protoc_insertion_point(class_scope:long_metric.LaneCenter) + }) +_sym_db.RegisterMessage(LaneCenter) + +RoadEdge = _reflection.GeneratedProtocolMessageType('RoadEdge', (_message.Message,), { + 'DESCRIPTOR' : _ROADEDGE, + '__module__' : 'dev.metrics.protos.map_pb2' + # @@protoc_insertion_point(class_scope:long_metric.RoadEdge) + }) +_sym_db.RegisterMessage(RoadEdge) + +RoadLine = _reflection.GeneratedProtocolMessageType('RoadLine', (_message.Message,), { + 'DESCRIPTOR' : _ROADLINE, + '__module__' : 'dev.metrics.protos.map_pb2' + # @@protoc_insertion_point(class_scope:long_metric.RoadLine) + }) +_sym_db.RegisterMessage(RoadLine) + +StopSign = _reflection.GeneratedProtocolMessageType('StopSign', (_message.Message,), { + 'DESCRIPTOR' : _STOPSIGN, + '__module__' : 'dev.metrics.protos.map_pb2' + # @@protoc_insertion_point(class_scope:long_metric.StopSign) + }) +_sym_db.RegisterMessage(StopSign) + +Crosswalk = _reflection.GeneratedProtocolMessageType('Crosswalk', (_message.Message,), { + 'DESCRIPTOR' : _CROSSWALK, + '__module__' : 'dev.metrics.protos.map_pb2' + # @@protoc_insertion_point(class_scope:long_metric.Crosswalk) + }) +_sym_db.RegisterMessage(Crosswalk) + +SpeedBump = _reflection.GeneratedProtocolMessageType('SpeedBump', (_message.Message,), { + 'DESCRIPTOR' : _SPEEDBUMP, + '__module__' : 'dev.metrics.protos.map_pb2' + # @@protoc_insertion_point(class_scope:long_metric.SpeedBump) + }) +_sym_db.RegisterMessage(SpeedBump) + +Driveway = _reflection.GeneratedProtocolMessageType('Driveway', (_message.Message,), { + 'DESCRIPTOR' : _DRIVEWAY, + '__module__' : 'dev.metrics.protos.map_pb2' + # @@protoc_insertion_point(class_scope:long_metric.Driveway) + }) +_sym_db.RegisterMessage(Driveway) + + +_LANECENTER.fields_by_name['entry_lanes']._options = None +_LANECENTER.fields_by_name['exit_lanes']._options = None +# @@protoc_insertion_point(module_scope) diff --git a/backups/dev/metrics/protos/scenario_pb2.py b/backups/dev/metrics/protos/scenario_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..1d42a98487c4f05b63c105c224558b0f0c655547 --- /dev/null +++ b/backups/dev/metrics/protos/scenario_pb2.py @@ -0,0 +1,454 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: dev/metrics/protos/scenario.proto + +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from dev.metrics.protos import map_pb2 as dev_dot_metrics_dot_protos_dot_map__pb2 + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='dev/metrics/protos/scenario.proto', + package='long_metric', + syntax='proto2', + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_pb=b'\n!dev/metrics/protos/scenario.proto\x12\x0blong_metric\x1a\x1c\x64\x65v/metrics/protos/map.proto\"\xba\x01\n\x0bObjectState\x12\x10\n\x08\x63\x65nter_x\x18\x02 \x01(\x01\x12\x10\n\x08\x63\x65nter_y\x18\x03 \x01(\x01\x12\x10\n\x08\x63\x65nter_z\x18\x04 \x01(\x01\x12\x0e\n\x06length\x18\x05 \x01(\x02\x12\r\n\x05width\x18\x06 \x01(\x02\x12\x0e\n\x06height\x18\x07 \x01(\x02\x12\x0f\n\x07heading\x18\x08 \x01(\x02\x12\x12\n\nvelocity_x\x18\t \x01(\x02\x12\x12\n\nvelocity_y\x18\n \x01(\x02\x12\r\n\x05valid\x18\x0b \x01(\x08\"\xd8\x01\n\x05Track\x12\n\n\x02id\x18\x01 \x01(\x05\x12\x32\n\x0bobject_type\x18\x02 \x01(\x0e\x32\x1d.long_metric.Track.ObjectType\x12(\n\x06states\x18\x03 \x03(\x0b\x32\x18.long_metric.ObjectState\"e\n\nObjectType\x12\x0e\n\nTYPE_UNSET\x10\x00\x12\x10\n\x0cTYPE_VEHICLE\x10\x01\x12\x13\n\x0fTYPE_PEDESTRIAN\x10\x02\x12\x10\n\x0cTYPE_CYCLIST\x10\x03\x12\x0e\n\nTYPE_OTHER\x10\x04\"K\n\x0f\x44ynamicMapState\x12\x38\n\x0blane_states\x18\x01 \x03(\x0b\x32#.long_metric.TrafficSignalLaneState\"\xa5\x01\n\x12RequiredPrediction\x12\x13\n\x0btrack_index\x18\x01 \x01(\x05\x12\x43\n\ndifficulty\x18\x02 \x01(\x0e\x32/.long_metric.RequiredPrediction.DifficultyLevel\"5\n\x0f\x44ifficultyLevel\x12\x08\n\x04NONE\x10\x00\x12\x0b\n\x07LEVEL_1\x10\x01\x12\x0b\n\x07LEVEL_2\x10\x02\"\xdc\x02\n\x08Scenario\x12\x13\n\x0bscenario_id\x18\x05 \x01(\t\x12\x1a\n\x12timestamps_seconds\x18\x01 \x03(\x01\x12\x1a\n\x12\x63urrent_time_index\x18\n \x01(\x05\x12\"\n\x06tracks\x18\x02 \x03(\x0b\x32\x12.long_metric.Track\x12\x38\n\x12\x64ynamic_map_states\x18\x07 \x03(\x0b\x32\x1c.long_metric.DynamicMapState\x12-\n\x0cmap_features\x18\x08 \x03(\x0b\x32\x17.long_metric.MapFeature\x12\x17\n\x0fsdc_track_index\x18\x06 \x01(\x05\x12\x1b\n\x13objects_of_interest\x18\x04 \x03(\x05\x12:\n\x11tracks_to_predict\x18\x0b \x03(\x0b\x32\x1f.long_metric.RequiredPredictionJ\x04\x08\t\x10\n' + , + dependencies=[dev_dot_metrics_dot_protos_dot_map__pb2.DESCRIPTOR,]) + + + +_TRACK_OBJECTTYPE = _descriptor.EnumDescriptor( + name='ObjectType', + full_name='long_metric.Track.ObjectType', + filename=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + values=[ + _descriptor.EnumValueDescriptor( + name='TYPE_UNSET', index=0, number=0, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='TYPE_VEHICLE', index=1, number=1, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='TYPE_PEDESTRIAN', index=2, number=2, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='TYPE_CYCLIST', index=3, number=3, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='TYPE_OTHER', index=4, number=4, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + ], + containing_type=None, + serialized_options=None, + serialized_start=385, + serialized_end=486, +) +_sym_db.RegisterEnumDescriptor(_TRACK_OBJECTTYPE) + +_REQUIREDPREDICTION_DIFFICULTYLEVEL = _descriptor.EnumDescriptor( + name='DifficultyLevel', + full_name='long_metric.RequiredPrediction.DifficultyLevel', + filename=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + values=[ + _descriptor.EnumValueDescriptor( + name='NONE', index=0, number=0, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='LEVEL_1', index=1, number=1, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='LEVEL_2', index=2, number=2, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + ], + containing_type=None, + serialized_options=None, + serialized_start=678, + serialized_end=731, +) +_sym_db.RegisterEnumDescriptor(_REQUIREDPREDICTION_DIFFICULTYLEVEL) + + +_OBJECTSTATE = _descriptor.Descriptor( + name='ObjectState', + full_name='long_metric.ObjectState', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='center_x', full_name='long_metric.ObjectState.center_x', index=0, + number=2, type=1, cpp_type=5, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='center_y', full_name='long_metric.ObjectState.center_y', index=1, + number=3, type=1, cpp_type=5, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='center_z', full_name='long_metric.ObjectState.center_z', index=2, + number=4, type=1, cpp_type=5, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='length', full_name='long_metric.ObjectState.length', index=3, + number=5, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='width', full_name='long_metric.ObjectState.width', index=4, + number=6, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='height', full_name='long_metric.ObjectState.height', index=5, + number=7, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='heading', full_name='long_metric.ObjectState.heading', index=6, + number=8, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='velocity_x', full_name='long_metric.ObjectState.velocity_x', index=7, + number=9, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='velocity_y', full_name='long_metric.ObjectState.velocity_y', index=8, + number=10, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='valid', full_name='long_metric.ObjectState.valid', index=9, + number=11, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=81, + serialized_end=267, +) + + +_TRACK = _descriptor.Descriptor( + name='Track', + full_name='long_metric.Track', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='id', full_name='long_metric.Track.id', index=0, + number=1, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='object_type', full_name='long_metric.Track.object_type', index=1, + number=2, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='states', full_name='long_metric.Track.states', index=2, + number=3, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + _TRACK_OBJECTTYPE, + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=270, + serialized_end=486, +) + + +_DYNAMICMAPSTATE = _descriptor.Descriptor( + name='DynamicMapState', + full_name='long_metric.DynamicMapState', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='lane_states', full_name='long_metric.DynamicMapState.lane_states', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=488, + serialized_end=563, +) + + +_REQUIREDPREDICTION = _descriptor.Descriptor( + name='RequiredPrediction', + full_name='long_metric.RequiredPrediction', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='track_index', full_name='long_metric.RequiredPrediction.track_index', index=0, + number=1, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='difficulty', full_name='long_metric.RequiredPrediction.difficulty', index=1, + number=2, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + _REQUIREDPREDICTION_DIFFICULTYLEVEL, + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=566, + serialized_end=731, +) + + +_SCENARIO = _descriptor.Descriptor( + name='Scenario', + full_name='long_metric.Scenario', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='scenario_id', full_name='long_metric.Scenario.scenario_id', index=0, + number=5, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='timestamps_seconds', full_name='long_metric.Scenario.timestamps_seconds', index=1, + number=1, type=1, cpp_type=5, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='current_time_index', full_name='long_metric.Scenario.current_time_index', index=2, + number=10, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='tracks', full_name='long_metric.Scenario.tracks', index=3, + number=2, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='dynamic_map_states', full_name='long_metric.Scenario.dynamic_map_states', index=4, + number=7, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='map_features', full_name='long_metric.Scenario.map_features', index=5, + number=8, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='sdc_track_index', full_name='long_metric.Scenario.sdc_track_index', index=6, + number=6, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='objects_of_interest', full_name='long_metric.Scenario.objects_of_interest', index=7, + number=4, type=5, cpp_type=1, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='tracks_to_predict', full_name='long_metric.Scenario.tracks_to_predict', index=8, + number=11, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=734, + serialized_end=1082, +) + +_TRACK.fields_by_name['object_type'].enum_type = _TRACK_OBJECTTYPE +_TRACK.fields_by_name['states'].message_type = _OBJECTSTATE +_TRACK_OBJECTTYPE.containing_type = _TRACK +_DYNAMICMAPSTATE.fields_by_name['lane_states'].message_type = dev_dot_metrics_dot_protos_dot_map__pb2._TRAFFICSIGNALLANESTATE +_REQUIREDPREDICTION.fields_by_name['difficulty'].enum_type = _REQUIREDPREDICTION_DIFFICULTYLEVEL +_REQUIREDPREDICTION_DIFFICULTYLEVEL.containing_type = _REQUIREDPREDICTION +_SCENARIO.fields_by_name['tracks'].message_type = _TRACK +_SCENARIO.fields_by_name['dynamic_map_states'].message_type = _DYNAMICMAPSTATE +_SCENARIO.fields_by_name['map_features'].message_type = dev_dot_metrics_dot_protos_dot_map__pb2._MAPFEATURE +_SCENARIO.fields_by_name['tracks_to_predict'].message_type = _REQUIREDPREDICTION +DESCRIPTOR.message_types_by_name['ObjectState'] = _OBJECTSTATE +DESCRIPTOR.message_types_by_name['Track'] = _TRACK +DESCRIPTOR.message_types_by_name['DynamicMapState'] = _DYNAMICMAPSTATE +DESCRIPTOR.message_types_by_name['RequiredPrediction'] = _REQUIREDPREDICTION +DESCRIPTOR.message_types_by_name['Scenario'] = _SCENARIO +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +ObjectState = _reflection.GeneratedProtocolMessageType('ObjectState', (_message.Message,), { + 'DESCRIPTOR' : _OBJECTSTATE, + '__module__' : 'dev.metrics.protos.scenario_pb2' + # @@protoc_insertion_point(class_scope:long_metric.ObjectState) + }) +_sym_db.RegisterMessage(ObjectState) + +Track = _reflection.GeneratedProtocolMessageType('Track', (_message.Message,), { + 'DESCRIPTOR' : _TRACK, + '__module__' : 'dev.metrics.protos.scenario_pb2' + # @@protoc_insertion_point(class_scope:long_metric.Track) + }) +_sym_db.RegisterMessage(Track) + +DynamicMapState = _reflection.GeneratedProtocolMessageType('DynamicMapState', (_message.Message,), { + 'DESCRIPTOR' : _DYNAMICMAPSTATE, + '__module__' : 'dev.metrics.protos.scenario_pb2' + # @@protoc_insertion_point(class_scope:long_metric.DynamicMapState) + }) +_sym_db.RegisterMessage(DynamicMapState) + +RequiredPrediction = _reflection.GeneratedProtocolMessageType('RequiredPrediction', (_message.Message,), { + 'DESCRIPTOR' : _REQUIREDPREDICTION, + '__module__' : 'dev.metrics.protos.scenario_pb2' + # @@protoc_insertion_point(class_scope:long_metric.RequiredPrediction) + }) +_sym_db.RegisterMessage(RequiredPrediction) + +Scenario = _reflection.GeneratedProtocolMessageType('Scenario', (_message.Message,), { + 'DESCRIPTOR' : _SCENARIO, + '__module__' : 'dev.metrics.protos.scenario_pb2' + # @@protoc_insertion_point(class_scope:long_metric.Scenario) + }) +_sym_db.RegisterMessage(Scenario) + + +# @@protoc_insertion_point(module_scope) diff --git a/backups/dev/metrics/trajectory_features.py b/backups/dev/metrics/trajectory_features.py new file mode 100644 index 0000000000000000000000000000000000000000..6f2c58be898a0f6e8edaa9eca3d15cc0f3a54839 --- /dev/null +++ b/backups/dev/metrics/trajectory_features.py @@ -0,0 +1,52 @@ +import torch +import numpy as np +from torch import Tensor +from typing import Tuple + + +def _wrap_angle(angle: Tensor) -> Tensor: + return (angle + np.pi) % (2 * np.pi) - np.pi + + +def central_diff(t: Tensor, pad_value: float) -> Tensor: + pad_shape = (*t.shape[:-1], 1) + pad_tensor = torch.full(pad_shape, pad_value, dtype=t.dtype, device=t.device) + diff_t = (t[..., 2:] - t[..., :-2]) / 2 + return torch.cat([pad_tensor, diff_t, pad_tensor], dim=-1) + + +def central_logical_and(t: Tensor, pad_value: bool) -> Tensor: + pad_shape = (*t.shape[:-1], 1) + pad_tensor = torch.full(pad_shape, pad_value, dtype=torch.bool, device=t.device) + diff_t = torch.logical_and(t[..., 2:], t[..., :-2]) + return torch.cat([pad_tensor, diff_t, pad_tensor], dim=-1) + + +def compute_displacement_error(x, y, z, ref_x, ref_y, ref_z) -> Tensor: + return torch.norm( + torch.stack([x, y, z], dim=-1) - torch.stack([ref_x, ref_y, ref_z], dim=-1), + p=2, dim=-1 + ) + + +def compute_kinematic_features( + x: Tensor, + y: Tensor, + z: Tensor, + heading: Tensor, + seconds_per_step: float +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + dpos = central_diff(torch.stack([x, y, z], dim=0), pad_value=np.nan) + linear_speed = torch.norm(dpos, p=2, dim=0) / seconds_per_step + linear_accel = central_diff(linear_speed, pad_value=np.nan) / seconds_per_step + dh_step = _wrap_angle(central_diff(heading, pad_value=np.nan) * 2) / 2 + dh = dh_step / seconds_per_step + d2h_step = _wrap_angle(central_diff(dh_step, pad_value=np.nan) * 2) / 2 + d2h = d2h_step / (seconds_per_step ** 2) + return linear_speed, linear_accel, dh, d2h + + +def compute_kinematic_validity(valid: Tensor) -> Tuple[Tensor, Tensor]: + speed_validity = central_logical_and(valid, pad_value=False) + acceleration_validity = central_logical_and(speed_validity, pad_value=False) + return speed_validity, acceleration_validity \ No newline at end of file diff --git a/backups/dev/metrics/val_close_long_metrics.json b/backups/dev/metrics/val_close_long_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..1b29bfbdf7c03abce8c5edd959cb2b2487763197 --- /dev/null +++ b/backups/dev/metrics/val_close_long_metrics.json @@ -0,0 +1,24 @@ +{ + "val_close_long/wosac/realism_meta_metric": 0.6187323331832886, + "val_close_long/wosac/kinematic_metrics": 0.6323384046554565, + "val_close_long/wosac/interactive_metrics": 0.5528579354286194, + "val_close_long/wosac/map_based_metrics": 0.0, + "val_close_long/wosac/placement_based_metrics": 0.6086956858634949, + "val_close_long/wosac/min_ade": 0.0, + "val_close_long/wosac/scenario_counter": 61, + "val_close_long/wosac_likelihood/metametric": 0.6187323331832886, + "val_close_long/wosac_likelihood/average_displacement_error": 0.0, + "val_close_long/wosac_likelihood/min_average_displacement_error": 0.0, + "val_close_long/wosac_likelihood/linear_speed_likelihood": 0.11858943104743958, + "val_close_long/wosac_likelihood/linear_acceleration_likelihood": 0.6093839406967163, + "val_close_long/wosac_likelihood/angular_speed_likelihood": 0.8988037705421448, + "val_close_long/wosac_likelihood/angular_acceleration_likelihood": 0.9025763869285583, + "val_close_long/wosac_likelihood/distance_to_nearest_object_likelihood": 0.10390616208314896, + "val_close_long/wosac_likelihood/collision_indication_likelihood": 0.6108496785163879, + "val_close_long/wosac_likelihood/time_to_collision_likelihood": 0.8568302989006042, + "val_close_long/wosac_likelihood/simulated_collision_rate": 0.030917881056666374, + "val_close_long/wosac_likelihood/num_placement_likelihood": 0.7245867848396301, + "val_close_long/wosac_likelihood/num_removement_likelihood": 0.6228984594345093, + "val_close_long/wosac_likelihood/distance_placement_likelihood": 1.0, + "val_close_long/wosac_likelihood/distance_removement_likelihood": 0.08729743212461472 +} \ No newline at end of file diff --git a/backups/dev/model/smart.py b/backups/dev/model/smart.py new file mode 100644 index 0000000000000000000000000000000000000000..2112540727480a5706b795c90a653c560b6bb14a --- /dev/null +++ b/backups/dev/model/smart.py @@ -0,0 +1,1100 @@ +import os +import contextlib +import pytorch_lightning as pl +import math +import numpy as np +import pickle +import random +import torch +import torch.nn as nn +from tqdm import tqdm +from torch_geometric.data import Batch +from torch_geometric.data import HeteroData +from torch.optim.lr_scheduler import LambdaLR +from collections import defaultdict + +from dev.utils.func import angle_between_2d_vectors +from dev.modules.layers import OccLoss +from dev.modules.attr_tokenizer import Attr_Tokenizer +from dev.modules.smart_decoder import SMARTDecoder +from dev.datasets.preprocess import TokenProcessor +from dev.metrics.compute_metrics import * +from dev.utils.metrics import * +from dev.utils.visualization import * + + +class SMART(pl.LightningModule): + + def __init__(self, model_config, save_path: os.PathLike="", logger=None, **kwargs) -> None: + super(SMART, self).__init__() + self.save_hyperparameters() + self.model_config = model_config + self.warmup_steps = model_config.warmup_steps + self.lr = model_config.lr + self.total_steps = model_config.total_steps + self.dataset = model_config.dataset + self.input_dim = model_config.input_dim + self.hidden_dim = model_config.hidden_dim + self.output_dim = model_config.output_dim + self.output_head = model_config.output_head + self.num_historical_steps = model_config.num_historical_steps + self.num_future_steps = model_config.decoder.num_future_steps + self.num_freq_bands = model_config.num_freq_bands + self.save_path = save_path + self.vis_map = False + self.noise = True + self.local_logger = logger + self.max_epochs = kwargs.get('max_epochs', 0) + module_dir = os.path.dirname(os.path.dirname(__file__)) + + self.map_token_traj_path = os.path.join(module_dir, 'tokens/map_traj_token5.pkl') + self.init_map_token() + + self.predict_motion = model_config.predict_motion + self.predict_state = model_config.predict_state + self.predict_map = model_config.predict_map + self.predict_occ = model_config.predict_occ + self.pl2seed_radius = model_config.decoder.pl2seed_radius + self.token_size = model_config.decoder.token_size + + # if `disable_grid_token` is True, then we process all locations as + # the continuous values. Besides, no occupancy grid input. + # Also, no need to predict the xy offset. + self.disable_grid_token = getattr(model_config, 'disable_grid_token') \ + if hasattr(model_config, 'disable_grid_token') else False + self.use_grid_token = not self.disable_grid_token + if self.disable_grid_token: + self.predict_occ = False + + self.token_processer = TokenProcessor(self.token_size, + training=self.training, + predict_motion=self.predict_motion, + predict_state=self.predict_state, + predict_map=self.predict_map, + state_token=model_config.state_token, + pl2seed_radius=self.pl2seed_radius) + + self.attr_tokenizer = Attr_Tokenizer(grid_range=self.model_config.grid_range, + grid_interval=self.model_config.grid_interval, + radius=model_config.decoder.pl2seed_radius, + angle_interval=self.model_config.angle_interval) + + # state tokens + self.invalid_state = int(self.model_config.state_token['invalid']) + self.valid_state = int(self.model_config.state_token['valid']) + self.enter_state = int(self.model_config.state_token['enter']) + self.exit_state = int(self.model_config.state_token['exit']) + + self.seed_size = int(model_config.decoder.seed_size) + + self.encoder = SMARTDecoder( + decoder_type=model_config.decoder_type, + dataset=model_config.dataset, + input_dim=model_config.input_dim, + hidden_dim=model_config.hidden_dim, + num_historical_steps=model_config.num_historical_steps, + num_freq_bands=model_config.num_freq_bands, + num_heads=model_config.num_heads, + head_dim=model_config.head_dim, + dropout=model_config.dropout, + num_map_layers=model_config.decoder.num_map_layers, + num_agent_layers=model_config.decoder.num_agent_layers, + pl2pl_radius=model_config.decoder.pl2pl_radius, + pl2a_radius=model_config.decoder.pl2a_radius, + pl2seed_radius=model_config.decoder.pl2seed_radius, + a2a_radius=model_config.decoder.a2a_radius, + a2sa_radius=model_config.decoder.a2sa_radius, + pl2sa_radius=model_config.decoder.pl2sa_radius, + time_span=model_config.decoder.time_span, + map_token={'traj_src': self.map_token['traj_src']}, + token_size=self.token_size, + attr_tokenizer=self.attr_tokenizer, + predict_motion=self.predict_motion, + predict_state=self.predict_state, + predict_map=self.predict_map, + predict_occ=self.predict_occ, + state_token=model_config.state_token, + use_grid_token=self.use_grid_token, + seed_size=self.seed_size, + buffer_size=model_config.decoder.buffer_size, + num_recurrent_steps_val=model_config.num_recurrent_steps_val, + loss_weight=model_config.loss_weight, + logger=logger, + ) + self.minADE = minADE(max_guesses=1) + self.minFDE = minFDE(max_guesses=1) + self.TokenCls = TokenCls(max_guesses=1) + self.StateCls = TokenCls(max_guesses=1) + self.StateAccuracy = StateAccuracy(state_token=self.model_config.state_token) + self.GridOverlapRate = GridOverlapRate(num_step=18, + state_token=self.model_config.state_token, + seed_size=self.encoder.agent_encoder.num_seed_feature) + # self.NumInsertAccuracy = NumInsertAccuracy() + self.loss_weight = model_config.loss_weight + + self.token_cls_loss = nn.CrossEntropyLoss(label_smoothing=0.1) + if self.predict_map: + self.map_token_loss = nn.CrossEntropyLoss(label_smoothing=0.1) + if self.predict_state: + self.state_cls_loss = nn.CrossEntropyLoss( + torch.tensor(self.loss_weight['state_weight'])) + self.state_cls_loss_seed = nn.CrossEntropyLoss( + torch.tensor(self.loss_weight['seed_state_weight'])) + self.type_cls_loss_seed = nn.CrossEntropyLoss( + torch.tensor(self.loss_weight['seed_type_weight'])) + self.pos_cls_loss_seed = nn.CrossEntropyLoss(label_smoothing=0.1) + self.head_cls_loss_seed = nn.CrossEntropyLoss() + self.offset_reg_loss_seed = nn.MSELoss() + self.shape_reg_loss_seed = nn.MSELoss() + self.pos_reg_loss_seed = nn.MSELoss() + if self.predict_occ: + self.occ_cls_loss = nn.CrossEntropyLoss() + self.agent_occ_loss_seed = nn.BCEWithLogitsLoss( + pos_weight=torch.tensor([self.loss_weight['agent_occ_pos_weight']])) + self.pt_occ_loss_seed = nn.BCEWithLogitsLoss( + pos_weight=torch.tensor([self.loss_weight['pt_occ_pos_weight']])) + # self.agent_occ_loss_seed = OccLoss() + # self.pt_occ_loss_seed = OccLoss() + # self.agent_occ_loss_seed = nn.BCEWithLogitsLoss() + # self.pt_occ_loss_seed = nn.BCEWithLogitsLoss() + self.rollout_num = 1 + + self.val_open_loop = model_config.val_open_loop + self.val_close_loop = model_config.val_close_loop + self.val_insert = model_config.val_insert or bool(os.getenv('VAL_INSERT')) + self.n_rollout_close_val = model_config.n_rollout_close_val + self.t = kwargs.get('t', 2) + + # for validation / test + self._mode = 'training' + self._long_metrics = None + self._online_metric = False + self._save_validate_reuslts = False + self._plot_rollouts = False + + def set(self, mode: str = 'train'): + self._mode = mode + + if mode == 'validation': + self._online_metric = True + self._save_validate_reuslts = True + self._long_metrics = LongMetric('val_close_long') + + elif mode == 'test': + self._save_validate_reuslts = True + + elif mode == 'plot_rollouts': + self._plot_rollouts = True + + def init_map_token(self): + self.argmin_sample_len = 3 + map_token_traj = pickle.load(open(self.map_token_traj_path, 'rb')) + self.map_token = {'traj_src': map_token_traj['traj_src'], } + traj_end_theta = np.arctan2(self.map_token['traj_src'][:, -1, 1]-self.map_token['traj_src'][:, -2, 1], + self.map_token['traj_src'][:, -1, 0]-self.map_token['traj_src'][:, -2, 0]) + indices = torch.linspace(0, self.map_token['traj_src'].shape[1]-1, steps=self.argmin_sample_len).long() + self.map_token['sample_pt'] = torch.from_numpy(self.map_token['traj_src'][:, indices]).to(torch.float) + self.map_token['traj_end_theta'] = torch.from_numpy(traj_end_theta).to(torch.float) + self.map_token['traj_src'] = torch.from_numpy(self.map_token['traj_src']).to(torch.float) + + def get_agent_inputs(self, data: HeteroData): + res = self.encoder.get_agent_inputs(data) + return res + + def forward(self, data: HeteroData): + res = self.encoder(data) + return res + + def maybe_autocast(self, dtype=torch.float16): + enable_autocast = self.device != torch.device("cpu") + + if enable_autocast: + return torch.cuda.amp.autocast(dtype=dtype) + else: + return contextlib.nullcontext() + + def check_inputs(self, data: HeteroData): + inputs = self.get_agent_inputs(data) + next_token_idx_gt = inputs['next_token_idx_gt'] + next_state_idx_gt = inputs['next_state_idx_gt'].clone() + next_token_eval_mask = inputs['next_token_eval_mask'].clone() + raw_agent_valid_mask = inputs['raw_agent_valid_mask'].clone() + + self.StateAccuracy.update(state_idx=next_state_idx_gt, + valid_mask=raw_agent_valid_mask) + + state_token = inputs['state_token'] + grid_index = inputs['grid_index'] + self.GridOverlapRate.update(state_token=state_token, + grid_index=grid_index) + + print(self.StateAccuracy) + print(self.GridOverlapRate) + # self.log('valid_accuracy', self.StateAccuracy.compute()['valid'], prog_bar=True, on_step=True, on_epoch=True, batch_size=1) + # self.log('invalid_accuracy', self.StateAccuracy.compute()['invalid'], prog_bar=True, on_step=True, on_epoch=True, batch_size=1) + + def training_step(self, + data, + batch_idx): + + data = self.token_processer(data) + + data = self.match_token_map(data) + data = self.sample_pt_pred(data) + + # find map tokens for entering agents + data = self._fetch_enterings(data) + + data['batch_size_a'] = data['agent']['ptr'][1:] - data['agent']['ptr'][:-1] + data['batch_size_pl'] = data['pt_token']['ptr'][1:] - data['pt_token']['ptr'][:-1] + if isinstance(data, Batch): + data['agent']['av_index'] += data['agent']['ptr'][:-1] + + if int(os.getenv("CHECK_INPUTS", 0)): + return self.check_inputs(data) + + pred = self(data) + + loss = 0 + + log_params = dict(prog_bar=True, on_step=True, on_epoch=True, batch_size=1, sync_dist=True) + + if pred.get('occ_decoder', False): + + agent_occ = pred['agent_occ'] + agent_occ_gt = pred['agent_occ_gt'] + agent_occ_eval_mask = pred['agent_occ_eval_mask'] + pt_occ = pred['pt_occ'] + pt_occ_gt = pred['pt_occ_gt'] + pt_occ_eval_mask = pred['pt_occ_eval_mask'] + + agent_occ_cls_loss = self.occ_cls_loss(agent_occ[agent_occ_eval_mask], + agent_occ_gt[agent_occ_eval_mask]) + pt_occ_cls_loss = self.occ_cls_loss(pt_occ[pt_occ_eval_mask], + pt_occ_gt[pt_occ_eval_mask]) + self.log('agent_occ_cls_loss', agent_occ_cls_loss, **log_params) + self.log('pt_occ_cls_loss', pt_occ_cls_loss, **log_params) + loss = loss + agent_occ_cls_loss + pt_occ_cls_loss + + # plot + # plot_scenario_ids = ['74ad7b76d5906d39', '1351ea8b8333ddcb', '1352066cc3c0508d', '135436833ce5b9e7', '13570a32432d449', '13577c32a81336fb'] + if random.random() < 4e-5 or os.getenv('DEBUG'): + num_step = pred['num_step'] + num_agent = pred['num_agent'] + num_pt = pred['num_pt'] + with torch.no_grad(): + agent_occ = agent_occ.reshape(num_step, num_agent, -1).transpose(0, 1) + agent_occ_gt = agent_occ_gt.reshape(num_step, num_agent).transpose(0, 1) + agent_occ_gt[agent_occ_gt == -1] = self.encoder.agent_encoder.grid_size // 2 + agent_occ_gt = torch.nn.functional.one_hot(agent_occ_gt, num_classes=self.encoder.agent_encoder.grid_size) + agent_occ = self.attr_tokenizer.pad_square(agent_occ.softmax(-1).detach().cpu().numpy())[0] + agent_occ_gt = self.attr_tokenizer.pad_square(agent_occ_gt.detach().cpu().numpy())[0] + plot_occ_grid(pred['scenario_id'][0], + agent_occ, + gt_occ=agent_occ_gt, + mode='agent', + save_path=self.save_path, + prefix=f'training_{self.global_step:06d}_') + pt_occ = pt_occ.reshape(num_step, num_pt, -1).transpose(0, 1) + pt_occ_gt = pt_occ_gt.reshape(num_step, num_pt).transpose(0, 1) + pt_occ_gt[pt_occ_gt == -1] = self.encoder.agent_encoder.grid_size // 2 + pt_occ_gt = torch.nn.functional.one_hot(pt_occ_gt, num_classes=self.encoder.agent_encoder.grid_size) + pt_occ = self.attr_tokenizer.pad_square(pt_occ.sigmoid().detach().cpu().numpy())[0] + pt_occ_gt = self.attr_tokenizer.pad_square(pt_occ_gt.detach().cpu().numpy())[0] + plot_occ_grid(pred['scenario_id'][0], + pt_occ, + gt_occ=pt_occ_gt, + mode='pt', + save_path=self.save_path, + prefix=f'training_{self.global_step:06d}_') + + return loss + + train_mask = data['agent']['train_mask'] + # remove_ina_mask = data['agent']['remove_ina_mask'] + + # motion token loss + if self.predict_motion: + + next_token_idx = pred['next_token_idx'] + next_token_prob = pred['next_token_prob'] # (a, t, token_size) + next_token_idx_gt = pred['next_token_idx_gt'] # (a, t) + next_token_eval_mask = pred['next_token_eval_mask'] # (a, t) + next_token_eval_mask &= train_mask[:, None] + + token_cls_loss = self.token_cls_loss(next_token_prob[next_token_eval_mask], + next_token_idx_gt[next_token_eval_mask]) * self.loss_weight['token_cls_loss'] + self.log('token_cls_loss', token_cls_loss, **log_params) + + loss = loss + token_cls_loss + + # record motion predict precision of certain timesteps of centain type of agents + with torch.no_grad(): + agent_state_idx_gt = data['agent']['state_idx'] + index = torch.nonzero(agent_state_idx_gt == self.enter_state) + for i in range(10): + index[:, 1] += 1 + index = index[index[:, 1] < agent_state_idx_gt.shape[1] - 1] + prob = next_token_prob[index[:, 0], index[:, 1]] + gt = next_token_idx_gt[index[:, 0], index[:, 1]] + mask = next_token_eval_mask[index[:, 0], index[:, 1]] + step_token_cls_loss = self.token_cls_loss(prob[mask], gt[mask]) + self.log(f's{i}', step_token_cls_loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True) + + # state token loss + if self.predict_state: + + next_state_idx = pred['next_state_idx'] + next_state_prob = pred['next_state_prob'] + next_state_idx_gt = pred['next_state_idx_gt'] + next_state_eval_mask = pred['next_state_eval_mask'] # (num_agent, num_timestep) + + state_cls_loss = self.state_cls_loss(next_state_prob[next_state_eval_mask], + next_state_idx_gt[next_state_eval_mask]) * self.loss_weight['state_cls_loss'] + if torch.isnan(state_cls_loss): + print("Found NaN in state_cls_loss!!!") + print(next_state_prob.shape) + print(next_state_idx_gt.shape) + print(next_state_eval_mask.shape) + print(next_state_idx_gt[next_state_eval_mask].shape) + self.log('state_cls_loss', state_cls_loss, **log_params) + + loss = loss + state_cls_loss + + next_state_idx_seed = pred['next_state_idx_seed'] + next_state_prob_seed = pred['next_state_prob_seed'] + next_state_idx_gt_seed = pred['next_state_idx_gt_seed'] + next_type_prob_seed = pred['next_type_prob_seed'] + next_type_idx_gt_seed = pred['next_type_idx_gt_seed'] + next_shape_seed = pred['next_shape_seed'] + next_shape_gt_seed = pred['next_shape_gt_seed'] + next_state_eval_mask_seed = pred['next_state_eval_mask_seed'] + next_attr_eval_mask_seed = pred['next_attr_eval_mask_seed'] + + # when num_seed_gt=0 loss term will be NaN + state_cls_loss_seed = self.state_cls_loss_seed(next_state_prob_seed[next_state_eval_mask_seed], + next_state_idx_gt_seed[next_state_eval_mask_seed]) * self.loss_weight['state_cls_loss'] + state_cls_loss_seed = torch.nan_to_num(state_cls_loss_seed) + self.log('seed_state_cls_loss', state_cls_loss_seed, **log_params) + + type_cls_loss_seed = self.type_cls_loss_seed(next_type_prob_seed[next_attr_eval_mask_seed], + next_type_idx_gt_seed[next_attr_eval_mask_seed]) * self.loss_weight['type_cls_loss'] + shape_reg_loss_seed = self.shape_reg_loss_seed(next_shape_seed[next_attr_eval_mask_seed], + next_shape_gt_seed[next_attr_eval_mask_seed]) * self.loss_weight['shape_reg_loss'] + type_cls_loss_seed = torch.nan_to_num(type_cls_loss_seed) + shape_reg_loss_seed = torch.nan_to_num(shape_reg_loss_seed) + self.log('seed_type_cls_loss', type_cls_loss_seed, **log_params) + self.log('seed_shape_reg_loss', shape_reg_loss_seed, **log_params) + + loss = loss + state_cls_loss_seed + type_cls_loss_seed + shape_reg_loss_seed + + next_head_rel_prob_seed = pred['next_head_rel_prob_seed'] + next_head_rel_index_gt_seed = pred['next_head_rel_index_gt_seed'] + next_offset_xy_seed = pred['next_offset_xy_seed'] + next_offset_xy_gt_seed = pred['next_offset_xy_gt_seed'] + next_head_eval_mask_seed = pred['next_head_eval_mask_seed'] + + if self.use_grid_token: + next_pos_rel_prob_seed = pred['next_pos_rel_prob_seed'] + next_pos_rel_index_gt_seed = pred['next_pos_rel_index_gt_seed'] + + pos_cls_loss_seed = self.pos_cls_loss_seed(next_pos_rel_prob_seed[next_attr_eval_mask_seed], + next_pos_rel_index_gt_seed[next_attr_eval_mask_seed]) * self.loss_weight['pos_cls_loss'] + offset_reg_loss_seed = self.offset_reg_loss_seed(next_offset_xy_seed[next_head_eval_mask_seed], + next_offset_xy_gt_seed[next_head_eval_mask_seed]) * self.loss_weight['offset_reg_loss'] + pos_cls_loss_seed = torch.nan_to_num(pos_cls_loss_seed) + self.log('seed_pos_cls_loss', pos_cls_loss_seed, **log_params) + self.log('seed_offset_reg_loss', offset_reg_loss_seed, **log_params) + + loss = loss + pos_cls_loss_seed + offset_reg_loss_seed + + else: + next_pos_rel_xy_seed = pred['next_pos_rel_xy_seed'] + next_pos_rel_xy_gt_seed = pred['next_pos_rel_xy_gt_seed'] + pos_reg_loss_seed = self.pos_reg_loss_seed(next_pos_rel_xy_seed[next_attr_eval_mask_seed], + next_pos_rel_xy_gt_seed[next_attr_eval_mask_seed]) * self.loss_weight['pos_reg_loss'] + pos_reg_loss_seed = torch.nan_to_num(pos_reg_loss_seed) + self.log('seed_pos_reg_loss', pos_reg_loss_seed, **log_params) + loss = loss + pos_reg_loss_seed + + head_cls_loss_seed = self.head_cls_loss_seed(next_head_rel_prob_seed[next_head_eval_mask_seed], + next_head_rel_index_gt_seed[next_head_eval_mask_seed]) * self.loss_weight['head_cls_loss'] + self.log('seed_head_cls_loss', head_cls_loss_seed, **log_params) + + loss = loss + head_cls_loss_seed + + # plot_scenario_ids = ['74ad7b76d5906d39', '1351ea8b8333ddcb', '1352066cc3c0508d', '135436833ce5b9e7', '13570a32432d449', '13577c32a81336fb'] + if random.random() < 4e-5 or int(os.getenv('DEBUG', 0)): + with torch.no_grad(): + # plot probability of inserting new agent (agent-timestep) + raw_next_state_prob_seed = pred['raw_next_state_prob_seed'] + plot_prob_seed(pred['scenario_id'][0], + torch.softmax(raw_next_state_prob_seed, dim=-1 + )[..., -1].detach().cpu().numpy(), + self.save_path, + prefix=f'training_{self.global_step:06d}_', + indices=pred['target_indices'].cpu().numpy()) + + # plot heatmap of inserting new agent + if self.use_grid_token: + next_pos_rel_prob_seed = pred['next_pos_rel_prob_seed'] + if next_pos_rel_prob_seed.shape[0] > 0: + next_pos_rel_prob_seed = torch.softmax(next_pos_rel_prob_seed, dim=-1).detach().cpu().numpy() + indices = next_pos_rel_index_gt_seed.detach().cpu().numpy() + mask = next_attr_eval_mask_seed.detach().cpu().numpy().astype(np.bool_) + indices[~mask] = -1 + prob, indices = self.attr_tokenizer.pad_square(next_pos_rel_prob_seed, indices) + plot_insert_grid(pred['scenario_id'][0], + prob, + indices=indices, + save_path=self.save_path, + prefix=f'training_{self.global_step:06d}_') + + if self.predict_occ: + + neighbor_agent_grid_idx = pred['neighbor_agent_grid_idx'] + neighbor_agent_grid_index_gt = pred['neighbor_agent_grid_index_gt'] + neighbor_agent_grid_index_eval_mask = pred['neighbor_agent_grid_index_eval_mask'] + neighbor_pt_grid_idx = pred['neighbor_pt_grid_idx'] + neighbor_pt_grid_index_gt = pred['neighbor_pt_grid_index_gt'] + neighbor_pt_grid_index_eval_mask = pred['neighbor_pt_grid_index_eval_mask'] + + neighbor_agent_grid_cls_loss = self.occ_cls_loss(neighbor_agent_grid_idx[neighbor_agent_grid_index_eval_mask], + neighbor_agent_grid_index_gt[neighbor_agent_grid_index_eval_mask]) + neighbor_pt_grid_cls_loss = self.occ_cls_loss(neighbor_pt_grid_idx[neighbor_pt_grid_index_eval_mask], + neighbor_pt_grid_index_gt[neighbor_pt_grid_index_eval_mask]) + # self.log('neighbor_agent_grid_cls_loss', neighbor_agent_grid_cls_loss, **log_params) + # self.log('neighbor_pt_grid_cls_loss', neighbor_pt_grid_cls_loss, **log_params) + # loss = loss + neighbor_agent_grid_cls_loss + neighbor_pt_grid_cls_loss + + grid_agent_occ_seed = pred['grid_agent_occ_seed'] + grid_pt_occ_seed = pred['grid_pt_occ_seed'] + grid_agent_occ_gt_seed = pred['grid_agent_occ_gt_seed'].float() + grid_pt_occ_gt_seed = pred['grid_pt_occ_gt_seed'].float() + grid_agent_occ_eval_mask_seed = pred['grid_agent_occ_eval_mask_seed'] + grid_pt_occ_eval_mask_seed = pred['grid_pt_occ_eval_mask_seed'] + + # plot_scenario_ids = ['74ad7b76d5906d39', '1351ea8b8333ddcb', '1352066cc3c0508d', '135436833ce5b9e7', '13570a32432d449', '13577c32a81336fb'] + if random.random() < 4e-5 or os.getenv('DEBUG'): + with torch.no_grad(): + agent_occ = self.attr_tokenizer.pad_square(grid_agent_occ_seed.sigmoid().detach().cpu().numpy())[0] + agent_occ_gt = self.attr_tokenizer.pad_square(grid_agent_occ_gt_seed.detach().cpu().numpy())[0] + plot_occ_grid(pred['scenario_id'][0], + agent_occ, + gt_occ=agent_occ_gt, + mode='agent', + save_path=self.save_path, + prefix=f'training_{self.global_step:06d}_') + pt_occ = self.attr_tokenizer.pad_square(grid_pt_occ_seed.sigmoid().detach().cpu().numpy())[0] + pt_occ_gt = self.attr_tokenizer.pad_square(grid_pt_occ_gt_seed.detach().cpu().numpy())[0] + plot_occ_grid(pred['scenario_id'][0], + pt_occ, + gt_occ=pt_occ_gt, + mode='pt', + save_path=self.save_path, + prefix=f'training_{self.global_step:06d}_') + + grid_agent_occ_gt_seed[grid_agent_occ_gt_seed == -1] = 0 + if grid_agent_occ_gt_seed.min() < 0 or grid_agent_occ_gt_seed.max() > 1 or \ + grid_pt_occ_gt_seed.min() < 0 or grid_pt_occ_gt_seed.max() > 1: + raise RuntimeError("Occurred invalid values in occ gt") + + agent_occ_loss = self.agent_occ_loss_seed(grid_agent_occ_seed[grid_agent_occ_eval_mask_seed], + grid_agent_occ_gt_seed[grid_agent_occ_eval_mask_seed]) * self.loss_weight['agent_occ_loss'] + pt_occ_loss = self.pt_occ_loss_seed(grid_pt_occ_seed[grid_pt_occ_eval_mask_seed], + grid_pt_occ_gt_seed[grid_pt_occ_eval_mask_seed]) * self.loss_weight['pt_occ_loss'] + + self.log('agent_occ_loss', agent_occ_loss, **log_params) + self.log('pt_occ_loss', pt_occ_loss, **log_params) + loss = loss + agent_occ_loss + pt_occ_loss + + if os.getenv('LOG_TRAIN', False) and (self.predict_motion or self.predict_state): + for a in range(next_token_idx.shape[0]): + print(f"agent: {a}") + if self.predict_motion: + print(f"pred motion: {next_token_idx[a, :, 0].tolist()}, \ngt motion: {next_token_idx_gt[a, :].tolist()}") + print(f"train mask: {next_token_eval_mask[a].long().tolist()}") + if self.predict_state: + print(f"pred state: {next_state_idx[a, :, 0].tolist()}, \ngt state: {next_state_idx_gt[a, :].tolist()}") + print(f"train mask: {next_state_eval_mask[a].long().tolist()}") + num_sa = next_state_idx_seed[..., 0].sum(dim=-1).bool().sum() + for sa in range(num_sa): + print(f"seed agent: {sa}") + print(f"seed pred state: {next_state_idx_seed[sa, :, 0].tolist()}, \ngt seed state: {next_state_idx_gt_seed[sa, :].tolist()}") + # if sa < next_pos_rel_seed.shape[0]: + # print(f"pred pos: {next_pos_rel_seed[sa, :, 0].tolist()}, \ngt pos: {next_pos_rel_gt_seed[sa, :, 0].tolist()}") + # print(f"pred head: {next_head_rel_seed[sa].tolist()}, \ngt head: {next_head_rel_gt_seed[sa].tolist()}") + # print(f"seed train mask: {next_state_eval_mask_seed[sa].long().tolist()}") + + # map token loss + if self.predict_map: + + map_next_token_prob = pred['map_next_token_prob'] + map_next_token_idx_gt = pred['map_next_token_idx_gt'] + map_next_token_eval_mask = pred['map_next_token_eval_mask'] + + map_token_loss = self.map_token_loss(map_next_token_prob[map_next_token_eval_mask], map_next_token_idx_gt[map_next_token_eval_mask]) * self.loss_weight['map_token_loss'] + self.log('map_token_loss', map_token_loss, prog_bar=True, on_step=True, on_epoch=True, batch_size=1) + loss = loss + map_token_loss + + allocated = torch.cuda.memory_allocated(device='cuda:0') / (1024 ** 3) + reserved = torch.cuda.memory_reserved(device='cuda:0') / (1024 ** 3) + self.log('allocated', allocated, **log_params) + self.log('reserved', reserved, **log_params) + + return loss + + def validation_step(self, + data, + batch_idx): + + self.debug = int(os.getenv('DEBUG', 0)) + + # ! validation in training process + if ( + self._mode == 'training' and ( + self.current_epoch not in [5, 10, 20, 25, self.max_epochs] or random.random() > 5e-4) and + not self.debug + ): + self.val_open_loop = False + self.val_close_loop = False + return + + if int(os.getenv('NO_VAL', 0)) or int(os.getenv("CHECK_INPUTS", 0)): + return + + # ! check if save exists + if not self._plot_rollouts: + rollouts_path = os.path.join(self.save_path, f'idx_{self.trainer.global_rank}_{batch_idx}_rollouts.pkl') + if os.path.exists(rollouts_path): + tqdm.write(f'Skipped batch {batch_idx}') + return + else: + rollouts_path = os.path.join(self.save_path, f'{data["scenario_id"][0]}.gif') + if os.path.exists(rollouts_path): + tqdm.write(f'Skipped scenario {data["scenario_id"][0]}') + return + + # ! data preparation + data = self.token_processer(data) + + data = self.match_token_map(data) + data = self.sample_pt_pred(data) + + # find map tokens for entering agents + data = self._fetch_enterings(data) + + data['batch_size_a'] = data['agent']['ptr'][1:] - data['agent']['ptr'][:-1] + data['batch_size_pl'] = data['pt_token']['ptr'][1:] - data['pt_token']['ptr'][:-1] + if isinstance(data, Batch): + data['agent']['av_index'] += data['agent']['ptr'][:-1] + + if int(os.getenv('NEAREST_POS', 0)): + pred = self.encoder.predict_nearest_pos(data, rank=self.local_rank) + return + + # if self.insert_agent: + # pred = self.encoder.insert_agent(data) + # return + + # ! open-loop validation + if self.val_open_loop or int(os.getenv('OPEN_LOOP', 0)): + + pred = self(data) + + # pred['next_state_prob_seed'] = torch.softmax(pred['next_state_prob_seed'], dim=-1)[..., -1] + # plot_prob_seed(pred, self.save_path, suffix=f'_training') + + loss = 0 + + if self.predict_motion: + + # motion token + next_token_idx = pred['next_token_idx'] + next_token_idx_gt = pred['next_token_idx_gt'] # (num_agent, num_step, 10) + next_token_prob = pred['next_token_prob'] + next_token_eval_mask = pred['next_token_eval_mask'] + + token_cls_loss = self.token_cls_loss(next_token_prob[next_token_eval_mask], next_token_idx_gt[next_token_eval_mask]) + loss = loss + token_cls_loss + + if self.predict_state: + + # state token + next_state_idx = pred['next_state_idx'] + next_state_idx_gt = pred['next_state_idx_gt'] + next_state_prob = pred['next_state_prob'] + next_state_eval_mask = pred['next_state_eval_mask'] + + state_cls_loss = self.state_cls_loss(next_state_prob[next_state_eval_mask], next_state_idx_gt[next_state_eval_mask]) + loss = loss + state_cls_loss + + # seed state token + next_state_idx_seed = pred['next_state_idx_seed'] + next_state_idx_gt_seed = pred['next_state_idx_gt_seed'] + + if self.predict_occ: + + grid_agent_occ_seed = pred['grid_agent_occ_seed'] + grid_pt_occ_seed = pred['grid_pt_occ_seed'] + grid_agent_occ_gt_seed = pred['grid_agent_occ_gt_seed'].float() + grid_pt_occ_gt_seed = pred['grid_pt_occ_gt_seed'].float() + + agent_occ = self.attr_tokenizer.pad_square(grid_agent_occ_seed.sigmoid().detach().cpu().numpy())[0] + agent_occ_gt = self.attr_tokenizer.pad_square(grid_agent_occ_gt_seed.detach().cpu().numpy())[0] + plot_occ_grid(pred['scenario_id'][0], + agent_occ, + gt_occ=agent_occ_gt, + mode='agent', + save_path=self.save_path, + prefix=f'eval_') + pt_occ = self.attr_tokenizer.pad_square(grid_pt_occ_seed.sigmoid().detach().cpu().numpy())[0] + pt_occ_gt = self.attr_tokenizer.pad_square(grid_pt_occ_gt_seed.detach().cpu().numpy())[0] + plot_occ_grid(pred['scenario_id'][0], + pt_occ, + gt_occ=pt_occ_gt, + mode='pt', + save_path=self.save_path, + prefix=f'eval_') + + self.log('val_loss', loss, prog_bar=True, on_step=True, on_epoch=True, batch_size=1, sync_dist=True) + + if self.val_insert: + + pred = self(data) + + next_state_idx_seed = pred['next_state_idx_seed'] + next_state_idx_gt_seed = pred['next_state_idx_gt_seed'] + + self.NumInsertAccuracy.update(next_state_idx_seed=next_state_idx_seed, + next_state_idx_gt_seed=next_state_idx_gt_seed) + + return + + # ! close-loop validation + if self.val_close_loop and (self.predict_motion or self.predict_state): + + rollouts = [] + for _ in tqdm(range(self.n_rollout_close_val), leave=False, desc='Rollout ...'): + rollout = self.encoder.inference(data.clone()) + rollouts.append(rollout) + + av_index = int(rollout['ego_index']) + scenario_id = rollout['scenario_id'][0] + + # motion tokens + if self.predict_motion: + + if self._plot_rollouts: # only plot gifs for last 2 epochs for efficiency + plot_val(data, rollout, av_index, self.save_path, pl2seed_radius=self.pl2seed_radius, attr_tokenizer=self.attr_tokenizer) + + # next_token_idx = pred['next_token_idx'][..., None] + # next_token_idx_gt = pred['next_token_idx_gt'][:, 2:] # hard code 2=11//5 + # next_token_eval_mask = pred['next_token_eval_mask'][:, 2:] + + # gt_traj = pred['gt_traj'] + # pred_traj = pred['pred_traj'] + # pred_head = pred['pred_head'] + + # self.TokenCls.update(pred=next_token_idx[next_token_eval_mask], target=next_token_idx_gt[next_token_eval_mask], + # valid_mask=next_token_eval_mask[next_token_eval_mask]) + # self.log('val_token_cls_acc', self.TokenCls, prog_bar=True, on_step=True, on_epoch=True, batch_size=1, sync_dist=True) + + # remove the agents which are unseen at current step + # eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps - 1] + + # self.minADE.update(pred=pred_traj[eval_mask], target=gt_traj[eval_mask], valid_mask=valid_mask[eval_mask]) + # self.minFDE.update(pred=pred_traj[eval_mask], target=gt_traj[eval_mask], valid_mask=valid_mask[eval_mask]) + # print('ade: ', self.minADE.compute(), 'fde: ', self.minFDE.compute()) + + # self.log('val_minADE', self.minADE, prog_bar=True, on_step=True, on_epoch=True, batch_size=1) + # self.log('val_minFDE', self.minFDE, prog_bar=True, on_step=True, on_epoch=True, batch_size=1) + + # state tokens + if self.predict_state: + + if self.use_grid_token: + next_pos_rel_prob_seed = rollout['next_pos_rel_prob_seed'].cpu().numpy() # (s, t, grid_size) + prob, _ = self.attr_tokenizer.pad_square(next_pos_rel_prob_seed) + + if self._plot_rollouts: + if self.use_grid_token: + plot_insert_grid(scenario_id, + prob, + save_path=self.save_path, + prefix=f'inference_') + plot_prob_seed(scenario_id, + rollout['next_state_prob_seed'].cpu().numpy(), + self.save_path, + prefix=f'inference_') + + next_state_idx = rollout['next_state_idx'][..., None] + # next_state_idx_gt = rollout['next_state_idx_gt'][:, 2:] + # next_state_eval_mask = rollout['next_state_eval_mask'][:, 2:] + + # self.StateCls.update(pred=next_state_idx[next_token_eval_mask], target=next_state_idx_gt[next_token_eval_mask], + # valid_mask=next_token_eval_mask[next_token_eval_mask]) + # self.log('val_state_cls_acc', self.TokenCls, prog_bar=True, on_step=True, on_epoch=True, batch_size=1, sync_dist=True) + + self.StateAccuracy.update(state_idx=next_state_idx[..., 0]) + self.log('valid_accuracy', self.StateAccuracy.compute()['valid'], prog_bar=True, on_step=True, on_epoch=True, batch_size=1) + self.log('invalid_accuracy', self.StateAccuracy.compute()['invalid'], prog_bar=True, on_step=True, on_epoch=True, batch_size=1) + self.local_logger.info(rollout['log_message']) + # print(rollout['log_message']) + # print(self.StateAccuracy) + + if self.predict_occ: + + grid_agent_occ_seed = rollout['grid_agent_occ_seed'] + grid_pt_occ_seed = rollout['grid_pt_occ_seed'] + grid_agent_occ_gt_seed = rollout['grid_agent_occ_gt_seed'] + + agent_occ = self.attr_tokenizer.pad_square(grid_agent_occ_seed.sigmoid().cpu().numpy())[0] + agent_occ_gt = self.attr_tokenizer.pad_square(grid_agent_occ_gt_seed.sigmoid().cpu().numpy())[0] + if self._plot_rollouts: + plot_occ_grid(scenario_id, + agent_occ, + gt_occ=agent_occ_gt, + mode='agent', + save_path=self.save_path, + prefix=f'inference_') + + if self._online_metric or self._save_validate_reuslts: + + # ! format results + pred_valid, token_pos, token_head = [], [], [] + pred_traj, pred_head, pred_z = [], [], [] + pred_shape, pred_type, pred_state = [], [], [] + agent_id = [] + for rollout in rollouts: + pred_valid.append(rollout['pred_valid']) + token_pos.append(rollout['pos_a']) + token_head.append(rollout['head_a']) + pred_traj.append(rollout['pred_traj']) + pred_head.append(rollout['pred_head']) + pred_z.append(rollout['pred_z']) + pred_shape.append(rollout['eval_shape']) + pred_type.append(rollout['pred_type']) + pred_state.append(rollout['next_state_idx']) + agent_id.append(rollout['agent_id']) + + pred_valid = torch.stack(pred_valid, dim=1) + token_pos = torch.stack(token_pos, dim=1) + token_head = torch.stack(token_head, dim=1) + pred_traj = torch.stack(pred_traj, dim=1) # (n_agent, n_rollout, n_step, 2) + pred_head = torch.stack(pred_head, dim=1) + pred_z = torch.stack(pred_z, dim=1) + pred_shape = torch.stack(pred_shape, dim=1) # [n_agent, n_rollout, 3] + pred_type = torch.stack(pred_type, dim=1) # [n_agent, n_rollout] + pred_state = torch.stack(pred_state, dim=1) # [n_agent, n_rollout, n_step // shift] + agent_id = torch.stack(agent_id, dim=1) # [n_agent, n_rollout] + + agent_batch = torch.zeros((pred_traj.shape[0]), dtype=torch.long) + rollouts = dict( + _scenario_id=data['scenario_id'], + scenario_id=get_scenario_id_int_tensor(data['scenario_id']), + av_id=int(rollouts[0]['agent_id'][rollouts[0]['ego_index']]), # NOTE: hard code!!! + agent_id=agent_id.cpu(), + agent_batch=agent_batch.cpu(), + pred_traj=pred_traj.cpu(), + pred_z=pred_z.cpu(), + pred_head=pred_head.cpu(), + pred_shape=pred_shape.cpu(), + pred_type=pred_type.cpu(), + pred_state=pred_state.cpu(), + pred_valid=pred_valid.cpu(), + token_pos=token_pos.cpu(), + token_head=token_head.cpu(), + tfrecord_path=data['tfrecord_path'], + ) + + if self._save_validate_reuslts: + with open(rollouts_path, 'wb') as f: + pickle.dump(rollouts, f) + + if self._online_metric: + self._long_metrics.update(rollouts) + + def on_validation_start(self): + self.scenario_rollouts = [] + self.batch_metric = defaultdict(list) + + def on_validation_epoch_end(self): + if self.val_close_loop: + + if self._long_metrics is not None: + epoch_long_metrics = self._long_metrics.compute() + if self.global_rank == 0: + epoch_long_metrics['epoch'] = self.current_epoch + self.logger.log_metrics(epoch_long_metrics) + + self._long_metrics.reset() + + self.minADE.reset() + self.minFDE.reset() + self.StateAccuracy.reset() + + def configure_optimizers(self): + optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) + + def lr_lambda(current_step): + if current_step + 1 < self.warmup_steps: + return float(current_step + 1) / float(max(1, self.warmup_steps)) + return max( + 0.0, 0.5 * (1.0 + math.cos(math.pi * (current_step - self.warmup_steps) / float(max(1, self.total_steps - self.warmup_steps)))) + ) + + lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) + return [optimizer], [lr_scheduler] + + def load_state_from_file(self, filename, to_cpu=False): + logger = self.local_logger + + if not os.path.isfile(filename): + raise FileNotFoundError + + logger.info('==> Loading parameters from checkpoint %s to %s' % (filename, 'CPU' if to_cpu else 'GPU')) + loc_type = torch.device('cpu') if to_cpu else None + checkpoint = torch.load(filename, map_location=loc_type) + + version = checkpoint.get("version", None) + if version is not None: + logger.info('==> Checkpoint trained from version: %s' % version) + + + model_state_disk = checkpoint['state_dict'] + logger.info(f'The number of disk ckpt keys: {len(model_state_disk)}') + + model_state = self.state_dict() + model_state_disk_filter = {} + for key, val in model_state_disk.items(): + if key in model_state and model_state_disk[key].shape == model_state[key].shape: + model_state_disk_filter[key] = val + else: + if key not in model_state: + print(f'Ignore key in disk (not found in model): {key}, shape={val.shape}') + else: + print(f'Ignore key in disk (shape does not match): {key}, load_shape={val.shape}, model_shape={model_state[key].shape}') + + model_state_disk = model_state_disk_filter + missing_keys, unexpected_keys = self.load_state_dict(model_state_disk, strict=False) + + logger.info(f'Missing keys: {missing_keys}') + logger.info(f'The number of missing keys: {len(missing_keys)}') + logger.info(f'The number of unexpected keys: {len(unexpected_keys)}') + logger.info('==> Done (total keys %d)' % (len(model_state))) + + epoch = checkpoint.get('epoch', -1) + it = checkpoint.get('it', 0.0) + + return it, epoch + + def match_token_map(self, data): + traj_pos = data['map_save']['traj_pos'].to(torch.float) + traj_theta = data['map_save']['traj_theta'].to(torch.float) + pl_idx_list = data['map_save']['pl_idx_list'] + token_sample_pt = self.map_token['sample_pt'].to(traj_pos.device) + token_src = self.map_token['traj_src'].to(traj_pos.device) + max_traj_len = self.map_token['traj_src'].shape[1] + pl_num = traj_pos.shape[0] + + pt_token_pos = traj_pos[:, 0, :].clone() + pt_token_orientation = traj_theta.clone() + cos, sin = traj_theta.cos(), traj_theta.sin() + rot_mat = traj_theta.new_zeros(pl_num, 2, 2) + rot_mat[..., 0, 0] = cos + rot_mat[..., 0, 1] = -sin + rot_mat[..., 1, 0] = sin + rot_mat[..., 1, 1] = cos + traj_pos_local = torch.bmm((traj_pos - traj_pos[:, 0:1]), rot_mat.view(-1, 2, 2)) + distance = torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1)) ** 2, dim=(-2, -1)) + pt_token_id = torch.argmin(distance, dim=1) + + if self.noise: + topk_indices = torch.argsort(torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1))**2, dim=(-2, -1)), dim=1)[:, :8] + sample_topk = torch.randint(0, topk_indices.shape[-1], size=(topk_indices.shape[0], 1), device=topk_indices.device) + pt_token_id = torch.gather(topk_indices, 1, sample_topk).squeeze(-1) + + # cos, sin = traj_theta.cos(), traj_theta.sin() + # rot_mat = traj_theta.new_zeros(pl_num, 2, 2) + # rot_mat[..., 0, 0] = cos + # rot_mat[..., 0, 1] = sin + # rot_mat[..., 1, 0] = -sin + # rot_mat[..., 1, 1] = cos + # token_src_world = torch.bmm(token_src[None, ...].repeat(pl_num, 1, 1, 1).reshape(pl_num, -1, 2), + # rot_mat.view(-1, 2, 2)).reshape(pl_num, token_src.shape[0], max_traj_len, 2) + traj_pos[:, None, [0], :] + # token_src_world_select = token_src_world.view(-1, 1024, 11, 2)[torch.arange(pt_token_id.view(-1).shape[0]), pt_token_id.view(-1)].view(pl_num, max_traj_len, 2) + + pl_idx_full = pl_idx_list.clone() + token2pl = torch.stack([torch.arange(len(pl_idx_list), device=traj_pos.device), pl_idx_full.long()]) + count_nums = [] + for pl in pl_idx_full.unique(): + pt = token2pl[0, token2pl[1, :] == pl] + left_side = (data['pt_token']['side'][pt] == 0).sum() + right_side = (data['pt_token']['side'][pt] == 1).sum() + center_side = (data['pt_token']['side'][pt] == 2).sum() + count_nums.append(torch.Tensor([left_side, right_side, center_side])) + count_nums = torch.stack(count_nums, dim=0) + num_polyline = int(count_nums.max().item()) + traj_mask = torch.zeros((int(len(pl_idx_full.unique())), 3, num_polyline), dtype=bool) + idx_matrix = torch.arange(traj_mask.size(2)).unsqueeze(0).unsqueeze(0) + idx_matrix = idx_matrix.expand(traj_mask.size(0), traj_mask.size(1), -1) + counts_num_expanded = count_nums.unsqueeze(-1) + mask_update = idx_matrix < counts_num_expanded + traj_mask[mask_update] = True + + data['pt_token']['traj_mask'] = traj_mask + data['pt_token']['position'] = torch.cat([pt_token_pos, torch.zeros((data['pt_token']['num_nodes'], 1), + device=traj_pos.device, dtype=torch.float)], dim=-1) + data['pt_token']['orientation'] = pt_token_orientation + data['pt_token']['height'] = data['pt_token']['position'][:, -1] + data[('pt_token', 'to', 'map_polygon')] = {} + data[('pt_token', 'to', 'map_polygon')]['edge_index'] = token2pl # (2, num_points) + data['pt_token']['token_idx'] = pt_token_id + + # data['pt_token']['batch'] = torch.zeros(data['pt_token']['num_nodes'], device=traj_pos.device).long() + # data['pt_token']['ptr'] = torch.tensor([0, data['pt_token']['num_nodes']], device=traj_pos.device).long() + + return data + + def sample_pt_pred(self, data): + traj_mask = data['pt_token']['traj_mask'] + raw_pt_index = torch.arange(1, traj_mask.shape[2]).repeat(traj_mask.shape[0], traj_mask.shape[1], 1) + masked_pt_index = raw_pt_index.view(-1)[torch.randperm(raw_pt_index.numel())[:traj_mask.shape[0] * traj_mask.shape[1] * ((traj_mask.shape[2] - 1) // 3)].reshape(traj_mask.shape[0], traj_mask.shape[1], (traj_mask.shape[2] - 1) // 3)] + masked_pt_index = torch.sort(masked_pt_index, -1)[0] + pt_valid_mask = traj_mask.clone() + pt_valid_mask.scatter_(2, masked_pt_index, False) + pt_pred_mask = traj_mask.clone() + pt_pred_mask.scatter_(2, masked_pt_index, False) + tmp_mask = pt_pred_mask.clone() + tmp_mask[:, :, :] = True + tmp_mask.scatter_(2, masked_pt_index-1, False) + pt_pred_mask.masked_fill_(tmp_mask, False) + pt_pred_mask = pt_pred_mask * torch.roll(traj_mask, shifts=-1, dims=2) + pt_target_mask = torch.roll(pt_pred_mask, shifts=1, dims=2) + + data['pt_token']['pt_valid_mask'] = pt_valid_mask[traj_mask] + data['pt_token']['pt_pred_mask'] = pt_pred_mask[traj_mask] + data['pt_token']['pt_target_mask'] = pt_target_mask[traj_mask] + + return data + + def _fetch_enterings(self, data: HeteroData, plot: bool=False): + data['agent']['grid_token_idx'] = torch.zeros_like(data['agent']['state_idx']).long() + data['agent']['grid_offset_xy'] = torch.zeros_like(data['agent']['token_pos']) + data['agent']['heading_token_idx'] = torch.zeros_like(data['agent']['state_idx']).long() + data['agent']['sort_indices'] = torch.zeros_like(data['agent']['state_idx']).long() + data['agent']['inrange_mask'] = torch.zeros_like(data['agent']['state_idx']).bool() + data['agent']['bos_mask'] = torch.zeros_like(data['agent']['state_idx']).bool() + + data['agent']['pos_xy'] = torch.zeros_like(data['agent']['token_pos']) + if self.predict_occ: + num_step = data['agent']['state_idx'].shape[1] + data['agent']['pt_grid_token_idx'] = torch.zeros_like(data['pt_token']['token_idx'])[None].repeat(num_step, 1).long() + + for b in range(data.num_graphs): + av_index = int(data['agent']['av_index'][b]) + agent_batch_mask = data['agent']['batch'] == b + pt_batch_mask = data['pt_token']['batch'] == b + pt_token_idx = data['pt_token']['token_idx'][pt_batch_mask] + pt_pos = data['pt_token']['position'][pt_batch_mask] + agent_token_pos = data['agent']['token_pos'][agent_batch_mask] + agent_token_heading = data['agent']['token_heading'][agent_batch_mask] + state_idx = data['agent']['state_idx'][agent_batch_mask] + ego_pos = agent_token_pos[av_index] # NOTE: `av_index` will be added by `ptr` later + ego_heading = agent_token_heading[av_index] + + grid_token_idx = torch.full(state_idx.shape, -1, device=state_idx.device) + offset_xy = torch.zeros_like(agent_token_pos) + sort_indices = torch.zeros_like(grid_token_idx) + pt_grid_token_idx = torch.full((state_idx.shape[1], *pt_token_idx.shape), -1, device=pt_token_idx.device) + + pos_xy = torch.zeros((*state_idx.shape, 2), device=state_idx.device) + + is_bos = [] + is_inrange = [] + for t in range(agent_token_pos.shape[1]): # num_step + + # tokenize position + is_bos_t = state_idx[:, t] == self.enter_state + is_invalid_t = state_idx[:, t] == self.invalid_state + is_inrange_t = ((agent_token_pos[:, t] - ego_pos[[t]]) ** 2).sum(-1).sqrt() <= self.pl2seed_radius + grid_index_t, offset_xy_t = self.attr_tokenizer.encode_pos(x=agent_token_pos[~is_invalid_t & is_inrange_t, t], + y=ego_pos[[t]], + theta_y=ego_heading[[t]]) + grid_token_idx[~is_invalid_t & is_inrange_t, t] = grid_index_t + offset_xy[~is_invalid_t & is_inrange_t, t] = offset_xy_t + + pos_xy[~is_invalid_t & is_inrange_t, t] = agent_token_pos[~is_invalid_t & is_inrange_t, t] - ego_pos[[t]] + + # distance = ((agent_token_pos[:, t] - ego_pos[[t]]) ** 2).sum(-1).sqrt() + head_vector = torch.stack([ego_heading[[t]].cos(), ego_heading[[t]].sin()], dim=-1) + distance = angle_between_2d_vectors(ctr_vector=head_vector, + nbr_vector=agent_token_pos[:, t] - ego_pos[[t]]) + # distance = torch.rand(agent_token_pos.shape[0], device=agent_token_pos.device) + distance[~(is_bos_t & is_inrange_t)] = torch.inf + sort_dist, sort_indice = distance.sort() + sort_indice[torch.isinf(sort_dist)] = av_index + sort_indices[:, t] = sort_indice + + is_bos.append(is_bos_t) + is_inrange.append(is_inrange_t) + + # tokenize pt token + if self.predict_occ: + is_inrange_t = ((pt_pos[:, :2] - ego_pos[None, t]) ** 2).sum(-1).sqrt() <= self.pl2seed_radius + grid_index_t, _ = self.attr_tokenizer.encode_pos(x=pt_pos[is_inrange_t, :2], + y=ego_pos[[t]], + theta_y=ego_heading[[t]]) + + pt_grid_token_idx[t, is_inrange_t] = grid_index_t + + # tokenize heading + rel_heading = agent_token_heading - ego_heading[None, ...] + heading_token_idx = self.attr_tokenizer.encode_heading(rel_heading) + + data['agent']['grid_token_idx'][agent_batch_mask] = grid_token_idx + data['agent']['grid_offset_xy'][agent_batch_mask] = offset_xy + data['agent']['pos_xy'][agent_batch_mask] = pos_xy + data['agent']['heading_token_idx'][agent_batch_mask] = heading_token_idx + data['agent']['sort_indices'][agent_batch_mask] = sort_indices + data['agent']['inrange_mask'][agent_batch_mask] = torch.stack(is_inrange, dim=1) + data['agent']['bos_mask'][agent_batch_mask] = torch.stack(is_bos, dim=1) + if self.predict_occ: + data['agent']['pt_grid_token_idx'][:, pt_batch_mask] = pt_grid_token_idx + + plot = False + if plot: + scenario_id = data['scenario_id'][b] + dummy_prob = np.zeros((ego_pos.shape[0], self.attr_tokenizer.grid.shape[0])) + .5 + indices = grid_token_idx[:, 1:][state_idx[:, 1:] == self.enter_state].reshape(-1).cpu().numpy() + dummy_prob, indices = self.attr_tokenizer.pad_square(dummy_prob, indices) + # plot_insert_grid(scenario_id, dummy_prob, + # self.attr_tokenizer.grid.cpu().numpy(), + # ego_pos.cpu().numpy(), + # None, + # save_path=os.path.join(self.save_path, 'vis'), + # indices=indices[np.newaxis, ...], + # inference=True, + # all_t_in_one=True) + + enter_index = [grid_token_idx[:, i][state_idx[:, i] == self.enter_state].tolist() + for i in range(agent_token_pos.shape[1])] + agent_labels = [[f'A{i}'] * agent_token_pos.shape[1] for i in range(agent_token_pos.shape[0])] + plot_scenario(scenario_id, + data['map_point']['position'].cpu().numpy(), + agent_token_pos.cpu().numpy(), + agent_token_heading.cpu().numpy(), + state_idx.cpu().numpy(), + types=list(map(lambda i: self.encoder.agent_encoder.agent_type[i], + data['agent']['type'].tolist())), + av_index=av_index, + pl2seed_radius=self.pl2seed_radius, + attr_tokenizer=self.attr_tokenizer, + enter_index=enter_index, + save_gif=False, + save_path=os.path.join(self.save_path, 'vis'), + agent_labels=agent_labels, + tokenized=True) + + return data diff --git a/backups/dev/modules/agent_decoder.py b/backups/dev/modules/agent_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..1a0640a097958f7b95eafcd0ed11b54c34ab998b --- /dev/null +++ b/backups/dev/modules/agent_decoder.py @@ -0,0 +1,2419 @@ +import os +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm +from typing import Dict, Mapping, Optional, Literal +from torch_cluster import radius, radius_graph +from torch_geometric.data import HeteroData, Batch +from torch_geometric.utils import dense_to_sparse, subgraph +from scipy.optimize import linear_sum_assignment + +from dev.modules.attr_tokenizer import Attr_Tokenizer +from dev.modules.layers import * +from dev.utils.visualization import * +from dev.datasets.preprocess import AGENT_SHAPE, AGENT_TYPE +from dev.utils.func import angle_between_2d_vectors, wrap_angle, weight_init + + +class HungarianMatcher(nn.Module): + + def __init__(self, loss_weight: dict, enter_state: int = 0): + super().__init__() + self.enter_state = enter_state + self.cost_state = loss_weight['state_cls_loss'] + self.cost_pos = loss_weight['pos_cls_loss'] + self.cost_head = loss_weight['head_cls_loss'] + self.cost_shape = loss_weight['shape_reg_loss'] + self.seed_state_weight = loss_weight['seed_state_weight'] + self.seed_type_weight = loss_weight['seed_type_weight'] + + @torch.no_grad() + def forward(self, outputs, targets, ptr_pred, ptr_gt, valid_mask=None): + + pred_indices = [] + gt_indices = [] + + for b in range(len(ptr_gt) - 1): + + start_pred, end_pred = ptr_pred[b], ptr_pred[b + 1] + start_gt, end_gt = ptr_gt[b], ptr_gt[b + 1] + + pos_pred = outputs['pos_pred'][start_pred : end_pred] # (n, s, l) + shape_pred = outputs['shape_pred'][start_pred : end_pred] + + pos_gt = targets['pos_gt'][start_gt : end_gt] + shape_gt = targets['shape_gt'][start_gt : end_gt] + + num_pred = pos_pred.shape[0] + num_gt = pos_gt.shape[0] + + cost_pos = F.cross_entropy(pos_pred[:, None].repeat(1, num_gt, 1, 1).reshape(-1, pos_pred.shape[-1]), + pos_gt[None, ...].repeat(num_pred, 1, 1).reshape(-1), + label_smoothing=0.1, ignore_index=-1, reduction='none' + ).reshape(num_pred, num_gt, -1) + cost_shape = ((shape_pred[:, None] - shape_gt[None, ...]) ** 2).sum(-1) + + C = ( + self.cost_pos * cost_pos + + self.cost_shape * cost_shape + ) + + C = C.reshape(num_pred, num_gt, -1).cpu().numpy() + + if valid_mask is not None: + # in case of seed size is smaller than the maximum number of gt among all steps + C[:, ~valid_mask[start_gt : end_gt].cpu().numpy().astype(np.bool_)] = 1 << 15 + + _indices = [] + for t in range(C.shape[-1]): # num_step + _indices.append(linear_sum_assignment(C[..., t])) + + _indices = ( + torch.as_tensor(np.array([indices_t[0] for indices_t in _indices]) + int(start_pred), dtype=torch.long).transpose(-1, -2), + torch.as_tensor(np.array([indices_t[1] for indices_t in _indices]) + int(start_gt), dtype=torch.long).transpose(-1, -2), + ) + + pred_indices.append(_indices[0]) + gt_indices.append(_indices[1]) + + pred_indices = torch.cat(pred_indices) + gt_indices = torch.cat(gt_indices) + + return pred_indices, gt_indices + + def __repr__(self): + head = "Matcher " + self.__class__.__name__ + body = [ + "cost_class: {}".format(self.cost_class), + "cost_pos: {}".format(self.cost_pos), + "cost_head: {}".format(self.cost_head), + ] + _repr_indent = 4 + lines = [head] + [" " * _repr_indent + line for line in body] + return "\n".join(lines) + + +class SMARTAgentDecoder(nn.Module): + + def __init__(self, + dataset: str, + input_dim: int, + hidden_dim: int, + num_historical_steps: int, + time_span: Optional[int], + pl2a_radius: float, + pl2seed_radius: float, + a2a_radius: float, + a2sa_radius: float, + pl2sa_radius: float, + num_freq_bands: int, + num_layers: int, + num_heads: int, + head_dim: int, + dropout: float, + token_size: int, + attr_tokenizer: Attr_Tokenizer=None, + predict_motion: bool=False, + predict_state: bool=False, + predict_map: bool=False, + predict_occ: bool=False, + state_token: Dict[str, int]=None, + use_grid_token: bool=True, + seed_size: int=5, + buffer_size: int=32, + num_recurrent_steps_val: int=-1, + loss_weight: dict=None, + logger=None) -> None: + + super(SMARTAgentDecoder, self).__init__() + self.dataset = dataset + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.num_historical_steps = num_historical_steps + self.time_span = time_span if time_span is not None else num_historical_steps + self.pl2a_radius = pl2a_radius + self.pl2seed_radius = pl2seed_radius + self.a2a_radius = a2a_radius + self.a2sa_radius = a2sa_radius + self.pl2sa_radius = pl2sa_radius + self.num_freq_bands = num_freq_bands + self.num_layers = num_layers + self.num_heads = num_heads + self.head_dim = head_dim + self.dropout = dropout + self.predict_motion = predict_motion + self.predict_state = predict_state + self.predict_map = predict_map + self.predict_occ = predict_occ + self.use_grid_token = use_grid_token + self.num_recurrent_steps_val = num_recurrent_steps_val + self.loss_weight = loss_weight + self.logger = logger + + self.attr_tokenizer = attr_tokenizer + + # state tokens + self.state_type = list(state_token.keys()) + self.state_token = state_token + self.invalid_state = int(state_token['invalid']) + self.valid_state = int(state_token['valid']) + self.enter_state = int(state_token['enter']) + self.exit_state = int(state_token['exit']) + + self.seed_state_type = ['invalid', 'enter'] + self.valid_state_type = ['invalid', 'valid', 'exit'] + + input_dim_x_a = 2 + input_dim_r_t = 4 + input_dim_r_pt2a = 3 + input_dim_r_pt2sa = 3 + input_dim_r_a2a = 3 + input_dim_r_a2sa = 3#4 + input_dim_motion_token = 8 # tokens: (token_size, 4, 2) + input_dim_offset_token = 2 + + self.seed_size = seed_size + self.buffer_size = buffer_size + + # self.agent_type = ['veh', 'ped', 'cyc', 'seed'] + self.type_a_emb = nn.Embedding(len(AGENT_TYPE), hidden_dim) + self.shape_emb = MLPEmbedding(input_dim=3, hidden_dim=hidden_dim) + self.state_a_emb = nn.Embedding(len(self.state_type), hidden_dim) + self.motion_gap = 1. + self.heading_gap = 1. + self.invalid_shape_value = .1 + self.invalid_motion_value = -2. + self.invalid_head_value = -2. + + self.x_a_emb = FourierEmbedding(input_dim=input_dim_x_a, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands) + self.r_t_emb = FourierEmbedding(input_dim=input_dim_r_t, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands) + self.r_pt2a_emb = FourierEmbedding(input_dim=input_dim_r_pt2a, hidden_dim=hidden_dim, + num_freq_bands=num_freq_bands) + self.r_a2a_emb = FourierEmbedding(input_dim=input_dim_r_a2a, hidden_dim=hidden_dim, + num_freq_bands=num_freq_bands) + # self.r_sa2sa_emb = FourierEmbedding(input_dim=input_dim_r_t, hidden_dim=hidden_dim, + # num_freq_bands=num_freq_bands) + self.r_pt2sa_emb = FourierEmbedding(input_dim=input_dim_r_pt2sa, hidden_dim=hidden_dim, + num_freq_bands=num_freq_bands) + self.r_a2sa_emb = FourierEmbedding(input_dim=input_dim_r_a2sa, hidden_dim=hidden_dim, + num_freq_bands=num_freq_bands) + self.token_emb_veh = MLPEmbedding(input_dim=input_dim_motion_token, hidden_dim=hidden_dim) + self.token_emb_ped = MLPEmbedding(input_dim=input_dim_motion_token, hidden_dim=hidden_dim) + self.token_emb_cyc = MLPEmbedding(input_dim=input_dim_motion_token, hidden_dim=hidden_dim) + self.token_emb_grid = MLPEmbedding(input_dim=input_dim_offset_token, hidden_dim=hidden_dim) + self.no_token_emb = nn.Embedding(1, hidden_dim) + self.bos_token_emb = nn.Embedding(1, hidden_dim) + self.invalid_offset_token_emb = nn.Embedding(1, hidden_dim) + + if self.use_grid_token: + num_inputs = 4 + else: + num_inputs = 3 + self.fusion_emb = MLPEmbedding(input_dim=self.hidden_dim * num_inputs, hidden_dim=self.hidden_dim) + + self.t_attn_layers = nn.ModuleList( + [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, + bipartite=False, has_pos_emb=True) for _ in range(num_layers)] + ) + self.pt2a_attn_layers = nn.ModuleList( + [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, + bipartite=True, has_pos_emb=True) for _ in range(num_layers)] + ) + self.a2a_attn_layers = nn.ModuleList( + [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, + bipartite=False, has_pos_emb=True) for _ in range(num_layers)] + ) + + # FIXME: for test! + self.seed_layers = 3 + # self.sa2sa_attn_layers = nn.ModuleList( + # [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, + # bipartite=False, has_pos_emb=True) for _ in range(self.seed_layers)] + # ) + self.pt2sa_attn_layers = nn.ModuleList( + [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, + bipartite=True, has_pos_emb=True) for _ in range(self.seed_layers)] + ) + self.a2sa_attn_layers = nn.ModuleList( + [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, + bipartite=False, has_pos_emb=True) for _ in range(self.seed_layers)] + ) + self.occ2sa_attn_layers = nn.ModuleList( + [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, + bipartite=True, has_pos_emb=False) for _ in range(self.seed_layers)] + ) + + self.token_size = token_size # 2048 + # agent motion prediction head + self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, + output_dim=self.token_size) + # agent state prediction head + self.state_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, + output_dim=len(self.valid_state_type)) + + self.seed_state_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, + output_dim=len(self.seed_state_type)) + self.seed_type_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, + output_dim=len(AGENT_TYPE) - 1) + self.seed_shape_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, + output_dim=3) + + self.grid_size = self.attr_tokenizer.grid_size + self.angle_size = self.attr_tokenizer.angle_size + + if self.use_grid_token: + self.seed_pos_rel_token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, + output_dim=self.grid_size) + self.seed_offset_xy_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, + output_dim=2) + self.seed_agent_occ_embed = MLPLayer(input_dim=self.grid_size, hidden_dim=hidden_dim, + output_dim=hidden_dim) + else: + self.seed_pos_rel_xy_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, + output_dim=2) + self.seed_heading_rel_token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, + output_dim=self.angle_size) + # self.seed_pt_occ_embed = MLPLayer(input_dim=self.grid_size, hidden_dim=hidden_dim, + # output_dim=hidden_dim) + + if self.predict_occ: + self.grid_agent_occ_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, + output_dim=self.grid_size) + self.grid_pt_occ_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, + output_dim=self.grid_size) + self.grid_index_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, + output_dim=self.grid_size) + + # self.num_seed_feature = 1 + # self.num_seed_feature = self.seed_size + self.num_seed_feature = 10 + + # self.trajectory_token = token_data['token'] # dict('veh', 'ped', 'cyc') (2048, 4, 2) + # self.trajectory_token_traj = token_data['traj'] # (2048, 6, 3) + # self.trajectory_token_all = token_data['token_all'] # (2048, 6, 4, 2) + self.apply(weight_init) + + self.shift = 5 + self.motion_beam_size = 5 + self.insert_beam_size = 10 + self.hist_mask = True + self.temporal_attn_to_invalid = False + self.use_rel = False + self.inference_filter_overlap = True + assert self.num_recurrent_steps_val % self.shift == 0 or self.num_recurrent_steps_val == -1, \ + f"Invalid num_recurrent_steps_val: {num_recurrent_steps_val}." + + # seed agent + self.temporal_attn_seed = False + self.seed_attn_to_av = True + self.seed_use_ego_motion = False + + self.matcher = HungarianMatcher(loss_weight=loss_weight, enter_state=self.enter_state) + + def transform_rel(self, token_traj, prev_pos, prev_heading=None): + if prev_heading is None: + diff_xy = prev_pos[:, :, -1, :] - prev_pos[:, :, -2, :] + prev_heading = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0]) + + num_agent, num_step, traj_num, traj_dim = token_traj.shape + cos, sin = prev_heading.cos(), prev_heading.sin() + rot_mat = torch.zeros((num_agent, num_step, 2, 2), device=prev_heading.device) + rot_mat[:, :, 0, 0] = cos + rot_mat[:, :, 0, 1] = -sin + rot_mat[:, :, 1, 0] = sin + rot_mat[:, :, 1, 1] = cos + agent_diff_rel = torch.bmm(token_traj.view(-1, traj_num, 2), rot_mat.view(-1, 2, 2)).view(num_agent, num_step, traj_num, traj_dim) + agent_pred_rel = agent_diff_rel + prev_pos[:, :, -1:, :] + return agent_pred_rel + + def _agent_token_embedding(self, data, agent_token_index, agent_state, agent_offset_token_idx, pos_a, head_a, + inference=False, filter_mask=None, av_index=None): + + if filter_mask is None: + filter_mask = torch.ones_like(agent_state[:, 2], dtype=torch.bool) + + num_agent, num_step, traj_dim = pos_a.shape # traj_dim=2 + agent_type = data['agent']['type'][filter_mask] + veh_mask = agent_type == 0 + ped_mask = agent_type == 1 + cyc_mask = agent_type == 2 + + motion_vector_a, head_vector_a = self._build_vector_a(pos_a, head_a, agent_state) + + + trajectory_token_veh = data['agent']['trajectory_token_veh'] # [n_token, 6, 4, 2] + trajectory_token_ped = data['agent']['trajectory_token_ped'] + trajectory_token_cyc = data['agent']['trajectory_token_cyc'] + agent_token_emb_veh = self.token_emb_veh(trajectory_token_veh[:, -1].flatten(1, 2)) + agent_token_emb_ped = self.token_emb_ped(trajectory_token_ped[:, -1].flatten(1, 2)) + agent_token_emb_cyc = self.token_emb_cyc(trajectory_token_cyc[:, -1].flatten(1, 2)) + + # add bos token embedding + agent_token_emb_veh = torch.cat([agent_token_emb_veh, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())]) + agent_token_emb_ped = torch.cat([agent_token_emb_ped, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())]) + agent_token_emb_cyc = torch.cat([agent_token_emb_cyc, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())]) + + # add invalid token embedding + agent_token_emb_veh = torch.cat([agent_token_emb_veh, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())]) + agent_token_emb_ped = torch.cat([agent_token_emb_ped, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())]) + agent_token_emb_cyc = torch.cat([agent_token_emb_cyc, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())]) + + # additional token embeddings are already added -> -1: invalid, -2: bos + agent_token_emb = torch.zeros((num_agent, num_step, self.hidden_dim), device=pos_a.device) + agent_token_emb[veh_mask] = agent_token_emb_veh[agent_token_index[veh_mask]] + agent_token_emb[ped_mask] = agent_token_emb_ped[agent_token_index[ped_mask]] + agent_token_emb[cyc_mask] = agent_token_emb_cyc[agent_token_index[cyc_mask]] + + # grid embedding + self.grid_token_emb = self.token_emb_grid(self.attr_tokenizer.grid) + self.grid_token_emb = torch.cat([self.grid_token_emb, self.invalid_offset_token_emb(torch.zeros(1, device=pos_a.device).long())]) + offset_token_emb = self.grid_token_emb[agent_offset_token_idx] + + # 'vehicle', 'pedestrian', 'cyclist', 'background' + is_invalid = agent_state == self.invalid_state + agent_types = data['agent']['type'].clone()[filter_mask].long().repeat_interleave(repeats=num_step, dim=0) + agent_types[is_invalid.reshape(-1)] = AGENT_TYPE.index('seed') + agent_shapes = data['agent']['shape'].clone()[filter_mask, self.num_historical_steps - 1, :].repeat_interleave(repeats=num_step, dim=0) + agent_shapes[is_invalid.reshape(-1)] = self.invalid_shape_value + + # TODO: fix ego_pos in inference mode + offset_pos = pos_a - pos_a[av_index].repeat_interleave(repeats=data['batch_size_a'], dim=0) + feat_a, categorical_embs = self._build_agent_feature(num_step, pos_a.device, + motion_vector_a, + head_vector_a, + agent_token_emb, + offset_token_emb, + offset_pos=offset_pos, + type=agent_types, + shape=agent_shapes, + state=agent_state, + n=num_agent) + + if inference: + return ( + feat_a, + agent_token_emb, + agent_token_emb_veh, + agent_token_emb_ped, + agent_token_emb_cyc, + categorical_embs, + trajectory_token_veh, + trajectory_token_ped, + trajectory_token_cyc, + ) + + else: + + # seed agent feature + if self.seed_use_ego_motion: + motion_vector_seed = motion_vector_a[av_index].repeat_interleave(repeats=self.num_seed_feature, dim=0) + head_vector_seed = head_vector_a[av_index].repeat_interleave(repeats=self.num_seed_feature, dim=0) + else: + motion_vector_seed = head_vector_seed = None + feat_seed, _ = self._build_agent_feature(num_step, pos_a.device, + motion_vector_seed, + head_vector_seed, + state_index=self.invalid_state, + n=data.num_graphs * self.num_seed_feature) + + feat_a = torch.cat([feat_a, feat_seed], dim=0) # (a + s, t, d) + + return feat_a + + def _build_vector_a(self, pos_a, head_a, state_a): + num_agent = pos_a.shape[0] + + motion_vector_a = torch.cat([pos_a.new_zeros(num_agent, 1, self.input_dim), + pos_a[:, 1:] - pos_a[:, :-1]], dim=1) + + motion_vector_a[state_a == self.invalid_state] = self.invalid_motion_value + + # invalid -> valid + is_last_invalid = (state_a.roll(shifts=1, dims=1) == self.invalid_state) & (state_a != self.invalid_state) + is_last_invalid[:, 0] = state_a[:, 0] == self.enter_state + motion_vector_a[is_last_invalid] = self.motion_gap + + # valid -> invalid + is_last_valid = (state_a.roll(shifts=1, dims=1) != self.invalid_state) & (state_a == self.invalid_state) + is_last_valid[:, 0] = False + motion_vector_a[is_last_valid] = -self.motion_gap + + head_a[state_a == self.invalid_state] == self.invalid_head_value + head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1) + + return motion_vector_a, head_vector_a + + def _build_agent_feature(self, num_step, device, + motion_vector=None, + head_vector=None, + agent_token_emb=None, + agent_grid_emb=None, + offset_pos=None, + type=None, + shape=None, + categorical_embs_a=None, + state=None, + state_index=None, + n=1): + + if agent_token_emb is None: + agent_token_emb = self.no_token_emb(torch.zeros(1, device=device).long())[:, None].repeat(n, num_step, 1) + if state is not None: + agent_token_emb[state == self.enter_state] = self.bos_token_emb(torch.zeros(1, device=device).long()) + + if agent_grid_emb is None: + agent_grid_emb = self.grid_token_emb[None, None, self.grid_size // 2].repeat(n, num_step, 1) + + if motion_vector is None or head_vector is None: + pos_a = torch.zeros((n, num_step, 2), device=device) + head_a = torch.zeros((n, num_step), device=device) + if state is None: + state = torch.full((n, num_step), self.invalid_state, device=device) + motion_vector, head_vector = self._build_vector_a(pos_a, head_a, state) + + if offset_pos is None: + offset_pos = torch.zeros_like(motion_vector) + + feature_a = torch.stack( + [torch.norm(motion_vector[:, :, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=head_vector, nbr_vector=motion_vector[:, :, :2]), + # torch.norm(offset_pos[:, :, :2], p=2, dim=-1), + ], dim=-1) + + if categorical_embs_a is None: + if type is None: + type = torch.tensor([AGENT_TYPE.index('seed')], device=device) + if shape is None: + shape = torch.full((1, 3), self.invalid_shape_value, device=device) + + categorical_embs_a = [self.type_a_emb(type.reshape(-1)), self.shape_emb(shape.reshape(-1, shape.shape[-1]))] + + x_a = self.x_a_emb(continuous_inputs=feature_a.view(-1, feature_a.size(-1)), + categorical_embs=categorical_embs_a) + x_a = x_a.view(-1, num_step, self.hidden_dim) # (a, t, d) + + if state is None: + assert state_index is not None, f"state index need to be set when state tensor is None!" + state = torch.tensor([state_index], device=device)[:, None].repeat(n, num_step, 1) # do not use `expand` + s_a = self.state_a_emb(state.reshape(-1).long()).reshape(n, num_step, self.hidden_dim) + + feat_a = torch.cat((agent_token_emb, x_a, s_a), dim=-1) + if self.use_grid_token: + feat_a = torch.cat([feat_a, agent_grid_emb], dim=-1) + + feat_a = self.fusion_emb(feat_a) # (a, t, d) + + return feat_a, categorical_embs_a + + def _pad_feat(self, num_graph, av_index, *feats, num_seed_feature=None): + + if num_seed_feature is None: + num_seed_feature = self.num_seed_feature + + padded_feats = tuple() + for i in range(len(feats)): + padded_feats += (torch.cat([feats[i], feats[i][av_index].repeat_interleave( + repeats=num_seed_feature, dim=0)], + dim=0 + ),) + + pad_mask = torch.ones(*padded_feats[0].shape[:2], device=feats[0].device).bool() # (a, t) + pad_mask[-num_graph * num_seed_feature:] = False + + return padded_feats + (pad_mask,) + + # def _build_seed_feat(self, data, pos_a, head_a, state_a, head_vector_a, mask, sort_indices, av_index): + # seed_mask = sort_indices != av_index.repeat_interleave(repeats=data['batch_size_a'], dim=0)[:, None] + # # TODO: fix batch_size!!! + # print(mask.shape, sort_indices.shape, seed_mask.shape) + # mask[-data.num_graphs * self.num_seed_feature:] = seed_mask[:self.num_seed_feature] + + # insert_pos_a = torch.gather(pos_a, dim=0, index=sort_indices[:self.num_seed_feature, :, None].expand(-1, -1, pos_a.shape[-1])) + # pos_a[mask] = insert_pos_a[mask[-self.num_seed_feature:]] + + # state_a[-data.num_graphs * self.num_seed_feature:] = self.enter_state + + # return pos_a, head_a, state_a, head_vector_a, mask + + def _build_temporal_edge(self, data, pos_a, head_a, state_a, head_vector_a, mask, inference_mask=None): + + num_graph = data.num_graphs + num_agent = pos_a.shape[0] + hist_mask = mask.clone() + + if not self.temporal_attn_to_invalid: + is_bos = state_a == self.enter_state + bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0)) + history_invalid_mask = torch.arange(mask.shape[1]).expand(mask.shape[0], mask.shape[1]).to(mask.device) + history_invalid_mask = (history_invalid_mask < bos_index[:, None]) + hist_mask[history_invalid_mask] = False + + if not self.temporal_attn_seed: + hist_mask[-num_graph * self.num_seed_feature:] = False + if inference_mask is not None: + inference_mask[-num_graph * self.num_seed_feature:] = False + else: + # WARNING: if use temporal attn to seed + # we need to fix the pos/head of seed!!! + raise RuntimeError("Wrong settings!") + + pos_t = pos_a.reshape(-1, self.input_dim) # (num_agent * num_step, ...) + head_t = head_a.reshape(-1) + head_vector_t = head_vector_a.reshape(-1, 2) + + # for those invalid agents won't predict any motion token, we don't attend to them + is_bos = state_a == self.enter_state + is_bos[-num_graph * self.num_seed_feature:] = False + bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0)) + motion_predict_start_index = torch.clamp(bos_index - self.time_span / self.shift + 1, min=0) + motion_predict_mask = torch.arange(hist_mask.shape[1]).expand(hist_mask.shape[0], -1).to(hist_mask.device) + motion_predict_mask = motion_predict_mask >= motion_predict_start_index[:, None] + hist_mask[~motion_predict_mask] = False + + if self.hist_mask and self.training: + hist_mask[ + torch.arange(mask.shape[0]).unsqueeze(1), torch.randint(0, mask.shape[1], (num_agent, 10))] = False + mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1) + elif inference_mask is not None: + mask_t = hist_mask.unsqueeze(2) & inference_mask.unsqueeze(1) + else: + mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1) + + # mask_t: (num_agent, 18, 18), edge_index_t: (2, num_edge) + edge_index_t = dense_to_sparse(mask_t)[0] + edge_index_t = edge_index_t[:, (edge_index_t[1] - edge_index_t[0] > 0) & + (edge_index_t[1] - edge_index_t[0] <= self.time_span / self.shift)] + rel_pos_t = pos_t[edge_index_t[0]] - pos_t[edge_index_t[1]] + rel_head_t = wrap_angle(head_t[edge_index_t[0]] - head_t[edge_index_t[1]]) + + # handle the invalid steps + is_invalid = state_a == self.invalid_state + is_invalid_t = is_invalid.reshape(-1) + + rel_pos_t[is_invalid_t[edge_index_t[0]] & ~is_invalid_t[edge_index_t[1]]] = -self.motion_gap + rel_pos_t[~is_invalid_t[edge_index_t[0]] & is_invalid_t[edge_index_t[1]]] = self.motion_gap + rel_head_t[is_invalid_t[edge_index_t[0]] & ~is_invalid_t[edge_index_t[1]]] = -self.heading_gap + rel_head_t[~is_invalid_t[edge_index_t[1]] & is_invalid_t[edge_index_t[1]]] = self.heading_gap + + rel_pos_t[is_invalid_t[edge_index_t[0]] & is_invalid_t[edge_index_t[1]]] = self.invalid_motion_value + rel_head_t[is_invalid_t[edge_index_t[0]] & is_invalid_t[edge_index_t[1]]] = self.invalid_head_value + + r_t = torch.stack( + [torch.norm(rel_pos_t[:, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=head_vector_t[edge_index_t[1]], nbr_vector=rel_pos_t[:, :2]), + rel_head_t, + edge_index_t[0] - edge_index_t[1]], dim=-1) + r_t = self.r_t_emb(continuous_inputs=r_t, categorical_embs=None) + + return edge_index_t, r_t + + def _build_interaction_edge(self, data, pos_a, head_a, state_a, head_vector_a, batch_s, mask, pad_mask=None, inference_mask=None, + av_index=None, seq_mask=None, seq_index=None, grid_index_a=None, **plot_kwargs): + num_graph = data.num_graphs + num_agent, num_step, _ = pos_a.shape + is_training = inference_mask is None + + mask_a = mask.clone() + + if pad_mask is None: + pad_mask = torch.ones_like(state_a).bool() + + pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim) + head_s = head_a.transpose(0, 1).reshape(-1) + head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2) + pad_mask_s = pad_mask.transpose(0, 1).reshape(-1) + if inference_mask is not None: + mask_a = mask_a & inference_mask + mask_s = mask_a.transpose(0, 1).reshape(-1) + + # build agent2agent bilateral connection + edge_index_a2a = radius_graph(x=pos_s[:, :2], r=self.a2a_radius, batch=batch_s, loop=False, + max_num_neighbors=300) + edge_index_a2a = subgraph(subset=mask_s & pad_mask_s, edge_index=edge_index_a2a)[0] + + if int(os.getenv('PLOT_EDGE', 0)): + plot_interact_edge(edge_index_a2a, data['scenario_id'], data['batch_size_a'].cpu(), self.num_seed_feature, num_step, + av_index=av_index, **plot_kwargs) + + rel_pos_a2a = pos_s[edge_index_a2a[0]] - pos_s[edge_index_a2a[1]] + rel_head_a2a = wrap_angle(head_s[edge_index_a2a[0]] - head_s[edge_index_a2a[1]]) + + # handle the invalid steps + is_invalid = state_a == self.invalid_state + is_invalid_s = is_invalid.transpose(0, 1).reshape(-1) + + rel_pos_a2a[is_invalid_s[edge_index_a2a[0]] & ~is_invalid_s[edge_index_a2a[1]]] = -self.motion_gap + rel_pos_a2a[~is_invalid_s[edge_index_a2a[0]] & is_invalid_s[edge_index_a2a[1]]] = self.motion_gap + rel_head_a2a[is_invalid_s[edge_index_a2a[0]] & ~is_invalid_s[edge_index_a2a[1]]] = -self.heading_gap + rel_head_a2a[~is_invalid_s[edge_index_a2a[1]] & is_invalid_s[edge_index_a2a[1]]] = self.heading_gap + + rel_pos_a2a[is_invalid_s[edge_index_a2a[0]] & is_invalid_s[edge_index_a2a[1]]] = self.invalid_motion_value + rel_head_a2a[is_invalid_s[edge_index_a2a[0]] & is_invalid_s[edge_index_a2a[1]]] = self.invalid_head_value + + r_a2a = torch.stack( + [torch.norm(rel_pos_a2a[:, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2a[1]], nbr_vector=rel_pos_a2a[:, :2]), + rel_head_a2a], dim=-1) + r_a2a = self.r_a2a_emb(continuous_inputs=r_a2a, categorical_embs=None) + + # add the edges which connect seed agents + if is_training: + mask_av = torch.ones_like(mask_a).bool() + if not self.seed_attn_to_av: + mask_av[av_index] = False + mask_a &= mask_av + edge_index_seed2a, r_seed2a = self._build_a2sa_edge(data, pos_a, head_a, head_vector_a, batch_s, + mask_a.clone(), ~pad_mask.clone(), inference_mask=inference_mask, + r=self.pl2seed_radius, max_num_neighbors=300, + seq_mask=seq_mask, seq_index=seq_index, grid_index_a=grid_index_a, mode='insert') + + if os.getenv('PLOT_EDGE', False): + plot_interact_edge(edge_index_seed2a, data['scenario_id'], data['batch_size_a'].cpu(), self.num_seed_feature, num_step, + 'interact_edge_map_seed', av_index=av_index, **plot_kwargs) + + edge_index_a2a = torch.cat([edge_index_a2a, edge_index_seed2a], dim=-1) + r_a2a = torch.cat([r_a2a, r_seed2a]) + + return edge_index_a2a, r_a2a, (edge_index_a2a.shape[1], edge_index_seed2a.shape[1]) + + return edge_index_a2a, r_a2a + + def _build_map2agent_edge(self, data, pos_a, head_a, state_a, head_vector_a, batch_s, batch_pl, + mask, pad_mask=None, inference_mask=None, av_index=None, **kwargs): + num_graph = data.num_graphs + num_agent, num_step, _ = pos_a.shape + is_training = inference_mask is None + + mask_pl2a = mask.clone() + + if pad_mask is None: + pad_mask = torch.ones_like(state_a).bool() + + pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim) + head_s = head_a.transpose(0, 1).reshape(-1) + head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2) + pad_mask_s = pad_mask.transpose(0, 1).reshape(-1) + if inference_mask is not None: + mask_pl2a = mask_pl2a & inference_mask + mask_pl2a = mask_pl2a.transpose(0, 1).reshape(-1) + + ori_pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous() + ori_orient_pl = data['pt_token']['orientation'].contiguous() + pos_pl = ori_pos_pl.repeat(num_step, 1) # not `repeat_interleave` + orient_pl = ori_orient_pl.repeat(num_step) + + # build map2agent directed graph + # edge_index_pl2a = radius(x=pos_s[:, :2], y=pos_pl[:, :2], r=self.pl2a_radius, + # batch_x=batch_s, batch_y=batch_pl, max_num_neighbors=300) + edge_index_pl2a = radius(x=pos_pl[:, :2], y=pos_s[:, :2], r=self.pl2a_radius, + batch_x=batch_pl, batch_y=batch_s, max_num_neighbors=5) + edge_index_pl2a = edge_index_pl2a[[1, 0]] + edge_index_pl2a = edge_index_pl2a[:, mask_pl2a[edge_index_pl2a[1]] & + pad_mask_s[edge_index_pl2a[1]]] + + rel_pos_pl2a = pos_pl[edge_index_pl2a[0]] - pos_s[edge_index_pl2a[1]] + rel_orient_pl2a = wrap_angle(orient_pl[edge_index_pl2a[0]] - head_s[edge_index_pl2a[1]]) + + # handle the invalid steps + is_invalid = state_a == self.invalid_state + is_invalid_s = is_invalid.transpose(0, 1).reshape(-1) + rel_pos_pl2a[is_invalid_s[edge_index_pl2a[1]]] = self.motion_gap + rel_orient_pl2a[is_invalid_s[edge_index_pl2a[1]]] = self.heading_gap + + r_pl2a = torch.stack( + [torch.norm(rel_pos_pl2a[:, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_pl2a[1]], nbr_vector=rel_pos_pl2a[:, :2]), + rel_orient_pl2a], dim=-1) + r_pl2a = self.r_pt2a_emb(continuous_inputs=r_pl2a, categorical_embs=None) + + # add the edges which connect seed agents + if is_training: + edge_index_pl2seed, r_pl2seed = self._build_map2sa_edge(data, pos_a, head_a, head_vector_a, batch_s, batch_pl, + ~pad_mask.clone(), inference_mask=inference_mask, + r=self.pl2seed_radius, max_num_neighbors=2048, mode='insert') + + # sanity check + # pl2a_index = torch.zeros(pos_a.shape[0], num_step) + # pl2a_r = torch.zeros(pos_a.shape[0], num_step) + # for src_index in torch.unique(edge_index_pl2seed[1]): + # src_row = src_index % pos_a.shape[0] + # src_col = src_index // pos_a.shape[0] + # pl2a_index[src_row, src_col] = edge_index_pl2seed[0, edge_index_pl2seed[1] == src_index].sum() + # pl2a_r[src_row, src_col] = r_pl2seed[edge_index_pl2seed[1] == src_index].sum() + # print(pl2a_index) + # print(pl2a_r) + # exit(1) + + if os.getenv('PLOT_EDGE', False): + plot_interact_edge(edge_index_pl2seed, data['scenario_id'], data['batch_size_a'].cpu(), self.num_seed_feature, num_step, + 'interact_edge_map_seed', av_index=av_index) + + edge_index_pl2a = torch.cat([edge_index_pl2a, edge_index_pl2seed], dim=-1) + r_pl2a = torch.cat([r_pl2a, r_pl2seed]) + + return edge_index_pl2a, r_pl2a, (edge_index_pl2a.shape[1], edge_index_pl2seed.shape[1]) + + return edge_index_pl2a, r_pl2a + + def _build_a2sa_edge(self, data, pos_a, head_a, head_vector_a, batch_s, mask_a, mask_sa, + inference_mask=None, r=None, max_num_neighbors=8, seq_mask=None, seq_index=None, + grid_index_a=None, mode: Literal['insert', 'feature']='feature', **plot_kwargs): + + num_agent, num_step, _ = pos_a.shape + is_training = inference_mask is None + + pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim) + head_s = head_a.transpose(0, 1).reshape(-1) + head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2) + if inference_mask is not None: + mask_a = mask_a & inference_mask + mask_sa = mask_sa & inference_mask + mask_s = mask_a.transpose(0, 1).reshape(-1) + mask_s_sa = mask_sa.transpose(0, 1).reshape(-1) + + # build seed_agent2agent unilateral connection + assert r is not None, "r needs to be specified!" + # edge_index_a2sa = radius(x=pos_s[mask_s_sa, :2], y=pos_s[:, :2], r=r, + # batch_x=batch_s[mask_s_sa], batch_y=batch_s, max_num_neighbors=max_num_neighbors) + edge_index_a2sa = radius(x=pos_s[:, :2], y=pos_s[mask_s_sa, :2], r=r, + batch_x=batch_s, batch_y=batch_s[mask_s_sa], max_num_neighbors=max_num_neighbors) + edge_index_a2sa = edge_index_a2sa[[1, 0]] + edge_index_a2sa = edge_index_a2sa[:, ~mask_s_sa[edge_index_a2sa[0]] & mask_s[edge_index_a2sa[0]]] + + # only for seed agent sequence training + if seq_mask is not None: + edge_mask = seq_mask[edge_index_a2sa[1]] + edge_mask = torch.gather(edge_mask, dim=1, index=edge_index_a2sa[0, :, None] % num_agent)[:, 0] + edge_index_a2sa = edge_index_a2sa[:, edge_mask] + + if seq_index is None: + seq_index = torch.zeros(num_agent, device=pos_a.device).long() + if seq_index.dim() == 1: + seq_index = seq_index[:, None].repeat(1, num_step) + seq_index = seq_index.transpose(0, 1).reshape(-1) + assert seq_index.shape[0] == pos_s.shape[0], f"Inconsistent lenght {seq_index.shape[0]} and {pos_s.shape[0]}!" + + # convert to global index + all_index = torch.arange(pos_s.shape[0], device=pos_a.device).long() + sa_index = all_index[mask_s_sa] + edge_index_a2sa[1] = sa_index[edge_index_a2sa[1]] + + # plot edge index TODO: now only support bs=1 + if os.getenv('PLOT_EDGE_INFERENCE', False) and not is_training: + num_agent, num_step, _ = pos_a.shape + # plot_interact_edge(edge_index_a2sa, data['scenario_id'], data['batch_size_a'].cpu(), 1, num_step, + # 'interact_a2sa_edge_map', **plot_kwargs) + plot_interact_edge(edge_index_a2sa, data['scenario_id'], torch.tensor([num_agent - 1]), 1, num_step, + f"interact_a2sa_edge_map_infer_{plot_kwargs['tag']}", **plot_kwargs) + + rel_pos_a2sa = pos_s[edge_index_a2sa[0]] - pos_s[edge_index_a2sa[1]] + rel_head_a2sa = wrap_angle(head_s[edge_index_a2sa[0]] - head_s[edge_index_a2sa[1]]) + + if mode == 'insert': + + # assert grid_index_a is not None, f"Missing input: grid_index_a!" + # grid_index_s = grid_index_a.transpose(0, 1).reshape(-1) + # assert grid_index_s[edge_index_a2sa[0]].min() >= 0, "Found invalid values in grid index" + + # r_a2sa = torch.stack( + # [self.attr_tokenizer.dist[grid_index_s[edge_index_a2sa[0]]], + # self.attr_tokenizer.dir[grid_index_s[edge_index_a2sa[0]]], + # rel_head_a2sa, + # seq_index[edge_index_a2sa[0]] - seq_index[edge_index_a2sa[1]]], dim=-1) + + # r_a2sa = torch.stack( + # [torch.norm(rel_pos_a2sa[:, :2], p=2, dim=-1), + # angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2sa[1]], nbr_vector=rel_pos_a2sa[:, :2]), + # rel_head_a2sa, + # seq_index[edge_index_a2sa[0]] - seq_index[edge_index_a2sa[1]]], dim=-1) + r_a2sa = torch.stack( + [torch.norm(rel_pos_a2sa[:, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2sa[1]], nbr_vector=rel_pos_a2sa[:, :2]), + rel_head_a2sa], dim=-1) + # TODO: try categorical embs + r_a2sa = self.r_a2sa_emb(continuous_inputs=r_a2sa, categorical_embs=None) + + elif mode == 'feature': + + r_a2sa = torch.stack( + [torch.norm(rel_pos_a2sa[:, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2sa[1]], nbr_vector=rel_pos_a2sa[:, :2]), + rel_head_a2sa], dim=-1) + r_a2sa = self.r_a2a_emb(continuous_inputs=r_a2sa, categorical_embs=None) + + else: + raise ValueError(f"Unsupport mode {mode}.") + + return edge_index_a2sa, r_a2sa + + def _build_map2sa_edge(self, data, pos_a, head_a, head_vector_a, batch_s, batch_pl, + mask_sa, inference_mask=None, r=None, max_num_neighbors=32, mode: Literal['insert', 'feature']='feature'): + + _, num_step, _ = pos_a.shape + + mask_pl2sa = torch.ones_like(mask_sa).bool() + if inference_mask is not None: + mask_pl2sa = mask_pl2sa & inference_mask + mask_pl2sa = mask_pl2sa.transpose(0, 1).reshape(-1) + mask_s_sa = mask_sa.transpose(0, 1).reshape(-1) + + pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim) + head_s = head_a.transpose(0, 1).reshape(-1) + head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2) + + ori_pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous() + ori_orient_pl = data['pt_token']['orientation'].contiguous() + pos_pl = ori_pos_pl.repeat(num_step, 1) # not `repeat_interleave` + orient_pl = ori_orient_pl.repeat(num_step) + + # build map2agent directed graph + assert r is not None, "r needs to be specified!" + # edge_index_pl2sa = radius(x=pos_s[mask_s_sa, :2], y=pos_pl[:, :2], r=r, + # batch_x=batch_s[mask_s_sa], batch_y=batch_pl, max_num_neighbors=max_num_neighbors) + edge_index_pl2sa = radius(x=pos_pl[:, :2], y=pos_s[mask_s_sa, :2], r=r, + batch_x=batch_pl, batch_y=batch_s[mask_s_sa], max_num_neighbors=max_num_neighbors) + edge_index_pl2sa = edge_index_pl2sa[[1, 0]] + edge_index_pl2sa = edge_index_pl2sa[:, mask_pl2sa[mask_s_sa][edge_index_pl2sa[1]]] + + # convert to global index + all_index = torch.arange(pos_s.shape[0], device=pos_a.device).long() + sa_index = all_index[mask_s_sa] + edge_index_pl2sa[1] = sa_index[edge_index_pl2sa[1]] + + # plot edge map + # if os.getenv('PLOT_EDGE', False): + # plot_map_edge(edge_index_pl2sa, pos_s[:, :2], data, save_path='map2sa_edge_map') + + rel_pos_pl2sa = pos_pl[edge_index_pl2sa[0]] - pos_s[edge_index_pl2sa[1]] + rel_orient_pl2sa = wrap_angle(orient_pl[edge_index_pl2sa[0]] - head_s[edge_index_pl2sa[1]]) + + r_pl2sa = torch.stack( + [torch.norm(rel_pos_pl2sa[:, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_pl2sa[1]], nbr_vector=rel_pos_pl2sa[:, :2]), + rel_orient_pl2sa], dim=-1) + + if mode == 'insert': + r_pl2sa = self.r_pt2sa_emb(continuous_inputs=r_pl2sa, categorical_embs=None) + elif mode == 'feature': + r_pl2sa = self.r_pt2a_emb(continuous_inputs=r_pl2sa, categorical_embs=None) + else: + raise ValueError(f"Unsupport mode {mode}.") + + return edge_index_pl2sa, r_pl2sa + + # def _build_sa2sa_edge(self, data, pos_a, head_a, state_a, head_vector_a, batch_s, mask, inference_mask=None, **plot_kwargs): + + # num_agent = pos_a.shape[0] + + # pos_t = pos_a.transpose(0, 1).reshape(-1, self.input_dim) + # head_t = head_a.reshape(-1) + # head_vector_t = head_vector_a.reshape(-1, 2) + + # if inference_mask is not None: + # mask_t = mask.unsqueeze(2) & inference_mask.unsqueeze(1) + # else: + # mask_t = mask.unsqueeze(2) & mask.unsqueeze(1) + + # edge_index_sa2sa = dense_to_sparse(mask_t)[0] + # edge_index_sa2sa = edge_index_sa2sa[:, edge_index_sa2sa[1] - edge_index_sa2sa[0] > 0] + # rel_pos_t = pos_t[edge_index_sa2sa[0]] - pos_t[edge_index_sa2sa[1]] + # rel_head_t = wrap_angle(head_t[edge_index_sa2sa[0]] - head_t[edge_index_sa2sa[1]]) + + # r_t = torch.stack( + # [torch.norm(rel_pos_t[:, :2], p=2, dim=-1), + # angle_between_2d_vectors(ctr_vector=head_vector_t[edge_index_sa2sa[1]], nbr_vector=rel_pos_t[:, :2]), + # rel_head_t, + # edge_index_sa2sa[0] - edge_index_sa2sa[1]], dim=-1) + # r_sa2sa = self.r_sa2sa_emb(continuous_inputs=r_t, categorical_embs=None) + + # return edge_index_sa2sa, r_sa2sa + + def get_inputs(self, data: HeteroData) -> Dict[str, torch.Tensor]: + + pos_a = data['agent']['token_pos'].clone() + head_a = data['agent']['token_heading'].clone() + agent_token_index = data['agent']['token_idx'].clone() + agent_state_index = data['agent']['state_idx'].clone() + mask = data['agent']['raw_agent_valid_mask'].clone() + + agent_grid_token_idx = data['agent']['grid_token_idx'] + agent_grid_offset_xy = data['agent']['grid_offset_xy'] + agent_head_token_idx = data['agent']['heading_token_idx'] + sort_indices = data['agent']['sort_indices'] + + next_token_index_gt = agent_token_index.roll(shifts=-1, dims=1) + next_state_index_gt = agent_token_index.roll(shifts=-1, dims=1) + + # next token prediction mask + bos_token_index = torch.nonzero(agent_state_index == self.enter_state) + eos_token_index = torch.nonzero(agent_state_index == self.exit_state) + + # mask for motion tokens + next_token_eval_mask = mask.clone() + next_token_eval_mask = next_token_eval_mask * next_token_eval_mask.roll(shifts=-1, dims=1) * next_token_eval_mask.roll(shifts=1, dims=1) + for bos_token_index_ in bos_token_index: + next_token_eval_mask[bos_token_index_[0], bos_token_index_[1] : bos_token_index_[1] + 1] = 1 + next_token_eval_mask[bos_token_index_[0], bos_token_index_[1] + 1 : bos_token_index_[1] + 2] = \ + mask[bos_token_index_[0], bos_token_index_[1] + 2 : bos_token_index_[1] + 3] + next_token_eval_mask[eos_token_index[:, 0], eos_token_index[:, 1]] = 0 + + # mask for state tokens + next_state_eval_mask = mask.clone() + next_state_eval_mask = next_state_eval_mask * next_state_eval_mask.roll(shifts=-1, dims=1) * next_state_eval_mask.roll(shifts=1, dims=1) + for bos_token_index_ in bos_token_index: + next_state_eval_mask[bos_token_index_[0], :bos_token_index_[1]] = 0 + next_state_eval_mask[bos_token_index_[0], bos_token_index_[1] : bos_token_index_[1] + 1] = 1 + next_state_eval_mask[bos_token_index_[0], bos_token_index_[1] + 1 : bos_token_index_[1] + 2] = \ + mask[bos_token_index_[0], bos_token_index_[1] + 2 : bos_token_index_[1] + 3] + for eos_token_index_ in eos_token_index: + next_state_eval_mask[eos_token_index_[0], eos_token_index_[1] + 1:] = 1 + next_state_eval_mask[eos_token_index_[0], eos_token_index_[1] : eos_token_index_[1] + 1] = \ + mask[eos_token_index_[0], eos_token_index_[1] - 1 : eos_token_index_[1]] + + # the last timestep is the beginning of the sequence (also the input) + next_token_eval_mask[:, 0] = mask[:, 0] * mask[:, 1] + next_state_eval_mask[:, 0] = mask[:, 0] * mask[:, 1] + next_token_eval_mask[:, -1] = 0 + next_state_eval_mask[:, -1] = 0 + + if next_token_index_gt[next_token_eval_mask].min() < 0: + raise RuntimeError() + + return {'token_pos': pos_a, + 'token_heading': head_a, + 'next_token_idx_gt': next_token_index_gt, + 'next_state_idx_gt': next_state_index_gt, + 'next_token_eval_mask': next_token_eval_mask, + 'raw_agent_valid_mask': data['agent']['raw_agent_valid_mask'], + 'state_token': agent_state_index, + 'grid_index': agent_grid_token_idx, + } + + def _build_seq(self, device, data, num_agent, num_step, av_index, sort_indices): + """ + Args: + sort_indices (torch.Tensor): shape (num_agent, num_atep) + """ + ptr = data['agent']['ptr'] + num_graph = len(ptr) - 1 + + # sort_indices = sort_indices[:self.num_seed_feature] + seq_mask = torch.ones(num_graph * self.num_seed_feature, num_step, num_agent + num_graph * self.num_seed_feature, device=device).bool() + seq_mask[..., -num_graph * self.num_seed_feature:] = False + for b in range(num_graph): + batch_sort_indices = sort_indices[ptr[b] : ptr[b + 1]] + for t in range(num_step): + for s in range(self.num_seed_feature): + seq_mask[b * self.num_seed_feature + s, t, batch_sort_indices[s:, t].flatten().long()] = False + if self.seed_attn_to_av: + seq_mask[..., av_index] = True + seq_mask = seq_mask.transpose(0, 1).reshape(-1, num_agent + num_graph * self.num_seed_feature) + + seq_index = torch.cat([torch.zeros(num_agent), (torch.arange(self.num_seed_feature) + 1).repeat(num_graph)]).to(device) + seq_index = seq_index[:, None].repeat(1, num_step) + # 0, 0, 0, ..., 1, 2, 3, ... + for b in range(num_graph): + batch_sort_indices = sort_indices[ptr[b] : ptr[b + 1]] + for t in range(num_step): + for s in range(self.num_seed_feature): + seq_index[batch_sort_indices[s : s + 1, t].flatten().long() + ptr[b], t] = s + 1 + + # 0, 2, 1, ..., N+1, N+2, ... + # for b in range(num_graph): + # batch_sort_indices = sort_indices[ptr[b] : ptr[b + 1]] + # batch_agent_valid_mask = data['agent']['inrange_mask'][ptr[b] : ptr[b + 1]] & \ + # data['agent']['raw_agent_valid_mask'][ptr[b] : ptr[b + 1]] & \ + # ~data['agent']['bos_mask'][ptr[b] : ptr[b + 1]] + # batch_agent_valid_mask[av_index[b]] = False + # for t in range(num_step): + # batch_num_valid_agent_t = batch_agent_valid_mask[:, t].sum() + # seq_index[num_agent + b * self.num_seed_feature : num_agent + (b + 1) * self.num_seed_feature, t] += batch_num_valid_agent_t + # random_seq_index = torch.zeros(ptr[b + 1] - ptr[b], device=device) + # random_seq_index[batch_agent_valid_mask[:, t]] = torch.randperm(batch_num_valid_agent_t, device=device).float() + 1 # starts from 1 + # seq_index[ptr[b] : ptr[b + 1], t] = random_seq_index + # for s in range(self.num_seed_feature): + # seq_index[batch_sort_indices[s : s + 1, t].flatten().long() + ptr[b], t] = s + 1 + batch_num_valid_agent_t.float() + + # 0, 0, 0, ..., N+1, N+2, ... + # for b in range(num_graph): + # batch_sort_indices = sort_indices[ptr[b] : ptr[b + 1]] + # batch_agent_valid_mask = data['agent']['inrange_mask'][ptr[b] : ptr[b + 1]] & \ + # data['agent']['raw_agent_valid_mask'][ptr[b] : ptr[b + 1]] & \ + # ~data['agent']['bos_mask'][ptr[b] : ptr[b + 1]] + # batch_agent_valid_mask[av_index[b]] = False + # for t in range(num_step): + # batch_num_valid_agent_t = batch_agent_valid_mask[:, t].sum() + # seq_index[num_agent + b * self.num_seed_feature : num_agent + (b + 1) * self.num_seed_feature, t] += batch_num_valid_agent_t + # for s in range(self.num_seed_feature): + # seq_index[batch_sort_indices[s : s + 1, t].flatten().long() + ptr[b], t] = s + 1 + batch_num_valid_agent_t.float() + + seq_index[av_index] = 0 + + return seq_mask, seq_index + + def _build_occ_gt(self, data, seq_mask, pos_rel_index_gt, pos_rel_index_gt_seed=None, mask_seed=None, + edge_index=None, mode='edge_index'): + """ + Args: + seq_mask (torch.Tensor): shape (num_step * num_seed_feature, num_agent + self.num_seed_feature) + pos_rel_index_gt (torch.Tensor): shape (num_agent, num_step) + pos_rel_index_gt_seed (torch.Tensor): shape (num_seed, num_step) + """ + num_agent = data['agent']['state_idx'].shape[0] + data.num_graphs * self.num_seed_feature + num_step = data['agent']['state_idx'].shape[1] + data['agent']['agent_occ'] = torch.zeros(data.num_graphs * self.num_seed_feature, num_step, self.attr_tokenizer.grid_size, + device=data['agent']['state_idx'].device).long() + data['agent']['map_occ'] = torch.zeros(data.num_graphs, num_step, self.attr_tokenizer.grid_size, + device=data['agent']['state_idx'].device).long() + + if mode == 'edge_index': + + assert edge_index is not None, f"Need edge_index input!" + for src_index in torch.unique(edge_index[1]): + # decode src + src_row = src_index % num_agent - (num_agent - data.num_graphs * self.num_seed_feature) + src_col = src_index // num_agent + # decode tgt + tgt_indexes = edge_index[0, edge_index[1] == src_index] + tgt_rows = tgt_indexes % num_agent + tgt_cols = tgt_indexes // num_agent + assert tgt_rows.max() < num_agent - data.num_graphs * self.num_seed_feature, f"Invalid {tgt_rows}" + assert torch.unique(tgt_cols).shape[0] == 1 and torch.unique(tgt_cols)[0] == src_col + data['agent']['agent_occ'][src_row, src_col, pos_rel_index_gt[tgt_rows, tgt_cols]] = 1 + + else: + + seq_mask = seq_mask.reshape(num_step, self.num_seed_feature, -1).transpose(0, 1)[..., :-self.num_seed_feature] + for s in range(self.num_seed_feature): + for t in range(num_step): + index = pos_rel_index_gt[seq_mask[s, t], t] + data['agent']['agent_occ'][s, t, index[index != -1]] = 1 + if t > 0 and s < pos_rel_index_gt_seed.shape[0] and mask_seed[s, t - 1]: # insert agents + data['agent']['agent_occ'][s, t, pos_rel_index_gt_seed[s, t - 1]] = -1 + + ptr = data['pt_token']['ptr'] + pt_grid_token_idx = data['agent']['pt_grid_token_idx'] # (t, num_pt) + for b in range(data.num_graphs): + batch_pt_grid_token_idx = pt_grid_token_idx[:, ptr[b] : ptr[b + 1]] + for t in range(num_step): + data['agent']['map_occ'][b, t, batch_pt_grid_token_idx[t][batch_pt_grid_token_idx[t] != -1]] = 1 + data['agent']['map_occ'] = data['agent']['map_occ'].repeat_interleave(repeats=self.num_seed_feature, dim=0) + + def forward(self, + data: HeteroData, + map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + + pos_a = data['agent']['token_pos'].clone() # (a, t, 2) + head_a = data['agent']['token_heading'].clone() # (a, t) + num_agent, num_step, traj_dim = pos_a.shape # e.g. (50, 18, 2) + agent_shape = data['agent']['shape'][:, self.num_historical_steps - 1].clone() # (a, 3) + agent_token_index = data['agent']['token_idx'].clone() # (a, t) + agent_state_index = data['agent']['state_idx'].clone() + agent_type_index = data['agent']['type'].clone() + + av_index = data['agent']['av_index'].long() + ego_pos = pos_a[av_index] + ego_head = head_a[av_index] + + _, head_vector_a = self._build_vector_a(pos_a, head_a, agent_state_index) + + agent_grid_token_idx = data['agent']['grid_token_idx'] + agent_grid_offset_xy = data['agent']['grid_offset_xy'] + agent_pos_xy = data['agent']['pos_xy'] + agent_head_token_idx = data['agent']['heading_token_idx'] + sort_indices = data['agent']['sort_indices'] + + device = pos_a.device + + feat_a = self._agent_token_embedding(data, + agent_token_index, + agent_state_index, + agent_grid_token_idx, + pos_a, + head_a, + av_index=av_index) + + raw_feat_a = feat_a[:-data.num_graphs * self.num_seed_feature].clone() + raw_feat_seed = feat_a[-data.num_graphs * self.num_seed_feature:].clone() + + # build masks + mask = data['agent']['raw_agent_valid_mask'].clone() + temporal_mask = mask.clone() + interact_mask = mask.clone() + + is_bos = agent_state_index == self.enter_state + is_eos = agent_state_index == self.exit_state + bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0)) + eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1)) # not `-1` + + temporal_mask = torch.ones_like(mask) + motion_mask = torch.arange(mask.shape[1]).expand(mask.shape[0], -1).to(device) + motion_mask = (motion_mask > bos_index[:, None]) & (motion_mask <= eos_index[:, None]) + temporal_mask[motion_mask] = mask[motion_mask] + temporal_mask = torch.cat([temporal_mask, torch.ones(data.num_graphs * self.num_seed_feature, *temporal_mask.shape[1:], device=device)]).bool() + + interact_mask[agent_state_index == self.enter_state] = True + interact_mask = torch.cat([interact_mask, torch.ones(data.num_graphs * self.num_seed_feature, *interact_mask.shape[1:], device=device)]).bool() # placeholder + + pos_a_p, head_a_p, state_a_p, head_vector_a_p, grid_index_a_p, pad_mask = \ + self._pad_feat(data.num_graphs, av_index, pos_a, head_a, agent_state_index, head_vector_a, agent_grid_token_idx) + edge_index_t, r_t = self._build_temporal_edge(data, pos_a_p, head_a_p, state_a_p, head_vector_a_p, temporal_mask) + + # placeholder for seed agent + batch_s = torch.cat([ + torch.cat([data['agent']['batch'], torch.arange(data.num_graphs, device=device + ).repeat_interleave(repeats=self.num_seed_feature, dim=0)], dim=0) + + data.num_graphs * t for t in range(num_step) + ], dim=0) + batch_pl = torch.cat([data['pt_token']['batch'] + data.num_graphs * t for t in range(num_step)], dim=0) + + seq_mask, seq_index = self._build_seq(device, data, num_agent, num_step, av_index, sort_indices) + plot_kwargs = dict(is_bos=agent_state_index == self.enter_state) + edge_index_a2a, r_a2a, (na2a, na2sa) = self._build_interaction_edge(data, pos_a_p, head_a_p, state_a_p, head_vector_a_p, batch_s, + interact_mask, pad_mask=pad_mask, av_index=av_index, + seq_mask=seq_mask, seq_index=seq_index, grid_index_a=grid_index_a_p, **plot_kwargs) + + edge_index_pl2a, r_pl2a, (npl2a, npl2sa) = self._build_map2agent_edge(data, pos_a_p, head_a_p, state_a_p, head_vector_a_p, batch_s, batch_pl, + interact_mask, pad_mask=pad_mask, av_index=av_index) + interact_mask = interact_mask[:-data.num_graphs * self.num_seed_feature] + + # pos_a_s, head_a_s, state_a_s, head_vector_a_s, mask_a_s = self._build_seed_feat(data, pos_a_p, head_a_p, state_a_p, head_vector_a_p, ~pad_mask, + # sort_indices, av_index=av_index) + # edge_index_sa2sa, r_sa2sa = self._build_sa2sa_edge(data, pos_a_s, head_a_s, state_a_s, head_vector_a_s, batch_s, mask=mask_a_s) + + # for i in range(self.num_layers): + + # feat_a = feat_a.reshape(-1, self.hidden_dim) # (a, t, d) -> (a*t, d) + # feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t) + + # feat_a = feat_a.reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim) + # feat_a = self.pt2a_attn_layers[i](( + # map_enc['x_pt'].repeat_interleave(repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape( + # -1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a) + + # feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a) + # feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1) + + # predict next motions + for i in range(self.num_layers): + + feat_a = feat_a.reshape(-1, self.hidden_dim) # (a, t, d) -> (a*t, d) + feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t) + + feat_a = feat_a.reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim) + feat_a = self.pt2a_attn_layers[i](( + map_enc['x_pt'].repeat_interleave(repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape( + -1, self.hidden_dim), feat_a), r_pl2a[:npl2a], edge_index_pl2a[:, :npl2a]) + + feat_a = self.a2a_attn_layers[i](feat_a, r_a2a[:na2a], edge_index_a2a[:, :na2a]) + feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1) + + feat_ea = feat_a[:-data.num_graphs * self.num_seed_feature] + + # next motion token + next_token_prob = self.token_predict_head(feat_ea) # (a, t, token_size) + next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1) + _, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1) # (a, t, 10) + + next_token_index_gt = agent_token_index.roll(shifts=-1, dims=1) + + # next state token + next_state_prob = self.state_predict_head(feat_ea) + next_state_idx = next_state_prob.softmax(dim=-1).argmax(dim=-1, keepdim=True) # (a, t, 1) + + next_state_index_gt = agent_state_index.roll(shifts=-1, dims=1) # (invalid, valid, exit) + + # predict next agents: coarse stage + grid_agent_occ_gt_seed = grid_pt_occ_gt_seed = None + if self.use_grid_token: + self._build_occ_gt(data, seq_mask, agent_grid_token_idx.long(), edge_index=edge_index_a2a[:, -na2sa:], mode='edge_index') + grid_agent_occ_gt_seed = data['agent']['agent_occ'] + grid_pt_occ_gt_seed = data['agent']['map_occ'] + + if self.use_grid_token: + occ_embed_a = self.seed_agent_occ_embed(grid_agent_occ_gt_seed.transpose(0, 1).reshape(-1, self.grid_size).float()) + # occ_embed_pt = self.seed_pt_occ_embed(grid_pt_occ_gt_seed.transpose(0, 1).reshape(-1, self.grid_size).float()) + edge_index_occ2sa_src = torch.arange(feat_a.shape[0] * feat_a.shape[1], device=device).long() + edge_index_occ2sa_src = edge_index_occ2sa_src[~pad_mask.transpose(0, 1).reshape(-1)] + edge_index_occ2sa_tgt = torch.arange(occ_embed_a.shape[0], device=device).long() + edge_index_occ2sa = torch.stack([edge_index_occ2sa_tgt, edge_index_occ2sa_src], dim=0) + + feat_sa = torch.cat([raw_feat_a, raw_feat_seed]) + # feat_sa = feat_a + for i in range(self.seed_layers): + + # feat_sa = feat_a.reshape(-1, self.hidden_dim) + # feat_sa = self.sa2sa_attn_layers[i](feat_sa, r_sa2sa, edge_index_sa2sa) + + feat_sa = feat_sa.reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim) + if self.use_grid_token: + feat_sa = self.occ2sa_attn_layers[i]((occ_embed_a, feat_sa), None, edge_index_occ2sa) + feat_sa = self.pt2sa_attn_layers[i](( + map_enc['x_pt'].repeat_interleave(repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape( + -1, self.hidden_dim), feat_sa), r_pl2a[-npl2sa:], edge_index_pl2a[:, -npl2sa:]) + feat_sa = self.a2sa_attn_layers[i](feat_sa, r_a2a[-na2sa:], edge_index_a2a[:, -na2sa:]) + feat_sa = feat_sa.reshape(num_step, -1, self.hidden_dim).transpose(0, 1) + + feat_seed = feat_sa[-data.num_graphs * self.num_seed_feature:] + + # seed agent + next_state_prob_seed = self.seed_state_predict_head(feat_seed) + raw_next_state_prob_seed = next_state_prob_seed.clone() + next_state_idx_seed = next_state_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True) # (seed_size, t, 1) + + next_type_prob_seed = self.seed_type_predict_head(feat_seed) + next_type_idx_seed = next_type_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True) + + next_type_index_gt = agent_type_index[:, None].repeat(1, num_step).long() + + next_shape_seed = self.seed_shape_predict_head(feat_seed) + + next_shape_gt = agent_shape[:, None].repeat(1, num_step, 1).float() + + if self.use_grid_token: + next_pos_rel_prob_seed = self.seed_pos_rel_token_predict_head(feat_seed) + next_pos_rel_idx_seed = next_pos_rel_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True) + else: + next_pos_rel_prob_seed = self.seed_pos_rel_xy_predict_head(feat_seed) + next_pos_rel_xy_seed = F.tanh(next_pos_rel_prob_seed) + + next_pos_rel_index_gt = agent_grid_token_idx.long() + next_pos_rel_xy_gt = agent_pos_xy.float() / self.pl2seed_radius + + # decode grid index of neighbor agents + if self.use_grid_token: + neighbor_agent_grid_index_gt = grid_index_a_p.transpose(0, 1).reshape(-1)[edge_index_a2a[0, -na2sa:]] + neighbor_pt_grid_index_gt = data['agent']['pt_grid_token_idx'].reshape(-1)[edge_index_pl2a[0, -npl2sa:]] + neighbor_agent_grid_idx = self.grid_index_head(r_a2a[-na2sa:]) + neighbor_pt_grid_idx = self.grid_index_head(r_pl2a[-npl2sa:]) + neighbor_agent_grid_index_eval_mask = torch.zeros_like(neighbor_agent_grid_index_gt).bool() + neighbor_pt_grid_index_eval_mask = torch.zeros_like(neighbor_pt_grid_index_gt).bool() + neighbor_agent_grid_index_eval_mask[torch.randperm(neighbor_agent_grid_index_gt.shape[0])[:180]] = True + neighbor_pt_grid_index_eval_mask[torch.randperm(neighbor_pt_grid_index_gt.shape[0])[:600]] = True + + # occupancy prediction + grid_agent_occ_seed = grid_pt_occ_seed = grid_agent_occ_eval_mask_seed = grid_pt_occ_eval_mask_seed = None + if self.predict_occ: + # grid_occ_embed = self.grid_occ_embed(self.grid_token_emb[:-1]) + grid_agent_occ_seed = self.grid_agent_occ_head(feat_seed) # (s, t, d) + grid_pt_occ_seed = self.grid_pt_occ_head(feat_seed) + + # refine stage + batch_s = torch.cat([data['agent']['batch'] + data.num_graphs * t for t in range(num_step)], dim=0) + batch_pl = torch.cat([data['pt_token']['batch'] + data.num_graphs * t for t in range(num_step)], dim=0) + + mask_sa = torch.zeros_like(agent_state_index).bool() + for t in range(mask_sa.shape[1]): + availabel_rows = ((agent_state_index[:, t] != self.invalid_state) & + (agent_grid_token_idx[:, t] != -1)).nonzero()[..., 0] + mask_sa[availabel_rows[torch.randperm(availabel_rows.shape[0])[:data.num_graphs * 10]], t] = True + mask_sa[agent_state_index == self.enter_state] = True + mask_sa[:, 0] = False # ignore the first step + mask_sa[av_index] = False # ignore self + + state_sa = torch.full_like(agent_state_index, self.invalid_state).long() + state_sa[mask_sa] = self.enter_state + + sa_indices = torch.nonzero(mask_sa) + pos_sa = pos_a.clone() + head_sa = head_a.clone() + expanded_av_index = av_index.repeat_interleave(repeats=data['batch_size_a'], dim=0) + head_sa[sa_indices[:, 0], sa_indices[:, 1]] = head_a[expanded_av_index[sa_indices[:, 0]], sa_indices[:, 1]] + + motion_vector_sa, head_vector_sa = self._build_vector_a(pos_a, head_sa, state_sa) + motion_vector_sa[mask_sa] = self.motion_gap # fix the case e.g. [0, 0, 1, '1', 0, 1] + + offset_pos = pos_a - data['ego_pos'].repeat_interleave(repeats=data['batch_size_a'], dim=0) + agent_grid_emb = self.grid_token_emb[agent_grid_token_idx.long()] + feat_sa, _ = self._build_agent_feature(num_step, pos_a.device, + motion_vector_sa, + head_vector_sa, + agent_grid_emb=agent_grid_emb, + offset_pos=offset_pos, + type=next_type_index_gt.long(), + shape=next_shape_gt, + state=state_sa, + n=num_agent) + feat_sa[~mask_sa] = raw_feat_a[~mask_sa].clone() + + edge_index_a2sa, r_a2sa = self._build_a2sa_edge(data, pos_a, head_sa, head_vector_sa, batch_s, + interact_mask, mask_sa=mask_sa, r=self.a2sa_radius) + edge_index_pl2sa, r_pl2sa = self._build_map2sa_edge(data, pos_a, head_sa, head_vector_sa, batch_s, batch_pl, + mask_sa=mask_sa, r=self.pl2sa_radius) + + # sanity check + global_index = set(torch.nonzero(mask_sa.transpose(0, 1).reshape(-1).int())[:, 0].tolist()) + a2sa_index = set(edge_index_a2sa[1].tolist()) + pl2sa_index = set(edge_index_pl2sa[1].tolist()) + assert a2sa_index.issubset(global_index) and pl2sa_index.issubset(global_index), "Invalid index!" + + select_mask = torch.zeros_like(mask_sa.view(-1)).bool() + select_mask[torch.unique(edge_index_a2sa[1])] = True + select_mask[torch.unique(edge_index_pl2sa[1])] = True + mask_sa[~select_mask.reshape(num_step, -1).transpose(0, 1)] = False + + for i in range(self.seed_layers): + + feat_sa = feat_sa.transpose(0, 1).reshape(-1, self.hidden_dim) + feat_sa = self.pt2a_attn_layers[i](( + map_enc['x_pt'].repeat_interleave(repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape( + -1, self.hidden_dim), feat_sa), r_pl2sa, edge_index_pl2sa) + + feat_sa = self.a2a_attn_layers[i](feat_sa, r_a2sa, edge_index_a2sa) + feat_sa = feat_sa.reshape(num_step, -1, self.hidden_dim).transpose(0, 1) + + next_head_rel_prob_seed = self.seed_heading_rel_token_predict_head(feat_sa) + next_head_rel_idx_seed = next_head_rel_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True) + + next_head_rel_index_gt_seed = agent_head_token_idx.long() + + next_offset_xy_seed = next_offset_xy_gt_seed = None + if self.use_grid_token: + next_offset_xy_seed = self.seed_offset_xy_predict_head(feat_sa) + next_offset_xy_seed = torch.tanh(next_offset_xy_seed) * 2 # [-2, 2] + + next_offset_xy_gt_seed = agent_grid_offset_xy.float() + + # next token prediction mask + bos_token_index = torch.nonzero(agent_state_index == self.enter_state) + eos_token_index = torch.nonzero(agent_state_index == self.exit_state) + + # mask for motion tokens + next_token_eval_mask = mask.clone() + next_token_eval_mask = next_token_eval_mask * next_token_eval_mask.roll(shifts=-1, dims=1) * next_token_eval_mask.roll(shifts=1, dims=1) + for bos_token_index_ in bos_token_index: + next_token_eval_mask[bos_token_index_[0], bos_token_index_[1] : bos_token_index_[1] + 1] = 1 + next_token_eval_mask[bos_token_index_[0], bos_token_index_[1] + 1 : bos_token_index_[1] + 2] = \ + mask[bos_token_index_[0], bos_token_index_[1] + 2 : bos_token_index_[1] + 3] + next_token_eval_mask[eos_token_index[:, 0], eos_token_index[:, 1]] = 0 + + # mask for state tokens + next_state_eval_mask = mask.clone() + next_state_eval_mask = next_state_eval_mask * next_state_eval_mask.roll(shifts=-1, dims=1) * next_state_eval_mask.roll(shifts=1, dims=1) + for bos_token_index_ in bos_token_index: + next_state_eval_mask[bos_token_index_[0], :bos_token_index_[1]] = 0 + next_state_eval_mask[bos_token_index_[0], bos_token_index_[1] : bos_token_index_[1] + 1] = 1 + next_state_eval_mask[bos_token_index_[0], bos_token_index_[1] + 1 : bos_token_index_[1] + 2] = \ + mask[bos_token_index_[0], bos_token_index_[1] + 2 : bos_token_index_[1] + 3] + for eos_token_index_ in eos_token_index: + next_state_eval_mask[eos_token_index_[0], eos_token_index_[1] + 1:] = 1 + next_state_eval_mask[eos_token_index_[0], eos_token_index_[1] : eos_token_index_[1] + 1] = \ + mask[eos_token_index_[0], eos_token_index_[1] - 1 : eos_token_index_[1]] + + next_state_eval_mask_seed = torch.ones_like(next_state_idx_seed[..., 0]) + + # the last timestep is the beginning of the sequence (also the input) + next_token_eval_mask[:, 0] = mask[:, 0] * mask[:, 1] + next_state_eval_mask[:, 0] = mask[:, 0] * mask[:, 1] + next_token_eval_mask[:, -1] = 0 + next_state_eval_mask[:, -1] = 0 + next_state_eval_mask_seed[:, 0] = 0 + + # no invalid motion token will be supervised + if (next_token_index_gt[next_token_eval_mask] < 0).any(): + raise RuntimeError("Found invalid motion index.") + + # seed agents + # is_next_bos = next_state_index_gt.roll(shifts=1, dims=1) == self.enter_state + # is_next_bos[:, 0] = False # we filter out the last timestep + # is_next_bos[av_index] = False + + # num_seed_gt = is_next_bos.sum(dim=0).max() + + # pred_indices = torch.zeros((num_seed_gt, num_step, 1), device=device).long() + # gt_indices = torch.zeros((num_seed_gt, num_step), device=device).long() + # if num_seed_gt > 0: + # outputs = dict(state_pred=next_state_prob_seed, + # pos_pred=next_pos_rel_prob_seed, + # shape_pred=next_shape_seed) + # targets = dict(state_gt=next_state_index_gt.clone(), + # pos_gt=next_pos_rel_index_gt.clone(), + # shape_gt=next_shape_gt.clone()) + + # indices = self.matcher(outputs, targets, + # valid_mask=is_next_bos, + # ptr_gt=data['agent']['ptr'], + # ptr_pred=torch.arange(data.num_graphs + 1, device=device) * self.num_seed_feature) + + # pred_indices = indices[0][..., None].to(device) + # gt_indices = indices[1].to(device) + + pred_indices = [] + gt_indices = [] + agent_ptr = data['agent']['ptr'] + num_seed_gt = 0 + for b in range(data.num_graphs): + batch_sort_indices = sort_indices[agent_ptr[b] : agent_ptr[b + 1]] + batch_num_seed_gt = min(self.num_seed_feature, batch_sort_indices.shape[0]) + num_seed_gt += batch_num_seed_gt + pred_indices.append((torch.arange(batch_num_seed_gt, device=device) + b * self.num_seed_feature + )[:, None, None].repeat(1, num_step, 1).long()) + gt_indices.append(batch_sort_indices[:batch_num_seed_gt] + agent_ptr[b]) + pred_indices = torch.concat(pred_indices) + gt_indices = torch.concat(gt_indices) + + n = pred_indices.shape[0] + + res_pred_indices = [] + for t in range(next_state_idx_seed.shape[1]): + indices_t = torch.arange(next_state_idx_seed.shape[0]).to(device) + selected_pred_mask = torch.zeros_like(indices_t) + selected_pred_mask[pred_indices[:, t]] = 1 + res_pred_indices.append(indices_t[~selected_pred_mask.bool()]) + res_pred_indices = torch.stack(res_pred_indices, dim=1) + padded_pred_indices = torch.concat([pred_indices, res_pred_indices[..., None]]) + next_state_idx_seed = torch.gather(next_state_idx_seed, dim=0, index=padded_pred_indices) + next_state_prob_seed = torch.gather(next_state_prob_seed, dim=0, index=padded_pred_indices.expand( + -1, -1, next_state_prob_seed.shape[-1])) + next_state_index_gt_seed = torch.gather(agent_state_index, dim=0, index=gt_indices) + next_state_index_gt_seed = torch.concat([next_state_index_gt_seed, + torch.zeros((next_state_prob_seed.shape[0] - next_state_index_gt_seed.shape[0], next_state_index_gt_seed.shape[1]), device=device)]).long() + seed_enter_mask = next_state_index_gt_seed == self.enter_state + next_state_index_gt_seed = torch.full(next_state_index_gt_seed.shape, self.seed_state_type.index('invalid'), device=device) + next_state_index_gt_seed[seed_enter_mask] = self.seed_state_type.index('enter') + + next_type_idx_seed = torch.gather(next_type_idx_seed, dim=0, index=pred_indices) + next_type_prob_seed = torch.gather(next_type_prob_seed, dim=0, index=pred_indices.expand( + -1, -1, next_type_prob_seed.shape[-1])) + next_type_index_gt_seed = torch.gather(next_type_index_gt, dim=0, index=gt_indices) + + if self.use_grid_token: + next_pos_rel_xy_seed = None + next_pos_rel_prob_seed = torch.gather(next_pos_rel_prob_seed, dim=0, index=pred_indices.expand( + -1, -1, next_pos_rel_prob_seed.shape[-1])) + else: + next_pos_rel_prob_seed = None + next_pos_rel_xy_seed = torch.gather(next_pos_rel_xy_seed, dim=0, index=pred_indices.expand( + -1, -1, next_pos_rel_xy_seed.shape[-1])) + next_pos_rel_index_gt_seed = torch.gather(next_pos_rel_index_gt, dim=0, index=gt_indices) + next_pos_rel_xy_gt_seed = torch.gather(next_pos_rel_xy_gt, dim=0, index=gt_indices[..., None].expand( + -1, -1, next_pos_rel_xy_gt.shape[-1])) + + next_shape_seed = torch.gather(next_shape_seed, dim=0, index=pred_indices.expand( + -1, -1, next_shape_seed.shape[-1])) + next_shape_gt_seed = torch.gather(next_shape_gt, dim=0, index=gt_indices[..., None].expand( + -1, -1, next_shape_gt.shape[-1])) + + next_attr_eval_mask_seed = seed_enter_mask[:n] + next_attr_eval_mask_seed[:, 0] = False # we ignore the first step + next_attr_eval_mask_seed[next_pos_rel_index_gt_seed == self.grid_size // 2] = False + + next_state_eval_mask[av_index] = 0 # we dont predict state for ego agent + + if (torch.any(next_type_index_gt_seed[next_attr_eval_mask_seed] == AGENT_TYPE.index('seed')) \ + or torch.any(torch.all(next_shape_gt_seed[next_attr_eval_mask_seed] == self.invalid_shape_value, dim=-1)) \ + or torch.any(next_pos_rel_index_gt_seed[next_attr_eval_mask_seed] < 0)) and num_seed_gt > 0: + raise ValueError(f"Found invalid gt values in scenario {data['scenario_id'][0]}.") + + next_state_index_gt[next_state_index_gt == self.exit_state] = self.valid_state_type.index('exit') + + # build occ gt + if self.predict_occ: + + # grid_agent_occ_seed = torch.einsum('s t d, g d -> s t g', grid_agent_occ_seed, grid_occ_embed) + # grid_pt_occ_seed = torch.einsum('s t d, g d -> s t g', grid_pt_occ_seed, grid_occ_embed) + + # augmentation + # TODO: add convolution!!! + # grid_agent_occ_eval_mask_seed = torch.zeros_like(grid_agent_occ_seed).bool() + # grid_pt_occ_eval_mask_seed = torch.zeros_like(grid_agent_occ_seed).bool() + + # gt_mask = grid_agent_occ_gt_seed.bool() + # gt_mask[:, 0] = False # ignore the first step + # gt_mask[..., self.grid_size // 2] = False # ignore self + + # random_weights = torch.rand_like(grid_agent_occ_seed) * gt_mask + # _, topk_indices = random_weights.topk(10, dim=-1) + # grid_agent_occ_eval_mask_seed.scatter_(-1, topk_indices, True) + # grid_agent_occ_eval_mask_seed[~gt_mask] = False + + # random_weights = torch.rand_like(grid_agent_occ_seed) * ~gt_mask + # _, topk_indices = random_weights.topk(10, dim=-1) + # grid_agent_occ_eval_mask_seed.scatter_(-1, topk_indices, True) + + # grid_agent_occ_eval_mask_seed[:, 0] = False + # grid_agent_occ_eval_mask_seed[..., self.grid_size // 2] = False + + # gt_mask = grid_pt_occ_gt_seed.bool() + + # random_weights = torch.rand_like(grid_agent_occ_seed) * gt_mask + # _, topk_indices = random_weights.topk(256, dim=-1) + # grid_pt_occ_eval_mask_seed.scatter_(-1, topk_indices, True) + # grid_pt_occ_eval_mask_seed[~gt_mask] = False + + # random_weights = torch.rand_like(grid_agent_occ_seed) * ~gt_mask + # _, topk_indices = random_weights.topk(256, dim=-1) + # grid_pt_occ_eval_mask_seed.scatter_(-1, topk_indices, True) + + grid_occ_eval_mask_seed = torch.ones_like(grid_agent_occ_seed).bool() + grid_occ_eval_mask_seed[:, 0] = False + grid_occ_eval_mask_seed[..., self.grid_size // 2] = False + grid_agent_occ_eval_mask_seed = grid_pt_occ_eval_mask_seed = grid_occ_eval_mask_seed + + # sanity check + # s = random.randint(0, self.num_seed_feature - 1) + # t = random.randint(0, num_step - 1) + # grid_index = grid_agent_occ_gt_seed[s, t].nonzero()[..., 0] + # check_mask = torch.zeros_like(pad_mask) + # check_mask[av_index + s + 1, t] = 1 + # check_index = check_mask.transpose(0, 1).reshape(-1).nonzero()[..., 0] + # check_agent_index = edge_index_a2a[0, edge_index_a2a[1] == check_index[0]] % (num_agent + self.num_seed_feature) + # if not torch.all(grid_index == next_pos_rel_index_gt[check_agent_index, t].unique().sort()[0]): + # raise RuntimeError(f"Grid index not consistent s={s} t={t} scenario_id={data['scenario_id'][0]}") + + target_indices = pred_indices.clone() + target_indices[~next_attr_eval_mask_seed] = -1 + + return {'x_a': feat_a, + 'ego_pos': ego_pos, + # motion token + 'next_token_idx': next_token_idx, + 'next_token_prob': next_token_prob, + 'next_token_idx_gt': next_token_index_gt, + 'next_token_eval_mask': next_token_eval_mask.bool(), + # state token + 'next_state_idx': next_state_idx, + 'next_state_prob': next_state_prob, + 'next_state_idx_gt': next_state_index_gt, + 'next_state_eval_mask': next_state_eval_mask.bool(), + # seed agent + 'next_state_idx_seed': next_state_idx_seed, + 'next_state_prob_seed': next_state_prob_seed, + 'next_state_idx_gt_seed': next_state_index_gt_seed, + + 'next_type_idx_seed': next_type_idx_seed, + 'next_type_prob_seed': next_type_prob_seed, + 'next_type_idx_gt_seed': next_type_index_gt_seed, + + 'next_pos_rel_prob_seed': next_pos_rel_prob_seed, + 'next_pos_rel_index_gt_seed': next_pos_rel_index_gt_seed, + 'next_pos_rel_xy_seed': next_pos_rel_xy_seed, + 'next_pos_rel_xy_gt_seed': next_pos_rel_xy_gt_seed, + 'next_head_rel_prob_seed': next_head_rel_prob_seed, + 'next_head_rel_index_gt_seed': next_head_rel_index_gt_seed, + 'next_offset_xy_seed': next_offset_xy_seed, + 'next_offset_xy_gt_seed': next_offset_xy_gt_seed, + 'next_shape_seed': next_shape_seed, + 'next_shape_gt_seed': next_shape_gt_seed, + + 'grid_agent_occ_seed': grid_agent_occ_seed, + 'grid_pt_occ_seed': grid_pt_occ_seed, + 'grid_agent_occ_gt_seed': grid_agent_occ_gt_seed, + 'grid_pt_occ_gt_seed': grid_pt_occ_gt_seed, + 'neighbor_agent_grid_idx': neighbor_agent_grid_idx + if self.use_grid_token else None, + 'neighbor_pt_grid_idx': neighbor_pt_grid_idx + if self.use_grid_token else None, + 'neighbor_agent_grid_index_gt': neighbor_agent_grid_index_gt + if self.use_grid_token else None, + 'neighbor_pt_grid_index_gt': neighbor_pt_grid_index_gt + if self.use_grid_token else None, + + 'target_indices': target_indices[..., 0], + 'raw_next_state_prob_seed': raw_next_state_prob_seed, + + 'next_state_eval_mask_seed': next_state_eval_mask_seed.bool(), + 'next_attr_eval_mask_seed': next_attr_eval_mask_seed.bool(), + 'next_head_eval_mask_seed': mask_sa.bool(), + 'grid_agent_occ_eval_mask_seed': grid_agent_occ_eval_mask_seed + if self.use_grid_token else None, + 'grid_pt_occ_eval_mask_seed': grid_pt_occ_eval_mask_seed + if self.use_grid_token else None, + 'neighbor_agent_grid_index_eval_mask': neighbor_agent_grid_index_eval_mask.bool() + if self.use_grid_token else None, + 'neighbor_pt_grid_index_eval_mask': neighbor_pt_grid_index_eval_mask.bool() + if self.use_grid_token else None, + } + + def inference(self, + data: HeteroData, + map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + + filter_mask = data['agent']['state_idx'][:, (self.num_historical_steps - 1) // self.shift - 1] != self.invalid_state + + seed_step_mask = data['agent']['state_idx'][:, (self.num_historical_steps - 1) // self.shift:] == self.enter_state + # seed_agent_index_per_step = [torch.nonzero(seed_step_mask[:, t]).squeeze(dim=-1) for t in range(seed_step_mask.shape[1])] + + # num_historical_steps=11 + eval_mask = data['agent']['valid_mask'][filter_mask, self.num_historical_steps - 1] + + # agent attributes + agent_id = data['agent']['id'][filter_mask].clone() + agent_valid_mask = data['agent']['raw_agent_valid_mask'][filter_mask].clone() # token_valid_mask + pos_a = data['agent']['token_pos'][filter_mask].clone() # (a, t, 2) + token_a = data['agent']['token_idx'][filter_mask].clone() # (a, t) + state_a = data['agent']['state_idx'][filter_mask].clone() + head_a = data['agent']['token_heading'][filter_mask].clone() + shape_a = data['agent']['shape'][filter_mask].clone() + type_a = data['agent']['type'][filter_mask].clone() + grid_a = data['agent']['grid_token_idx'][filter_mask].clone() + gt_traj = data['agent']['position'][filter_mask, self.num_historical_steps:, :self.input_dim].contiguous() + agent_token_traj_all = data['agent']['token_traj_all'][filter_mask] + + device = pos_a.device + max_agent_id = agent_id.max() # TODO: bs=1 + + if self.num_recurrent_steps_val == -1: + # self.num_recurrent_steps_val = 91 - 11 = 80 + self.num_recurrent_steps_val = data["agent"]['position'].shape[1] - self.num_historical_steps + num_agent, num_ori_step, traj_dim = pos_a.shape + num_infer_step = (self.num_recurrent_steps_val + self.num_historical_steps) // self.shift + if num_infer_step > num_ori_step: + pad_shape = num_agent, num_infer_step - num_ori_step + agent_valid_mask = torch.cat([agent_valid_mask, torch.full(pad_shape, True, device=device)], dim=1) + pos_a = torch.cat([pos_a, torch.zeros((*pad_shape, pos_a.shape[-1]), device=device)], dim=1) + token_a = torch.cat([token_a, torch.full(pad_shape, -1, device=device)], dim=1) + state_a = torch.cat([state_a, torch.full(pad_shape, self.invalid_state, device=device)], dim=1) + head_a = torch.cat([head_a, torch.zeros(pad_shape, device=device)], dim=1) + grid_a = torch.cat([grid_a, torch.full(pad_shape, -1, device=device)], dim=1) + + # TODO: support bs > 1 in inference !!! + num_removed_agent = int((~filter_mask[:data['agent']['av_index']]).sum()) + data['batch_size_a'] -= num_removed_agent + av_index = data['agent']['av_index'] - num_removed_agent + + # make future steps to zero + pos_a[:, (self.num_historical_steps - 1) // self.shift:] = 0 + head_a[:, (self.num_historical_steps - 1) // self.shift:] = 0 + token_a[:, (self.num_historical_steps - 1) // self.shift:] = -1 + state_a[:, (self.num_historical_steps - 1) // self.shift:] = 0 + grid_a[:, (self.num_historical_steps - 1) // self.shift:] = -1 + + motion_vector_a, head_vector_a = self._build_vector_a(pos_a, head_a, state_a) + + agent_valid_mask[:, (self.num_historical_steps - 1) // self.shift:] = True + agent_valid_mask[~eval_mask] = False + agent_token_index = data['agent']['token_idx'][filter_mask] + agent_state_index = data['agent']['state_idx'][filter_mask] + + (feat_a, agent_token_emb, agent_token_emb_veh, agent_token_emb_ped, agent_token_emb_cyc, categorical_embs, + trajectory_token_veh, trajectory_token_ped, trajectory_token_cyc) = self._agent_token_embedding( + data, + token_a, + state_a, + grid_a, + pos_a, + head_a, + inference=True, + filter_mask=filter_mask, + av_index=av_index, + ) + raw_feat_a = feat_a.clone() + + veh_mask = type_a == 0 + cyc_mask = type_a == 2 + ped_mask = type_a == 1 + + pred_traj = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val, 2, device=device) # (a, val_t, 2) + pred_head = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val, device=device) + pred_type = type_a.clone() + pred_shape = shape_a[:, (self.num_historical_steps - 1) // self.shift - 1] # (a, 3) + pred_state = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val, device=device) + pred_prob = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val // self.shift, device=device) # (a, val_t) + + feat_a_t_dict = {} + feat_sa_t_dict = {} + + # build masks (init) + mask = agent_valid_mask.clone() + temporal_mask = mask.clone() + interact_mask = mask.clone() + + # find bos and eos index + is_bos = state_a == self.enter_state + is_eos = state_a == self.exit_state + bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0)) + eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_infer_step - 1)) + + temporal_mask = torch.ones_like(mask) + motion_mask = torch.arange(mask.shape[1]).expand(mask.shape[0], mask.shape[1]).to(mask.device) + motion_mask = (motion_mask > bos_index[:, None]) & (motion_mask <= eos_index[:, None]) + motion_mask[:, self.num_historical_steps // self.shift:] = False + temporal_mask[motion_mask] = mask[motion_mask] + + interact_mask = torch.ones_like(mask) + non_motion_mask = ~motion_mask + non_motion_mask[:, self.num_historical_steps // self.shift:] = False + interact_mask[non_motion_mask] = 0 + interact_mask[state_a == self.enter_state] = 1 + interact_mask[av_index] = 1 + + temporal_mask[:, (self.num_historical_steps - 1) // self.shift:] = 1 + interact_mask[:, (self.num_historical_steps - 1) // self.shift:] = 1 + + self.log_message = "" + num_inserted_agents_total = num_inserted_agents = 0 + next_token_idx_list = [] + next_state_idx_list = [] + grid_agent_occ_list = [] + grid_pt_occ_list = [] + grid_agent_occ_gt_list = [] + next_state_prob_seed_list = [] + next_pos_rel_prob_seed_list = [] + agent_labels = [[None] * num_infer_step for _ in range(pos_a.shape[0])] + + # append history motion/state tokens + for i in range((self.num_historical_steps - 1) // self.shift): + next_token_idx_list.append(agent_token_index[:, i : i + 1]) + next_state_idx_list.append(agent_state_index[:, i : i + 1]) + + num_seed_feature = 1 + insert_limit = 10 + + for t in ( + pbar := tqdm(range(self.num_recurrent_steps_val // self.shift), leave=False, desc='Timestep ...') + ): + + # 1. insert agents + num_new_agents = 0 + next_state_prob_seeds = torch.zeros(10 + 1, 1, device=device) + next_pos_rel_prob_seeds = torch.zeros(10 + 1, 1, self.attr_tokenizer.grid_size, device=device) + grid_agent_occ_seeds = torch.zeros(10 + 1, 1, self.attr_tokenizer.grid_size, device=device) + grid_pt_occ_seeds = torch.zeros(10 + 1, 1, self.attr_tokenizer.grid_size, device=device) + grid_agent_occ_gt_seeds = torch.zeros(10 + 1, 1, self.attr_tokenizer.grid_size, device=device) + + valid_state_mask = state_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t] != self.invalid_state # TODO: only support bs=1 + distance = ((pos_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t, :2] - pos_a[av_index, (self.num_historical_steps - 1) // self.shift - 1 + t, :2]) ** 2).sum(-1).sqrt() + inrange_mask = distance <= self.pl2seed_radius + seq_valid_mask = valid_state_mask & inrange_mask + seq_valid_mask[av_index] = False + res_seq_index = torch.zeros_like(state_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) + res_seq_index[seq_valid_mask] = torch.randperm(seq_valid_mask.sum(), device=device) + 1 + + if t == 0: + inference_mask = temporal_mask.clone() + inference_mask = torch.cat([inference_mask, torch.ones_like(inference_mask[-1:]).repeat( + num_seed_feature, *([1] * (inference_mask.dim() - 1)))]) + inference_mask[:, (self.num_historical_steps - 1) // self.shift + t:] = False + else: + inference_mask = torch.zeros_like(temporal_mask) + inference_mask = torch.cat([inference_mask, torch.zeros_like(inference_mask[-1:]).repeat( + num_seed_feature, *([1] * (inference_mask.dim() - 1)))]) + inference_mask[:, (self.num_historical_steps - 1) // self.shift + t - 1] = True + + plot_kwargs = dict() + p = 0 + while True: + + p += 1 + if t == 0 or p - 1 >= insert_limit: break + + # rebuild inference mask since number of agents have changed + inference_mask = torch.zeros_like(temporal_mask) + inference_mask = torch.cat([inference_mask, torch.zeros_like(inference_mask[-1:]).repeat( + num_seed_feature, *([1] * (inference_mask.dim() - 1)))]) + inference_mask[:, (self.num_historical_steps - 1) // self.shift + t - 1] = True + + # sanity check: make sure seed agents will interact with **all** non-invalid agents + assert torch.all(state_a[:, :(self.num_historical_steps - 1) // self.shift + t][ + ~interact_mask[:, :(self.num_historical_steps - 1) // self.shift + t]] == self.invalid_state) and \ + torch.all(state_a[:, :(self.num_historical_steps - 1) // self.shift + t][ + interact_mask[:, :(self.num_historical_steps - 1) // self.shift + t]] != self.invalid_state), \ + f"Got wrong with interact mask at scenario {data['scenario_id'][0]} t={t}!" + + temporal_mask = torch.cat([temporal_mask, torch.ones_like(temporal_mask[:1]).repeat( + num_seed_feature, *([1] * (temporal_mask.dim() - 1)))]).bool() + interact_mask = torch.cat([interact_mask, torch.ones_like(interact_mask[:1]).repeat( + num_seed_feature, *([1] * (interact_mask.dim() - 1)))]).bool() # placeholder + + pos_a_p, head_a_p, state_a_p, head_vector_a_p, grid_index_a_p, pad_mask = \ + self._pad_feat(data.num_graphs, av_index, pos_a, head_a, state_a, head_vector_a, grid_a, num_seed_feature=num_seed_feature) + # sanity check + assert torch.all(~pad_mask[-num_seed_feature:]), "Got wrong with pad mask!" + + batch_s = torch.arange(num_infer_step, device=device).repeat_interleave(num_agent + num_seed_feature) + batch_pl = torch.arange(num_infer_step, device=device).repeat_interleave(data['pt_token']['num_nodes']) + + inference_mask_sa = torch.zeros_like(inference_mask).bool() + inference_mask_sa[:, (self.num_historical_steps - 1) // self.shift - 1 + t] = True + + # 1.1 build seed agent features + if self.seed_use_ego_motion: + motion_vector_seed = motion_vector_a[av_index] + head_vector_seed = head_vector_a[av_index] + else: + motion_vector_seed = head_vector_seed = None + + feat_seed, _ = self._build_agent_feature(num_infer_step, device, + motion_vector_seed, + head_vector_seed, + state_index=self.invalid_state, + n=num_seed_feature) + + if feat_a.shape[1] != feat_seed.shape[1]: + assert t == 0, f"Unmatched timestep {feat_a.shape[1]} and {feat_seed.shape[1]}." + feat_a = torch.cat([feat_a, feat_a[:, -1:].repeat(1, feat_seed.shape[1] - feat_a.shape[1], 1)], dim=1) + + raw_feat_a = feat_a.clone() + feat_a = torch.cat([feat_a, feat_seed], dim=0) + + # 1.2 global feature aggregation + plot_kwargs.update(t=t, n=num_new_agents, tag='global_feature') + # 0, 0, 0, ..., N+1, N+2, ... + seq_index = torch.cat([torch.zeros(pos_a.shape[0] - num_new_agents), torch.arange(num_new_agents + 1) + 1]).to(device) + # 0, 2, 1, ..., N+1, N+2, ... + # seq_index = torch.cat([res_seq_index, torch.arange(num_new_agents + 1, device=device) + 1 + seq_valid_mask.sum()]) + edge_index_a2seed, r_seed2a = self._build_a2sa_edge(data, pos_a_p, head_a_p, head_vector_a_p, batch_s, + interact_mask.clone(), + mask_sa=~pad_mask.clone(), + inference_mask=inference_mask_sa, + r=self.pl2seed_radius, + max_num_neighbors=300, + seq_index=seq_index, + grid_index_a=grid_index_a_p, + mode='insert', **plot_kwargs) + edge_index_pl2seed, r_pl2seed = self._build_map2sa_edge(data, pos_a_p, head_a_p, head_vector_a_p, batch_s, batch_pl, + mask_sa=~pad_mask.clone(), + inference_mask=inference_mask_sa, + r=self.pl2seed_radius, + max_num_neighbors=2048, + mode='insert') + temporal_mask = temporal_mask[:-num_seed_feature] + interact_mask = interact_mask[:-num_seed_feature] + + if self.use_grid_token: + grid_agent_occ_gt_t_1 = torch.zeros((self.grid_size,), device=device).long() + grid_t_1 = grid_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t] + grid_agent_occ_gt_t_1[grid_t_1[grid_t_1 != -1]] = 1 + occ_embed_a = self.seed_agent_occ_embed(grid_agent_occ_gt_t_1.reshape(1, self.grid_size).float()).repeat(num_seed_feature, 1) + edge_index_occ2sa_src = torch.arange(feat_a.shape[0] * feat_a.shape[1], device=device).long() + edge_index_occ2sa_src = edge_index_occ2sa_src[(~pad_mask.transpose(0, 1).reshape(-1)) & (inference_mask_sa.transpose(0, 1).reshape(-1))] + edge_index_occ2sa_tgt = torch.arange(occ_embed_a.shape[0], device=device).long() + edge_index_occ2sa = torch.stack([edge_index_occ2sa_tgt, edge_index_occ2sa_src], dim=0) + + for i in range(self.seed_layers): + + feat_a = feat_a.transpose(0, 1).reshape(-1, self.hidden_dim) + if self.use_grid_token: + feat_a = self.occ2sa_attn_layers[i]((occ_embed_a, feat_a), None, edge_index_occ2sa) + feat_a = self.pt2sa_attn_layers[i](( + map_enc['x_pt'].repeat_interleave(repeats=num_infer_step, dim=0).reshape(-1, num_infer_step, self.hidden_dim).transpose(0, 1).reshape( + -1, self.hidden_dim), feat_a), r_pl2seed, edge_index_pl2seed) + + feat_a = self.a2sa_attn_layers[i](feat_a, r_seed2a, edge_index_a2seed) + feat_a = feat_a.reshape(num_infer_step, -1, self.hidden_dim).transpose(0, 1) + + feat_seed = feat_a[-num_seed_feature:] # (s, t, d) + + ego_pos_t_1 = pos_a[av_index, (self.num_historical_steps - 1) // self.shift - 1 + t] + ego_head_t_1 = head_a[av_index, (self.num_historical_steps - 1) // self.shift - 1 + t] + + # occupancy + if self.predict_occ: + grid_agent_occ_seed = self.grid_agent_occ_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) # (num_seed, grid_size) + grid_pt_occ_seed = self.grid_pt_occ_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) + + # insert prob + next_state_prob_seed = self.seed_state_predict_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) + next_state_idx_seed = next_state_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True) + next_state_idx_seed[next_state_idx_seed == self.seed_state_type.index('invalid')] = self.invalid_state + next_state_idx_seed[next_state_idx_seed == self.seed_state_type.index('enter')] = self.enter_state + if int(os.getenv('DEBUG', 0)): + next_state_idx_seed = torch.full(next_state_idx_seed.shape, self.enter_state, device=device).long() + + # type and shape + next_type_prob_seed = self.seed_type_predict_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) + next_type_idx_seed = next_type_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True) + next_shape_seed = self.seed_shape_predict_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) + + # position + if self.use_grid_token: + next_pos_rel_prob_seed = self.seed_pos_rel_token_predict_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) + next_pos_rel_prob_softmax = torch.softmax(next_pos_rel_prob_seed, dim=-1) + # if self.inference_filter_overlap: + # next_pos_rel_prob_softmax[..., grid_agent_occ_gt_t_1.bool()] = 1e-6 # diffuse! + topk_pos_rel_prob, next_pos_rel_idx_seed = torch.topk(next_pos_rel_prob_softmax, k=self.insert_beam_size, dim=-1) + sample_pos_rel_index = torch.multinomial(topk_pos_rel_prob, 1).to(device) + next_pos_rel_idx_seed = next_pos_rel_idx_seed.gather(dim=1, index=sample_pos_rel_index) + next_pos_seed = self.attr_tokenizer.decode_pos(next_pos_rel_idx_seed[..., 0], y=ego_pos_t_1, theta_y=ego_head_t_1) + if self.inference_filter_overlap: + if grid_agent_occ_gt_t_1[next_pos_rel_idx_seed[..., 0]]: # TODO: only support insert num=1 for each iter!!! + feat_a = raw_feat_a.clone() + continue + else: + next_pos_rel_xy_seed = self.seed_pos_rel_xy_predict_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) + next_pos_seed = F.tanh(next_pos_rel_xy_seed) * self.pl2seed_radius + ego_pos_t_1 + + if torch.all(next_state_idx_seed == self.invalid_state) or num_new_agents + 1 > insert_limit: + break + + num_new_agent = 1 # TODO: fix this term + num_new_agents += 1 + + # ! 1.5. insert new agents and update attributes + + # append new agent id + agent_id = torch.cat([agent_id, torch.tensor([max_agent_id + 1], device=device, dtype=agent_id.dtype)]) + max_agent_id += 1 + + mask = torch.cat([mask, torch.ones(num_new_agent, num_infer_step, device=mask.device)], dim=0).bool() + temporal_mask = torch.cat([temporal_mask, torch.ones(num_new_agent, num_infer_step, device=temporal_mask.device)], dim=0).bool() + interact_mask = torch.cat([interact_mask, torch.ones(num_new_agent, num_infer_step, device=interact_mask.device)], dim=0).bool() + + # initialize new attributes + new_pos_a = torch.zeros(num_new_agent, num_infer_step, 2, device=device) + new_head_a = torch.zeros(num_new_agent, num_infer_step, device=device) + new_grid_a = torch.full((num_new_agent, num_infer_step), -1, device=device) + new_state_a = torch.full((num_new_agent, num_infer_step), self.invalid_state, device=state_a.device) + new_shape_a = torch.full((num_new_agent, num_infer_step, 3), self.invalid_shape_value, device=device) + new_type_a = torch.full((num_new_agent, num_infer_step), AGENT_TYPE.index('seed'), device=device) + + # add new attributes + new_pos_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t] = next_pos_seed + pos_a = torch.cat([pos_a, new_pos_a], dim=0) + + new_head_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t] = ego_head_t_1 # dummy values + head_a = torch.cat([head_a, new_head_a], dim=0) + + if self.use_grid_token: + new_grid_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t] = next_pos_rel_idx_seed + grid_a = torch.cat([grid_a, new_grid_a], dim=0) + + new_type_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t:] = next_type_idx_seed + new_shape_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t:] = next_shape_seed[:, None] + pred_type = torch.cat([pred_type, new_type_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]]) + pred_shape = torch.cat([pred_shape, new_shape_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]]) + + new_state_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t] = next_state_idx_seed # all enter state + state_a = torch.cat([state_a, new_state_a], dim=0) + + mask[-num_new_agent:, :(self.num_historical_steps - 1) // self.shift + t] = 0 + interact_mask[-num_new_agent:, :(self.num_historical_steps - 1) // self.shift - 1 + t] = 0 + + # placeholdersin pred_traj, pred_head, pred_state + new_pred_traj = torch.zeros(num_new_agent, self.num_recurrent_steps_val, 2, device=device) + new_pred_head = torch.zeros(num_new_agent, self.num_recurrent_steps_val, device=device) + new_pred_state = torch.zeros(num_new_agent, self.num_recurrent_steps_val, device=device) + + if t > 0: + new_pred_traj[:, (t - 1) * 5 : t * 5] = new_pos_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t, None].repeat(1, 5, 1) + new_pred_head[:, (t - 1) * 5 : t * 5] = new_head_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t, None].repeat(1, 5) + new_pred_state[:, (t - 1) * 5 : t * 5] = next_state_idx_seed.repeat(1, 5) + + pred_traj = torch.cat([pred_traj, new_pred_traj], dim=0) + pred_head = torch.cat([pred_head, new_pred_head], dim=0) + pred_state = torch.cat([pred_state, new_pred_state], dim=0) + + # add new agents token embeddings + new_agent_token_emb = self.no_token_emb(torch.zeros(1, device=device).long())[None, :].repeat(num_new_agent, num_infer_step, 1) + new_agent_token_emb[:, (self.num_historical_steps - 1) // self.shift - 1 + t] = self.bos_token_emb(torch.zeros(1, device=device).long()) + agent_token_emb = torch.cat([agent_token_emb, new_agent_token_emb]) + next_veh_mask = next_type_idx_seed[..., 0] == AGENT_TYPE.index('veh') + next_ped_mask = next_type_idx_seed[..., 0] == AGENT_TYPE.index('ped') + next_cyc_mask = next_type_idx_seed[..., 0] == AGENT_TYPE.index('cyc') + veh_mask = torch.cat([veh_mask, next_veh_mask]) + ped_mask = torch.cat([ped_mask, next_ped_mask]) + cyc_mask = torch.cat([cyc_mask, next_cyc_mask]) + + # add new agents trajectory embeddings + new_agent_token_traj_all = torch.zeros((num_new_agent, self.token_size, self.shift + 1, 4, 2), device=device) + new_agent_token_traj_all[next_veh_mask] = trajectory_token_veh[None, ...] + new_agent_token_traj_all[next_ped_mask] = trajectory_token_ped[None, ...] + new_agent_token_traj_all[next_cyc_mask] = trajectory_token_cyc[None, ...] + + agent_token_traj_all = torch.cat([agent_token_traj_all, new_agent_token_traj_all], dim=0) + + new_categorical_embs = [self.type_a_emb(new_type_a.reshape(-1).long()), self.shape_emb(new_shape_a.reshape(-1, 3))] + categorical_embs = [torch.cat([categorical_embs[0], new_categorical_embs[0]], dim=0), + torch.cat([categorical_embs[1], new_categorical_embs[1]], dim=0)] + + new_labels = [None] * num_infer_step + new_labels[(self.num_historical_steps - 1) // self.shift + t] = f'A{num_new_agents}' # the first step after bos step! + agent_labels.append(new_labels) + + # 2. predict headings for seed agents + motion_vector_sa, head_vector_sa = self._build_vector_a(pos_a[-num_new_agent:], + head_a[-num_new_agent:], + state_a[-num_new_agent:]) + # sanity check + assert torch.all(motion_vector_sa[:, :(self.num_historical_steps - 1) // self.shift - 1 + t] == self.invalid_motion_value) and \ + torch.all(motion_vector_sa[:, (self.num_historical_steps - 1) // self.shift - 1 + t] == self.motion_gap), \ + f"Found invalid values in motion_vectect_a at scenario {data['scenario_id'][0]} t={t}!" + + motion_vector_sa[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0. + head_vector_sa[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0. + motion_vector_a = torch.cat([motion_vector_a, motion_vector_sa]) + head_vector_a = torch.cat([head_vector_a, head_vector_sa]) + + new_offset_pos = pos_a[-num_new_agent:] - pos_a[av_index] + new_agent_grid_emb = self.grid_token_emb[new_grid_a] if self.use_grid_token else None + + feat_sa, _ = self._build_agent_feature(num_infer_step, device, + motion_vector_sa, + head_vector_sa, + agent_token_emb=new_agent_token_emb, + agent_grid_emb=new_agent_grid_emb, + offset_pos=new_offset_pos, + categorical_embs_a=new_categorical_embs, + state=new_state_a) + + feat_a = torch.cat([raw_feat_a, feat_sa]) + + batch_s = torch.arange(num_infer_step, device=device).repeat_interleave(num_agent + num_new_agent) + batch_pl = torch.arange(num_infer_step, device=device).repeat_interleave(data['pt_token']['num_nodes']) + + # sanity check + assert pos_a.shape[0] == head_a.shape[0] == head_vector_a.shape[0] == interact_mask.shape[0] == \ + pad_mask.shape[0] == inference_mask_sa.shape[0] == (num_agent + num_new_agent), f"Inconsistent shapes!" + + plot_kwargs.update(tag='heading') + edge_index_a2sa, r_a2sa = self._build_a2sa_edge(data, pos_a, head_a, head_vector_a, batch_s, + interact_mask.clone(), + mask_sa=~pad_mask.clone(), + inference_mask=inference_mask_sa, + r=self.a2sa_radius, + max_num_neighbors=24, + **plot_kwargs) + edge_index_pl2sa, r_pl2sa = self._build_map2sa_edge(data, pos_a, head_a, head_vector_a, batch_s, batch_pl, + mask_sa=~pad_mask.clone(), + inference_mask=inference_mask_sa, + r=self.pl2sa_radius, + max_num_neighbors=128) + + for i in range(self.seed_layers): + + feat_a = feat_a.transpose(0, 1).reshape(-1, self.hidden_dim) + feat_a = self.pt2a_attn_layers[i](( + map_enc['x_pt'].repeat_interleave(repeats=num_infer_step, dim=0).reshape(-1, num_infer_step, self.hidden_dim).transpose(0, 1).reshape( + -1, self.hidden_dim), feat_a), r_pl2sa, edge_index_pl2sa) + + feat_a = self.a2a_attn_layers[i](feat_a, r_a2sa, edge_index_a2sa) + feat_a = feat_a.reshape(num_infer_step, -1, self.hidden_dim).transpose(0, 1) + + next_head_rel_prob_seed = self.seed_heading_rel_token_predict_head(feat_a[-num_new_agent:, (self.num_historical_steps - 1) // self.shift - 1 + t]) + next_head_rel_idx_seed = next_head_rel_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True) + next_head_seed = wrap_angle(self.attr_tokenizer.decode_heading(next_head_rel_idx_seed) + ego_head_t_1) + + head_a[-num_new_agent:, (self.num_historical_steps - 1) // self.shift - 1 + t] = next_head_seed + + if self.use_grid_token: + next_offset_xy_seed = self.seed_offset_xy_predict_head(feat_a[-num_new_agent:, (self.num_historical_steps - 1) // self.shift - 1 + t]) + next_offset_xy_seed = torch.tanh(next_offset_xy_seed) * 2 + + pos_a[-num_new_agent:, (self.num_historical_steps - 1) // self.shift - 1 + t] += next_offset_xy_seed + + # ! finalize new features + motion_vector_sa, head_vector_sa = self._build_vector_a(pos_a[-num_new_agent:], + head_a[-num_new_agent:], + state_a[-num_new_agent:]) + motion_vector_sa[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0. + head_vector_sa[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0. + motion_vector_a[-num_new_agent:] = motion_vector_sa + head_vector_a[-num_new_agents:] = head_vector_sa + + feat_sa, _ = self._build_agent_feature(num_infer_step, device, + motion_vector_sa, + head_vector_sa, + agent_token_emb=new_agent_token_emb, + agent_grid_emb=new_agent_grid_emb, + offset_pos=new_offset_pos, + categorical_embs_a=new_categorical_embs, + state=state_a[-num_new_agent:], + n=num_new_agent) + + feat_a = torch.cat([raw_feat_a, feat_sa]) + raw_feat_a = feat_a.clone() + + num_agent = pos_a.shape[0] + + if self.use_grid_token: + grid_agent_occ_gt_seeds[num_new_agents] = grid_agent_occ_gt_t_1 + grid_agent_occ_seeds[num_new_agents] = grid_agent_occ_seed + grid_pt_occ_seeds[num_new_agents] = grid_pt_occ_seed + next_pos_rel_prob_seeds[num_new_agents] = next_pos_rel_prob_softmax + next_state_prob_seeds[num_new_agents] = next_state_prob_seed.softmax(dim=-1)[:, -1] + + inference_mask = inference_mask[:-num_seed_feature] + next_state_prob_seed_list.append(next_state_prob_seeds) + if self.use_grid_token: + next_pos_rel_prob_seed_list.append(next_pos_rel_prob_seeds) + grid_agent_occ_list.append(grid_agent_occ_seeds) + grid_pt_occ_list.append(grid_pt_occ_seeds) + grid_agent_occ_gt_list.append(grid_agent_occ_gt_seeds) + next_state_idx_list[-1] = torch.cat([next_state_idx_list[-1], torch.full((num_new_agents, 1), self.enter_state, device=device).long()]) + + # 3. predict motions for all agents + feat_a = raw_feat_a + + # rebuild inference mask since number of agents have changed + inference_mask = torch.zeros_like(temporal_mask) + inference_mask[:, (self.num_historical_steps - 1) // self.shift + t - 1] = True + + edge_index_t, r_t = self._build_temporal_edge(data, pos_a, head_a, state_a, head_vector_a, temporal_mask, inference_mask.clone()) + + batch_s = torch.arange(num_infer_step, device=device).repeat_interleave(num_agent) + batch_pl = torch.arange(num_infer_step, device=device).repeat_interleave(data['pt_token']['num_nodes']) + + edge_index_a2a, r_a2a = self._build_interaction_edge(data, pos_a, head_a, state_a, head_vector_a, batch_s, + interact_mask, inference_mask=inference_mask, av_index=av_index, **plot_kwargs) + edge_index_pl2a, r_pl2a = self._build_map2agent_edge(data, pos_a, head_a, state_a, head_vector_a, batch_s, batch_pl, + interact_mask, inference_mask=inference_mask, av_index=av_index, **plot_kwargs) + + for i in range(self.num_layers): + + if i in feat_a_t_dict: + feat_a = feat_a_t_dict[i] + + feat_a = feat_a.reshape(-1, self.hidden_dim) + feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t) + + feat_a = feat_a.reshape(-1, num_infer_step, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim) + feat_a = self.pt2a_attn_layers[i](( + map_enc['x_pt'].repeat_interleave(repeats=num_infer_step, dim=0).reshape(-1, num_infer_step, self.hidden_dim).transpose(0, 1).reshape( + -1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a) + + feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a) + feat_a = feat_a.reshape(num_infer_step, -1, self.hidden_dim).transpose(0, 1) + + if t == 0: + feat_a_t_dict[i + 1] = feat_a + else: + # update agent features at current step + n = feat_a_t_dict[i + 1].shape[0] + feat_a_t_dict[i + 1][:n, (self.num_historical_steps - 1) // self.shift - 1 + t] = feat_a[:n, (self.num_historical_steps - 1) // self.shift - 1 + t] + # add newly inserted agent features (only when t changed) + if feat_a.shape[0] > n: + m = feat_a.shape[0] - n + feat_a_t_dict[i + 1] = torch.cat([feat_a_t_dict[i + 1], feat_a[-m:]]) + + # next motion token + next_token_prob = self.token_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) + next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1) + topk_token_prob, next_token_idx = torch.topk(next_token_prob_softmax, k=self.motion_beam_size, dim=-1) # both (num_agent, beam_size) e.g. (31, 5) + + # next state token + next_state_prob = self.state_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) + next_state_idx = next_state_prob.softmax(dim=-1).argmax(dim=-1) + next_state_idx[next_state_idx == self.valid_state_type.index('exit')] = self.exit_state + next_state_idx[av_index] = self.valid_state # force ego_agent to be valid + + # convert the predicted token to a 0.5s (6 timesteps) trajectory + expanded_token_index = next_token_idx[..., None, None, None].expand(-1, -1, 6, 4, 2) + next_token_traj = torch.gather(agent_token_traj_all, 1, expanded_token_index) # (num_agent, beam_size, 6, 4, 2) + + # apply rotation and translation on 'next_token_traj' + theta = head_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t] + cos, sin = theta.cos(), theta.sin() + rot_mat = torch.zeros((num_agent, 2, 2), device=theta.device) + rot_mat[:, 0, 0] = cos + rot_mat[:, 0, 1] = sin + rot_mat[:, 1, 0] = -sin + rot_mat[:, 1, 1] = cos + agent_diff_rel = torch.bmm(next_token_traj.view(-1, 4, 2), + rot_mat[:, None, None, ...].repeat(1, self.motion_beam_size, self.shift + 1, 1, 1).view( + -1, 2, 2)).view(num_agent, self.motion_beam_size, self.shift + 1, 4, 2) + agent_pred_rel = agent_diff_rel + pos_a[:, None, None, None, (self.num_historical_steps - 1) // self.shift - 1 + t, :] + + # sample 1 most probable index of top beam_size tokens, (num_agent, beam_size) -> (num_agent, 1) + # then sample the agent_pred_rel, (num_agent, beam_size, 6, 4, 2) -> (num_agent, 6, 4, 2) + sample_token_index = torch.multinomial(topk_token_prob, 1).to(agent_pred_rel.device) + next_token_idx = next_token_idx.gather(dim=1, index=sample_token_index).squeeze(-1) + agent_pred_rel = agent_pred_rel.gather(dim=1, + index=sample_token_index[..., None, None, None].expand(-1, -1, 6, 4, + 2))[:, 0, ...] + + # get predicted position and heading of current shifted timesteps + diff_xy = agent_pred_rel[:, 1:, 0, :] - agent_pred_rel[:, 1:, 3, :] + pred_traj[:num_agent, t * 5 : (t + 1) * 5] = agent_pred_rel[:, 1:, ...].clone().mean(dim=2) + pred_head[:num_agent, t * 5 : (t + 1) * 5] = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0]) + pred_state[:num_agent, t * 5 : (t + 1) * 5] = next_state_idx[:, None].repeat(1, 5) + # pred_prob[:num_agent, t] = topk_token_prob.gather(dim=-1, index=sample_token_index)[:, 0] # (num_agent, beam_size) -> (num_agent,) + + # update pos/head/state of current step + pos_a[:, (self.num_historical_steps - 1) // self.shift + t] = agent_pred_rel[:, -1, ...].clone().mean(dim=1) + diff_xy = agent_pred_rel[:, -1, 0, :] - agent_pred_rel[:, -1, 3, :] + theta = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0]) + head_a[:, (self.num_historical_steps - 1) // self.shift + t] = theta + state_a[:, (self.num_historical_steps - 1) // self.shift + t] = next_state_idx + if self.use_grid_token: + grid_a[:, (self.num_historical_steps - 1) // self.shift + t] = self.attr_tokenizer.encode_pos( + x=pos_a[:, (self.num_historical_steps - 1) // self.shift + t], + y=pos_a[av_index, (self.num_historical_steps - 1) // self.shift + t], + theta_y=theta[av_index], + )[0] + + # the case that the current predicted state token is invalid/exit + is_eos = next_state_idx == self.exit_state + is_invalid = next_state_idx == self.invalid_state + + next_token_idx[is_invalid] = -1 + pos_a[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = 0. + head_a[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = 0. + if self.use_grid_token: + grid_a[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = -1 + + mask[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = False # to handle those newly-added agents + interact_mask[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = False + + agent_token_emb[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = self.no_token_emb(torch.zeros(1, device=device).long()) + + type_emb = categorical_embs[0].reshape(num_agent, num_infer_step, -1) + shape_emb = categorical_embs[1].reshape(num_agent, num_infer_step, -1) + type_emb[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = self.type_a_emb(torch.tensor(AGENT_TYPE.index('seed'), device=device).long()) + shape_emb[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = self.shape_emb(torch.full((1, 3), self.invalid_shape_value, device=device)) + categorical_embs = [type_emb.reshape(-1, self.hidden_dim), shape_emb.reshape(-1, self.hidden_dim)] + + # FIXME: need to discuss!!! + # if is_eos.any(): + + # pos_a[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = 0. + # head_a[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = 0. + # mask[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = False # to handle those newly-added agents + # interact_mask[torch.cat([is_eos, torch.zeros(1, device=is_eos.device).bool()]), (self.num_historical_steps - 1) // self.shift + t + 1:] = False + + # agent_token_emb[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = self.no_token_emb(torch.zeros(1, device=device).long()) + + # type_emb = categorical_embs[0].reshape(num_agent, num_infer_step, -1) + # shape_emb = categorical_embs[1].reshape(num_agent, num_infer_step, -1) + # type_emb[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = self.type_a_emb(torch.tensor(AGENT_TYPE.index('seed'), device=device).long()) + # shape_emb[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = self.shape_emb(torch.full((1, 3), self.invalid_shape_value, device=device)) + # categorical_embs = [type_emb.reshape(-1, self.hidden_dim), shape_emb.reshape(-1, self.hidden_dim)] + + # update token embeddings of current step + agent_token_emb[veh_mask, (self.num_historical_steps - 1) // self.shift + t] = agent_token_emb_veh[ + next_token_idx[veh_mask]] + agent_token_emb[ped_mask, (self.num_historical_steps - 1) // self.shift + t] = agent_token_emb_ped[ + next_token_idx[ped_mask]] + agent_token_emb[cyc_mask, (self.num_historical_steps - 1) // self.shift + t] = agent_token_emb_cyc[ + next_token_idx[cyc_mask]] + + # 4. update feat_a (t-1) + motion_vector_a, head_vector_a = self._build_vector_a(pos_a, head_a, state_a) + motion_vector_a[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0. + head_vector_a[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0. + + offset_pos = pos_a - pos_a[av_index] + + x_a = torch.stack( + [torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2]), + # torch.norm(offset_pos[:, :, :2], p=2, dim=-1), + ], dim=-1) + + x_a = self.x_a_emb(continuous_inputs=x_a.view(-1, x_a.size(-1)), + categorical_embs=categorical_embs) + x_a = x_a.view(-1, num_infer_step, self.hidden_dim) + + s_a = self.state_a_emb(state_a.reshape(-1).long()).reshape(-1, num_infer_step, self.hidden_dim) + feat_a = torch.cat((agent_token_emb, x_a, s_a), dim=-1) + if self.use_grid_token: + agent_grid_emb = self.grid_token_emb[grid_a] + feat_a = torch.cat([feat_a, agent_grid_emb], dim=-1) + feat_a = self.fusion_emb(feat_a) + + next_token_idx_list.append(next_token_idx[:, None]) + next_state_idx_list.append(next_state_idx[:, None]) + + # get log message + num_inserted_agents_total += num_new_agents + num_inserted_agents += num_new_agents + if num_new_agents > 0: + self.log(t, next_pos_seed, ego_pos_t_1, next_head_seed, ego_head_t_1, next_shape_seed, next_type_idx_seed) + + # pbar + allocated_memory = torch.cuda.memory_allocated() / (1024 ** 3) + pbar.set_postfix(memory=f'{allocated_memory:.2f}GB', + insert=f'{num_inserted_agents_total}/{seed_step_mask.sum()}') + + for i in range(len(next_token_idx_list)): + next_token_idx_list[i] = torch.cat([next_token_idx_list[i], torch.zeros(num_agent - next_token_idx_list[i].shape[0], 1, device=device) - 1], dim=0).long() # -1: invalid motion token + next_state_idx_list[i] = torch.cat([next_state_idx_list[i], torch.zeros(num_agent - next_state_idx_list[i].shape[0], 1, device=device)], dim=0).long() # 0: invalid state token + + # add history attributes + num_agent = pred_traj.shape[0] + num_init_agent = filter_mask.sum() + + pred_traj = torch.cat([torch.zeros(num_agent, self.num_historical_steps, *(pred_traj.shape[2:]), device=pred_traj.device), pred_traj], dim=1) + pred_head = torch.cat([torch.zeros(num_agent, self.num_historical_steps, *(pred_head.shape[2:]), device=pred_head.device), pred_head], dim=1) + pred_state = torch.cat([torch.zeros(num_agent, self.num_historical_steps, *(pred_state.shape[2:]), device=pred_state.device), pred_state], dim=1) + + pred_traj[:num_init_agent, 0] = data['agent']['position'][filter_mask, 0, :2] + pred_head[:num_init_agent, 0] = data['agent']['heading'][filter_mask, 0] + pred_state[:num_init_agent, 1 : self.num_historical_steps] = data['agent']['state_idx'][filter_mask, :(self.num_historical_steps - 1) // self.shift].repeat_interleave(repeats=self.shift, dim=1) + + historical_token_idx = data['agent']['token_idx'][filter_mask, :(self.num_historical_steps - 1) // self.shift] + historical_token_idx[historical_token_idx < 0] = 0 + historical_token_traj_all = torch.gather(agent_token_traj_all, 1, historical_token_idx[..., None, None, None].expand(-1, -1, 6, 4, 2)) + init_theta = head_a[:num_init_agent, 0] + cos, sin = init_theta.cos(), init_theta.sin() + rot_mat = torch.zeros((num_init_agent, 2, 2), device=init_theta.device) + rot_mat[:, 0, 0] = cos + rot_mat[:, 0, 1] = sin + rot_mat[:, 1, 0] = -sin + rot_mat[:, 1, 1] = cos + historical_token_traj_all = torch.bmm(historical_token_traj_all.view(-1, 4, 2), + rot_mat[:, None, None, ...].repeat(1, (self.num_historical_steps - 1) // self.shift, self.shift + 1, 1, 1).view( + -1, 2, 2)).view(num_init_agent, (self.num_historical_steps - 1) // self.shift, self.shift + 1, 4, 2) + historical_token_traj_all = historical_token_traj_all + pos_a[:num_init_agent, 0, :][:, None, None, None, ...] + pred_traj[:num_init_agent, 1 : self.num_historical_steps] = historical_token_traj_all[:, :, 1:, ...].clone().mean(dim=3).reshape(num_init_agent, -1, 2) + diff_xy = historical_token_traj_all[..., 1:, 0, :] - historical_token_traj_all[..., 1:, 3, :] + pred_head[:num_init_agent, 1 : self.num_historical_steps] = torch.arctan2(diff_xy[..., 1], diff_xy[..., 0]).reshape(num_init_agent, -1) + + # ! build z and valid + pred_z = torch.zeros_like(pred_traj[..., 0]) # hard code + pred_valid = (pred_state != self.invalid_state) & (pred_state != self.enter_state) + + # ! predefined agent shape + eval_shape = torch.zeros_like(pred_shape) + eval_shape[veh_mask] = torch.tensor(AGENT_SHAPE['vehicle'], device=device)[None, ...] + eval_shape[ped_mask] = torch.tensor(AGENT_SHAPE['pedstrain'], device=device)[None, ...] + eval_shape[cyc_mask] = torch.tensor(AGENT_SHAPE['cyclist'], device=device)[None, ...] + + next_token_idx = torch.cat(next_token_idx_list, dim=-1) + next_state_idx = torch.cat(next_state_idx_list, dim=-1) if len(next_state_idx_list) > 0 else None + + # sanity check + assert torch.all(pos_a[next_state_idx == self.invalid_state] == 0), f'Invalid step should have all zeros position!' + + if self.log_message == "": + self.log_message = "No agents inserted!" + else: + self.log_message += f"\nNumber of total inserted agents: {num_inserted_agents_total}/{seed_step_mask.sum()}" + + return { + 'ego_index': int(av_index), + 'agent_id': agent_id, + # 'valid_mask': agent_valid_mask[:, self.num_historical_steps:], + # 'pos_a': pos_a[:, (self.num_historical_steps - 1) // self.shift:], + # 'head_a': head_a[:, (self.num_historical_steps - 1) // self.shift:], + 'valid_mask': agent_valid_mask, # [n_agent, n_infer_step // shift] + 'pos_a': pos_a, # [n_agent, n_infer_step // shift, 2] + 'head_a': head_a, # [n_agent, n_infer_step // shift] + 'gt_traj': gt_traj, + 'pred_traj': pred_traj, # [n_agent, n_infer_step, 2] + 'pred_head': pred_head, # [n_agent, n_infer_step] + 'pred_type': pred_type, + 'pred_state': pred_state, + 'pred_z': pred_z, + 'pred_shape': pred_shape, + 'eval_shape': eval_shape, + 'pred_valid': pred_valid, + 'next_state_prob_seed': torch.cat(next_state_prob_seed_list, dim=1), + 'next_pos_rel_prob_seed': torch.cat(next_pos_rel_prob_seed_list, dim=1) + if self.use_grid_token else None, + 'next_token_idx': next_token_idx, # [n_agent, n_infer_step // shift] + 'next_state_idx': next_state_idx, # [n_agent, n_infer_step // shift] + 'grid_agent_occ_seed': torch.cat(grid_agent_occ_list, dim=1) + if self.use_grid_token else None, + 'grid_pt_occ_seed': torch.cat(grid_pt_occ_list, dim=1) + if self.use_grid_token else None, + 'grid_agent_occ_gt_seed': torch.cat(grid_agent_occ_gt_list, dim=1) + if self.use_grid_token else None, + 'agent_labels': agent_labels, + 'log_message': self.log_message, + } + + def log(self, t, next_pos_seed, ego_pos, next_head_seed, ego_head, next_shape_seed, next_type_idx_seed): + i = 0 + _repr_indent = 4 + for sa in range(next_pos_seed.shape[0]): + head = f"\n{i} agent {sa} is entering at step {t}" + body = [ + f"rel pos {(next_pos_seed[sa] - ego_pos).tolist()}, pos {next_pos_seed[sa].tolist()}", + f"rel head {wrap_angle(next_head_seed[sa] - ego_head).item()}, head {next_head_seed[sa].item()}", + f"shape {next_shape_seed[sa].tolist()}, type {next_type_idx_seed[sa].item()}", + ] + self.log_message += "\n".join([head] + [" " * _repr_indent + line for line in body]) + i += 1 diff --git a/backups/dev/modules/attr_tokenizer.py b/backups/dev/modules/attr_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..721bd446d6c2f288e774abc2a9af6c203d2253c5 --- /dev/null +++ b/backups/dev/modules/attr_tokenizer.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import numpy as np +from dev.utils.func import wrap_angle, angle_between_2d_vectors + + +class Attr_Tokenizer(nn.Module): + + def __init__(self, grid_range, grid_interval, radius, angle_interval): + super().__init__() + self.grid_range = grid_range + self.grid_interval = grid_interval + self.radius = radius + self.angle_interval = angle_interval + self.heading = torch.pi / 2 + self._prepare_grid() + + self.grid_size = self.grid.shape[0] + self.angle_size = int(360. / self.angle_interval) + + assert torch.all(self.grid[self.grid_size // 2] == 0.) + + def _prepare_grid(self): + num_grid = int(self.grid_range / self.grid_interval) + 1 # Do not use '//' + + x = torch.linspace(0, num_grid - 1, steps=num_grid) + y = torch.linspace(0, num_grid - 1, steps=num_grid) + grid_x, grid_y = torch.meshgrid(x, y, indexing='xy') + grid = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1) # (n^2, 2) + grid = grid.reshape(num_grid, num_grid, 2).flip(dims=[0]).reshape(-1, 2) + grid = (grid - x.shape[0] // 2) * self.grid_interval + + distance = (grid ** 2).sum(-1).sqrt() + square_mask = ((distance <= self.radius) & (distance >= 0.)) | (distance == 0.) + self.register_buffer('grid', grid[square_mask]) + self.register_buffer('dist', torch.norm(self.grid, p=2, dim=-1)) + head_vector = torch.stack([torch.tensor(self.heading).cos(), torch.tensor(self.heading).sin()]) + self.register_buffer('dir', angle_between_2d_vectors(ctr_vector=head_vector.unsqueeze(0), + nbr_vector=self.grid)) # (-pi, pi] + + self.num_grid = num_grid + self.square_mask = square_mask.numpy() + + def _apply_rot(self, x, theta): + # x: (b, l, 2) e.g. (num_step, num_agent, 2) + # theta: (b,) e.g. (num_step,) + cos, sin = theta.cos(), theta.sin() + rot_mat = torch.zeros((theta.shape[0], 2, 2), device=theta.device) + rot_mat[:, 0, 0] = cos + rot_mat[:, 0, 1] = sin + rot_mat[:, 1, 0] = -sin + rot_mat[:, 1, 1] = cos + x = torch.bmm(x, rot_mat) + return x + + def pad_square(self, prob, indices=None): + # square_mask: bool array of shape (n^2,) + # prob: float array of shape (num_step, m) + pad_prob = np.zeros((*prob.shape[:-1], self.square_mask.shape[0])) + pad_prob[..., self.square_mask] = prob + + square_indices = np.arange(self.square_mask.shape[0]) + circle_indices = np.concatenate([square_indices[self.square_mask], [-1]]) + if indices is not None: + indices = circle_indices[indices] + + return pad_prob, indices + + def get_grid(self, x, theta=None): + x = x.reshape(-1, 2) + grid = self.grid[None, ...].to(x.device) + if theta is not None: + grid = self._apply_rot(grid, (theta - self.heading).expand(x.shape[0])) + return x[:, None] + grid + + def encode_pos(self, x, y, theta_y=None): + assert x.dim() == y.dim() and x.shape[-1] == 2 and y.shape[-1] == 2, \ + f"Invalid input shape x: {x.shape}, y: {y.shape}." + centered_x = x - y + if theta_y is not None: + centered_x = self._apply_rot(centered_x[:, None], -(theta_y - self.heading).expand(x.shape[0]))[:, 0] + distance = ((centered_x[:, None] - self.grid.to(x.device)[None, ...]) ** 2).sum(-1).sqrt() + index = torch.argmin(distance, dim=-1) + + grid_xy = self.grid[index] + offset_xy = centered_x - grid_xy + + return index.long(), offset_xy + + def decode_pos(self, index, y=None, theta_y=None): + assert torch.all((index >= 0) & (index < self.grid_size)) + centered_x = self.grid.to(index.device)[index.long()] + if y is not None: + if theta_y is not None: + centered_x = self._apply_rot(centered_x[:, None], (theta_y - self.heading).expand(centered_x.shape[0]))[:, 0] + x = centered_x + y + return x.float() + return centered_x.float() + + def encode_heading(self, heading): + heading = (wrap_angle(heading) + torch.pi) / (2 * torch.pi) * 360 + index = heading // self.angle_interval + return index.long() + + def decode_heading(self, index): + assert torch.all(index >= 0) and torch.all(index < (360 / self.angle_interval)) + angles = index * self.angle_interval - 180 + angles = angles / 360 * (2 * torch.pi) + return angles.float() diff --git a/backups/dev/modules/debug.py b/backups/dev/modules/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..81b479fe255074dabdced79dbadae80a5ecabf0a --- /dev/null +++ b/backups/dev/modules/debug.py @@ -0,0 +1,1439 @@ +from typing import Dict, Mapping, Optional +import math +import torch +import torch.nn as nn +from torch_cluster import radius, radius_graph +from torch_geometric.data import HeteroData, Batch +from torch_geometric.utils import dense_to_sparse, subgraph + +from dev.modules.layers import * +from dev.modules.map_decoder import discretize_neighboring +from dev.utils.geometry import angle_between_2d_vectors, wrap_angle +from dev.utils.weight_init import weight_init + + +def cal_polygon_contour(x, y, theta, width, length): + left_front_x = x + 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta) + left_front_y = y + 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta) + left_front = (left_front_x, left_front_y) + + right_front_x = x + 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta) + right_front_y = y + 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta) + right_front = (right_front_x, right_front_y) + + right_back_x = x - 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta) + right_back_y = y - 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta) + right_back = (right_back_x, right_back_y) + + left_back_x = x - 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta) + left_back_y = y - 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta) + left_back = (left_back_x, left_back_y) + polygon_contour = [left_front, right_front, right_back, left_back] + + return polygon_contour + + +class SMARTAgentDecoder(nn.Module): + + def __init__(self, + dataset: str, + input_dim: int, + hidden_dim: int, + num_historical_steps: int, + num_interaction_steps: int, + time_span: Optional[int], + pl2a_radius: float, + pl2seed_radius: float, + a2a_radius: float, + num_freq_bands: int, + num_layers: int, + num_heads: int, + head_dim: int, + dropout: float, + token_data: Dict, + token_size: int, + special_token_index: list=[], + predict_motion: bool=False, + predict_state: bool=False, + predict_map: bool=False, + state_token: Dict[str, int]=None, + seed_size: int=5) -> None: + + super(SMARTAgentDecoder, self).__init__() + self.dataset = dataset + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.num_historical_steps = num_historical_steps + self.num_interaction_steps = num_interaction_steps + self.time_span = time_span if time_span is not None else num_historical_steps + self.pl2a_radius = pl2a_radius + self.pl2seed_radius = pl2seed_radius + self.a2a_radius = a2a_radius + self.num_freq_bands = num_freq_bands + self.num_layers = num_layers + self.num_heads = num_heads + self.head_dim = head_dim + self.dropout = dropout + self.special_token_index = special_token_index + self.predict_motion = predict_motion + self.predict_state = predict_state + self.predict_map = predict_map + + # state tokens + self.state_type = list(state_token.keys()) + self.state_token = state_token + self.invalid_state = int(state_token['invalid']) + self.valid_state = int(state_token['valid']) + self.enter_state = int(state_token['enter']) + self.exit_state = int(state_token['exit']) + + self.seed_state_type = ['invalid', 'enter'] + self.valid_state_type = ['invalid', 'valid', 'exit'] + + input_dim_x_a = 2 + input_dim_r_t = 4 + input_dim_r_pt2a = 3 + input_dim_r_a2a = 3 + input_dim_token = 8 # tokens: (token_size, 4, 2) + + self.seed_size = seed_size + + self.all_agent_type = ['veh', 'ped', 'cyc', 'background', 'invalid', 'seed'] + self.seed_agent_type = ['veh', 'ped', 'cyc', 'seed'] + self.type_a_emb = nn.Embedding(len(self.all_agent_type), hidden_dim) + self.shape_emb = MLPLayer(3, hidden_dim, hidden_dim) + if self.predict_state: + self.state_a_emb = nn.Embedding(len(self.state_type), hidden_dim) + self.invalid_shape_value = .1 + self.motion_gap = 1. + self.heading_gap = 1. + + self.x_a_emb = FourierEmbedding(input_dim=input_dim_x_a, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands) + self.r_t_emb = FourierEmbedding(input_dim=input_dim_r_t, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands) + self.r_pt2a_emb = FourierEmbedding(input_dim=input_dim_r_pt2a, hidden_dim=hidden_dim, + num_freq_bands=num_freq_bands) + self.r_a2a_emb = FourierEmbedding(input_dim=input_dim_r_a2a, hidden_dim=hidden_dim, + num_freq_bands=num_freq_bands) + self.token_emb_veh = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim) + self.token_emb_ped = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim) + self.token_emb_cyc = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim) + self.no_token_emb = nn.Embedding(1, hidden_dim) + self.bos_token_emb = nn.Embedding(1, hidden_dim) + # FIXME: do we need this??? + self.token_emb_offset = MLPEmbedding(input_dim=2, hidden_dim=hidden_dim) + + num_inputs = 2 + if self.predict_state: + num_inputs = 3 + self.fusion_emb = MLPEmbedding(input_dim=self.hidden_dim * num_inputs, hidden_dim=self.hidden_dim) + + self.t_attn_layers = nn.ModuleList( + [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, + bipartite=False, has_pos_emb=True) for _ in range(num_layers)] + ) + self.pt2a_attn_layers = nn.ModuleList( + [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, + bipartite=True, has_pos_emb=True) for _ in range(num_layers)] + ) + self.a2a_attn_layers = nn.ModuleList( + [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, + bipartite=False, has_pos_emb=True) for _ in range(num_layers)] + ) + self.token_size = token_size # 2048 + # agent motion prediction head + self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, + output_dim=self.token_size) + # agent state prediction head + if self.predict_state: + + self.seed_feature = nn.Embedding(self.seed_size, self.hidden_dim) + + self.state_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, + output_dim=len(self.valid_state_type)) + + self.seed_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, + output_dim=hidden_dim) + + self.seed_state_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, + output_dim=len(self.seed_state_type)) + self.seed_type_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, + output_dim=len(self.seed_agent_type)) + # entering token prediction + # FIXME: this is just under test!!! + # self.bos_pl_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, + # output_dim=200) + # self.bos_offset_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, + # output_dim=2601) + self.trajectory_token = token_data['token'] # dict('veh', 'ped', 'cyc') (2048, 4, 2) + self.trajectory_token_traj = token_data['traj'] # (2048, 6, 3) + self.trajectory_token_all = token_data['token_all'] # (2048, 6, 4, 2) + self.apply(weight_init) + + self.shift = 5 + self.beam_size = 5 + self.hist_mask = True + self.temporal_attn_to_invalid = True + self.temporal_attn_seed = False + + # FIXME: This is just under test!!! + # self.mapping_network = MappingNetwork(z_dim=hidden_dim, w_dim=hidden_dim, num_layers=num_layers) + + def transform_rel(self, token_traj, prev_pos, prev_heading=None): + if prev_heading is None: + diff_xy = prev_pos[:, :, -1, :] - prev_pos[:, :, -2, :] + prev_heading = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0]) + + num_agent, num_step, traj_num, traj_dim = token_traj.shape + cos, sin = prev_heading.cos(), prev_heading.sin() + rot_mat = torch.zeros((num_agent, num_step, 2, 2), device=prev_heading.device) + rot_mat[:, :, 0, 0] = cos + rot_mat[:, :, 0, 1] = -sin + rot_mat[:, :, 1, 0] = sin + rot_mat[:, :, 1, 1] = cos + agent_diff_rel = torch.bmm(token_traj.view(-1, traj_num, 2), rot_mat.view(-1, 2, 2)).view(num_agent, num_step, traj_num, traj_dim) + agent_pred_rel = agent_diff_rel + prev_pos[:, :, -1:, :] + return agent_pred_rel + + def agent_token_embedding(self, data, agent_token_index, agent_state, pos_a, head_a, inference=False, + filter_mask=None, av_index=None): + + if filter_mask is None: + filter_mask = torch.ones_like(agent_state[:, 2], dtype=torch.bool) + + num_agent, num_step, traj_dim = pos_a.shape # traj_dim=2 + agent_type = data['agent']['type'][filter_mask] + veh_mask = (agent_type == 0) + ped_mask = (agent_type == 1) + cyc_mask = (agent_type == 2) + + # set the position of invalid agents to the position of ego agent + # note here we only set invalid steps BEFORE the bos token! + # is_invalid = agent_state == self.invalid_state + # is_bos = agent_state == self.enter_state + # bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0)) + # bos_mask = torch.arange(num_step).expand(num_agent, -1).to(agent_state.device) < bos_index[:, None] + # is_invalid[~bos_mask] = False + + # ego_pos_a = pos_a[av_index].clone() + # ego_head_vector_a = head_vector_a[av_index].clone() + # pos_a[is_invalid] = ego_pos_a[None, :].repeat(pos_a.shape[0], 1, 1)[is_invalid] + # head_vector_a[is_invalid] = ego_head_vector_a[None, :].repeat(head_vector_a.shape[0], 1, 1)[is_invalid] + + motion_vector_a, head_vector_a = self.build_vector_a(pos_a, head_a, agent_state) + + trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(pos_a.device).to(torch.float) + trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(pos_a.device).to(torch.float) + trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(pos_a.device).to(torch.float) + self.agent_token_emb_veh = self.token_emb_veh(trajectory_token_veh.view(trajectory_token_veh.shape[0], -1)) # (token_size, 8) + self.agent_token_emb_ped = self.token_emb_ped(trajectory_token_ped.view(trajectory_token_ped.shape[0], -1)) + self.agent_token_emb_cyc = self.token_emb_cyc(trajectory_token_cyc.view(trajectory_token_cyc.shape[0], -1)) + + # add bos token embedding + self.agent_token_emb_veh = torch.cat([self.agent_token_emb_veh, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())]) + self.agent_token_emb_ped = torch.cat([self.agent_token_emb_ped, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())]) + self.agent_token_emb_cyc = torch.cat([self.agent_token_emb_cyc, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())]) + + # add invalid token embedding + self.agent_token_emb_veh = torch.cat([self.agent_token_emb_veh, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())]) + self.agent_token_emb_ped = torch.cat([self.agent_token_emb_ped, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())]) + self.agent_token_emb_cyc = torch.cat([self.agent_token_emb_cyc, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())]) + + if inference: + agent_token_traj_all = torch.zeros((num_agent, self.token_size, self.shift + 1, 4, 2), device=pos_a.device) + trajectory_token_all_veh = torch.from_numpy(self.trajectory_token_all['veh']).clone().to(pos_a.device).to(torch.float) + trajectory_token_all_ped = torch.from_numpy(self.trajectory_token_all['ped']).clone().to(pos_a.device).to(torch.float) + trajectory_token_all_cyc = torch.from_numpy(self.trajectory_token_all['cyc']).clone().to(pos_a.device).to(torch.float) + agent_token_traj_all[veh_mask] = torch.cat( + [trajectory_token_all_veh[:, :self.shift], trajectory_token_veh[:, None, ...]], dim=1) + agent_token_traj_all[ped_mask] = torch.cat( + [trajectory_token_all_ped[:, :self.shift], trajectory_token_ped[:, None, ...]], dim=1) + agent_token_traj_all[cyc_mask] = torch.cat( + [trajectory_token_all_cyc[:, :self.shift], trajectory_token_cyc[:, None, ...]], dim=1) + + # additional token embeddings are already added -> -1: invalid, -2: bos + agent_token_emb = torch.zeros((num_agent, num_step, self.hidden_dim), device=pos_a.device) + agent_token_emb[veh_mask] = self.agent_token_emb_veh[agent_token_index[veh_mask]] + agent_token_emb[ped_mask] = self.agent_token_emb_ped[agent_token_index[ped_mask]] + agent_token_emb[cyc_mask] = self.agent_token_emb_cyc[agent_token_index[cyc_mask]] + + # 'vehicle', 'pedestrian', 'cyclist', 'background' + is_invalid = (agent_state == self.invalid_state) & (agent_state != self.enter_state) + agent_types = data['agent']['type'][filter_mask].long().repeat_interleave(repeats=num_step, dim=0) + agent_types[is_invalid.reshape(-1)] = self.all_agent_type.index('invalid') + agent_shapes = data['agent']['shape'][filter_mask, self.num_historical_steps - 1, :].repeat_interleave(repeats=num_step, dim=0) + agent_shapes[is_invalid.reshape(-1)] = self.invalid_shape_value + + categorical_embs = [self.type_a_emb(agent_types), self.shape_emb(agent_shapes)] + feature_a = torch.stack( + [torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2]), + ], dim=-1) # (num_agent, num_shifted_step, 2) + + x_a = self.x_a_emb(continuous_inputs=feature_a.view(-1, feature_a.size(-1)), + categorical_embs=categorical_embs) + x_a = x_a.view(-1, num_step, self.hidden_dim) # (num_agent, num_step, hidden_dim) + + s_a = self.state_a_emb(agent_state.reshape(-1).long()).reshape(num_agent, num_step, self.hidden_dim) + feat_a = torch.cat((agent_token_emb, x_a, s_a), dim=-1) # (num_agent, num_step, hidden_dim * 3) + feat_a = self.fusion_emb(feat_a) # (num_agent, num_step, hidden_dim) + + # seed agent feature + motion_vector_seed = motion_vector_a[av_index : av_index + 1] + head_vector_seed = head_vector_a[av_index : av_index + 1] + feat_seed = self.build_invalid_agent_feature(num_step, pos_a.device, type_index=self.all_agent_type.index('seed'), + motion_vector=motion_vector_seed, head_vector=head_vector_seed) + + # replace the features of steps before bos of valid agents with the corresponding invalid agent features + # is_bos = agent_state == self.enter_state + # is_eos = agent_state == self.exit_state + # bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0)) + # eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1)) + # is_before_bos = torch.arange(num_step).expand(num_agent, -1).to(agent_state.device) < bos_index[:, None] + # is_after_eos = torch.arange(num_step).expand(num_agent, -1).to(agent_state.device) > eos_index[:, None] + 1 + # feat_ina = self.build_invalid_agent_feature(num_step, pos_a.device) + # feat_a[is_before_bos | is_after_eos] = feat_ina.repeat(num_agent, 1, 1)[is_before_bos | is_after_eos] + + # print("train") + # is_bos = agent_state == self.enter_state + # is_eos = agent_state == self.exit_state + # bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0)) + # eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1)) + # mask = torch.arange(num_step).expand(num_agent, -1).to(agent_state.device) + # mask = (mask >= bos_index[:, None]) & (mask <= eos_index[:, None] + 1) + # is_invalid[mask] = False + # print(feat_a.sum(dim=-1)[is_invalid]) + + feat_a = torch.cat([feat_a, feat_seed], dim=0) # (num_agent + 1, num_step, hidden_dim) + + # feat_a_sum = feat_a.sum(dim=-1) + # for a in range(num_agent): + # print(f"agent {a}:") + # print(f"state: {agent_state[a, :]}") + # print(f"feat_a_sum: {feat_a_sum[a, :]}") + # exit(1) + + if inference: + return feat_a, head_vector_a, agent_token_traj_all, agent_token_emb, categorical_embs + else: + return feat_a, head_vector_a + + def build_vector_a(self, pos_a, head_a, state_a): + num_agent = pos_a.shape[0] + + motion_vector_a = torch.cat([pos_a.new_zeros(num_agent, 1, self.input_dim), + pos_a[:, 1:] - pos_a[:, :-1]], dim=1) + + # update the relative motion/head vectors + is_bos = state_a == self.enter_state + motion_vector_a[is_bos] = self.motion_gap + + is_last_eos = state_a.roll(shifts=1, dims=1) == self.exit_state + is_last_eos[:, 0] = False + motion_vector_a[is_last_eos] = -self.motion_gap + + head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1) + + return motion_vector_a, head_vector_a + + def build_invalid_agent_feature(self, num_step, device, motion_vector=None, head_vector=None, type_index=None, shape_value=None): + invalid_agent_token_emb = self.no_token_emb(torch.zeros(1, device=device).long())[:, None].repeat(1, num_step, 1) + + if motion_vector is None or head_vector is None: + motion_vector = torch.zeros((1, num_step, 2), device=device) + head_vector = torch.stack([torch.cos(torch.zeros(1, device=device)), torch.sin(torch.zeros(1, device=device))], dim=-1)[:, None, :].repeat(1, num_step, 1) + + feature_ina = torch.stack( + [torch.norm(motion_vector[:, :, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=head_vector, nbr_vector=motion_vector[:, :, :2]), + ], dim=-1) + + if type_index is None: + type_index = self.all_agent_type.index('invalid') + if shape_value is None: + shape_value = torch.full((1, 3), self.invalid_shape_value, device=device) + + categorical_embs_ina = [self.type_a_emb(torch.tensor([type_index], device=device)), + self.shape_emb(shape_value)] + x_ina = self.x_a_emb(continuous_inputs=feature_ina.view(-1, feature_ina.size(-1)), + categorical_embs=categorical_embs_ina) + x_ina = x_ina.view(-1, num_step, self.hidden_dim) # (1, num_step, hidden_dim) + + s_ina = self.state_a_emb(torch.tensor([self.invalid_state], device=device))[:, None].repeat(1, num_step, 1) # NOTE: do not use `expand` + + feat_ina = torch.cat((invalid_agent_token_emb, x_ina, s_ina), dim=-1) + feat_ina = self.fusion_emb(feat_ina) # (1, num_step, hidden_dim) + + return feat_ina + + def build_temporal_edge(self, pos_a, head_a, head_vector_a, state_a, mask, inference_mask=None, av_index=None): + + num_agent = pos_a.shape[0] + hist_mask = mask.clone() + + if not self.temporal_attn_to_invalid: + hist_mask[state_a == self.invalid_state] = False + + # set the position of invalid agents to the position of ego agent + ego_pos_a = pos_a[av_index].clone() # (num_step, 2) + ego_head_a = head_a[av_index].clone() + ego_head_vector_a = head_vector_a[av_index].clone() + ego_state_a = state_a[av_index].clone() + # is_invalid = state_a == self.invalid_state + # pos_a[is_invalid] = ego_pos_a[None, :].repeat(pos_a.shape[0], 1, 1)[is_invalid] + # head_a[is_invalid] = ego_head_a[None, :].repeat(head_a.shape[0], 1)[is_invalid] + + # add seed agent + pos_a = torch.cat([pos_a, ego_pos_a[None]], dim=0) + head_a = torch.cat([head_a, ego_head_a[None]], dim=0) + state_a = torch.cat([state_a, ego_state_a[None]], dim=0) + head_vector_a = torch.cat([head_vector_a, ego_head_vector_a[None]], dim=0) + hist_mask = torch.cat([hist_mask, torch.ones_like(hist_mask[0:1])], dim=0).bool() + if not self.temporal_attn_seed: + hist_mask[-1:] = False + if inference_mask is not None: + inference_mask[-1:] = False + + pos_t = pos_a.reshape(-1, self.input_dim) # (num_agent * num_step, ...) + head_t = head_a.reshape(-1) + head_vector_t = head_vector_a.reshape(-1, 2) + + # for those invalid agents won't predict any motion token, we don't attend to them + is_bos = state_a == self.enter_state + is_bos[-1] = False + bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0)) + motion_predict_start_index = torch.clamp(bos_index - self.time_span / self.shift + 1, min=0) + motion_predict_mask = torch.arange(hist_mask.shape[1]).expand(hist_mask.shape[0], -1).to(hist_mask.device) + motion_predict_mask = motion_predict_mask >= motion_predict_start_index[:, None] + hist_mask[~motion_predict_mask] = False + + if self.hist_mask and self.training: + hist_mask[ + torch.arange(mask.shape[0]).unsqueeze(1), torch.randint(0, mask.shape[1], (num_agent, 10))] = False + mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1) + elif inference_mask is not None: + mask_t = hist_mask.unsqueeze(2) & inference_mask.unsqueeze(1) + else: + mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1) + + # mask_t: (num_agent, 18, 18), edge_index_t: (2, num_edge) + edge_index_t = dense_to_sparse(mask_t)[0] + edge_index_t = edge_index_t[:, (edge_index_t[1] - edge_index_t[0] > 0) & + (edge_index_t[1] - edge_index_t[0] <= self.time_span / self.shift)] + rel_pos_t = pos_t[edge_index_t[0]] - pos_t[edge_index_t[1]] + rel_head_t = wrap_angle(head_t[edge_index_t[0]] - head_t[edge_index_t[1]]) + + # FIXME relative motion/head for bos/eos token + # is_next_bos = state_a.roll(shifts=-1, dims=1) == self.enter_state + # is_next_bos[:, -1] = False # the last step + # is_next_bos_t = is_next_bos.reshape(-1) + # rel_pos_t[is_next_bos_t[edge_index_t[0]]] = -self.bos_motion + # rel_pos_t[is_next_bos_t[edge_index_t[1]]] = self.bos_motion + # rel_head_t[is_next_bos_t[edge_index_t[0]]] = -torch.pi + # rel_head_t[is_next_bos_t[edge_index_t[1]]] = torch.pi + + # is_last_eos = state_a.roll(shifts=1, dims=1) == self.exit_state + # is_last_eos[:, 0] = False # the first step + # is_last_eos_t = is_last_eos.reshape(-1) + # rel_pos_t[is_last_eos_t[edge_index_t[0]]] = -self.bos_motion + # rel_pos_t[is_last_eos_t[edge_index_t[1]]] = self.bos_motion + # rel_head_t[is_last_eos_t[edge_index_t[0]]] = -torch.pi + # rel_head_t[is_last_eos_t[edge_index_t[1]]] = torch.pi + + # handle the bos token of ego agent + # is_invalid = state_a == self.invalid_state + # is_invalid_t = is_invalid.reshape(-1) + # is_ego_bos = (ego_state_a == self.enter_state)[None, :].expand(num_agent + 1, -1) + # is_ego_bos_t = is_ego_bos.reshape(-1) + # rel_pos_t[is_invalid_t[edge_index_t[0]] & is_ego_bos_t[edge_index_t[0]]] = 0. + # rel_pos_t[is_invalid_t[edge_index_t[1]] & is_ego_bos_t[edge_index_t[1]]] = 0. + # rel_head_t[is_invalid_t[edge_index_t[0]] & is_ego_bos_t[edge_index_t[0]]] = 0. + # rel_head_t[is_invalid_t[edge_index_t[1]] & is_ego_bos_t[edge_index_t[1]]] = 0. + + # handle the invalid steps + is_invalid = state_a == self.invalid_state + is_invalid_t = is_invalid.reshape(-1) + rel_pos_t[is_invalid_t[edge_index_t[0]]] = -self.motion_gap + rel_pos_t[is_invalid_t[edge_index_t[1]]] = self.motion_gap + rel_head_t[is_invalid_t[edge_index_t[0]]] = -self.heading_gap + rel_head_t[is_invalid_t[edge_index_t[1]]] = self.heading_gap + + r_t = torch.stack( + [torch.norm(rel_pos_t[:, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=head_vector_t[edge_index_t[1]], nbr_vector=rel_pos_t[:, :2]), + rel_head_t, + edge_index_t[0] - edge_index_t[1]], dim=-1) + r_t = self.r_t_emb(continuous_inputs=r_t, categorical_embs=None) + + return edge_index_t, r_t + + def build_interaction_edge(self, pos_a, head_a, head_vector_a, state_a, batch_s, mask_a, inference_mask=None, av_index=None): + num_agent, num_step, _ = pos_a.shape + + pos_a = torch.cat([pos_a, pos_a[av_index][None]], dim=0) + head_a = torch.cat([head_a, head_a[av_index][None]], dim=0) + state_a = torch.cat([state_a, state_a[av_index][None]], dim=0) + head_vector_a = torch.cat([head_vector_a, head_vector_a[av_index][None]], dim=0) + + pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim) + head_s = head_a.transpose(0, 1).reshape(-1) + head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2) + if inference_mask is not None: + mask_a = mask_a & inference_mask + mask_s = mask_a.transpose(0, 1).reshape(-1) + + # seed agent + mask_seed = state_a[av_index] != self.invalid_state + pos_seed = pos_a[av_index] + edge_index_seed2a = radius(x=pos_seed[:, :2], y=pos_s[:, :2], r=self.pl2seed_radius, + batch_x=torch.arange(num_step).to(pos_s.device), batch_y=batch_s, max_num_neighbors=300) + edge_index_seed2a = edge_index_seed2a[:, mask_s[edge_index_seed2a[0]] & mask_seed[edge_index_seed2a[1]]] + + # convert to global index (must be unilateral connection) + edge_index_seed2a[1, :] = (edge_index_seed2a[1, :] + 1) * (num_agent + 1) - 1 + + # build agent2agent bilateral connection + edge_index_a2a = radius_graph(x=pos_s[:, :2], r=self.a2a_radius, batch=batch_s, loop=False, + max_num_neighbors=300) + edge_index_a2a = subgraph(subset=mask_s, edge_index=edge_index_a2a)[0] + + # add the edges which connect seed agents + edge_index_a2a = torch.cat([edge_index_a2a, edge_index_seed2a], dim=-1) + + # set the position of invalid agents to the position of ego agent + # ego_pos_a = pos_a[av_index].clone() + # ego_head_a = head_a[av_index].clone() + # is_invalid = state_a == self.invalid_state + # pos_a[is_invalid] = ego_pos_a[None, :].repeat(pos_a.shape[0], 1, 1)[is_invalid] + # head_a[is_invalid] = ego_head_a[None, :].repeat(head_a.shape[0], 1)[is_invalid] + + rel_pos_a2a = pos_s[edge_index_a2a[0]] - pos_s[edge_index_a2a[1]] + rel_head_a2a = wrap_angle(head_s[edge_index_a2a[0]] - head_s[edge_index_a2a[1]]) + + # relative motion/head for bos/eos token + # is_bos = state_a == self.enter_state + # is_bos_s = is_bos.transpose(0, 1).reshape(-1) + # rel_pos_a2a[is_bos_s[edge_index_a2a[0]]] = -self.bos_motion + # rel_pos_a2a[is_bos_s[edge_index_a2a[1]]] = self.bos_motion + # rel_head_a2a[is_bos_s[edge_index_a2a[0]]] = -torch.pi + # rel_head_a2a[is_bos_s[edge_index_a2a[1]]] = torch.pi + + # is_last_eos = state_a.roll(shifts=-1, dims=1) == self.exit_state + # is_last_eos[:, 0] = False # first step + # is_last_eos_s = is_last_eos.transpose(0, 1).reshape(-1) + # rel_pos_a2a[is_last_eos_s[edge_index_a2a[0]]] = -self.bos_motion + # rel_pos_a2a[is_last_eos_s[edge_index_a2a[1]]] = self.bos_motion + # rel_head_a2a[is_last_eos_s[edge_index_a2a[0]]] = -torch.pi + # rel_head_a2a[is_last_eos_s[edge_index_a2a[1]]] = torch.pi + + # handle the bos token of ego agent + # is_invalid = state_a == self.invalid_state + # is_invalid_s = is_invalid.transpose(0, 1).reshape(-1) + # is_ego_bos = (state_a[av_index] == self.enter_state)[None, :].expand(num_agent + 1, -1) + # is_ego_bos_s = is_ego_bos.transpose(0, 1).reshape(-1) + # rel_pos_a2a[is_invalid_s[edge_index_a2a[0]] & is_ego_bos_s[edge_index_a2a[0]]] = 0. + # rel_pos_a2a[is_invalid_s[edge_index_a2a[1]] & is_ego_bos_s[edge_index_a2a[1]]] = 0. + # rel_head_a2a[is_invalid_s[edge_index_a2a[0]] & is_ego_bos_s[edge_index_a2a[0]]] = 0. + # rel_head_a2a[is_invalid_s[edge_index_a2a[1]] & is_ego_bos_s[edge_index_a2a[1]]] = 0. + + # handle the invalid steps + is_invalid = state_a == self.invalid_state + is_invalid_s = is_invalid.transpose(0, 1).reshape(-1) + rel_pos_a2a[is_invalid_s[edge_index_a2a[0]]] = -self.motion_gap + rel_pos_a2a[is_invalid_s[edge_index_a2a[1]]] = self.motion_gap + rel_head_a2a[is_invalid_s[edge_index_a2a[0]]] = -self.heading_gap + rel_head_a2a[is_invalid_s[edge_index_a2a[1]]] = self.heading_gap + + r_a2a = torch.stack( + [torch.norm(rel_pos_a2a[:, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2a[1]], nbr_vector=rel_pos_a2a[:, :2]), + rel_head_a2a], dim=-1) + r_a2a = self.r_a2a_emb(continuous_inputs=r_a2a, categorical_embs=None) + + return edge_index_a2a, r_a2a + + def build_map2agent_edge(self, data, num_step, pos_a, head_a, head_vector_a, state_a, batch_s, batch_pl, + mask, inference_mask=None, av_index=None): + + num_agent, num_step, _ = pos_a.shape + + mask_pl2a = mask.clone() + if inference_mask is not None: + mask_pl2a = mask_pl2a & inference_mask + mask_pl2a = mask_pl2a.transpose(0, 1).reshape(-1) + + pos_a = torch.cat([pos_a, pos_a[av_index][None]], dim=0) + state_a = torch.cat([state_a, state_a[av_index][None]], dim=0) + head_a = torch.cat([head_a, head_a[av_index][None]], dim=0) + head_vector_a = torch.cat([head_vector_a, head_vector_a[av_index][None]], dim=0) + + pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim) + head_s = head_a.transpose(0, 1).reshape(-1) + head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2) + + ori_pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous() + ori_orient_pl = data['pt_token']['orientation'].contiguous() + pos_pl = ori_pos_pl.repeat(num_step, 1) # not `repeat_interleave` + orient_pl = ori_orient_pl.repeat(num_step) + + # seed agent + mask_seed = state_a[av_index] != self.invalid_state + pos_seed = pos_a[av_index] + edge_index_pl2seed = radius(x=pos_seed[:, :2], y=pos_pl[:, :2], r=self.pl2seed_radius, + batch_x=torch.arange(num_step).to(pos_s.device), batch_y=batch_pl, max_num_neighbors=600) + edge_index_pl2seed = edge_index_pl2seed[:, mask_seed[edge_index_pl2seed[1]]] + + # convert to global index + edge_index_pl2seed[1, :] = (edge_index_pl2seed[1, :] + 1) * (num_agent + 1) - 1 + + # build map2agent directed graph + # edge_index_pl2a[0]: pl token; edge_index_pl2a[1]: agent token + edge_index_pl2a = radius(x=pos_s[:, :2], y=pos_pl[:, :2], r=self.pl2a_radius, + batch_x=batch_s, batch_y=batch_pl, max_num_neighbors=300) + # We force invalid agents to interact with **all** (visible in current window) map tokens + # invalid_node_index_a = torch.where(bos_state_s.bool())[0] + # sampled_node_index_m = torch.arange(ori_pos_pl.shape[0]).to(pos_pl.device) + # if kwargs.get('sample_pt_indices', None) is not None: + # sampled_node_index_m = sampled_node_index_m[kwargs['sample_pt_indices'].long()] + # grid_a, grid_b = torch.meshgrid(sampled_node_index_m, invalid_node_index_a, indexing='ij') + # invalid_edge_index_pl2a = torch.stack([grid_a.reshape(-1), grid_b.reshape(-1)], dim=0) + # edge_index_pl2a = torch.concat([edge_index_pl2a, invalid_edge_index_pl2a], dim=-1) + # remove the edges which connect with motion-invalid agents + edge_index_pl2a = edge_index_pl2a[:, mask_pl2a[edge_index_pl2a[1]]] + + # add the edges which connect seed agents with map tokens + edge_index_pl2a = torch.cat([edge_index_pl2a, edge_index_pl2seed], dim=-1) + + # set the position of invalid agents to the position of ego agent + # ego_pos_a = pos_a[av_index].clone() + # ego_head_a = head_a[av_index].clone() + # is_invalid = state_a == self.invalid_state + # pos_a[is_invalid] = ego_pos_a[None, :].repeat(pos_a.shape[0], 1, 1)[is_invalid] + # head_a[is_invalid] = ego_head_a[None, :].repeat(head_a.shape[0], 1)[is_invalid] + + rel_pos_pl2a = pos_pl[edge_index_pl2a[0]] - pos_s[edge_index_pl2a[1]] + rel_orient_pl2a = wrap_angle(orient_pl[edge_index_pl2a[0]] - head_s[edge_index_pl2a[1]]) + + # handle the invalid steps + is_invalid = state_a == self.invalid_state + is_invalid_s = is_invalid.transpose(0, 1).reshape(-1) + rel_pos_pl2a[is_invalid_s[edge_index_pl2a[1]]] = self.motion_gap + rel_orient_pl2a[is_invalid_s[edge_index_pl2a[1]]] = self.heading_gap + + r_pl2a = torch.stack( + [torch.norm(rel_pos_pl2a[:, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_pl2a[1]], nbr_vector=rel_pos_pl2a[:, :2]), + rel_orient_pl2a], dim=-1) + r_pl2a = self.r_pt2a_emb(continuous_inputs=r_pl2a, categorical_embs=None) + + return edge_index_pl2a, r_pl2a + + def get_inputs(self, data: HeteroData) -> Dict[str, torch.Tensor]: + + pos_a = data['agent']['token_pos'] + head_a = data['agent']['token_heading'] + agent_category = data['agent']['category'] + agent_token_index = data['agent']['token_idx'] + agent_state_index = data['agent']['state_idx'] + mask = data['agent']['raw_agent_valid_mask'].clone() + # mask[agent_category != 3] = False + + if not self.predict_state: + agent_state_index = None + + next_token_index_gt = agent_token_index.roll(shifts=-1, dims=1) + next_state_index_gt = agent_state_index.roll(shifts=-1, dims=1) + + if self.predict_state: + next_token_eval_mask = mask.clone() + next_token_eval_mask = next_token_eval_mask * next_token_eval_mask.roll(shifts=1, dims=1) + bos_token_index = torch.nonzero(agent_state_index == 2) + eos_token_index = torch.nonzero(agent_state_index == 3) + next_token_eval_mask[bos_token_index[:, 0], bos_token_index[:, 1]] = 1 + for eos_token_index_ in eos_token_index: + if not next_token_eval_mask[eos_token_index_[0], eos_token_index_[1]]: + next_token_eval_mask[eos_token_index_[0], eos_token_index_[1]:] = 0 + next_token_eval_mask = next_token_eval_mask.roll(shifts=-1, dims=1) + # TODO: next_state_eval_mask !!! + + if next_token_index_gt[next_token_eval_mask].min() < 0: + raise RuntimeError() + + next_token_eval_mask[:, -1] = False + + return {'token_pos': pos_a, + 'token_heading': head_a, + 'agent_category': agent_category, + 'next_token_idx_gt': next_token_index_gt, + 'next_state_idx_gt': next_state_index_gt, + 'next_token_eval_mask': next_token_eval_mask, + 'raw_agent_valid_mask': data['agent']['raw_agent_valid_mask'], + } + + def forward(self, + data: HeteroData, + map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + + pos_a = data['agent']['token_pos'].clone() # (num_agent, num_shifted_step, 2) + head_a = data['agent']['token_heading'].clone() # (num_agent, num_shifted_step) + num_agent, num_step, traj_dim = pos_a.shape # e.g. (50, 18, 2) + agent_category = data['agent']['category'].clone() # (num_agent,) + agent_token_index = data['agent']['token_idx'].clone() # (num_agent, num_step) + agent_state_index = data['agent']['state_idx'].clone() # (num_agent, num_step) + agent_type_index = data['agent']['type'].clone() # (num_agent, num_step) + agent_enter_pl_token_idx = None + agent_enter_offset_token_idx = None + + device = pos_a.device + + seed_step_mask = agent_state_index[:, 1:] == self.enter_state + if torch.any(seed_step_mask.sum(dim=0) > self.seed_size): + print(agent_state_index) + print(agent_state_index.shape) + print(seed_step_mask.long()) + print(seed_step_mask.sum(dim=0)) + raise RuntimeError(f"Seed size {self.seed_size} is too small.") + + # fix pos and head of invalid agents + av_index = int(data['agent']['av_index']) + # ego_pos_a = pos_a[av_index].clone() # (num_shifted_step, 2) + # ego_head_vector_a = head_vector_a[av_index] # (num_shifted_step, 2) + # is_invalid = agent_state_index == self.invalid_state + # pos_a[is_invalid] = ego_pos_a[None, :].expand(pos_a.shape[0], -1, -1)[is_invalid] + # head_vector_a[is_invalid] = ego_head_vector_a[None, :].expand(head_vector_a.shape[0], -1, -1)[is_invalid] + + if not self.predict_state: + agent_state_index = None + + feat_a, head_vector_a = self.agent_token_embedding(data, agent_token_index, agent_state_index, pos_a, head_a, av_index=av_index) + + # build masks + mask = data['agent']['raw_agent_valid_mask'].clone() + temporal_mask = mask.clone() + interact_mask = mask.clone() + if self.predict_state: + + agent_enter_offset_token_idx = data['agent']['neighbor_token_idx'] + agent_enter_pl_token_idx = data['agent']['map_bos_token_idx'] + agent_enter_pl_token_id = data['agent']['map_bos_token_id'] + + is_bos = agent_state_index == self.enter_state + is_eos = agent_state_index == self.exit_state + bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0)) + eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1)) # not `-1` + + temporal_mask = torch.ones_like(mask) + motion_mask = torch.arange(mask.shape[1]).expand(mask.shape[0], -1).to(device) + motion_mask = (motion_mask > bos_index[:, None]) & (motion_mask <= eos_index[:, None]) + temporal_mask[motion_mask] = mask[motion_mask] + + interact_mask[agent_state_index == self.enter_state] = True + interact_mask = torch.cat([interact_mask, torch.ones_like(interact_mask[:1])]).bool() # placeholder + + edge_index_t, r_t = self.build_temporal_edge(pos_a, head_a, head_vector_a, agent_state_index, temporal_mask, + av_index=av_index) + + # +1: placeholder for seed agent + # if isinstance(data, Batch): + # print(data['agent']['batch'], data.num_graphs) + # batch_s = torch.cat([data['agent']['batch'] + data.num_graphs * t for t in range(num_step)], dim=0) + # batch_pl = torch.cat([data['pt_token']['batch'] + data.num_graphs * t for t in range(num_step)], dim=0) + # else: + batch_s = torch.arange(num_step, device=device).repeat_interleave(data['agent']['num_nodes'] + 1) + batch_pl = torch.arange(num_step, device=device).repeat_interleave(data['pt_token']['num_nodes']) + + edge_index_a2a, r_a2a = self.build_interaction_edge(pos_a, head_a, head_vector_a, agent_state_index, batch_s, + interact_mask, av_index=av_index) + + agent_category = torch.cat([agent_category, torch.full(agent_category[-1:].shape, 3, device=device)]) + interact_mask[agent_category != 3] = False + edge_index_pl2a, r_pl2a = self.build_map2agent_edge(data, num_step, pos_a, head_a, head_vector_a, + agent_state_index, batch_s, batch_pl, interact_mask, av_index=av_index) + + # mapping network + # z = torch.randn(num_agent, self.hidden_dim).to(feat_a.device) + # w = self.mapping_network(z) + + for i in range(self.num_layers): + + # feat_a = feat_a + w[:, None] + + feat_a = feat_a.reshape(-1, self.hidden_dim) # (num_agent, num_step, hidden_dim) -> (seq_len, hidden_dim) + feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t) + + feat_a = feat_a.reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim) + feat_a = self.pt2a_attn_layers[i](( + map_enc['x_pt'].repeat_interleave(repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape( + -1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a) + + feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a) + feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1) + + # next motion token + next_token_prob = self.token_predict_head(feat_a[:-1]) # (num_agent, num_step, token_size) + next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1) + _, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1) # (num_agent, num_step, 10) + + next_token_index_gt = agent_token_index.roll(shifts=-1, dims=1) + + # next state token + next_state_prob = self.state_predict_head(feat_a[:-1]) + next_state_idx = next_state_prob.softmax(dim=-1).argmax(dim=-1, keepdim=True) # (num_agent, num_step, 1) + + next_state_index_gt = agent_state_index.roll(shifts=-1, dims=1) # (invalid, valid, exit) + + # seed agent + feat_seed = self.seed_head(feat_a[-1:]) + self.seed_feature.weight[:, None] + next_state_prob_seed = self.seed_state_predict_head(feat_seed) + next_state_idx_seed = next_state_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True) # (self.seed_size, num_step, 1) + + next_type_prob_seed = self.seed_type_predict_head(feat_seed) + next_type_idx_seed = next_type_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True) + + next_type_index_gt = agent_type_index[:, None].expand(-1, num_step).roll(shifts=-1, dims=1) + + # polygon token for bos token + # next_bos_pl_prob = self.bos_pl_predict_head(feat_a) + # next_bos_pl_prob_softmax = torch.softmax(next_bos_pl_prob, dim=-1) + # _, next_bos_pl_idx = torch.topk(next_bos_pl_prob_softmax, k=1, dim=-1) # (num_agent, num_step, 1) + + # next_bos_pl_index_gt = agent_enter_pl_token_id.roll(shifts=-1, dims=-1) + + # offset token for bos token + # next_bos_offset_prob = self.bos_offset_predict_head(feat_a) + # next_bos_offset_prob_softmax = torch.softmax(next_bos_offset_prob, dim=-1) + # _, next_bos_offset_idx = torch.topk(next_bos_offset_prob_softmax, k=1, dim=-1) + + # next_bos_offset_index_gt = agent_enter_offset_token_idx.roll(shifts=-1, dims=-1) + + # next token prediction mask + bos_token_index = torch.nonzero(agent_state_index == self.enter_state) + eos_token_index = torch.nonzero(agent_state_index == self.exit_state) + + # mask for motion tokens + next_token_eval_mask = mask.clone() + next_token_eval_mask = next_token_eval_mask * next_token_eval_mask.roll(shifts=-1, dims=1) * next_token_eval_mask.roll(shifts=1, dims=1) + for bos_token_index_ in bos_token_index: + next_token_eval_mask[bos_token_index_[0], bos_token_index_[1] : bos_token_index_[1] + 1] = 1 + next_token_eval_mask[bos_token_index_[0], bos_token_index_[1] + 1 : bos_token_index_[1] + 2] = \ + mask[bos_token_index_[0], bos_token_index_[1] + 2 : bos_token_index_[1] + 3] + next_token_eval_mask[eos_token_index[:, 0], eos_token_index[:, 1]] = 0 + + # mask for state tokens + next_state_eval_mask = mask.clone() + next_state_eval_mask = next_state_eval_mask * next_state_eval_mask.roll(shifts=-1, dims=1) * next_state_eval_mask.roll(shifts=1, dims=1) + for bos_token_index_ in bos_token_index: + next_state_eval_mask[bos_token_index_[0], :bos_token_index_[1]] = 0 + next_state_eval_mask[bos_token_index_[0], bos_token_index_[1] : bos_token_index_[1] + 1] = 1 + next_state_eval_mask[bos_token_index_[0], bos_token_index_[1] + 1 : bos_token_index_[1] + 2] = \ + mask[bos_token_index_[0], bos_token_index_[1] + 2 : bos_token_index_[1] + 3] + for eos_token_index_ in eos_token_index: + next_state_eval_mask[eos_token_index_[0], eos_token_index_[1] + 1:] = 1 + next_state_eval_mask[eos_token_index_[0], eos_token_index_[1] : eos_token_index_[1] + 1] = \ + mask[eos_token_index_[0], eos_token_index_[1] - 1 : eos_token_index_[1]] + + # seed agents + next_bos_token_index = torch.nonzero(next_state_index_gt == self.enter_state) + next_bos_token_index = next_bos_token_index[next_bos_token_index[:, 1] < num_step - 1] + + next_state_index_gt_seed = torch.full((self.seed_size, num_step), self.seed_state_type.index('invalid'), device=next_state_index_gt.device) + next_type_index_gt_seed = torch.full((self.seed_size, num_step), self.seed_agent_type.index('seed'), device=next_state_index_gt.device) + next_eval_mask_seed = torch.ones_like(next_state_index_gt_seed) + + num_seed = torch.zeros(num_step, device=next_state_index_gt.device).long() + for next_bos_token_index_ in next_bos_token_index: + if num_seed[next_bos_token_index_[1]] < self.seed_size: + next_state_index_gt_seed[num_seed[next_bos_token_index_[1]], next_bos_token_index_[1]] = self.seed_state_type.index('enter') + next_type_index_gt_seed[num_seed[next_bos_token_index_[1]], next_bos_token_index_[1]] = next_type_index_gt[next_bos_token_index_[0], next_bos_token_index_[1]] + num_seed[next_bos_token_index_[1]] += 1 + + # the last timestep is the beginning of the sequence (also the input) + next_token_eval_mask[:, -1] = 0 + next_state_eval_mask[:, -1] = 0 + next_eval_mask_seed[:, -1] = 0 + # next_bos_token_eval_mask[:, -1] = False + + # no invalid motion token will be supervised + if (next_token_index_gt[next_token_eval_mask] < 0).any(): + raise RuntimeError() + + next_state_index_gt[next_state_index_gt == self.exit_state] = self.valid_state_type.index('exit') + + return {'x_a': feat_a, + # motion token + 'next_token_idx': next_token_idx, + 'next_token_prob': next_token_prob, + 'next_token_idx_gt': next_token_index_gt, + 'next_token_eval_mask': next_token_eval_mask.bool(), + # state token + 'next_state_idx': next_state_idx, + 'next_state_prob': next_state_prob, + 'next_state_idx_gt': next_state_index_gt, + 'next_state_eval_mask': next_state_eval_mask.bool(), + # seed agent + 'next_state_idx_seed': next_state_idx_seed, + 'next_state_prob_seed': next_state_prob_seed, + 'next_state_idx_gt_seed': next_state_index_gt_seed, + 'next_type_idx_seed': next_type_idx_seed, + 'next_type_prob_seed': next_type_prob_seed, + 'next_type_idx_gt_seed': next_type_index_gt_seed, + 'next_eval_mask_seed': next_eval_mask_seed.bool(), + # pl token for bos + # 'next_bos_pl_idx': next_bos_pl_idx, + # 'next_bos_pl_prob': next_bos_pl_prob, + # 'next_bos_pl_index_gt': next_bos_pl_index_gt, + # offset token for bos + # 'next_bos_offset_idx': next_bos_offset_idx, + # 'next_bos_offset_prob': next_bos_offset_prob, + # 'next_bos_offset_index_gt': next_bos_offset_index_gt, + # 'next_bos_token_eval_mask': next_bos_token_eval_mask, + } + + def inference(self, + data: HeteroData, + map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + + start_state_idx = data['agent']['state_idx'][:, (self.num_historical_steps - 1) // self.shift] + filter_mask = (start_state_idx == self.valid_state) | (start_state_idx == self.exit_state) + seed_step_mask = data['agent']['state_idx'][:, (self.num_historical_steps - 1) // self.shift:] == self.enter_state + seed_agent_index_per_step = [torch.nonzero(seed_step_mask[:, t]).squeeze(dim=-1) for t in range(seed_step_mask.shape[1])] + if torch.any(seed_step_mask.sum(dim=0) > self.seed_size): + raise RuntimeError(f"Seed size {self.seed_size} is too small.") + + # num_historical_steps=11 + eval_mask = data['agent']['valid_mask'][filter_mask, self.num_historical_steps - 1] + + if self.predict_state: + eval_mask = torch.ones_like(eval_mask).bool() + + # agent attributes + pos_a = data['agent']['token_pos'][filter_mask].clone() # (num_agent, num_step, 2) + state_a = data['agent']['state_idx'][filter_mask].clone() # (num_agent, num_step) + head_a = data['agent']['token_heading'][filter_mask].clone() # (num_agent, num_step) + gt_traj = data['agent']['position'][filter_mask, self.num_historical_steps:, :self.input_dim].contiguous() + num_agent, num_step, traj_dim = pos_a.shape + + av_index = int(data['agent']['av_index']) + av_index -= (~filter_mask[:av_index]).sum() + + # map attributes + pos_pl = data['pt_token']['position'][:, :2].clone() # (num_pl, 2) + + # make future steps to zero + pos_a[:, (self.num_historical_steps - 1) // self.shift:] = 0 + state_a[:, (self.num_historical_steps - 1) // self.shift:] = 0 + head_a[:, (self.num_historical_steps - 1) // self.shift:] = 0 + + agent_valid_mask = data['agent']['raw_agent_valid_mask'][filter_mask].clone() # token_valid_mask + agent_valid_mask[:, (self.num_historical_steps - 1) // self.shift:] = True + agent_valid_mask[~eval_mask] = False + agent_token_index = data['agent']['token_idx'][filter_mask] + agent_state_index = data['agent']['state_idx'][filter_mask] + agent_type = data['agent']['type'][filter_mask] + agent_category = data['agent']['category'][filter_mask] + + feat_a, head_vector_a, agent_token_traj_all, agent_token_emb, categorical_embs = self.agent_token_embedding(data, + agent_token_index, + agent_state_index, + pos_a, + head_a, + inference=True, + filter_mask=filter_mask, + av_index=av_index, + ) + feat_seed = feat_a[-1:] + feat_a = feat_a[:-1] + + agent_type = data["agent"]["type"][filter_mask] + veh_mask = agent_type == 0 + cyc_mask = agent_type == 2 + ped_mask = agent_type == 1 + + # self.num_recurrent_steps_val = 91 - 11 = 80 + self.num_recurrent_steps_val = data["agent"]['position'].shape[1] - self.num_historical_steps + pred_traj = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val, 2, device=feat_a.device) # (num_agent, 80, 2) + pred_head = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val, device=feat_a.device) + pred_type = agent_type.clone() + pred_state = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val, device=feat_a.device) + pred_prob = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val // self.shift, device=feat_a.device) # (num_agent, 80 // 5 = 16) + next_token_idx_list = [] + next_state_idx_list = [] + next_bos_pl_idx_list = [] + next_bos_offset_idx_list = [] + feat_a_t_dict = {} + feat_sa_t_dict = {} + + # build masks (init) + mask = agent_valid_mask.clone() + temporal_mask = mask.clone() + interact_mask = mask.clone() + if self.predict_state: + + # find bos and eos index + is_bos = agent_state_index == self.enter_state + is_eos = agent_state_index == self.exit_state + bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0)) + eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1)) + + temporal_mask = torch.ones_like(mask) + motion_mask = torch.arange(mask.shape[1]).expand(mask.shape[0], mask.shape[1]).to(mask.device) + motion_mask = (motion_mask > bos_index[:, None]) & (motion_mask <= eos_index[:, None]) + motion_mask[:, self.num_historical_steps // self.shift:] = False + temporal_mask[motion_mask] = mask[motion_mask] + + interact_mask = torch.ones_like(mask) + non_motion_mask = ~motion_mask + non_motion_mask[:, self.num_historical_steps // self.shift:] = False + interact_mask[non_motion_mask] = False + interact_mask[agent_state_index == self.enter_state] = True + + temporal_mask[:, (self.num_historical_steps - 1) // self.shift:] = True + interact_mask[:, (self.num_historical_steps - 1) // self.shift:] = True + + # mapping network + # z = torch.randn(num_agent, self.hidden_dim).to(feat_a.device) + # w = self.mapping_network(z) + + # we only need to predict 16 next tokens + for t in range(self.num_recurrent_steps_val // self.shift): + + # feat_a = feat_a + w[:, None] + num_agent = pos_a.shape[0] + + if t == 0: + inference_mask = temporal_mask.clone() + inference_mask = torch.cat([inference_mask, torch.ones_like(inference_mask[-1:])]) + inference_mask[:, (self.num_historical_steps - 1) // self.shift + t:] = False + else: + inference_mask = torch.zeros_like(temporal_mask) + inference_mask = torch.cat([inference_mask, torch.zeros_like(inference_mask[-1:])]) + inference_mask[:, max((self.num_historical_steps - 1) // self.shift + t - (self.num_interaction_steps // self.shift), 0) : + (self.num_historical_steps - 1) // self.shift + t] = True + + interact_mask = torch.cat([interact_mask, torch.ones_like(interact_mask[:1])]).bool() # placeholder + + edge_index_t, r_t = self.build_temporal_edge(pos_a, head_a, head_vector_a, state_a, temporal_mask, inference_mask, + av_index=av_index) + + # +1: placeholder for seed agent + batch_s = torch.arange(num_step, device=pos_a.device).repeat_interleave(num_agent + 1) + batch_pl = torch.arange(num_step, device=pos_a.device).repeat_interleave(data['pt_token']['num_nodes']) + + # In the inference stage, we only infer the current stage for recurrent + edge_index_a2a, r_a2a = self.build_interaction_edge(pos_a, head_a, head_vector_a, state_a, batch_s, + interact_mask, inference_mask, av_index=av_index) + edge_index_pl2a, r_pl2a = self.build_map2agent_edge(data, num_step, pos_a, head_a, head_vector_a, state_a, batch_s, batch_pl, + interact_mask, inference_mask, av_index=av_index) + interact_mask = interact_mask[:-1] + + # if t > 0: + # feat_a_sum = feat_a.sum(dim=-1) + # for a in range(pos_a.shape[0]): + # t_1 = (self.num_historical_steps - 1) // self.shift + t - 1 + # print(f"agent {a} t_1 {t_1}") + # print(f"token: {next_token_idx[a]}") + # print(f"state: {next_state_idx[a]}") + # print(f"feat_a_sum: {feat_a_sum[a, t_1]}") + + for i in range(self.num_layers): + + if (i in feat_a_t_dict) and (i in feat_sa_t_dict): + feat_a = feat_a_t_dict[i] + feat_seed = feat_sa_t_dict[i] + + feat_a = torch.cat([feat_a, feat_seed], dim=0) + + feat_a = feat_a.reshape(-1, self.hidden_dim) + feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t) + + feat_a = feat_a.reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim) + feat_a = self.pt2a_attn_layers[i](( + map_enc['x_pt'].repeat_interleave(repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape( + -1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a) + + feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a) + feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1) + + feat_seed = feat_a[-1:] # (1, num_step, hidden_dim) + feat_a = feat_a[:-1] # (num_agent, num_step, hidden_dim) + + if t == 0: + feat_a_t_dict[i + 1] = feat_a + feat_sa_t_dict[i + 1] = feat_seed + else: + # update agent features at current step + n = feat_a_t_dict[i + 1].shape[0] + feat_a_t_dict[i + 1][:n, (self.num_historical_steps - 1) // self.shift - 1 + t] = feat_a[:n, (self.num_historical_steps - 1) // self.shift - 1 + t] + # add newly inserted agent features (only when t changed) + if feat_a.shape[0] > n: + m = feat_a.shape[0] - n + feat_a_t_dict[i + 1] = torch.cat([feat_a_t_dict[i + 1], feat_a[-m:]]) + # update seed agent features at current step + feat_sa_t_dict[i + 1][:, (self.num_historical_steps - 1) // self.shift - 1 + t] = feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t] + + # next motion token + next_token_prob = self.token_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) + next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1) + topk_token_prob, next_token_idx = torch.topk(next_token_prob_softmax, k=self.beam_size, dim=-1) # both (num_agent, beam_size) e.g. (31, 5) + + # next state token + next_state_prob = self.state_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) + next_state_idx = next_state_prob.softmax(dim=-1).argmax(dim=-1) + next_state_idx[next_state_idx == self.valid_state_type.index('exit')] = self.exit_state + + # seed agent + feat_seed = self.seed_head(feat_seed) + self.seed_feature.weight[:, None] + next_state_prob_seed = self.seed_state_predict_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) + next_state_idx_seed = next_state_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True) + next_state_idx_seed[next_state_idx_seed == self.seed_state_type.index('enter')] = self.enter_state + + next_type_prob_seed = self.seed_type_predict_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) + next_type_idx_seed = next_type_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True) + + # print(f"t: {t}") + # print(next_type_idx_seed[..., 0].tolist()) + + # bos pl prediction + # next_bos_pl_prob = self.bos_pl_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) + # next_bos_pl_prob_softmax = torch.softmax(next_bos_pl_prob, dim=-1) + # next_bos_pl_idx = torch.argmax(next_bos_pl_prob_softmax, dim=-1) + + # bos offset prediction + # next_bos_offset_prob = self.bos_offset_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) + # next_bos_offset_prob_softmax = torch.softmax(next_bos_offset_prob, dim=-1) + # next_bos_offset_idx = torch.argmax(next_bos_offset_prob_softmax, dim=-1) + + # convert the predicted token to a 0.5s (6 timesteps) trajectory + expanded_token_index = next_token_idx[..., None, None, None].expand(-1, -1, 6, 4, 2) + next_token_traj = torch.gather(agent_token_traj_all, 1, expanded_token_index) # (num_agent, beam_size, 6, 4, 2) + + # apply rotation and translation on 'next_token_traj' + theta = head_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t] + cos, sin = theta.cos(), theta.sin() + rot_mat = torch.zeros((num_agent, 2, 2), device=theta.device) + rot_mat[:, 0, 0] = cos + rot_mat[:, 0, 1] = sin + rot_mat[:, 1, 0] = -sin + rot_mat[:, 1, 1] = cos + agent_diff_rel = torch.bmm(next_token_traj.view(-1, 4, 2), + rot_mat[:, None, None, ...].repeat(1, self.beam_size, self.shift + 1, 1, 1).view( + -1, 2, 2)).view(num_agent, self.beam_size, self.shift + 1, 4, 2) + agent_pred_rel = agent_diff_rel + pos_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t, :][:, None, None, None, ...] + + # sample 1 most probable index of top beam_size tokens, (num_agent, beam_size) -> (num_agent, 1) + # then sample the agent_pred_rel, (num_agent, beam_size, 6, 4, 2) -> (num_agent, 6, 4, 2) + sample_token_index = torch.multinomial(topk_token_prob, 1).to(agent_pred_rel.device) + next_token_idx = next_token_idx.gather(dim=1, index=sample_token_index).squeeze(-1) + agent_pred_rel = agent_pred_rel.gather(dim=1, + index=sample_token_index[..., None, None, None].expand(-1, -1, 6, 4, + 2))[:, 0, ...] + + # get predicted position and heading of current shifted timesteps + diff_xy = agent_pred_rel[:, 1:, 0, :] - agent_pred_rel[:, 1:, 3, :] + pred_traj[:num_agent, t * 5 : (t + 1) * 5] = agent_pred_rel[:, 1:, ...].clone().mean(dim=2) + pred_head[:num_agent, t * 5 : (t + 1) * 5] = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0]) + pred_state[:num_agent, t * 5 : (t + 1) * 5] = next_state_idx[:, None].repeat(1, 5) + # pred_prob[:num_agent, t] = topk_token_prob.gather(dim=-1, index=sample_token_index)[:, 0] # (num_agent, beam_size) -> (num_agent,) + + # update pos/head/state of current step + pos_a[:, (self.num_historical_steps - 1) // self.shift + t] = agent_pred_rel[:, -1, ...].clone().mean(dim=1) + diff_xy = agent_pred_rel[:, -1, 0, :] - agent_pred_rel[:, -1, 3, :] + theta = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0]) + head_a[:, (self.num_historical_steps - 1) // self.shift + t] = theta + state_a[:, (self.num_historical_steps - 1) // self.shift + t] = next_state_idx + + # the case that the current predicted state token is invalid/exit + is_eos = next_state_idx == self.exit_state + is_invalid = next_state_idx == self.invalid_state + + next_token_idx[is_invalid] = -1 + pos_a[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = 0. + head_a[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = 0. + + mask[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = False # to handle those newly-added agents + interact_mask[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = False + + agent_token_emb[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = self.no_token_emb(torch.zeros(1, device=pos_a.device).long()) + + type_emb = categorical_embs[0].reshape(num_agent, num_step, -1) + shape_emb = categorical_embs[1].reshape(num_agent, num_step, -1) + type_emb[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = self.type_a_emb(torch.tensor(self.all_agent_type.index('invalid'), device=pos_a.device).long()) + shape_emb[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = self.shape_emb(torch.full((1, 3), self.invalid_shape_value, device=pos_a.device)) + categorical_embs = [type_emb.reshape(-1, self.hidden_dim), shape_emb.reshape(-1, self.hidden_dim)] + + # FIXME: need to discuss!!! + # if is_eos.any(): + + # pos_a[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = 0. + # head_a[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = 0. + # mask[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = False # to handle those newly-added agents + # interact_mask[torch.cat([is_eos, torch.zeros(1, device=is_eos.device).bool()]), (self.num_historical_steps - 1) // self.shift + t + 1:] = False + + # agent_token_emb[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = self.no_token_emb(torch.zeros(1, device=pos_a.device).long()) + + # type_emb = categorical_embs[0].reshape(num_agent, num_step, -1) + # shape_emb = categorical_embs[1].reshape(num_agent, num_step, -1) + # type_emb[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = self.type_a_emb(torch.tensor(self.all_agent_type.index('invalid'), device=pos_a.device).long()) + # shape_emb[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = self.shape_emb(torch.full((1, 3), self.invalid_shape_value, device=pos_a.device)) + # categorical_embs = [type_emb.reshape(-1, self.hidden_dim), shape_emb.reshape(-1, self.hidden_dim)] + + # for sa in range(next_state_idx_seed.shape[0]): + # if next_state_idx_seed[sa] == self.enter_state: + # print(f"agent {sa} is entering at step {t}") + + # insert new agents (from seed agent) + seed_agent_index_cur_step = seed_agent_index_per_step[t] + num_new_agent = min(len(seed_agent_index_cur_step), next_state_idx_seed.bool().sum()) + new_agent_mask = next_state_idx_seed.bool() + next_state_idx_seed = next_state_idx_seed[new_agent_mask] + next_state_idx_seed = next_state_idx_seed[:num_new_agent] + next_type_idx_seed = next_type_idx_seed[new_agent_mask] + next_type_idx_seed = next_type_idx_seed[:num_new_agent] + selected_agent_index_cur_step = seed_agent_index_cur_step[:num_new_agent] + agent_token_index = torch.cat([agent_token_index, data['agent']['token_idx'][selected_agent_index_cur_step]]) + agent_state_index = torch.cat([agent_state_index, data['agent']['state_idx'][selected_agent_index_cur_step]]) + agent_category = torch.cat([agent_category, data['agent']['category'][selected_agent_index_cur_step]]) + agent_valid_mask = torch.cat([agent_valid_mask, data['agent']['raw_agent_valid_mask'][selected_agent_index_cur_step]]) + gt_traj = torch.cat([gt_traj, data['agent']['position'][selected_agent_index_cur_step, self.num_historical_steps:, :self.input_dim]]) + + # FIXME: under test!!! bos token index is -2 + next_state_idx = torch.cat([next_state_idx, next_state_idx_seed], dim=0).long() + next_token_idx = torch.cat([next_token_idx, torch.zeros(num_new_agent, device=next_token_idx.device) - 2], dim=0).long() + mask = torch.cat([mask, torch.ones(num_new_agent, num_step, device=mask.device)], dim=0).bool() + temporal_mask = torch.cat([temporal_mask, torch.ones(num_new_agent, num_step, device=temporal_mask.device)], dim=0).bool() + interact_mask = torch.cat([interact_mask, torch.ones(num_new_agent, num_step, device=interact_mask.device)], dim=0).bool() + + # new_pos_a = ego_pos_a[None].repeat(num_new_agent, 1, 1) + # new_head_a = ego_head_a[None].repeat(num_new_agent, 1) + new_pos_a = torch.zeros(num_new_agent, num_step, 2, device=pos_a.device) + new_head_a = torch.zeros(num_new_agent, num_step, device=pos_a.device) + new_state_a = torch.zeros(num_new_agent, num_step, device=state_a.device) + new_shape_a = torch.full((num_new_agent, num_step, 3), self.invalid_shape_value, device=pos_a.device) + new_type_a = torch.full((num_new_agent, num_step), self.all_agent_type.index('invalid'), device=pos_a.device) + + if num_new_agent > 0: + gt_bos_pos_a = data['agent']['position'][seed_agent_index_cur_step[:num_new_agent], (self.num_historical_steps - 1) // self.shift + t] + new_pos_a[:, (self.num_historical_steps - 1) // self.shift + t] = gt_bos_pos_a[:, :2].clone() + pos_a = torch.cat([pos_a, new_pos_a], dim=0) + + gt_bos_head_a = data['agent']['heading'][seed_agent_index_cur_step[:num_new_agent], (self.num_historical_steps - 1) // self.shift + t] + new_head_a[:, (self.num_historical_steps - 1) // self.shift + t] = gt_bos_head_a.clone() + head_a = torch.cat([head_a, new_head_a], dim=0) + + gt_bos_shape_a = data['agent']['shape'][seed_agent_index_cur_step[:num_new_agent], self.num_historical_steps - 1] + gt_bos_type_a = data['agent']['type'][seed_agent_index_cur_step[:num_new_agent]] + new_shape_a[:, (self.num_historical_steps - 1) // self.shift + t:] = gt_bos_shape_a.clone()[:, None] + new_type_a[:, (self.num_historical_steps - 1) // self.shift + t:] = gt_bos_type_a.clone()[:, None] + # new_type_a[:, (self.num_historical_steps - 1) // self.shift + t] = next_type_idx_seed + pred_type = torch.cat([pred_type, new_type_a[:, (self.num_historical_steps - 1) // self.shift + t]]) + + new_state_a[:, (self.num_historical_steps - 1) // self.shift + t] = self.enter_state + state_a = torch.cat([state_a, new_state_a], dim=0) + + mask[-num_new_agent:, :(self.num_historical_steps - 1) // self.shift + t + 1] = 0 + interact_mask[-num_new_agent:, :(self.num_historical_steps - 1) // self.shift + t] = 0 + + # update all steps + new_pred_traj = torch.zeros(num_new_agent, self.num_recurrent_steps_val, 2, device=pos_a.device) + new_pred_traj[:, t * 5 : (t + 1) * 5] = new_pos_a[:, (self.num_historical_steps - 1) // self.shift + t][:, None].repeat(1, 5, 1) + pred_traj = torch.cat([pred_traj, new_pred_traj], dim=0) + + new_pred_head = torch.zeros(num_new_agent, self.num_recurrent_steps_val, device=pos_a.device) + new_pred_head[:, t * 5 : (t + 1) * 5] = new_head_a[:, (self.num_historical_steps - 1) // self.shift + t][:, None].repeat(1, 5) + pred_head = torch.cat([pred_head, new_pred_head], dim=0) + + new_pred_state = torch.zeros(num_new_agent, self.num_recurrent_steps_val, device=pos_a.device) + new_pred_state[:, t * 5 : (t + 1) * 5] = next_state_idx_seed[:, None].repeat(1, 5) + pred_state = torch.cat([pred_state, new_pred_state], dim=0) + + # handle the position/heading of bos token + # bos_pl_pos = pos_pl[next_bos_pl_idx[is_bos].long()] + # bos_offset_pos = discretize_neighboring(neighbor_index=next_bos_offset_idx[is_bos]) + # pos_a[is_bos, (self.num_historical_steps - 1) // self.shift + t] += (bos_pl_pos + bos_offset_pos) + # # headings before bos token remains 0 which align with training process + # head_a[is_bos, (self.num_historical_steps - 1) // self.shift + t] += 0. + + # add new agents token embeddings + agent_token_emb = torch.cat([agent_token_emb, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())[None, :].repeat(num_new_agent, num_step, 1)]) + veh_mask = torch.cat([veh_mask, next_type_idx_seed == self.seed_agent_type.index('veh')]) + ped_mask = torch.cat([ped_mask, next_type_idx_seed == self.seed_agent_type.index('ped')]) + cyc_mask = torch.cat([cyc_mask, next_type_idx_seed == self.seed_agent_type.index('cyc')]) + + # add new agents trajectory embeddings + trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(pos_a.device).to(torch.float) + trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(pos_a.device).to(torch.float) + trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(pos_a.device).to(torch.float) + + new_agent_token_traj_all = torch.zeros((num_new_agent, self.token_size, self.shift + 1, 4, 2), device=pos_a.device) + trajectory_token_all_veh = torch.from_numpy(self.trajectory_token_all['veh']).clone().to(pos_a.device).to(torch.float) + trajectory_token_all_ped = torch.from_numpy(self.trajectory_token_all['ped']).clone().to(pos_a.device).to(torch.float) + trajectory_token_all_cyc = torch.from_numpy(self.trajectory_token_all['cyc']).clone().to(pos_a.device).to(torch.float) + new_agent_token_traj_all[next_type_idx_seed == 0] = torch.cat( + [trajectory_token_all_veh[:, :self.shift], trajectory_token_veh[:, None, ...]], dim=1) + new_agent_token_traj_all[next_type_idx_seed == 1] = torch.cat( + [trajectory_token_all_ped[:, :self.shift], trajectory_token_ped[:, None, ...]], dim=1) + new_agent_token_traj_all[next_type_idx_seed == 2] = torch.cat( + [trajectory_token_all_cyc[:, :self.shift], trajectory_token_cyc[:, None, ...]], dim=1) + + agent_token_traj_all = torch.cat([agent_token_traj_all, new_agent_token_traj_all], dim=0) + + # add new agents categorical embeddings + new_categorical_embs = [self.type_a_emb(new_type_a.reshape(-1).long()), self.shape_emb(new_shape_a.reshape(-1, 3))] + categorical_embs = [torch.cat([categorical_embs[0], new_categorical_embs[0]], dim=0), + torch.cat([categorical_embs[1], new_categorical_embs[1]], dim=0)] + + # update token embeddings of current step + agent_token_emb[veh_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_veh[ + next_token_idx[veh_mask]] + agent_token_emb[ped_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_ped[ + next_token_idx[ped_mask]] + agent_token_emb[cyc_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_cyc[ + next_token_idx[cyc_mask]] + + motion_vector_a, head_vector_a = self.build_vector_a(pos_a, head_a, state_a) + + motion_vector_a[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0. + head_vector_a[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0. + x_a = torch.stack( + [torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2])], dim=-1) + + x_b = x_a.clone() + x_a = self.x_a_emb(continuous_inputs=x_a.view(-1, x_a.size(-1)), + categorical_embs=categorical_embs) + x_a = x_a.view(-1, num_step, self.hidden_dim) + + s_a = self.state_a_emb(state_a.reshape(-1).long()).reshape(num_agent + num_new_agent, num_step, self.hidden_dim) + feat_a = torch.cat((agent_token_emb, x_a, s_a), dim=-1) + feat_a = self.fusion_emb(feat_a) + + # if t >= 15: + # print(f"inference {t}") + # is_invalid = state_a == self.invalid_state + # is_bos = state_a == self.enter_state + # is_eos = state_a == self.exit_state + # bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0)) + # eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1)) + # mask = torch.arange(num_step).expand(num_agent + num_new_agent, -1).to(state_a.device) + # mask = (mask >= bos_index[:, None]) & (mask <= eos_index[:, None] + 1) + # is_invalid[mask] = False + # is_invalid[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = False + # print(pos_a[:, :((self.num_historical_steps - 1) // self.shift + 1 + t)]) + # print(state_a[:, :((self.num_historical_steps - 1) // self.shift + 1 + t)]) + # print(pos_a[is_invalid][:, 0]) + # print(head_a[is_invalid]) + # print(categorical_embs[0].sum(dim=-1)[is_invalid.reshape(-1)]) + # print(categorical_embs[1].sum(dim=-1)[is_invalid.reshape(-1)]) + # print(motion_vector_a[is_invalid][:, 0]) + # print(head_vector_a[is_invalid][:, 0]) + # print(x_b.sum(dim=-1)[is_invalid]) + # print(x_a.sum(dim=-1)[is_invalid]) + # for a in range(state_a.shape[0]): + # print(f"agent: {a}") + # print(state_a[a]) + # print(is_invalid[a].long()) + # print(pos_a[a, :, 0]) + # print(motion_vector_a[a, :, 0]) + # print(s_a.sum(dim=-1)[is_invalid]) + # print(feat_a.sum(dim=-1)[is_invalid]) + + # replace the features of steps before bos of valid agents with the corresponding seed agent features + # is_bos = state_a == self.enter_state + # bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(num_step)) + # before_bos_mask = torch.arange(num_step).expand(num_agent + num_new_agent, -1).to(state_a.device) < bos_index[:, None] + # feat_a[before_bos_mask] = feat_seed.repeat(num_agent + num_new_agent, 1, 1)[before_bos_mask] + + # build seed agent features + motion_vector_seed = motion_vector_a[av_index : av_index + 1] + head_vector_seed = head_vector_a[av_index : av_index + 1] + feat_seed = self.build_invalid_agent_feature(num_step, pos_a.device, type_index=self.all_agent_type.index('seed'), + motion_vector=motion_vector_seed, head_vector=head_vector_seed) + # print(f"inference {t}") + # print(feat_seed.sum(dim=-1)) + + next_token_idx_list.append(next_token_idx[:, None]) + next_state_idx_list.append(next_state_idx[:, None]) + # next_bos_pl_idx_list.append(next_bos_pl_idx[:, None]) + # next_bos_offset_idx_list.append(next_bos_offset_idx[:, None]) + + # TODO: check this + # agent_valid_mask[agent_category != 3] = False + + # print("inference") + # is_invalid = state_a == self.invalid_state + # is_bos = state_a == self.enter_state + # is_eos = state_a == self.exit_state + # bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0)) + # eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1)) + # mask = torch.arange(num_step).expand(num_agent, -1).to(state_a.device) + # mask = (mask >= bos_index[:, None]) & (mask <= eos_index[:, None] + 1) + # is_invalid[mask] = False + # print(feat_a.sum(dim=-1)[is_invalid]) + # print(pos_a[is_invalid][: 0]) + # print(head_a[is_invalid]) + # exit(1) + + num_agent = pos_a.shape[0] + for i in range(len(next_token_idx_list)): + next_token_idx_list[i] = torch.cat([next_token_idx_list[i], torch.zeros(num_agent - next_token_idx_list[i].shape[0], 1, device=next_token_idx_list[i].device) - 1], dim=0).long() + next_state_idx_list[i] = torch.cat([next_state_idx_list[i], torch.zeros(num_agent - next_state_idx_list[i].shape[0], 1, device=next_state_idx_list[i].device)], dim=0).long() + + # eval mask + next_token_eval_mask = agent_valid_mask.clone() + next_state_eval_mask = agent_valid_mask.clone() + bos_token_index = torch.nonzero(agent_state_index == self.enter_state) + eos_token_index = torch.nonzero(agent_state_index == self.exit_state) + + next_token_eval_mask[bos_token_index[:, 0], bos_token_index[:, 1]] = 1 + + for bos_token_index_i in bos_token_index: + next_state_eval_mask[bos_token_index_i[0], :bos_token_index_i[1] + 2] = 1 + for eos_token_index_i in eos_token_index: + next_state_eval_mask[eos_token_index_i[0], eos_token_index_i[1]:] = 1 + + # add history attributes + num_agent = pred_traj.shape[0] + num_init_agent = filter_mask.sum() + + pred_traj = torch.cat([pred_traj, torch.zeros(num_agent, self.num_historical_steps - 1, *(pred_traj.shape[2:]), device=pred_traj.device)], dim=1) + pred_head = torch.cat([pred_head, torch.zeros(num_agent, self.num_historical_steps - 1, *(pred_head.shape[2:]), device=pred_head.device)], dim=1) + pred_state = torch.cat([pred_state, torch.zeros(num_agent, self.num_historical_steps - 1, *(pred_state.shape[2:]), device=pred_state.device)], dim=1) + + pred_state[:num_init_agent, :self.num_historical_steps - 1] = data['agent']['state_idx'][filter_mask, :(self.num_historical_steps - 1) // self.shift].repeat_interleave(repeats=self.shift, dim=1) + + historical_token_idx = data['agent']['token_idx'][filter_mask, :(self.num_historical_steps - 1) // self.shift] + historical_token_idx[historical_token_idx < 0] = 0 + historical_token_traj_all = torch.gather(agent_token_traj_all, 1, historical_token_idx[..., None, None, None].expand(-1, -1, 6, 4, 2)) + init_theta = head_a[:num_init_agent, 0] + cos, sin = init_theta.cos(), init_theta.sin() + rot_mat = torch.zeros((num_init_agent, 2, 2), device=init_theta.device) + rot_mat[:, 0, 0] = cos + rot_mat[:, 0, 1] = sin + rot_mat[:, 1, 0] = -sin + rot_mat[:, 1, 1] = cos + historical_token_traj_all = torch.bmm(historical_token_traj_all.view(-1, 4, 2), + rot_mat[:, None, None, ...].repeat(1, (self.num_historical_steps - 1) // self.shift, self.shift + 1, 1, 1).view( + -1, 2, 2)).view(num_init_agent, (self.num_historical_steps - 1) // self.shift, self.shift + 1, 4, 2) + historical_token_traj_all = historical_token_traj_all + pos_a[:num_init_agent, 0, :][:, None, None, None, ...] + pred_traj[:num_init_agent, :self.num_historical_steps - 1] = historical_token_traj_all[:, :, 1:, ...].clone().mean(dim=3).reshape(num_init_agent, -1, 2) + diff_xy = historical_token_traj_all[..., 1:, 0, :] - historical_token_traj_all[..., 1:, 3, :] + pred_head[:num_init_agent, :self.num_historical_steps - 1] = torch.arctan2(diff_xy[..., 1], diff_xy[..., 0]).reshape(num_init_agent, -1) + + return { + 'av_index': av_index, + 'valid_mask': agent_valid_mask[:, self.num_historical_steps:], + 'pos_a': pos_a[:, (self.num_historical_steps - 1) // self.shift:], + 'head_a': head_a[:, (self.num_historical_steps - 1) // self.shift:], + 'gt_traj': gt_traj, + 'pred_traj': pred_traj, + 'pred_head': pred_head, + 'pred_type': list(map(lambda i: self.seed_agent_type[i], pred_type.tolist())), + 'pred_state': pred_state, + 'next_token_idx': torch.cat(next_token_idx_list, dim=-1), # (num_agent, num_step) + 'next_token_idx_gt': agent_token_index, + 'next_state_idx': torch.cat(next_state_idx_list, dim=-1) if len(next_state_idx_list) > 0 else None, + 'next_state_idx_gt': agent_state_index, + 'next_token_eval_mask': next_token_eval_mask, + 'next_state_eval_mask': next_state_eval_mask, + # 'next_bos_pl_idx': torch.cat(next_bos_pl_idx_list, dim=-1), + # 'next_bos_offset_idx': torch.cat(next_bos_offset_idx_list, dim=-1), + } diff --git a/backups/dev/modules/layers.py b/backups/dev/modules/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..13cd5f687fd37a363b39ad1d7db542c8fdf58a8a --- /dev/null +++ b/backups/dev/modules/layers.py @@ -0,0 +1,371 @@ + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import List, Optional, Tuple, Union +from torch_geometric.nn.conv import MessagePassing +from torch_geometric.utils import softmax + +from dev.utils.func import weight_init + + +__all__ = ['AttentionLayer', 'FourierEmbedding', 'MLPEmbedding', 'MLPLayer', 'MappingNetwork'] + + +class AttentionLayer(MessagePassing): + + def __init__(self, + hidden_dim: int, + num_heads: int, + head_dim: int, + dropout: float, + bipartite: bool, + has_pos_emb: bool, + **kwargs) -> None: + super(AttentionLayer, self).__init__(aggr='add', node_dim=0, **kwargs) + self.num_heads = num_heads + self.head_dim = head_dim + self.has_pos_emb = has_pos_emb + self.scale = head_dim ** -0.5 + + self.to_q = nn.Linear(hidden_dim, head_dim * num_heads) + self.to_k = nn.Linear(hidden_dim, head_dim * num_heads, bias=False) + self.to_v = nn.Linear(hidden_dim, head_dim * num_heads) + if has_pos_emb: + self.to_k_r = nn.Linear(hidden_dim, head_dim * num_heads, bias=False) + self.to_v_r = nn.Linear(hidden_dim, head_dim * num_heads) + self.to_s = nn.Linear(hidden_dim, head_dim * num_heads) + self.to_g = nn.Linear(head_dim * num_heads + hidden_dim, head_dim * num_heads) + self.to_out = nn.Linear(head_dim * num_heads, hidden_dim) + self.attn_drop = nn.Dropout(dropout) + self.ff_mlp = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim * 4), + nn.ReLU(inplace=True), + nn.Dropout(dropout), + nn.Linear(hidden_dim * 4, hidden_dim), + ) + if bipartite: + self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim) + self.attn_prenorm_x_dst = nn.LayerNorm(hidden_dim) + else: + self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim) + self.attn_prenorm_x_dst = self.attn_prenorm_x_src + if has_pos_emb: + self.attn_prenorm_r = nn.LayerNorm(hidden_dim) + self.attn_postnorm = nn.LayerNorm(hidden_dim) + self.ff_prenorm = nn.LayerNorm(hidden_dim) + self.ff_postnorm = nn.LayerNorm(hidden_dim) + self.apply(weight_init) + + def forward(self, + x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + r: Optional[torch.Tensor], + edge_index: torch.Tensor) -> torch.Tensor: + if isinstance(x, torch.Tensor): + x_src = x_dst = self.attn_prenorm_x_src(x) + else: + x_src, x_dst = x + x_src = self.attn_prenorm_x_src(x_src) + x_dst = self.attn_prenorm_x_dst(x_dst) + x = x[1] + if self.has_pos_emb and r is not None: + r = self.attn_prenorm_r(r) + x = x + self.attn_postnorm(self._attn_block(x_src, x_dst, r, edge_index)) + x = x + self.ff_postnorm(self._ff_block(self.ff_prenorm(x))) + return x + + def message(self, + q_i: torch.Tensor, + k_j: torch.Tensor, + v_j: torch.Tensor, + r: Optional[torch.Tensor], + index: torch.Tensor, + ptr: Optional[torch.Tensor]) -> torch.Tensor: + if self.has_pos_emb and r is not None: + k_j = k_j + self.to_k_r(r).view(-1, self.num_heads, self.head_dim) + v_j = v_j + self.to_v_r(r).view(-1, self.num_heads, self.head_dim) + sim = (q_i * k_j).sum(dim=-1) * self.scale + attn = softmax(sim, index, ptr) + self.attention_weight = attn.sum(-1).detach() + attn = self.attn_drop(attn) + return v_j * attn.unsqueeze(-1) + + def update(self, + inputs: torch.Tensor, + x_dst: torch.Tensor) -> torch.Tensor: + inputs = inputs.view(-1, self.num_heads * self.head_dim) + g = torch.sigmoid(self.to_g(torch.cat([inputs, x_dst], dim=-1))) + return inputs + g * (self.to_s(x_dst) - inputs) + + def _attn_block(self, + x_src: torch.Tensor, + x_dst: torch.Tensor, + r: Optional[torch.Tensor], + edge_index: torch.Tensor) -> torch.Tensor: + q = self.to_q(x_dst).view(-1, self.num_heads, self.head_dim) + k = self.to_k(x_src).view(-1, self.num_heads, self.head_dim) + v = self.to_v(x_src).view(-1, self.num_heads, self.head_dim) + agg = self.propagate(edge_index=edge_index, x_dst=x_dst, q=q, k=k, v=v, r=r) + return self.to_out(agg) + + def _ff_block(self, x: torch.Tensor) -> torch.Tensor: + return self.ff_mlp(x) + + +class FourierEmbedding(nn.Module): + + def __init__(self, + input_dim: int, + hidden_dim: int, + num_freq_bands: int) -> None: + super(FourierEmbedding, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.freqs = nn.Embedding(input_dim, num_freq_bands) if input_dim != 0 else None + self.mlps = nn.ModuleList( + [nn.Sequential( + nn.Linear(num_freq_bands * 2 + 1, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, hidden_dim), + ) + for _ in range(input_dim)]) + self.to_out = nn.Sequential( + nn.LayerNorm(hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, hidden_dim), + ) + self.apply(weight_init) + + def forward(self, + continuous_inputs: Optional[torch.Tensor] = None, + categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor: + if continuous_inputs is None: + if categorical_embs is not None: + x = torch.stack(categorical_embs).sum(dim=0) + else: + raise ValueError('Both continuous_inputs and categorical_embs are None') + else: + x = continuous_inputs.unsqueeze(-1) * self.freqs.weight * 2 * math.pi + # Warning: if your data are noisy, don't use learnable sinusoidal embedding + x = torch.cat([x.cos(), x.sin(), continuous_inputs.unsqueeze(-1)], dim=-1) + continuous_embs: List[Optional[torch.Tensor]] = [None] * self.input_dim + for i in range(self.input_dim): + continuous_embs[i] = self.mlps[i](x[:, i]) + x = torch.stack(continuous_embs).sum(dim=0) + if categorical_embs is not None: + x = x + torch.stack(categorical_embs).sum(dim=0) + return self.to_out(x) + + +class MLPEmbedding(nn.Module): + def __init__(self, + input_dim: int, + hidden_dim: int) -> None: + super(MLPEmbedding, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.mlp = nn.Sequential( + nn.Linear(input_dim, 128), + nn.LayerNorm(128), + nn.ReLU(inplace=True), + nn.Linear(128, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, hidden_dim)) + self.apply(weight_init) + + def forward(self, + continuous_inputs: Optional[torch.Tensor] = None, + categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor: + if continuous_inputs is None: + if categorical_embs is not None: + x = torch.stack(categorical_embs).sum(dim=0) + else: + raise ValueError('Both continuous_inputs and categorical_embs are None') + else: + x = self.mlp(continuous_inputs) + if categorical_embs is not None: + x = x + torch.stack(categorical_embs).sum(dim=0) + return x + + +class MLPLayer(nn.Module): + + def __init__(self, + input_dim: int, + hidden_dim: int=None, + output_dim: int=None) -> None: + super(MLPLayer, self).__init__() + + if hidden_dim is None: + hidden_dim = output_dim + + self.mlp = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, output_dim), + ) + self.apply(weight_init) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(x) + + +class MappingNetwork(nn.Module): + def __init__(self, z_dim, w_dim, layer_dim=None, num_layers=8): + super().__init__() + + if not layer_dim: + layer_dim = w_dim + layer_dims = [z_dim] + [layer_dim] * (num_layers - 1) + [w_dim] + + layers = [] + for i in range(num_layers): + layers.extend([ + nn.Linear(layer_dims[i], layer_dims[i + 1]), + nn.LeakyReLU(), + ]) + self.layers = nn.Sequential(*layers) + + def forward(self, z): + w = self.layers(z) + return w + + +# class FocalLoss: +# def __init__(self, alpha: float=.25, gamma: float=2): +# self.alpha = alpha +# self.gamma = gamma + +# def __call__(self, inputs, targets): +# prob = inputs.sigmoid() +# ce_loss = F.binary_cross_entropy_with_logits(inputs, targets.float(), reduction='none') +# p_t = prob * targets + (1 - prob) * (1 - targets) +# loss = ce_loss * ((1 - p_t) ** self.gamma) + +# if self.alpha >= 0: +# alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets) +# loss = alpha_t * loss + +# return loss.mean() + + +class FocalLoss(nn.Module): + """Focal Loss, as described in https://arxiv.org/abs/1708.02002. + It is essentially an enhancement to cross entropy loss and is + useful for classification tasks when there is a large class imbalance. + x is expected to contain raw, unnormalized scores for each class. + y is expected to contain class labels. + Shape: + - x: (batch_size, C) or (batch_size, C, d1, d2, ..., dK), K > 0. + - y: (batch_size,) or (batch_size, d1, d2, ..., dK), K > 0. + """ + + def __init__( + self, + alpha: Optional[torch.Tensor] = None, + gamma: float = 0.0, + reduction: str = "mean", + ignore_index: int = -100, + ): + """Constructor. + Args: + alpha (Tensor, optional): Weights for each class. Defaults to None. + gamma (float, optional): A constant, as described in the paper. + Defaults to 0. + reduction (str, optional): 'mean', 'sum' or 'none'. + Defaults to 'mean'. + ignore_index (int, optional): class label to ignore. + Defaults to -100. + """ + if reduction not in ("mean", "sum", "none"): + raise ValueError('Reduction must be one of: "mean", "sum", "none".') + + super().__init__() + self.alpha = alpha + self.gamma = gamma + self.ignore_index = ignore_index + self.reduction = reduction + + self.nll_loss = nn.NLLLoss( + weight=alpha, reduction="none", ignore_index=ignore_index + ) + + def __repr__(self): + arg_keys = ["alpha", "gamma", "ignore_index", "reduction"] + arg_vals = [self.__dict__[k] for k in arg_keys] + arg_strs = [f"{k}={v}" for k, v in zip(arg_keys, arg_vals)] + arg_str = ", ".join(arg_strs) + return f"{type(self).__name__}({arg_str})" + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + if x.ndim > 2: + # (N, C, d1, d2, ..., dK) --> (N * d1 * ... * dK, C) + c = x.shape[1] + x = x.permute(0, *range(2, x.ndim), 1).reshape(-1, c) + # (N, d1, d2, ..., dK) --> (N * d1 * ... * dK,) + y = y.view(-1) + + unignored_mask = y != self.ignore_index + y = y[unignored_mask] + if len(y) == 0: + return 0.0 + x = x[unignored_mask] + + # compute weighted cross entropy term: -alpha * log(pt) + # (alpha is already part of self.nll_loss) + log_p = F.log_softmax(x, dim=-1) + ce = self.nll_loss(log_p, y) + + # get true class column from each row + all_rows = torch.arange(len(x)) + log_pt = log_p[all_rows, y] + + # compute focal term: (1 - pt)^gamma + pt = log_pt.exp() + focal_term = (1 - pt) ** self.gamma + + # the full loss: -alpha * ((1 - pt)^gamma) * log(pt) + loss = focal_term * ce + + if self.reduction == "mean": + loss = loss.mean() + elif self.reduction == "sum": + loss = loss.sum() + + return loss + + +class OccLoss(nn.Module): + + # geo_scal_loss + def __init__(self): + super().__init__() + + def forward(self, pred, target, mask=None): + + nonempty_probs = torch.sigmoid(pred) + empty_probs = 1 - nonempty_probs + + if mask is None: + mask = torch.ones_like(target).bool() + + nonempty_target = target == 1 + nonempty_target = nonempty_target[mask].float() + nonempty_probs = nonempty_probs[mask] + empty_probs = empty_probs[mask] + + intersection = (nonempty_target * nonempty_probs).sum() + precision = intersection / nonempty_probs.sum() + recall = intersection / nonempty_target.sum() + spec = ((1 - nonempty_target) * (empty_probs)).sum() / (1 - nonempty_target).sum() + + return ( + F.binary_cross_entropy(precision, torch.ones_like(precision)) + + F.binary_cross_entropy(recall, torch.ones_like(recall)) + + F.binary_cross_entropy(spec, torch.ones_like(spec)) + ) diff --git a/backups/dev/modules/map_decoder.py b/backups/dev/modules/map_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..97b033a0ffdb5b538ad923058ce9e4843f4ed18d --- /dev/null +++ b/backups/dev/modules/map_decoder.py @@ -0,0 +1,130 @@ +from typing import Dict +import torch +import torch.nn as nn +from torch_cluster import radius_graph +from torch_geometric.data import Batch +from torch_geometric.data import HeteroData +from torch_geometric.utils import subgraph + +from dev.modules.layers import MLPLayer, AttentionLayer, FourierEmbedding, MLPEmbedding +from dev.utils.func import weight_init, wrap_angle, angle_between_2d_vectors + + +class SMARTMapDecoder(nn.Module): + + def __init__(self, + dataset: str, + input_dim: int, + hidden_dim: int, + num_historical_steps: int, + pl2pl_radius: float, + num_freq_bands: int, + num_layers: int, + num_heads: int, + head_dim: int, + dropout: float, + map_token) -> None: + + super(SMARTMapDecoder, self).__init__() + self.dataset = dataset + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.num_historical_steps = num_historical_steps + self.pl2pl_radius = pl2pl_radius + self.num_freq_bands = num_freq_bands + self.num_layers = num_layers + self.num_heads = num_heads + self.head_dim = head_dim + self.dropout = dropout + + if input_dim == 2: + input_dim_r_pt2pt = 3 + elif input_dim == 3: + input_dim_r_pt2pt = 4 + else: + raise ValueError('{} is not a valid dimension'.format(input_dim)) + + self.type_pt_emb = nn.Embedding(17, hidden_dim) + self.side_pt_emb = nn.Embedding(4, hidden_dim) + self.polygon_type_emb = nn.Embedding(4, hidden_dim) + self.light_pl_emb = nn.Embedding(4, hidden_dim) + + self.r_pt2pt_emb = FourierEmbedding(input_dim=input_dim_r_pt2pt, hidden_dim=hidden_dim, + num_freq_bands=num_freq_bands) + self.pt2pt_layers = nn.ModuleList( + [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, + bipartite=False, has_pos_emb=True) for _ in range(num_layers)] + ) + self.token_size = 1024 + self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, + output_dim=self.token_size) + input_dim_token = 22 + self.token_emb = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim) + self.map_token = map_token + self.apply(weight_init) + self.mask_pt = False + + def maybe_autocast(self, dtype=torch.float32): + return torch.cuda.amp.autocast(dtype=dtype) + + def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]: + pt_valid_mask = data['pt_token']['pt_valid_mask'] + pt_pred_mask = data['pt_token']['pt_pred_mask'] + pt_target_mask = data['pt_token']['pt_target_mask'] + mask_s = pt_valid_mask + + pos_pt = data['pt_token']['position'][:, :self.input_dim].contiguous() + orient_pt = data['pt_token']['orientation'].contiguous() + orient_vector_pt = torch.stack([orient_pt.cos(), orient_pt.sin()], dim=-1) + token_sample_pt = self.map_token['traj_src'].to(pos_pt.device).to(torch.float) + pt_token_emb_src = self.token_emb(token_sample_pt.view(token_sample_pt.shape[0], -1)) + pt_token_emb = pt_token_emb_src[data['pt_token']['token_idx']] + + x_pt = pt_token_emb + + token2pl = data[('pt_token', 'to', 'map_polygon')]['edge_index'] + token_light_type = data['map_polygon']['light_type'][token2pl[1]] + x_pt_categorical_embs = [self.type_pt_emb(data['pt_token']['type'].long()), + self.polygon_type_emb(data['pt_token']['pl_type'].long()), + self.light_pl_emb(token_light_type.long()),] + x_pt = x_pt + torch.stack(x_pt_categorical_embs).sum(dim=0) + edge_index_pt2pt = radius_graph(x=pos_pt[:, :2], r=self.pl2pl_radius, + batch=data['pt_token']['batch'] if isinstance(data, Batch) else None, + loop=False, max_num_neighbors=100) + if self.mask_pt: + edge_index_pt2pt = subgraph(subset=mask_s, edge_index=edge_index_pt2pt)[0] + rel_pos_pt2pt = pos_pt[edge_index_pt2pt[0]] - pos_pt[edge_index_pt2pt[1]] + rel_orient_pt2pt = wrap_angle(orient_pt[edge_index_pt2pt[0]] - orient_pt[edge_index_pt2pt[1]]) + if self.input_dim == 2: + r_pt2pt = torch.stack( + [torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]], + nbr_vector=rel_pos_pt2pt[:, :2]), + rel_orient_pt2pt], dim=-1) + elif self.input_dim == 3: + r_pt2pt = torch.stack( + [torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]], + nbr_vector=rel_pos_pt2pt[:, :2]), + rel_pos_pt2pt[:, -1], + rel_orient_pt2pt], dim=-1) + else: + raise ValueError('{} is not a valid dimension'.format(self.input_dim)) + + # layers + r_pt2pt = self.r_pt2pt_emb(continuous_inputs=r_pt2pt, categorical_embs=None) + for i in range(self.num_layers): + x_pt = self.pt2pt_layers[i](x_pt, r_pt2pt, edge_index_pt2pt) + + next_token_prob = self.token_predict_head(x_pt[pt_pred_mask]) + next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1) + _, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1) + next_token_index_gt = data['pt_token']['token_idx'][pt_target_mask] + + return { + 'x_pt': x_pt, + 'map_next_token_idx': next_token_idx, + 'map_next_token_prob': next_token_prob, + 'map_next_token_idx_gt': next_token_index_gt, + 'map_next_token_eval_mask': pt_pred_mask[pt_pred_mask] + } diff --git a/backups/dev/modules/occ_decoder.py b/backups/dev/modules/occ_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..ca46881461122ffc721a7bbda960dee2c8902b92 --- /dev/null +++ b/backups/dev/modules/occ_decoder.py @@ -0,0 +1,927 @@ +import os +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Dict, Mapping, Optional, Literal +from torch_cluster import radius, radius_graph +from torch_geometric.data import HeteroData, Batch +from torch_geometric.utils import dense_to_sparse, subgraph +from scipy.optimize import linear_sum_assignment + +from dev.modules.attr_tokenizer import Attr_Tokenizer +from dev.modules.layers import * +from dev.utils.visualization import * +from dev.utils.func import angle_between_2d_vectors, wrap_angle, weight_init + + +class SMARTOccDecoder(nn.Module): + + def __init__(self, + dataset: str, + input_dim: int, + hidden_dim: int, + num_historical_steps: int, + time_span: Optional[int], + pl2a_radius: float, + pl2seed_radius: float, + a2a_radius: float, + a2sa_radius: float, + pl2sa_radius: float, + num_freq_bands: int, + num_layers: int, + num_heads: int, + head_dim: int, + dropout: float, + token_data: Dict, + token_size: int, + special_token_index: list=[], + attr_tokenizer: Attr_Tokenizer=None, + predict_motion: bool=False, + predict_state: bool=False, + predict_map: bool=False, + predict_occ: bool=False, + state_token: Dict[str, int]=None, + seed_size: int=5, + buffer_size: int=32, + loss_weight: dict=None, + logger=None) -> None: + + super(SMARTOccDecoder, self).__init__() + self.dataset = dataset + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.num_historical_steps = num_historical_steps + self.time_span = time_span if time_span is not None else num_historical_steps + self.pl2a_radius = pl2a_radius + self.pl2seed_radius = pl2seed_radius + self.a2a_radius = a2a_radius + self.a2sa_radius = a2sa_radius + self.pl2sa_radius = pl2sa_radius + self.num_freq_bands = num_freq_bands + self.num_layers = num_layers + self.num_heads = num_heads + self.head_dim = head_dim + self.dropout = dropout + self.special_token_index = special_token_index + self.predict_motion = predict_motion + self.predict_state = predict_state + self.predict_map = predict_map + self.predict_occ = predict_occ + self.loss_weight = loss_weight + self.logger = logger + + self.attr_tokenizer = attr_tokenizer + + # state tokens + self.state_type = list(state_token.keys()) + self.state_token = state_token + self.invalid_state = int(state_token['invalid']) + self.valid_state = int(state_token['valid']) + self.enter_state = int(state_token['enter']) + self.exit_state = int(state_token['exit']) + + self.seed_state_type = ['invalid', 'enter'] + self.valid_state_type = ['invalid', 'valid', 'exit'] + + input_dim_r_pt2a = 3 + input_dim_r_a2a = 3 + + self.seed_size = seed_size + self.buffer_size = buffer_size + + self.agent_type = ['veh', 'ped', 'cyc', 'seed'] + self.type_a_emb = nn.Embedding(len(self.agent_type), hidden_dim) + self.shape_emb = MLPEmbedding(input_dim=3, hidden_dim=hidden_dim) + self.state_a_emb = nn.Embedding(len(self.state_type), hidden_dim) + self.motion_gap = 1. + self.heading_gap = 1. + self.invalid_shape_value = .1 + self.invalid_motion_value = -2. + self.invalid_head_value = -2. + + self.r_pt2a_emb = FourierEmbedding(input_dim=input_dim_r_pt2a, hidden_dim=hidden_dim, + num_freq_bands=num_freq_bands) + self.r_a2a_emb = FourierEmbedding(input_dim=input_dim_r_a2a, hidden_dim=hidden_dim, + num_freq_bands=num_freq_bands) + + self.token_size = token_size # 2048 + self.grid_size = self.attr_tokenizer.grid_size + self.angle_size = self.attr_tokenizer.angle_size + self.agent_limit = 3 + self.pt_limit = 10 + self.grid_agent_occ_head = MLPLayer(input_dim=hidden_dim, hidden_dim=self.grid_size, + output_dim=self.agent_limit * self.grid_size) + self.grid_pt_occ_head = MLPLayer(input_dim=hidden_dim, hidden_dim=self.grid_size, + output_dim=self.pt_limit * self.grid_size) + + # self.num_seed_feature = 1 + # self.num_seed_feature = self.seed_size + self.num_seed_feature = 10 + + self.trajectory_token = token_data['token'] # dict('veh', 'ped', 'cyc') (2048, 4, 2) + self.trajectory_token_traj = token_data['traj'] # (2048, 6, 3) + self.trajectory_token_all = token_data['token_all'] # (2048, 6, 4, 2) + self.apply(weight_init) + + self.shift = 5 + self.beam_size = 5 + self.hist_mask = True + self.temporal_attn_to_invalid = False + self.use_rel = False + + # seed agent + self.temporal_attn_seed = False + self.seed_attn_to_av = True + self.seed_use_ego_motion = False + + def transform_rel(self, token_traj, prev_pos, prev_heading=None): + if prev_heading is None: + diff_xy = prev_pos[:, :, -1, :] - prev_pos[:, :, -2, :] + prev_heading = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0]) + + num_agent, num_step, traj_num, traj_dim = token_traj.shape + cos, sin = prev_heading.cos(), prev_heading.sin() + rot_mat = torch.zeros((num_agent, num_step, 2, 2), device=prev_heading.device) + rot_mat[:, :, 0, 0] = cos + rot_mat[:, :, 0, 1] = -sin + rot_mat[:, :, 1, 0] = sin + rot_mat[:, :, 1, 1] = cos + agent_diff_rel = torch.bmm(token_traj.view(-1, traj_num, 2), rot_mat.view(-1, 2, 2)).view(num_agent, num_step, traj_num, traj_dim) + agent_pred_rel = agent_diff_rel + prev_pos[:, :, -1:, :] + return agent_pred_rel + + def _agent_token_embedding(self, data, agent_token_index, agent_state, agent_offset_token_idx, pos_a, head_a, + inference=False, filter_mask=None, av_index=None): + + if filter_mask is None: + filter_mask = torch.ones_like(agent_state[:, 2], dtype=torch.bool) + + num_agent, num_step, traj_dim = pos_a.shape # traj_dim=2 + agent_type = data['agent']['type'][filter_mask] + veh_mask = (agent_type == 0) + ped_mask = (agent_type == 1) + cyc_mask = (agent_type == 2) + + motion_vector_a, head_vector_a = self._build_vector_a(pos_a, head_a, agent_state) + + trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(pos_a.device).to(torch.float) + trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(pos_a.device).to(torch.float) + trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(pos_a.device).to(torch.float) + self.agent_token_emb_veh = self.token_emb_veh(trajectory_token_veh.view(trajectory_token_veh.shape[0], -1)) # (token_size, 8) + self.agent_token_emb_ped = self.token_emb_ped(trajectory_token_ped.view(trajectory_token_ped.shape[0], -1)) + self.agent_token_emb_cyc = self.token_emb_cyc(trajectory_token_cyc.view(trajectory_token_cyc.shape[0], -1)) + + # add bos token embedding + self.agent_token_emb_veh = torch.cat([self.agent_token_emb_veh, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())]) + self.agent_token_emb_ped = torch.cat([self.agent_token_emb_ped, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())]) + self.agent_token_emb_cyc = torch.cat([self.agent_token_emb_cyc, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())]) + + # add invalid token embedding + self.agent_token_emb_veh = torch.cat([self.agent_token_emb_veh, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())]) + self.agent_token_emb_ped = torch.cat([self.agent_token_emb_ped, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())]) + self.agent_token_emb_cyc = torch.cat([self.agent_token_emb_cyc, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())]) + + # self.grid_token_emb = self.token_emb_grid(torch.stack([self.attr_tokenizer.dist, + # self.attr_tokenizer.dir], dim=-1).to(pos_a.device)) + self.grid_token_emb = self.token_emb_grid(self.attr_tokenizer.grid) + self.grid_token_emb = torch.cat([self.grid_token_emb, self.invalid_offset_token_emb(torch.zeros(1, device=pos_a.device).long())]) + + if inference: + agent_token_traj_all = torch.zeros((num_agent, self.token_size, self.shift + 1, 4, 2), device=pos_a.device) + trajectory_token_all_veh = torch.from_numpy(self.trajectory_token_all['veh']).clone().to(pos_a.device).to(torch.float) + trajectory_token_all_ped = torch.from_numpy(self.trajectory_token_all['ped']).clone().to(pos_a.device).to(torch.float) + trajectory_token_all_cyc = torch.from_numpy(self.trajectory_token_all['cyc']).clone().to(pos_a.device).to(torch.float) + agent_token_traj_all[veh_mask] = torch.cat( + [trajectory_token_all_veh[:, :self.shift], trajectory_token_veh[:, None, ...]], dim=1) + agent_token_traj_all[ped_mask] = torch.cat( + [trajectory_token_all_ped[:, :self.shift], trajectory_token_ped[:, None, ...]], dim=1) + agent_token_traj_all[cyc_mask] = torch.cat( + [trajectory_token_all_cyc[:, :self.shift], trajectory_token_cyc[:, None, ...]], dim=1) + + # additional token embeddings are already added -> -1: invalid, -2: bos + agent_token_emb = torch.zeros((num_agent, num_step, self.hidden_dim), device=pos_a.device) + agent_token_emb[veh_mask] = self.agent_token_emb_veh[agent_token_index[veh_mask]] + agent_token_emb[ped_mask] = self.agent_token_emb_ped[agent_token_index[ped_mask]] + agent_token_emb[cyc_mask] = self.agent_token_emb_cyc[agent_token_index[cyc_mask]] + + offset_token_emb = self.grid_token_emb[agent_offset_token_idx] + + # 'vehicle', 'pedestrian', 'cyclist', 'background' + is_invalid = agent_state == self.invalid_state + agent_types = data['agent']['type'].clone()[filter_mask].long().repeat_interleave(repeats=num_step, dim=0) + agent_types[is_invalid.reshape(-1)] = self.agent_type.index('seed') + agent_shapes = data['agent']['shape'].clone()[filter_mask, self.num_historical_steps - 1, :].repeat_interleave(repeats=num_step, dim=0) + agent_shapes[is_invalid.reshape(-1)] = self.invalid_shape_value + + # TODO: fix ego_pos in inference mode + offset_pos = pos_a - pos_a[av_index].repeat_interleave(repeats=data['batch_size_a'], dim=0) + feat_a, categorical_embs = self._build_agent_feature(num_step, pos_a.device, + motion_vector_a, + head_vector_a, + agent_token_emb, + offset_token_emb, + offset_pos=offset_pos, + type=agent_types, + shape=agent_shapes, + state=agent_state, + n=num_agent) + + if inference: + return feat_a, agent_token_traj_all, agent_token_emb, categorical_embs + else: + # seed agent feature + if self.seed_use_ego_motion: + motion_vector_seed = motion_vector_a[av_index].repeat_interleave(repeats=self.num_seed_feature, dim=0) + head_vector_seed = head_vector_a[av_index].repeat_interleave(repeats=self.num_seed_feature, dim=0) + else: + motion_vector_seed = head_vector_seed = None + feat_seed, _ = self._build_agent_feature(num_step, pos_a.device, + motion_vector_seed, + head_vector_seed, + state_index=self.invalid_state, + n=data.num_graphs * self.num_seed_feature) + + feat_a = torch.cat([feat_a, feat_seed], dim=0) # (a + n, t, d) + + return feat_a + + def _build_vector_a(self, pos_a, head_a, state_a): + num_agent = pos_a.shape[0] + + motion_vector_a = torch.cat([pos_a.new_zeros(num_agent, 1, self.input_dim), + pos_a[:, 1:] - pos_a[:, :-1]], dim=1) + + motion_vector_a[state_a == self.invalid_state] = self.invalid_motion_value + + # invalid -> valid + is_last_invalid = (state_a.roll(shifts=1, dims=1) == self.invalid_state) & (state_a != self.invalid_state) + is_last_invalid[:, 0] = state_a[:, 0] == self.enter_state + motion_vector_a[is_last_invalid] = self.motion_gap + + # valid -> invalid + is_last_valid = (state_a.roll(shifts=1, dims=1) != self.invalid_state) & (state_a == self.invalid_state) + is_last_valid[:, 0] = False + motion_vector_a[is_last_valid] = -self.motion_gap + + head_a[state_a == self.invalid_state] == self.invalid_head_value + head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1) + + return motion_vector_a, head_vector_a + + def _build_agent_feature(self, num_step, device, + motion_vector=None, + head_vector=None, + agent_token_emb=None, + agent_grid_emb=None, + offset_pos=None, + type=None, + shape=None, + categorical_embs_a=None, + state=None, + state_index=None, + n=1): + + if agent_token_emb is None: + agent_token_emb = self.no_token_emb(torch.zeros(1, device=device).long())[:, None].repeat(n, num_step, 1) + if state is not None: + agent_token_emb[state == self.enter_state] = self.bos_token_emb(torch.zeros(1, device=device).long()) + + if agent_grid_emb is None: + agent_grid_emb = self.grid_token_emb[None, None, self.grid_size // 2].repeat(n, num_step, 1) + + if motion_vector is None or head_vector is None: + pos_a = torch.zeros((n, num_step, 2), device=device) + head_a = torch.zeros((n, num_step), device=device) + if state is None: + state = torch.full((n, num_step), self.invalid_state, device=device) + motion_vector, head_vector = self._build_vector_a(pos_a, head_a, state) + + if offset_pos is None: + offset_pos = torch.zeros_like(motion_vector) + + feature_a = torch.stack( + [torch.norm(motion_vector[:, :, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=head_vector, nbr_vector=motion_vector[:, :, :2]), + # torch.norm(offset_pos[:, :, :2], p=2, dim=-1), + ], dim=-1) + + if categorical_embs_a is None: + if type is None: + type = torch.tensor([self.agent_type.index('seed')], device=device) + if shape is None: + shape = torch.full((1, 3), self.invalid_shape_value, device=device) + + categorical_embs_a = [self.type_a_emb(type.reshape(-1)), self.shape_emb(shape.reshape(-1, shape.shape[-1]))] + + x_a = self.x_a_emb(continuous_inputs=feature_a.view(-1, feature_a.size(-1)), + categorical_embs=categorical_embs_a) + x_a = x_a.view(-1, num_step, self.hidden_dim) # (a, t, d) + + if state is None: + assert state_index is not None, f"state index need to be set when state tensor is None!" + state = torch.tensor([state_index], device=device)[:, None].repeat(n, num_step, 1) # do not use `expand` + s_a = self.state_a_emb(state.reshape(-1).long()).reshape(n, num_step, self.hidden_dim) + + feat_a = torch.cat((agent_token_emb, x_a, s_a, agent_grid_emb), dim=-1) + feat_a = self.fusion_emb(feat_a) # (a, t, d) + + return feat_a, categorical_embs_a + + def _pad_feat(self, num_graph, av_index, *feats, num_seed_feature=None): + + if num_seed_feature is None: + num_seed_feature = self.num_seed_feature + + padded_feats = tuple() + for i in range(len(feats)): + padded_feats += (torch.cat([feats[i], feats[i][av_index].repeat_interleave( + repeats=num_seed_feature, dim=0)], + dim=0 + ),) + + pad_mask = torch.ones(*padded_feats[0].shape[:2], device=feats[0].device).bool() # (a, t) + pad_mask[-num_graph * num_seed_feature:] = False + + return padded_feats + (pad_mask,) + + def _build_seed_feat(self, data, pos_a, head_a, state_a, head_vector_a, mask, sort_indices, av_index): + seed_mask = sort_indices != av_index.repeat_interleave(repeats=data['batch_size_a'], dim=0)[:, None] + # TODO: fix batch_size!!! + print(mask.shape, sort_indices.shape, seed_mask.shape) + mask[-data.num_graphs * self.num_seed_feature:] = seed_mask[:self.num_seed_feature] + + insert_pos_a = torch.gather(pos_a, dim=0, index=sort_indices[:self.num_seed_feature, :, None].expand(-1, -1, pos_a.shape[-1])) + pos_a[mask] = insert_pos_a[mask[-self.num_seed_feature:]] + + state_a[-data.num_graphs * self.num_seed_feature:] = self.enter_state + + return pos_a, head_a, state_a, head_vector_a, mask + + def _build_temporal_edge(self, data, pos_a, head_a, state_a, head_vector_a, mask, inference_mask=None): + + num_graph = data.num_graphs + num_agent = pos_a.shape[0] + hist_mask = mask.clone() + + if not self.temporal_attn_to_invalid: + is_bos = state_a == self.enter_state + bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0)) + history_invalid_mask = torch.arange(mask.shape[1]).expand(mask.shape[0], mask.shape[1]).to(mask.device) + history_invalid_mask = (history_invalid_mask < bos_index[:, None]) + hist_mask[history_invalid_mask] = False + + if not self.temporal_attn_seed: + hist_mask[-num_graph * self.num_seed_feature:] = False + if inference_mask is not None: + inference_mask[-num_graph * self.num_seed_feature:] = False + else: + # WARNING: if use temporal attn to seed + # we need to fix the pos/head of seed!!! + raise RuntimeError("Wrong settings!") + + pos_t = pos_a.reshape(-1, self.input_dim) # (num_agent * num_step, ...) + head_t = head_a.reshape(-1) + head_vector_t = head_vector_a.reshape(-1, 2) + + # for those invalid agents won't predict any motion token, we don't attend to them + is_bos = state_a == self.enter_state + is_bos[-num_graph * self.num_seed_feature:] = False + bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0)) + motion_predict_start_index = torch.clamp(bos_index - self.time_span / self.shift + 1, min=0) + motion_predict_mask = torch.arange(hist_mask.shape[1]).expand(hist_mask.shape[0], -1).to(hist_mask.device) + motion_predict_mask = motion_predict_mask >= motion_predict_start_index[:, None] + hist_mask[~motion_predict_mask] = False + + if self.hist_mask and self.training: + hist_mask[ + torch.arange(mask.shape[0]).unsqueeze(1), torch.randint(0, mask.shape[1], (num_agent, 10))] = False + mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1) + elif inference_mask is not None: + mask_t = hist_mask.unsqueeze(2) & inference_mask.unsqueeze(1) + else: + mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1) + + # mask_t: (num_agent, 18, 18), edge_index_t: (2, num_edge) + edge_index_t = dense_to_sparse(mask_t)[0] + edge_index_t = edge_index_t[:, (edge_index_t[1] - edge_index_t[0] > 0) & + (edge_index_t[1] - edge_index_t[0] <= self.time_span / self.shift)] + rel_pos_t = pos_t[edge_index_t[0]] - pos_t[edge_index_t[1]] + rel_head_t = wrap_angle(head_t[edge_index_t[0]] - head_t[edge_index_t[1]]) + + # handle the invalid steps + is_invalid = state_a == self.invalid_state + is_invalid_t = is_invalid.reshape(-1) + + rel_pos_t[is_invalid_t[edge_index_t[0]] & ~is_invalid_t[edge_index_t[1]]] = -self.motion_gap + rel_pos_t[~is_invalid_t[edge_index_t[0]] & is_invalid_t[edge_index_t[1]]] = self.motion_gap + rel_head_t[is_invalid_t[edge_index_t[0]] & ~is_invalid_t[edge_index_t[1]]] = -self.heading_gap + rel_head_t[~is_invalid_t[edge_index_t[1]] & is_invalid_t[edge_index_t[1]]] = self.heading_gap + + rel_pos_t[is_invalid_t[edge_index_t[0]] & is_invalid_t[edge_index_t[1]]] = self.invalid_motion_value + rel_head_t[is_invalid_t[edge_index_t[0]] & is_invalid_t[edge_index_t[1]]] = self.invalid_head_value + + r_t = torch.stack( + [torch.norm(rel_pos_t[:, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=head_vector_t[edge_index_t[1]], nbr_vector=rel_pos_t[:, :2]), + rel_head_t, + edge_index_t[0] - edge_index_t[1]], dim=-1) + r_t = self.r_t_emb(continuous_inputs=r_t, categorical_embs=None) + + return edge_index_t, r_t + + def _build_interaction_edge(self, data, pos_a, head_a, state_a, head_vector_a, batch_s, mask, pad_mask=None, inference_mask=None, + av_index=None, seq_mask=None, seq_index=None, grid_index_a=None, **plot_kwargs): + num_graph = data.num_graphs + num_agent, num_step, _ = pos_a.shape + is_training = inference_mask is None + + mask_a = mask.clone() + + if pad_mask is None: + pad_mask = torch.ones_like(state_a).bool() + + pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim) + head_s = head_a.transpose(0, 1).reshape(-1) + head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2) + pad_mask_s = pad_mask.transpose(0, 1).reshape(-1) + if inference_mask is not None: + mask_a = mask_a & inference_mask + mask_s = mask_a.transpose(0, 1).reshape(-1) + + # build agent2agent bilateral connection + edge_index_a2a = radius_graph(x=pos_s[:, :2], r=self.a2a_radius, batch=batch_s, loop=False, + max_num_neighbors=300) + edge_index_a2a = subgraph(subset=mask_s & pad_mask_s, edge_index=edge_index_a2a)[0] + + if os.getenv('PLOT_EDGE', False): + plot_interact_edge(edge_index_a2a, data['scenario_id'], data['batch_size_a'].cpu(), self.num_seed_feature, num_step, + av_index=av_index, **plot_kwargs) + + rel_pos_a2a = pos_s[edge_index_a2a[0]] - pos_s[edge_index_a2a[1]] + rel_head_a2a = wrap_angle(head_s[edge_index_a2a[0]] - head_s[edge_index_a2a[1]]) + + # handle the invalid steps + is_invalid = state_a == self.invalid_state + is_invalid_s = is_invalid.transpose(0, 1).reshape(-1) + + rel_pos_a2a[is_invalid_s[edge_index_a2a[0]] & ~is_invalid_s[edge_index_a2a[1]]] = -self.motion_gap + rel_pos_a2a[~is_invalid_s[edge_index_a2a[0]] & is_invalid_s[edge_index_a2a[1]]] = self.motion_gap + rel_head_a2a[is_invalid_s[edge_index_a2a[0]] & ~is_invalid_s[edge_index_a2a[1]]] = -self.heading_gap + rel_head_a2a[~is_invalid_s[edge_index_a2a[1]] & is_invalid_s[edge_index_a2a[1]]] = self.heading_gap + + rel_pos_a2a[is_invalid_s[edge_index_a2a[0]] & is_invalid_s[edge_index_a2a[1]]] = self.invalid_motion_value + rel_head_a2a[is_invalid_s[edge_index_a2a[0]] & is_invalid_s[edge_index_a2a[1]]] = self.invalid_head_value + + r_a2a = torch.stack( + [torch.norm(rel_pos_a2a[:, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2a[1]], nbr_vector=rel_pos_a2a[:, :2]), + rel_head_a2a, + torch.zeros_like(edge_index_a2a[0])], dim=-1) + r_a2a = self.r_a2a_emb(continuous_inputs=r_a2a, categorical_embs=None) + + # add the edges which connect seed agents + if is_training: + mask_av = torch.ones_like(mask_a).bool() + if not self.seed_attn_to_av: + mask_av[av_index] = False + mask_a &= mask_av + edge_index_seed2a, r_seed2a = self._build_a2sa_edge(data, pos_a, head_a, head_vector_a, batch_s, + mask_a.clone(), ~pad_mask.clone(), inference_mask=inference_mask, + r=self.pl2seed_radius, max_num_neighbors=300, + seq_mask=seq_mask, seq_index=seq_index, grid_index_a=grid_index_a, mode='grid') + + if os.getenv('PLOT_EDGE', False): + plot_interact_edge(edge_index_seed2a, data['scenario_id'], data['batch_size_a'].cpu(), self.num_seed_feature, num_step, + 'interact_edge_map_seed', av_index=av_index, **plot_kwargs) + + edge_index_a2a = torch.cat([edge_index_a2a, edge_index_seed2a], dim=-1) + r_a2a = torch.cat([r_a2a, r_seed2a]) + + return edge_index_a2a, r_a2a, (edge_index_a2a.shape[1], edge_index_seed2a.shape[1]) #, nearest_dict + + return edge_index_a2a, r_a2a + + def _build_map2agent_edge(self, data, pos_a, head_a, state_a, head_vector_a, batch_s, batch_pl, + mask, pad_mask=None, inference_mask=None, av_index=None, **kwargs): + num_graph = data.num_graphs + num_agent, num_step, _ = pos_a.shape + is_training = inference_mask is None + + mask_pl2a = mask.clone() + + if pad_mask is None: + pad_mask = torch.ones_like(state_a).bool() + + pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim) + head_s = head_a.transpose(0, 1).reshape(-1) + head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2) + pad_mask_s = pad_mask.transpose(0, 1).reshape(-1) + if inference_mask is not None: + mask_pl2a = mask_pl2a & inference_mask + mask_pl2a = mask_pl2a.transpose(0, 1).reshape(-1) + + ori_pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous() + ori_orient_pl = data['pt_token']['orientation'].contiguous() + pos_pl = ori_pos_pl.repeat(num_step, 1) # not `repeat_interleave` + orient_pl = ori_orient_pl.repeat(num_step) + + # build map2agent directed graph + # edge_index_pl2a = radius(x=pos_s[:, :2], y=pos_pl[:, :2], r=self.pl2a_radius, + # batch_x=batch_s, batch_y=batch_pl, max_num_neighbors=300) + edge_index_pl2a = radius(x=pos_pl[:, :2], y=pos_s[:, :2], r=self.pl2a_radius, + batch_x=batch_pl, batch_y=batch_s, max_num_neighbors=5) + edge_index_pl2a = edge_index_pl2a[[1, 0]] + edge_index_pl2a = edge_index_pl2a[:, mask_pl2a[edge_index_pl2a[1]] & + pad_mask_s[edge_index_pl2a[1]]] + + rel_pos_pl2a = pos_pl[edge_index_pl2a[0]] - pos_s[edge_index_pl2a[1]] + rel_orient_pl2a = wrap_angle(orient_pl[edge_index_pl2a[0]] - head_s[edge_index_pl2a[1]]) + + # handle the invalid steps + is_invalid = state_a == self.invalid_state + is_invalid_s = is_invalid.transpose(0, 1).reshape(-1) + rel_pos_pl2a[is_invalid_s[edge_index_pl2a[1]]] = self.motion_gap + rel_orient_pl2a[is_invalid_s[edge_index_pl2a[1]]] = self.heading_gap + + r_pl2a = torch.stack( + [torch.norm(rel_pos_pl2a[:, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_pl2a[1]], nbr_vector=rel_pos_pl2a[:, :2]), + rel_orient_pl2a], dim=-1) + r_pl2a = self.r_pt2a_emb(continuous_inputs=r_pl2a, categorical_embs=None) + + # add the edges which connect seed agents + if is_training: + edge_index_pl2seed, r_pl2seed = self._build_map2sa_edge(data, pos_a, head_a, head_vector_a, batch_s, batch_pl, + ~pad_mask.clone(), inference_mask=inference_mask, + r=self.pl2seed_radius, max_num_neighbors=2048, mode='grid') + + # sanity check + # pl2a_index = torch.zeros(pos_a.shape[0], num_step) + # pl2a_r = torch.zeros(pos_a.shape[0], num_step) + # for src_index in torch.unique(edge_index_pl2seed[1]): + # src_row = src_index % pos_a.shape[0] + # src_col = src_index // pos_a.shape[0] + # pl2a_index[src_row, src_col] = edge_index_pl2seed[0, edge_index_pl2seed[1] == src_index].sum() + # pl2a_r[src_row, src_col] = r_pl2seed[edge_index_pl2seed[1] == src_index].sum() + # print(pl2a_index) + # print(pl2a_r) + # exit(1) + + if os.getenv('PLOT_EDGE', False): + plot_interact_edge(edge_index_pl2seed, data['scenario_id'], data['batch_size_a'].cpu(), self.num_seed_feature, num_step, + 'interact_edge_map_seed', av_index=av_index) + + edge_index_pl2a = torch.cat([edge_index_pl2a, edge_index_pl2seed], dim=-1) + r_pl2a = torch.cat([r_pl2a, r_pl2seed]) + + return edge_index_pl2a, r_pl2a, (edge_index_pl2a.shape[1], edge_index_pl2seed.shape[1]) + + return edge_index_pl2a, r_pl2a + + def _build_a2sa_edge(self, data, pos_a, head_a, head_vector_a, batch_s, mask_a, mask_sa, + inference_mask=None, r=None, max_num_neighbors=8, seq_mask=None, seq_index=None, + grid_index_a=None, mode: Literal['grid', 'heading']='heading', **plot_kwargs): + + num_agent, num_step, _ = pos_a.shape + is_training = inference_mask is None + + pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim) + head_s = head_a.transpose(0, 1).reshape(-1) + head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2) + if inference_mask is not None: + mask_a = mask_a & inference_mask + mask_sa = mask_sa & inference_mask + mask_s = mask_a.transpose(0, 1).reshape(-1) + mask_s_sa = mask_sa.transpose(0, 1).reshape(-1) + + # build seed_agent2agent unilateral connection + assert r is not None, "r needs to be specified!" + # edge_index_a2sa = radius(x=pos_s[mask_s_sa, :2], y=pos_s[:, :2], r=r, + # batch_x=batch_s[mask_s_sa], batch_y=batch_s, max_num_neighbors=max_num_neighbors) + edge_index_a2sa = radius(x=pos_s[:, :2], y=pos_s[mask_s_sa, :2], r=r, + batch_x=batch_s, batch_y=batch_s[mask_s_sa], max_num_neighbors=max_num_neighbors) + edge_index_a2sa = edge_index_a2sa[[1, 0]] + edge_index_a2sa = edge_index_a2sa[:, ~mask_s_sa[edge_index_a2sa[0]] & mask_s[edge_index_a2sa[0]]] + + # only for seed agent sequence training + if seq_mask is not None: + edge_mask = seq_mask[edge_index_a2sa[1]] + edge_mask = torch.gather(edge_mask, dim=1, index=edge_index_a2sa[0, :, None] % num_agent)[:, 0] + edge_index_a2sa = edge_index_a2sa[:, edge_mask] + + if seq_index is None: + seq_index = torch.zeros(num_agent, device=pos_a.device).long() + if seq_index.dim() == 1: + seq_index = seq_index[:, None].repeat(1, num_step) + seq_index = seq_index.transpose(0, 1).reshape(-1) + assert seq_index.shape[0] == pos_s.shape[0], f"Inconsistent lenght {seq_index.shape[0]} and {pos_s.shape[0]}!" + + # convert to global index + all_index = torch.arange(pos_s.shape[0], device=pos_a.device).long() + sa_index = all_index[mask_s_sa] + edge_index_a2sa[1] = sa_index[edge_index_a2sa[1]] + + # plot edge index TODO: now only support bs=1 + if os.getenv('PLOT_EDGE_INFERENCE', False) and not is_training: + num_agent, num_step, _ = pos_a.shape + # plot_interact_edge(edge_index_a2sa, data['scenario_id'], data['batch_size_a'].cpu(), 1, num_step, + # 'interact_a2sa_edge_map', **plot_kwargs) + plot_interact_edge(edge_index_a2sa, data['scenario_id'], torch.tensor([num_agent - 1]), 1, num_step, + f"interact_a2sa_edge_map_infer_{plot_kwargs['tag']}", **plot_kwargs) + + rel_pos_a2sa = pos_s[edge_index_a2sa[0]] - pos_s[edge_index_a2sa[1]] + rel_head_a2sa = wrap_angle(head_s[edge_index_a2sa[0]] - head_s[edge_index_a2sa[1]]) + + r_a2sa = torch.stack( + [torch.norm(rel_pos_a2sa[:, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2sa[1]], nbr_vector=rel_pos_a2sa[:, :2]), + rel_head_a2sa, + seq_index[edge_index_a2sa[0]] - seq_index[edge_index_a2sa[1]]], dim=-1) + r_a2sa = self.r_a2sa_emb(continuous_inputs=r_a2sa, categorical_embs=None) + + return edge_index_a2sa, r_a2sa + + def _build_map2sa_edge(self, data, pos_a, head_a, head_vector_a, batch_s, batch_pl, + mask_sa, inference_mask=None, r=None, max_num_neighbors=32, mode: Literal['grid', 'heading']='heading'): + + _, num_step, _ = pos_a.shape + + mask_pl2sa = torch.ones_like(mask_sa).bool() + if inference_mask is not None: + mask_pl2sa = mask_pl2sa & inference_mask + mask_pl2sa = mask_pl2sa.transpose(0, 1).reshape(-1) + mask_s_sa = mask_sa.transpose(0, 1).reshape(-1) + + pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim) + head_s = head_a.transpose(0, 1).reshape(-1) + head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2) + + ori_pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous() + ori_orient_pl = data['pt_token']['orientation'].contiguous() + pos_pl = ori_pos_pl.repeat(num_step, 1) # not `repeat_interleave` + orient_pl = ori_orient_pl.repeat(num_step) + + # build map2agent directed graph + assert r is not None, "r needs to be specified!" + # edge_index_pl2sa = radius(x=pos_s[mask_s_sa, :2], y=pos_pl[:, :2], r=r, + # batch_x=batch_s[mask_s_sa], batch_y=batch_pl, max_num_neighbors=max_num_neighbors) + edge_index_pl2sa = radius(x=pos_pl[:, :2], y=pos_s[mask_s_sa, :2], r=r, + batch_x=batch_pl, batch_y=batch_s[mask_s_sa], max_num_neighbors=max_num_neighbors) + edge_index_pl2sa = edge_index_pl2sa[[1, 0]] + edge_index_pl2sa = edge_index_pl2sa[:, mask_pl2sa[mask_s_sa][edge_index_pl2sa[1]]] + + # convert to global index + all_index = torch.arange(pos_s.shape[0], device=pos_a.device).long() + sa_index = all_index[mask_s_sa] + edge_index_pl2sa[1] = sa_index[edge_index_pl2sa[1]] + + # plot edge map + # if os.getenv('PLOT_EDGE', False): + # plot_map_edge(edge_index_pl2sa, pos_s[:, :2], data, save_path='map2sa_edge_map') + + rel_pos_pl2sa = pos_pl[edge_index_pl2sa[0]] - pos_s[edge_index_pl2sa[1]] + rel_orient_pl2sa = wrap_angle(orient_pl[edge_index_pl2sa[0]] - head_s[edge_index_pl2sa[1]]) + + r_pl2sa = torch.stack( + [torch.norm(rel_pos_pl2sa[:, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_pl2sa[1]], nbr_vector=rel_pos_pl2sa[:, :2]), + rel_orient_pl2sa], dim=-1) + r_pl2sa = self.r_pt2sa_emb(continuous_inputs=r_pl2sa, categorical_embs=None) + + return edge_index_pl2sa, r_pl2sa + + def _build_sa2sa_edge(self, data, pos_a, head_a, state_a, head_vector_a, batch_s, mask, inference_mask=None, **plot_kwargs): + + num_agent = pos_a.shape[0] + + pos_t = pos_a.transpose(0, 1).reshape(-1, self.input_dim) + head_t = head_a.reshape(-1) + head_vector_t = head_vector_a.reshape(-1, 2) + + if inference_mask is not None: + mask_t = mask.unsqueeze(2) & inference_mask.unsqueeze(1) + else: + mask_t = mask.unsqueeze(2) & mask.unsqueeze(1) + + edge_index_sa2sa = dense_to_sparse(mask_t)[0] + edge_index_sa2sa = edge_index_sa2sa[:, edge_index_sa2sa[1] - edge_index_sa2sa[0] > 0] + rel_pos_t = pos_t[edge_index_sa2sa[0]] - pos_t[edge_index_sa2sa[1]] + rel_head_t = wrap_angle(head_t[edge_index_sa2sa[0]] - head_t[edge_index_sa2sa[1]]) + + r_t = torch.stack( + [torch.norm(rel_pos_t[:, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=head_vector_t[edge_index_sa2sa[1]], nbr_vector=rel_pos_t[:, :2]), + rel_head_t, + edge_index_sa2sa[0] - edge_index_sa2sa[1]], dim=-1) + r_sa2sa = self.r_sa2sa_emb(continuous_inputs=r_t, categorical_embs=None) + + return edge_index_sa2sa, r_sa2sa + + def _build_seq(self, device, num_agent, num_step, av_index, sort_indices): + """ + Args: + sort_indices (torch.Tensor): shape (num_agent, num_atep) + """ + # sort_indices = sort_indices[:self.num_seed_feature] + seq_mask = torch.ones(self.num_seed_feature, num_step, num_agent + self.num_seed_feature, device=device).bool() + seq_mask[..., -self.num_seed_feature:] = False + for t in range(num_step): + for s in range(self.num_seed_feature): + seq_mask[s, t, sort_indices[s:, t].flatten().long()] = False + if self.seed_attn_to_av: + seq_mask[..., av_index] = True + seq_mask = seq_mask.transpose(0, 1).reshape(-1, num_agent + self.num_seed_feature) + + seq_index = torch.cat([torch.zeros(num_agent), torch.arange(self.num_seed_feature) + 1]).to(device) + seq_index = seq_index[:, None].repeat(1, num_step) + for t in range(num_step): + for s in range(self.num_seed_feature): + seq_index[sort_indices[s : s + 1, t].flatten().long(), t] = s + 1 + seq_index[av_index] = 0 + + return seq_mask, seq_index + + def _build_occ_gt(self, data, seq_mask, pos_rel_index_gt, pos_rel_index_gt_seed, mask_seed, + edge_index=None, mode='edge_index'): + """ + Args: + seq_mask (torch.Tensor): shape (num_step * num_seed_feature, num_agent + self.num_seed_feature) + pos_rel_index_gt (torch.Tensor): shape (num_agent, num_step) + pos_rel_index_gt_seed (torch.Tensor): shape (num_seed, num_step) + """ + num_agent = data['agent']['state_idx'].shape[0] + self.num_seed_feature + num_step = data['agent']['state_idx'].shape[1] + data['agent']['agent_occ'] = torch.zeros(data.num_graphs * self.num_seed_feature, num_step, self.attr_tokenizer.grid_size, + device=data['agent']['state_idx'].device).long() + data['agent']['map_occ'] = torch.zeros(data.num_graphs, num_step, self.attr_tokenizer.grid_size, + device=data['agent']['state_idx'].device).long() + + if mode == 'edge_index': + + assert edge_index is not None, f"Need edge_index input!" + for src_index in torch.unique(edge_index[1]): + # decode src + src_row = src_index % num_agent - (num_agent - self.num_seed_feature) + src_col = src_index // num_agent + # decode tgt + tgt_indexes = edge_index[0, edge_index[1] == src_index] + tgt_rows = tgt_indexes % num_agent + tgt_cols = tgt_indexes // num_agent + assert tgt_rows.max() < num_agent - self.num_seed_feature, f"Invalid {tgt_rows}" + assert torch.unique(tgt_cols).shape[0] == 1 and torch.unique(tgt_cols)[0] == src_col + data['agent']['agent_occ'][src_row, src_col, pos_rel_index_gt[tgt_rows, tgt_cols]] = 1 + + else: + + seq_mask = seq_mask.reshape(num_step, self.num_seed_feature, -1).transpose(0, 1)[..., :-self.num_seed_feature] + for s in range(self.num_seed_feature): + for t in range(num_step): + index = pos_rel_index_gt[seq_mask[s, t], t] + data['agent']['agent_occ'][s, t, index[index != -1]] = 1 + if t > 0 and s < pos_rel_index_gt_seed.shape[0] and mask_seed[s, t - 1]: # insert agents + data['agent']['agent_occ'][s, t, pos_rel_index_gt_seed[s, t - 1]] = -1 + + # TODO: fix batch_size!!! + pt_grid_token_idx = data['agent']['pt_grid_token_idx'] # (t, num_pt) + for t in range(num_step): + data['agent']['map_occ'][:, t, pt_grid_token_idx[t][pt_grid_token_idx[t] != -1]] = 1 + data['agent']['map_occ'] = data['agent']['map_occ'].repeat_interleave(repeats=self.num_seed_feature, dim=0) + + def forward(self, + data: HeteroData, + map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + + pos_a = data['agent']['token_pos'].clone() # (a, t, 2) + head_a = data['agent']['token_heading'].clone() # (a, t) + num_agent, num_step, traj_dim = pos_a.shape # e.g. (50, 18, 2) + num_pt = data['pt_token']['position'].shape[0] + agent_category = data['agent']['category'].clone() # (a,) + agent_shape = data['agent']['shape'][:, self.num_historical_steps - 1].clone() # (a, 3) + agent_token_index = data['agent']['token_idx'].clone() # (a, t) + agent_state_index = data['agent']['state_idx'].clone() + agent_type_index = data['agent']['type'].clone() + + av_index = data['agent']['av_index'].long() + ego_pos = pos_a[av_index] + ego_head = head_a[av_index] + + _, head_vector_a = self._build_vector_a(pos_a, head_a, agent_state_index) + + agent_grid_token_idx = data['agent']['grid_token_idx'] + agent_grid_offset_xy = data['agent']['grid_offset_xy'] + agent_head_token_idx = data['agent']['heading_token_idx'] + sort_indices = data['agent']['sort_indices'] + pt_grid_token_idx = data['agent']['pt_grid_token_idx'] + + ori_pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous() + ori_orient_pl = data['pt_token']['orientation'].contiguous() + pos_pl = ori_pos_pl.repeat(num_step, 1) + orient_pl = ori_orient_pl.repeat(num_step) + + # build relative 3d descriptors + pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim) + head_s = head_a.transpose(0, 1).reshape(-1) + + ego_pos_a = ego_pos.repeat_interleave(repeats=data['batch_size_a'], dim=0) + ego_head_a = ego_head.repeat_interleave(repeats=data['batch_size_a'], dim=0) + ego_pos_s = ego_pos_a.transpose(0, 1).reshape(-1, self.input_dim) + ego_head_s = ego_head_a.transpose(0, 1).reshape(-1) + rel_pos_a2a = pos_s - ego_pos_s + rel_head_a2a = head_s - ego_head_s + + ego_pos_pl = ego_pos.repeat_interleave(repeats=data['batch_size_pl'], dim=0) + ego_head_pl = ego_head.repeat_interleave(repeats=data['batch_size_pl'], dim=0) + ego_pos_s = ego_pos_pl.transpose(0, 1).reshape(-1, self.input_dim) + ego_head_s = ego_head_pl.transpose(0, 1).reshape(-1) + rel_pos_pl2a = pos_pl - ego_pos_s + rel_head_pl2a = orient_pl - ego_head_s + + # releative encodings + ego_head_vector_a = head_vector_a[av_index].repeat_interleave(repeats=data['batch_size_a'], dim=0) + ego_head_vector_s = ego_head_vector_a.transpose(0, 1).reshape(-1, 2) + r_a2a = torch.stack( + [torch.norm(rel_pos_a2a[:, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=ego_head_vector_s, nbr_vector=rel_pos_a2a[:, :2]), + rel_head_a2a], dim=-1) + r_a2a = self.r_a2a_emb(continuous_inputs=r_a2a, categorical_embs=None) # [N, hidden_dim] + + ego_head_vector_a = head_vector_a[av_index].repeat_interleave(repeats=data['batch_size_pl'], dim=0) + ego_head_vector_s = ego_head_vector_a.transpose(0, 1).reshape(-1, 2) + r_pl2a = torch.stack( + [torch.norm(rel_pos_pl2a[:, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=ego_head_vector_s, nbr_vector=rel_pos_pl2a[:, :2]), + rel_head_pl2a], dim=-1) + r_pl2a = self.r_pt2a_emb(continuous_inputs=r_pl2a, categorical_embs=None) # [M, d] + + r_a2a = r_a2a.reshape(num_step, num_agent, -1).transpose(0, 1) + r_pl2a = r_pl2a.reshape(num_step, num_pt, -1).transpose(0, 1) + select_agent = torch.randperm(num_agent)[:self.agent_limit] + select_pt = torch.randperm(num_pt)[:self.pt_limit] + r_a2a = r_a2a[select_agent] + r_pl2a = r_pl2a[select_pt] + + # aggregate to global feature + r_a2a = r_a2a.mean(dim=0) # [t, d] + r_pl2a = r_pl2a.mean(dim=0) + + # decode grid index of neighbor agents + agent_occ = self.grid_agent_occ_head(r_a2a) # [t, grid_size] + pt_occ = self.grid_pt_occ_head(r_pl2a) + + # 1. + # agent_occ_gt = torch.zeros_like(agent_occ).long() + # pt_occ_gt = torch.zeros_like(pt_occ).long() + + # for t in range(num_step): + # agent_occ_gt[t, agent_grid_token_idx[:, t][agent_grid_token_idx[:, t] != -1]] = 1 + # pt_occ_gt[t, pt_grid_token_idx[t][pt_grid_token_idx[t] != -1]] = 1 + + # agent_occ_gt[:, self.grid_size // 2] = 0 + # pt_occ_gt[:, self.grid_size // 2] = 0 + + # agent_occ_eval_mask = torch.ones_like(agent_occ_gt) + # agent_occ_eval_mask[0] = 0 + # agent_occ_eval_mask[:, self.grid_size // 2] = 0 + # pt_occ_eval_mask = torch.ones_like(pt_occ_gt) + # pt_occ_eval_mask[0] = 0 + # pt_occ_eval_mask[:, self.grid_size // 2] = 0 + + # 2. + # agent_occ_gt = agent_grid_token_idx.transpose(0, 1).reshape(-1) + # pt_occ_gt = pt_grid_token_idx.reshape(-1) + + # agent_occ_eval_mask = torch.zeros_like(agent_occ_gt) + # agent_occ_eval_mask[torch.randperm(agent_occ_gt.shape[0])[:(num_step * 10)]] = 1 + # agent_occ_eval_mask[agent_occ_gt == -1] = 0 + + # pt_occ_eval_mask = torch.zeros_like(pt_occ_gt) + # pt_occ_eval_mask[torch.randperm(pt_occ_gt.shape[0])[:(num_step * 300)]] = 1 + # pt_occ_eval_mask[pt_occ_gt == -1] = 0 + + # 3. + agent_occ = agent_occ.reshape(num_step, self.agent_limit, -1) + pt_occ = pt_occ.reshape(num_step, self.pt_limit, -1) + agent_occ_gt = agent_grid_token_idx[select_agent].transpose(0, 1) + pt_occ_gt = pt_grid_token_idx[:, select_pt] + agent_occ_eval_mask = agent_occ_gt != -1 + pt_occ_eval_mask = pt_occ_gt != -1 + + agent_occ = agent_occ[:, :agent_occ_gt.shape[1]] + pt_occ = pt_occ[:, :pt_occ_gt.shape[1]] + + return {'occ_decoder': True, + 'num_step': num_step, + 'num_agent': self.agent_limit, # num_agent + 'num_pt': self.pt_limit, # num_pt + 'agent_occ': agent_occ, + 'agent_occ_gt': agent_occ_gt, + 'agent_occ_eval_mask': agent_occ_eval_mask.bool(), + 'pt_occ': pt_occ, + 'pt_occ_gt': pt_occ_gt, + 'pt_occ_eval_mask': pt_occ_eval_mask.bool(), + } + + def inference(self, *args, **kwargs): + return self(*args, **kwargs) + diff --git a/backups/dev/modules/smart_decoder.py b/backups/dev/modules/smart_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..34cbf79af76ab2b284a3cae719d1ec2e92b2bf55 --- /dev/null +++ b/backups/dev/modules/smart_decoder.py @@ -0,0 +1,137 @@ +from typing import Dict, Optional +import torch +import torch.nn as nn +from torch_geometric.data import HeteroData +from dev.modules.attr_tokenizer import Attr_Tokenizer +from dev.modules.agent_decoder import SMARTAgentDecoder +from dev.modules.occ_decoder import SMARTOccDecoder +from dev.modules.map_decoder import SMARTMapDecoder + + +DECODER = {'agent_decoder': SMARTAgentDecoder, + 'occ_decoder': SMARTOccDecoder} + + +class SMARTDecoder(nn.Module): + + def __init__(self, + decoder_type: str, + dataset: str, + input_dim: int, + hidden_dim: int, + num_historical_steps: int, + pl2pl_radius: float, + time_span: Optional[int], + pl2a_radius: float, + pl2seed_radius: float, + a2a_radius: float, + a2sa_radius: float, + pl2sa_radius: float, + num_freq_bands: int, + num_map_layers: int, + num_agent_layers: int, + num_heads: int, + head_dim: int, + dropout: float, + map_token: Dict, + token_size=512, + attr_tokenizer: Attr_Tokenizer=None, + predict_motion: bool=False, + predict_state: bool=False, + predict_map: bool=False, + predict_occ: bool=False, + use_grid_token: bool=False, + state_token: Dict[str, int]=None, + seed_size: int=5, + buffer_size: int=32, + num_recurrent_steps_val: int=-1, + loss_weight: dict=None, + logger=None) -> None: + + super(SMARTDecoder, self).__init__() + + self.map_encoder = SMARTMapDecoder( + dataset=dataset, + input_dim=input_dim, + hidden_dim=hidden_dim, + num_historical_steps=num_historical_steps, + pl2pl_radius=pl2pl_radius, + num_freq_bands=num_freq_bands, + num_layers=num_map_layers, + num_heads=num_heads, + head_dim=head_dim, + dropout=dropout, + map_token=map_token, + ) + + assert decoder_type in list(DECODER.keys()), f"Unsupport decoder type: {decoder_type}" + self.agent_encoder = DECODER[decoder_type]( + dataset=dataset, + input_dim=input_dim, + hidden_dim=hidden_dim, + num_historical_steps=num_historical_steps, + time_span=time_span, + pl2a_radius=pl2a_radius, + pl2seed_radius=pl2seed_radius, + a2a_radius=a2a_radius, + a2sa_radius=a2sa_radius, + pl2sa_radius=pl2sa_radius, + num_freq_bands=num_freq_bands, + num_layers=num_agent_layers, + num_heads=num_heads, + head_dim=head_dim, + dropout=dropout, + token_size=token_size, + attr_tokenizer=attr_tokenizer, + predict_motion=predict_motion, + predict_state=predict_state, + predict_map=predict_map, + predict_occ=predict_occ, + state_token=state_token, + use_grid_token=use_grid_token, + seed_size=seed_size, + buffer_size=buffer_size, + num_recurrent_steps_val=num_recurrent_steps_val, + loss_weight=loss_weight, + logger=logger, + ) + self.map_enc = None + self.predict_motion = predict_motion + self.predict_state = predict_state + self.predict_map = predict_map + self.predict_occ = predict_occ + self.data_keys = ["agent_valid_mask", "category", "valid_mask", "av_index", "scenario_id", "shape"] + + def get_agent_inputs(self, data: HeteroData) -> Dict[str, torch.Tensor]: + return self.agent_encoder.get_inputs(data) + + def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]: + map_enc = self.map_encoder(data) + + agent_enc = {} + if self.predict_motion or self.predict_state or self.predict_occ: + agent_enc = self.agent_encoder(data, map_enc) + + return {**map_enc, **agent_enc, **{k: data[k] for k in self.data_keys}} + + def inference(self, data: HeteroData) -> Dict[str, torch.Tensor]: + map_enc = self.map_encoder(data) + + agent_enc = {} + if self.predict_motion or self.predict_state or self.predict_occ: + agent_enc = self.agent_encoder.inference(data, map_enc) + + return {**map_enc, **agent_enc, **{k: data[k] for k in self.data_keys}} + + def inference_no_map(self, data: HeteroData, map_enc) -> Dict[str, torch.Tensor]: + agent_enc = self.agent_encoder.inference(data, map_enc) + return {**map_enc, **agent_enc} + + def insert_agent(self, data: HeteroData) -> Dict[str, torch.Tensor]: + map_enc = self.map_encoder(data) + agent_enc = self.agent_encoder.insert(data, map_enc) + return {**map_enc, **agent_enc, **{k: data[k] for k in self.data_keys}} + + def predict_nearest_pos(self, data: HeteroData, rank) -> Dict[str, torch.Tensor]: + map_enc = self.map_encoder(data) + self.agent_encoder.predict_nearest_pos(data, map_enc, rank) diff --git a/backups/dev/utils/cluster_reader.py b/backups/dev/utils/cluster_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..b0ef689d750196227038d28ce0dda99dcee9fab9 --- /dev/null +++ b/backups/dev/utils/cluster_reader.py @@ -0,0 +1,45 @@ +import io +import pickle +import pandas as pd +import json + + +class LoadScenarioFromCeph: + def __init__(self): + from petrel_client.client import Client + self.file_client = Client('~/petreloss.conf') + + def list(self, dir_path): + return list(self.file_client.list(dir_path)) + + def save(self, data, url): + self.file_client.put(url, pickle.dumps(data)) + + def read_correct_csv(self, scenario_path): + output = pd.read_csv(io.StringIO(self.file_client.get(scenario_path).decode('utf-8')), engine="python") + return output + + def contains(self, url): + return self.file_client.contains(url) + + def read_string(self, csv_url): + from io import StringIO + df = pd.read_csv(StringIO(str(self.file_client.get(csv_url), 'utf-8')), sep='\s+', low_memory=False) + return df + + def read(self, scenario_path): + with io.BytesIO(self.file_client.get(scenario_path)) as f: + datas = pickle.load(f) + return datas + + def read_json(self, path): + with io.BytesIO(self.file_client.get(path)) as f: + data = json.load(f) + return data + + def read_csv(self, scenario_path): + return pickle.loads(self.file_client.get(scenario_path)) + + def read_model(self, model_path): + with io.BytesIO(self.file_client.get(model_path)) as f: + pass diff --git a/backups/dev/utils/func.py b/backups/dev/utils/func.py new file mode 100644 index 0000000000000000000000000000000000000000..0161701911b22283914e50700b16fed5e8bd975d --- /dev/null +++ b/backups/dev/utils/func.py @@ -0,0 +1,260 @@ +import logging +import time +import os +import yaml +import easydict +import math +import torch +import torch.nn as nn +from rich.console import Console +from typing import Any, List, Optional, Mapping + +from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only + + +CONSOLE = Console(width=128) + + +def check_nan_inf(t, s): + assert not torch.isinf(t).any(), f"{s} is inf, {t}" + assert not torch.isnan(t).any(), f"{s} is nan, {t}" + + +def safe_list_index(ls: List[Any], elem: Any) -> Optional[int]: + try: + return ls.index(elem) + except ValueError: + return None + + +def angle_between_2d_vectors( + ctr_vector: torch.Tensor, + nbr_vector: torch.Tensor) -> torch.Tensor: + return torch.atan2(ctr_vector[..., 0] * nbr_vector[..., 1] - ctr_vector[..., 1] * nbr_vector[..., 0], + (ctr_vector[..., :2] * nbr_vector[..., :2]).sum(dim=-1)) + + +def angle_between_3d_vectors( + ctr_vector: torch.Tensor, + nbr_vector: torch.Tensor) -> torch.Tensor: + return torch.atan2(torch.cross(ctr_vector, nbr_vector, dim=-1).norm(p=2, dim=-1), + (ctr_vector * nbr_vector).sum(dim=-1)) + + +def side_to_directed_lineseg( + query_point: torch.Tensor, + start_point: torch.Tensor, + end_point: torch.Tensor) -> str: + cond = ((end_point[0] - start_point[0]) * (query_point[1] - start_point[1]) - + (end_point[1] - start_point[1]) * (query_point[0] - start_point[0])) + if cond > 0: + return 'LEFT' + elif cond < 0: + return 'RIGHT' + else: + return 'CENTER' + + +def wrap_angle( + angle: torch.Tensor, + min_val: float = -math.pi, + max_val: float = math.pi) -> torch.Tensor: + return min_val + (angle + max_val) % (max_val - min_val) + + +def load_config_act(path): + """ load config file""" + with open(path, 'r') as f: + cfg = yaml.load(f, Loader=yaml.FullLoader) + return easydict.EasyDict(cfg) + + +def load_config_init(path): + """ load config file""" + path = os.path.join('init/configs', f'{path}.yaml') + with open(path, 'r') as f: + cfg = yaml.load(f, Loader=yaml.FullLoader) + return cfg + + +class Logging: + + def make_log_dir(self, dirname='logs'): + now_dir = os.path.dirname(__file__) + path = os.path.join(now_dir, dirname) + path = os.path.normpath(path) + if not os.path.exists(path): + os.mkdir(path) + return path + + def get_log_filename(self): + filename = "{}.log".format(time.strftime("%Y-%m-%d-%H%M%S", time.localtime())) + filename = os.path.join(self.make_log_dir(), filename) + filename = os.path.normpath(filename) + return filename + + def log(self, level='DEBUG', name="simagent"): + logger = logging.getLogger(name) + level = getattr(logging, level) + logger.setLevel(level) + if not logger.handlers: + sh = logging.StreamHandler() + fh = logging.FileHandler(filename=self.get_log_filename(), mode='a',encoding="utf-8") + fmt = logging.Formatter("%(asctime)s-%(levelname)s-%(filename)s-Line:%(lineno)d-Message:%(message)s") + sh.setFormatter(fmt=fmt) + fh.setFormatter(fmt=fmt) + logger.addHandler(sh) + logger.addHandler(fh) + return logger + + def add_log(self, logger, level='DEBUG'): + level = getattr(logging, level) + logger.setLevel(level) + if not logger.handlers: + sh = logging.StreamHandler() + fh = logging.FileHandler(filename=self.get_log_filename(), mode='a',encoding="utf-8") + fmt = logging.Formatter("%(asctime)s-%(levelname)s-%(filename)s-Line:%(lineno)d-Message:%(message)s") + sh.setFormatter(fmt=fmt) + fh.setFormatter(fmt=fmt) + logger.addHandler(sh) + logger.addHandler(fh) + return logger + + +# Adapted from 'CatK' +class RankedLogger(logging.LoggerAdapter): + """A multi-GPU-friendly python command line logger.""" + + def __init__( + self, + name: str = __name__, + rank_zero_only: bool = False, + extra: Optional[Mapping[str, object]] = None, + ) -> None: + """Initializes a multi-GPU-friendly python command line logger that logs on all processes + with their rank prefixed in the log message. + + :param name: The name of the logger. Default is ``__name__``. + :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. + :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. + """ + logger = logging.getLogger(name) + super().__init__(logger=logger, extra=extra) + self.rank_zero_only = rank_zero_only + + def log( + self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs + ) -> None: + """Delegate a log call to the underlying logger, after prefixing its message with the rank + of the process it's being logged from. If `'rank'` is provided, then the log will only + occur on that rank/process. + + :param level: The level to log at. Look at `logging.__init__.py` for more information. + :param msg: The message to log. + :param rank: The rank to log at. + :param args: Additional args to pass to the underlying logging function. + :param kwargs: Any additional keyword args to pass to the underlying logging function. + """ + if self.isEnabledFor(level): + msg, kwargs = self.process(msg, kwargs) + current_rank = getattr(rank_zero_only, "rank", None) + if current_rank is None: + raise RuntimeError( + "The `rank_zero_only.rank` needs to be set before use" + ) + msg = rank_prefixed_message(msg, current_rank) + if self.rank_zero_only: + if current_rank == 0: + self.logger.log(level, msg, *args, **kwargs) + else: + if rank is None: + self.logger.log(level, msg, *args, **kwargs) + elif current_rank == rank: + self.logger.log(level, msg, *args, **kwargs) + + + +def weight_init(m: nn.Module) -> None: + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + fan_in = m.in_channels / m.groups + fan_out = m.out_channels / m.groups + bound = (6.0 / (fan_in + fan_out)) ** 0.5 + nn.init.uniform_(m.weight, -bound, bound) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, mean=0.0, std=0.02) + elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.MultiheadAttention): + if m.in_proj_weight is not None: + fan_in = m.embed_dim + fan_out = m.embed_dim + bound = (6.0 / (fan_in + fan_out)) ** 0.5 + nn.init.uniform_(m.in_proj_weight, -bound, bound) + else: + nn.init.xavier_uniform_(m.q_proj_weight) + nn.init.xavier_uniform_(m.k_proj_weight) + nn.init.xavier_uniform_(m.v_proj_weight) + if m.in_proj_bias is not None: + nn.init.zeros_(m.in_proj_bias) + nn.init.xavier_uniform_(m.out_proj.weight) + if m.out_proj.bias is not None: + nn.init.zeros_(m.out_proj.bias) + if m.bias_k is not None: + nn.init.normal_(m.bias_k, mean=0.0, std=0.02) + if m.bias_v is not None: + nn.init.normal_(m.bias_v, mean=0.0, std=0.02) + elif isinstance(m, (nn.LSTM, nn.LSTMCell)): + for name, param in m.named_parameters(): + if 'weight_ih' in name: + for ih in param.chunk(4, 0): + nn.init.xavier_uniform_(ih) + elif 'weight_hh' in name: + for hh in param.chunk(4, 0): + nn.init.orthogonal_(hh) + elif 'weight_hr' in name: + nn.init.xavier_uniform_(param) + elif 'bias_ih' in name: + nn.init.zeros_(param) + elif 'bias_hh' in name: + nn.init.zeros_(param) + nn.init.ones_(param.chunk(4, 0)[1]) + elif isinstance(m, (nn.GRU, nn.GRUCell)): + for name, param in m.named_parameters(): + if 'weight_ih' in name: + for ih in param.chunk(3, 0): + nn.init.xavier_uniform_(ih) + elif 'weight_hh' in name: + for hh in param.chunk(3, 0): + nn.init.orthogonal_(hh) + elif 'bias_ih' in name: + nn.init.zeros_(param) + elif 'bias_hh' in name: + nn.init.zeros_(param) + + +def pos2posemb(pos, num_pos_feats=128, temperature=10000): + + scale = 2 * math.pi + pos = pos * scale + dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device) + dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) + + D = pos.shape[-1] + pos_dims = [] + for i in range(D): + pos_dim_i = pos[..., i, None] / dim_t + pos_dim_i = torch.stack((pos_dim_i[..., 0::2].sin(), pos_dim_i[..., 1::2].cos()), dim=-1).flatten(-2) + pos_dims.append(pos_dim_i) + posemb = torch.cat(pos_dims, dim=-1) + + return posemb \ No newline at end of file diff --git a/backups/dev/utils/graph.py b/backups/dev/utils/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..3ae3c1854f506e7037f39fe1dc6f39c336456147 --- /dev/null +++ b/backups/dev/utils/graph.py @@ -0,0 +1,89 @@ + +import torch +from typing import List, Optional, Tuple, Union +from torch_geometric.utils import coalesce +from torch_geometric.utils import degree + + +def add_edges( + from_edge_index: torch.Tensor, + to_edge_index: torch.Tensor, + from_edge_attr: Optional[torch.Tensor] = None, + to_edge_attr: Optional[torch.Tensor] = None, + replace: bool = True) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + from_edge_index = from_edge_index.to(device=to_edge_index.device, dtype=to_edge_index.dtype) + mask = ((to_edge_index[0].unsqueeze(-1) == from_edge_index[0].unsqueeze(0)) & + (to_edge_index[1].unsqueeze(-1) == from_edge_index[1].unsqueeze(0))) + if replace: + to_mask = mask.any(dim=1) + if from_edge_attr is not None and to_edge_attr is not None: + from_edge_attr = from_edge_attr.to(device=to_edge_attr.device, dtype=to_edge_attr.dtype) + to_edge_attr = torch.cat([to_edge_attr[~to_mask], from_edge_attr], dim=0) + to_edge_index = torch.cat([to_edge_index[:, ~to_mask], from_edge_index], dim=1) + else: + from_mask = mask.any(dim=0) + if from_edge_attr is not None and to_edge_attr is not None: + from_edge_attr = from_edge_attr.to(device=to_edge_attr.device, dtype=to_edge_attr.dtype) + to_edge_attr = torch.cat([to_edge_attr, from_edge_attr[~from_mask]], dim=0) + to_edge_index = torch.cat([to_edge_index, from_edge_index[:, ~from_mask]], dim=1) + return to_edge_index, to_edge_attr + + +def merge_edges( + edge_indices: List[torch.Tensor], + edge_attrs: Optional[List[torch.Tensor]] = None, + reduce: str = 'add') -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + edge_index = torch.cat(edge_indices, dim=1) + if edge_attrs is not None: + edge_attr = torch.cat(edge_attrs, dim=0) + else: + edge_attr = None + return coalesce(edge_index=edge_index, edge_attr=edge_attr, reduce=reduce) + + +def complete_graph( + num_nodes: Union[int, Tuple[int, int]], + ptr: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + loop: bool = False, + device: Optional[Union[torch.device, str]] = None) -> torch.Tensor: + if ptr is None: + if isinstance(num_nodes, int): + num_src, num_dst = num_nodes, num_nodes + else: + num_src, num_dst = num_nodes + edge_index = torch.cartesian_prod(torch.arange(num_src, dtype=torch.long, device=device), + torch.arange(num_dst, dtype=torch.long, device=device)).t() + else: + if isinstance(ptr, torch.Tensor): + ptr_src, ptr_dst = ptr, ptr + num_src_batch = num_dst_batch = ptr[1:] - ptr[:-1] + else: + ptr_src, ptr_dst = ptr + num_src_batch = ptr_src[1:] - ptr_src[:-1] + num_dst_batch = ptr_dst[1:] - ptr_dst[:-1] + edge_index = torch.cat( + [torch.cartesian_prod(torch.arange(num_src, dtype=torch.long, device=device), + torch.arange(num_dst, dtype=torch.long, device=device)) + p + for num_src, num_dst, p in zip(num_src_batch, num_dst_batch, torch.stack([ptr_src, ptr_dst], dim=1))], + dim=0) + edge_index = edge_index.t() + if isinstance(num_nodes, int) and not loop: + edge_index = edge_index[:, edge_index[0] != edge_index[1]] + return edge_index.contiguous() + + +def bipartite_dense_to_sparse(adj: torch.Tensor) -> torch.Tensor: + index = adj.nonzero(as_tuple=True) + if len(index) == 3: + batch_src = index[0] * adj.size(1) + batch_dst = index[0] * adj.size(2) + index = (batch_src + index[1], batch_dst + index[2]) + return torch.stack(index, dim=0) + + +def unbatch( + src: torch.Tensor, + batch: torch.Tensor, + dim: int = 0) -> List[torch.Tensor]: + sizes = degree(batch, dtype=torch.long).tolist() + return src.split(sizes, dim) diff --git a/backups/dev/utils/metrics.py b/backups/dev/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..461f70cb1b04e431117d677382f55e3f48e35a10 --- /dev/null +++ b/backups/dev/utils/metrics.py @@ -0,0 +1,692 @@ +import torch +import os +import itertools +import multiprocessing as mp +import torch.nn.functional as F +from pathlib import Path +from torch.nn import CrossEntropyLoss +from torch_scatter import gather_csr +from torch_scatter import segment_csr +from torchmetrics import Metric +from typing import Optional, Tuple, Dict, List + + +__all__ = ['minADE', 'minFDE', 'TokenCls', 'StateAccuracy', 'GridOverlapRate'] + + +class CustomCrossEntropyLoss(CrossEntropyLoss): + + def __init__(self, label_smoothing=0.0, reduction='mean'): + super(CustomCrossEntropyLoss, self).__init__() + self.label_smoothing = label_smoothing + self.reduction = reduction + + def forward(self, input, target): + num_classes = input.size(1) + + log_probs = F.log_softmax(input, dim=1) + + with torch.no_grad(): + smooth_target = torch.zeros_like(log_probs).scatter_(1, target.unsqueeze(1), 1) + smooth_target = smooth_target * (1 - self.label_smoothing) + self.label_smoothing / num_classes + + loss = -torch.sum(log_probs * smooth_target, dim=1) + + if self.reduction == 'mean': + return loss.mean() + elif self.reduction == 'sum': + return loss.sum() + else: + return loss + + +def topk( + max_guesses: int, + pred: torch.Tensor, + prob: Optional[torch.Tensor] = None, + ptr: Optional[torch.Tensor] = None, + joint: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + max_guesses = min(max_guesses, pred.size(1)) + if max_guesses == pred.size(1): + if prob is not None: + prob = prob / prob.sum(dim=-1, keepdim=True) + else: + prob = pred.new_ones((pred.size(0), max_guesses)) / max_guesses + return pred, prob + else: + if prob is not None: + if joint: + if ptr is None: + inds_topk = torch.topk((prob / prob.sum(dim=-1, keepdim=True)).mean(dim=0, keepdim=True), + k=max_guesses, dim=-1, largest=True, sorted=True)[1] + inds_topk = inds_topk.repeat(pred.size(0), 1) + else: + inds_topk = torch.topk(segment_csr(src=prob / prob.sum(dim=-1, keepdim=True), indptr=ptr, + reduce='mean'), + k=max_guesses, dim=-1, largest=True, sorted=True)[1] + inds_topk = gather_csr(src=inds_topk, indptr=ptr) + else: + inds_topk = torch.topk(prob, k=max_guesses, dim=-1, largest=True, sorted=True)[1] + pred_topk = pred[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk] + prob_topk = prob[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk] + prob_topk = prob_topk / prob_topk.sum(dim=-1, keepdim=True) + else: + pred_topk = pred[:, :max_guesses] + prob_topk = pred.new_ones((pred.size(0), max_guesses)) / max_guesses + return pred_topk, prob_topk + + +def topkind( + max_guesses: int, + pred: torch.Tensor, + prob: Optional[torch.Tensor] = None, + ptr: Optional[torch.Tensor] = None, + joint: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + max_guesses = min(max_guesses, pred.size(1)) + if max_guesses == pred.size(1): + if prob is not None: + prob = prob / prob.sum(dim=-1, keepdim=True) + else: + prob = pred.new_ones((pred.size(0), max_guesses)) / max_guesses + return pred, prob, None + else: + if prob is not None: + if joint: + if ptr is None: + inds_topk = torch.topk((prob / prob.sum(dim=-1, keepdim=True)).mean(dim=0, keepdim=True), + k=max_guesses, dim=-1, largest=True, sorted=True)[1] + inds_topk = inds_topk.repeat(pred.size(0), 1) + else: + inds_topk = torch.topk(segment_csr(src=prob / prob.sum(dim=-1, keepdim=True), indptr=ptr, + reduce='mean'), + k=max_guesses, dim=-1, largest=True, sorted=True)[1] + inds_topk = gather_csr(src=inds_topk, indptr=ptr) + else: + inds_topk = torch.topk(prob, k=max_guesses, dim=-1, largest=True, sorted=True)[1] + pred_topk = pred[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk] + prob_topk = prob[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk] + prob_topk = prob_topk / prob_topk.sum(dim=-1, keepdim=True) + else: + pred_topk = pred[:, :max_guesses] + prob_topk = pred.new_ones((pred.size(0), max_guesses)) / max_guesses + return pred_topk, prob_topk, inds_topk + + +def valid_filter( + pred: torch.Tensor, + target: torch.Tensor, + prob: Optional[torch.Tensor] = None, + valid_mask: Optional[torch.Tensor] = None, + ptr: Optional[torch.Tensor] = None, + keep_invalid_final_step: bool = True) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], + torch.Tensor, torch.Tensor]: + if valid_mask is None: + valid_mask = target.new_ones(target.size()[:-1], dtype=torch.bool) + if keep_invalid_final_step: + filter_mask = valid_mask.any(dim=-1) + else: + filter_mask = valid_mask[:, -1] + pred = pred[filter_mask] + target = target[filter_mask] + if prob is not None: + prob = prob[filter_mask] + valid_mask = valid_mask[filter_mask] + if ptr is not None: + num_nodes_batch = segment_csr(src=filter_mask.long(), indptr=ptr, reduce='sum') + ptr = num_nodes_batch.new_zeros((num_nodes_batch.size(0) + 1,)) + torch.cumsum(num_nodes_batch, dim=0, out=ptr[1:]) + else: + ptr = target.new_tensor([0, target.size(0)]) + return pred, target, prob, valid_mask, ptr + + +def new_batch_nms(pred_trajs, dist_thresh, num_ret_modes=6): + """ + + Args: + pred_trajs (batch_size, num_modes, num_timestamps, 7) + pred_scores (batch_size, num_modes): + dist_thresh (float): + num_ret_modes (int, optional): Defaults to 6. + + Returns: + ret_trajs (batch_size, num_ret_modes, num_timestamps, 5) + ret_scores (batch_size, num_ret_modes) + ret_idxs (batch_size, num_ret_modes) + """ + batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape + pred_goals = pred_trajs[:, :, -1, :] + dist = (pred_goals[:, :, None, 0:2] - pred_goals[:, None, :, 0:2]).norm(dim=-1) + nearby_neighbor = dist < dist_thresh + pred_scores = nearby_neighbor.sum(dim=-1) / num_modes + + sorted_idxs = pred_scores.argsort(dim=-1, descending=True) + bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes) + sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs] + sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7) + sorted_pred_goals = sorted_pred_trajs[:, :, -1, :] # (batch_size, num_modes, 7) + + dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1) + point_cover_mask = (dist < dist_thresh) + + point_val = sorted_pred_scores.clone() # (batch_size, N) + point_val_selected = torch.zeros_like(point_val) # (batch_size, N) + + ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long() + ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim) + ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes) + bs_idxs = torch.arange(batch_size).type_as(ret_idxs) + + for k in range(num_ret_modes): + cur_idx = point_val.argmax(dim=-1) # (batch_size) + ret_idxs[:, k] = cur_idx + + new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N) + point_val = point_val * (~new_cover_mask).float() # (batch_size, N) + point_val_selected[bs_idxs, cur_idx] = -1 + point_val += point_val_selected + + ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx] + ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx] + + bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes) + + ret_idxs = sorted_idxs[bs_idxs, ret_idxs] + return ret_trajs, ret_scores, ret_idxs + + +def batch_nms(pred_trajs, pred_scores, + dist_thresh, num_ret_modes=6, + mode='static', speed=None): + """ + + Args: + pred_trajs (batch_size, num_modes, num_timestamps, 7) + pred_scores (batch_size, num_modes): + dist_thresh (float): + num_ret_modes (int, optional): Defaults to 6. + + Returns: + ret_trajs (batch_size, num_ret_modes, num_timestamps, 5) + ret_scores (batch_size, num_ret_modes) + ret_idxs (batch_size, num_ret_modes) + """ + batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape + + sorted_idxs = pred_scores.argsort(dim=-1, descending=True) + bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes) + sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs] + sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7) + sorted_pred_goals = sorted_pred_trajs[:, :, -1, :] # (batch_size, num_modes, 7) + + if mode == "speed": + scale = torch.ones(batch_size).to(sorted_pred_goals.device) + lon_dist_thresh = 4 * scale + lat_dist_thresh = 0.5 * scale + lon_dist = (sorted_pred_goals[:, :, None, [0]] - sorted_pred_goals[:, None, :, [0]]).norm(dim=-1) + lat_dist = (sorted_pred_goals[:, :, None, [1]] - sorted_pred_goals[:, None, :, [1]]).norm(dim=-1) + point_cover_mask = (lon_dist < lon_dist_thresh[:, None, None]) & (lat_dist < lat_dist_thresh[:, None, None]) + else: + dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1) + point_cover_mask = (dist < dist_thresh) + + point_val = sorted_pred_scores.clone() # (batch_size, N) + point_val_selected = torch.zeros_like(point_val) # (batch_size, N) + + ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long() + ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim) + ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes) + bs_idxs = torch.arange(batch_size).type_as(ret_idxs) + + for k in range(num_ret_modes): + cur_idx = point_val.argmax(dim=-1) # (batch_size) + ret_idxs[:, k] = cur_idx + + new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N) + point_val = point_val * (~new_cover_mask).float() # (batch_size, N) + point_val_selected[bs_idxs, cur_idx] = -1 + point_val += point_val_selected + + ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx] + ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx] + + bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes) + + ret_idxs = sorted_idxs[bs_idxs, ret_idxs] + return ret_trajs, ret_scores, ret_idxs + + +def batch_nms_token(pred_trajs, pred_scores, + dist_thresh, num_ret_modes=6, + mode='static', speed=None): + """ + Args: + pred_trajs (batch_size, num_modes, num_timestamps, 7) + pred_scores (batch_size, num_modes): + dist_thresh (float): + num_ret_modes (int, optional): Defaults to 6. + + Returns: + ret_trajs (batch_size, num_ret_modes, num_timestamps, 5) + ret_scores (batch_size, num_ret_modes) + ret_idxs (batch_size, num_ret_modes) + """ + batch_size, num_modes, num_feat_dim = pred_trajs.shape + + sorted_idxs = pred_scores.argsort(dim=-1, descending=True) + bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes) + sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs] + sorted_pred_goals = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7) + + if mode == "nearby": + dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1) + values, indices = torch.topk(dist, 5, dim=-1, largest=False) + thresh_hold = values[..., -1] + point_cover_mask = dist < thresh_hold[..., None] + else: + dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1) + point_cover_mask = (dist < dist_thresh) + + point_val = sorted_pred_scores.clone() # (batch_size, N) + point_val_selected = torch.zeros_like(point_val) # (batch_size, N) + + ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long() + ret_goals = sorted_pred_goals.new_zeros(batch_size, num_ret_modes, num_feat_dim) + ret_scores = sorted_pred_goals.new_zeros(batch_size, num_ret_modes) + bs_idxs = torch.arange(batch_size).type_as(ret_idxs) + + for k in range(num_ret_modes): + cur_idx = point_val.argmax(dim=-1) # (batch_size) + ret_idxs[:, k] = cur_idx + + new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N) + point_val = point_val * (~new_cover_mask).float() # (batch_size, N) + point_val_selected[bs_idxs, cur_idx] = -1 + point_val += point_val_selected + + ret_goals[:, k] = sorted_pred_goals[bs_idxs, cur_idx] + ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx] + + bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes) + + ret_idxs = sorted_idxs[bs_idxs, ret_idxs] + return ret_goals, ret_scores, ret_idxs + + +class TokenCls(Metric): + + def __init__(self, + max_guesses: int = 6, + **kwargs) -> None: + super(TokenCls, self).__init__(**kwargs) + self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') + self.max_guesses = max_guesses + + def update(self, + pred: torch.Tensor, + target: torch.Tensor, + valid_mask: Optional[torch.Tensor] = None) -> None: + target = target[..., None] + acc = (pred[:, :self.max_guesses] == target).any(dim=1) * valid_mask + self.sum += acc.sum() + self.count += valid_mask.sum() + + def compute(self) -> torch.Tensor: + return self.sum / self.count + + +class minMultiFDE(Metric): + + def __init__(self, + max_guesses: int = 6, + **kwargs) -> None: + super(minMultiFDE, self).__init__(**kwargs) + self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') + self.max_guesses = max_guesses + + def update(self, + pred: torch.Tensor, + target: torch.Tensor, + prob: Optional[torch.Tensor] = None, + valid_mask: Optional[torch.Tensor] = None, + keep_invalid_final_step: bool = True) -> None: + pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step) + pred_topk, _ = topk(self.max_guesses, pred, prob) + inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1) + self.sum += torch.norm(pred_topk[torch.arange(pred.size(0)), :, inds_last] - + target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), + p=2, dim=-1).min(dim=-1)[0].sum() + self.count += pred.size(0) + + def compute(self) -> torch.Tensor: + return self.sum / self.count + + +class minFDE(Metric): + + def __init__(self, + max_guesses: int = 6, + **kwargs) -> None: + super(minFDE, self).__init__(**kwargs) + self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') + self.max_guesses = max_guesses + self.eval_timestep = 70 + + def update(self, + pred: torch.Tensor, + target: torch.Tensor, + prob: Optional[torch.Tensor] = None, + valid_mask: Optional[torch.Tensor] = None, + keep_invalid_final_step: bool = True) -> None: + eval_timestep = min(self.eval_timestep, pred.shape[1]) - 1 + self.sum += ((torch.norm(pred[:, eval_timestep-1:eval_timestep] - target[:, eval_timestep-1:eval_timestep], p=2, dim=-1) * + valid_mask[:, eval_timestep-1].unsqueeze(1)).sum(dim=-1)).sum() + self.count += valid_mask[:, eval_timestep-1].sum() + + def compute(self) -> torch.Tensor: + return self.sum / self.count + + +class minMultiADE(Metric): + + def __init__(self, + max_guesses: int = 6, + **kwargs) -> None: + super(minMultiADE, self).__init__(**kwargs) + self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') + self.max_guesses = max_guesses + + def update(self, + pred: torch.Tensor, + target: torch.Tensor, + prob: Optional[torch.Tensor] = None, + valid_mask: Optional[torch.Tensor] = None, + keep_invalid_final_step: bool = True, + min_criterion: str = 'FDE') -> None: + pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step) + pred_topk, _ = topk(self.max_guesses, pred, prob) + if min_criterion == 'FDE': + inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1) + inds_best = torch.norm( + pred_topk[torch.arange(pred.size(0)), :, inds_last] - + target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), p=2, dim=-1).argmin(dim=-1) + self.sum += ((torch.norm(pred_topk[torch.arange(pred.size(0)), inds_best] - target, p=2, dim=-1) * + valid_mask).sum(dim=-1) / valid_mask.sum(dim=-1)).sum() + elif min_criterion == 'ADE': + self.sum += ((torch.norm(pred_topk - target.unsqueeze(1), p=2, dim=-1) * + valid_mask.unsqueeze(1)).sum(dim=-1).min(dim=-1)[0] / valid_mask.sum(dim=-1)).sum() + else: + raise ValueError('{} is not a valid criterion'.format(min_criterion)) + self.count += pred.size(0) + + def compute(self) -> torch.Tensor: + return self.sum / self.count + + +class minADE(Metric): + + def __init__(self, + max_guesses: int = 6, + **kwargs) -> None: + super(minADE, self).__init__(**kwargs) + self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') + self.max_guesses = max_guesses + self.eval_timestep = 70 + + def update(self, + pred: torch.Tensor, + target: torch.Tensor, + prob: Optional[torch.Tensor] = None, + valid_mask: Optional[torch.Tensor] = None, + keep_invalid_final_step: bool = True, + min_criterion: str = 'ADE') -> None: + # pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step) + # pred_topk, _ = topk(self.max_guesses, pred, prob) + # if min_criterion == 'FDE': + # inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1) + # inds_best = torch.norm( + # pred[torch.arange(pred.size(0)), :, inds_last] - + # target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), p=2, dim=-1).argmin(dim=-1) + # self.sum += ((torch.norm(pred[torch.arange(pred.size(0)), inds_best] - target, p=2, dim=-1) * + # valid_mask).sum(dim=-1) / valid_mask.sum(dim=-1)).sum() + # elif min_criterion == 'ADE': + # self.sum += ((torch.norm(pred - target.unsqueeze(1), p=2, dim=-1) * + # valid_mask.unsqueeze(1)).sum(dim=-1).min(dim=-1)[0] / valid_mask.sum(dim=-1)).sum() + # else: + # raise ValueError('{} is not a valid criterion'.format(min_criterion)) + eval_timestep = min(self.eval_timestep, pred.shape[1]) + self.sum += ((torch.norm(pred[:, :eval_timestep] - target[:, :eval_timestep], p=2, dim=-1) * valid_mask[:, :eval_timestep]).sum(dim=-1) / pred.shape[1]).sum() + self.count += valid_mask[:, :eval_timestep].any(dim=-1).sum() + + def compute(self) -> torch.Tensor: + return self.sum / self.count + + +class AverageMeter(Metric): + + def __init__(self, **kwargs) -> None: + super(AverageMeter, self).__init__(**kwargs) + self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') + + def update(self, val: torch.Tensor) -> None: + self.sum += val.sum() + self.count += val.numel() + + def compute(self) -> torch.Tensor: + return self.sum / self.count + + +class StateAccuracy(Metric): + + def __init__(self, state_token: Dict[str, int], **kwargs) -> None: + super().__init__(**kwargs) + self.invalid_state = int(state_token['invalid']) + self.valid_state = int(state_token['valid']) + self.enter_state = int(state_token['enter']) + self.exit_state = int(state_token['exit']) + + self.add_state('valid', default=torch.tensor(0), dist_reduce_fx='sum') + self.add_state('valid_count', default=torch.tensor(0), dist_reduce_fx='sum') + self.add_state('invalid', default=torch.tensor(0), dist_reduce_fx='sum') + self.add_state('invalid_count', default=torch.tensor(0), dist_reduce_fx='sum') + + def update(self, + state_idx: torch.Tensor, + valid_mask: Optional[torch.Tensor] = None) -> None: + + num_agent, num_step = state_idx.shape + + # check the evaluation outputs + for a in range(num_agent): + bos_idx = torch.where(state_idx[a] == self.enter_state)[0] + eos_idx = torch.where(state_idx[a] == self.exit_state)[0] + bos = 0 + eos = num_step - 1 + if len(bos_idx) > 0: + bos = bos_idx[0] + self.invalid += (state_idx[a, :bos] == self.invalid_state).sum() + self.invalid_count += len(state_idx[a, :bos]) + if len(eos_idx) > 0: + eos = eos_idx[0] + self.invalid += (state_idx[a, eos + 1:] == self.invalid_state).sum() + self.invalid_count += len(state_idx[a, eos + 1:]) + self.valid += (state_idx[a, bos + 1 : eos] == self.valid_state).sum() + self.valid_count += len(state_idx[a, bos + 1 : eos]) + + # check the tokenization + if valid_mask is not None: + + state_idx = state_idx.roll(shifts=1, dims=1) + + for a in range(num_agent): + bos_idx = torch.where(state_idx[a] == self.enter_state)[0] + eos_idx = torch.where(state_idx[a] == self.exit_state)[0] + bos = 0 + eos = num_step - 1 + if len(bos_idx) > 0: + bos = bos_idx[0] + self.invalid += (valid_mask[a, :bos] == 0).sum() + self.invalid_count += len(valid_mask[a, :bos]) + if len(eos_idx) > 0: + eos = eos_idx[-1] + self.invalid += (valid_mask[a, eos + 1:] != 0).sum() + self.invalid_count += len(valid_mask[a, eos + 1:]) + self.invalid += (((state_idx[a, bos : eos + 1] > 0) != valid_mask[a, bos : eos + 1])[valid_mask[a, bos : eos + 1] == 0]).sum() + self.invalid_count += (valid_mask[a, bos : eos + 1] == 0).sum() + self.valid += (((state_idx[a, bos : eos + 1] > 0) != valid_mask[a, bos : eos + 1])[valid_mask[a, bos : eos + 1] == 1]).sum() + self.valid_count += (valid_mask[a, bos : eos + 1] == 1).sum() + + def compute(self) -> Dict[str, torch.Tensor]: + return {'valid': self.valid / self.valid_count, + 'invalid': self.invalid / self.invalid_count, + } + + def __repr__(self): + head = "Results of " + self.__class__.__name__ + results = self.compute() + body = [ + "valid: {}".format(results['valid']), + "invalid: {}".format(results['invalid']), + ] + _repr_indent = 4 + lines = [head] + [" " * _repr_indent + line for line in body] + return "\n".join(lines) + + +class GridOverlapRate(Metric): + + def __init__(self, num_step, state_token, seed_size, **kwargs) -> None: + super().__init__(**kwargs) + self.num_step = num_step + self.enter_state = int(state_token['enter']) + self.seed_size = seed_size + self.add_state('num_overlap_t', default=torch.zeros(num_step).long(), dist_reduce_fx='sum') + self.add_state('num_insert_agent_t', default=torch.zeros(num_step).long(), dist_reduce_fx='sum') + self.add_state('num_total_agent_t', default=torch.zeros(num_step).long(), dist_reduce_fx='sum') + self.add_state('num_exceed_seed_t', default=torch.zeros(num_step).long(), dist_reduce_fx='sum') + + def update(self, + state_token: torch.Tensor, + grid_index: torch.Tensor) -> None: + + for t in range(self.num_step): + inrange_mask_t = grid_index[:, t] != -1 + insert_mask_t = (state_token[:, t] == self.enter_state) & inrange_mask_t + self.num_total_agent_t[t] += inrange_mask_t.sum() + self.num_insert_agent_t[t] += insert_mask_t.sum() + self.num_exceed_seed_t[t] += int(insert_mask_t.sum() >= self.seed_size) + + occupied_grids = set(grid_index[:, t][(grid_index[:, t] != -1) & (state_token[:, t] != self.enter_state)].tolist()) + to_inserted_grids = grid_index[:, t][(grid_index[:, t] != -1) & (state_token[:, t] == self.enter_state)].tolist() + while to_inserted_grids: + grid_index_t_i = to_inserted_grids.pop() + if grid_index_t_i in occupied_grids: + self.num_overlap_t[t] += 1 + occupied_grids.add(grid_index_t_i) + + def compute(self) -> Dict[str, torch.Tensor]: + overlap_rate_t = self.num_overlap_t / self.num_insert_agent_t + overlap_rate_t.nan_to_num_() + return {'num_overlap_t': self.num_overlap_t, + 'num_insert_agent_t': self.num_insert_agent_t, + 'num_total_agent_t': self.num_total_agent_t, + 'overlap_rate_t': overlap_rate_t, + 'num_exceed_seed_t': self.num_exceed_seed_t, + } + + def __repr__(self): + head = "Results of " + self.__class__.__name__ + results = self.compute() + body = [ + "num_overlap_t: {}".format(results['num_overlap_t'].tolist()), + "num_insert_agent_t: {}".format(results['num_insert_agent_t'].tolist()), + "num_total_agent_t: {}".format(results['num_total_agent_t'].tolist()), + "overlap_rate_t: {}".format(results['overlap_rate_t'].tolist()), + "num_exceed_seed_t: {}".format(results['num_exceed_seed_t'].tolist()), + ] + _repr_indent = 4 + lines = [head] + [" " * _repr_indent + line for line in body] + return "\n".join(lines) + + +class NumInsertAccuracy(Metric): + + def __init__(self, state_token: Dict[str, int], **kwargs) -> None: + super().__init__(**kwargs) + self.invalid_state = int(state_token['invalid']) + self.valid_state = int(state_token['valid']) + self.enter_state = int(state_token['enter']) + self.exit_state = int(state_token['exit']) + + self.add_state('valid', default=torch.tensor(0), dist_reduce_fx='sum') + self.add_state('valid_count', default=torch.tensor(0), dist_reduce_fx='sum') + self.add_state('invalid', default=torch.tensor(0), dist_reduce_fx='sum') + self.add_state('invalid_count', default=torch.tensor(0), dist_reduce_fx='sum') + + def update(self, + state_idx: torch.Tensor, + valid_mask: Optional[torch.Tensor] = None) -> None: + + num_agent, num_step = state_idx.shape + + # check the evaluation outputs + for a in range(num_agent): + bos_idx = torch.where(state_idx[a] == self.enter_state)[0] + eos_idx = torch.where(state_idx[a] == self.exit_state)[0] + bos = 0 + eos = num_step - 1 + if len(bos_idx) > 0: + bos = bos_idx[0] + self.invalid += (state_idx[a, :bos] == self.invalid_state).sum() + self.invalid_count += len(state_idx[a, :bos]) + if len(eos_idx) > 0: + eos = eos_idx[0] + self.invalid += (state_idx[a, eos + 1:] == self.invalid_state).sum() + self.invalid_count += len(state_idx[a, eos + 1:]) + self.valid += (state_idx[a, bos + 1 : eos] == self.valid_state).sum() + self.valid_count += len(state_idx[a, bos + 1 : eos]) + + # check the tokenization + if valid_mask is not None: + + state_idx = state_idx.roll(shifts=1, dims=1) + + for a in range(num_agent): + bos_idx = torch.where(state_idx[a] == self.enter_state)[0] + eos_idx = torch.where(state_idx[a] == self.exit_state)[0] + bos = 0 + eos = num_step - 1 + if len(bos_idx) > 0: + bos = bos_idx[0] + self.invalid += (valid_mask[a, :bos] == 0).sum() + self.invalid_count += len(valid_mask[a, :bos]) + if len(eos_idx) > 0: + eos = eos_idx[-1] + self.invalid += (valid_mask[a, eos + 1:] != 0).sum() + self.invalid_count += len(valid_mask[a, eos + 1:]) + self.invalid += (((state_idx[a, bos : eos + 1] > 0) != valid_mask[a, bos : eos + 1])[valid_mask[a, bos : eos + 1] == 0]).sum() + self.invalid_count += (valid_mask[a, bos : eos + 1] == 0).sum() + self.valid += (((state_idx[a, bos : eos + 1] > 0) != valid_mask[a, bos : eos + 1])[valid_mask[a, bos : eos + 1] == 1]).sum() + self.valid_count += (valid_mask[a, bos : eos + 1] == 1).sum() + + def compute(self) -> Dict[str, torch.Tensor]: + return {'valid': self.valid / self.valid_count, + 'invalid': self.invalid / self.invalid_count, + } + + def __repr__(self): + head = "Results of " + self.__class__.__name__ + results = self.compute() + body = [ + "valid: {}".format(results['valid']), + "invalid: {}".format(results['invalid']), + ] + _repr_indent = 4 + lines = [head] + [" " * _repr_indent + line for line in body] + return "\n".join(lines) diff --git a/backups/dev/utils/visualization.py b/backups/dev/utils/visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..f754369b26c6f04478fbf2f83cc8375d1140d671 --- /dev/null +++ b/backups/dev/utils/visualization.py @@ -0,0 +1,1145 @@ +import math +import os +import torch +import pickle +import matplotlib.pyplot as plt +import tensorflow as tf +import numpy as np +import numpy.typing as npt +import fnmatch +import seaborn as sns +import matplotlib.axes as Axes +import matplotlib.transforms as mtransforms +from PIL import Image +from functools import wraps +from typing import Sequence, Union, Optional +from tqdm import tqdm +from typing import List, Literal +from argparse import ArgumentParser +from scipy.ndimage.filters import gaussian_filter +from matplotlib.patches import FancyBboxPatch, Polygon, Rectangle, Circle +from matplotlib.collections import LineCollection +from torch_geometric.data import HeteroData, Dataset +from waymo_open_dataset.protos import scenario_pb2 + +from dev.utils.func import CONSOLE +from dev.modules.attr_tokenizer import Attr_Tokenizer +from dev.datasets.preprocess import TokenProcessor, cal_polygon_contour, AGENT_TYPE +from dev.datasets.scalable_dataset import WaymoTargetBuilder + + +__all__ = ['plot_occ_grid', 'plot_interact_edge', 'plot_map_edge', 'plot_insert_grid', 'plot_binary_map', + 'plot_map_token', 'plot_prob_seed', 'plot_scenario', 'get_heatmap', 'draw_heatmap', 'plot_val', 'plot_tokenize'] + + +def safe_run(func): + + @wraps(func) + def wrapper1(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + print(e) + return + + @wraps(func) + def wrapper2(*args, **kwargs): + return func(*args, **kwargs) + + if int(os.getenv('DEBUG', 0)): + return wrapper2 + else: + return wrapper1 + + +@safe_run +def plot_occ_grid(scenario_id, occ, gt_occ=None, save_path='', mode='agent', prefix=''): + + def generate_box_edges(matrix, find_value=1): + y, x = np.where(matrix == find_value) + edges = [] + + for xi, yi in zip(x, y): + edges.append([(xi - 0.5, yi - 0.5), (xi + 0.5, yi - 0.5)]) + edges.append([(xi + 0.5, yi - 0.5), (xi + 0.5, yi + 0.5)]) + edges.append([(xi + 0.5, yi + 0.5), (xi - 0.5, yi + 0.5)]) + edges.append([(xi - 0.5, yi + 0.5), (xi - 0.5, yi - 0.5)]) + + return edges + + os.makedirs(save_path, exist_ok=True) + n = int(math.sqrt(occ.shape[-1])) + + plot_n = 3 + plot_t = 5 + + occ_list = [] + for i in range(plot_n): + for j in range(plot_t): + occ_list.append(occ[i, j].reshape(n, n)) + + occ_gt_list = [] + if gt_occ is not None: + for i in range(plot_n): + for j in range(plot_t): + occ_gt_list.append(gt_occ[i, j].reshape(n, n)) + + row_labels = [f'n={n}' for n in range(plot_n)] + col_labels = [f't={t}' for t in range(plot_t)] + + fig, axes = plt.subplots(plot_n, plot_t, figsize=(9, 6)) + plt.subplots_adjust(wspace=0.1, hspace=0.1) + + for i, ax in enumerate(axes.flat): + # NOTE: do not set vmin and vamx! + ax.imshow(occ_list[i], cmap='viridis', interpolation='nearest') + ax.axis('off') + + if occ_gt_list: + gt_edges = generate_box_edges(occ_gt_list[i]) + gts = LineCollection(gt_edges, colors='blue', linewidths=0.5) + ax.add_collection(gts) + insert_edges = generate_box_edges(occ_gt_list[i], find_value=-1) + inserts = LineCollection(insert_edges, colors='red', linewidths=0.5) + ax.add_collection(inserts) + + ax.add_patch(plt.Rectangle((-0.5, -0.5), occ_list[i].shape[1], occ_list[i].shape[0], + linewidth=2, edgecolor='black', facecolor='none')) + + for i, ax in enumerate(axes[:, 0]): + ax.annotate(row_labels[i], xy=(-0.1, 0.5), xycoords="axes fraction", + fontsize=12, ha="right", va="center", rotation=0) + + for j, ax in enumerate(axes[0, :]): + ax.annotate(col_labels[j], xy=(0.5, 1.05), xycoords="axes fraction", + fontsize=12, ha="center", va="bottom") + + plt.savefig(os.path.join(save_path, f'{prefix}{scenario_id}_occ_{mode}.png'), dpi=500, bbox_inches='tight') + plt.close() + + +@safe_run +def plot_interact_edge(edge_index, scenario_ids, batch_sizes, num_seed, num_step, save_path='interact_edge_map', + **kwargs): + + num_batch = len(scenario_ids) + batches = torch.cat([ + torch.arange(num_batch).repeat_interleave(repeats=batch_sizes, dim=0), + torch.arange(num_batch).repeat_interleave(repeats=num_seed, dim=0), + ], dim=0).repeat(num_step).numpy() + + num_agent = batch_sizes.sum() + num_seed * num_batch + batch_sizes = torch.nn.functional.pad(batch_sizes, (1, 0), mode='constant', value=0) + ptr = torch.cumsum(batch_sizes, dim=0) + # assume difference scenarios and different timestep have the same number of seed agents + ptr_seed = torch.tensor(np.array([0] + [num_seed] * num_batch), device=ptr.device) + + all_av_index = None + if 'av_index' in kwargs: + all_av_index = kwargs.pop('av_index').cpu() - ptr[:-1] + + is_bos = np.zeros((batch_sizes.sum(), num_step)).astype(np.bool_) + if 'is_bos' in kwargs: + is_bos = kwargs.pop('is_bos').cpu().numpy() + + src_index = torch.unique(edge_index[1]) + for idx, src in enumerate(tqdm(src_index)): + + src_batch = batches[src] + + src_row = src % num_agent + if src_row // batch_sizes.sum() > 0: + seed_row = src_row % batch_sizes.sum() - ptr_seed[src_batch] + src_row = batch_sizes[src_batch + 1] + seed_row + else: + src_row = src_row - ptr[src_batch] + + src_col = src // (num_agent) + src_mask = np.zeros((batch_sizes[src_batch + 1] + num_seed, num_step)) + src_mask[src_row, src_col] = 1 + + tgt_mask = np.zeros((src_mask.shape[0], num_step)) + tgt_index = edge_index[0, edge_index[1] == src] + for tgt in tgt_index: + + tgt_batch = batches[tgt] + + tgt_row = tgt % num_agent + if tgt_row // batch_sizes.sum() > 0: + seed_row = tgt_row % batch_sizes.sum() - ptr_seed[tgt_batch] + tgt_row = batch_sizes[tgt_batch + 1] + seed_row + else: + tgt_row = tgt_row - ptr[tgt_batch] + + tgt_col = tgt // num_agent + tgt_mask[tgt_row, tgt_col] = 1 + assert tgt_batch == src_batch + + selected_step = tgt_mask.sum(axis=0) > 0 + if selected_step.sum() > 1: + print(f"\nidx={idx}", src.item(), src_row.item(), src_col.item()) + print(selected_step) + print(edge_index[:, edge_index[1] == src].tolist()) + + if all_av_index is not None: + kwargs['av_index'] = int(all_av_index[src_batch]) + + t = kwargs.get('t', src_col) + n = kwargs.get('n', 0) + is_bos_batch = is_bos[ptr[src_batch] : ptr[src_batch + 1]] + plot_binary_map(src_mask, tgt_mask, save_path, suffix=f'_{scenario_ids[src_batch]}_{t:02d}_{n:02d}_{idx:04d}', + is_bos=is_bos_batch, **kwargs) + + +@safe_run +def plot_map_edge(edge_index, pos_a, data, save_path='map_edge_map'): + + map_points = data['map_point']['position'][:, :2].cpu().numpy() + token_pos = data['pt_token']['position'][:, :2].cpu().numpy() + token_heading = data['pt_token']['orientation'].cpu().numpy() + num_pt = token_pos.shape[0] + + agent_index = torch.unique(edge_index[1]) + for i in tqdm(agent_index): + xy = pos_a[i].cpu().numpy() + pt_index = edge_index[0, edge_index[1] == i].cpu().numpy() + pt_index = pt_index % num_pt + + plt.subplots_adjust(left=0.3, right=0.7, top=0.7, bottom=0.3) + _, ax = plt.subplots() + ax.set_axis_off() + + plot_map_token(ax, map_points, token_pos[pt_index], token_heading[pt_index], colors='blue') + + ax.scatter(xy[0], xy[1], s=0.5, c='red', edgecolors='none') + + os.makedirs(save_path, exist_ok=True) + plt.savefig(os.path.join(save_path, f'map_{i}.png'), dpi=600, bbox_inches='tight') + plt.close() + + +def get_heatmap(x, y, prob, s=3, bins=1000): + heatmap, xedges, yedges = np.histogram2d(x, y, bins=bins, weights=prob, density=True) + + heatmap = gaussian_filter(heatmap, sigma=s) + + extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]] + return heatmap.T, extent + + +@safe_run +def draw_heatmap(vector, vector_prob, gt_idx): + fig, ax = plt.subplots(figsize=(10, 10)) + vector_prob = vector_prob.cpu().numpy() + + for j in range(vector.shape[0]): + if j in gt_idx: + color = (0, 0, 1) + else: + grey_scale = max(0, 0.9 - vector_prob[j]) + color = (0.9, grey_scale, grey_scale) + + # if lane[j, k, -1] == 0: continue + x0, y0, x1, y1, = vector[j, :4] + ax.plot((x0, x1), (y0, y1), color=color, linewidth=2) + + return plt + + +@safe_run +def plot_insert_grid(scenario_id, prob, grid, ego_pos, map, save_path='', prefix='', inference=False, indices=None, + all_t_in_one=False): + + """ + prob: float array of shape (num_step, num_grid) + grid: float array of shape (num_grid, 2) + """ + + os.makedirs(save_path, exist_ok=True) + + n = int(math.sqrt(prob.shape[1])) + + # grid = grid[:, np.newaxis] + ego_pos[np.newaxis, ...] + for t in range(ego_pos.shape[0]): + + plt.subplots_adjust(left=0.3, right=0.7, top=0.7, bottom=0.3) + _, ax = plt.subplots() + + # plot probability + prob_t = prob[t].reshape(n, n) + plt.imshow(prob_t, cmap='viridis', interpolation='nearest') + + if indices is not None: + indice = indices[t] + + if isinstance(indice, (int, float, np.int_)): + indice = [indice] + + for _indice in indice: + if _indice == -1: continue + + row = _indice // n + col = _indice % n + + rect = Rectangle((col - 0.5, row - 0.5), 1, 1, edgecolor='red', facecolor='none', lw=2) + ax.add_patch(rect) + + ax.grid(False) + ax.set_aspect('equal', adjustable='box') + + plt.title('Prob of Rel Position Grid') + plt.savefig(os.path.join(save_path, f'{prefix}{scenario_id}_heat_map_{t}.png'), dpi=300, bbox_inches='tight') + plt.close() + + if all_t_in_one: + break + + +@safe_run +def plot_insert_grid(scenario_id, prob, indices=None, save_path='', prefix='', inference=False): + + """ + prob: float array of shape (num_seed, num_step, num_grid) + grid: float array of shape (num_grid, 2) + """ + + os.makedirs(save_path, exist_ok=True) + + n = int(math.sqrt(prob.shape[-1])) + + plot_n = 3 + plot_t = 5 + + prob_list = [] + for i in range(plot_n): + for j in range(plot_t): + prob_list.append(prob[i, j].reshape(n, n)) + + indice_list = [] + if indices is not None: + for i in range(plot_n): + for j in range(plot_t): + indice_list.append(indices[i, j]) + + row_labels = [f'n={n}' for n in range(plot_n)] + col_labels = [f't={t}' for t in range(plot_t)] + + fig, axes = plt.subplots(plot_n, plot_t, figsize=(9, 6)) + fig.suptitle('Prob of Insert Position Grid') + plt.subplots_adjust(wspace=0.1, hspace=0.1) + + for i, ax in enumerate(axes.flat): + ax.imshow(prob_list[i], cmap='viridis', interpolation='nearest') + ax.axis('off') + + if indice_list: + row = indice_list[i] // n + col = indice_list[i] % n + rect = Rectangle((col - .5, row - .5), 1, 1, edgecolor='red', facecolor='none', lw=2) + ax.add_patch(rect) + + ax.add_patch(plt.Rectangle((-0.5, -0.5), prob_list[i].shape[1], prob_list[i].shape[0], + linewidth=2, edgecolor='black', facecolor='none')) + + for i, ax in enumerate(axes[:, 0]): + ax.annotate(row_labels[i], xy=(-0.1, 0.5), xycoords="axes fraction", + fontsize=12, ha="right", va="center", rotation=0) + + for j, ax in enumerate(axes[0, :]): + ax.annotate(col_labels[j], xy=(0.5, 1.05), xycoords="axes fraction", + fontsize=12, ha="center", va="bottom") + + ax.grid(False) + ax.set_aspect('equal', adjustable='box') + + plt.savefig(os.path.join(save_path, f'{prefix}{scenario_id}_insert_map.png'), dpi=500, bbox_inches='tight') + plt.close() + + +@safe_run +def plot_binary_map(src_mask, tgt_mask, save_path='', suffix='', av_index=None, is_bos=None, **kwargs): + + from matplotlib.colors import ListedColormap + os.makedirs(save_path, exist_ok=True) + + fig, axes = plt.subplots(1, 2, figsize=(10, 8)) + + title = [] + if kwargs.get('t', None) is not None: + t = kwargs['t'] + title.append(f't={t}') + + if kwargs.get('n', None) is not None: + n = kwargs['n'] + title.append(f'n={n}') + + plt.title(' '.join(title)) + + cmap = ListedColormap(['white', 'green']) + axes[0].imshow(src_mask, cmap=cmap, interpolation='nearest') + + cmap = ListedColormap(['white', 'orange']) + axes[1].imshow(tgt_mask, cmap=cmap, interpolation='nearest') + + if av_index is not None: + rect = Rectangle((-0.5, av_index - 0.5), src_mask.shape[1], 1, edgecolor='red', facecolor='none', lw=2) + axes[0].add_patch(rect) + rect = Rectangle((-0.5, av_index - 0.5), tgt_mask.shape[1], 1, edgecolor='red', facecolor='none', lw=2) + axes[1].add_patch(rect) + + if is_bos is not None: + rows, cols = np.where(is_bos) + for row, col in zip(rows, cols): + rect = Rectangle((col - 0.5, row - 0.5), 1, 1, edgecolor='blue', facecolor='none', lw=1) + axes[0].add_patch(rect) + rect = Rectangle((col - 0.5, row - 0.5), 1, 1, edgecolor='blue', facecolor='none', lw=1) + axes[1].add_patch(rect) + + for ax in axes: + ax.set_xticks(range(src_mask.shape[1] + 1), minor=False) + ax.set_yticks(range(src_mask.shape[0] + 1), minor=False) + ax.grid(which='major', color='gray', linestyle='--', linewidth=0.5) + + plt.savefig(os.path.join(save_path, f'map{suffix}.png'), dpi=300, bbox_inches='tight') + plt.close() + + +@safe_run +def plot_prob_seed(scenario_id, prob, save_path, prefix='', indices=None): + + os.makedirs(save_path, exist_ok=True) + + plt.figure(figsize=(8, 5)) + plt.imshow(prob, cmap='viridis', aspect='auto') + plt.colorbar() + + plt.title('Seed Probability') + + if indices is not None: + + for col in range(indices.shape[1]): + for row in indices[:, col]: + + if row == -1: continue + + rect = Rectangle((col - 0.5, row - 0.5), 1, 1, edgecolor='red', facecolor='none', lw=2) + plt.gca().add_patch(rect) + + plt.tight_layout() + plt.savefig(os.path.join(save_path, f'{prefix}{scenario_id}_prob_seed.png'), dpi=300, bbox_inches='tight') + plt.close() + + +@safe_run +def plot_raw(): + plt.figure(figsize=(30, 30)) + plt.rcParams['axes.facecolor']='white' + + data_path = '/u/xiuyu/work/dev4/data/waymo/scenario/training' + os.makedirs("data/vis/raw/0/", exist_ok=True) + file_list = os.listdir(data_path) + + for cnt_file, file in enumerate(file_list): + file_path = os.path.join(data_path, file) + dataset = tf.data.TFRecordDataset(file_path, compression_type='') + for scenario_idx, data in enumerate(dataset): + scenario = scenario_pb2.Scenario() + scenario.ParseFromString(bytearray(data.numpy())) + tqdm.write(f"scenario id: {scenario.scenario_id}") + + # draw maps + for i in range(len(scenario.map_features)): + + # draw lanes + if str(scenario.map_features[i].lane) != '': + line_x = [z.x for z in scenario.map_features[i].lane.polyline] + line_y = [z.y for z in scenario.map_features[i].lane.polyline] + plt.scatter(line_x, line_y, c='g', s=5) + plt.text(line_x[0], line_y[0], str(scenario.map_features[i].id), fontdict={'family': 'serif', 'size': 20, 'color': 'green'}) + + # draw road_edge + if str(scenario.map_features[i].road_edge) != '': + road_edge_x = [polyline.x for polyline in scenario.map_features[i].road_edge.polyline] + road_edge_y = [polyline.y for polyline in scenario.map_features[i].road_edge.polyline] + plt.scatter(road_edge_x, road_edge_y) + plt.text(road_edge_x[0], road_edge_y[0], scenario.map_features[i].road_edge.type, fontdict={'family': 'serif', 'size': 20, 'color': 'black'}) + if scenario.map_features[i].road_edge.type == 2: + plt.scatter(road_edge_x, road_edge_y, c='k') + elif scenario.map_features[i].road_edge.type == 3: + plt.scatter(road_edge_x, road_edge_y, c='purple') + print(scenario.map_features[i].road_edge) + else: + plt.scatter(road_edge_x, road_edge_y, c='k') + + # draw road_line + if str(scenario.map_features[i].road_line) != '': + road_line_x = [j.x for j in scenario.map_features[i].road_line.polyline] + road_line_y = [j.y for j in scenario.map_features[i].road_line.polyline] + if scenario.map_features[i].road_line.type == 7: + plt.plot(road_line_x, road_line_y, c='y') + elif scenario.map_features[i].road_line.type == 8: + plt.plot(road_line_x, road_line_y, c='y') + elif scenario.map_features[i].road_line.type == 6: + plt.plot(road_line_x, road_line_y, c='y') + elif scenario.map_features[i].road_line.type == 1: + for i in range(int(len(road_line_x) / 7)): + plt.plot(road_line_x[i * 7 : 5 + i * 7], road_line_y[i * 7 : 5 + i * 7], color='w') + elif scenario.map_features[i].road_line.type == 2: + plt.plot(road_line_x, road_line_y, c='w') + else: + plt.plot(road_line_x, road_line_y, c='w') + + # draw tracks + scenario_has_invalid_tracks = False + for i in range(len(scenario.tracks)): + traj_x = [center.center_x for center in scenario.tracks[i].states] + traj_y = [center.center_y for center in scenario.tracks[i].states] + head = [center.heading for center in scenario.tracks[i].states] + valid = [center.valid for center in scenario.tracks[i].states] + print(valid) + if i == scenario.sdc_track_index: + plt.scatter(traj_x[0], traj_y[0], s=140, c='r', marker='s') + plt.scatter([x for x, v in zip(traj_x, valid) if v], + [y for y, v in zip(traj_y, valid) if v], s=14, c='r') + plt.scatter([x for x, v in zip(traj_x, valid) if not v], + [y for y, v in zip(traj_y, valid) if not v], s=14, c='m') + else: + plt.scatter(traj_x[0], traj_y[0], s=140, c='k', marker='s') + plt.scatter([x for x, v in zip(traj_x, valid) if v], + [y for y, v in zip(traj_y, valid) if v], s=14, c='b') + plt.scatter([x for x, v in zip(traj_x, valid) if not v], + [y for y, v in zip(traj_y, valid) if not v], s=14, c='m') + if valid.count(False) > 0: + scenario_has_invalid_tracks = True + if scenario_has_invalid_tracks: + plt.savefig(f"scenario_{scenario_idx}_{scenario.scenario_id}.png") + plt.clf() + breakpoint() + break + + +colors = [ + ('#1f77b4', '#1a5a8a'), # blue + ('#2ca02c', '#217721'), # green + ('#ff7f0e', '#cc660b'), # orange + ('#9467bd', '#6f4a91'), # purple + ('#d62728', '#a31d1d'), # red + ('#000000', '#000000'), # black +] + +@safe_run +def plot_gif(): + data_path = "/u/xiuyu/work/dev4/data/waymo_processed/training" + os.makedirs("data/vis/processed/0/gif", exist_ok=True) + file_list = os.listdir(data_path) + + for scenario_idx, file in tqdm(enumerate(file_list), leave=False, desc="Scenario"): + + fig, ax = plt.subplots() + ax.set_axis_off() + + file_path = os.path.join(data_path, file) + data = pickle.load(open(file_path, "rb")) + scenario_id = data['scenario_id'] + + save_path = os.path.join("data/vis/processed/0/gif", + f"scenario_{scenario_idx}_{scenario_id}.gif") + if os.path.exists(save_path): + tqdm.write(f"Skipped {save_path}.") + continue + + # draw maps + ax.scatter(data['map_point']['position'][:, 0], + data['map_point']['position'][:, 1], s=0.2, c='black', edgecolors='none') + + # draw agents + agent_data = data['agent'] + av_index = agent_data['av_index'] + position = agent_data['position'] # (num_agent, 91, 3) + heading = agent_data['heading'] # (num_agent, 91) + shape = agent_data['shape'] # (num_agent, 91, 3) + category = agent_data['category'] # (num_agent,) + valid_mask = (position[..., 0] != 0) & (position[..., 1] != 0) # (num_agent, 91) + + num_agent = valid_mask.shape[0] + num_timestep = position.shape[1] + is_av = np.arange(num_agent) == int(av_index) + + is_blue = valid_mask.sum(axis=1) == num_timestep + is_green = ~valid_mask[:, 0] & valid_mask[:, -1] + is_orange = valid_mask[:, 0] & ~valid_mask[:, -1] + is_purple = (valid_mask.sum(axis=1) != num_timestep + ) & (~is_green) & (~is_orange) + agent_colors = np.zeros((num_agent,)) + agent_colors[is_blue] = 1 + agent_colors[is_green] = 2 + agent_colors[is_orange] = 3 + agent_colors[is_purple] = 4 + agent_colors[is_av] = 5 + + veh_mask = category == 1 + ped_mask = category == 2 + cyc_mask = category == 3 + shape[veh_mask, :, 1] = 1.8 + shape[veh_mask, :, 0] = 1.8 + shape[ped_mask, :, 1] = 0.5 + shape[ped_mask, :, 0] = 0.5 + shape[cyc_mask, :, 1] = 1.0 + shape[cyc_mask, :, 0] = 1.0 + + fig_paths = [] + for tid in tqdm(range(num_timestep), leave=False, desc="Timestep"): + current_valid_mask = valid_mask[:, tid] + xs = position[current_valid_mask, tid, 0] + ys = position[current_valid_mask, tid, 1] + widths = shape[current_valid_mask, tid, 1] + lengths = shape[current_valid_mask, tid, 0] + angles = heading[current_valid_mask, tid] + current_agent_colors = agent_colors[current_valid_mask] + + drawn_agents = [] + contours = cal_polygon_contour(xs, ys, angles, widths, lengths) # (num_agent, 4, 2) + contours = np.concatenate([contours, contours[:, 0:1]], axis=1) # (num_agent, 5, 2) + for x, y, width, length, angle, color_type in zip( + xs, ys, widths, lengths, angles, current_agent_colors): + agent = plt.Rectangle((x, y), width, length, angle=((angle + np.pi / 2) / np.pi * 360) % 360, + linewidth=0.2, + facecolor=colors[int(color_type) - 1][0], + edgecolor=colors[int(color_type) - 1][1]) + ax.add_patch(agent) + drawn_agents.append(agent) + plt.gca().set_aspect('equal', adjustable='box') + # for contour, color_type in zip(contours, agent_colors): + # drawn_agent = ax.plot(contour[:, 0], contour[:, 1]) + # drawn_agents.append(drawn_agent) + + fig_path = os.path.join("data/vis/processed/0/", + f"scenario_{scenario_idx}_{scenario_id}_{tid}.png") + plt.savefig(fig_path, dpi=600) + fig_paths.append(fig_path) + + for drawn_agent in drawn_agents: + drawn_agent.remove() + + plt.close() + + # generate gif + import imageio.v2 as imageio + images = [] + for fig_path in tqdm(fig_paths, leave=False, desc="Generate gif ..."): + images.append(imageio.imread(fig_path)) + imageio.mimsave(save_path, images, duration=0.1) + + +@safe_run +def plot_map_token(ax: Axes, map_points: npt.NDArray, token_pos: npt.NDArray, token_heading: npt.NDArray, colors: Union[str, npt.NDArray]=None): + + plot_map(ax, map_points) + + x, y = token_pos[:, 0], token_pos[:, 1] + u = np.cos(token_heading) + v = np.sin(token_heading) + + if colors is None: + colors = np.random.rand(x.shape[0], 3) + ax.quiver(x, y, u, v, angles='xy', scale_units='xy', scale=0.2, color=colors, width=0.005, + headwidth=0.2, headlength=2) + ax.scatter(x, y, color='blue', s=0.2, edgecolors='none') + ax.axis("equal") + + +@safe_run +def plot_map(ax: Axes, map_points: npt.NDArray, color='black'): + ax.scatter(map_points[:, 0], map_points[:, 1], s=0.2, c=color, edgecolors='none') + + xmin = np.min(map_points[:, 0]) + xmax = np.max(map_points[:, 0]) + ymin = np.min(map_points[:, 1]) + ymax = np.max(map_points[:, 1]) + ax.set_xlim(xmin, xmax) + ax.set_ylim(ymin, ymax) + + +@safe_run +def plot_agent(ax: Axes, xy: Sequence[float], heading: float, type: str, state, is_av: bool=False, + pl2seed_radius: float=25., attr_tokenizer: Attr_Tokenizer=None, enter_index: list=[], **kwargs): + + if type == 'veh': + length = 4.3 + width = 1.8 + size = 1.0 + elif type == 'ped': + length = 0.5 + width = 0.5 + size = 0.1 + elif type == 'cyc': + length = 1.9 + width = 0.5 + size = 0.3 + else: + raise ValueError(f"Unsupported agent type {type}") + + if kwargs.get('label', None) is not None: + ax.text( + xy[0] + 1.5, xy[1] + 1.5, + kwargs['label'], fontsize=2, color="darkred", ha="center", va="center" + ) + + patch = FancyBboxPatch([-length / 2, -width / 2], length, width, linewidth=.2, **kwargs) + transform = ( + mtransforms.Affine2D().rotate(heading).translate(xy[0], xy[1]) + + ax.transData + ) + patch.set_transform(transform) + + kwargs['label'] = None + angles = [0, 2 * np.pi / 3, np.pi, 4 * np.pi / 3] + pts = np.stack([size * np.cos(angles), size * np.sin(angles)], axis=-1) + center_patch = Polygon(pts, zorder=10., linewidth=.2, **kwargs) + center_patch.set_transform(transform) + + ax.add_patch(patch) + ax.add_patch(center_patch) + + if is_av: + + if attr_tokenizer is not None: + + circle_patch = Circle( + (xy[0], xy[1]), pl2seed_radius, linewidth=0.5, edgecolor='gray', linestyle='--', facecolor='none' + ) + ax.add_patch(circle_patch) + + grid = attr_tokenizer.get_grid(torch.tensor(np.array(xy)).float(), + torch.tensor(np.array([heading])).float()).numpy()[0] # (num_grid, 2) + ax.scatter(grid[:, 0], grid[:, 1], s=0.3, c='blue', edgecolors='none') + ax.text(grid[0, 0], grid[0, 1], 'Front', fontsize=2, color='darkred', ha='center', va='center') + ax.text(grid[-1, 0], grid[-1, 1], 'Back', fontsize=2, color='darkred', ha='center', va='center') + + if enter_index: + for i in enter_index: + ax.plot(grid[int(i), 0], grid[int(i), 1], marker='x', color='red', markersize=1) + + return patch, center_patch + + +@safe_run +def plot_all(map, xs, ys, angles, types, colors, is_avs, pl2seed_radius: float=25., + attr_tokenizer: Attr_Tokenizer=None, enter_index: list=[], labels: list=[], **kwargs): + + plt.subplots_adjust(left=0.3, right=0.7, top=0.7, bottom=0.3) + _, ax = plt.subplots() + ax.set_axis_off() + + plot_map(ax, map) + + if not labels: + labels = [None] * xs.shape[0] + + for x, y, angle, type, color, label, is_av in zip(xs, ys, angles, types, colors, labels, is_avs): + assert type in ('veh', 'ped', 'cyc'), f"Unsupported type {type}." + plot_agent(ax, [x, y], angle.item(), type, None, is_av, facecolor=color, edgecolor='k', label=label, + pl2seed_radius=pl2seed_radius, attr_tokenizer=attr_tokenizer, enter_index=enter_index) + + ax.grid(False) + ax.set_aspect('equal', adjustable='box') + + # ax.legend(loc='best', frameon=True) + + if kwargs.get('save_path', None): + plt.savefig(kwargs['save_path'], dpi=600, bbox_inches="tight") + + plt.close() + + return ax + + +@safe_run +def plot_file(gt_folder: str, + folder: Optional[str] = None, + files: Optional[str] = None): + from dev.metrics.compute_metrics import _unbatch + + if files is None: + assert os.path.exists(folder), f'Path {folder} does not exist.' + files = list(fnmatch.filter(os.listdir(folder), 'idx_*_rollouts.pkl')) + CONSOLE.log(f'Found {len(files)} rollouts files from {folder}.') + + + if folder is None: + assert os.path.exists(files), f'Path {files} does not exist.' + folder = os.path.dirname(files) + files = [files] + + parent, folder_name = os.path.split(folder.rstrip(os.sep)) + save_path = os.path.join(parent, f'{folder_name}_plots') + + for file in (pbar := tqdm(files, leave=False, desc='Plotting files ...')): + pbar.set_postfix(file=file) + + with open(os.path.join(folder, file), 'rb') as f: + preds = pickle.load(f) + + scenario_ids = preds['_scenario_id'] + agent_batch = preds['agent_batch'] + agent_id = _unbatch(preds['agent_id'], agent_batch) + preds_traj = _unbatch(preds['pred_traj'], agent_batch) + preds_head = _unbatch(preds['pred_head'], agent_batch) + preds_type = _unbatch(preds['pred_type'], agent_batch) + preds_state = _unbatch(preds['pred_state'], agent_batch) + preds_valid = _unbatch(preds['pred_valid'], agent_batch) + + for i, scenario_id in enumerate(scenario_ids): + n_rollouts = preds_traj[0].shape[1] + + for j in range(n_rollouts): # 1 + pred = dict(scenario_id=[scenario_id], + pred_traj=preds_traj[i][:, j], + pred_head=preds_head[i][:, j], + pred_state=preds_state[i][:, j], + pred_type=preds_type[i][:, j], + ) + av_index = agent_id[i][:, 0].tolist().index(preds['av_id']) # NOTE: hard code!!! + + data_path = os.path.join(gt_folder, 'validation', f'{scenario_id}.pkl') + with open(data_path, 'rb') as f: + data = pickle.load(f) + plot_val(data, pred, av_index=av_index, save_path=save_path) + + +@safe_run +def plot_val(data: Union[dict, str], pred: dict, av_index: int, save_path: str, suffix: str='', + pl2seed_radius: float=75., attr_tokenizer=None, **kwargs): + + if isinstance(data, str): + assert data.endswith('.pkl'), f'Got invalid data path {data}.' + assert os.path.exists(data), f'Path {data} does not exist.' + with open(data, 'rb') as f: + data = pickle.load(f) + + map_point = data['map_point']['position'].cpu().numpy() + + scenario_id = pred['scenario_id'][0] + pred_traj = pred['pred_traj'].cpu().numpy() # (num_agent, num_future_step, 2) + pred_type = list(map(lambda i: AGENT_TYPE[i], pred['pred_type'].tolist())) + pred_state = pred['pred_state'].cpu().numpy() + pred_head = pred['pred_head'].cpu().numpy() + ids = np.arange(pred_traj.shape[0]) + + if 'agent_labels' in pred: + kwargs.update(agent_labels=pred['agent_labels']) + + plot_scenario(scenario_id, map_point, pred_traj, pred_head, pred_state, pred_type, + av_index=av_index, ids=ids, save_path=save_path, suffix=suffix, + pl2seed_radius=pl2seed_radius, attr_tokenizer=attr_tokenizer, **kwargs) + + +@safe_run +def plot_scenario(scenario_id: str, + map_data: npt.NDArray, + traj: npt.NDArray, + heading: npt.NDArray, + state: npt.NDArray, + types: List[str], + av_index: int, + color_type: Literal['state', 'type', 'seed']='seed', + state_type: List[str]=['invalid', 'valid', 'enter', 'exit'], + plot_enter: bool=False, + suffix: str='', + pl2seed_radius: float=25., + attr_tokenizer: Attr_Tokenizer=None, + enter_index: List[list] = [], + save_gif: bool=True, + tokenized: bool=False, + agent_labels: List[List[Optional[str]]] = [], + **kwargs): + + num_historical_steps = 11 + shift = 5 + num_agent, num_timestep = traj.shape[:2] + + if tokenized: + num_historical_steps = 2 + shift = 1 + + if 'save_path' in kwargs and kwargs['save_path'] != '': + os.makedirs(kwargs['save_path'], exist_ok=True) + save_id = int(max([0] + list(map(lambda fname: int(fname.split("_")[-1]), + filter(lambda fname: fname.startswith(scenario_id) + and os.path.isdir(os.path.join(kwargs['save_path'], fname)), + os.listdir(kwargs['save_path'])))))) + 1 + os.makedirs(f"{kwargs['save_path']}/{scenario_id}_{str(save_id).zfill(3)}", exist_ok=True) + + if save_id > 1: + try: + import shutil + shutil.rmtree(f"{kwargs['save_path']}/{scenario_id}_{str(save_id - 1).zfill(3)}") + except: + pass + + visible_mask = state != state_type.index('invalid') + if not plot_enter: + visible_mask &= (state != state_type.index('enter')) + + last_valid_step = visible_mask.shape[1] - 1 - torch.argmax(torch.Tensor(visible_mask).flip(dims=[1]).long(), dim=1) + ids = None + if 'ids' in kwargs: + ids = kwargs['ids'] + last_valid_step = {int(ids[i]): int(last_valid_step[i]) for i in range(len(ids))} + + # agent colors + agent_colors = np.zeros((num_agent, num_timestep, 3)) + + agent_palette = sns.color_palette('husl', n_colors=7) + state_colors = {state: np.array(agent_palette[i]) for i, state in enumerate(state_type)} + seed_colors = {seed: np.array(agent_palette[i]) for i, seed in enumerate(['existing', 'entered', 'exited'])} + + if color_type == 'state': + for t in range(state.shape[1]): + agent_colors[state[:, t] == state_type.index('invalid'), t * shift : (t + 1) * shift] = state_colors['invalid'] + agent_colors[state[:, t] == state_type.index('valid'), t * shift : (t + 1) * shift] = state_colors['valid'] + agent_colors[state[:, t] == state_type.index('enter'), t * shift : (t + 1) * shift] = state_colors['enter'] + agent_colors[state[:, t] == state_type.index('exit'), t * shift : (t + 1) * shift] = state_colors['exit'] + + if color_type == 'seed': + agent_colors[:, :] = seed_colors['existing'] + is_exited = np.any(state[:, num_historical_steps - 1:] == state_type.index('exit'), axis=-1) + is_entered = np.any(state[:, num_historical_steps - 1:] == state_type.index('enter'), axis=-1) + is_entered[av_index + 1:] = True # NOTE: hard code, need improvment + agent_colors[is_exited, :] = seed_colors['exited'] + agent_colors[is_entered, :] = seed_colors['entered'] + + agent_colors[av_index, :] = np.array(agent_palette[-1]) + is_av = np.zeros_like(state[:, 0]).astype(np.bool_) + is_av[av_index] = True + + # draw agents + fig_paths = [] + for tid in tqdm(range(num_timestep), leave=False, desc="Plot ..."): + mask_t = visible_mask[:, tid] + xs = traj[mask_t, tid, 0] + ys = traj[mask_t, tid, 1] + angles = heading[mask_t, tid] + colors = agent_colors[mask_t, tid] + types_t = [types[i] for i, mask in enumerate(mask_t) if mask] + if ids is not None: + ids_t = ids[mask_t] + is_av_t = is_av[mask_t] + enter_index_t = enter_index[tid] if enter_index else None + labels = [] + if agent_labels: + labels = [agent_labels[i][tid // shift] for i in range(len(agent_labels)) if mask_t[i]] + + fig_path = None + if 'save_path' in kwargs: + save_path = kwargs['save_path'] + fig_path = os.path.join(f"{save_path}/{scenario_id}_{str(save_id).zfill(3)}", f"{tid}.png") + fig_paths.append(fig_path) + + plot_all(map_data, xs, ys, angles, types_t, colors=colors, save_path=fig_path, is_avs=is_av_t, + pl2seed_radius=pl2seed_radius, attr_tokenizer=attr_tokenizer, enter_index=enter_index_t, labels=labels) + + # generate gif + if fig_paths and save_gif: + os.makedirs(os.path.join(save_path, 'gifs'), exist_ok=True) + images = [] + gif_path = f"{save_path}/gifs/{scenario_id}_{str(save_id).zfill(3)}.gif" + for fig_path in tqdm(fig_paths, leave=False, desc="Generate gif ..."): + images.append(Image.open(fig_path)) + try: + images[0].save(gif_path, save_all=True, append_images=images[1:], duration=100, loop=0) + tqdm.write(f"Saved gif at {gif_path}") + try: + import shutil + shutil.rmtree(f"{save_path}/{scenario_id}_{str(save_id).zfill(3)}") + os.remove(f"{save_path}/gifs/{scenario_id}_{str(save_id - 1).zfill(3)}.gif") + except: + pass + except Exception as e: + tqdm.write(f"{e}! Failed to save gif at {gif_path}") + + +def match_token_map(data): + + # init map token + argmin_sample_len = 3 + map_token_traj_path = '/u/xiuyu/work/dev4/dev/tokens/map_traj_token5.pkl' + + map_token_traj = pickle.load(open(map_token_traj_path, 'rb')) + map_token = {'traj_src': map_token_traj['traj_src'], } + traj_end_theta = np.arctan2(map_token['traj_src'][:, -1, 1] - map_token['traj_src'][:, -2, 1], + map_token['traj_src'][:, -1, 0] - map_token['traj_src'][:, -2, 0]) + indices = torch.linspace(0, map_token['traj_src'].shape[1]-1, steps=argmin_sample_len).long() + map_token['sample_pt'] = torch.from_numpy(map_token['traj_src'][:, indices]).to(torch.float) + map_token['traj_end_theta'] = torch.from_numpy(traj_end_theta).to(torch.float) + map_token['traj_src'] = torch.from_numpy(map_token['traj_src']).to(torch.float) + + traj_pos = data['map_save']['traj_pos'].to(torch.float) + traj_theta = data['map_save']['traj_theta'].to(torch.float) + pl_idx_list = data['map_save']['pl_idx_list'] + token_sample_pt = map_token['sample_pt'].to(traj_pos.device) + token_src = map_token['traj_src'].to(traj_pos.device) + max_traj_len = map_token['traj_src'].shape[1] + pl_num = traj_pos.shape[0] + + pt_token_pos = traj_pos[:, 0, :].clone() + pt_token_orientation = traj_theta.clone() + cos, sin = traj_theta.cos(), traj_theta.sin() + rot_mat = traj_theta.new_zeros(pl_num, 2, 2) + rot_mat[..., 0, 0] = cos + rot_mat[..., 0, 1] = -sin + rot_mat[..., 1, 0] = sin + rot_mat[..., 1, 1] = cos + traj_pos_local = torch.bmm((traj_pos - traj_pos[:, 0:1]), rot_mat.view(-1, 2, 2)) + distance = torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1)) ** 2, dim=(-2, -1)) + pt_token_id = torch.argmin(distance, dim=1) + + noise = False + if noise: + topk_indices = torch.argsort(torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1)) ** 2, dim=(-2, -1)), dim=1)[:, :8] + sample_topk = torch.randint(0, topk_indices.shape[-1], size=(topk_indices.shape[0], 1), device=topk_indices.device) + pt_token_id = torch.gather(topk_indices, 1, sample_topk).squeeze(-1) + + # cos, sin = traj_theta.cos(), traj_theta.sin() + # rot_mat = traj_theta.new_zeros(pl_num, 2, 2) + # rot_mat[..., 0, 0] = cos + # rot_mat[..., 0, 1] = sin + # rot_mat[..., 1, 0] = -sin + # rot_mat[..., 1, 1] = cos + # token_src_world = torch.bmm(token_src[None, ...].repeat(pl_num, 1, 1, 1).reshape(pl_num, -1, 2), + # rot_mat.view(-1, 2, 2)).reshape(pl_num, token_src.shape[0], max_traj_len, 2) + traj_pos[:, None, [0], :] + # token_src_world_select = token_src_world.view(-1, 1024, 11, 2)[torch.arange(pt_token_id.view(-1).shape[0]), pt_token_id.view(-1)].view(pl_num, max_traj_len, 2) + + pl_idx_full = pl_idx_list.clone() + token2pl = torch.stack([torch.arange(len(pl_idx_list), device=traj_pos.device), pl_idx_full.long()]) + count_nums = [] + for pl in pl_idx_full.unique(): + pt = token2pl[0, token2pl[1, :] == pl] + left_side = (data['pt_token']['side'][pt] == 0).sum() + right_side = (data['pt_token']['side'][pt] == 1).sum() + center_side = (data['pt_token']['side'][pt] == 2).sum() + count_nums.append(torch.Tensor([left_side, right_side, center_side])) + count_nums = torch.stack(count_nums, dim=0) + num_polyline = int(count_nums.max().item()) + traj_mask = torch.zeros((int(len(pl_idx_full.unique())), 3, num_polyline), dtype=bool) + idx_matrix = torch.arange(traj_mask.size(2)).unsqueeze(0).unsqueeze(0) + idx_matrix = idx_matrix.expand(traj_mask.size(0), traj_mask.size(1), -1) + counts_num_expanded = count_nums.unsqueeze(-1) + mask_update = idx_matrix < counts_num_expanded + traj_mask[mask_update] = True + + data['pt_token']['traj_mask'] = traj_mask + data['pt_token']['position'] = torch.cat([pt_token_pos, torch.zeros((data['pt_token']['num_nodes'], 1), + device=traj_pos.device, dtype=torch.float)], dim=-1) + data['pt_token']['orientation'] = pt_token_orientation + data['pt_token']['height'] = data['pt_token']['position'][:, -1] + data[('pt_token', 'to', 'map_polygon')] = {} + data[('pt_token', 'to', 'map_polygon')]['edge_index'] = token2pl # (2, num_points) + data['pt_token']['token_idx'] = pt_token_id + return data + + +@safe_run +def plot_tokenize(data, save_path: str): + + shift = 5 + token_size = 2048 + pl2seed_radius = 75 + + # transformation + transform = WaymoTargetBuilder(num_historical_steps=11, + num_future_steps=80, + max_num=32, + training=False) + + grid_range = 150. + grid_interval = 3. + angle_interval = 3. + attr_tokenizer = Attr_Tokenizer(grid_range=grid_range, + grid_interval=grid_interval, + radius=pl2seed_radius, + angle_interval=angle_interval) + + # tokenization + token_processor = TokenProcessor(token_size, + training=False, + predict_motion=True, + predict_state=True, + predict_map=True, + state_token={'invalid': 0, 'valid': 1, 'enter': 2, 'exit': 3}, + pl2seed_radius=pl2seed_radius) + CONSOLE.log(f"Loaded token processor with token_size: {token_size}") + + # preprocess + data: HeteroData = transform(data) + tokenized_data = token_processor(data) + CONSOLE.log(f"Keys in tokenized data:\n{tokenized_data.keys()}") + + # plot + agent_data = tokenized_data['agent'] + map_data = tokenized_data['map_point'] + # CONSOLE.log(f"Keys in agent data:\n{agent_data.keys()}") + + av_index = agent_data['av_index'] + raw_traj = agent_data['position'][..., :2].contiguous() # [n_agent, n_step, 2] + raw_heading = agent_data['heading'] # [n_agent, n_step] + + traj = agent_data['traj_pos'][..., :2].contiguous() # [n_agent, n_step, 6, 2] + traj = traj[:, :, 1:, :].flatten(1, 2) + traj = torch.cat([raw_traj[:, :1], traj], dim=1) + heading = agent_data['traj_heading'] # [n_agent, n_step, 6] + heading = heading[:, :, 1:].flatten(1, 2) + heading = torch.cat([raw_heading[:, :1], heading], dim=1) + + agent_state = agent_data['state_idx'].repeat_interleave(repeats=shift, dim=-1) + agent_state = torch.cat([torch.zeros_like(agent_state[:, :1]), agent_state], dim=1) + agent_type = agent_data['type'] + ids = np.arange(raw_traj.shape[0]) + + plot_scenario(scenario_id=tokenized_data['scenario_id'], + map_data=tokenized_data['map_point']['position'].numpy(), + traj=raw_traj.numpy(), + heading=raw_heading.numpy(), + state=agent_state.numpy(), + types=list(map(lambda i: AGENT_TYPE[i], agent_type.tolist())), + av_index=av_index, + ids=ids, + save_path=save_path, + pl2seed_radius=pl2seed_radius, + attr_tokenizer=attr_tokenizer, + color_type='state', + ) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument('--data_path', type=str, default='/u/xiuyu/work/dev4/data/waymo_processed') + parser.add_argument('--tfrecord_dir', type=str, default='validation_tfrecords_splitted') + # plot tokenized data + parser.add_argument('--save_folder', type=str, default='plot_gt') + parser.add_argument('--split', type=str, default='validation') + parser.add_argument('--scenario_id', type=str, default=None) + parser.add_argument('--plot_tokenize', action='store_true') + # plot generated rollouts + parser.add_argument('--plot_file', action='store_true') + parser.add_argument('--folder_path', type=str, default=None) + parser.add_argument('--file_path', type=str, default=None) + args = parser.parse_args() + + if args.plot_tokenize: + + scenario_id = "74ad7b76d5906d39" + # scenario_id = "1d60300bc06f4801" + data_path = os.path.join(args.data_path, args.split, f"{scenario_id}.pkl") + data = pickle.load(open(data_path, "rb")) + data['tfrecord_path'] = os.path.join(args.tfrecord_dir, f'{scenario_id}.tfrecords') + CONSOLE.log(f"Loaded scenario {scenario_id}") + + save_path = os.path.join(args.data_path, args.save_folder, args.split) + os.makedirs(save_path, exist_ok=True) + + plot_tokenize(data, save_path) + + if args.plot_file: + + plot_file(args.data_path, folder=args.folder_path, files=args.file_path) diff --git a/backups/environment.yml b/backups/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..a5a6a39e2f1227a14555869bb90498038e626cf2 --- /dev/null +++ b/backups/environment.yml @@ -0,0 +1,326 @@ +name: traj +channels: + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=2_gnu + - ca-certificates=2025.1.31=hbcca054_0 + - ld_impl_linux-64=2.38=h1181459_1 + - libffi=3.4.4=h6a678d5_1 + - libgcc=14.2.0=h77fa898_1 + - libgcc-ng=14.2.0=h69a702a_1 + - libgomp=14.2.0=h77fa898_1 + - libstdcxx-ng=11.2.0=h1234567_1 + - ncdu=1.16=h0f457ee_0 + - ncurses=6.4=h6a678d5_0 + - openssl=3.4.0=h7b32b05_1 + - pip=24.2=py39h06a4308_0 + - python=3.9.19=h955ad1f_1 + - readline=8.2=h5eee18b_0 + - sqlite=3.45.3=h5eee18b_0 + - tk=8.6.14=h39e8969_0 + - wheel=0.43.0=py39h06a4308_0 + - xz=5.4.6=h5eee18b_1 + - zlib=1.2.13=h5eee18b_1 + - pip: + - absl-py==1.4.0 + - addict==2.4.0 + - aiohappyeyeballs==2.4.0 + - aiohttp==3.10.5 + - aiosignal==1.3.1 + - anyio==4.4.0 + - appdirs==1.4.4 + - argon2-cffi==23.1.0 + - argon2-cffi-bindings==21.2.0 + - array-record==0.5.1 + - arrow==1.3.0 + - asttokens==2.4.1 + - astunparse==1.6.3 + - async-lru==2.0.4 + - async-timeout==4.0.3 + - attrs==24.2.0 + - av==12.3.0 + - babel==2.16.0 + - beautifulsoup4==4.12.3 + - bidict==0.23.1 + - bleach==6.1.0 + - blinker==1.8.2 + - cachetools==5.5.0 + - certifi==2024.7.4 + - cffi==1.17.0 + - chardet==5.2.0 + - charset-normalizer==3.3.2 + - click==8.1.7 + - cloudpickle==3.0.0 + - colorlog==6.8.2 + - comet-ml==3.45.0 + - comm==0.2.2 + - configargparse==1.7 + - configobj==5.0.8 + - contourpy==1.3.0 + - cryptography==43.0.0 + - cycler==0.12.1 + - dacite==1.8.1 + - dash==2.17.1 + - dash-core-components==2.0.0 + - dash-html-components==2.0.0 + - dash-table==5.0.0 + - dask==2023.3.1 + - dataclass-array==1.5.1 + - debugpy==1.8.5 + - decorator==5.1.1 + - defusedxml==0.7.1 + - descartes==1.1.0 + - dm-tree==0.1.8 + - docker-pycreds==0.4.0 + - docstring-parser==0.16 + - dulwich==0.22.1 + - easydict==1.13 + - einops==0.8.0 + - einsum==0.3.0 + - embreex==2.17.7.post5 + - etils==1.5.2 + - eval-type-backport==0.2.0 + - everett==3.1.0 + - exceptiongroup==1.2.2 + - executing==2.0.1 + - fastjsonschema==2.20.0 + - filelock==3.15.4 + - fire==0.6.0 + - flask==3.0.3 + - flatbuffers==24.3.25 + - fonttools==4.53.1 + - fqdn==1.5.1 + - frozenlist==1.4.1 + - fsspec==2024.6.1 + - gast==0.4.0 + - gdown==5.2.0 + - gitdb==4.0.11 + - gitpython==3.1.43 + - google-auth==2.16.2 + - google-auth-oauthlib==1.0.0 + - google-pasta==0.2.0 + - grpcio==1.66.1 + - h11==0.14.0 + - h5py==3.11.0 + - httpcore==1.0.5 + - httpx==0.27.2 + - idna==3.8 + - imageio==2.35.1 + - immutabledict==2.2.0 + - importlib-metadata==8.4.0 + - importlib-resources==6.4.4 + - ipykernel==6.29.5 + - ipython==8.18.1 + - ipywidgets==8.1.5 + - isoduration==20.11.0 + - itsdangerous==2.2.0 + - jax==0.4.30 + - jaxlib==0.4.30 + - jaxtyping==0.2.33 + - jedi==0.19.1 + - jinja2==3.1.4 + - joblib==1.4.2 + - json5==0.9.25 + - jsonpointer==3.0.0 + - jsonschema==4.23.0 + - jsonschema-specifications==2023.12.1 + - jupyter-client==8.6.2 + - jupyter-core==5.7.2 + - jupyter-events==0.10.0 + - jupyter-lsp==2.2.5 + - jupyter-server==2.14.2 + - jupyter-server-terminals==0.5.3 + - jupyterlab==4.2.5 + - jupyterlab-pygments==0.3.0 + - jupyterlab-server==2.27.3 + - jupyterlab-widgets==3.0.13 + - keras==2.12.0 + - kiwisolver==1.4.5 + - lark==1.2.2 + - lazy-loader==0.4 + - libclang==18.1.1 + - lightning-utilities==0.11.6 + - locket==1.0.0 + - lxml==5.3.0 + - manifold3d==2.5.1 + - mapbox-earcut==1.0.2 + - markdown==3.7 + - markdown-it-py==3.0.0 + - markupsafe==2.1.5 + - matplotlib==3.9.2 + - matplotlib-inline==0.1.7 + - mdurl==0.1.2 + - mediapy==1.2.2 + - mistune==3.0.2 + - ml-dtypes==0.4.0 + - mpmath==1.3.0 + - msgpack==1.0.8 + - msgpack-numpy==0.4.8 + - multidict==6.0.5 + - namex==0.0.8 + - nbclient==0.10.0 + - nbconvert==7.16.4 + - nbformat==5.10.4 + - nerfacc==0.5.2 + - nerfstudio==0.3.4 + - nest-asyncio==1.6.0 + - networkx==3.2.1 + - ninja==1.11.1.1 + - nodeenv==1.9.1 + - notebook-shim==0.2.4 + - numpy==1.23.0 + - nuscenes-devkit==1.1.11 + - nvidia-cublas-cu12==12.1.3.1 + - nvidia-cuda-cupti-cu12==12.1.105 + - nvidia-cuda-nvrtc-cu12==12.1.105 + - nvidia-cuda-runtime-cu12==12.1.105 + - nvidia-cudnn-cu12==9.1.0.70 + - nvidia-cufft-cu12==11.0.2.54 + - nvidia-curand-cu12==10.3.2.106 + - nvidia-cusolver-cu12==11.4.5.107 + - nvidia-cusparse-cu12==12.1.0.106 + - nvidia-nccl-cu12==2.20.5 + - nvidia-nvjitlink-cu12==12.6.20 + - nvidia-nvtx-cu12==12.1.105 + - oauthlib==3.2.2 + - open3d==0.18.0 + - opencv-python==4.6.0.66 + - openexr==1.3.9 + - opt-einsum==3.3.0 + - optree==0.12.1 + - overrides==7.7.0 + - packaging==24.1 + - pandas==1.5.3 + - pandocfilters==1.5.1 + - parso==0.8.4 + - partd==1.4.2 + - pexpect==4.9.0 + - pillow==9.2.0 + - platformdirs==4.2.2 + - plotly==5.13.1 + - prometheus-client==0.20.0 + - promise==2.3 + - prompt-toolkit==3.0.47 + - protobuf==3.20.3 + - psutil==6.0.0 + - ptyprocess==0.7.0 + - pure-eval==0.2.3 + - pyarrow==10.0.0 + - pyasn1==0.6.0 + - pyasn1-modules==0.4.0 + - pycocotools==2.0.8 + - pycollada==0.8 + - pycparser==2.22 + - pygments==2.18.0 + - pyliblzfse==0.4.1 + - pymeshlab==2023.12.post1 + - pyngrok==7.2.0 + - pyparsing==3.1.4 + - pyquaternion==0.9.9 + - pysocks==1.7.1 + - python-box==6.1.0 + - python-dateutil==2.9.0.post0 + - python-engineio==4.9.1 + - python-json-logger==2.0.7 + - python-socketio==5.11.3 + - pytorch-lightning==2.4.0 + - pytz==2024.1 + - pywavelets==1.6.0 + - pyyaml==6.0.2 + - pyzmq==26.2.0 + - rawpy==0.22.0 + - referencing==0.35.1 + - requests==2.32.3 + - requests-oauthlib==2.0.0 + - requests-toolbelt==1.0.0 + - retrying==1.3.4 + - rfc3339-validator==0.1.4 + - rfc3986-validator==0.1.1 + - rich==13.8.0 + - rpds-py==0.20.0 + - rsa==4.9 + - rtree==1.3.0 + - scikit-image==0.20.0 + - scikit-learn==1.2.2 + - scipy==1.9.1 + - seaborn==0.13.2 + - semantic-version==2.10.0 + - send2trash==1.8.3 + - sentry-sdk==2.13.0 + - setproctitle==1.3.3 + - setuptools==67.6.0 + - shapely==1.8.5.post1 + - shtab==1.7.1 + - simple-websocket==1.0.0 + - simplejson==3.19.3 + - six==1.16.0 + - smmap==5.0.1 + - sniffio==1.3.1 + - soupsieve==2.6 + - splines==0.3.0 + - stack-data==0.6.3 + - svg-path==6.3 + - sympy==1.13.2 + - tenacity==9.0.0 + - tensorboard==2.12.3 + - tensorboard-data-server==0.7.2 + - tensorflow==2.12.0 + - tensorflow-addons==0.23.0 + - tensorflow-datasets==4.9.3 + - tensorflow-estimator==2.12.0 + - tensorflow-graphics==2021.12.3 + - tensorflow-io-gcs-filesystem==0.37.1 + - tensorflow-metadata==1.15.0 + - tensorflow-probability==0.19.0 + - termcolor==2.4.0 + - terminado==0.18.1 + - threadpoolctl==3.5.0 + - tifffile==2024.8.28 + - timm==0.6.7 + - tinycss2==1.3.0 + - toml==0.10.2 + - tomli==2.0.1 + - toolz==0.12.1 + - torch==2.4.0 + - torch-cluster==1.6.3+pt24cu121 + - torch-fidelity==0.3.0 + - torch-geometric==2.5.3 + - torch-scatter==2.1.2+pt24cu121 + - torch-sparse==0.6.18+pt24cu121 + - torchmetrics==1.4.1 + - torchvision==0.19.0 + - tornado==6.4.1 + - tqdm==4.66.5 + - traitlets==5.14.3 + - trimesh==4.4.7 + - triton==3.0.0 + - typeguard==2.13.3 + - types-python-dateutil==2.9.0.20240821 + - typing-extensions==4.12.2 + - tyro==0.8.10 + - tzdata==2024.1 + - uri-template==1.3.0 + - urllib3==2.2.2 + - vhacdx==0.0.8.post1 + - viser==0.1.3 + - visu3d==1.5.1 + - wandb==0.17.8 + - waymo-open-dataset-tf-2-12-0==1.6.4 + - wcwidth==0.2.13 + - webcolors==24.8.0 + - webencodings==0.5.1 + - websocket-client==1.8.0 + - websockets==13.0.1 + - werkzeug==3.0.4 + - widgetsnbextension==4.0.13 + - wrapt==1.14.1 + - wsproto==1.2.0 + - wurlitzer==3.1.1 + - xatlas==0.0.9 + - xxhash==3.5.0 + - yarl==1.9.11 + - yourdfpy==0.0.56 + - zipp==3.20.1 +prefix: /u/xiuyu/anaconda3/envs/traffic diff --git a/backups/run.py b/backups/run.py new file mode 100644 index 0000000000000000000000000000000000000000..c0c714a58bd4232b8b214fa701a84459aee7aa81 --- /dev/null +++ b/backups/run.py @@ -0,0 +1,181 @@ +import pytorch_lightning as pl +import os +import shutil +import fnmatch +import torch +from argparse import ArgumentParser +from pytorch_lightning.callbacks import LearningRateMonitor +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.strategies import DDPStrategy +from pytorch_lightning.loggers import WandbLogger + +from dev.utils.func import RankedLogger, load_config_act, CONSOLE +from dev.datasets.scalable_dataset import MultiDataModule +from dev.model.smart import SMART + + +def backup(source_dir, backup_dir): + """ + Back up the source directory (code and configs) to a backup directory. + """ + + if os.path.exists(backup_dir): + return + os.makedirs(backup_dir, exist_ok=False) + + # Helper function to check if a path matches exclude patterns + def should_exclude(path): + for pattern in exclude_patterns: + if fnmatch.fnmatch(os.path.basename(path), pattern): + return True + return False + + # Iterate through the files and directories in source_dir + for root, dirs, files in os.walk(source_dir): + # Skip excluded directories + dirs[:] = [d for d in dirs if not should_exclude(d)] + + # Determine the relative path and destination path + rel_path = os.path.relpath(root, source_dir) + dest_dir = os.path.join(backup_dir, rel_path) + os.makedirs(dest_dir, exist_ok=True) + + # Copy all relevant files + for file in files: + if any(fnmatch.fnmatch(file, pattern) for pattern in include_patterns): + shutil.copy2(os.path.join(root, file), os.path.join(dest_dir, file)) + + logger.info(f"Backup completed. Files saved to: {backup_dir}") + + +if __name__ == '__main__': + pl.seed_everything(2024, workers=True) + torch.set_printoptions(precision=3) + + parser = ArgumentParser() + parser.add_argument('--config', type=str, default='configs/ours_long_term.yaml') + parser.add_argument('--pretrain_ckpt', type=str, default=None, + help='Path to any pretrained model, will only load its parameters.' + ) + parser.add_argument('--ckpt_path', type=str, default=None, + help='Path to any trained model, will load all the states.' + ) + parser.add_argument('--save_ckpt_path', type=str, default='output/debug', + help='Path to save the checkpoints in training mode' + ) + parser.add_argument('--save_path', type=str, default=None, + help='Path to save the inference results in validation and test mode.' + ) + parser.add_argument('--wandb', action='store_true', + help='Whether to use wandb logger in training.' + ) + parser.add_argument('--devices', type=int, default=1) + parser.add_argument('--train', action='store_true') + parser.add_argument('--validate', action='store_true') + parser.add_argument('--test', action='store_true') + parser.add_argument('--plot_rollouts', action='store_true') + args = parser.parse_args() + + if not (args.train or args.validate or args.test or args.plot_rollouts): + raise RuntimeError(f"Got invalid action, should be one of ['train', 'validate', 'test', 'plot_rollouts']") + + # ! setup logger + logger = RankedLogger(__name__, rank_zero_only=True) + + # ! backup codes + exclude_patterns = ['*output*', '*logs', 'wandb', 'data', '*debug*', '*backup*', 'interact_*', '*edge_map*', '__pycache__'] + include_patterns = ['*.py', '*.json', '*.yaml', '*.yml', '*.sh'] + backup(os.getcwd(), os.path.join(args.save_ckpt_path, 'backups')) + + config = load_config_act(args.config) + + wandb_logger = None + if args.wandb and not int(os.getenv('DEBUG', 0)): + # squeue -O username,state,nodelist,gres,minmemory,numcpus,name + wandb_logger = WandbLogger(project='simagent') + + trainer_config = config.Trainer + max_epochs = trainer_config.max_epochs + + # ! setup datamodule and model + datamodule = MultiDataModule(**vars(config.Dataset), logger=logger) + model = SMART(config.Model, save_path=args.save_ckpt_path, logger=logger, max_epochs=max_epochs) + if args.pretrain_ckpt: + model.load_state_from_file(filename=args.pretrain_ckpt) + strategy = DDPStrategy(find_unused_parameters=True, gradient_as_bucket_view=True) + logger.info(f'Build model: {model.__class__.__name__} datamodule: {datamodule.__class__.__name__}') + + # ! checkpoint configuration + every_n_epochs = 1 + if int(os.getenv('OVERFIT', 0)): + max_epochs = trainer_config.overfit_epochs + every_n_epochs = 100 + + if int(os.getenv('CHECK_INPUTS', 0)): + max_epochs = 1 + + check_val_every_n_epoch = 1 # save checkpoints for each epoch + model_checkpoint = ModelCheckpoint(dirpath=args.save_ckpt_path, + filename='{epoch:02d}', + save_top_k=5, + monitor='epoch', + mode='max', + save_last=True, + every_n_train_steps=1000, + save_on_train_epoch_end=True) + + # ! setup trainer + lr_monitor = LearningRateMonitor(logging_interval='epoch') + trainer = pl.Trainer(accelerator=trainer_config.accelerator, devices=args.devices if args.devices is not None else trainer_config.devices, + strategy=strategy, logger=wandb_logger, + accumulate_grad_batches=trainer_config.accumulate_grad_batches, + num_nodes=trainer_config.num_nodes, + callbacks=[model_checkpoint, lr_monitor], + max_epochs=max_epochs, + num_sanity_val_steps=0, + check_val_every_n_epoch=check_val_every_n_epoch, + log_every_n_steps=1, + gradient_clip_val=0.5) + logger.info(f'Build trainer: {trainer.__class__.__name__}') + + # ! run + if args.train: + + logger.info(f'Start training ...') + trainer.fit(model, datamodule, ckpt_path=args.ckpt_path) + + # NOTE: here both validation and test process use validation split data + # for validation, we enable the online metric calculation with results dumping + # for test, we disable it and only dump the inference results. + else: + + if args.save_path is not None: + save_path = args.save_path + else: + assert args.ckpt_path is not None and os.path.exists(args.ckpt_path), \ + f'Path {args.ckpt_path} not exists!' + save_path = os.path.join(os.path.dirname(args.ckpt_path), 'validation') + os.makedirs(save_path, exist_ok=True) + CONSOLE.log(f'Results will be saved to [yellow]{save_path}[/]') + + model.save_path = save_path + + if not args.ckpt_path: + CONSOLE.log(f'[yellow] Warning: no checkpoint will be loaded in validation! [/]') + + if args.validate: + + CONSOLE.log('[on blue] Start validating ... [/]') + model.set(mode='validation') + + elif args.test: + + CONSOLE.log('[on blue] Sart testing ... [/]') + model.set(mode='test') + + elif args.plot_rollouts: + + CONSOLE.log('[on blue] Sart generating ... [/]') + model.set(mode='plot_rollouts') + + trainer.validate(model, datamodule, ckpt_path=args.ckpt_path) diff --git a/backups/scripts/aggregate_log_metric_features.sh b/backups/scripts/aggregate_log_metric_features.sh new file mode 100644 index 0000000000000000000000000000000000000000..52c198bb0bd15be83f4bf6b013aa2ada455a999a --- /dev/null +++ b/backups/scripts/aggregate_log_metric_features.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +export TF_CPP_MIN_LOG_LEVEL='2' +export PYTHONPATH='.' + +# dump all features +echo 'Start dump all log features ...' +python dev/metrics/compute_metrics.py --dump_log --no_batch + +sleep 20 + +# aggregate features +echo 'Start aggregate log features ...' +python dev/metrics/compute_metrics.py --aggregate_log + +echo 'Done! \ No newline at end of file diff --git a/backups/scripts/c128.sh b/backups/scripts/c128.sh new file mode 100644 index 0000000000000000000000000000000000000000..6c75bd9206e9352417d37e8c8ee29daef953cc66 --- /dev/null +++ b/backups/scripts/c128.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +#SBATCH --job-name c128 # Job name +### Logging +#SBATCH --output=%j.out # Stdout (%j expands to jobId) +#SBATCH --error=%j.err # Stderr (%j expands to jobId) +### Node info +#SBATCH --nodes=1 # Single node or multi node +#SBATCH --time 100:00:00 # Max time (hh:mm:ss) +#SBATCH --gres=gpu:0 # GPUs per node +#SBATCH --mem=128G # Recommend 32G per GPU +#SBATCH --ntasks-per-node=1 # Tasks per node +#SBATCH --cpus-per-task=64 # Recommend 8 per GPU diff --git a/backups/scripts/c64.sh b/backups/scripts/c64.sh new file mode 100644 index 0000000000000000000000000000000000000000..6c75bd9206e9352417d37e8c8ee29daef953cc66 --- /dev/null +++ b/backups/scripts/c64.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +#SBATCH --job-name c128 # Job name +### Logging +#SBATCH --output=%j.out # Stdout (%j expands to jobId) +#SBATCH --error=%j.err # Stderr (%j expands to jobId) +### Node info +#SBATCH --nodes=1 # Single node or multi node +#SBATCH --time 100:00:00 # Max time (hh:mm:ss) +#SBATCH --gres=gpu:0 # GPUs per node +#SBATCH --mem=128G # Recommend 32G per GPU +#SBATCH --ntasks-per-node=1 # Tasks per node +#SBATCH --cpus-per-task=64 # Recommend 8 per GPU diff --git a/backups/scripts/compute_metrics.sh b/backups/scripts/compute_metrics.sh new file mode 100644 index 0000000000000000000000000000000000000000..5d225e1c3988dea24f9f5b125c9e722b7e475c77 --- /dev/null +++ b/backups/scripts/compute_metrics.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +export TORCH_LOGS='0' +export TF_CPP_MIN_LOG_LEVEL='2' +export PYTHONPATH='.' + +NUM_WORKERS=$1 +SIM_DIR=$2 + +echo 'Start running ...' +python dev/metrics/compute_metrics.py --compute_metric --num_workers "$NUM_WORKERS" --sim_dir "$SIM_DIR" ${@:3} + +echo 'Done! \ No newline at end of file diff --git a/backups/scripts/data_preprocess.sh b/backups/scripts/data_preprocess.sh new file mode 100644 index 0000000000000000000000000000000000000000..42178f95571c3a63d1f37f4b0191c3e8610af41a --- /dev/null +++ b/backups/scripts/data_preprocess.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +# env +source ~/anaconda3/etc/profile.d/conda.sh +conda config --append envs_dirs ~/.conda/envs +conda activate traj + +echo "Starting running..." + +# multi-GPU training +cd ~/work/dev6/thirdparty/dev4 +PYTHONPATH='..':$PYTHONPATH python3 data_preprocess.py --split validation diff --git a/backups/scripts/data_preprocess_loop.sh b/backups/scripts/data_preprocess_loop.sh new file mode 100644 index 0000000000000000000000000000000000000000..309ca2a0284b261e8d30a30229d31e727299a65d --- /dev/null +++ b/backups/scripts/data_preprocess_loop.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +RED='\033[0;31m' +NC='\033[0m' + +cd /u/xiuyu/work/dev4/ + +trap "echo -e \"${RED}Stopping script...${NC}\"; kill -- -$$" SIGINT + +while true; do + echo -e "${RED}Start running ...${NC}" + PYTHONPATH='.':$PYTHONPATH setsid python data_preprocess.py --split training & + PID=$! + + sleep 1200 + + echo -e "${RED}Sending SIGINT to process group $PID...${NC}" + PGID=$(ps -o pgid= -p $PID | tail -n 1 | tr -d ' ') + kill -- -$PGID + wait $PID + + sleep 10 +done \ No newline at end of file diff --git a/backups/scripts/debug.py b/backups/scripts/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..26c0092b6d24dfc34a2f33d3068f43358825030d --- /dev/null +++ b/backups/scripts/debug.py @@ -0,0 +1,17 @@ +import torch +from torch import nn +import torch.optim as optim + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +embedding = nn.Embedding(180, 128).to(device) +gt = torch.randint(0, 2, (180, 2048)).to(device) +head = nn.Linear(128, 2048).to(device) +optimizer = optim.Adam([embedding.weight, head.weight]) + +while True: + pred = head(embedding.weight).sigmoid() + loss = nn.MSELoss()(pred, gt.float()) + optimizer.zero_grad() + loss.backward() + optimizer.step() diff --git a/backups/scripts/debug_map.py b/backups/scripts/debug_map.py new file mode 100644 index 0000000000000000000000000000000000000000..a996c2d9638ad9edad9404b51534ed6cab3f17ad --- /dev/null +++ b/backups/scripts/debug_map.py @@ -0,0 +1,204 @@ +import os +import pickle +import torch +import matplotlib.pyplot as plt +import numpy as np +from tqdm import tqdm +from argparse import ArgumentParser +from dev.datasets.preprocess import TokenProcessor +from dev.transforms.target_builder import WaymoTargetBuilder + + +colors = [ + ('#1f77b4', '#1a5a8a'), # blue + ('#2ca02c', '#217721'), # green + ('#ff7f0e', '#cc660b'), # orange + ('#9467bd', '#6f4a91'), # purple + ('#d62728', '#a31d1d'), # red + ('#000000', '#000000'), # black +] + + +def draw_map(tokenize_data, token_processor: TokenProcessor, index, posfix): + print("Drawing raw data ...") + shift = 5 + token_size = 2048 + + traj_token = token_processor.trajectory_token["veh"] + traj_token_all = token_processor.trajectory_token_all["veh"] + + plt.subplots_adjust(left=0.3, right=0.7, top=0.7, bottom=0.3) + fig, ax = plt.subplots() + ax.set_axis_off() + + scenario_id = data['scenario_id'] + ax.scatter(tokenize_data["map_point"]["position"][:, 0], + tokenize_data["map_point"]["position"][:, 1], s=0.2, c='black', edgecolors='none') + + index = np.array(index).astype(np.int32) + agent_data = tokenize_data["agent"] + token_index = agent_data["token_idx"][index] + token_valid_mask = agent_data["agent_valid_mask"][index] + + num_agent, num_token = token_index.shape + tokens = traj_token[token_index.view(-1)].reshape(num_agent, num_token, 4, 2) + tokens_all = traj_token_all[token_index.view(-1)].reshape(num_agent, num_token, 6, 4, 2) + + position = agent_data['position'][index, :, :2] # (num_agent, 91, 2) + heading = agent_data['heading'][index] # (num_agent, 91) + valid_mask = (position[..., 0] != 0) & (position[..., 1] != 0) # (num_agent, 91) + # TODO: fix this + if args.smart: + for shifted_tid in range(token_valid_mask.shape[1]): + valid_mask[:, shifted_tid * shift : (shifted_tid + 1) * shift] = token_valid_mask[:, shifted_tid : shifted_tid + 1].repeat(1, shift) + else: + for shifted_tid in range(token_index.shape[1]): + valid_mask[:, shifted_tid * shift : (shifted_tid + 1) * shift] = token_index[:, shifted_tid : shifted_tid + 1] != token_size + 2 + last_valid_step = valid_mask.shape[1] - 1 - torch.argmax(valid_mask.flip(dims=[1]).long(), dim=1) + last_valid_step = {int(index[i]): int(last_valid_step[i]) for i in range(len(index))} + + _, token_num, token_contour_dim, feat_dim = tokens.shape + tokens_src = tokens.reshape(num_agent, token_num * token_contour_dim, feat_dim) + tokens_all_src = tokens_all.reshape(num_agent, token_num * 6 * token_contour_dim, feat_dim) + prev_heading = heading[:, 0] + prev_pos = position[:, 0] + + fig_paths = [] + agent_colors = np.zeros((num_agent, position.shape[1])) + shape = np.zeros((num_agent, position.shape[1], 2)) + 3. + for tid in tqdm(range(shift, position.shape[1], shift), leave=False, desc="Token ..."): + cos, sin = prev_heading.cos(), prev_heading.sin() + rot_mat = prev_heading.new_zeros(num_agent, 2, 2) + rot_mat[:, 0, 0] = cos + rot_mat[:, 0, 1] = sin + rot_mat[:, 1, 0] = -sin + rot_mat[:, 1, 1] = cos + tokens_world = torch.bmm(torch.from_numpy(tokens_src).float(), rot_mat).reshape(num_agent, + token_num, + token_contour_dim, + feat_dim) + tokens_all_world = torch.bmm(torch.from_numpy(tokens_all_src).float(), rot_mat).reshape(num_agent, + token_num, + 6, + token_contour_dim, + feat_dim) + tokens_world += prev_pos[:, None, None, :2] + tokens_all_world += prev_pos[:, None, None, None, :2] + tokens_select = tokens_world[:, tid // shift - 1] # (num_agent, token_contour_dim, feat_dim) + tokens_all_select = tokens_all_world[:, tid // shift - 1] # (num_agent, 6, token_contour_dim, feat_dim) + + diff_xy = tokens_select[:, 0, :] - tokens_select[:, 3, :] + prev_heading = heading[:, tid].clone() + # prev_heading[valid_mask[:, tid - shift]] = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])[ + # valid_mask[:, tid - shift]] + prev_pos = position[:, tid].clone() + # prev_pos[valid_mask[:, tid - shift]] = tokens_select.mean(dim=1)[valid_mask[:, tid - shift]] + + # NOTE tokens_pos equals to tokens_all_pos[:, -1] + tokens_pos = tokens_select.mean(dim=1) # (num_agent, 2) + tokens_all_pos = tokens_all_select.mean(dim=2) # (num_agent, 6, 2) + + # colors + cur_token_index = token_index[:, tid // shift - 1] + is_bos = cur_token_index == token_size + is_eos = cur_token_index == token_size + 1 + is_invalid = cur_token_index == token_size + 2 + is_valid = ~is_bos & ~is_eos & ~is_invalid + agent_colors[is_valid, tid - shift : tid] = 1 + agent_colors[is_bos, tid - shift : tid] = 2 + agent_colors[is_eos, tid - shift : tid] = 3 + agent_colors[is_invalid, tid - shift : tid] = 4 + + for i in tqdm(range(shift), leave=False, desc="Timestep ..."): + global_tid = tid - shift + i + cur_valid_mask = valid_mask[:, tid - shift] # only when the last tokenized timestep is valid the current shifts trajectory is valid + xs = tokens_all_pos[cur_valid_mask, i, 0] + ys = tokens_all_pos[cur_valid_mask, i, 1] + widths = shape[cur_valid_mask, global_tid, 1] + lengths = shape[cur_valid_mask, global_tid, 0] + angles = heading[cur_valid_mask, global_tid] + cur_agent_colors = agent_colors[cur_valid_mask, global_tid] + current_index = index[cur_valid_mask] + + drawn_agents = [] + drawn_texts = [] + for x, y, width, length, angle, color_type, id in zip( + xs, ys, widths, lengths, angles, cur_agent_colors, current_index): + if x < 3000: continue + agent = plt.Rectangle((x, y), width, length, # angle=((angle + np.pi / 2) / np.pi * 360) % 360, + linewidth=0.2, + facecolor=colors[int(color_type) - 1][0], + edgecolor=colors[int(color_type) - 1][1]) + ax.add_patch(agent) + text = plt.text(x-4, y-4, f"{str(id)}:{str(global_tid)}", fontdict={'family': 'serif', 'size': 3, 'color': 'red'}) + + if global_tid != last_valid_step[id]: + drawn_agents.append(agent) + drawn_texts.append(text) + + # draw timestep to be tokenized + if global_tid % shift == 0: + tokenize_agent = plt.Rectangle((x, y), width, length, # angle=((angle + np.pi / 2) / np.pi * 360) % 360, + linewidth=0.2, fill=False, + edgecolor=colors[int(color_type) - 1][1]) + ax.add_patch(tokenize_agent) + + plt.gca().set_aspect('equal', adjustable='box') + + fig_path = f"debug/tokenize/steps/{scenario_id}_{global_tid}.png" + plt.savefig(fig_path, dpi=600, bbox_inches="tight") + fig_paths.append(fig_path) + + for drawn_agent, drawn_text in zip(drawn_agents, drawn_texts): + drawn_agent.remove() + drawn_text.remove() + + plt.close() + + # generate gif + import imageio.v2 as imageio + images = [] + for fig_path in tqdm(fig_paths, leave=False, desc="Generate gif ..."): + images.append(imageio.imread(fig_path)) + imageio.mimsave(f"debug/tokenize/{scenario_id}_tokenize_{posfix}.gif", images, duration=0.1) + + +def main(data): + + token_size = 2048 + + os.makedirs("debug/tokenize/steps/", exist_ok=True) + scenario_id = data["scenario_id"] + + selected_agents_index = [1, 21, 35, 36, 46] + + # raw data + if not os.path.exists(f"debug/tokenize/{scenario_id}_raw.gif"): + draw_raw(data, selected_agents_index) + + # tokenization + token_processor = TokenProcessor(token_size, disable_invalid=args.smart) + print(f"Loaded token processor with token_size: {token_size}") + data = token_processor.preprocess(data) + + # tokenzied data + posfix = "smart" if args.smart else "ours" + # if not os.path.exists(f"debug/tokenize/{scenario_id}_tokenize_{posfix}.gif"): + draw_tokenize(data, token_processor, selected_agents_index, posfix) + + target_builder = WaymoTargetBuilder(num_historical_steps=11, num_future_steps=80) + data = target_builder(data) + + +if __name__ == "__main__": + parser = ArgumentParser(description="Testing script parameters") + parser.add_argument("--smart", action="store_true") + parser.add_argument("--data_path", type=str, default="/u/xiuyu/work/dev4/data/waymo_processed/training") + args = parser.parse_args() + + scenario_id = "74ad7b76d5906d39" + data_path = os.path.join(args.data_path, f"{scenario_id}.pkl") + data = pickle.load(open(data_path, "rb")) + print(f"Loaded scenario {scenario_id}") + + main(data) \ No newline at end of file diff --git a/backups/scripts/g2.sh b/backups/scripts/g2.sh new file mode 100644 index 0000000000000000000000000000000000000000..ad594c9d6fd2fd14f8810eea43fc2aa17cf6f499 --- /dev/null +++ b/backups/scripts/g2.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +#SBATCH --job-name g2 # Job name +### Logging +#SBATCH --output=%j.out # Stdout (%j expands to jobId) +#SBATCH --error=%j.err # Stderr (%j expands to jobId) +### Node info +#SBATCH --nodes=1 # Single node or multi node +#SBATCH --nodelist=sota-2 +#SBATCH --time 24:00:00 # Max time (hh:mm:ss) +#SBATCH --gres=gpu:2 # GPUs per node +#SBATCH --mem=96G # Recommend 32G per GPU +#SBATCH --ntasks-per-node=1 # Tasks per node +#SBATCH --cpus-per-task=16 # Recommend 8 per GPU + +export NCCL_DEBUG=INFO +export REQUESTS_CA_BUNDLE="/etc/ssl/certs/ca-certificates.crt" +export HTTPS_PROXY="https://192.168.0.10:443/" +export https_proxy="https://192.168.0.10:443/" + +export TEST_VAL_TRAIN=False +export TEST_VAL_PRED=True +export WANDB=True + +sleep 86400 + +cd /u/xiuyu/work/dev4 +PYTHONPATH=".":$PYTHONPATH python3 train.py \ + --devices 2 \ + --config configs/train/train_scalable_with_state.yaml \ + --save_ckpt_path output/seed_1k_pure_seed_150_3_emb_head_3_debug \ + --pretrain_ckpt output/ours_map_pretrain/epoch=31.ckpt + +PYTHONPATH=".":$PYTHONPATH python val.py \ + --config configs/validation/val_scalable_with_state.yaml \ + --save_path output/seed_debug \ + --pretrain_ckpt output/seed_1k_pure_seed_150_3_emb_head_3/last.ckpt \ No newline at end of file diff --git a/backups/scripts/g4.sh b/backups/scripts/g4.sh new file mode 100644 index 0000000000000000000000000000000000000000..484b70500f810e7b2e0b77317ab3b311e57c2924 --- /dev/null +++ b/backups/scripts/g4.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +#SBATCH --job-name g4 # Job name +### Logging +#SBATCH --output=%j.out # Stdout (%j expands to jobId) +#SBATCH --error=%j.err # Stderr (%j expands to jobId) +### Node info +#SBATCH --nodes=1 # Single node or multi node +#SBATCH --nodelist=sota-1 +#SBATCH --time 72:00:00 # Max time (hh:mm:ss) +#SBATCH --gres=gpu:4 # GPUs per node +#SBATCH --mem=128G # Recommend 32G per GPU +#SBATCH --ntasks-per-node=1 # Tasks per node +#SBATCH --cpus-per-task=32 # Recommend 8 per GPU + +export REQUESTS_CA_BUNDLE="/etc/ssl/certs/ca-certificates.crt" +export HTTPS_PROXY="https://192.168.0.10:443/" +export https_proxy="https://192.168.0.10:443/" + +export TEST_VAL_TRAIN=0 +export TEST_VAL_PRED=1 +export WANDB=1 + +sleep 604800 + +cd /u/xiuyu/work/dev4 +PYTHONPATH=".":$PYTHONPATH python3 train.py \ + --devices 4 \ + --config configs/train/train_scalable_with_state.yaml \ + --save_ckpt_path output/seq_1k_10_150_3_3_encode_occ_separate_offsets \ + --pretrain_ckpt output/pretrain_scalable_map/epoch=31.ckpt + +PYTHONPATH=".":$PYTHONPATH python val.py \ + --config configs/ours_long_term.yaml \ + --ckpt_path output/seq_5k_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long/epoch=31.ckpt \ No newline at end of file diff --git a/backups/scripts/g8.sh b/backups/scripts/g8.sh new file mode 100644 index 0000000000000000000000000000000000000000..cb475f5c9c726557efc68b48a6d55dc1040498c5 --- /dev/null +++ b/backups/scripts/g8.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +#SBATCH --job-name g8 # Job name +### Logging +#SBATCH --output=%j.out # Stdout (%j expands to jobId) +#SBATCH --error=%j.err # Stderr (%j expands to jobId) +### Node info +#SBATCH --nodes=1 # Single node or multi node +#SBATCH --nodelist=sota-6 +#SBATCH --time 120:00:00 # Max time (hh:mm:ss) +#SBATCH --gres=gpu:8 # GPUs per node +#SBATCH --mem=256G # Recommend 32G per GPU +#SBATCH --ntasks-per-node=1 # Tasks per node +#SBATCH --cpus-per-task=32 # Recommend 8 per GPU + +export REQUESTS_CA_BUNDLE="/etc/ssl/certs/ca-certificates.crt" +export HTTPS_PROXY="https://192.168.0.10:443/" +export https_proxy="https://192.168.0.10:443/" + +export TEST_VAL_TRAIN=0 +export TEST_VAL_PRED=1 +export WANDB=1 + +sleep 864000 + +cd /u/xiuyu/work/dev4 +PYTHONPATH=".":$PYTHONPATH python3 train.py \ + --devices 8 \ + --config configs/train/train_scalable_long_term.yaml \ + --save_ckpt_path output/seq_5k_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long \ + --pretrain_ckpt output/pretrain_scalable_map/epoch=31.ckpt + +PYTHONPATH=".":$PYTHONPATH python3 train.py \ + --devices 8 \ + --config configs/ours_long_term.yaml \ + --save_ckpt_path output2/seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid + +PYTHONPATH=".":$PYTHONPATH python3 train.py \ + --config configs/ours_long_term.yaml \ + --save_ckpt_path output2/debug + +PYTHONPATH=".":$PYTHONPATH python val.py \ + --config configs/ours_long_term.yaml \ + --ckpt_path output2/bug_seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long/last.ckpt diff --git a/backups/scripts/hf_model.py b/backups/scripts/hf_model.py new file mode 100644 index 0000000000000000000000000000000000000000..0e850e09d44a8b9a754d877af7d43e3d4b8bdee4 --- /dev/null +++ b/backups/scripts/hf_model.py @@ -0,0 +1,111 @@ +import argparse +import os +from huggingface_hub import upload_folder, upload_file, hf_hub_download +from rich.console import Console +from rich.panel import Panel +from rich import box, style +from rich.table import Table + +CONSOLE = Console(width=120) + + +def upload(): + + if args.folder_path: + + try: + if token is not None: + upload_folder(repo_id=args.repo_id, folder_path=args.folder_path, ignore_patterns=ignore_patterns, token=token) + else: + upload_folder(repo_id=args.repo_id, folder_path=args.folder_path, ignore_patterns=ignore_patterns) + table = Table(title=None, show_header=False, box=box.MINIMAL, title_style=style.Style(bold=True)) + table.add_row(f"Model id {args.repo_id}", str(args.folder_path)) + CONSOLE.print(Panel(table, title="[bold][green]:tada: Upload completed DO NOT forget specify the model id in methods! :tada:[/bold]", expand=False)) + + except Exception as e: + CONSOLE.print(f"[bold][yellow]:tada: Upload failed due to {e}.") + raise e + + if args.file_path: + + try: + if token is not None: + upload_file( + path_or_fileobj=args.file_path, + path_in_repo=os.path.basename(args.file_path), + repo_id=args.repo_id, + repo_type='model', + token=token + ) + else: + upload_file( + path_or_fileobj=args.file_path, + path_in_repo=os.path.basename(args.file_path), + repo_id=args.repo_id, + repo_type='model', + ) + table = Table(title=None, show_header=False, box=box.MINIMAL, title_style=style.Style(bold=True)) + table.add_row(f"Model id {args.repo_id}", str(args.file_path)) + CONSOLE.print(Panel(table, title="[bold][green]:tada: Upload completed! :tada:[/bold]", expand=False)) + + except Exception as e: + CONSOLE.print(f"[bold][yellow]:tada: Upload failed due to {e}.") + raise e + + +def download(): + + try: + if token is not None: + ckpt_path = hf_hub_download( + repo_id=args.repo_id, + filename=args.file_path, + token=token + ) + else: + ckpt_path = hf_hub_download( + repo_id=args.repo_id, + filename=args.file_path, + ) + table = Table(title=None, show_header=False, box=box.MINIMAL, title_style=style.Style(bold=True)) + table.add_row(f"Model id {args.repo_id}", str(args.file_path)) + CONSOLE.print(Panel(table, title=f"[bold][green]:tada: Download completed to {ckpt_path}! :tada:[/bold]", expand=False)) + + if args.save_path is not None: + os.makedirs(args.save_path, exist_ok=True) + import shutil + shutil.copy(ckpt_path, os.path.join(args.save_path, args.file_path)) + + except Exception as e: + CONSOLE.print(f"[bold][yellow]:tada: Download failed due to {e}.") + raise e + + return ckpt_path + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--repo_id", type=str, default=None, required=True) + parser.add_argument("--upload", action="store_true") + parser.add_argument("--download", action="store_true") + parser.add_argument("--folder_path", type=str, default=None, required=False) + parser.add_argument("--file_path", type=str, default=None, required=False) + parser.add_argument("--save_path", type=str, default=None, required=False) + parser.add_argument("--token", type=str, default=None, required=False) + args = parser.parse_args() + + token = args.token or os.getenv("hf_token", None) + ignore_patterns = ["**/optimizer.bin", "**/random_states*", "**/scaler.pt", "**/scheduler.bin"] + + if not (args.folder_path or args.file_path): + raise RuntimeError(f'Choose either folder path or file path please!') + + if len(args.repo_id.split('/')) != 2: + raise RuntimeError(f'Invalid repo_id: {args.repo_id}, please use in [use-id]/[repo-name] format') + CONSOLE.log(f"Use repo: [bold][yellow] {args.repo_id}") + + if args.upload: + upload() + + if args.download: + download() diff --git a/backups/scripts/pretrain_map.sh b/backups/scripts/pretrain_map.sh new file mode 100644 index 0000000000000000000000000000000000000000..fb332f8be4c0467b94a4cc2de069970e5e2fe8b3 --- /dev/null +++ b/backups/scripts/pretrain_map.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +mkdir -p job_out + +#SBATCH --job-name YOUR_JOB_NAME # Job name +### Logging +#SBATCH --output=job_out/%j.out # Stdout (%j expands to jobId) +#SBATCH --error=job_out/%j.err # Stderr (%j expands to jobId) +### Node info +#SBATCH --nodes=1 # Single node or multi node +#SBATCH --nodelist=sota-6 +#SBATCH --time 20:00:00 # Max time (hh:mm:ss) +#SBATCH --gres=gpu:4 # GPUs per node +#SBATCH --mem=256G # Recommend 32G per GPU +#SBATCH --ntasks-per-node=4 # Tasks per node +#SBATCH --cpus-per-task=256 # Recommend 8 per GPU +### Whatever your job needs to do + +export REQUESTS_CA_BUNDLE="/etc/ssl/certs/ca-certificates.crt" +export HTTPS_PROXY="https://192.168.0.10:443/" +export https_proxy="https://192.168.0.10:443/" + +export TEST_VAL_PRED=True +export WANDB=True + +cd /u/xiuyu/work/dev4 +PYTHONPATH=".":$PYTHONPATH python3 train.py --config configs/train/pretrain_scalable_map.yaml --save_ckpt_path output/ours_map_pretrain diff --git a/backups/scripts/run_eval.sh b/backups/scripts/run_eval.sh new file mode 100644 index 0000000000000000000000000000000000000000..494ec55ac99385b482053ea4669656b333c12684 --- /dev/null +++ b/backups/scripts/run_eval.sh @@ -0,0 +1,20 @@ +#! /bin/bash + +# env +export REQUESTS_CA_BUNDLE="/etc/ssl/certs/ca-certificates.crt" +export HTTPS_PROXY="https://192.168.0.10:443/" +export https_proxy="https://192.168.0.10:443/" + +export WANDB=1 + +# args +DEVICES=$1 +CONFIG='configs/ours_long_term.yaml' +# CKPT_PATH='output/scalable_smart_long/last.ckpt' +CKPT_PATH='output2/seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long/last.ckpt' + +# run +PYTHONPATH=".":$PYTHONPATH python3 run.py \ + --devices $DEVICES \ + --config $CONFIG \ + --ckpt_path $CKPT_PATH ${@:2} diff --git a/backups/scripts/run_train.sh b/backups/scripts/run_train.sh new file mode 100644 index 0000000000000000000000000000000000000000..9e8b1a2bbcc420b4ff769f7a53d077395d27aeba --- /dev/null +++ b/backups/scripts/run_train.sh @@ -0,0 +1,20 @@ +#! /bin/bash + +# env +export REQUESTS_CA_BUNDLE="/etc/ssl/certs/ca-certificates.crt" +export HTTPS_PROXY="https://192.168.0.10:443/" +export https_proxy="https://192.168.0.10:443/" + +export WANDB=1 + +# args +DEVICES=$1 +CONFIG='configs/ours_long_term.yaml' +SAVE_CKPT_PATH='output/scalable_smart_long' + +# run +PYTHONPATH=".":$PYTHONPATH python3 run.py \ + --train \ + --devices $DEVICES \ + --config $CONFIG \ + --save_ckpt_path $SAVE_CKPT_PATH diff --git a/backups/scripts/test.py b/backups/scripts/test.py new file mode 100644 index 0000000000000000000000000000000000000000..883c4975017364e61d171467b1a92c61b386967e --- /dev/null +++ b/backups/scripts/test.py @@ -0,0 +1,15 @@ +import torch + + +def b(a): + a += 3 + print(a) + a[a==4] += 3 + print(a) + return a + + +a = torch.ones(10, 2).cuda() +print(a) +b(a) +print(a) diff --git a/backups/scripts/traj_clustering.py b/backups/scripts/traj_clustering.py new file mode 100644 index 0000000000000000000000000000000000000000..1f1ade102d32329cb1363ec546b6c770be3b1d58 --- /dev/null +++ b/backups/scripts/traj_clustering.py @@ -0,0 +1,295 @@ +""" Adapted from NVLabs/CatK """ + +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from typing import Optional, Tuple + +import torch +from omegaconf import DictConfig +from torch import Tensor +from torch.distributions import Categorical, Independent, MixtureSameFamily, Normal + + +@torch.no_grad() +def cal_polygon_contour( + pos: Tensor, # [n_agent, n_step, n_target, 2] + head: Tensor, # [n_agent, n_step, n_target] + width_length: Tensor, # [n_agent, 1, 1, 2] +) -> Tensor: # [n_agent, n_step, n_target, 4, 2] + x, y = pos[..., 0], pos[..., 1] # [n_agent, n_step, n_target] + width, length = width_length[..., 0], width_length[..., 1] # [n_agent, 1 ,1] + + half_cos = 0.5 * head.cos() # [n_agent, n_step, n_target] + half_sin = 0.5 * head.sin() # [n_agent, n_step, n_target] + length_cos = length * half_cos # [n_agent, n_step, n_target] + length_sin = length * half_sin # [n_agent, n_step, n_target] + width_cos = width * half_cos # [n_agent, n_step, n_target] + width_sin = width * half_sin # [n_agent, n_step, n_target] + + left_front_x = x + length_cos - width_sin + left_front_y = y + length_sin + width_cos + left_front = torch.stack((left_front_x, left_front_y), dim=-1) + + right_front_x = x + length_cos + width_sin + right_front_y = y + length_sin - width_cos + right_front = torch.stack((right_front_x, right_front_y), dim=-1) + + right_back_x = x - length_cos + width_sin + right_back_y = y - length_sin - width_cos + right_back = torch.stack((right_back_x, right_back_y), dim=-1) + + left_back_x = x - length_cos - width_sin + left_back_y = y - length_sin + width_cos + left_back = torch.stack((left_back_x, left_back_y), dim=-1) + + polygon_contour = torch.stack( + (left_front, right_front, right_back, left_back), dim=-2 + ) + + return polygon_contour + + +def transform_to_global( + pos_local: Tensor, # [n_agent, n_step, 2] + head_local: Optional[Tensor], # [n_agent, n_step] + pos_now: Tensor, # [n_agent, 2] + head_now: Tensor, # [n_agent] +) -> Tuple[Tensor, Optional[Tensor]]: + cos, sin = head_now.cos(), head_now.sin() + rot_mat = torch.zeros((head_now.shape[0], 2, 2), device=head_now.device) + rot_mat[:, 0, 0] = cos + rot_mat[:, 0, 1] = sin + rot_mat[:, 1, 0] = -sin + rot_mat[:, 1, 1] = cos + + pos_global = torch.bmm(pos_local, rot_mat) # [n_agent, n_step, 2]*[n_agent, 2, 2] + pos_global = pos_global + pos_now.unsqueeze(1) + if head_local is None: + head_global = None + else: + head_global = head_local + head_now.unsqueeze(1) + return pos_global, head_global + + +def transform_to_local( + pos_global: Tensor, # [n_agent, n_step, 2] + head_global: Optional[Tensor], # [n_agent, n_step] + pos_now: Tensor, # [n_agent, 2] + head_now: Tensor, # [n_agent] +) -> Tuple[Tensor, Optional[Tensor]]: + cos, sin = head_now.cos(), head_now.sin() + rot_mat = torch.zeros((head_now.shape[0], 2, 2), device=head_now.device) + rot_mat[:, 0, 0] = cos + rot_mat[:, 0, 1] = -sin + rot_mat[:, 1, 0] = sin + rot_mat[:, 1, 1] = cos + + pos_local = pos_global - pos_now.unsqueeze(1) + pos_local = torch.bmm(pos_local, rot_mat) # [n_agent, n_step, 2]*[n_agent, 2, 2] + if head_global is None: + head_local = None + else: + head_local = head_global - head_now.unsqueeze(1) + return pos_local, head_local + + +def sample_next_token_traj( + token_traj: Tensor, # [n_agent, n_token, 4, 2] + token_traj_all: Tensor, # [n_agent, n_token, 6, 4, 2] + sampling_scheme: DictConfig, + # ! for most-likely sampling + next_token_logits: Tensor, # [n_agent, n_token], with grad + # ! for nearest-pos sampling, sampling near to GT + pos_now: Tensor, # [n_agent, 2] + head_now: Tensor, # [n_agent] + pos_next_gt: Tensor, # [n_agent, 2] + head_next_gt: Tensor, # [n_agent] + valid_next_gt: Tensor, # [n_agent] + token_agent_shape: Tensor, # [n_agent, 2] +) -> Tuple[Tensor, Tensor]: + """ + Returns: + next_token_traj_all: [n_agent, 6, 4, 2], local coord + next_token_idx: [n_agent], without grad + """ + range_a = torch.arange(next_token_logits.shape[0]) + next_token_logits = next_token_logits.detach() + + if ( + sampling_scheme.criterium == "topk_prob" + or sampling_scheme.criterium == "topk_prob_sampled_with_dist" + ): + topk_logits, topk_indices = torch.topk( + next_token_logits, sampling_scheme.num_k, dim=-1, sorted=False + ) + if sampling_scheme.criterium == "topk_prob_sampled_with_dist": + #! gt_contour: [n_agent, 4, 2] in global coord + gt_contour = cal_polygon_contour( + pos_next_gt, head_next_gt, token_agent_shape + ) + gt_contour = gt_contour.unsqueeze(1) # [n_agent, 1, 4, 2] + token_world_sample = token_traj[range_a.unsqueeze(1), topk_indices] + token_world_sample = transform_to_global( + pos_local=token_world_sample.flatten(1, 2), + head_local=None, + pos_now=pos_now, # [n_agent, 2] + head_now=head_now, # [n_agent] + )[0].view(*token_world_sample.shape) + + # dist: [n_agent, n_token] + dist = torch.norm(token_world_sample - gt_contour, dim=-1).mean(-1) + topk_logits = topk_logits.masked_fill( + valid_next_gt.unsqueeze(1), 0.0 + ) - 1.0 * dist.masked_fill(~valid_next_gt.unsqueeze(1), 0.0) + elif sampling_scheme.criterium == "topk_dist_sampled_with_prob": + #! gt_contour: [n_agent, 4, 2] in global coord + gt_contour = cal_polygon_contour(pos_next_gt, head_next_gt, token_agent_shape) + gt_contour = gt_contour.unsqueeze(1) # [n_agent, 1, 4, 2] + token_world_sample = transform_to_global( + pos_local=token_traj.flatten(1, 2), # [n_agent, n_token*4, 2] + head_local=None, + pos_now=pos_now, # [n_agent, 2] + head_now=head_now, # [n_agent] + )[0].view(*token_traj.shape) + + _invalid = ~valid_next_gt + # dist: [n_agent, n_token] + dist = torch.norm(token_world_sample - gt_contour, dim=-1).mean(-1) + _logits = -1.0 * dist.masked_fill(_invalid.unsqueeze(1), 0.0) + + if _invalid.any(): + _logits[_invalid] = next_token_logits[_invalid] + _, topk_indices = torch.topk( + _logits, sampling_scheme.num_k, dim=-1, sorted=False + ) # [n_agent, K] + topk_logits = next_token_logits[range_a.unsqueeze(1), topk_indices] + + else: + raise ValueError(f"Invalid criterium: {sampling_scheme.criterium}") + + # topk_logits, topk_indices: [n_agent, K] + topk_logits = topk_logits / sampling_scheme.temp + samples = Categorical(logits=topk_logits).sample() # [n_agent] in K + next_token_idx = topk_indices[range_a, samples] + next_token_traj_all = token_traj_all[range_a, next_token_idx] + + return next_token_idx, next_token_traj_all + + +def sample_next_gmm_traj( + token_traj: Tensor, # [n_agent, n_token, 4, 2] + token_traj_all: Tensor, # [n_agent, n_token, 6, 4, 2] + sampling_scheme: DictConfig, + # ! for most-likely sampling + ego_mask: Tensor, # [n_agent], bool, ego_mask.sum()==n_batch + ego_next_logits: Tensor, # [n_batch, n_k_ego_gmm] + ego_next_poses: Tensor, # [n_batch, n_k_ego_gmm, 3] + ego_next_cov: Tensor, # [2], one for pos, one for heading. + # ! for nearest-pos sampling, sampling near to GT + pos_now: Tensor, # [n_agent, 2] + head_now: Tensor, # [n_agent] + pos_next_gt: Tensor, # [n_agent, 2] + head_next_gt: Tensor, # [n_agent] + valid_next_gt: Tensor, # [n_agent] + token_agent_shape: Tensor, # [n_agent, 2] + next_token_idx: Tensor, # [n_agent] +) -> Tuple[Tensor, Tensor]: + """ + Returns: + next_token_traj_all: [n_agent, 6, 4, 2], local coord + next_token_idx: [n_agent], without grad + """ + n_agent = token_traj.shape[0] + n_batch = ego_next_logits.shape[0] + next_token_traj_all = token_traj_all[torch.arange(n_agent), next_token_idx] + + # ! sample only the ego-vehicle + assert ( + sampling_scheme.criterium == "topk_prob" + or sampling_scheme.criterium == "topk_prob_sampled_with_dist" + ) + topk_logits, topk_indices = torch.topk( + ego_next_logits, sampling_scheme.num_k, dim=-1, sorted=False + ) # [n_agent, k], [n_agent, k] + ego_pose_topk = ego_next_poses[ + torch.arange(n_batch).unsqueeze(1), topk_indices + ] # [n_batch, k, 3] + + if sampling_scheme.criterium == "topk_prob_sampled_with_dist": + # udpate topk_logits + gt_contour = cal_polygon_contour( + pos_next_gt[ego_mask], + head_next_gt[ego_mask], + token_agent_shape[ego_mask], + ) # [n_batch, 4, 2] in global coord + gt_contour = gt_contour.unsqueeze(1) # [n_batch, 1, 4, 2] + + ego_pos_global, ego_head_global = transform_to_global( + pos_local=ego_pose_topk[:, :, :2], # [n_batch, k, 2] + head_local=ego_pose_topk[:, :, -1], # [n_batch, k] + pos_now=pos_now[ego_mask], # [n_batch, 2] + head_now=head_now[ego_mask], # [n_batch] + ) + ego_contour = cal_polygon_contour( + ego_pos_global, # [n_batch, k, 2] + ego_head_global, # [n_batch, k] + token_agent_shape[ego_mask].unsqueeze(1), + ) # [n_batch, k, 4, 2] in global coord + + dist = torch.norm(ego_contour - gt_contour, dim=-1).mean(-1) # [n_batch, k] + topk_logits = topk_logits.masked_fill( + valid_next_gt[ego_mask].unsqueeze(1), 0.0 + ) - 1.0 * dist.masked_fill(~valid_next_gt[ego_mask].unsqueeze(1), 0.0) + + topk_logits = topk_logits / sampling_scheme.temp_mode # [n_batch, k] + ego_pose_topk = torch.cat( + [ + ego_pose_topk[..., :2], + ego_pose_topk[..., [-1]].cos(), + ego_pose_topk[..., [-1]].sin(), + ], + dim=-1, + ) + cov = ( + (ego_next_cov * sampling_scheme.temp_cov) + .repeat_interleave(2)[None, None, :] + .expand(*ego_pose_topk.shape) + ) # [n_batch, k, 4] + gmm = MixtureSameFamily( + Categorical(logits=topk_logits), Independent(Normal(ego_pose_topk, cov), 1) + ) + ego_sample = gmm.sample() # [n_batch, 4] + + ego_contour_local = cal_polygon_contour( + ego_sample[:, :2], # [n_batch, 2] + torch.arctan2(ego_sample[:, -1], ego_sample[:, -2]), # [n_batch] + token_agent_shape[ego_mask], # [n_batch, 2] + ) # [n_batch, 4, 2] in local coord + + ego_token_local = token_traj[ego_mask] # [n_batch, n_token, 4, 2] + + dist = torch.norm(ego_contour_local.unsqueeze(1) - ego_token_local, dim=-1).mean( + -1 + ) # [n_batch, n_token] + next_token_idx[ego_mask] = dist.argmin(-1) + + ego_contour_local # [n_batch, 4, 2] in local coord + ego_countour_start = next_token_traj_all[ego_mask][:, 0] # [n_batch, 4, 2] + n_step = next_token_traj_all.shape[1] + diff = (ego_contour_local - ego_countour_start) / (n_step - 1) + ego_token_interp = [ego_countour_start + diff * i for i in range(n_step)] + # [n_batch, 6, 4, 2] + next_token_traj_all[ego_mask] = torch.stack(ego_token_interp, dim=1) + + return next_token_idx, next_token_traj_all diff --git a/backups/scripts/wosca_config.json b/backups/scripts/wosca_config.json new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backups/thirdparty/SMART/__init__.py b/backups/thirdparty/SMART/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/backups/thirdparty/SMART/__init__.py @@ -0,0 +1 @@ + diff --git a/backups/thirdparty/SMART/configs/train/train_scalable.yaml b/backups/thirdparty/SMART/configs/train/train_scalable.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6a67e27aaea2b3854163c8a40cd98b2bd4ee942c --- /dev/null +++ b/backups/thirdparty/SMART/configs/train/train_scalable.yaml @@ -0,0 +1,62 @@ +# Config format schema number, the yaml support to valid case source from different dataset +time_info: &time_info + num_historical_steps: 11 + num_future_steps: 80 + use_intention: True + token_size: 2048 + +Dataset: + root: + train_batch_size: 1 + val_batch_size: 1 + test_batch_size: 1 + shuffle: True + num_workers: 1 + pin_memory: True + persistent_workers: True + train_raw_dir: ["data/valid_demo"] + val_raw_dir: ["data/valid_demo"] + test_raw_dir: + transform: WaymoTargetBuilder + train_processed_dir: + val_processed_dir: + test_processed_dir: + dataset: "scalable" + <<: *time_info + +Trainer: + strategy: ddp_find_unused_parameters_false + accelerator: "gpu" + devices: 1 + max_epochs: 32 + save_ckpt_path: + num_nodes: 1 + mode: + ckpt_path: + precision: 32 + accumulate_grad_batches: 1 + +Model: + mode: "train" + predictor: "smart" + dataset: "waymo" + input_dim: 2 + hidden_dim: 128 + output_dim: 2 + output_head: False + num_heads: 8 + <<: *time_info + head_dim: 16 + dropout: 0.1 + num_freq_bands: 64 + lr: 0.0005 + warmup_steps: 0 + total_steps: 32 + decoder: + <<: *time_info + num_map_layers: 3 + num_agent_layers: 6 + a2a_radius: 60 + pl2pl_radius: 10 + pl2a_radius: 30 + time_span: 30 diff --git a/backups/thirdparty/SMART/configs/validation/validation_scalable.yaml b/backups/thirdparty/SMART/configs/validation/validation_scalable.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ca93b84c150b66aacdd238348e4589352517a325 --- /dev/null +++ b/backups/thirdparty/SMART/configs/validation/validation_scalable.yaml @@ -0,0 +1,60 @@ +# Config format schema number, the yaml support to valid case source from different dataset +time_info: &time_info + num_historical_steps: 11 + num_future_steps: 80 + token_size: 2048 + +Dataset: + root: + batch_size: 1 + shuffle: True + num_workers: 1 + pin_memory: True + persistent_workers: True + train_raw_dir: + val_raw_dir: ["data/valid_demo"] + test_raw_dir: + TargetBuilder: WaymoTargetBuilder + train_processed_dir: + val_processed_dir: + test_processed_dir: + dataset: "scalable" + <<: *time_info + +Trainer: + strategy: ddp_find_unused_parameters_false + accelerator: "gpu" + devices: 1 + max_epochs: 32 + save_ckpt_path: + num_nodes: 1 + mode: + ckpt_path: + precision: 32 + accumulate_grad_batches: 1 + +Model: + mode: "validation" + predictor: "smart" + dataset: "waymo" + input_dim: 2 + hidden_dim: 128 + output_dim: 2 + output_head: False + num_heads: 8 + <<: *time_info + head_dim: 16 + dropout: 0.1 + num_freq_bands: 64 + lr: 0.0005 + warmup_steps: 0 + total_steps: 32 + decoder: + <<: *time_info + num_map_layers: 3 + num_agent_layers: 6 + a2a_radius: 60 + pl2pl_radius: 10 + pl2a_radius: 30 + time_span: 30 + diff --git a/backups/thirdparty/SMART/data_preprocess.py b/backups/thirdparty/SMART/data_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..e55b1d3f5c9e74954309692a169d15e5c042e534 --- /dev/null +++ b/backups/thirdparty/SMART/data_preprocess.py @@ -0,0 +1,714 @@ +import numpy as np +import pandas as pd +import os +import torch +import pickle +from tqdm import tqdm +from typing import Any, Dict, List, Optional +import easydict + +predict_unseen_agents = False +vector_repr = True +root = '' +split = 'train' +raw_dir = os.path.join(root, split, 'raw') +_raw_dir = raw_dir + +if os.path.isdir(_raw_dir): + _raw_file_names = [name for name in os.listdir(_raw_dir)] +else: + _raw_file_names = [] + +processed_dir = os.path.join(root, split, 'processed') +_processed_dir = processed_dir +if os.path.isdir(_processed_dir): + _processed_file_names = [name for name in os.listdir(_processed_dir) if + name.endswith(('pkl', 'pickle'))] +else: + _processed_file_names = [] + +_agent_types = ['vehicle', 'pedestrian', 'cyclist', 'background'] +_polygon_types = ['VEHICLE', 'BIKE', 'BUS', 'PEDESTRIAN'] +_polygon_light_type = ['LANE_STATE_STOP', 'LANE_STATE_GO', 'LANE_STATE_CAUTION', 'LANE_STATE_UNKNOWN'] +_point_types = ['DASH_SOLID_YELLOW', 'DASH_SOLID_WHITE', 'DASHED_WHITE', 'DASHED_YELLOW', + 'DOUBLE_SOLID_YELLOW', 'DOUBLE_SOLID_WHITE', 'DOUBLE_DASH_YELLOW', 'DOUBLE_DASH_WHITE', + 'SOLID_YELLOW', 'SOLID_WHITE', 'SOLID_DASH_WHITE', 'SOLID_DASH_YELLOW', 'EDGE', + 'NONE', 'UNKNOWN', 'CROSSWALK', 'CENTERLINE'] +_point_sides = ['LEFT', 'RIGHT', 'CENTER'] +_polygon_to_polygon_types = ['NONE', 'PRED', 'SUCC', 'LEFT', 'RIGHT'] +_polygon_is_intersections = [True, False, None] + + +Lane_type_hash = { + 4: "BIKE", + 3: "VEHICLE", + 2: "VEHICLE", + 1: "BUS" +} + +boundary_type_hash = { + 5: "UNKNOWN", + 6: "DASHED_WHITE", + 7: "SOLID_WHITE", + 8: "DOUBLE_DASH_WHITE", + 9: "DASHED_YELLOW", + 10: "DOUBLE_DASH_YELLOW", + 11: "SOLID_YELLOW", + 12: "DOUBLE_SOLID_YELLOW", + 13: "DASH_SOLID_YELLOW", + 14: "UNKNOWN", + 15: "EDGE", + 16: "EDGE" +} + + +def safe_list_index(ls: List[Any], elem: Any) -> Optional[int]: + try: + return ls.index(elem) + except ValueError: + return None + + +def get_agent_features(df: pd.DataFrame, av_id, num_historical_steps=10, dim=3, num_steps=91) -> Dict[str, Any]: + if not predict_unseen_agents: # filter out agents that are unseen during the historical time steps + historical_df = df[df['timestep'] == num_historical_steps-1] + agent_ids = list(historical_df['track_id'].unique()) + df = df[df['track_id'].isin(agent_ids)] + else: + agent_ids = list(df['track_id'].unique()) + + num_agents = len(agent_ids) + # initialization + valid_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool) + current_valid_mask = torch.zeros(num_agents, dtype=torch.bool) + predict_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool) + agent_id: List[Optional[str]] = [None] * num_agents + agent_type = torch.zeros(num_agents, dtype=torch.uint8) + agent_category = torch.zeros(num_agents, dtype=torch.uint8) + position = torch.zeros(num_agents, num_steps, dim, dtype=torch.float) + heading = torch.zeros(num_agents, num_steps, dtype=torch.float) + velocity = torch.zeros(num_agents, num_steps, dim, dtype=torch.float) + shape = torch.zeros(num_agents, num_steps, dim, dtype=torch.float) + + for track_id, track_df in df.groupby('track_id'): + agent_idx = agent_ids.index(track_id) + agent_steps = track_df['timestep'].values + + valid_mask[agent_idx, agent_steps] = True + current_valid_mask[agent_idx] = valid_mask[agent_idx, num_historical_steps - 1] + predict_mask[agent_idx, agent_steps] = True + if vector_repr: # a time step t is valid only when both t and t-1 are valid + valid_mask[agent_idx, 1: num_historical_steps] = ( + valid_mask[agent_idx, :num_historical_steps - 1] & + valid_mask[agent_idx, 1: num_historical_steps]) + valid_mask[agent_idx, 0] = False + predict_mask[agent_idx, :num_historical_steps] = False + if not current_valid_mask[agent_idx]: + predict_mask[agent_idx, num_historical_steps:] = False + + agent_id[agent_idx] = track_id + agent_type[agent_idx] = _agent_types.index(track_df['object_type'].values[0]) + agent_category[agent_idx] = track_df['object_category'].values[0] + position[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['position_x'].values, + track_df['position_y'].values, + track_df['position_z'].values], + axis=-1)).float() + heading[agent_idx, agent_steps] = torch.from_numpy(track_df['heading'].values).float() + velocity[agent_idx, agent_steps, :2] = torch.from_numpy(np.stack([track_df['velocity_x'].values, + track_df['velocity_y'].values], + axis=-1)).float() + shape[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['length'].values, + track_df['width'].values, + track_df["height"].values], + axis=-1)).float() + av_idx = agent_id.index(av_id) + if split == 'test': + predict_mask[current_valid_mask + | (agent_category == 2) + | (agent_category == 3), num_historical_steps:] = True + + return { + 'num_nodes': num_agents, + 'av_index': av_idx, + 'valid_mask': valid_mask, + 'predict_mask': predict_mask, + 'id': agent_id, + 'type': agent_type, + 'category': agent_category, + 'position': position, + 'heading': heading, + 'velocity': velocity, + 'shape': shape + } + + +def get_map_features(map_infos, tf_current_light, dim=3): + lane_segments = map_infos['lane'] + all_polylines = map_infos["all_polylines"] + crosswalks = map_infos['crosswalk'] + road_edges = map_infos['road_edge'] + road_lines = map_infos['road_line'] + lane_segment_ids = [info["id"] for info in lane_segments] + cross_walk_ids = [info["id"] for info in crosswalks] + road_edge_ids = [info["id"] for info in road_edges] + road_line_ids = [info["id"] for info in road_lines] + polygon_ids = lane_segment_ids + road_edge_ids + road_line_ids + cross_walk_ids + num_polygons = len(lane_segment_ids) + len(road_edge_ids) + len(road_line_ids) + len(cross_walk_ids) + + # initialization + polygon_type = torch.zeros(num_polygons, dtype=torch.uint8) + polygon_light_type = torch.ones(num_polygons, dtype=torch.uint8) * 3 + + point_position: List[Optional[torch.Tensor]] = [None] * num_polygons + point_orientation: List[Optional[torch.Tensor]] = [None] * num_polygons + point_magnitude: List[Optional[torch.Tensor]] = [None] * num_polygons + point_height: List[Optional[torch.Tensor]] = [None] * num_polygons + point_type: List[Optional[torch.Tensor]] = [None] * num_polygons + + for lane_segment in lane_segments: + lane_segment = easydict.EasyDict(lane_segment) + lane_segment_idx = polygon_ids.index(lane_segment.id) + polyline_index = lane_segment.polyline_index + centerline = all_polylines[polyline_index[0]:polyline_index[1], :] + centerline = torch.from_numpy(centerline).float() + polygon_type[lane_segment_idx] = _polygon_types.index(Lane_type_hash[lane_segment.type]) + + res = tf_current_light[tf_current_light["lane_id"] == str(lane_segment.id)] + if len(res) != 0: + polygon_light_type[lane_segment_idx] = _polygon_light_type.index(res["state"].item()) + + point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0) + center_vectors = centerline[1:] - centerline[:-1] + point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0) + point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1) + point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0) + center_type = _point_types.index('CENTERLINE') + point_type[lane_segment_idx] = torch.cat( + [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0) + + for lane_segment in road_edges: + lane_segment = easydict.EasyDict(lane_segment) + lane_segment_idx = polygon_ids.index(lane_segment.id) + polyline_index = lane_segment.polyline_index + centerline = all_polylines[polyline_index[0]:polyline_index[1], :] + centerline = torch.from_numpy(centerline).float() + polygon_type[lane_segment_idx] = _polygon_types.index("VEHICLE") + + point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0) + center_vectors = centerline[1:] - centerline[:-1] + point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0) + point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1) + point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0) + center_type = _point_types.index('EDGE') + point_type[lane_segment_idx] = torch.cat( + [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0) + + for lane_segment in road_lines: + lane_segment = easydict.EasyDict(lane_segment) + lane_segment_idx = polygon_ids.index(lane_segment.id) + polyline_index = lane_segment.polyline_index + centerline = all_polylines[polyline_index[0]:polyline_index[1], :] + centerline = torch.from_numpy(centerline).float() + + polygon_type[lane_segment_idx] = _polygon_types.index("VEHICLE") + + point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0) + center_vectors = centerline[1:] - centerline[:-1] + point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0) + point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1) + point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0) + center_type = _point_types.index(boundary_type_hash[lane_segment.type]) + point_type[lane_segment_idx] = torch.cat( + [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0) + + for crosswalk in crosswalks: + crosswalk = easydict.EasyDict(crosswalk) + lane_segment_idx = polygon_ids.index(crosswalk.id) + polyline_index = crosswalk.polyline_index + centerline = all_polylines[polyline_index[0]:polyline_index[1], :] + centerline = torch.from_numpy(centerline).float() + + polygon_type[lane_segment_idx] = _polygon_types.index("PEDESTRIAN") + + point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0) + center_vectors = centerline[1:] - centerline[:-1] + point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0) + point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1) + point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0) + center_type = _point_types.index("CROSSWALK") + point_type[lane_segment_idx] = torch.cat( + [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0) + + num_points = torch.tensor([point.size(0) for point in point_position], dtype=torch.long) + point_to_polygon_edge_index = torch.stack( + [torch.arange(num_points.sum(), dtype=torch.long), + torch.arange(num_polygons, dtype=torch.long).repeat_interleave(num_points)], dim=0) + polygon_to_polygon_edge_index = [] + polygon_to_polygon_type = [] + for lane_segment in lane_segments: + lane_segment = easydict.EasyDict(lane_segment) + lane_segment_idx = polygon_ids.index(lane_segment.id) + pred_inds = [] + for pred in lane_segment.entry_lanes: + pred_idx = safe_list_index(polygon_ids, pred) + if pred_idx is not None: + pred_inds.append(pred_idx) + if len(pred_inds) != 0: + polygon_to_polygon_edge_index.append( + torch.stack([torch.tensor(pred_inds, dtype=torch.long), + torch.full((len(pred_inds),), lane_segment_idx, dtype=torch.long)], dim=0)) + polygon_to_polygon_type.append( + torch.full((len(pred_inds),), _polygon_to_polygon_types.index('PRED'), dtype=torch.uint8)) + succ_inds = [] + for succ in lane_segment.exit_lanes: + succ_idx = safe_list_index(polygon_ids, succ) + if succ_idx is not None: + succ_inds.append(succ_idx) + if len(succ_inds) != 0: + polygon_to_polygon_edge_index.append( + torch.stack([torch.tensor(succ_inds, dtype=torch.long), + torch.full((len(succ_inds),), lane_segment_idx, dtype=torch.long)], dim=0)) + polygon_to_polygon_type.append( + torch.full((len(succ_inds),), _polygon_to_polygon_types.index('SUCC'), dtype=torch.uint8)) + if len(lane_segment.left_neighbors) != 0: + left_neighbor_ids = lane_segment.left_neighbors + for left_neighbor_id in left_neighbor_ids: + left_idx = safe_list_index(polygon_ids, left_neighbor_id) + if left_idx is not None: + polygon_to_polygon_edge_index.append( + torch.tensor([[left_idx], [lane_segment_idx]], dtype=torch.long)) + polygon_to_polygon_type.append( + torch.tensor([_polygon_to_polygon_types.index('LEFT')], dtype=torch.uint8)) + if len(lane_segment.right_neighbors) != 0: + right_neighbor_ids = lane_segment.right_neighbors + for right_neighbor_id in right_neighbor_ids: + right_idx = safe_list_index(polygon_ids, right_neighbor_id) + if right_idx is not None: + polygon_to_polygon_edge_index.append( + torch.tensor([[right_idx], [lane_segment_idx]], dtype=torch.long)) + polygon_to_polygon_type.append( + torch.tensor([_polygon_to_polygon_types.index('RIGHT')], dtype=torch.uint8)) + if len(polygon_to_polygon_edge_index) != 0: + polygon_to_polygon_edge_index = torch.cat(polygon_to_polygon_edge_index, dim=1) + polygon_to_polygon_type = torch.cat(polygon_to_polygon_type, dim=0) + else: + polygon_to_polygon_edge_index = torch.tensor([[], []], dtype=torch.long) + polygon_to_polygon_type = torch.tensor([], dtype=torch.uint8) + + map_data = { + 'map_polygon': {}, + 'map_point': {}, + ('map_point', 'to', 'map_polygon'): {}, + ('map_polygon', 'to', 'map_polygon'): {}, + } + map_data['map_polygon']['num_nodes'] = num_polygons + map_data['map_polygon']['type'] = polygon_type + map_data['map_polygon']['light_type'] = polygon_light_type + if len(num_points) == 0: + map_data['map_point']['num_nodes'] = 0 + map_data['map_point']['position'] = torch.tensor([], dtype=torch.float) + map_data['map_point']['orientation'] = torch.tensor([], dtype=torch.float) + map_data['map_point']['magnitude'] = torch.tensor([], dtype=torch.float) + if dim == 3: + map_data['map_point']['height'] = torch.tensor([], dtype=torch.float) + map_data['map_point']['type'] = torch.tensor([], dtype=torch.uint8) + map_data['map_point']['side'] = torch.tensor([], dtype=torch.uint8) + else: + map_data['map_point']['num_nodes'] = num_points.sum().item() + map_data['map_point']['position'] = torch.cat(point_position, dim=0) + map_data['map_point']['orientation'] = torch.cat(point_orientation, dim=0) + map_data['map_point']['magnitude'] = torch.cat(point_magnitude, dim=0) + if dim == 3: + map_data['map_point']['height'] = torch.cat(point_height, dim=0) + map_data['map_point']['type'] = torch.cat(point_type, dim=0) + map_data['map_point', 'to', 'map_polygon']['edge_index'] = point_to_polygon_edge_index + map_data['map_polygon', 'to', 'map_polygon']['edge_index'] = polygon_to_polygon_edge_index + map_data['map_polygon', 'to', 'map_polygon']['type'] = polygon_to_polygon_type + # import matplotlib.pyplot as plt + # plt.axis('equal') + # plt.scatter(map_data['map_point']['position'][:, 0], + # map_data['map_point']['position'][:, 1], s=0.2, c='black', edgecolors='none') + # plt.show(dpi=600) + return map_data + + +def process_agent(track_info, tracks_to_predict, sdc_track_index, scenario_id, start_timestamp, end_timestamp): + agents_array = track_info["trajs"].transpose(1, 0, 2) + object_id = np.array(track_info["object_id"]) + object_type = track_info["object_type"] + id_hash = {object_id[o_idx]: object_type[o_idx] for o_idx in range(len(object_id))} + def type_hash(x): + tp = id_hash[x] + type_re_hash = { + "TYPE_VEHICLE": "vehicle", + "TYPE_PEDESTRIAN": "pedestrian", + "TYPE_CYCLIST": "cyclist", + "TYPE_OTHER": "background", + "TYPE_UNSET": "background" + } + return type_re_hash[tp] + + columns = ['observed', 'track_id', 'object_type', 'object_category', 'timestep', + 'position_x', 'position_y', 'position_z', 'length', 'width', 'height', 'heading', 'velocity_x', 'velocity_y', + 'scenario_id', 'start_timestamp', 'end_timestamp', 'num_timestamps', + 'focal_track_id', 'city'] + new_columns = np.ones((agents_array.shape[0], agents_array.shape[1], 11)) + new_columns[:11, :, 0] = True + new_columns[11:, :, 0] = False + for index in range(new_columns.shape[0]): + new_columns[index, :, 4] = int(index) + new_columns[..., 1] = object_id + new_columns[..., 2] = object_id + new_columns[:, tracks_to_predict["track_index"], 3] = 3 + new_columns[..., 5] = 11 + new_columns[..., 6] = int(start_timestamp) + new_columns[..., 7] = int(end_timestamp) + new_columns[..., 8] = int(91) + new_columns[..., 9] = object_id + new_columns[..., 10] = 10086 + new_columns = new_columns + new_agents_array = np.concatenate([new_columns, agents_array], axis=-1) + new_agents_array = new_agents_array[new_agents_array[..., -1] == 1.0].reshape(-1, new_agents_array.shape[-1]) + new_agents_array = new_agents_array[..., [0, 1, 2, 3, 4, 11, 12, 13, 14, 15, 16, 17, 18, 19, 5, 6, 7, 8, 9, 10]] + new_agents_array = pd.DataFrame(data=new_agents_array, columns=columns) + new_agents_array["object_type"] = new_agents_array["object_type"].apply(func=type_hash) + new_agents_array["start_timestamp"] = new_agents_array["start_timestamp"].astype(int) + new_agents_array["end_timestamp"] = new_agents_array["end_timestamp"].astype(int) + new_agents_array["num_timestamps"] = new_agents_array["num_timestamps"].astype(int) + new_agents_array["scenario_id"] = scenario_id + return new_agents_array + + +def process_dynamic_map(dynamic_map_infos): + lane_ids = dynamic_map_infos["lane_id"] + tf_lights = [] + for t in range(len(lane_ids)): + lane_id = lane_ids[t] + time = np.ones_like(lane_id) * t + state = dynamic_map_infos["state"][t] + tf_light = np.concatenate([lane_id, time, state], axis=0) + tf_lights.append(tf_light) + tf_lights = np.concatenate(tf_lights, axis=1).transpose(1, 0) + tf_lights = pd.DataFrame(data=tf_lights, columns=["lane_id", "time_step", "state"]) + tf_lights["time_step"] = tf_lights["time_step"].astype("str") + tf_lights["lane_id"] = tf_lights["lane_id"].astype("str") + tf_lights["state"] = tf_lights["state"].astype("str") + tf_lights.loc[tf_lights["state"].str.contains("STOP"), ["state"] ] = 'LANE_STATE_STOP' + tf_lights.loc[tf_lights["state"].str.contains("GO"), ["state"] ] = 'LANE_STATE_GO' + tf_lights.loc[tf_lights["state"].str.contains("CAUTION"), ["state"] ] = 'LANE_STATE_CAUTION' + return tf_lights + + +polyline_type = { + # for lane + 'TYPE_UNDEFINED': -1, + 'TYPE_FREEWAY': 1, + 'TYPE_SURFACE_STREET': 2, + 'TYPE_BIKE_LANE': 3, + + # for roadline + 'TYPE_UNKNOWN': -1, + 'TYPE_BROKEN_SINGLE_WHITE': 6, + 'TYPE_SOLID_SINGLE_WHITE': 7, + 'TYPE_SOLID_DOUBLE_WHITE': 8, + 'TYPE_BROKEN_SINGLE_YELLOW': 9, + 'TYPE_BROKEN_DOUBLE_YELLOW': 10, + 'TYPE_SOLID_SINGLE_YELLOW': 11, + 'TYPE_SOLID_DOUBLE_YELLOW': 12, + 'TYPE_PASSING_DOUBLE_YELLOW': 13, + + # for roadedge + 'TYPE_ROAD_EDGE_BOUNDARY': 15, + 'TYPE_ROAD_EDGE_MEDIAN': 16, + + # for stopsign + 'TYPE_STOP_SIGN': 17, + + # for crosswalk + 'TYPE_CROSSWALK': 18, + + # for speed bump + 'TYPE_SPEED_BUMP': 19 +} + +object_type = { + 0: 'TYPE_UNSET', + 1: 'TYPE_VEHICLE', + 2: 'TYPE_PEDESTRIAN', + 3: 'TYPE_CYCLIST', + 4: 'TYPE_OTHER' +} + + +signal_state = { + 0: 'LANE_STATE_UNKNOWN', + + # // States for traffic signals with arrows. + 1: 'LANE_STATE_ARROW_STOP', + 2: 'LANE_STATE_ARROW_CAUTION', + 3: 'LANE_STATE_ARROW_GO', + + # // Standard round traffic signals. + 4: 'LANE_STATE_STOP', + 5: 'LANE_STATE_CAUTION', + 6: 'LANE_STATE_GO', + + # // Flashing light signals. + 7: 'LANE_STATE_FLASHING_STOP', + 8: 'LANE_STATE_FLASHING_CAUTION' +} + +signal_state_to_id = {} +for key, val in signal_state.items(): + signal_state_to_id[val] = key + + +def decode_tracks_from_proto(tracks): + track_infos = { + 'object_id': [], # {0: unset, 1: vehicle, 2: pedestrian, 3: cyclist, 4: others} + 'object_type': [], + 'trajs': [] + } + for cur_data in tracks: # number of objects + cur_traj = [np.array([x.center_x, x.center_y, x.center_z, x.length, x.width, x.height, x.heading, + x.velocity_x, x.velocity_y, x.valid], dtype=np.float32) for x in cur_data.states] + cur_traj = np.stack(cur_traj, axis=0) # (num_timestamp, 10) + + track_infos['object_id'].append(cur_data.id) + track_infos['object_type'].append(object_type[cur_data.object_type]) + track_infos['trajs'].append(cur_traj) + + track_infos['trajs'] = np.stack(track_infos['trajs'], axis=0) # (num_objects, num_timestamp, 9) + return track_infos + + +from collections import defaultdict + + +def decode_map_features_from_proto(map_features): + map_infos = { + 'lane': [], + 'road_line': [], + 'road_edge': [], + 'stop_sign': [], + 'crosswalk': [], + 'speed_bump': [], + 'lane_dict': {}, + 'lane2other_dict': {} + } + polylines = [] + + point_cnt = 0 + lane2other_dict = defaultdict(list) + + for cur_data in map_features: + cur_info = {'id': cur_data.id} + + if cur_data.lane.ByteSize() > 0: + cur_info['speed_limit_mph'] = cur_data.lane.speed_limit_mph + cur_info['type'] = cur_data.lane.type + 1 # 0: undefined, 1: freeway, 2: surface_street, 3: bike_lane + cur_info['left_neighbors'] = [lane.feature_id for lane in cur_data.lane.left_neighbors] + + cur_info['right_neighbors'] = [lane.feature_id for lane in cur_data.lane.right_neighbors] + + cur_info['interpolating'] = cur_data.lane.interpolating + cur_info['entry_lanes'] = list(cur_data.lane.entry_lanes) + cur_info['exit_lanes'] = list(cur_data.lane.exit_lanes) + + cur_info['left_boundary_type'] = [x.boundary_type + 5 for x in cur_data.lane.left_boundaries] + cur_info['right_boundary_type'] = [x.boundary_type + 5 for x in cur_data.lane.right_boundaries] + + cur_info['left_boundary'] = [x.boundary_feature_id for x in cur_data.lane.left_boundaries] + cur_info['right_boundary'] = [x.boundary_feature_id for x in cur_data.lane.right_boundaries] + cur_info['left_boundary_start_index'] = [lane.lane_start_index for lane in cur_data.lane.left_boundaries] + cur_info['left_boundary_end_index'] = [lane.lane_end_index for lane in cur_data.lane.left_boundaries] + cur_info['right_boundary_start_index'] = [lane.lane_start_index for lane in cur_data.lane.right_boundaries] + cur_info['right_boundary_end_index'] = [lane.lane_end_index for lane in cur_data.lane.right_boundaries] + + lane2other_dict[cur_data.id].extend(cur_info['left_boundary']) + lane2other_dict[cur_data.id].extend(cur_info['right_boundary']) + + global_type = cur_info['type'] + cur_polyline = np.stack( + [np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in cur_data.lane.polyline], + axis=0) + cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1) + if cur_polyline.shape[0] <= 1: + continue + map_infos['lane'].append(cur_info) + map_infos['lane_dict'][cur_data.id] = cur_info + + elif cur_data.road_line.ByteSize() > 0: + cur_info['type'] = cur_data.road_line.type + 5 + + global_type = cur_info['type'] + cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in + cur_data.road_line.polyline], axis=0) + cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1) + if cur_polyline.shape[0] <= 1: + continue + map_infos['road_line'].append(cur_info) + + elif cur_data.road_edge.ByteSize() > 0: + cur_info['type'] = cur_data.road_edge.type + 14 + + global_type = cur_info['type'] + cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in + cur_data.road_edge.polyline], axis=0) + cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1) + if cur_polyline.shape[0] <= 1: + continue + map_infos['road_edge'].append(cur_info) + + elif cur_data.stop_sign.ByteSize() > 0: + cur_info['lane_ids'] = list(cur_data.stop_sign.lane) + for i in cur_info['lane_ids']: + lane2other_dict[i].append(cur_data.id) + point = cur_data.stop_sign.position + cur_info['position'] = np.array([point.x, point.y, point.z]) + + global_type = polyline_type['TYPE_STOP_SIGN'] + cur_polyline = np.array([point.x, point.y, point.z, global_type, cur_data.id]).reshape(1, 5) + if cur_polyline.shape[0] <= 1: + continue + map_infos['stop_sign'].append(cur_info) + elif cur_data.crosswalk.ByteSize() > 0: + global_type = polyline_type['TYPE_CROSSWALK'] + cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in + cur_data.crosswalk.polygon], axis=0) + cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1) + if cur_polyline.shape[0] <= 1: + continue + map_infos['crosswalk'].append(cur_info) + + elif cur_data.speed_bump.ByteSize() > 0: + global_type = polyline_type['TYPE_SPEED_BUMP'] + cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in + cur_data.speed_bump.polygon], axis=0) + cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1) + if cur_polyline.shape[0] <= 1: + continue + map_infos['speed_bump'].append(cur_info) + + else: + # print(cur_data) + continue + polylines.append(cur_polyline) + cur_info['polyline_index'] = (point_cnt, point_cnt + len(cur_polyline)) + point_cnt += len(cur_polyline) + + # try: + polylines = np.concatenate(polylines, axis=0).astype(np.float32) + # except: + # polylines = np.zeros((0, 8), dtype=np.float32) + # print('Empty polylines: ') + map_infos['all_polylines'] = polylines + map_infos['lane2other_dict'] = lane2other_dict + return map_infos + + +def decode_dynamic_map_states_from_proto(dynamic_map_states): + dynamic_map_infos = { + 'lane_id': [], + 'state': [], + 'stop_point': [] + } + for cur_data in dynamic_map_states: # (num_timestamp) + lane_id, state, stop_point = [], [], [] + for cur_signal in cur_data.lane_states: # (num_observed_signals) + lane_id.append(cur_signal.lane) + state.append(signal_state[cur_signal.state]) + stop_point.append([cur_signal.stop_point.x, cur_signal.stop_point.y, cur_signal.stop_point.z]) + + dynamic_map_infos['lane_id'].append(np.array([lane_id])) + dynamic_map_infos['state'].append(np.array([state])) + dynamic_map_infos['stop_point'].append(np.array([stop_point])) + + return dynamic_map_infos + + +def process_single_data(scenario): + info = {} + info['scenario_id'] = scenario.scenario_id + info['timestamps_seconds'] = list(scenario.timestamps_seconds) # list of int of shape (91) + info['current_time_index'] = scenario.current_time_index # int, 10 + info['sdc_track_index'] = scenario.sdc_track_index # int + info['objects_of_interest'] = list(scenario.objects_of_interest) # list, could be empty list + + info['tracks_to_predict'] = { + 'track_index': [cur_pred.track_index for cur_pred in scenario.tracks_to_predict], + 'difficulty': [cur_pred.difficulty for cur_pred in scenario.tracks_to_predict] + } # for training: suggestion of objects to train on, for val/test: need to be predicted + + track_infos = decode_tracks_from_proto(scenario.tracks) + info['tracks_to_predict']['object_type'] = [track_infos['object_type'][cur_idx] for cur_idx in + info['tracks_to_predict']['track_index']] + + # decode map related data + map_infos = decode_map_features_from_proto(scenario.map_features) + dynamic_map_infos = decode_dynamic_map_states_from_proto(scenario.dynamic_map_states) + + save_infos = { + 'track_infos': track_infos, + 'dynamic_map_infos': dynamic_map_infos, + 'map_infos': map_infos + } + save_infos.update(info) + return save_infos + +import tensorflow as tf +from waymo_open_dataset.protos import scenario_pb2 + + +def wm2argo(file, dir_name, output_dir): + file_path = os.path.join(dir_name, file) + dataset = tf.data.TFRecordDataset(file_path, compression_type='', num_parallel_reads=3) + for cnt, data in enumerate(dataset): + print(cnt) + scenario = scenario_pb2.Scenario() + scenario.ParseFromString(bytearray(data.numpy())) + save_infos = process_single_data(scenario) # pkl2mtr + map_info = save_infos["map_infos"] + track_info = save_infos['track_infos'] + scenario_id = save_infos['scenario_id'] + tracks_to_predict = save_infos['tracks_to_predict'] + sdc_track_index = save_infos['sdc_track_index'] + av_id = track_info["object_id"][sdc_track_index] + if len(tracks_to_predict["track_index"]) < 1: + return + dynamic_map_infos = save_infos["dynamic_map_infos"] + tf_lights = process_dynamic_map(dynamic_map_infos) + tf_current_light = tf_lights.loc[tf_lights["time_step"] == "11"] + map_data = get_map_features(map_info, tf_current_light) + new_agents_array = process_agent(track_info, tracks_to_predict, sdc_track_index, scenario_id, 0, 91) # mtr2argo + data = dict() + data['scenario_id'] = new_agents_array['scenario_id'].values[0] + data['city'] = new_agents_array['city'].values[0] + data['agent'] = get_agent_features(new_agents_array, av_id, num_historical_steps=11) + data.update(map_data) + with open(os.path.join(output_dir, scenario_id + '.pkl'), "wb+") as f: + pickle.dump(data, f) + + +def batch_process9s_transformer(dir_name, output_dir, num_workers=2): + from functools import partial + import multiprocessing + packages = os.listdir(dir_name) + func = partial( + wm2argo, output_dir=output_dir, dir_name=dir_name) + with multiprocessing.Pool(num_workers) as p: + list(tqdm(p.imap(func, packages), total=len(packages))) + + +from argparse import ArgumentParser + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument('--input_dir', type=str, default='data/waymo/scenario/training') + parser.add_argument('--output_dir', type=str, default='data/waymo_processed/training') + args = parser.parse_args() + files = os.listdir(args.input_dir) + for file in tqdm(files): + wm2argo(file, args.input_dir, args.output_dir) + # batch_process9s_transformer(args.input_dir, args.output_dir, num_workers="ur_cpu_count") diff --git a/backups/thirdparty/SMART/environment.yml b/backups/thirdparty/SMART/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..83aa94561250b5ba60039bc5aadbe71f1b206849 --- /dev/null +++ b/backups/thirdparty/SMART/environment.yml @@ -0,0 +1,71 @@ +name: smart +channels: + - pytorch + - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free + - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main + - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ + - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ + - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - blas=1.0=mkl + - brotli-python=1.0.9=py39h6a678d5_8 + - bzip2=1.0.8=h5eee18b_6 + - ca-certificates=2024.9.24=h06a4308_0 + - certifi=2024.8.30=py39h06a4308_0 + - charset-normalizer=3.3.2=pyhd3eb1b0_0 + - cudatoolkit=11.3.1=h2bc3f7f_2 + - ffmpeg=4.3=hf484d3e_0 + - freetype=2.12.1=h4a9f257_0 + - gmp=6.2.1=h295c915_3 + - gnutls=3.6.15=he1e5248_0 + - idna=3.7=py39h06a4308_0 + - intel-openmp=2023.1.0=hdb19cb5_46306 + - jpeg=9e=h5eee18b_3 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.40=h12ee557_0 + - lerc=3.0=h295c915_0 + - libdeflate=1.17=h5eee18b_1 + - libffi=3.4.4=h6a678d5_1 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libiconv=1.14=0 + - libidn2=2.3.4=h5eee18b_0 + - libpng=1.6.39=h5eee18b_0 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtasn1=4.19.0=h5eee18b_0 + - libtiff=4.5.1=h6a678d5_0 + - libunistring=0.9.10=h27cfd23_0 + - libwebp-base=1.3.2=h5eee18b_1 + - lz4-c=1.9.4=h6a678d5_1 + - mkl=2023.1.0=h213fc3f_46344 + - mkl-service=2.4.0=py39h5eee18b_1 + - mkl_fft=1.3.10=py39h5eee18b_0 + - mkl_random=1.2.7=py39h1128e8f_0 + - ncurses=6.4=h6a678d5_0 + - nettle=3.7.3=hbbd107a_1 + - openh264=2.1.1=h4ff587b_0 + - openjpeg=2.5.2=he7f1fd0_0 + - openssl=3.0.15=h5eee18b_0 + - pillow=10.4.0=py39h5eee18b_0 + - pip=24.2=py39h06a4308_0 + - pysocks=1.7.1=py39h06a4308_0 + - python=3.9.19=h955ad1f_1 + - pytorch=1.12.1=py3.9_cuda11.3_cudnn8.3.2_0 + - pytorch-mutex=1.0=cuda + - readline=8.2=h5eee18b_0 + - requests=2.32.3=py39h06a4308_0 + - setuptools=75.1.0=py39h06a4308_0 + - sqlite=3.45.3=h5eee18b_0 + - tbb=2021.8.0=hdb19cb5_0 + - tk=8.6.14=h39e8969_0 + - torchvision=0.13.1=py39_cu113 + - typing_extensions=4.11.0=py39h06a4308_0 + - urllib3=2.2.3=py39h06a4308_0 + - wheel=0.44.0=py39h06a4308_0 + - xz=5.4.6=h5eee18b_1 + - zlib=1.2.13=h5eee18b_1 + - zstd=1.5.6=hc292b87_0 diff --git a/backups/thirdparty/SMART/scripts/install_pyg.sh b/backups/thirdparty/SMART/scripts/install_pyg.sh new file mode 100644 index 0000000000000000000000000000000000000000..78a727c7c643568f76a2e2b7c693763ccaaef23e --- /dev/null +++ b/backups/thirdparty/SMART/scripts/install_pyg.sh @@ -0,0 +1,10 @@ +mkdir pyg_depend && cd pyg_depend +wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_cluster-1.6.0%2Bpt112cu113-cp39-cp39-linux_x86_64.whl +wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_scatter-2.1.0%2Bpt112cu113-cp39-cp39-linux_x86_64.whl +wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_sparse-0.6.16%2Bpt112cu113-cp39-cp39-linux_x86_64.whl +wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_spline_conv-1.2.1%2Bpt112cu113-cp39-cp39-linux_x86_64.whl +python3 -m pip install torch_cluster-1.6.0+pt112cu113-cp39-cp39-linux_x86_64.whl +python3 -m pip install torch_scatter-2.1.0+pt112cu113-cp39-cp39-linux_x86_64.whl +python3 -m pip install torch_sparse-0.6.16+pt112cu113-cp39-cp39-linux_x86_64.whl +python3 -m pip install torch_spline_conv-1.2.1+pt112cu113-cp39-cp39-linux_x86_64.whl +python3 -m pip install torch_geometric diff --git a/backups/thirdparty/SMART/scripts/traj_clstering.py b/backups/thirdparty/SMART/scripts/traj_clstering.py new file mode 100644 index 0000000000000000000000000000000000000000..b9c39147fc9b6a3751c910cd191f77bdef349b58 --- /dev/null +++ b/backups/thirdparty/SMART/scripts/traj_clstering.py @@ -0,0 +1,150 @@ +from smart.utils.geometry import wrap_angle +import numpy as np + + +def average_distance_vectorized(point_set1, centroids): + dists = np.sqrt(np.sum((point_set1[:, None, :, :] - centroids[None, :, :, :])**2, axis=-1)) + return np.mean(dists, axis=2) + + +def assign_clusters(sub_X, centroids): + distances = average_distance_vectorized(sub_X, centroids) + return np.argmin(distances, axis=1) + + +def Kdisk_cluster(X, N=256, tol=0.035, width=0, length=0, a_pos=None): + S = [] + ret_traj_list = [] + while len(S) < N: + num_all = X.shape[0] + # 随机选择第一个簇中心 + choice_index = np.random.choice(num_all) + x0 = X[choice_index] + if x0[0, 0] < -10 or x0[0, 0] > 50 or x0[0, 1] > 10 or x0[0, 1] < -10: + continue + res_mask = np.sum((X - x0)**2, axis=(1, 2))/4 > (tol**2) + del_mask = np.sum((X - x0)**2, axis=(1, 2))/4 <= (tol**2) + if cal_mean_heading: + del_contour = X[del_mask] + diff_xy = del_contour[:, 0, :] - del_contour[:, 3, :] + del_heading = np.arctan2(diff_xy[:, 1], diff_xy[:, 0]).mean() + x0 = cal_polygon_contour(x0.mean(0)[0], x0.mean(0)[1], del_heading, width, length) + del_traj = a_pos[del_mask] + ret_traj = del_traj.mean(0)[None, ...] + if abs(ret_traj[0, 1, 0] - ret_traj[0, 0, 0]) > 1 and ret_traj[0, 1, 0] < 0: + print(ret_traj) + print('1') + else: + x0 = x0[None, ...] + ret_traj = a_pos[choice_index][None, ...] + X = X[res_mask] + a_pos = a_pos[res_mask] + S.append(x0) + ret_traj_list.append(ret_traj) + centroids = np.concatenate(S, axis=0) + ret_traj = np.concatenate(ret_traj_list, axis=0) + + # closest_dist_sq = np.sum((X - centroids[0])**2, axis=(1, 2)) + + # for k in range(1, K): + # new_dist_sq = np.sum((X - centroids[k - 1])**2, axis=(1, 2)) + # closest_dist_sq = np.minimum(closest_dist_sq, new_dist_sq) + # probabilities = closest_dist_sq / np.sum(closest_dist_sq) + # centroids[k] = X[np.random.choice(N, p=probabilities)] + + return centroids, ret_traj + + +def cal_polygon_contour(x, y, theta, width, length): + + left_front_x = x + 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta) + left_front_y = y + 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta) + left_front = np.column_stack((left_front_x, left_front_y)) + + right_front_x = x + 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta) + right_front_y = y + 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta) + right_front = np.column_stack((right_front_x, right_front_y)) + + right_back_x = x - 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta) + right_back_y = y - 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta) + right_back = np.column_stack((right_back_x, right_back_y)) + + left_back_x = x - 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta) + left_back_y = y - 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta) + left_back = np.column_stack((left_back_x, left_back_y)) + + polygon_contour = np.concatenate((left_front[:, None, :], right_front[:, None, :], right_back[:, None, :], left_back[:, None, :]), axis=1) + + return polygon_contour + + +if __name__ == '__main__': + shift = 5 # motion token time dimension + num_cluster = 6 # vocabulary size + cal_mean_heading = True + data = { + "veh": np.random.rand(1000, 6, 3), + "cyc": np.random.rand(1000, 6, 3), + "ped": np.random.rand(1000, 6, 3) + } + # Collect the trajectories of all traffic participants from the raw data [NumAgent, shift+1, [relative_x, relative_y, relative_theta]] + nms_res = {} + res = {'token': {}, 'traj': {}, 'token_all': {}} + for k, v in data.items(): + # if k != 'veh': + # continue + a_pos = v + print(a_pos.shape) + # a_pos = a_pos[:, shift:1+shift, :] + cal_num = min(int(1e6), a_pos.shape[0]) + a_pos = a_pos[np.random.choice(a_pos.shape[0], cal_num, replace=False)] + a_pos[:, :, -1] = wrap_angle(a_pos[:, :, -1]) + print(a_pos.shape) + if shift <= 2: + if k == 'veh': + width = 1.0 + length = 2.4 + elif k == 'cyc': + width = 0.5 + length = 1.5 + else: + width = 0.5 + length = 0.5 + else: + if k == 'veh': + width = 2.0 + length = 4.8 + elif k == 'cyc': + width = 1.0 + length = 2.0 + else: + width = 1.0 + length = 1.0 + contour = cal_polygon_contour(a_pos[:, shift, 0], a_pos[:, shift, 1], a_pos[:, shift, 2], width, length) + + # plt.figure(figsize=(10, 10)) + # for rect in contour: + # rect_closed = np.vstack([rect, rect[0]]) + # plt.plot(rect_closed[:, 0], rect_closed[:, 1], linewidth=0.1) + + # plt.title("Plot of 256 Rectangles") + # plt.xlabel("x") + # plt.ylabel("y") + # plt.axis('equal') + # plt.savefig(f'src_{k}_new.jpg', dpi=300) + + if k == 'veh': + tol = 0.05 + elif k == 'cyc': + tol = 0.004 + else: + tol = 0.004 + centroids, ret_traj = Kdisk_cluster(contour, num_cluster, tol, width, length, a_pos[:, :shift+1]) + # plt.figure(figsize=(10, 10)) + contour = cal_polygon_contour(ret_traj[:, :, 0].reshape(num_cluster*(shift+1)), + ret_traj[:, :, 1].reshape(num_cluster*(shift+1)), + ret_traj[:, :, 2].reshape(num_cluster*(shift+1)), width, length) + + res['token_all'][k] = contour.reshape(num_cluster, (shift+1), 4, 2) + res['token'][k] = centroids + res['traj'][k] = ret_traj diff --git a/backups/thirdparty/SMART/smart/__init__.py b/backups/thirdparty/SMART/smart/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backups/thirdparty/SMART/smart/datamodules/__init__.py b/backups/thirdparty/SMART/smart/datamodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..daa6e8e7fbf2247239b02650b2fa86be89662cf6 --- /dev/null +++ b/backups/thirdparty/SMART/smart/datamodules/__init__.py @@ -0,0 +1 @@ +from smart.datamodules.scalable_datamodule import MultiDataModule diff --git a/backups/thirdparty/SMART/smart/datamodules/scalable_datamodule.py b/backups/thirdparty/SMART/smart/datamodules/scalable_datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..6174dd9893f154fabd47708685779125e10101ee --- /dev/null +++ b/backups/thirdparty/SMART/smart/datamodules/scalable_datamodule.py @@ -0,0 +1,90 @@ +from typing import Optional + +import pytorch_lightning as pl +from torch_geometric.loader import DataLoader +from smart.datasets.scalable_dataset import MultiDataset +from smart.transforms import WaymoTargetBuilder + + +class MultiDataModule(pl.LightningDataModule): + transforms = { + "WaymoTargetBuilder": WaymoTargetBuilder, + } + + dataset = { + "scalable": MultiDataset, + } + + def __init__(self, + root: str, + train_batch_size: int, + val_batch_size: int, + test_batch_size: int, + shuffle: bool = False, + num_workers: int = 0, + pin_memory: bool = True, + persistent_workers: bool = True, + train_raw_dir: Optional[str] = None, + val_raw_dir: Optional[str] = None, + test_raw_dir: Optional[str] = None, + train_processed_dir: Optional[str] = None, + val_processed_dir: Optional[str] = None, + test_processed_dir: Optional[str] = None, + transform: Optional[str] = None, + dataset: Optional[str] = None, + num_historical_steps: int = 50, + num_future_steps: int = 60, + processor='ntp', + use_intention=False, + token_size=512, + **kwargs) -> None: + super(MultiDataModule, self).__init__() + self.root = root + self.dataset_class = dataset + self.train_batch_size = train_batch_size + self.val_batch_size = val_batch_size + self.test_batch_size = test_batch_size + self.shuffle = shuffle + self.num_workers = num_workers + self.pin_memory = pin_memory + self.persistent_workers = persistent_workers and num_workers > 0 + self.train_raw_dir = train_raw_dir + self.val_raw_dir = val_raw_dir + self.test_raw_dir = test_raw_dir + self.train_processed_dir = train_processed_dir + self.val_processed_dir = val_processed_dir + self.test_processed_dir = test_processed_dir + self.processor = processor + self.use_intention = use_intention + self.token_size = token_size + + train_transform = MultiDataModule.transforms[transform](num_historical_steps, num_future_steps, "train") + val_transform = MultiDataModule.transforms[transform](num_historical_steps, num_future_steps, "val") + test_transform = MultiDataModule.transforms[transform](num_historical_steps, num_future_steps) + + self.train_transform = train_transform + self.val_transform = val_transform + self.test_transform = test_transform + + def setup(self, stage: Optional[str] = None) -> None: + self.train_dataset = MultiDataModule.dataset[self.dataset_class](self.root, 'train', processed_dir=self.train_processed_dir, + raw_dir=self.train_raw_dir, processor=self.processor, transform=self.train_transform, token_size=self.token_size) + self.val_dataset = MultiDataModule.dataset[self.dataset_class](None, 'val', processed_dir=self.val_processed_dir, + raw_dir=self.val_raw_dir, processor=self.processor, transform=self.val_transform, token_size=self.token_size) + self.test_dataset = MultiDataModule.dataset[self.dataset_class](None, 'test', processed_dir=self.test_processed_dir, + raw_dir=self.test_raw_dir, processor=self.processor, transform=self.test_transform, token_size=self.token_size) + + def train_dataloader(self): + return DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=self.shuffle, + num_workers=self.num_workers, pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers) + + def val_dataloader(self): + return DataLoader(self.val_dataset, batch_size=self.val_batch_size, shuffle=False, + num_workers=self.num_workers, pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers) + + def test_dataloader(self): + return DataLoader(self.test_dataset, batch_size=self.test_batch_size, shuffle=False, + num_workers=self.num_workers, pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers) diff --git a/backups/thirdparty/SMART/smart/datasets/__init__.py b/backups/thirdparty/SMART/smart/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..35a6c6079c2671c0d524ab892f3cf3ae0a3db82d --- /dev/null +++ b/backups/thirdparty/SMART/smart/datasets/__init__.py @@ -0,0 +1 @@ +from smart.datasets.scalable_dataset import MultiDataset diff --git a/backups/thirdparty/SMART/smart/datasets/preprocess.py b/backups/thirdparty/SMART/smart/datasets/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..58716d352dd28f723a9d08917837347a74811bf0 --- /dev/null +++ b/backups/thirdparty/SMART/smart/datasets/preprocess.py @@ -0,0 +1,468 @@ +import torch +import numpy as np +from scipy.interpolate import interp1d +from scipy.spatial.distance import euclidean +import math +import pickle +from smart.utils import wrap_angle +import os + +def cal_polygon_contour(x, y, theta, width, length): + left_front_x = x + 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta) + left_front_y = y + 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta) + left_front = np.column_stack((left_front_x, left_front_y)) + + right_front_x = x + 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta) + right_front_y = y + 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta) + right_front = np.column_stack((right_front_x, right_front_y)) + + right_back_x = x - 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta) + right_back_y = y - 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta) + right_back = np.column_stack((right_back_x, right_back_y)) + + left_back_x = x - 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta) + left_back_y = y - 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta) + left_back = np.column_stack((left_back_x, left_back_y)) + + polygon_contour = np.concatenate( + (left_front[:, None, :], right_front[:, None, :], right_back[:, None, :], left_back[:, None, :]), axis=1) + + return polygon_contour + + +def interplating_polyline(polylines, heading, distance=0.5, split_distace=5): + # Calculate the cumulative distance along the path, up-sample the polyline to 0.5 meter + dist_along_path_list = [[0]] + polylines_list = [[polylines[0]]] + for i in range(1, polylines.shape[0]): + euclidean_dist = euclidean(polylines[i, :2], polylines[i - 1, :2]) + heading_diff = min(abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1])), + abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1]) + math.pi)) + if heading_diff > math.pi / 4 and euclidean_dist > 3: + dist_along_path_list.append([0]) + polylines_list.append([polylines[i]]) + elif heading_diff > math.pi / 8 and euclidean_dist > 3: + dist_along_path_list.append([0]) + polylines_list.append([polylines[i]]) + elif heading_diff > 0.1 and euclidean_dist > 3: + dist_along_path_list.append([0]) + polylines_list.append([polylines[i]]) + elif euclidean_dist > 10: + dist_along_path_list.append([0]) + polylines_list.append([polylines[i]]) + else: + dist_along_path_list[-1].append(dist_along_path_list[-1][-1] + euclidean_dist) + polylines_list[-1].append(polylines[i]) + # plt.plot(polylines[:, 0], polylines[:, 1]) + # plt.savefig('tmp.jpg') + new_x_list = [] + new_y_list = [] + multi_polylines_list = [] + for idx in range(len(dist_along_path_list)): + if len(dist_along_path_list[idx]) < 2: + continue + dist_along_path = np.array(dist_along_path_list[idx]) + polylines_cur = np.array(polylines_list[idx]) + # Create interpolation functions for x and y coordinates + fx = interp1d(dist_along_path, polylines_cur[:, 0]) + fy = interp1d(dist_along_path, polylines_cur[:, 1]) + # fyaw = interp1d(dist_along_path, heading) + + # Create an array of distances at which to interpolate + new_dist_along_path = np.arange(0, dist_along_path[-1], distance) + new_dist_along_path = np.concatenate([new_dist_along_path, dist_along_path[[-1]]]) + # Use the interpolation functions to generate new x and y coordinates + new_x = fx(new_dist_along_path) + new_y = fy(new_dist_along_path) + # new_yaw = fyaw(new_dist_along_path) + new_x_list.append(new_x) + new_y_list.append(new_y) + + # Combine the new x and y coordinates into a single array + new_polylines = np.vstack((new_x, new_y)).T + polyline_size = int(split_distace / distance) + if new_polylines.shape[0] >= (polyline_size + 1): + padding_size = (new_polylines.shape[0] - (polyline_size + 1)) % polyline_size + final_index = (new_polylines.shape[0] - (polyline_size + 1)) // polyline_size + 1 + else: + padding_size = new_polylines.shape[0] + final_index = 0 + multi_polylines = None + new_polylines = torch.from_numpy(new_polylines) + new_heading = torch.atan2(new_polylines[1:, 1] - new_polylines[:-1, 1], + new_polylines[1:, 0] - new_polylines[:-1, 0]) + new_heading = torch.cat([new_heading, new_heading[-1:]], -1)[..., None] + new_polylines = torch.cat([new_polylines, new_heading], -1) + if new_polylines.shape[0] >= (polyline_size + 1): + multi_polylines = new_polylines.unfold(dimension=0, size=polyline_size + 1, step=polyline_size) + multi_polylines = multi_polylines.transpose(1, 2) + multi_polylines = multi_polylines[:, ::5, :] + if padding_size >= 3: + last_polyline = new_polylines[final_index * polyline_size:] + last_polyline = last_polyline[torch.linspace(0, last_polyline.shape[0] - 1, steps=3).long()] + if multi_polylines is not None: + multi_polylines = torch.cat([multi_polylines, last_polyline.unsqueeze(0)], dim=0) + else: + multi_polylines = last_polyline.unsqueeze(0) + if multi_polylines is None: + continue + multi_polylines_list.append(multi_polylines) + if len(multi_polylines_list) > 0: + multi_polylines_list = torch.cat(multi_polylines_list, dim=0) + else: + multi_polylines_list = None + return multi_polylines_list + + +def average_distance_vectorized(point_set1, centroids): + dists = np.sqrt(np.sum((point_set1[:, None, :, :] - centroids[None, :, :, :]) ** 2, axis=-1)) + return np.mean(dists, axis=2) + + +def assign_clusters(sub_X, centroids): + distances = average_distance_vectorized(sub_X, centroids) + return np.argmin(distances, axis=1) + + +class TokenProcessor: + + def __init__(self, token_size): + module_dir = os.path.dirname(os.path.dirname(__file__)) + self.agent_token_path = os.path.join(module_dir, f'tokens/cluster_frame_5_{token_size}.pkl') + self.map_token_traj_path = os.path.join(module_dir, 'tokens/map_traj_token5.pkl') + self.noise = False + self.disturb = False + self.shift = 5 + self.get_trajectory_token() + self.training = False + self.current_step = 10 + + def preprocess(self, data): + data = self.tokenize_agent(data) + data = self.tokenize_map(data) + del data['city'] + if 'polygon_is_intersection' in data['map_polygon']: + del data['map_polygon']['polygon_is_intersection'] + if 'route_type' in data['map_polygon']: + del data['map_polygon']['route_type'] + return data + + def get_trajectory_token(self): + agent_token_data = pickle.load(open(self.agent_token_path, 'rb')) + map_token_traj = pickle.load(open(self.map_token_traj_path, 'rb')) + self.trajectory_token = agent_token_data['token'] + self.trajectory_token_all = agent_token_data['token_all'] + self.map_token = {'traj_src': map_token_traj['traj_src'], } + self.token_last = {} + for k, v in self.trajectory_token_all.items(): + token_last = torch.from_numpy(v[:, -2:]).to(torch.float) + diff_xy = token_last[:, 0, 0] - token_last[:, 0, 3] + theta = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0]) + cos, sin = theta.cos(), theta.sin() + rot_mat = theta.new_zeros(token_last.shape[0], 2, 2) + rot_mat[:, 0, 0] = cos + rot_mat[:, 0, 1] = -sin + rot_mat[:, 1, 0] = sin + rot_mat[:, 1, 1] = cos + agent_token = torch.bmm(token_last[:, 1], rot_mat) + agent_token -= token_last[:, 0].mean(1)[:, None, :] + self.token_last[k] = agent_token.numpy() + + def clean_heading(self, data): + heading = data['agent']['heading'] + valid = data['agent']['valid_mask'] + pi = torch.tensor(torch.pi) + n_vehicles, n_frames = heading.shape + + heading_diff_raw = heading[:, :-1] - heading[:, 1:] + heading_diff = torch.remainder(heading_diff_raw + pi, 2 * pi) - pi + heading_diff[heading_diff > pi] -= 2 * pi + heading_diff[heading_diff < -pi] += 2 * pi + + valid_pairs = valid[:, :-1] & valid[:, 1:] + + for i in range(n_frames - 1): + change_needed = (torch.abs(heading_diff[:, i:i + 1]) > 1.0) & valid_pairs[:, i:i + 1] + + heading[:, i + 1][change_needed.squeeze()] = heading[:, i][change_needed.squeeze()] + + if i < n_frames - 2: + heading_diff_raw = heading[:, i + 1] - heading[:, i + 2] + heading_diff[:, i + 1] = torch.remainder(heading_diff_raw + pi, 2 * pi) - pi + heading_diff[heading_diff[:, i + 1] > pi] -= 2 * pi + heading_diff[heading_diff[:, i + 1] < -pi] += 2 * pi + + def tokenize_agent(self, data): + if data['agent']["velocity"].shape[1] == 90: + print(data['scenario_id'], data['agent']["velocity"].shape) + interplote_mask = (data['agent']['valid_mask'][:, self.current_step] == False) * ( + data['agent']['position'][:, self.current_step, 0] != 0) + if data['agent']["velocity"].shape[-1] == 2: + data['agent']["velocity"] = torch.cat([data['agent']["velocity"], + torch.zeros(data['agent']["velocity"].shape[0], + data['agent']["velocity"].shape[1], 1)], dim=-1) + vel = data['agent']["velocity"][interplote_mask, self.current_step] + data['agent']['position'][interplote_mask, self.current_step - 1, :3] = data['agent']['position'][ + interplote_mask, self.current_step, + :3] - vel * 0.1 + data['agent']['valid_mask'][interplote_mask, self.current_step - 1:self.current_step + 1] = True + data['agent']['heading'][interplote_mask, self.current_step - 1] = data['agent']['heading'][ + interplote_mask, self.current_step] + data['agent']["velocity"][interplote_mask, self.current_step - 1] = data['agent']["velocity"][ + interplote_mask, self.current_step] + + data['agent']['type'] = data['agent']['type'].to(torch.uint8) + + self.clean_heading(data) + matching_extra_mask = (data['agent']['valid_mask'][:, self.current_step] == True) * ( + data['agent']['valid_mask'][:, self.current_step - 5] == False) + + interplote_mask_first = (data['agent']['valid_mask'][:, 0] == False) * (data['agent']['position'][:, 0, 0] != 0) + data['agent']['valid_mask'][interplote_mask_first, 0] = True + + agent_pos = data['agent']['position'][:, :, :2] + valid_mask = data['agent']['valid_mask'] + + valid_mask_shift = valid_mask.unfold(1, self.shift + 1, self.shift) + token_valid_mask = valid_mask_shift[:, :, 0] * valid_mask_shift[:, :, -1] + agent_type = data['agent']['type'] + agent_category = data['agent']['category'] + agent_heading = data['agent']['heading'] + vehicle_mask = agent_type == 0 + cyclist_mask = agent_type == 2 + ped_mask = agent_type == 1 + + veh_pos = agent_pos[vehicle_mask, :, :] + veh_valid_mask = valid_mask[vehicle_mask, :] + cyc_pos = agent_pos[cyclist_mask, :, :] + cyc_valid_mask = valid_mask[cyclist_mask, :] + ped_pos = agent_pos[ped_mask, :, :] + ped_valid_mask = valid_mask[ped_mask, :] + + veh_token_index, veh_token_contour = self.match_token(veh_pos, veh_valid_mask, agent_heading[vehicle_mask], + 'veh', agent_category[vehicle_mask], + matching_extra_mask[vehicle_mask]) + ped_token_index, ped_token_contour = self.match_token(ped_pos, ped_valid_mask, agent_heading[ped_mask], 'ped', + agent_category[ped_mask], matching_extra_mask[ped_mask]) + cyc_token_index, cyc_token_contour = self.match_token(cyc_pos, cyc_valid_mask, agent_heading[cyclist_mask], + 'cyc', agent_category[cyclist_mask], + matching_extra_mask[cyclist_mask]) + + token_index = torch.zeros((agent_pos.shape[0], veh_token_index.shape[1])).to(torch.int64) + token_index[vehicle_mask] = veh_token_index + token_index[ped_mask] = ped_token_index + token_index[cyclist_mask] = cyc_token_index + + token_contour = torch.zeros((agent_pos.shape[0], veh_token_contour.shape[1], + veh_token_contour.shape[2], veh_token_contour.shape[3])) + token_contour[vehicle_mask] = veh_token_contour + token_contour[ped_mask] = ped_token_contour + token_contour[cyclist_mask] = cyc_token_contour + + trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(torch.float) + trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(torch.float) + trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(torch.float) + + agent_token_traj = torch.zeros((agent_pos.shape[0], trajectory_token_veh.shape[0], 4, 2)) + agent_token_traj[vehicle_mask] = trajectory_token_veh + agent_token_traj[ped_mask] = trajectory_token_ped + agent_token_traj[cyclist_mask] = trajectory_token_cyc + + if not self.training: + token_valid_mask[matching_extra_mask, 1] = True + + data['agent']['token_idx'] = token_index + data['agent']['token_contour'] = token_contour + token_pos = token_contour.mean(dim=2) + data['agent']['token_pos'] = token_pos + diff_xy = token_contour[:, :, 0, :] - token_contour[:, :, 3, :] + data['agent']['token_heading'] = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0]) + data['agent']['agent_valid_mask'] = token_valid_mask + + vel = torch.cat([token_pos.new_zeros(data['agent']['num_nodes'], 1, 2), + ((token_pos[:, 1:] - token_pos[:, :-1]) / (0.1 * self.shift))], dim=1) + vel_valid_mask = torch.cat([torch.zeros(token_valid_mask.shape[0], 1, dtype=torch.bool), + (token_valid_mask * token_valid_mask.roll(shifts=1, dims=1))[:, 1:]], dim=1) + vel[~vel_valid_mask] = 0 + vel[data['agent']['valid_mask'][:, self.current_step], 1] = data['agent']['velocity'][ + data['agent']['valid_mask'][:, self.current_step], + self.current_step, :2] + + data['agent']['token_velocity'] = vel + + return data + + def match_token(self, pos, valid_mask, heading, category, agent_category, extra_mask): + agent_token_src = self.trajectory_token[category] + token_last = self.token_last[category] + if self.shift <= 2: + if category == 'veh': + width = 1.0 + length = 2.4 + elif category == 'cyc': + width = 0.5 + length = 1.5 + else: + width = 0.5 + length = 0.5 + else: + if category == 'veh': + width = 2.0 + length = 4.8 + elif category == 'cyc': + width = 1.0 + length = 2.0 + else: + width = 1.0 + length = 1.0 + + prev_heading = heading[:, 0] + prev_pos = pos[:, 0] + agent_num, num_step, feat_dim = pos.shape + token_num, token_contour_dim, feat_dim = agent_token_src.shape + agent_token_src = agent_token_src.reshape(1, token_num * token_contour_dim, feat_dim).repeat(agent_num, 0) + token_last = token_last.reshape(1, token_num * token_contour_dim, feat_dim).repeat(extra_mask.sum(), 0) + token_index_list = [] + token_contour_list = [] + prev_token_idx = None + + for i in range(self.shift, pos.shape[1], self.shift): + theta = prev_heading + cur_heading = heading[:, i] + cur_pos = pos[:, i] + cos, sin = theta.cos(), theta.sin() + rot_mat = theta.new_zeros(agent_num, 2, 2) + rot_mat[:, 0, 0] = cos + rot_mat[:, 0, 1] = sin + rot_mat[:, 1, 0] = -sin + rot_mat[:, 1, 1] = cos + agent_token_world = torch.bmm(torch.from_numpy(agent_token_src).to(torch.float), rot_mat).reshape(agent_num, + token_num, + token_contour_dim, + feat_dim) + agent_token_world += prev_pos[:, None, None, :] + + cur_contour = cal_polygon_contour(cur_pos[:, 0], cur_pos[:, 1], cur_heading, width, length) + agent_token_index = torch.from_numpy(np.argmin( + np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), axis=2), + axis=-1)) + if prev_token_idx is not None and self.noise: + same_idx = prev_token_idx == agent_token_index + same_idx[:] = True + topk_indices = np.argsort( + np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), + axis=2), axis=-1)[:, :5] + sample_topk = np.random.choice(range(0, topk_indices.shape[1]), topk_indices.shape[0]) + agent_token_index[same_idx] = \ + torch.from_numpy(topk_indices[np.arange(topk_indices.shape[0]), sample_topk])[same_idx] + + token_contour_select = agent_token_world[torch.arange(agent_num), agent_token_index] + + diff_xy = token_contour_select[:, 0, :] - token_contour_select[:, 3, :] + + prev_heading = heading[:, i].clone() + prev_heading[valid_mask[:, i - self.shift]] = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])[ + valid_mask[:, i - self.shift]] + + prev_pos = pos[:, i].clone() + prev_pos[valid_mask[:, i - self.shift]] = token_contour_select.mean(dim=1)[valid_mask[:, i - self.shift]] + prev_token_idx = agent_token_index + token_index_list.append(agent_token_index[:, None]) + token_contour_list.append(token_contour_select[:, None, ...]) + + token_index = torch.cat(token_index_list, dim=1) + token_contour = torch.cat(token_contour_list, dim=1) + + # extra matching + if not self.training: + theta = heading[extra_mask, self.current_step - 1] + prev_pos = pos[extra_mask, self.current_step - 1] + cur_pos = pos[extra_mask, self.current_step] + cur_heading = heading[extra_mask, self.current_step] + cos, sin = theta.cos(), theta.sin() + rot_mat = theta.new_zeros(extra_mask.sum(), 2, 2) + rot_mat[:, 0, 0] = cos + rot_mat[:, 0, 1] = sin + rot_mat[:, 1, 0] = -sin + rot_mat[:, 1, 1] = cos + agent_token_world = torch.bmm(torch.from_numpy(token_last).to(torch.float), rot_mat).reshape( + extra_mask.sum(), token_num, token_contour_dim, feat_dim) + agent_token_world += prev_pos[:, None, None, :] + + cur_contour = cal_polygon_contour(cur_pos[:, 0], cur_pos[:, 1], cur_heading, width, length) + agent_token_index = torch.from_numpy(np.argmin( + np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), axis=2), + axis=-1)) + token_contour_select = agent_token_world[torch.arange(extra_mask.sum()), agent_token_index] + + token_index[extra_mask, 1] = agent_token_index + token_contour[extra_mask, 1] = token_contour_select + + return token_index, token_contour + + def tokenize_map(self, data): + data['map_polygon']['type'] = data['map_polygon']['type'].to(torch.uint8) + data['map_point']['type'] = data['map_point']['type'].to(torch.uint8) + pt2pl = data[('map_point', 'to', 'map_polygon')]['edge_index'] + pt_type = data['map_point']['type'].to(torch.uint8) + pt_side = torch.zeros_like(pt_type) + pt_pos = data['map_point']['position'][:, :2] + data['map_point']['orientation'] = wrap_angle(data['map_point']['orientation']) + pt_heading = data['map_point']['orientation'] + split_polyline_type = [] + split_polyline_pos = [] + split_polyline_theta = [] + split_polyline_side = [] + pl_idx_list = [] + split_polygon_type = [] + data['map_point']['type'].unique() + + for i in sorted(np.unique(pt2pl[1])): + index = pt2pl[0, pt2pl[1] == i] + polygon_type = data['map_polygon']["type"][i] + cur_side = pt_side[index] + cur_type = pt_type[index] + cur_pos = pt_pos[index] + cur_heading = pt_heading[index] + + for side_val in np.unique(cur_side): + for type_val in np.unique(cur_type): + if type_val == 13: + continue + indices = np.where((cur_side == side_val) & (cur_type == type_val))[0] + if len(indices) <= 2: + continue + split_polyline = interplating_polyline(cur_pos[indices].numpy(), cur_heading[indices].numpy()) + if split_polyline is None: + continue + new_cur_type = cur_type[indices][0] + new_cur_side = cur_side[indices][0] + map_polygon_type = polygon_type.repeat(split_polyline.shape[0]) + new_cur_type = new_cur_type.repeat(split_polyline.shape[0]) + new_cur_side = new_cur_side.repeat(split_polyline.shape[0]) + cur_pl_idx = torch.Tensor([i]) + new_cur_pl_idx = cur_pl_idx.repeat(split_polyline.shape[0]) + split_polyline_pos.append(split_polyline[..., :2]) + split_polyline_theta.append(split_polyline[..., 2]) + split_polyline_type.append(new_cur_type) + split_polyline_side.append(new_cur_side) + pl_idx_list.append(new_cur_pl_idx) + split_polygon_type.append(map_polygon_type) + + split_polyline_pos = torch.cat(split_polyline_pos, dim=0) + split_polyline_theta = torch.cat(split_polyline_theta, dim=0) + split_polyline_type = torch.cat(split_polyline_type, dim=0) + split_polyline_side = torch.cat(split_polyline_side, dim=0) + split_polygon_type = torch.cat(split_polygon_type, dim=0) + pl_idx_list = torch.cat(pl_idx_list, dim=0) + vec = split_polyline_pos[:, 1, :] - split_polyline_pos[:, 0, :] + data['map_save'] = {} + data['pt_token'] = {} + data['map_save']['traj_pos'] = split_polyline_pos + data['map_save']['traj_theta'] = split_polyline_theta[:, 0] # torch.arctan2(vec[:, 1], vec[:, 0]) + data['map_save']['pl_idx_list'] = pl_idx_list + data['pt_token']['type'] = split_polyline_type + data['pt_token']['side'] = split_polyline_side + data['pt_token']['pl_type'] = split_polygon_type + data['pt_token']['num_nodes'] = split_polyline_pos.shape[0] + return data \ No newline at end of file diff --git a/backups/thirdparty/SMART/smart/datasets/scalable_dataset.py b/backups/thirdparty/SMART/smart/datasets/scalable_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b92e595be27b0e401b41534247afc7cc766be546 --- /dev/null +++ b/backups/thirdparty/SMART/smart/datasets/scalable_dataset.py @@ -0,0 +1,91 @@ +import os +import pickle +from typing import Callable, List, Optional, Tuple, Union +import pandas as pd +from torch_geometric.data import Dataset +from smart.utils.log import Logging +import numpy as np +from .preprocess import TokenProcessor + + +def distance(point1, point2): + return np.sqrt((point2[0] - point1[0])**2 + (point2[1] - point1[1])**2) + + +class MultiDataset(Dataset): + def __init__(self, + root: str, + split: str, + raw_dir: List[str] = None, + processed_dir: List[str] = None, + transform: Optional[Callable] = None, + dim: int = 3, + num_historical_steps: int = 50, + num_future_steps: int = 60, + predict_unseen_agents: bool = False, + vector_repr: bool = True, + cluster: bool = False, + processor=None, + use_intention=False, + token_size=512) -> None: + self.logger = Logging().log(level='DEBUG') + self.root = root + self.well_done = [0] + if split not in ('train', 'val', 'test'): + raise ValueError(f'{split} is not a valid split') + self.split = split + self.training = split == 'train' + self.logger.debug("Starting loading dataset") + self._raw_file_names = [] + self._raw_paths = [] + self._raw_file_dataset = [] + if raw_dir is not None: + self._raw_dir = raw_dir + for raw_dir in self._raw_dir: + raw_dir = os.path.expanduser(os.path.normpath(raw_dir)) + dataset = "waymo" + file_list = os.listdir(raw_dir) + self._raw_file_names.extend(file_list) + self._raw_paths.extend([os.path.join(raw_dir, f) for f in file_list]) + self._raw_file_dataset.extend([dataset for _ in range(len(file_list))]) + if self.root is not None: + split_datainfo = os.path.join(root, "split_datainfo.pkl") + with open(split_datainfo, 'rb+') as f: + split_datainfo = pickle.load(f) + if split == "test": + split = "val" + self._processed_file_names = split_datainfo[split] + self.dim = dim + self.num_historical_steps = num_historical_steps + self._num_samples = len(self._processed_file_names) - 1 if processed_dir is not None else len(self._raw_file_names) + self.logger.debug("The number of {} dataset is ".format(split) + str(self._num_samples)) + self.token_processor = TokenProcessor(2048) + super(MultiDataset, self).__init__(root=root, transform=transform, pre_transform=None, pre_filter=None) + + @property + def raw_dir(self) -> str: + return self._raw_dir + + @property + def raw_paths(self) -> List[str]: + return self._raw_paths + + @property + def raw_file_names(self) -> Union[str, List[str], Tuple]: + return self._raw_file_names + + @property + def processed_file_names(self) -> Union[str, List[str], Tuple]: + return self._processed_file_names + + def len(self) -> int: + return self._num_samples + + def generate_ref_token(self): + pass + + def get(self, idx: int): + with open(self.raw_paths[idx], 'rb') as handle: + data = pickle.load(handle) + data = self.token_processor.preprocess(data) + return data diff --git a/backups/thirdparty/SMART/smart/layers/__init__.py b/backups/thirdparty/SMART/smart/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..52c66a2fa5c78330117cff31ef10ea11e584e9d4 --- /dev/null +++ b/backups/thirdparty/SMART/smart/layers/__init__.py @@ -0,0 +1,4 @@ + +from smart.layers.attention_layer import AttentionLayer +from smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding +from smart.layers.mlp_layer import MLPLayer diff --git a/backups/thirdparty/SMART/smart/layers/attention_layer.py b/backups/thirdparty/SMART/smart/layers/attention_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..7f182a12f230d9bc1d7e7248caf402a2bd75a5eb --- /dev/null +++ b/backups/thirdparty/SMART/smart/layers/attention_layer.py @@ -0,0 +1,109 @@ + +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch_geometric.nn.conv import MessagePassing +from torch_geometric.utils import softmax + +from smart.utils import weight_init + + +class AttentionLayer(MessagePassing): + + def __init__(self, + hidden_dim: int, + num_heads: int, + head_dim: int, + dropout: float, + bipartite: bool, + has_pos_emb: bool, + **kwargs) -> None: + super(AttentionLayer, self).__init__(aggr='add', node_dim=0, **kwargs) + self.num_heads = num_heads + self.head_dim = head_dim + self.has_pos_emb = has_pos_emb + self.scale = head_dim ** -0.5 + + self.to_q = nn.Linear(hidden_dim, head_dim * num_heads) + self.to_k = nn.Linear(hidden_dim, head_dim * num_heads, bias=False) + self.to_v = nn.Linear(hidden_dim, head_dim * num_heads) + if has_pos_emb: + self.to_k_r = nn.Linear(hidden_dim, head_dim * num_heads, bias=False) + self.to_v_r = nn.Linear(hidden_dim, head_dim * num_heads) + self.to_s = nn.Linear(hidden_dim, head_dim * num_heads) + self.to_g = nn.Linear(head_dim * num_heads + hidden_dim, head_dim * num_heads) + self.to_out = nn.Linear(head_dim * num_heads, hidden_dim) + self.attn_drop = nn.Dropout(dropout) + self.ff_mlp = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim * 4), + nn.ReLU(inplace=True), + nn.Dropout(dropout), + nn.Linear(hidden_dim * 4, hidden_dim), + ) + if bipartite: + self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim) + self.attn_prenorm_x_dst = nn.LayerNorm(hidden_dim) + else: + self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim) + self.attn_prenorm_x_dst = self.attn_prenorm_x_src + if has_pos_emb: + self.attn_prenorm_r = nn.LayerNorm(hidden_dim) + self.attn_postnorm = nn.LayerNorm(hidden_dim) + self.ff_prenorm = nn.LayerNorm(hidden_dim) + self.ff_postnorm = nn.LayerNorm(hidden_dim) + self.apply(weight_init) + + def forward(self, + x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + r: Optional[torch.Tensor], + edge_index: torch.Tensor) -> torch.Tensor: + if isinstance(x, torch.Tensor): + x_src = x_dst = self.attn_prenorm_x_src(x) + else: + x_src, x_dst = x + x_src = self.attn_prenorm_x_src(x_src) + x_dst = self.attn_prenorm_x_dst(x_dst) + x = x[1] + if self.has_pos_emb and r is not None: + r = self.attn_prenorm_r(r) + x = x + self.attn_postnorm(self._attn_block(x_src, x_dst, r, edge_index)) + x = x + self.ff_postnorm(self._ff_block(self.ff_prenorm(x))) + return x + + def message(self, + q_i: torch.Tensor, + k_j: torch.Tensor, + v_j: torch.Tensor, + r: Optional[torch.Tensor], + index: torch.Tensor, + ptr: Optional[torch.Tensor]) -> torch.Tensor: + if self.has_pos_emb and r is not None: + k_j = k_j + self.to_k_r(r).view(-1, self.num_heads, self.head_dim) + v_j = v_j + self.to_v_r(r).view(-1, self.num_heads, self.head_dim) + sim = (q_i * k_j).sum(dim=-1) * self.scale + attn = softmax(sim, index, ptr) + self.attention_weight = attn.sum(-1).detach() + attn = self.attn_drop(attn) + return v_j * attn.unsqueeze(-1) + + def update(self, + inputs: torch.Tensor, + x_dst: torch.Tensor) -> torch.Tensor: + inputs = inputs.view(-1, self.num_heads * self.head_dim) + g = torch.sigmoid(self.to_g(torch.cat([inputs, x_dst], dim=-1))) + return inputs + g * (self.to_s(x_dst) - inputs) + + def _attn_block(self, + x_src: torch.Tensor, + x_dst: torch.Tensor, + r: Optional[torch.Tensor], + edge_index: torch.Tensor) -> torch.Tensor: + q = self.to_q(x_dst).view(-1, self.num_heads, self.head_dim) + k = self.to_k(x_src).view(-1, self.num_heads, self.head_dim) + v = self.to_v(x_src).view(-1, self.num_heads, self.head_dim) + agg = self.propagate(edge_index=edge_index, x_dst=x_dst, q=q, k=k, v=v, r=r) + return self.to_out(agg) + + def _ff_block(self, x: torch.Tensor) -> torch.Tensor: + return self.ff_mlp(x) diff --git a/backups/thirdparty/SMART/smart/layers/fourier_embedding.py b/backups/thirdparty/SMART/smart/layers/fourier_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..a89db66edbcfbf83da41cd1cc1afd1d2050a452d --- /dev/null +++ b/backups/thirdparty/SMART/smart/layers/fourier_embedding.py @@ -0,0 +1,85 @@ +import math +from typing import List, Optional +import torch +import torch.nn as nn + +from smart.utils import weight_init + + +class FourierEmbedding(nn.Module): + + def __init__(self, + input_dim: int, + hidden_dim: int, + num_freq_bands: int) -> None: + super(FourierEmbedding, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.freqs = nn.Embedding(input_dim, num_freq_bands) if input_dim != 0 else None + self.mlps = nn.ModuleList( + [nn.Sequential( + nn.Linear(num_freq_bands * 2 + 1, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, hidden_dim), + ) + for _ in range(input_dim)]) + self.to_out = nn.Sequential( + nn.LayerNorm(hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, hidden_dim), + ) + self.apply(weight_init) + + def forward(self, + continuous_inputs: Optional[torch.Tensor] = None, + categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor: + if continuous_inputs is None: + if categorical_embs is not None: + x = torch.stack(categorical_embs).sum(dim=0) + else: + raise ValueError('Both continuous_inputs and categorical_embs are None') + else: + x = continuous_inputs.unsqueeze(-1) * self.freqs.weight * 2 * math.pi + # Warning: if your data are noisy, don't use learnable sinusoidal embedding + x = torch.cat([x.cos(), x.sin(), continuous_inputs.unsqueeze(-1)], dim=-1) + continuous_embs: List[Optional[torch.Tensor]] = [None] * self.input_dim + for i in range(self.input_dim): + continuous_embs[i] = self.mlps[i](x[:, i]) + x = torch.stack(continuous_embs).sum(dim=0) + if categorical_embs is not None: + x = x + torch.stack(categorical_embs).sum(dim=0) + return self.to_out(x) + + +class MLPEmbedding(nn.Module): + def __init__(self, + input_dim: int, + hidden_dim: int) -> None: + super(MLPEmbedding, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.mlp = nn.Sequential( + nn.Linear(input_dim, 128), + nn.LayerNorm(128), + nn.ReLU(inplace=True), + nn.Linear(128, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, hidden_dim)) + self.apply(weight_init) + + def forward(self, + continuous_inputs: Optional[torch.Tensor] = None, + categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor: + if continuous_inputs is None: + if categorical_embs is not None: + x = torch.stack(categorical_embs).sum(dim=0) + else: + raise ValueError('Both continuous_inputs and categorical_embs are None') + else: + x = self.mlp(continuous_inputs) + if categorical_embs is not None: + x = x + torch.stack(categorical_embs).sum(dim=0) + return x diff --git a/backups/thirdparty/SMART/smart/layers/mlp_layer.py b/backups/thirdparty/SMART/smart/layers/mlp_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..453b42a90e4431271a9ad5dc6f97d7718724103e --- /dev/null +++ b/backups/thirdparty/SMART/smart/layers/mlp_layer.py @@ -0,0 +1,24 @@ + +import torch +import torch.nn as nn + +from smart.utils import weight_init + + +class MLPLayer(nn.Module): + + def __init__(self, + input_dim: int, + hidden_dim: int, + output_dim: int) -> None: + super(MLPLayer, self).__init__() + self.mlp = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, output_dim), + ) + self.apply(weight_init) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(x) diff --git a/backups/thirdparty/SMART/smart/metrics/__init__.py b/backups/thirdparty/SMART/smart/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e99201f9739de81e9809e9baa2bc8a7b18c993ee --- /dev/null +++ b/backups/thirdparty/SMART/smart/metrics/__init__.py @@ -0,0 +1,5 @@ + +from smart.metrics.average_meter import AverageMeter +from smart.metrics.min_ade import minADE +from smart.metrics.min_fde import minFDE +from smart.metrics.next_token_cls import TokenCls diff --git a/backups/thirdparty/SMART/smart/metrics/average_meter.py b/backups/thirdparty/SMART/smart/metrics/average_meter.py new file mode 100644 index 0000000000000000000000000000000000000000..7487fff3d9da76fd135af4efc3580190e1008597 --- /dev/null +++ b/backups/thirdparty/SMART/smart/metrics/average_meter.py @@ -0,0 +1,18 @@ + +import torch +from torchmetrics import Metric + + +class AverageMeter(Metric): + + def __init__(self, **kwargs) -> None: + super(AverageMeter, self).__init__(**kwargs) + self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') + + def update(self, val: torch.Tensor) -> None: + self.sum += val.sum() + self.count += val.numel() + + def compute(self) -> torch.Tensor: + return self.sum / self.count diff --git a/backups/thirdparty/SMART/smart/metrics/min_ade.py b/backups/thirdparty/SMART/smart/metrics/min_ade.py new file mode 100644 index 0000000000000000000000000000000000000000..13b10090999300ef61e02f025f53f3e758148308 --- /dev/null +++ b/backups/thirdparty/SMART/smart/metrics/min_ade.py @@ -0,0 +1,85 @@ + +from typing import Optional + +import torch +from torchmetrics import Metric + +from smart.metrics.utils import topk +from smart.metrics.utils import valid_filter + + +class minMultiADE(Metric): + + def __init__(self, + max_guesses: int = 6, + **kwargs) -> None: + super(minMultiADE, self).__init__(**kwargs) + self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') + self.max_guesses = max_guesses + + def update(self, + pred: torch.Tensor, + target: torch.Tensor, + prob: Optional[torch.Tensor] = None, + valid_mask: Optional[torch.Tensor] = None, + keep_invalid_final_step: bool = True, + min_criterion: str = 'FDE') -> None: + pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step) + pred_topk, _ = topk(self.max_guesses, pred, prob) + if min_criterion == 'FDE': + inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1) + inds_best = torch.norm( + pred_topk[torch.arange(pred.size(0)), :, inds_last] - + target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), p=2, dim=-1).argmin(dim=-1) + self.sum += ((torch.norm(pred_topk[torch.arange(pred.size(0)), inds_best] - target, p=2, dim=-1) * + valid_mask).sum(dim=-1) / valid_mask.sum(dim=-1)).sum() + elif min_criterion == 'ADE': + self.sum += ((torch.norm(pred_topk - target.unsqueeze(1), p=2, dim=-1) * + valid_mask.unsqueeze(1)).sum(dim=-1).min(dim=-1)[0] / valid_mask.sum(dim=-1)).sum() + else: + raise ValueError('{} is not a valid criterion'.format(min_criterion)) + self.count += pred.size(0) + + def compute(self) -> torch.Tensor: + return self.sum / self.count + + +class minADE(Metric): + + def __init__(self, + max_guesses: int = 6, + **kwargs) -> None: + super(minADE, self).__init__(**kwargs) + self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') + self.max_guesses = max_guesses + self.eval_timestep = 70 + + def update(self, + pred: torch.Tensor, + target: torch.Tensor, + prob: Optional[torch.Tensor] = None, + valid_mask: Optional[torch.Tensor] = None, + keep_invalid_final_step: bool = True, + min_criterion: str = 'ADE') -> None: + # pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step) + # pred_topk, _ = topk(self.max_guesses, pred, prob) + # if min_criterion == 'FDE': + # inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1) + # inds_best = torch.norm( + # pred[torch.arange(pred.size(0)), :, inds_last] - + # target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), p=2, dim=-1).argmin(dim=-1) + # self.sum += ((torch.norm(pred[torch.arange(pred.size(0)), inds_best] - target, p=2, dim=-1) * + # valid_mask).sum(dim=-1) / valid_mask.sum(dim=-1)).sum() + # elif min_criterion == 'ADE': + # self.sum += ((torch.norm(pred - target.unsqueeze(1), p=2, dim=-1) * + # valid_mask.unsqueeze(1)).sum(dim=-1).min(dim=-1)[0] / valid_mask.sum(dim=-1)).sum() + # else: + # raise ValueError('{} is not a valid criterion'.format(min_criterion)) + eval_timestep = min(self.eval_timestep, pred.shape[1]) + self.sum += ((torch.norm(pred[:, :eval_timestep] - target[:, :eval_timestep], p=2, dim=-1) * valid_mask[:, :eval_timestep]).sum(dim=-1) / pred.shape[1]).sum() + self.count += valid_mask[:, :eval_timestep].any(dim=-1).sum() + + def compute(self) -> torch.Tensor: + return self.sum / self.count diff --git a/backups/thirdparty/SMART/smart/metrics/min_fde.py b/backups/thirdparty/SMART/smart/metrics/min_fde.py new file mode 100644 index 0000000000000000000000000000000000000000..c60d9d7f5b2dc507d5a82dc51d575db8b076cd40 --- /dev/null +++ b/backups/thirdparty/SMART/smart/metrics/min_fde.py @@ -0,0 +1,61 @@ +from typing import Optional + +import torch +from torchmetrics import Metric + +from smart.metrics.utils import topk +from smart.metrics.utils import valid_filter + + +class minMultiFDE(Metric): + + def __init__(self, + max_guesses: int = 6, + **kwargs) -> None: + super(minMultiFDE, self).__init__(**kwargs) + self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') + self.max_guesses = max_guesses + + def update(self, + pred: torch.Tensor, + target: torch.Tensor, + prob: Optional[torch.Tensor] = None, + valid_mask: Optional[torch.Tensor] = None, + keep_invalid_final_step: bool = True) -> None: + pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step) + pred_topk, _ = topk(self.max_guesses, pred, prob) + inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1) + self.sum += torch.norm(pred_topk[torch.arange(pred.size(0)), :, inds_last] - + target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), + p=2, dim=-1).min(dim=-1)[0].sum() + self.count += pred.size(0) + + def compute(self) -> torch.Tensor: + return self.sum / self.count + + +class minFDE(Metric): + + def __init__(self, + max_guesses: int = 6, + **kwargs) -> None: + super(minFDE, self).__init__(**kwargs) + self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') + self.max_guesses = max_guesses + self.eval_timestep = 70 + + def update(self, + pred: torch.Tensor, + target: torch.Tensor, + prob: Optional[torch.Tensor] = None, + valid_mask: Optional[torch.Tensor] = None, + keep_invalid_final_step: bool = True) -> None: + eval_timestep = min(self.eval_timestep, pred.shape[1]) - 1 + self.sum += ((torch.norm(pred[:, eval_timestep-1:eval_timestep] - target[:, eval_timestep-1:eval_timestep], p=2, dim=-1) * + valid_mask[:, eval_timestep-1].unsqueeze(1)).sum(dim=-1)).sum() + self.count += valid_mask[:, eval_timestep-1].sum() + + def compute(self) -> torch.Tensor: + return self.sum / self.count diff --git a/backups/thirdparty/SMART/smart/metrics/next_token_cls.py b/backups/thirdparty/SMART/smart/metrics/next_token_cls.py new file mode 100644 index 0000000000000000000000000000000000000000..72c1a3a22897bca15eb6647ad91b30106de24070 --- /dev/null +++ b/backups/thirdparty/SMART/smart/metrics/next_token_cls.py @@ -0,0 +1,30 @@ +from typing import Optional + +import torch +from torchmetrics import Metric + +from smart.metrics.utils import topk +from smart.metrics.utils import valid_filter + + +class TokenCls(Metric): + + def __init__(self, + max_guesses: int = 6, + **kwargs) -> None: + super(TokenCls, self).__init__(**kwargs) + self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') + self.max_guesses = max_guesses + + def update(self, + pred: torch.Tensor, + target: torch.Tensor, + valid_mask: Optional[torch.Tensor] = None) -> None: + target = target[..., None] + acc = (pred[:, :self.max_guesses] == target).any(dim=1) * valid_mask + self.sum += acc.sum() + self.count += valid_mask.sum() + + def compute(self) -> torch.Tensor: + return self.sum / self.count diff --git a/backups/thirdparty/SMART/smart/metrics/utils.py b/backups/thirdparty/SMART/smart/metrics/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8c610fe4ae7cccffcdba79c03ed9056e722cb41c --- /dev/null +++ b/backups/thirdparty/SMART/smart/metrics/utils.py @@ -0,0 +1,278 @@ +from typing import Optional, Tuple + +import torch +from torch_scatter import gather_csr +from torch_scatter import segment_csr + + +def topk( + max_guesses: int, + pred: torch.Tensor, + prob: Optional[torch.Tensor] = None, + ptr: Optional[torch.Tensor] = None, + joint: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + max_guesses = min(max_guesses, pred.size(1)) + if max_guesses == pred.size(1): + if prob is not None: + prob = prob / prob.sum(dim=-1, keepdim=True) + else: + prob = pred.new_ones((pred.size(0), max_guesses)) / max_guesses + return pred, prob + else: + if prob is not None: + if joint: + if ptr is None: + inds_topk = torch.topk((prob / prob.sum(dim=-1, keepdim=True)).mean(dim=0, keepdim=True), + k=max_guesses, dim=-1, largest=True, sorted=True)[1] + inds_topk = inds_topk.repeat(pred.size(0), 1) + else: + inds_topk = torch.topk(segment_csr(src=prob / prob.sum(dim=-1, keepdim=True), indptr=ptr, + reduce='mean'), + k=max_guesses, dim=-1, largest=True, sorted=True)[1] + inds_topk = gather_csr(src=inds_topk, indptr=ptr) + else: + inds_topk = torch.topk(prob, k=max_guesses, dim=-1, largest=True, sorted=True)[1] + pred_topk = pred[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk] + prob_topk = prob[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk] + prob_topk = prob_topk / prob_topk.sum(dim=-1, keepdim=True) + else: + pred_topk = pred[:, :max_guesses] + prob_topk = pred.new_ones((pred.size(0), max_guesses)) / max_guesses + return pred_topk, prob_topk + + +def topkind( + max_guesses: int, + pred: torch.Tensor, + prob: Optional[torch.Tensor] = None, + ptr: Optional[torch.Tensor] = None, + joint: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + max_guesses = min(max_guesses, pred.size(1)) + if max_guesses == pred.size(1): + if prob is not None: + prob = prob / prob.sum(dim=-1, keepdim=True) + else: + prob = pred.new_ones((pred.size(0), max_guesses)) / max_guesses + return pred, prob, None + else: + if prob is not None: + if joint: + if ptr is None: + inds_topk = torch.topk((prob / prob.sum(dim=-1, keepdim=True)).mean(dim=0, keepdim=True), + k=max_guesses, dim=-1, largest=True, sorted=True)[1] + inds_topk = inds_topk.repeat(pred.size(0), 1) + else: + inds_topk = torch.topk(segment_csr(src=prob / prob.sum(dim=-1, keepdim=True), indptr=ptr, + reduce='mean'), + k=max_guesses, dim=-1, largest=True, sorted=True)[1] + inds_topk = gather_csr(src=inds_topk, indptr=ptr) + else: + inds_topk = torch.topk(prob, k=max_guesses, dim=-1, largest=True, sorted=True)[1] + pred_topk = pred[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk] + prob_topk = prob[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk] + prob_topk = prob_topk / prob_topk.sum(dim=-1, keepdim=True) + else: + pred_topk = pred[:, :max_guesses] + prob_topk = pred.new_ones((pred.size(0), max_guesses)) / max_guesses + return pred_topk, prob_topk, inds_topk + + +def valid_filter( + pred: torch.Tensor, + target: torch.Tensor, + prob: Optional[torch.Tensor] = None, + valid_mask: Optional[torch.Tensor] = None, + ptr: Optional[torch.Tensor] = None, + keep_invalid_final_step: bool = True) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], + torch.Tensor, torch.Tensor]: + if valid_mask is None: + valid_mask = target.new_ones(target.size()[:-1], dtype=torch.bool) + if keep_invalid_final_step: + filter_mask = valid_mask.any(dim=-1) + else: + filter_mask = valid_mask[:, -1] + pred = pred[filter_mask] + target = target[filter_mask] + if prob is not None: + prob = prob[filter_mask] + valid_mask = valid_mask[filter_mask] + if ptr is not None: + num_nodes_batch = segment_csr(src=filter_mask.long(), indptr=ptr, reduce='sum') + ptr = num_nodes_batch.new_zeros((num_nodes_batch.size(0) + 1,)) + torch.cumsum(num_nodes_batch, dim=0, out=ptr[1:]) + else: + ptr = target.new_tensor([0, target.size(0)]) + return pred, target, prob, valid_mask, ptr + + +def new_batch_nms(pred_trajs, dist_thresh, num_ret_modes=6): + """ + + Args: + pred_trajs (batch_size, num_modes, num_timestamps, 7) + pred_scores (batch_size, num_modes): + dist_thresh (float): + num_ret_modes (int, optional): Defaults to 6. + + Returns: + ret_trajs (batch_size, num_ret_modes, num_timestamps, 5) + ret_scores (batch_size, num_ret_modes) + ret_idxs (batch_size, num_ret_modes) + """ + batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape + pred_goals = pred_trajs[:, :, -1, :] + dist = (pred_goals[:, :, None, 0:2] - pred_goals[:, None, :, 0:2]).norm(dim=-1) + nearby_neighbor = dist < dist_thresh + pred_scores = nearby_neighbor.sum(dim=-1) / num_modes + + sorted_idxs = pred_scores.argsort(dim=-1, descending=True) + bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes) + sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs] + sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7) + sorted_pred_goals = sorted_pred_trajs[:, :, -1, :] # (batch_size, num_modes, 7) + + dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1) + point_cover_mask = (dist < dist_thresh) + + point_val = sorted_pred_scores.clone() # (batch_size, N) + point_val_selected = torch.zeros_like(point_val) # (batch_size, N) + + ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long() + ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim) + ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes) + bs_idxs = torch.arange(batch_size).type_as(ret_idxs) + + for k in range(num_ret_modes): + cur_idx = point_val.argmax(dim=-1) # (batch_size) + ret_idxs[:, k] = cur_idx + + new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N) + point_val = point_val * (~new_cover_mask).float() # (batch_size, N) + point_val_selected[bs_idxs, cur_idx] = -1 + point_val += point_val_selected + + ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx] + ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx] + + bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes) + + ret_idxs = sorted_idxs[bs_idxs, ret_idxs] + return ret_trajs, ret_scores, ret_idxs + + +def batch_nms(pred_trajs, pred_scores, + dist_thresh, num_ret_modes=6, + mode='static', speed=None): + """ + + Args: + pred_trajs (batch_size, num_modes, num_timestamps, 7) + pred_scores (batch_size, num_modes): + dist_thresh (float): + num_ret_modes (int, optional): Defaults to 6. + + Returns: + ret_trajs (batch_size, num_ret_modes, num_timestamps, 5) + ret_scores (batch_size, num_ret_modes) + ret_idxs (batch_size, num_ret_modes) + """ + batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape + + sorted_idxs = pred_scores.argsort(dim=-1, descending=True) + bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes) + sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs] + sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7) + sorted_pred_goals = sorted_pred_trajs[:, :, -1, :] # (batch_size, num_modes, 7) + + if mode == "speed": + scale = torch.ones(batch_size).to(sorted_pred_goals.device) + lon_dist_thresh = 4 * scale + lat_dist_thresh = 0.5 * scale + lon_dist = (sorted_pred_goals[:, :, None, [0]] - sorted_pred_goals[:, None, :, [0]]).norm(dim=-1) + lat_dist = (sorted_pred_goals[:, :, None, [1]] - sorted_pred_goals[:, None, :, [1]]).norm(dim=-1) + point_cover_mask = (lon_dist < lon_dist_thresh[:, None, None]) & (lat_dist < lat_dist_thresh[:, None, None]) + else: + dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1) + point_cover_mask = (dist < dist_thresh) + + point_val = sorted_pred_scores.clone() # (batch_size, N) + point_val_selected = torch.zeros_like(point_val) # (batch_size, N) + + ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long() + ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim) + ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes) + bs_idxs = torch.arange(batch_size).type_as(ret_idxs) + + for k in range(num_ret_modes): + cur_idx = point_val.argmax(dim=-1) # (batch_size) + ret_idxs[:, k] = cur_idx + + new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N) + point_val = point_val * (~new_cover_mask).float() # (batch_size, N) + point_val_selected[bs_idxs, cur_idx] = -1 + point_val += point_val_selected + + ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx] + ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx] + + bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes) + + ret_idxs = sorted_idxs[bs_idxs, ret_idxs] + return ret_trajs, ret_scores, ret_idxs + + +def batch_nms_token(pred_trajs, pred_scores, + dist_thresh, num_ret_modes=6, + mode='static', speed=None): + """ + Args: + pred_trajs (batch_size, num_modes, num_timestamps, 7) + pred_scores (batch_size, num_modes): + dist_thresh (float): + num_ret_modes (int, optional): Defaults to 6. + + Returns: + ret_trajs (batch_size, num_ret_modes, num_timestamps, 5) + ret_scores (batch_size, num_ret_modes) + ret_idxs (batch_size, num_ret_modes) + """ + batch_size, num_modes, num_feat_dim = pred_trajs.shape + + sorted_idxs = pred_scores.argsort(dim=-1, descending=True) + bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes) + sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs] + sorted_pred_goals = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7) + + if mode == "nearby": + dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1) + values, indices = torch.topk(dist, 5, dim=-1, largest=False) + thresh_hold = values[..., -1] + point_cover_mask = dist < thresh_hold[..., None] + else: + dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1) + point_cover_mask = (dist < dist_thresh) + + point_val = sorted_pred_scores.clone() # (batch_size, N) + point_val_selected = torch.zeros_like(point_val) # (batch_size, N) + + ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long() + ret_goals = sorted_pred_goals.new_zeros(batch_size, num_ret_modes, num_feat_dim) + ret_scores = sorted_pred_goals.new_zeros(batch_size, num_ret_modes) + bs_idxs = torch.arange(batch_size).type_as(ret_idxs) + + for k in range(num_ret_modes): + cur_idx = point_val.argmax(dim=-1) # (batch_size) + ret_idxs[:, k] = cur_idx + + new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N) + point_val = point_val * (~new_cover_mask).float() # (batch_size, N) + point_val_selected[bs_idxs, cur_idx] = -1 + point_val += point_val_selected + + ret_goals[:, k] = sorted_pred_goals[bs_idxs, cur_idx] + ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx] + + bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes) + + ret_idxs = sorted_idxs[bs_idxs, ret_idxs] + return ret_goals, ret_scores, ret_idxs diff --git a/backups/thirdparty/SMART/smart/model/__init__.py b/backups/thirdparty/SMART/smart/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7e27dd99277f1ee19826a8acc4a2db3730d6fa1f --- /dev/null +++ b/backups/thirdparty/SMART/smart/model/__init__.py @@ -0,0 +1 @@ +from smart.model.smart import SMART diff --git a/backups/thirdparty/SMART/smart/model/smart.py b/backups/thirdparty/SMART/smart/model/smart.py new file mode 100644 index 0000000000000000000000000000000000000000..515b5564e2e362d49406a1b67f27fa7996300a92 --- /dev/null +++ b/backups/thirdparty/SMART/smart/model/smart.py @@ -0,0 +1,341 @@ +import contextlib +import pytorch_lightning as pl +import torch +import torch.nn as nn +from torch_geometric.data import Batch +from torch_geometric.data import HeteroData +from smart.metrics import minADE +from smart.metrics import minFDE +from smart.metrics import TokenCls +from smart.modules import SMARTDecoder +from torch.optim.lr_scheduler import LambdaLR +import math +import numpy as np +import pickle +from collections import defaultdict +import os +from waymo_open_dataset.protos import sim_agents_submission_pb2 + + +def cal_polygon_contour(x, y, theta, width, length): + left_front_x = x + 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta) + left_front_y = y + 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta) + left_front = (left_front_x, left_front_y) + + right_front_x = x + 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta) + right_front_y = y + 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta) + right_front = (right_front_x, right_front_y) + + right_back_x = x - 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta) + right_back_y = y - 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta) + right_back = (right_back_x, right_back_y) + + left_back_x = x - 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta) + left_back_y = y - 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta) + left_back = (left_back_x, left_back_y) + polygon_contour = [left_front, right_front, right_back, left_back] + + return polygon_contour + + +def joint_scene_from_states(states, object_ids) -> sim_agents_submission_pb2.JointScene: + states = states.numpy() + simulated_trajectories = [] + for i_object in range(len(object_ids)): + simulated_trajectories.append(sim_agents_submission_pb2.SimulatedTrajectory( + center_x=states[i_object, :, 0], center_y=states[i_object, :, 1], + center_z=states[i_object, :, 2], heading=states[i_object, :, 3], + object_id=object_ids[i_object].item() + )) + return sim_agents_submission_pb2.JointScene(simulated_trajectories=simulated_trajectories) + + +class SMART(pl.LightningModule): + + def __init__(self, model_config) -> None: + super(SMART, self).__init__() + self.save_hyperparameters() + self.model_config = model_config + self.warmup_steps = model_config.warmup_steps + self.lr = model_config.lr + self.total_steps = model_config.total_steps + self.dataset = model_config.dataset + self.input_dim = model_config.input_dim + self.hidden_dim = model_config.hidden_dim + self.output_dim = model_config.output_dim + self.output_head = model_config.output_head + self.num_historical_steps = model_config.num_historical_steps + self.num_future_steps = model_config.decoder.num_future_steps + self.num_freq_bands = model_config.num_freq_bands + self.vis_map = False + self.noise = True + module_dir = os.path.dirname(os.path.dirname(__file__)) + self.map_token_traj_path = os.path.join(module_dir, 'tokens/map_traj_token5.pkl') + self.init_map_token() + self.token_path = os.path.join(module_dir, 'tokens/cluster_frame_5_2048.pkl') + token_data = self.get_trajectory_token() + self.encoder = SMARTDecoder( + dataset=model_config.dataset, + input_dim=model_config.input_dim, + hidden_dim=model_config.hidden_dim, + num_historical_steps=model_config.num_historical_steps, + num_freq_bands=model_config.num_freq_bands, + num_heads=model_config.num_heads, + head_dim=model_config.head_dim, + dropout=model_config.dropout, + num_map_layers=model_config.decoder.num_map_layers, + num_agent_layers=model_config.decoder.num_agent_layers, + pl2pl_radius=model_config.decoder.pl2pl_radius, + pl2a_radius=model_config.decoder.pl2a_radius, + a2a_radius=model_config.decoder.a2a_radius, + time_span=model_config.decoder.time_span, + map_token={'traj_src': self.map_token['traj_src']}, + token_data=token_data, + token_size=model_config.decoder.token_size + ) + self.minADE = minADE(max_guesses=1) + self.minFDE = minFDE(max_guesses=1) + self.TokenCls = TokenCls(max_guesses=1) + + self.test_predictions = dict() + self.cls_loss = nn.CrossEntropyLoss(label_smoothing=0.1) + self.map_cls_loss = nn.CrossEntropyLoss(label_smoothing=0.1) + self.inference_token = False + self.rollout_num = 1 + + def get_trajectory_token(self): + token_data = pickle.load(open(self.token_path, 'rb')) + self.trajectory_token = token_data['token'] + self.trajectory_token_traj = token_data['traj'] + self.trajectory_token_all = token_data['token_all'] + return token_data + + def init_map_token(self): + self.argmin_sample_len = 3 + map_token_traj = pickle.load(open(self.map_token_traj_path, 'rb')) + self.map_token = {'traj_src': map_token_traj['traj_src'], } + traj_end_theta = np.arctan2(self.map_token['traj_src'][:, -1, 1]-self.map_token['traj_src'][:, -2, 1], + self.map_token['traj_src'][:, -1, 0]-self.map_token['traj_src'][:, -2, 0]) + indices = torch.linspace(0, self.map_token['traj_src'].shape[1]-1, steps=self.argmin_sample_len).long() + self.map_token['sample_pt'] = torch.from_numpy(self.map_token['traj_src'][:, indices]).to(torch.float) + self.map_token['traj_end_theta'] = torch.from_numpy(traj_end_theta).to(torch.float) + self.map_token['traj_src'] = torch.from_numpy(self.map_token['traj_src']).to(torch.float) + + def forward(self, data: HeteroData): + res = self.encoder(data) + return res + + def inference(self, data: HeteroData): + res = self.encoder.inference(data) + return res + + def maybe_autocast(self, dtype=torch.float16): + enable_autocast = self.device != torch.device("cpu") + + if enable_autocast: + return torch.cuda.amp.autocast(dtype=dtype) + else: + return contextlib.nullcontext() + + def training_step(self, + data, + batch_idx): + data = self.match_token_map(data) + data = self.sample_pt_pred(data) + if isinstance(data, Batch): + data['agent']['av_index'] += data['agent']['ptr'][:-1] + pred = self(data) + next_token_prob = pred['next_token_prob'] + next_token_idx_gt = pred['next_token_idx_gt'] + next_token_eval_mask = pred['next_token_eval_mask'] + cls_loss = self.cls_loss(next_token_prob[next_token_eval_mask], next_token_idx_gt[next_token_eval_mask]) + loss = cls_loss + self.log('train_loss', loss, prog_bar=True, on_step=True, on_epoch=True, batch_size=1) + self.log('cls_loss', cls_loss, prog_bar=True, on_step=True, on_epoch=True, batch_size=1) + return loss + + def validation_step(self, + data, + batch_idx): + data = self.match_token_map(data) + data = self.sample_pt_pred(data) + if isinstance(data, Batch): + data['agent']['av_index'] += data['agent']['ptr'][:-1] + pred = self(data) + next_token_idx = pred['next_token_idx'] + next_token_idx_gt = pred['next_token_idx_gt'] + next_token_eval_mask = pred['next_token_eval_mask'] + next_token_prob = pred['next_token_prob'] + cls_loss = self.cls_loss(next_token_prob[next_token_eval_mask], next_token_idx_gt[next_token_eval_mask]) + loss = cls_loss + self.TokenCls.update(pred=next_token_idx[next_token_eval_mask], target=next_token_idx_gt[next_token_eval_mask], + valid_mask=next_token_eval_mask[next_token_eval_mask]) + self.log('val_cls_acc', self.TokenCls, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True) + self.log('val_loss', loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True) + + eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps-1] # * (data['agent']['category'] == 3) + if self.inference_token: + pred = self.inference(data) + pos_a = pred['pos_a'] + gt = pred['gt'] + valid_mask = data['agent']['valid_mask'][:, self.num_historical_steps:] + pred_traj = pred['pred_traj'] + # next_token_idx = pred['next_token_idx'][..., None] + # next_token_idx_gt = pred['next_token_idx_gt'][:, 2:] + # next_token_eval_mask = pred['next_token_eval_mask'][:, 2:] + # next_token_eval_mask[:, 1:] = False + # self.TokenCls.update(pred=next_token_idx[next_token_eval_mask], target=next_token_idx_gt[next_token_eval_mask], + # valid_mask=next_token_eval_mask[next_token_eval_mask]) + # self.log('val_inference_cls_acc', self.TokenCls, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True) + eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps-1] + + self.minADE.update(pred=pred_traj[eval_mask], target=gt[eval_mask], valid_mask=valid_mask[eval_mask]) + self.minFDE.update(pred=pred_traj[eval_mask], target=gt[eval_mask], valid_mask=valid_mask[eval_mask]) + # print('ade: ', self.minADE.compute(), 'fde: ', self.minFDE.compute()) + + self.log('val_minADE', self.minADE, prog_bar=True, on_step=False, on_epoch=True, batch_size=1) + self.log('val_minFDE', self.minFDE, prog_bar=True, on_step=False, on_epoch=True, batch_size=1) + + def on_validation_start(self): + self.gt = [] + self.pred = [] + self.scenario_rollouts = [] + self.batch_metric = defaultdict(list) + + def configure_optimizers(self): + optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) + + def lr_lambda(current_step): + if current_step + 1 < self.warmup_steps: + return float(current_step + 1) / float(max(1, self.warmup_steps)) + return max( + 0.0, 0.5 * (1.0 + math.cos(math.pi * (current_step - self.warmup_steps) / float(max(1, self.total_steps - self.warmup_steps)))) + ) + + lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) + return [optimizer], [lr_scheduler] + + def load_params_from_file(self, filename, logger, to_cpu=False): + if not os.path.isfile(filename): + raise FileNotFoundError + + logger.info('==> Loading parameters from checkpoint %s to %s' % (filename, 'CPU' if to_cpu else 'GPU')) + loc_type = torch.device('cpu') if to_cpu else None + checkpoint = torch.load(filename, map_location=loc_type) + model_state_disk = checkpoint['state_dict'] + + version = checkpoint.get("version", None) + if version is not None: + logger.info('==> Checkpoint trained from version: %s' % version) + + logger.info(f'The number of disk ckpt keys: {len(model_state_disk)}') + model_state = self.state_dict() + model_state_disk_filter = {} + for key, val in model_state_disk.items(): + if key in model_state and model_state_disk[key].shape == model_state[key].shape: + model_state_disk_filter[key] = val + else: + if key not in model_state: + print(f'Ignore key in disk (not found in model): {key}, shape={val.shape}') + else: + print(f'Ignore key in disk (shape does not match): {key}, load_shape={val.shape}, model_shape={model_state[key].shape}') + + model_state_disk = model_state_disk_filter + + missing_keys, unexpected_keys = self.load_state_dict(model_state_disk, strict=False) + + logger.info(f'Missing keys: {missing_keys}') + logger.info(f'The number of missing keys: {len(missing_keys)}') + logger.info(f'The number of unexpected keys: {len(unexpected_keys)}') + logger.info('==> Done (total keys %d)' % (len(model_state))) + + epoch = checkpoint.get('epoch', -1) + it = checkpoint.get('it', 0.0) + + return it, epoch + + def match_token_map(self, data): + traj_pos = data['map_save']['traj_pos'].to(torch.float) + traj_theta = data['map_save']['traj_theta'].to(torch.float) + pl_idx_list = data['map_save']['pl_idx_list'] + token_sample_pt = self.map_token['sample_pt'].to(traj_pos.device) + token_src = self.map_token['traj_src'].to(traj_pos.device) + max_traj_len = self.map_token['traj_src'].shape[1] + pl_num = traj_pos.shape[0] + + pt_token_pos = traj_pos[:, 0, :].clone() + pt_token_orientation = traj_theta.clone() + cos, sin = traj_theta.cos(), traj_theta.sin() + rot_mat = traj_theta.new_zeros(pl_num, 2, 2) + rot_mat[..., 0, 0] = cos + rot_mat[..., 0, 1] = -sin + rot_mat[..., 1, 0] = sin + rot_mat[..., 1, 1] = cos + traj_pos_local = torch.bmm((traj_pos - traj_pos[:, 0:1]), rot_mat.view(-1, 2, 2)) + distance = torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1))**2, dim=(-2, -1)) + pt_token_id = torch.argmin(distance, dim=1) + + if self.noise: + topk_indices = torch.argsort(torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1))**2, dim=(-2, -1)), dim=1)[:, :8] + sample_topk = torch.randint(0, topk_indices.shape[-1], size=(topk_indices.shape[0], 1), device=topk_indices.device) + pt_token_id = torch.gather(topk_indices, 1, sample_topk).squeeze(-1) + + cos, sin = traj_theta.cos(), traj_theta.sin() + rot_mat = traj_theta.new_zeros(pl_num, 2, 2) + rot_mat[..., 0, 0] = cos + rot_mat[..., 0, 1] = sin + rot_mat[..., 1, 0] = -sin + rot_mat[..., 1, 1] = cos + token_src_world = torch.bmm(token_src[None, ...].repeat(pl_num, 1, 1, 1).reshape(pl_num, -1, 2), + rot_mat.view(-1, 2, 2)).reshape(pl_num, token_src.shape[0], max_traj_len, 2) + traj_pos[:, None, [0], :] + token_src_world_select = token_src_world.view(-1, 1024, 11, 2)[torch.arange(pt_token_id.view(-1).shape[0]), pt_token_id.view(-1)].view(pl_num, max_traj_len, 2) + + pl_idx_full = pl_idx_list.clone() + token2pl = torch.stack([torch.arange(len(pl_idx_list), device=traj_pos.device), pl_idx_full.long()]) + count_nums = [] + for pl in pl_idx_full.unique(): + pt = token2pl[0, token2pl[1, :] == pl] + left_side = (data['pt_token']['side'][pt] == 0).sum() + right_side = (data['pt_token']['side'][pt] == 1).sum() + center_side = (data['pt_token']['side'][pt] == 2).sum() + count_nums.append(torch.Tensor([left_side, right_side, center_side])) + count_nums = torch.stack(count_nums, dim=0) + num_polyline = int(count_nums.max().item()) + traj_mask = torch.zeros((int(len(pl_idx_full.unique())), 3, num_polyline), dtype=bool) + idx_matrix = torch.arange(traj_mask.size(2)).unsqueeze(0).unsqueeze(0) + idx_matrix = idx_matrix.expand(traj_mask.size(0), traj_mask.size(1), -1) # + counts_num_expanded = count_nums.unsqueeze(-1) + mask_update = idx_matrix < counts_num_expanded + traj_mask[mask_update] = True + + data['pt_token']['traj_mask'] = traj_mask + data['pt_token']['position'] = torch.cat([pt_token_pos, torch.zeros((data['pt_token']['num_nodes'], 1), + device=traj_pos.device, dtype=torch.float)], dim=-1) + data['pt_token']['orientation'] = pt_token_orientation + data['pt_token']['height'] = data['pt_token']['position'][:, -1] + data[('pt_token', 'to', 'map_polygon')] = {} + data[('pt_token', 'to', 'map_polygon')]['edge_index'] = token2pl + data['pt_token']['token_idx'] = pt_token_id + return data + + def sample_pt_pred(self, data): + traj_mask = data['pt_token']['traj_mask'] + raw_pt_index = torch.arange(1, traj_mask.shape[2]).repeat(traj_mask.shape[0], traj_mask.shape[1], 1) + masked_pt_index = raw_pt_index.view(-1)[torch.randperm(raw_pt_index.numel())[:traj_mask.shape[0]*traj_mask.shape[1]*((traj_mask.shape[2]-1)//3)].reshape(traj_mask.shape[0], traj_mask.shape[1], (traj_mask.shape[2]-1)//3)] + masked_pt_index = torch.sort(masked_pt_index, -1)[0] + pt_valid_mask = traj_mask.clone() + pt_valid_mask.scatter_(2, masked_pt_index, False) + pt_pred_mask = traj_mask.clone() + pt_pred_mask.scatter_(2, masked_pt_index, False) + tmp_mask = pt_pred_mask.clone() + tmp_mask[:, :, :] = True + tmp_mask.scatter_(2, masked_pt_index-1, False) + pt_pred_mask.masked_fill_(tmp_mask, False) + pt_pred_mask = pt_pred_mask * torch.roll(traj_mask, shifts=-1, dims=2) + pt_target_mask = torch.roll(pt_pred_mask, shifts=1, dims=2) + + data['pt_token']['pt_valid_mask'] = pt_valid_mask[traj_mask] + data['pt_token']['pt_pred_mask'] = pt_pred_mask[traj_mask] + data['pt_token']['pt_target_mask'] = pt_target_mask[traj_mask] + + return data diff --git a/backups/thirdparty/SMART/smart/modules/__init__.py b/backups/thirdparty/SMART/smart/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..63ea9cc8e3a69633807b50375fc2796d4dee3e27 --- /dev/null +++ b/backups/thirdparty/SMART/smart/modules/__init__.py @@ -0,0 +1,3 @@ +from smart.modules.smart_decoder import SMARTDecoder +from smart.modules.map_decoder import SMARTMapDecoder +from smart.modules.agent_decoder import SMARTAgentDecoder diff --git a/backups/thirdparty/SMART/smart/modules/agent_decoder.py b/backups/thirdparty/SMART/smart/modules/agent_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e95dba39e3fded5ec921886a487ad065fb0ecc7b --- /dev/null +++ b/backups/thirdparty/SMART/smart/modules/agent_decoder.py @@ -0,0 +1,509 @@ +import pickle +from typing import Dict, Mapping, Optional +import torch +import torch.nn as nn +from smart.layers import MLPLayer +from smart.layers.attention_layer import AttentionLayer +from smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding +from torch_cluster import radius, radius_graph +from torch_geometric.data import Batch, HeteroData +from torch_geometric.utils import dense_to_sparse, subgraph +from smart.utils import angle_between_2d_vectors, weight_init, wrap_angle +import math + + +def cal_polygon_contour(x, y, theta, width, length): + left_front_x = x + 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta) + left_front_y = y + 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta) + left_front = (left_front_x, left_front_y) + + right_front_x = x + 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta) + right_front_y = y + 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta) + right_front = (right_front_x, right_front_y) + + right_back_x = x - 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta) + right_back_y = y - 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta) + right_back = (right_back_x, right_back_y) + + left_back_x = x - 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta) + left_back_y = y - 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta) + left_back = (left_back_x, left_back_y) + polygon_contour = [left_front, right_front, right_back, left_back] + + return polygon_contour + + +class SMARTAgentDecoder(nn.Module): + + def __init__(self, + dataset: str, + input_dim: int, + hidden_dim: int, + num_historical_steps: int, + time_span: Optional[int], + pl2a_radius: float, + a2a_radius: float, + num_freq_bands: int, + num_layers: int, + num_heads: int, + head_dim: int, + dropout: float, + token_data: Dict, + token_size=512) -> None: + super(SMARTAgentDecoder, self).__init__() + self.dataset = dataset + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.num_historical_steps = num_historical_steps + self.time_span = time_span if time_span is not None else num_historical_steps + self.pl2a_radius = pl2a_radius + self.a2a_radius = a2a_radius + self.num_freq_bands = num_freq_bands + self.num_layers = num_layers + self.num_heads = num_heads + self.head_dim = head_dim + self.dropout = dropout + + input_dim_x_a = 2 + input_dim_r_t = 4 + input_dim_r_pt2a = 3 + input_dim_r_a2a = 3 + input_dim_token = 8 + + self.type_a_emb = nn.Embedding(4, hidden_dim) + self.shape_emb = MLPLayer(3, hidden_dim, hidden_dim) + + self.x_a_emb = FourierEmbedding(input_dim=input_dim_x_a, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands) + self.r_t_emb = FourierEmbedding(input_dim=input_dim_r_t, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands) + self.r_pt2a_emb = FourierEmbedding(input_dim=input_dim_r_pt2a, hidden_dim=hidden_dim, + num_freq_bands=num_freq_bands) + self.r_a2a_emb = FourierEmbedding(input_dim=input_dim_r_a2a, hidden_dim=hidden_dim, + num_freq_bands=num_freq_bands) + self.token_emb_veh = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim) + self.token_emb_ped = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim) + self.token_emb_cyc = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim) + self.fusion_emb = MLPEmbedding(input_dim=self.hidden_dim * 2, hidden_dim=self.hidden_dim) + + self.t_attn_layers = nn.ModuleList( + [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, + bipartite=False, has_pos_emb=True) for _ in range(num_layers)] + ) + self.pt2a_attn_layers = nn.ModuleList( + [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, + bipartite=True, has_pos_emb=True) for _ in range(num_layers)] + ) + self.a2a_attn_layers = nn.ModuleList( + [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, + bipartite=False, has_pos_emb=True) for _ in range(num_layers)] + ) + self.token_size = token_size + self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, + output_dim=self.token_size) + self.trajectory_token = token_data['token'] + self.trajectory_token_traj = token_data['traj'] + self.trajectory_token_all = token_data['token_all'] + self.apply(weight_init) + self.shift = 5 + self.beam_size = 5 + self.hist_mask = True + + def transform_rel(self, token_traj, prev_pos, prev_heading=None): + if prev_heading is None: + diff_xy = prev_pos[:, :, -1, :] - prev_pos[:, :, -2, :] + prev_heading = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0]) + + num_agent, num_step, traj_num, traj_dim = token_traj.shape + cos, sin = prev_heading.cos(), prev_heading.sin() + rot_mat = torch.zeros((num_agent, num_step, 2, 2), device=prev_heading.device) + rot_mat[:, :, 0, 0] = cos + rot_mat[:, :, 0, 1] = -sin + rot_mat[:, :, 1, 0] = sin + rot_mat[:, :, 1, 1] = cos + agent_diff_rel = torch.bmm(token_traj.view(-1, traj_num, 2), rot_mat.view(-1, 2, 2)).view(num_agent, num_step, traj_num, traj_dim) + agent_pred_rel = agent_diff_rel + prev_pos[:, :, -1:, :] + return agent_pred_rel + + def agent_token_embedding(self, data, agent_category, agent_token_index, pos_a, head_vector_a, inference=False): + num_agent, num_step, traj_dim = pos_a.shape + motion_vector_a = torch.cat([pos_a.new_zeros(data['agent']['num_nodes'], 1, self.input_dim), + pos_a[:, 1:] - pos_a[:, :-1]], dim=1) + + agent_type = data['agent']['type'] + veh_mask = (agent_type == 0) + cyc_mask = (agent_type == 2) + ped_mask = (agent_type == 1) + trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(pos_a.device).to(torch.float) + self.agent_token_emb_veh = self.token_emb_veh(trajectory_token_veh.view(trajectory_token_veh.shape[0], -1)) + trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(pos_a.device).to(torch.float) + self.agent_token_emb_ped = self.token_emb_ped(trajectory_token_ped.view(trajectory_token_ped.shape[0], -1)) + trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(pos_a.device).to(torch.float) + self.agent_token_emb_cyc = self.token_emb_cyc(trajectory_token_cyc.view(trajectory_token_cyc.shape[0], -1)) + + if inference: + agent_token_traj_all = torch.zeros((num_agent, self.token_size, self.shift + 1, 4, 2), device=pos_a.device) + trajectory_token_all_veh = torch.from_numpy(self.trajectory_token_all['veh']).clone().to(pos_a.device).to( + torch.float) + trajectory_token_all_ped = torch.from_numpy(self.trajectory_token_all['ped']).clone().to(pos_a.device).to( + torch.float) + trajectory_token_all_cyc = torch.from_numpy(self.trajectory_token_all['cyc']).clone().to(pos_a.device).to( + torch.float) + agent_token_traj_all[veh_mask] = torch.cat( + [trajectory_token_all_veh[:, :self.shift], trajectory_token_veh[:, None, ...]], dim=1) + agent_token_traj_all[ped_mask] = torch.cat( + [trajectory_token_all_ped[:, :self.shift], trajectory_token_ped[:, None, ...]], dim=1) + agent_token_traj_all[cyc_mask] = torch.cat( + [trajectory_token_all_cyc[:, :self.shift], trajectory_token_cyc[:, None, ...]], dim=1) + + agent_token_emb = torch.zeros((num_agent, num_step, self.hidden_dim), device=pos_a.device) + agent_token_emb[veh_mask] = self.agent_token_emb_veh[agent_token_index[veh_mask]] + agent_token_emb[ped_mask] = self.agent_token_emb_ped[agent_token_index[ped_mask]] + agent_token_emb[cyc_mask] = self.agent_token_emb_cyc[agent_token_index[cyc_mask]] + + agent_token_traj = torch.zeros((num_agent, num_step, self.token_size, 4, 2), device=pos_a.device) + agent_token_traj[veh_mask] = trajectory_token_veh + agent_token_traj[ped_mask] = trajectory_token_ped + agent_token_traj[cyc_mask] = trajectory_token_cyc + + vel = data['agent']['token_velocity'] + + categorical_embs = [ + self.type_a_emb(data['agent']['type'].long()).repeat_interleave(repeats=num_step, + dim=0), + + self.shape_emb(data['agent']['shape'][:, self.num_historical_steps - 1, :]).repeat_interleave( + repeats=num_step, + dim=0) + ] + feature_a = torch.stack( + [torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2]), + ], dim=-1) + + x_a = self.x_a_emb(continuous_inputs=feature_a.view(-1, feature_a.size(-1)), + categorical_embs=categorical_embs) + x_a = x_a.view(-1, num_step, self.hidden_dim) + + feat_a = torch.cat((agent_token_emb, x_a), dim=-1) + feat_a = self.fusion_emb(feat_a) + + if inference: + return feat_a, agent_token_traj, agent_token_traj_all, agent_token_emb, categorical_embs + else: + return feat_a, agent_token_traj + + def agent_predict_next(self, data, agent_category, feat_a): + num_agent, num_step, traj_dim = data['agent']['token_pos'].shape + agent_type = data['agent']['type'] + veh_mask = (agent_type == 0) # * agent_category==3 + cyc_mask = (agent_type == 2) # * agent_category==3 + ped_mask = (agent_type == 1) # * agent_category==3 + token_res = torch.zeros((num_agent, num_step, self.token_size), device=agent_category.device) + token_res[veh_mask] = self.token_predict_head(feat_a[veh_mask]) + token_res[cyc_mask] = self.token_predict_cyc_head(feat_a[cyc_mask]) + token_res[ped_mask] = self.token_predict_walker_head(feat_a[ped_mask]) + return token_res + + def agent_predict_next_inf(self, data, agent_category, feat_a): + num_agent, traj_dim = feat_a.shape + agent_type = data['agent']['type'] + + veh_mask = (agent_type == 0) # * agent_category==3 + cyc_mask = (agent_type == 2) # * agent_category==3 + ped_mask = (agent_type == 1) # * agent_category==3 + + token_res = torch.zeros((num_agent, self.token_size), device=agent_category.device) + token_res[veh_mask] = self.token_predict_head(feat_a[veh_mask]) + token_res[cyc_mask] = self.token_predict_cyc_head(feat_a[cyc_mask]) + token_res[ped_mask] = self.token_predict_walker_head(feat_a[ped_mask]) + + return token_res + + def build_temporal_edge(self, pos_a, head_a, head_vector_a, num_agent, mask, inference_mask=None): + pos_t = pos_a.reshape(-1, self.input_dim) + head_t = head_a.reshape(-1) + head_vector_t = head_vector_a.reshape(-1, 2) + hist_mask = mask.clone() + + if self.hist_mask and self.training: + hist_mask[ + torch.arange(mask.shape[0]).unsqueeze(1), torch.randint(0, mask.shape[1], (num_agent, 10))] = False + mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1) + elif inference_mask is not None: + mask_t = hist_mask.unsqueeze(2) & inference_mask.unsqueeze(1) + else: + mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1) + + edge_index_t = dense_to_sparse(mask_t)[0] + edge_index_t = edge_index_t[:, edge_index_t[1] > edge_index_t[0]] + edge_index_t = edge_index_t[:, edge_index_t[1] - edge_index_t[0] <= self.time_span / self.shift] + rel_pos_t = pos_t[edge_index_t[0]] - pos_t[edge_index_t[1]] + rel_head_t = wrap_angle(head_t[edge_index_t[0]] - head_t[edge_index_t[1]]) + r_t = torch.stack( + [torch.norm(rel_pos_t[:, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=head_vector_t[edge_index_t[1]], nbr_vector=rel_pos_t[:, :2]), + rel_head_t, + edge_index_t[0] - edge_index_t[1]], dim=-1) + r_t = self.r_t_emb(continuous_inputs=r_t, categorical_embs=None) + return edge_index_t, r_t + + def build_interaction_edge(self, pos_a, head_a, head_vector_a, batch_s, mask_s): + pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim) + head_s = head_a.transpose(0, 1).reshape(-1) + head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2) + edge_index_a2a = radius_graph(x=pos_s[:, :2], r=self.a2a_radius, batch=batch_s, loop=False, + max_num_neighbors=300) + edge_index_a2a = subgraph(subset=mask_s, edge_index=edge_index_a2a)[0] + rel_pos_a2a = pos_s[edge_index_a2a[0]] - pos_s[edge_index_a2a[1]] + rel_head_a2a = wrap_angle(head_s[edge_index_a2a[0]] - head_s[edge_index_a2a[1]]) + r_a2a = torch.stack( + [torch.norm(rel_pos_a2a[:, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2a[1]], nbr_vector=rel_pos_a2a[:, :2]), + rel_head_a2a], dim=-1) + r_a2a = self.r_a2a_emb(continuous_inputs=r_a2a, categorical_embs=None) + return edge_index_a2a, r_a2a + + def build_map2agent_edge(self, data, num_step, agent_category, pos_a, head_a, head_vector_a, mask, + batch_s, batch_pl): + mask_pl2a = mask.clone() + mask_pl2a = mask_pl2a.transpose(0, 1).reshape(-1) + pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim) + head_s = head_a.transpose(0, 1).reshape(-1) + head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2) + pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous() + orient_pl = data['pt_token']['orientation'].contiguous() + pos_pl = pos_pl.repeat(num_step, 1) + orient_pl = orient_pl.repeat(num_step) + edge_index_pl2a = radius(x=pos_s[:, :2], y=pos_pl[:, :2], r=self.pl2a_radius, + batch_x=batch_s, batch_y=batch_pl, max_num_neighbors=300) + edge_index_pl2a = edge_index_pl2a[:, mask_pl2a[edge_index_pl2a[1]]] + rel_pos_pl2a = pos_pl[edge_index_pl2a[0]] - pos_s[edge_index_pl2a[1]] + rel_orient_pl2a = wrap_angle(orient_pl[edge_index_pl2a[0]] - head_s[edge_index_pl2a[1]]) + r_pl2a = torch.stack( + [torch.norm(rel_pos_pl2a[:, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_pl2a[1]], nbr_vector=rel_pos_pl2a[:, :2]), + rel_orient_pl2a], dim=-1) + r_pl2a = self.r_pt2a_emb(continuous_inputs=r_pl2a, categorical_embs=None) + return edge_index_pl2a, r_pl2a + + def forward(self, + data: HeteroData, + map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + pos_a = data['agent']['token_pos'] + head_a = data['agent']['token_heading'] + head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1) + num_agent, num_step, traj_dim = pos_a.shape + agent_category = data['agent']['category'] + agent_token_index = data['agent']['token_idx'] + feat_a, agent_token_traj = self.agent_token_embedding(data, agent_category, agent_token_index, + pos_a, head_vector_a) + + agent_valid_mask = data['agent']['agent_valid_mask'].clone() + # eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps - 1] + # agent_valid_mask[~eval_mask] = False + mask = agent_valid_mask + edge_index_t, r_t = self.build_temporal_edge(pos_a, head_a, head_vector_a, num_agent, mask) + + if isinstance(data, Batch): + batch_s = torch.cat([data['agent']['batch'] + data.num_graphs * t + for t in range(num_step)], dim=0) + batch_pl = torch.cat([data['pt_token']['batch'] + data.num_graphs * t + for t in range(num_step)], dim=0) + else: + batch_s = torch.arange(num_step, + device=pos_a.device).repeat_interleave(data['agent']['num_nodes']) + batch_pl = torch.arange(num_step, + device=pos_a.device).repeat_interleave(data['pt_token']['num_nodes']) + + mask_s = mask.transpose(0, 1).reshape(-1) + edge_index_a2a, r_a2a = self.build_interaction_edge(pos_a, head_a, head_vector_a, batch_s, mask_s) + mask[agent_category != 3] = False + edge_index_pl2a, r_pl2a = self.build_map2agent_edge(data, num_step, agent_category, pos_a, head_a, + head_vector_a, mask, batch_s, batch_pl) + + for i in range(self.num_layers): + feat_a = feat_a.reshape(-1, self.hidden_dim) + feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t) + feat_a = feat_a.reshape(-1, num_step, + self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim) + feat_a = self.pt2a_attn_layers[i]((map_enc['x_pt'].repeat_interleave( + repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape( + -1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a) + feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a) + feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1) + + num_agent, num_step, hidden_dim, traj_num, traj_dim = agent_token_traj.shape + next_token_prob = self.token_predict_head(feat_a) + next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1) + _, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1) + + next_token_index_gt = agent_token_index.roll(shifts=-1, dims=1) + next_token_eval_mask = mask.clone() + next_token_eval_mask = next_token_eval_mask * next_token_eval_mask.roll(shifts=-1, dims=1) * next_token_eval_mask.roll(shifts=1, dims=1) + next_token_eval_mask[:, -1] = False + + return {'x_a': feat_a, + 'next_token_idx': next_token_idx, + 'next_token_prob': next_token_prob, + 'next_token_idx_gt': next_token_index_gt, + 'next_token_eval_mask': next_token_eval_mask, + } + + def inference(self, + data: HeteroData, + map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps - 1] + pos_a = data['agent']['token_pos'].clone() + head_a = data['agent']['token_heading'].clone() + num_agent, num_step, traj_dim = pos_a.shape + pos_a[:, (self.num_historical_steps - 1) // self.shift:] = 0 + head_a[:, (self.num_historical_steps - 1) // self.shift:] = 0 + head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1) + + agent_valid_mask = data['agent']['agent_valid_mask'].clone() + agent_valid_mask[:, (self.num_historical_steps - 1) // self.shift:] = True + agent_valid_mask[~eval_mask] = False + agent_token_index = data['agent']['token_idx'] + agent_category = data['agent']['category'] + feat_a, agent_token_traj, agent_token_traj_all, agent_token_emb, categorical_embs = self.agent_token_embedding( + data, + agent_category, + agent_token_index, + pos_a, + head_vector_a, + inference=True) + + agent_type = data["agent"]["type"] + veh_mask = (agent_type == 0) # * agent_category==3 + cyc_mask = (agent_type == 2) # * agent_category==3 + ped_mask = (agent_type == 1) # * agent_category==3 + av_mask = data["agent"]["av_index"] + + self.num_recurrent_steps_val = data["agent"]['position'].shape[1]-self.num_historical_steps + pred_traj = torch.zeros(data["agent"].num_nodes, self.num_recurrent_steps_val, 2, device=feat_a.device) + pred_head = torch.zeros(data["agent"].num_nodes, self.num_recurrent_steps_val, device=feat_a.device) + pred_prob = torch.zeros(data["agent"].num_nodes, self.num_recurrent_steps_val // self.shift, device=feat_a.device) + next_token_idx_list = [] + mask = agent_valid_mask.clone() + feat_a_t_dict = {} + for t in range(self.num_recurrent_steps_val // self.shift): + if t == 0: + inference_mask = mask.clone() + inference_mask[:, (self.num_historical_steps - 1) // self.shift + t:] = False + else: + inference_mask = torch.zeros_like(mask) + inference_mask[:, (self.num_historical_steps - 1) // self.shift + t - 1] = True + edge_index_t, r_t = self.build_temporal_edge(pos_a, head_a, head_vector_a, num_agent, mask, inference_mask) + if isinstance(data, Batch): + batch_s = torch.cat([data['agent']['batch'] + data.num_graphs * t + for t in range(num_step)], dim=0) + batch_pl = torch.cat([data['pt_token']['batch'] + data.num_graphs * t + for t in range(num_step)], dim=0) + else: + batch_s = torch.arange(num_step, + device=pos_a.device).repeat_interleave(data['agent']['num_nodes']) + batch_pl = torch.arange(num_step, + device=pos_a.device).repeat_interleave(data['pt_token']['num_nodes']) + # In the inference stage, we only infer the current stage for recurrent + edge_index_pl2a, r_pl2a = self.build_map2agent_edge(data, num_step, agent_category, pos_a, head_a, + head_vector_a, + inference_mask, batch_s, + batch_pl) + mask_s = inference_mask.transpose(0, 1).reshape(-1) + edge_index_a2a, r_a2a = self.build_interaction_edge(pos_a, head_a, head_vector_a, + batch_s, mask_s) + + for i in range(self.num_layers): + if i in feat_a_t_dict: + feat_a = feat_a_t_dict[i] + feat_a = feat_a.reshape(-1, self.hidden_dim) + feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t) + feat_a = feat_a.reshape(-1, num_step, + self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim) + feat_a = self.pt2a_attn_layers[i]((map_enc['x_pt'].repeat_interleave( + repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape( + -1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a) + feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a) + feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1) + + if i+1 not in feat_a_t_dict: + feat_a_t_dict[i+1] = feat_a + else: + feat_a_t_dict[i+1][:, (self.num_historical_steps - 1) // self.shift - 1 + t] = feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t] + + next_token_prob = self.token_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) + + next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1) + + topk_prob, next_token_idx = torch.topk(next_token_prob_softmax, k=self.beam_size, dim=-1) + + expanded_index = next_token_idx[..., None, None, None].expand(-1, -1, 6, 4, 2) + next_token_traj = torch.gather(agent_token_traj_all, 1, expanded_index) + + theta = head_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t] + cos, sin = theta.cos(), theta.sin() + rot_mat = torch.zeros((num_agent, 2, 2), device=theta.device) + rot_mat[:, 0, 0] = cos + rot_mat[:, 0, 1] = sin + rot_mat[:, 1, 0] = -sin + rot_mat[:, 1, 1] = cos + agent_diff_rel = torch.bmm(next_token_traj.view(-1, 4, 2), + rot_mat[:, None, None, ...].repeat(1, self.beam_size, self.shift + 1, 1, 1).view( + -1, 2, 2)).view(num_agent, self.beam_size, self.shift + 1, 4, 2) + agent_pred_rel = agent_diff_rel + pos_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t, :][:, None, None, None, ...] + + sample_index = torch.multinomial(topk_prob, 1).to(agent_pred_rel.device) + agent_pred_rel = agent_pred_rel.gather(dim=1, + index=sample_index[..., None, None, None].expand(-1, -1, 6, 4, + 2))[:, 0, ...] + pred_prob[:, t] = topk_prob.gather(dim=-1, index=sample_index)[:, 0] + pred_traj[:, t * 5:(t + 1) * 5] = agent_pred_rel[:, 1:, ...].clone().mean(dim=2) + diff_xy = agent_pred_rel[:, 1:, 0, :] - agent_pred_rel[:, 1:, 3, :] + pred_head[:, t * 5:(t + 1) * 5] = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0]) + + pos_a[:, (self.num_historical_steps - 1) // self.shift + t] = agent_pred_rel[:, -1, ...].clone().mean(dim=1) + diff_xy = agent_pred_rel[:, -1, 0, :] - agent_pred_rel[:, -1, 3, :] + theta = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0]) + head_a[:, (self.num_historical_steps - 1) // self.shift + t] = theta + next_token_idx = next_token_idx.gather(dim=1, index=sample_index) + next_token_idx = next_token_idx.squeeze(-1) + next_token_idx_list.append(next_token_idx[:, None]) + agent_token_emb[veh_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_veh[ + next_token_idx[veh_mask]] + agent_token_emb[ped_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_ped[ + next_token_idx[ped_mask]] + agent_token_emb[cyc_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_cyc[ + next_token_idx[cyc_mask]] + motion_vector_a = torch.cat([pos_a.new_zeros(data['agent']['num_nodes'], 1, self.input_dim), + pos_a[:, 1:] - pos_a[:, :-1]], dim=1) + + head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1) + + vel = motion_vector_a.clone() / (0.1 * self.shift) + vel[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0 + motion_vector_a[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0 + x_a = torch.stack( + [torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2])], dim=-1) + + x_a = self.x_a_emb(continuous_inputs=x_a.view(-1, x_a.size(-1)), + categorical_embs=categorical_embs) + x_a = x_a.view(-1, num_step, self.hidden_dim) + + feat_a = torch.cat((agent_token_emb, x_a), dim=-1) + feat_a = self.fusion_emb(feat_a) + + agent_valid_mask[agent_category != 3] = False + + return { + 'pos_a': pos_a[:, (self.num_historical_steps - 1) // self.shift:], + 'head_a': head_a[:, (self.num_historical_steps - 1) // self.shift:], + 'gt': data['agent']['position'][:, self.num_historical_steps:, :self.input_dim].contiguous(), + 'valid_mask': agent_valid_mask[:, self.num_historical_steps:], + 'pred_traj': pred_traj, + 'pred_head': pred_head, + 'next_token_idx': torch.cat(next_token_idx_list, dim=-1), + 'next_token_idx_gt': agent_token_index.roll(shifts=-1, dims=1), + 'next_token_eval_mask': data['agent']['agent_valid_mask'], + 'pred_prob': pred_prob, + 'vel': vel + } diff --git a/backups/thirdparty/SMART/smart/modules/map_decoder.py b/backups/thirdparty/SMART/smart/modules/map_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4eb49eca90edb971ead9f3d31cf2f6886e6a4914 --- /dev/null +++ b/backups/thirdparty/SMART/smart/modules/map_decoder.py @@ -0,0 +1,139 @@ +import os.path +from typing import Dict +import torch +import torch.nn as nn +from torch_cluster import radius_graph +from torch_geometric.data import Batch +from torch_geometric.data import HeteroData +from torch_geometric.utils import dense_to_sparse, subgraph +from smart.utils.nan_checker import check_nan_inf +from smart.layers.attention_layer import AttentionLayer +from smart.layers import MLPLayer +from smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding +from smart.utils import angle_between_2d_vectors +from smart.utils import merge_edges +from smart.utils import weight_init +from smart.utils import wrap_angle +import pickle + + +class SMARTMapDecoder(nn.Module): + + def __init__(self, + dataset: str, + input_dim: int, + hidden_dim: int, + num_historical_steps: int, + pl2pl_radius: float, + num_freq_bands: int, + num_layers: int, + num_heads: int, + head_dim: int, + dropout: float, + map_token) -> None: + super(SMARTMapDecoder, self).__init__() + self.dataset = dataset + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.num_historical_steps = num_historical_steps + self.pl2pl_radius = pl2pl_radius + self.num_freq_bands = num_freq_bands + self.num_layers = num_layers + self.num_heads = num_heads + self.head_dim = head_dim + self.dropout = dropout + + if input_dim == 2: + input_dim_r_pt2pt = 3 + elif input_dim == 3: + input_dim_r_pt2pt = 4 + else: + raise ValueError('{} is not a valid dimension'.format(input_dim)) + + self.type_pt_emb = nn.Embedding(17, hidden_dim) + self.side_pt_emb = nn.Embedding(4, hidden_dim) + self.polygon_type_emb = nn.Embedding(4, hidden_dim) + self.light_pl_emb = nn.Embedding(4, hidden_dim) + + self.r_pt2pt_emb = FourierEmbedding(input_dim=input_dim_r_pt2pt, hidden_dim=hidden_dim, + num_freq_bands=num_freq_bands) + self.pt2pt_layers = nn.ModuleList( + [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, + bipartite=False, has_pos_emb=True) for _ in range(num_layers)] + ) + self.token_size = 1024 + self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, + output_dim=self.token_size) + input_dim_token = 22 + self.token_emb = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim) + self.map_token = map_token + self.apply(weight_init) + self.mask_pt = False + + def maybe_autocast(self, dtype=torch.float32): + return torch.cuda.amp.autocast(dtype=dtype) + + def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]: + pt_valid_mask = data['pt_token']['pt_valid_mask'] + pt_pred_mask = data['pt_token']['pt_pred_mask'] + pt_target_mask = data['pt_token']['pt_target_mask'] + mask_s = pt_valid_mask + + pos_pt = data['pt_token']['position'][:, :self.input_dim].contiguous() + orient_pt = data['pt_token']['orientation'].contiguous() + orient_vector_pt = torch.stack([orient_pt.cos(), orient_pt.sin()], dim=-1) + token_sample_pt = self.map_token['traj_src'].to(pos_pt.device).to(torch.float) + pt_token_emb_src = self.token_emb(token_sample_pt.view(token_sample_pt.shape[0], -1)) + pt_token_emb = pt_token_emb_src[data['pt_token']['token_idx']] + + if self.input_dim == 2: + x_pt = pt_token_emb + elif self.input_dim == 3: + x_pt = pt_token_emb + else: + raise ValueError('{} is not a valid dimension'.format(self.input_dim)) + + token2pl = data[('pt_token', 'to', 'map_polygon')]['edge_index'] + token_light_type = data['map_polygon']['light_type'][token2pl[1]] + x_pt_categorical_embs = [self.type_pt_emb(data['pt_token']['type'].long()), + self.polygon_type_emb(data['pt_token']['pl_type'].long()), + self.light_pl_emb(token_light_type.long()),] + x_pt = x_pt + torch.stack(x_pt_categorical_embs).sum(dim=0) + edge_index_pt2pt = radius_graph(x=pos_pt[:, :2], r=self.pl2pl_radius, + batch=data['pt_token']['batch'] if isinstance(data, Batch) else None, + loop=False, max_num_neighbors=100) + if self.mask_pt: + edge_index_pt2pt = subgraph(subset=mask_s, edge_index=edge_index_pt2pt)[0] + rel_pos_pt2pt = pos_pt[edge_index_pt2pt[0]] - pos_pt[edge_index_pt2pt[1]] + rel_orient_pt2pt = wrap_angle(orient_pt[edge_index_pt2pt[0]] - orient_pt[edge_index_pt2pt[1]]) + if self.input_dim == 2: + r_pt2pt = torch.stack( + [torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]], + nbr_vector=rel_pos_pt2pt[:, :2]), + rel_orient_pt2pt], dim=-1) + elif self.input_dim == 3: + r_pt2pt = torch.stack( + [torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1), + angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]], + nbr_vector=rel_pos_pt2pt[:, :2]), + rel_pos_pt2pt[:, -1], + rel_orient_pt2pt], dim=-1) + else: + raise ValueError('{} is not a valid dimension'.format(self.input_dim)) + r_pt2pt = self.r_pt2pt_emb(continuous_inputs=r_pt2pt, categorical_embs=None) + for i in range(self.num_layers): + x_pt = self.pt2pt_layers[i](x_pt, r_pt2pt, edge_index_pt2pt) + + next_token_prob = self.token_predict_head(x_pt[pt_pred_mask]) + next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1) + _, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1) + next_token_index_gt = data['pt_token']['token_idx'][pt_target_mask] + + return { + 'x_pt': x_pt, + 'map_next_token_idx': next_token_idx, + 'map_next_token_prob': next_token_prob, + 'map_next_token_idx_gt': next_token_index_gt, + 'map_next_token_eval_mask': pt_pred_mask[pt_pred_mask] + } diff --git a/backups/thirdparty/SMART/smart/modules/smart_decoder.py b/backups/thirdparty/SMART/smart/modules/smart_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..9b12d8e144e10e3e931cfb4f480981ff799603fa --- /dev/null +++ b/backups/thirdparty/SMART/smart/modules/smart_decoder.py @@ -0,0 +1,74 @@ +from typing import Dict, Optional +import torch +import torch.nn as nn +from torch_geometric.data import HeteroData +from smart.modules.agent_decoder import SMARTAgentDecoder +from smart.modules.map_decoder import SMARTMapDecoder + + +class SMARTDecoder(nn.Module): + + def __init__(self, + dataset: str, + input_dim: int, + hidden_dim: int, + num_historical_steps: int, + pl2pl_radius: float, + time_span: Optional[int], + pl2a_radius: float, + a2a_radius: float, + num_freq_bands: int, + num_map_layers: int, + num_agent_layers: int, + num_heads: int, + head_dim: int, + dropout: float, + map_token: Dict, + token_data: Dict, + use_intention=False, + token_size=512) -> None: + super(SMARTDecoder, self).__init__() + self.map_encoder = SMARTMapDecoder( + dataset=dataset, + input_dim=input_dim, + hidden_dim=hidden_dim, + num_historical_steps=num_historical_steps, + pl2pl_radius=pl2pl_radius, + num_freq_bands=num_freq_bands, + num_layers=num_map_layers, + num_heads=num_heads, + head_dim=head_dim, + dropout=dropout, + map_token=map_token + ) + self.agent_encoder = SMARTAgentDecoder( + dataset=dataset, + input_dim=input_dim, + hidden_dim=hidden_dim, + num_historical_steps=num_historical_steps, + time_span=time_span, + pl2a_radius=pl2a_radius, + a2a_radius=a2a_radius, + num_freq_bands=num_freq_bands, + num_layers=num_agent_layers, + num_heads=num_heads, + head_dim=head_dim, + dropout=dropout, + token_size=token_size, + token_data=token_data + ) + self.map_enc = None + + def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]: + map_enc = self.map_encoder(data) + agent_enc = self.agent_encoder(data, map_enc) + return {**map_enc, **agent_enc} + + def inference(self, data: HeteroData) -> Dict[str, torch.Tensor]: + map_enc = self.map_encoder(data) + agent_enc = self.agent_encoder.inference(data, map_enc) + return {**map_enc, **agent_enc} + + def inference_no_map(self, data: HeteroData, map_enc) -> Dict[str, torch.Tensor]: + agent_enc = self.agent_encoder.inference(data, map_enc) + return {**map_enc, **agent_enc} diff --git a/backups/thirdparty/SMART/smart/preprocess/__init__.py b/backups/thirdparty/SMART/smart/preprocess/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backups/thirdparty/SMART/smart/preprocess/preprocess.py b/backups/thirdparty/SMART/smart/preprocess/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..665b4a3d555845ef066c40a80f3eac5664bb643b --- /dev/null +++ b/backups/thirdparty/SMART/smart/preprocess/preprocess.py @@ -0,0 +1,110 @@ +import numpy as np +import pandas as pd +import os +import torch +from typing import Any, Dict, List, Optional + +predict_unseen_agents = False +vector_repr = True +_agent_types = ['vehicle', 'pedestrian', 'cyclist', 'background'] +_polygon_types = ['VEHICLE', 'BIKE', 'BUS', 'PEDESTRIAN'] +_polygon_light_type = ['LANE_STATE_STOP', 'LANE_STATE_GO', 'LANE_STATE_CAUTION', 'LANE_STATE_UNKNOWN'] +_point_types = ['DASH_SOLID_YELLOW', 'DASH_SOLID_WHITE', 'DASHED_WHITE', 'DASHED_YELLOW', + 'DOUBLE_SOLID_YELLOW', 'DOUBLE_SOLID_WHITE', 'DOUBLE_DASH_YELLOW', 'DOUBLE_DASH_WHITE', + 'SOLID_YELLOW', 'SOLID_WHITE', 'SOLID_DASH_WHITE', 'SOLID_DASH_YELLOW', 'EDGE', + 'NONE', 'UNKNOWN', 'CROSSWALK', 'CENTERLINE'] +_point_sides = ['LEFT', 'RIGHT', 'CENTER'] +_polygon_to_polygon_types = ['NONE', 'PRED', 'SUCC', 'LEFT', 'RIGHT'] +_polygon_is_intersections = [True, False, None] + + +Lane_type_hash = { + 4: "BIKE", + 3: "VEHICLE", + 2: "VEHICLE", + 1: "BUS" +} + +boundary_type_hash = { + 5: "UNKNOWN", + 6: "DASHED_WHITE", + 7: "SOLID_WHITE", + 8: "DOUBLE_DASH_WHITE", + 9: "DASHED_YELLOW", + 10: "DOUBLE_DASH_YELLOW", + 11: "SOLID_YELLOW", + 12: "DOUBLE_SOLID_YELLOW", + 13: "DASH_SOLID_YELLOW", + 14: "UNKNOWN", + 15: "EDGE", + 16: "EDGE" +} + + +def get_agent_features(df: pd.DataFrame, av_id, num_historical_steps=10, dim=3, num_steps=91) -> Dict[str, Any]: + if not predict_unseen_agents: # filter out agents that are unseen during the historical time steps + historical_df = df[df['timestep'] == num_historical_steps-1] + agent_ids = list(historical_df['track_id'].unique()) + df = df[df['track_id'].isin(agent_ids)] + else: + agent_ids = list(df['track_id'].unique()) + + num_agents = len(agent_ids) + # initialization + valid_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool) + current_valid_mask = torch.zeros(num_agents, dtype=torch.bool) + predict_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool) + agent_id: List[Optional[str]] = [None] * num_agents + agent_type = torch.zeros(num_agents, dtype=torch.uint8) + agent_category = torch.zeros(num_agents, dtype=torch.uint8) + position = torch.zeros(num_agents, num_steps, dim, dtype=torch.float) + heading = torch.zeros(num_agents, num_steps, dtype=torch.float) + velocity = torch.zeros(num_agents, num_steps, dim, dtype=torch.float) + shape = torch.zeros(num_agents, num_steps, dim, dtype=torch.float) + + for track_id, track_df in df.groupby('track_id'): + agent_idx = agent_ids.index(track_id) + agent_steps = track_df['timestep'].values + + valid_mask[agent_idx, agent_steps] = True + current_valid_mask[agent_idx] = valid_mask[agent_idx, num_historical_steps - 1] + predict_mask[agent_idx, agent_steps] = True + if vector_repr: # a time step t is valid only when both t and t-1 are valid + valid_mask[agent_idx, 1: num_historical_steps] = ( + valid_mask[agent_idx, :num_historical_steps - 1] & + valid_mask[agent_idx, 1: num_historical_steps]) + valid_mask[agent_idx, 0] = False + predict_mask[agent_idx, :num_historical_steps] = False + if not current_valid_mask[agent_idx]: + predict_mask[agent_idx, num_historical_steps:] = False + + agent_id[agent_idx] = track_id + agent_type[agent_idx] = _agent_types.index(track_df['object_type'].values[0]) + agent_category[agent_idx] = track_df['object_category'].values[0] + position[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['position_x'].values, + track_df['position_y'].values, + track_df['position_z'].values], + axis=-1)).float() + heading[agent_idx, agent_steps] = torch.from_numpy(track_df['heading'].values).float() + velocity[agent_idx, agent_steps, :2] = torch.from_numpy(np.stack([track_df['velocity_x'].values, + track_df['velocity_y'].values], + axis=-1)).float() + shape[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['length'].values, + track_df['width'].values, + track_df["height"].values], + axis=-1)).float() + av_idx = agent_id.index(av_id) + + return { + 'num_nodes': num_agents, + 'av_index': av_idx, + 'valid_mask': valid_mask, + 'predict_mask': predict_mask, + 'id': agent_id, + 'type': agent_type, + 'category': agent_category, + 'position': position, + 'heading': heading, + 'velocity': velocity, + 'shape': shape + } \ No newline at end of file diff --git a/backups/thirdparty/SMART/smart/tokens/__init__.py b/backups/thirdparty/SMART/smart/tokens/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backups/thirdparty/SMART/smart/transforms/__init__.py b/backups/thirdparty/SMART/smart/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c73255af6e11d74c9ac037a894f7b8a582ab900a --- /dev/null +++ b/backups/thirdparty/SMART/smart/transforms/__init__.py @@ -0,0 +1 @@ +from smart.transforms.target_builder import WaymoTargetBuilder diff --git a/backups/thirdparty/SMART/smart/transforms/target_builder.py b/backups/thirdparty/SMART/smart/transforms/target_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..f20bab06a8417b4dfa4866f141e2e8f180a10869 --- /dev/null +++ b/backups/thirdparty/SMART/smart/transforms/target_builder.py @@ -0,0 +1,152 @@ + +import numpy as np +import torch +from torch_geometric.data import HeteroData +from torch_geometric.transforms import BaseTransform +from smart.utils import wrap_angle +from smart.utils.log import Logging + + +def to_16(data): + if isinstance(data, dict): + for key, value in data.items(): + new_value = to_16(value) + data[key] = new_value + if isinstance(data, torch.Tensor): + if data.dtype == torch.float32: + data = data.to(torch.float16) + return data + + +def tofloat32(data): + for name in data: + value = data[name] + if isinstance(value, dict): + value = tofloat32(value) + elif isinstance(value, torch.Tensor) and value.dtype == torch.float64: + value = value.to(torch.float32) + data[name] = value + return data + + +class WaymoTargetBuilder(BaseTransform): + + def __init__(self, + num_historical_steps: int, + num_future_steps: int, + mode="train") -> None: + self.num_historical_steps = num_historical_steps + self.num_future_steps = num_future_steps + self.mode = mode + self.num_features = 3 + self.augment = False + self.logger = Logging().log(level='DEBUG') + + def score_ego_agent(self, agent): + av_index = agent['av_index'] + agent["category"][av_index] = 5 + return agent + + def clip(self, agent, max_num=32): + av_index = agent["av_index"] + valid = agent['valid_mask'] + ego_pos = agent["position"][av_index] + obstacle_mask = agent['type'] == 3 + distance = torch.norm(agent["position"][:, self.num_historical_steps-1, :2] - ego_pos[self.num_historical_steps-1, :2], dim=-1) # keep the closest 100 vehicles near the ego car + distance[obstacle_mask] = 10e5 + sort_idx = distance.sort()[1] + mask = torch.zeros(valid.shape[0]) + mask[sort_idx[:max_num]] = 1 + mask = mask.to(torch.bool) + mask[av_index] = True + new_av_index = mask[:av_index].sum() + agent["num_nodes"] = int(mask.sum()) + agent["av_index"] = int(new_av_index) + excluded = ["num_nodes", "av_index", "ego"] + for key, val in agent.items(): + if key in excluded: + continue + if key == "id": + val = list(np.array(val)[mask]) + agent[key] = val + continue + if len(val.size()) > 1: + agent[key] = val[mask, ...] + else: + agent[key] = val[mask] + return agent + + def score_nearby_vehicle(self, agent, max_num=10): + av_index = agent['av_index'] + agent["category"] = torch.zeros_like(agent["category"]) + obstacle_mask = agent['type'] == 3 + pos = agent["position"][av_index, self.num_historical_steps, :2] + distance = torch.norm(agent["position"][:, self.num_historical_steps, :2] - pos, dim=-1) + distance[obstacle_mask] = 10e5 + sort_idx = distance.sort()[1] + nearby_mask = torch.zeros(distance.shape[0]) + nearby_mask[sort_idx[1:max_num]] = 1 + nearby_mask = nearby_mask.bool() + agent["category"][nearby_mask] = 3 + agent["category"][obstacle_mask] = 0 + + def score_trained_vehicle(self, agent, max_num=10, min_distance=0): + av_index = agent['av_index'] + agent["category"] = torch.zeros_like(agent["category"]) + pos = agent["position"][av_index, self.num_historical_steps, :2] + distance = torch.norm(agent["position"][:, self.num_historical_steps, :2] - pos, dim=-1) + distance_all_time = torch.norm(agent["position"][:, :, :2] - agent["position"][av_index, :, :2], dim=-1) + invalid_mask = distance_all_time < 150 # we do not believe the perception out of range of 150 meters + agent["valid_mask"] = agent["valid_mask"] * invalid_mask + # we do not predict vehicle too far away from ego car + closet_vehicle = distance < 100 + valid = agent['valid_mask'] + valid_current = valid[:, (self.num_historical_steps):] + valid_counts = valid_current.sum(1) + counts_vehicle = valid_counts >= 1 + no_backgroud = agent['type'] != 3 + vehicle2pred = closet_vehicle & counts_vehicle & no_backgroud + if vehicle2pred.sum() > max_num: + # too many still vehicle so that train the model using the moving vehicle as much as possible + true_indices = torch.nonzero(vehicle2pred).squeeze(1) + selected_indices = true_indices[torch.randperm(true_indices.size(0))[:max_num]] + vehicle2pred.fill_(False) + vehicle2pred[selected_indices] = True + agent["category"][vehicle2pred] = 3 + + def rotate_agents(self, position, heading, num_nodes, num_historical_steps, num_future_steps): + origin = position[:, num_historical_steps - 1] + theta = heading[:, num_historical_steps - 1] + cos, sin = theta.cos(), theta.sin() + rot_mat = theta.new_zeros(num_nodes, 2, 2) + rot_mat[:, 0, 0] = cos + rot_mat[:, 0, 1] = -sin + rot_mat[:, 1, 0] = sin + rot_mat[:, 1, 1] = cos + target = origin.new_zeros(num_nodes, num_future_steps, 4) + target[..., :2] = torch.bmm(position[:, num_historical_steps:, :2] - + origin[:, :2].unsqueeze(1), rot_mat) + his = origin.new_zeros(num_nodes, num_historical_steps, 4) + his[..., :2] = torch.bmm(position[:, :num_historical_steps, :2] - + origin[:, :2].unsqueeze(1), rot_mat) + if position.size(2) == 3: + target[..., 2] = (position[:, num_historical_steps:, 2] - + origin[:, 2].unsqueeze(-1)) + his[..., 2] = (position[:, :num_historical_steps, 2] - + origin[:, 2].unsqueeze(-1)) + target[..., 3] = wrap_angle(heading[:, num_historical_steps:] - + theta.unsqueeze(-1)) + his[..., 3] = wrap_angle(heading[:, :num_historical_steps] - + theta.unsqueeze(-1)) + else: + target[..., 2] = wrap_angle(heading[:, num_historical_steps:] - + theta.unsqueeze(-1)) + his[..., 2] = wrap_angle(heading[:, :num_historical_steps] - + theta.unsqueeze(-1)) + return his, target + + def __call__(self, data) -> HeteroData: + agent = data["agent"] + self.score_ego_agent(agent) + self.score_trained_vehicle(agent, max_num=32) + return HeteroData(data) diff --git a/backups/thirdparty/SMART/smart/utils/__init__.py b/backups/thirdparty/SMART/smart/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d8ef25cf735714f4cb5b51da5bf6678af9c736b7 --- /dev/null +++ b/backups/thirdparty/SMART/smart/utils/__init__.py @@ -0,0 +1,12 @@ + +from smart.utils.geometry import angle_between_2d_vectors +from smart.utils.geometry import angle_between_3d_vectors +from smart.utils.geometry import side_to_directed_lineseg +from smart.utils.geometry import wrap_angle +from smart.utils.graph import add_edges +from smart.utils.graph import bipartite_dense_to_sparse +from smart.utils.graph import complete_graph +from smart.utils.graph import merge_edges +from smart.utils.graph import unbatch +from smart.utils.list import safe_list_index +from smart.utils.weight_init import weight_init diff --git a/backups/thirdparty/SMART/smart/utils/cluster_reader.py b/backups/thirdparty/SMART/smart/utils/cluster_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..b0ef689d750196227038d28ce0dda99dcee9fab9 --- /dev/null +++ b/backups/thirdparty/SMART/smart/utils/cluster_reader.py @@ -0,0 +1,45 @@ +import io +import pickle +import pandas as pd +import json + + +class LoadScenarioFromCeph: + def __init__(self): + from petrel_client.client import Client + self.file_client = Client('~/petreloss.conf') + + def list(self, dir_path): + return list(self.file_client.list(dir_path)) + + def save(self, data, url): + self.file_client.put(url, pickle.dumps(data)) + + def read_correct_csv(self, scenario_path): + output = pd.read_csv(io.StringIO(self.file_client.get(scenario_path).decode('utf-8')), engine="python") + return output + + def contains(self, url): + return self.file_client.contains(url) + + def read_string(self, csv_url): + from io import StringIO + df = pd.read_csv(StringIO(str(self.file_client.get(csv_url), 'utf-8')), sep='\s+', low_memory=False) + return df + + def read(self, scenario_path): + with io.BytesIO(self.file_client.get(scenario_path)) as f: + datas = pickle.load(f) + return datas + + def read_json(self, path): + with io.BytesIO(self.file_client.get(path)) as f: + data = json.load(f) + return data + + def read_csv(self, scenario_path): + return pickle.loads(self.file_client.get(scenario_path)) + + def read_model(self, model_path): + with io.BytesIO(self.file_client.get(model_path)) as f: + pass diff --git a/backups/thirdparty/SMART/smart/utils/config.py b/backups/thirdparty/SMART/smart/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..77ba340ecf67a9146bf2671e691025b5f446ac5e --- /dev/null +++ b/backups/thirdparty/SMART/smart/utils/config.py @@ -0,0 +1,18 @@ +import os +import yaml +import easydict + + +def load_config_act(path): + """ load config file""" + with open(path, 'r') as f: + cfg = yaml.load(f, Loader=yaml.FullLoader) + return easydict.EasyDict(cfg) + + +def load_config_init(path): + """ load config file""" + path = os.path.join('init/configs', f'{path}.yaml') + with open(path, 'r') as f: + cfg = yaml.load(f, Loader=yaml.FullLoader) + return cfg diff --git a/backups/thirdparty/SMART/smart/utils/geometry.py b/backups/thirdparty/SMART/smart/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..90ba479c53598e41d3ef8c6bcb40bd672e4fa182 --- /dev/null +++ b/backups/thirdparty/SMART/smart/utils/geometry.py @@ -0,0 +1,39 @@ + +import math + +import torch + + +def angle_between_2d_vectors( + ctr_vector: torch.Tensor, + nbr_vector: torch.Tensor) -> torch.Tensor: + return torch.atan2(ctr_vector[..., 0] * nbr_vector[..., 1] - ctr_vector[..., 1] * nbr_vector[..., 0], + (ctr_vector[..., :2] * nbr_vector[..., :2]).sum(dim=-1)) + + +def angle_between_3d_vectors( + ctr_vector: torch.Tensor, + nbr_vector: torch.Tensor) -> torch.Tensor: + return torch.atan2(torch.cross(ctr_vector, nbr_vector, dim=-1).norm(p=2, dim=-1), + (ctr_vector * nbr_vector).sum(dim=-1)) + + +def side_to_directed_lineseg( + query_point: torch.Tensor, + start_point: torch.Tensor, + end_point: torch.Tensor) -> str: + cond = ((end_point[0] - start_point[0]) * (query_point[1] - start_point[1]) - + (end_point[1] - start_point[1]) * (query_point[0] - start_point[0])) + if cond > 0: + return 'LEFT' + elif cond < 0: + return 'RIGHT' + else: + return 'CENTER' + + +def wrap_angle( + angle: torch.Tensor, + min_val: float = -math.pi, + max_val: float = math.pi) -> torch.Tensor: + return min_val + (angle + max_val) % (max_val - min_val) diff --git a/backups/thirdparty/SMART/smart/utils/graph.py b/backups/thirdparty/SMART/smart/utils/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..8db7d41de4215aee4664d1888bc5ddb77e090f4f --- /dev/null +++ b/backups/thirdparty/SMART/smart/utils/graph.py @@ -0,0 +1,90 @@ + +from typing import List, Optional, Tuple, Union + +import torch +from torch_geometric.utils import coalesce +from torch_geometric.utils import degree + + +def add_edges( + from_edge_index: torch.Tensor, + to_edge_index: torch.Tensor, + from_edge_attr: Optional[torch.Tensor] = None, + to_edge_attr: Optional[torch.Tensor] = None, + replace: bool = True) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + from_edge_index = from_edge_index.to(device=to_edge_index.device, dtype=to_edge_index.dtype) + mask = ((to_edge_index[0].unsqueeze(-1) == from_edge_index[0].unsqueeze(0)) & + (to_edge_index[1].unsqueeze(-1) == from_edge_index[1].unsqueeze(0))) + if replace: + to_mask = mask.any(dim=1) + if from_edge_attr is not None and to_edge_attr is not None: + from_edge_attr = from_edge_attr.to(device=to_edge_attr.device, dtype=to_edge_attr.dtype) + to_edge_attr = torch.cat([to_edge_attr[~to_mask], from_edge_attr], dim=0) + to_edge_index = torch.cat([to_edge_index[:, ~to_mask], from_edge_index], dim=1) + else: + from_mask = mask.any(dim=0) + if from_edge_attr is not None and to_edge_attr is not None: + from_edge_attr = from_edge_attr.to(device=to_edge_attr.device, dtype=to_edge_attr.dtype) + to_edge_attr = torch.cat([to_edge_attr, from_edge_attr[~from_mask]], dim=0) + to_edge_index = torch.cat([to_edge_index, from_edge_index[:, ~from_mask]], dim=1) + return to_edge_index, to_edge_attr + + +def merge_edges( + edge_indices: List[torch.Tensor], + edge_attrs: Optional[List[torch.Tensor]] = None, + reduce: str = 'add') -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + edge_index = torch.cat(edge_indices, dim=1) + if edge_attrs is not None: + edge_attr = torch.cat(edge_attrs, dim=0) + else: + edge_attr = None + return coalesce(edge_index=edge_index, edge_attr=edge_attr, reduce=reduce) + + +def complete_graph( + num_nodes: Union[int, Tuple[int, int]], + ptr: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + loop: bool = False, + device: Optional[Union[torch.device, str]] = None) -> torch.Tensor: + if ptr is None: + if isinstance(num_nodes, int): + num_src, num_dst = num_nodes, num_nodes + else: + num_src, num_dst = num_nodes + edge_index = torch.cartesian_prod(torch.arange(num_src, dtype=torch.long, device=device), + torch.arange(num_dst, dtype=torch.long, device=device)).t() + else: + if isinstance(ptr, torch.Tensor): + ptr_src, ptr_dst = ptr, ptr + num_src_batch = num_dst_batch = ptr[1:] - ptr[:-1] + else: + ptr_src, ptr_dst = ptr + num_src_batch = ptr_src[1:] - ptr_src[:-1] + num_dst_batch = ptr_dst[1:] - ptr_dst[:-1] + edge_index = torch.cat( + [torch.cartesian_prod(torch.arange(num_src, dtype=torch.long, device=device), + torch.arange(num_dst, dtype=torch.long, device=device)) + p + for num_src, num_dst, p in zip(num_src_batch, num_dst_batch, torch.stack([ptr_src, ptr_dst], dim=1))], + dim=0) + edge_index = edge_index.t() + if isinstance(num_nodes, int) and not loop: + edge_index = edge_index[:, edge_index[0] != edge_index[1]] + return edge_index.contiguous() + + +def bipartite_dense_to_sparse(adj: torch.Tensor) -> torch.Tensor: + index = adj.nonzero(as_tuple=True) + if len(index) == 3: + batch_src = index[0] * adj.size(1) + batch_dst = index[0] * adj.size(2) + index = (batch_src + index[1], batch_dst + index[2]) + return torch.stack(index, dim=0) + + +def unbatch( + src: torch.Tensor, + batch: torch.Tensor, + dim: int = 0) -> List[torch.Tensor]: + sizes = degree(batch, dtype=torch.long).tolist() + return src.split(sizes, dim) diff --git a/backups/thirdparty/SMART/smart/utils/list.py b/backups/thirdparty/SMART/smart/utils/list.py new file mode 100644 index 0000000000000000000000000000000000000000..e3dc3b1b22af4f491bbfc2370410e07af3b44fa4 --- /dev/null +++ b/backups/thirdparty/SMART/smart/utils/list.py @@ -0,0 +1,9 @@ + +from typing import Any, List, Optional + + +def safe_list_index(ls: List[Any], elem: Any) -> Optional[int]: + try: + return ls.index(elem) + except ValueError: + return None diff --git a/backups/thirdparty/SMART/smart/utils/log.py b/backups/thirdparty/SMART/smart/utils/log.py new file mode 100644 index 0000000000000000000000000000000000000000..ba53d0caaa9e8df4d08408887e814fe27b24bcda --- /dev/null +++ b/backups/thirdparty/SMART/smart/utils/log.py @@ -0,0 +1,56 @@ +import logging +import time +import os + + +class Logging: + + def make_log_dir(self, dirname='logs'): + now_dir = os.path.dirname(__file__) + path = os.path.join(now_dir, dirname) + path = os.path.normpath(path) + if not os.path.exists(path): + os.mkdir(path) + return path + + def get_log_filename(self): + filename = "{}.log".format(time.strftime("%Y-%m-%d",time.localtime())) + filename = os.path.join(self.make_log_dir(), filename) + filename = os.path.normpath(filename) + return filename + + def log(self, level='DEBUG', name="simagent"): + logger = logging.getLogger(name) + level = getattr(logging, level) + logger.setLevel(level) + if not logger.handlers: + sh = logging.StreamHandler() + fh = logging.FileHandler(filename=self.get_log_filename(), mode='a',encoding="utf-8") + fmt = logging.Formatter("%(asctime)s-%(levelname)s-%(filename)s-Line:%(lineno)d-Message:%(message)s") + sh.setFormatter(fmt=fmt) + fh.setFormatter(fmt=fmt) + logger.addHandler(sh) + logger.addHandler(fh) + return logger + + def add_log(self, logger, level='DEBUG'): + level = getattr(logging, level) + logger.setLevel(level) + if not logger.handlers: + sh = logging.StreamHandler() + fh = logging.FileHandler(filename=self.get_log_filename(), mode='a',encoding="utf-8") + fmt = logging.Formatter("%(asctime)s-%(levelname)s-%(filename)s-Line:%(lineno)d-Message:%(message)s") + sh.setFormatter(fmt=fmt) + fh.setFormatter(fmt=fmt) + logger.addHandler(sh) + logger.addHandler(fh) + return logger + + +if __name__ == '__main__': + logger = Logging().log(level='INFO') + logger.debug("1111111111111111111111") #使用日志器生成日志 + logger.info("222222222222222222222222") + logger.error("附件为IP飞机外婆家二分IP文件放") + logger.warning("3333333333333333333333333333") + logger.critical("44444444444444444444444444") diff --git a/backups/thirdparty/SMART/smart/utils/nan_checker.py b/backups/thirdparty/SMART/smart/utils/nan_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..8b5f52c9fc52e323d0af319592b6161cb473fc5b --- /dev/null +++ b/backups/thirdparty/SMART/smart/utils/nan_checker.py @@ -0,0 +1,5 @@ +import torch + +def check_nan_inf(t, s): + assert not torch.isinf(t).any(), f"{s} is inf, {t}" + assert not torch.isnan(t).any(), f"{s} is nan, {t}" \ No newline at end of file diff --git a/backups/thirdparty/SMART/smart/utils/weight_init.py b/backups/thirdparty/SMART/smart/utils/weight_init.py new file mode 100644 index 0000000000000000000000000000000000000000..39286a1a5b29d5af63d81ce9e1fd557cb18e3e93 --- /dev/null +++ b/backups/thirdparty/SMART/smart/utils/weight_init.py @@ -0,0 +1,70 @@ + +import torch.nn as nn + + +def weight_init(m: nn.Module) -> None: + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + fan_in = m.in_channels / m.groups + fan_out = m.out_channels / m.groups + bound = (6.0 / (fan_in + fan_out)) ** 0.5 + nn.init.uniform_(m.weight, -bound, bound) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, mean=0.0, std=0.02) + elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.MultiheadAttention): + if m.in_proj_weight is not None: + fan_in = m.embed_dim + fan_out = m.embed_dim + bound = (6.0 / (fan_in + fan_out)) ** 0.5 + nn.init.uniform_(m.in_proj_weight, -bound, bound) + else: + nn.init.xavier_uniform_(m.q_proj_weight) + nn.init.xavier_uniform_(m.k_proj_weight) + nn.init.xavier_uniform_(m.v_proj_weight) + if m.in_proj_bias is not None: + nn.init.zeros_(m.in_proj_bias) + nn.init.xavier_uniform_(m.out_proj.weight) + if m.out_proj.bias is not None: + nn.init.zeros_(m.out_proj.bias) + if m.bias_k is not None: + nn.init.normal_(m.bias_k, mean=0.0, std=0.02) + if m.bias_v is not None: + nn.init.normal_(m.bias_v, mean=0.0, std=0.02) + elif isinstance(m, (nn.LSTM, nn.LSTMCell)): + for name, param in m.named_parameters(): + if 'weight_ih' in name: + for ih in param.chunk(4, 0): + nn.init.xavier_uniform_(ih) + elif 'weight_hh' in name: + for hh in param.chunk(4, 0): + nn.init.orthogonal_(hh) + elif 'weight_hr' in name: + nn.init.xavier_uniform_(param) + elif 'bias_ih' in name: + nn.init.zeros_(param) + elif 'bias_hh' in name: + nn.init.zeros_(param) + nn.init.ones_(param.chunk(4, 0)[1]) + elif isinstance(m, (nn.GRU, nn.GRUCell)): + for name, param in m.named_parameters(): + if 'weight_ih' in name: + for ih in param.chunk(3, 0): + nn.init.xavier_uniform_(ih) + elif 'weight_hh' in name: + for hh in param.chunk(3, 0): + nn.init.orthogonal_(hh) + elif 'bias_ih' in name: + nn.init.zeros_(param) + elif 'bias_hh' in name: + nn.init.zeros_(param) diff --git a/backups/thirdparty/SMART/train.py b/backups/thirdparty/SMART/train.py new file mode 100644 index 0000000000000000000000000000000000000000..a573a517cbd979f1faf36d5d72564fa632f8c722 --- /dev/null +++ b/backups/thirdparty/SMART/train.py @@ -0,0 +1,56 @@ + +from argparse import ArgumentParser +import pytorch_lightning as pl +from pytorch_lightning.callbacks import LearningRateMonitor +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.strategies import DDPStrategy +from smart.utils.config import load_config_act +from smart.datamodules import MultiDataModule +from smart.model import SMART +from smart.utils.log import Logging + + +if __name__ == '__main__': + parser = ArgumentParser() + Predictor_hash = {"smart": SMART, } + parser.add_argument('--config', type=str, default='configs/train/train_scalable.yaml') + parser.add_argument('--pretrain_ckpt', type=str, default="") + parser.add_argument('--ckpt_path', type=str, default="") + parser.add_argument('--save_ckpt_path', type=str, default="") + args = parser.parse_args() + config = load_config_act(args.config) + Predictor = Predictor_hash[config.Model.predictor] + strategy = DDPStrategy(find_unused_parameters=True, gradient_as_bucket_view=True) + Data_config = config.Dataset + datamodule = MultiDataModule(**vars(Data_config)) + + if args.pretrain_ckpt == "": + model = Predictor(config.Model) + else: + logger = Logging().log(level='DEBUG') + model = Predictor(config.Model) + model.load_params_from_file(filename=args.pretrain_ckpt, + logger=logger) + trainer_config = config.Trainer + model_checkpoint = ModelCheckpoint(dirpath=args.save_ckpt_path, + filename="{epoch:02d}", + monitor='val_cls_acc', + every_n_epochs=1, + save_top_k=5, + mode='max') + lr_monitor = LearningRateMonitor(logging_interval='epoch') + trainer = pl.Trainer(accelerator=trainer_config.accelerator, devices=trainer_config.devices, + strategy=strategy, + accumulate_grad_batches=trainer_config.accumulate_grad_batches, + num_nodes=trainer_config.num_nodes, + callbacks=[model_checkpoint, lr_monitor], + max_epochs=trainer_config.max_epochs, + num_sanity_val_steps=0, + gradient_clip_val=0.5) + if args.ckpt_path == "": + trainer.fit(model, + datamodule) + else: + trainer.fit(model, + datamodule, + ckpt_path=args.ckpt_path) diff --git a/backups/thirdparty/SMART/val.py b/backups/thirdparty/SMART/val.py new file mode 100644 index 0000000000000000000000000000000000000000..9c3caa163933f6b22b0afc4dd784711839eaf337 --- /dev/null +++ b/backups/thirdparty/SMART/val.py @@ -0,0 +1,43 @@ + +from argparse import ArgumentParser +import pytorch_lightning as pl +from torch_geometric.loader import DataLoader +from smart.datasets.scalable_dataset import MultiDataset +from smart.model import SMART +from smart.transforms import WaymoTargetBuilder +from smart.utils.config import load_config_act +from smart.utils.log import Logging + +if __name__ == '__main__': + pl.seed_everything(2, workers=True) + parser = ArgumentParser() + parser.add_argument('--config', type=str, default="configs/validation/validation_scalable.yaml") + parser.add_argument('--pretrain_ckpt', type=str, default="") + parser.add_argument('--ckpt_path', type=str, default="") + parser.add_argument('--save_ckpt_path', type=str, default="") + args = parser.parse_args() + config = load_config_act(args.config) + + data_config = config.Dataset + val_dataset = { + "scalable": MultiDataset, + }[data_config.dataset](root=data_config.root, split='val', + raw_dir=data_config.val_raw_dir, + processed_dir=data_config.val_processed_dir, + transform=WaymoTargetBuilder(config.Model.num_historical_steps, config.Model.decoder.num_future_steps)) + dataloader = DataLoader(val_dataset, batch_size=data_config.batch_size, shuffle=False, num_workers=data_config.num_workers, + pin_memory=data_config.pin_memory, persistent_workers=True if data_config.num_workers > 0 else False) + Predictor = SMART + if args.pretrain_ckpt == "": + model = Predictor(config.Model) + else: + logger = Logging().log(level='DEBUG') + model = Predictor(config.Model) + model.load_params_from_file(filename=args.pretrain_ckpt, + logger=logger) + + trainer_config = config.Trainer + trainer = pl.Trainer(accelerator=trainer_config.accelerator, + devices=trainer_config.devices, + strategy='ddp', num_sanity_val_steps=0) + trainer.validate(model, dataloader) diff --git a/backups/thirdparty/catk/configs/__init__.py b/backups/thirdparty/catk/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backups/thirdparty/catk/configs/callbacks/default.yaml b/backups/thirdparty/catk/configs/callbacks/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c27325c56971cd5763e0c5e0c2dfb42dc4809f3b --- /dev/null +++ b/backups/thirdparty/catk/configs/callbacks/default.yaml @@ -0,0 +1,14 @@ +defaults: + - model_checkpoint + - model_summary + - learning_rate_monitor + - _self_ + +model_checkpoint: + dirpath: ${paths.output_dir}/checkpoints + filename: "epoch_{epoch:03d}" + save_last: link + auto_insert_metric_name: false + +model_summary: + max_depth: -1 diff --git a/backups/thirdparty/catk/configs/callbacks/learning_rate_monitor.yaml b/backups/thirdparty/catk/configs/callbacks/learning_rate_monitor.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ab7ca31850a84002a1372e018ecced5b2fecaa9a --- /dev/null +++ b/backups/thirdparty/catk/configs/callbacks/learning_rate_monitor.yaml @@ -0,0 +1,5 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.LearningRateMonitor.html + +learning_rate_monitor: + _target_: lightning.pytorch.callbacks.LearningRateMonitor + logging_interval: epoch diff --git a/backups/thirdparty/catk/configs/callbacks/model_checkpoint.yaml b/backups/thirdparty/catk/configs/callbacks/model_checkpoint.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b68587379cf1adb0186aef6ef85a31d3322b9836 --- /dev/null +++ b/backups/thirdparty/catk/configs/callbacks/model_checkpoint.yaml @@ -0,0 +1,17 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html + +model_checkpoint: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + dirpath: null # directory to save the model file + filename: null # checkpoint filename + monitor: null # name of the logged metric which determines when model is improving + verbose: false # verbosity mode + save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt + save_top_k: 1 # save k best models (determined by above metric) + mode: "min" # "max" means higher metric value is better, can be also "min" + auto_insert_metric_name: true # when True, the checkpoints filenames will contain the metric name + save_weights_only: false # if True, then only the model’s weights will be saved + every_n_train_steps: null # number of training steps between checkpoints + train_time_interval: null # checkpoints are monitored at the specified time interval + every_n_epochs: 1 # number of epochs between checkpoints + save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation diff --git a/backups/thirdparty/catk/configs/callbacks/model_summary.yaml b/backups/thirdparty/catk/configs/callbacks/model_summary.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b75981d8cd5d73f61088d80495dc540274bca3d1 --- /dev/null +++ b/backups/thirdparty/catk/configs/callbacks/model_summary.yaml @@ -0,0 +1,5 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html + +model_summary: + _target_: lightning.pytorch.callbacks.RichModelSummary + max_depth: 1 # the maximum depth of layer nesting that the summary will include diff --git a/backups/thirdparty/catk/configs/experiment/clsft.yaml b/backups/thirdparty/catk/configs/experiment/clsft.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f8e9101bbf9ec46c924aa628df4774b3fe72c79a --- /dev/null +++ b/backups/thirdparty/catk/configs/experiment/clsft.yaml @@ -0,0 +1,42 @@ +# @package _global_ + +defaults: + # - override /trainer: ddp + - override /model: smart + +model: + model_config: + lr: 5e-5 + lr_min_ratio: 0.05 + token_processor: + map_token_sampling: # open-loop + num_k: 1 # for k nearest neighbors + temp: 1.0 # uniform sampling + agent_token_sampling: # closed-loop + num_k: 1 # for k nearest neighbors + temp: 1.0 + training_rollout_sampling: + criterium: topk_prob_sampled_with_dist # {topk_dist_sampled_with_prob, topk_prob, topk_prob_sampled_with_dist} + num_k: 32 # for k nearest neighbors, set to -1 to turn-off closed-loop training + temp: 1e-5 # catk = topk_prob_sampled_with_dist with temp=1e-5 + training_loss: + use_gt_raw: true + gt_thresh_scale_length: -1 # {"veh": 4.8, "cyc": 2.0, "ped": 1.0} + label_smoothing: 0.0 + rollout_as_gt: false + finetune: true + +ckpt_path: BC_PRETRAINED_MODEL.ckpt + +trainer: + limit_train_batches: 1.0 + limit_val_batches: 50 + check_val_every_n_epoch: 1 + +data: + train_batch_size: 10 + val_batch_size: 10 + test_batch_size: 10 + num_workers: 10 + +action: finetune \ No newline at end of file diff --git a/backups/thirdparty/catk/configs/experiment/ego_gmm_clsft.yaml b/backups/thirdparty/catk/configs/experiment/ego_gmm_clsft.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e3c00093531867e66bf04c892917894285bf64e4 --- /dev/null +++ b/backups/thirdparty/catk/configs/experiment/ego_gmm_clsft.yaml @@ -0,0 +1,43 @@ +# @package _global_ + +defaults: + # - override /trainer: ddp + - override /model: ego_gmm + +model: + model_config: + lr: 1e-4 + lr_min_ratio: 0.05 + token_processor: + map_token_sampling: # open-loop + num_k: 1 # for k nearest neighbors + temp: 1.0 # uniform sampling + agent_token_sampling: # closed-loop + num_k: 1 # for k nearest neighbors + temp: 1.0 + validation_rollout_sampling: + criterium: topk_prob # {topk_prob, topk_prob_sampled_with_dist} + num_k: 3 # for k most likely + temp_mode: 1e-3 + temp_cov: 1e-3 + training_rollout_sampling: + criterium: topk_prob_sampled_with_dist # {topk_prob, topk_prob_sampled_with_dist} + num_k: 3 # for k nearest neighbors, set to -1 to turn-off closed-loop training + temp_mode: 1e-3 + temp_cov: 1e-3 + finetune: true + +ckpt_path: BC_PRETRAINED_MODEL.ckpt + +trainer: + limit_train_batches: 1.0 + limit_val_batches: 0.1 + check_val_every_n_epoch: 1 + +data: + train_batch_size: 10 + val_batch_size: 10 + test_batch_size: 10 + num_workers: 10 + +action: finetune diff --git a/backups/thirdparty/catk/configs/experiment/ego_gmm_local_val.yaml b/backups/thirdparty/catk/configs/experiment/ego_gmm_local_val.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b537a3b02677b12b0c9299d834f7277216a4ee50 --- /dev/null +++ b/backups/thirdparty/catk/configs/experiment/ego_gmm_local_val.yaml @@ -0,0 +1,31 @@ +# @package _global_ + +defaults: + - override /model: ego_gmm + +model: + model_config: + n_vis_batch: 0 + n_vis_scenario: 0 + n_vis_rollout: 0 + n_batch_wosac_metric: 100 + val_open_loop: false + val_closed_loop: true + validation_rollout_sampling: + criterium: topk_prob # {topk_prob, topk_prob_sampled_with_dist} + num_k: 3 # for k most likely + temp_mode: 1e-3 + temp_cov: 1e-3 + +ckpt_path: YOUR_MODEL.ckpt + +trainer: + limit_train_batches: 1.0 + limit_val_batches: 60 + check_val_every_n_epoch: 1 + +data: + train_batch_size: 16 + val_batch_size: 16 + test_batch_size: 16 + num_workers: 8 \ No newline at end of file diff --git a/backups/thirdparty/catk/configs/experiment/ego_gmm_pre_bc.yaml b/backups/thirdparty/catk/configs/experiment/ego_gmm_pre_bc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..580b6d2579fc7bd5bb441641de7e39451bb28b14 --- /dev/null +++ b/backups/thirdparty/catk/configs/experiment/ego_gmm_pre_bc.yaml @@ -0,0 +1,41 @@ +# @package _global_ + +defaults: + # - override /trainer: ddp + - override /model: ego_gmm + +model: + model_config: + lr: 5e-4 + token_processor: + map_token_sampling: # open-loop + num_k: 1 # for k nearest neighbors + temp: 1.0 # uniform sampling + agent_token_sampling: # closed-loop + num_k: 1 # for k nearest neighbors + temp: 1.0 + validation_rollout_sampling: + criterium: topk_prob # {topk_prob, topk_prob_sampled_with_dist} + num_k: 3 # for k most likely + temp_mode: 1e-3 + temp_cov: 1e-3 + training_rollout_sampling: + criterium: topk_prob # {topk_prob, topk_prob_sampled_with_dist} + num_k: -1 # for k nearest neighbors, set to -1 to turn-off closed-loop training + temp_mode: 1e-3 + temp_cov: 1e-3 + +ckpt_path: null +# ckpt_path: CKPT_FOR_RESUME.ckpt # to resume training + +trainer: + limit_train_batches: 1.0 + limit_val_batches: 0.1 + check_val_every_n_epoch: 1 + max_epochs: 64 + +data: + train_batch_size: 10 + val_batch_size: 10 + test_batch_size: 10 + num_workers: 10 \ No newline at end of file diff --git a/backups/thirdparty/catk/configs/experiment/local_val.yaml b/backups/thirdparty/catk/configs/experiment/local_val.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b18eef6ec4ed0b462e54cd07d45ef20d82ead3b0 --- /dev/null +++ b/backups/thirdparty/catk/configs/experiment/local_val.yaml @@ -0,0 +1,33 @@ +# @package _global_ + +defaults: + - override /model: smart + +model: + model_config: + n_vis_batch: 0 + n_vis_scenario: 0 + n_vis_rollout: 0 + n_batch_wosac_metric: 100 + val_open_loop: false + val_closed_loop: true + validation_rollout_sampling: + criterium: topk_prob + num_k: 64 # for k most likely + temp: 1.0 + +ckpt_path: YOUR_MODEL.ckpt + +trainer: + limit_train_batches: 1.0 + limit_val_batches: 60 + limit_test_batches: 1.0 + check_val_every_n_epoch: 1 + +data: + train_batch_size: 16 + val_batch_size: 16 + test_batch_size: 16 + num_workers: 8 + +action: validate \ No newline at end of file diff --git a/backups/thirdparty/catk/configs/experiment/pre_bc.yaml b/backups/thirdparty/catk/configs/experiment/pre_bc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f4664bda3fcf3ce1aa40bb80c0d562594648484c --- /dev/null +++ b/backups/thirdparty/catk/configs/experiment/pre_bc.yaml @@ -0,0 +1,32 @@ +# @package _global_ + +defaults: + # - override /trainer: ddp + - override /model: smart + +model: + model_config: + lr: 5e-4 + lr_min_ratio: 1e-2 + token_processor: + map_token_sampling: # open-loop + num_k: 1 # for k nearest neighbors + temp: 1.0 # uniform sampling + agent_token_sampling: # closed-loop + num_k: 1 # for k nearest neighbors + temp: 1.0 + +ckpt_path: null +# ckpt_path: CKPT_FOR_RESUME.ckpt # to resume training + +trainer: + limit_train_batches: 1.0 + limit_val_batches: 0.1 + check_val_every_n_epoch: 1 + max_epochs: 64 + +data: + train_batch_size: 10 + val_batch_size: 10 + test_batch_size: 10 + num_workers: 10 \ No newline at end of file diff --git a/backups/thirdparty/catk/configs/experiment/wosac_sub.yaml b/backups/thirdparty/catk/configs/experiment/wosac_sub.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0e429e2dfe7352e919c59ea33a72076fb759c29e --- /dev/null +++ b/backups/thirdparty/catk/configs/experiment/wosac_sub.yaml @@ -0,0 +1,42 @@ +# @package _global_ + +defaults: + # - override /trainer: ddp + - override /model: smart + +model: + model_config: + n_vis_batch: 0 + n_vis_scenario: 0 + n_vis_rollout: 0 + n_batch_wosac_metric: 0 + val_open_loop: false + val_closed_loop: true + validation_rollout_sampling: + criterium: topk_prob + num_k: 64 # for k most likely + temp: 1.0 + wosac_submission: + is_active: true + method_name: "SMART-tiny-CLSFT" + authors: [Anonymous] + affiliation: YOUR_AFFILIATION + description: YOUR_DESCRIPTION + method_link: YOUR_METHOD_LINK + account_name: YOUR_ACCOUNT_NAME + +ckpt_path: YOUR_MODEL.ckpt + +trainer: + limit_train_batches: 1.0 + limit_val_batches: 1.0 + limit_test_batches: 1.0 + check_val_every_n_epoch: 1 + +data: + train_batch_size: 16 + val_batch_size: 16 + test_batch_size: 16 + num_workers: 16 + shuffle: false + pin_memory: false \ No newline at end of file diff --git a/backups/thirdparty/catk/configs/hydra/default.yaml b/backups/thirdparty/catk/configs/hydra/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a61e9b3a39a2fe08adf8c9e8ff8dc2abb9a4ef2f --- /dev/null +++ b/backups/thirdparty/catk/configs/hydra/default.yaml @@ -0,0 +1,19 @@ +# https://hydra.cc/docs/configure_hydra/intro/ + +# enable color logging +defaults: + - override hydra_logging: colorlog + - override job_logging: colorlog + +# output directory, generated dynamically on each run +run: + dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} +sweep: + dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} + subdir: ${hydra.job.num} + +job_logging: + handlers: + file: + # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242 + filename: ${hydra.runtime.output_dir}/${task_name}.log diff --git a/backups/thirdparty/catk/configs/logger/wandb.yaml b/backups/thirdparty/catk/configs/logger/wandb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a4d4fa1ebf1cf1e3af7cadf28d5310782e87c9c0 --- /dev/null +++ b/backups/thirdparty/catk/configs/logger/wandb.yaml @@ -0,0 +1,17 @@ +# https://wandb.ai + +wandb: + _target_: lightning.pytorch.loggers.wandb.WandbLogger + name: ${task_name} + save_dir: "${paths.output_dir}" + offline: False + id: null # pass correct id to resume experiment! + anonymous: null # enable anonymous logging + project: clsft-catk + log_model: False # upload lightning ckpts + prefix: "" # a string to put at the beginning of metric keys + entity: YOUR_ENTITY + group: "" + tags: [] + job_type: "" + resume: allow diff --git a/backups/thirdparty/catk/configs/model/ego_gmm.yaml b/backups/thirdparty/catk/configs/model/ego_gmm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c9f214d2e49ce088339640b607f8646b3633cdc8 --- /dev/null +++ b/backups/thirdparty/catk/configs/model/ego_gmm.yaml @@ -0,0 +1,55 @@ +_target_: src.smart.model.ego_gmm_smart.EgoGMMSMART +model_config: + lr: 0.0005 + lr_warmup_steps: 0 + lr_total_steps: ${trainer.max_epochs} + lr_min_ratio: 0.05 + n_rollout_closed_val: 32 + n_batch_wosac_metric: 10 + n_vis_batch: 2 + n_vis_scenario: 5 + n_vis_rollout: 5 + val_closed_loop: true + token_processor: + map_token_file: "map_traj_token5.pkl" + agent_token_file: "cluster_frame_5_2048_remove_duplicate.pkl" + map_token_sampling: # open-loop + num_k: 1 # for k nearest neighbors + temp: 1.0 # uniform sampling + agent_token_sampling: # closed-loop + num_k: 1 # for k nearest neighbors + temp: 1.0 + validation_rollout_sampling: + criterium: topk_prob # {topk_prob, topk_prob_sampled_with_dist} + num_k: 3 # for k most likely + temp_mode: 1e-3 + temp_cov: 1e-3 + training_rollout_sampling: + criterium: topk_prob # {topk_prob, topk_prob_sampled_with_dist} + num_k: -1 # for k nearest neighbors, set to -1 to turn-off closed-loop training + temp_mode: 1e-3 + temp_cov: 1e-3 + decoder: + hidden_dim: 128 + num_freq_bands: 64 + num_heads: 4 + head_dim: 8 + dropout: 0.1 + hist_drop_prob: 0.1 + num_map_layers: 2 + num_agent_layers: 4 + pl2pl_radius: 10 + pl2a_radius: 30 + a2a_radius: 60 + time_span: 30 + num_historical_steps: 11 + num_future_steps: 80 + k_ego_gmm: 16 + cov_ego_gmm: [1.0, 0.1] + cov_learnable: false + training_loss: + use_gt_raw: true + gt_thresh_scale_length: -1.0 # {"veh": 4.8, "cyc": 2.0, "ped": 1.0} + hard_assignment: false + rollout_as_gt: false + finetune: false \ No newline at end of file diff --git a/backups/thirdparty/catk/configs/model/smart.yaml b/backups/thirdparty/catk/configs/model/smart.yaml new file mode 100644 index 0000000000000000000000000000000000000000..61258eabd3c596e55b6b5017ee37c65c3f9838b6 --- /dev/null +++ b/backups/thirdparty/catk/configs/model/smart.yaml @@ -0,0 +1,59 @@ +_target_: src.smart.model.smart.SMART +model_config: + lr: 0.0005 + lr_warmup_steps: 0 + lr_total_steps: ${trainer.max_epochs} + lr_min_ratio: 0.05 + n_rollout_closed_val: 32 + n_batch_wosac_metric: 10 + n_vis_batch: 2 + n_vis_scenario: 5 + n_vis_rollout: 5 + val_open_loop: true + val_closed_loop: true + token_processor: + map_token_file: "map_traj_token5.pkl" + agent_token_file: "agent_vocab_555_s2.pkl" + map_token_sampling: # open-loop + num_k: 1 # for k nearest neighbors + temp: 1.0 # uniform sampling + agent_token_sampling: # closed-loop + num_k: 1 # for k nearest neighbors + temp: 1.0 + validation_rollout_sampling: + criterium: topk_prob + num_k: 5 # for k most likely + temp: 1.0 + training_rollout_sampling: + criterium: topk_prob # {topk_dist_sampled_with_prob, topk_prob, topk_prob_sampled_with_dist} + num_k: -1 # for k nearest neighbors, set to -1 to turn-off closed-loop training + temp: 1.0 + decoder: + hidden_dim: 128 + num_freq_bands: 64 + num_heads: 8 + head_dim: 16 + dropout: 0.1 + hist_drop_prob: 0.1 + num_map_layers: 3 + num_agent_layers: 6 + pl2pl_radius: 10 + pl2a_radius: 30 + a2a_radius: 60 + time_span: 30 + num_historical_steps: 11 + num_future_steps: 80 + wosac_submission: + is_active: false + method_name: "SMART-tiny-CLSFT" + authors: [Anonymous] + affiliation: YOUR_AFFILIATION + description: YOUR_DESCRIPTION + method_link: YOUR_METHOD_LINK + account_name: YOUR_ACCOUNT_NAME + training_loss: + use_gt_raw: true + gt_thresh_scale_length: -1.0 # {"veh": 4.8, "cyc": 2.0, "ped": 1.0} + label_smoothing: 0.1 + rollout_as_gt: false + finetune: false \ No newline at end of file diff --git a/backups/thirdparty/catk/configs/model/smart_mini_3M.yaml b/backups/thirdparty/catk/configs/model/smart_mini_3M.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cc9203113cab0df7142193fb0aa00a4c2ca45ad8 --- /dev/null +++ b/backups/thirdparty/catk/configs/model/smart_mini_3M.yaml @@ -0,0 +1,9 @@ +defaults: + - smart + +model_config: + decoder: + num_heads: 4 + head_dim: 8 + num_map_layers: 2 + num_agent_layers: 4 \ No newline at end of file diff --git a/backups/thirdparty/catk/configs/paths/default.yaml b/backups/thirdparty/catk/configs/paths/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..57ed2eba492622160be40bb3917b65d6eee36dec --- /dev/null +++ b/backups/thirdparty/catk/configs/paths/default.yaml @@ -0,0 +1,16 @@ +# path to root directory +# this requires PROJECT_ROOT environment variable to exist +# you can replace it with "." if you want the root to be the current working directory +# root_dir: ${oc.env:PROJECT_ROOT} +root_dir: ${hydra:runtime.cwd} + +# path to logging directory +log_dir: ${paths.root_dir}/logs/ + +# path to output directory, created dynamically by hydra +# path generation pattern is specified in `configs/hydra/default.yaml` +# use it to store all files generated during the run, like ckpts and metrics +output_dir: ${hydra:runtime.output_dir} + +# path to data directory +cache_root: /scratch/cache/SMART diff --git a/backups/thirdparty/catk/configs/run.yaml b/backups/thirdparty/catk/configs/run.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d5e0525584304dba298187fef1f604de25e0faba --- /dev/null +++ b/backups/thirdparty/catk/configs/run.yaml @@ -0,0 +1,29 @@ +# @package _global_ + +# specify here default configuration +# order of defaults determines the order in which configs override each other +defaults: + - _self_ + - data: waymo + - model: smart + - callbacks: default + - logger: wandb + - trainer: default + - paths: default + - hydra: default + + # experiment configs allow for version control of specific hyperparameters + # e.g. best hyperparameters for given model and datamodule + - experiment: pre_bc + +action: fit # fit, finetune, validate, test + +# task name, determines output directory path +task_name: "debug_open_source" + +# simply provide checkpoint path to resume training +ckpt_path: null +train_log_dir: null + +# seed for random number generators in pytorch, numpy and python.random +seed: 817 diff --git a/backups/thirdparty/catk/configs/trainer/ddp.yaml b/backups/thirdparty/catk/configs/trainer/ddp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4af13b2b44e4bb53f77093c64064049088ac8543 --- /dev/null +++ b/backups/thirdparty/catk/configs/trainer/ddp.yaml @@ -0,0 +1,13 @@ +defaults: + - default + +strategy: + _target_: lightning.pytorch.strategies.DDPStrategy + find_unused_parameters: false + gradient_as_bucket_view: true + +accelerator: gpu +devices: -1 +num_nodes: 1 +sync_batchnorm: true +log_every_n_steps: 20 diff --git a/backups/thirdparty/catk/configs/trainer/default.yaml b/backups/thirdparty/catk/configs/trainer/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..76f9d0d23a3c91ada6b6079dfb2cf1cb4f2f3a3a --- /dev/null +++ b/backups/thirdparty/catk/configs/trainer/default.yaml @@ -0,0 +1,27 @@ +_target_: lightning.pytorch.trainer.Trainer + +default_root_dir: ${paths.output_dir} + +limit_train_batches: 5 +limit_val_batches: 5 +limit_test_batches: 1.0 + +# max_steps: 25000 +# val_check_interval: 0.5 + +max_epochs: 32 + +accelerator: gpu +devices: -1 + +precision: 32-true +check_val_every_n_epoch: 1 + +# set True to to ensure deterministic results +# makes training slower but gives more reproducibility than just setting seeds +deterministic: false +gradient_clip_val: 0.5 +num_sanity_val_steps: 0 +accumulate_grad_batches: 1 +log_every_n_steps: 1 +strategy: auto diff --git a/backups/thirdparty/catk/scripts/cache_womd.sh b/backups/thirdparty/catk/scripts/cache_womd.sh new file mode 100644 index 0000000000000000000000000000000000000000..ddcc9015df1d9212b0bf35447187e9ce193882be --- /dev/null +++ b/backups/thirdparty/catk/scripts/cache_womd.sh @@ -0,0 +1,15 @@ +#!/bin/sh +export LOGLEVEL=INFO +export HYDRA_FULL_ERROR=1 +export TF_CPP_MIN_LOG_LEVEL=2 + +DATA_SPLIT=validation # training, validation, testing + +source ~/miniconda3/etc/profile.d/conda.sh +conda activate catk +python \ + -m src.data_preprocess \ + --split $DATA_SPLIT \ + --num_workers 12 \ + --input_dir /scratch/data/womd/uncompressed/scenario \ + --output_dir /scratch/cache/SMART \ No newline at end of file diff --git a/backups/thirdparty/catk/scripts/local_val.sh b/backups/thirdparty/catk/scripts/local_val.sh new file mode 100644 index 0000000000000000000000000000000000000000..f8eba7c67519f19d0ec98aa30b70eba95562ada4 --- /dev/null +++ b/backups/thirdparty/catk/scripts/local_val.sh @@ -0,0 +1,23 @@ +#!/bin/sh +export LOGLEVEL=INFO +export HYDRA_FULL_ERROR=1 +export TF_CPP_MIN_LOG_LEVEL=2 + +MY_EXPERIMENT="local_val" +VAL_K=48 +MY_TASK_NAME=$MY_EXPERIMENT-K$VAL_K"-debug" + +source ~/miniconda3/etc/profile.d/conda.sh +conda activate catk +# local_val runs on single GPU +python \ + -m src.run \ + experiment=$MY_EXPERIMENT \ + trainer=default \ + model.model_config.validation_rollout_sampling.num_k=$VAL_K \ + trainer.accelerator=gpu \ + trainer.devices=1 \ + trainer.strategy=auto \ + task_name=$MY_TASK_NAME + +echo "bash local_val.sh done!" \ No newline at end of file diff --git a/backups/thirdparty/catk/scripts/train.sh b/backups/thirdparty/catk/scripts/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..21ac61ec27479e8b678e32f6adaee818791d41ca --- /dev/null +++ b/backups/thirdparty/catk/scripts/train.sh @@ -0,0 +1,29 @@ +#!/bin/sh +export LOGLEVEL=INFO +export HYDRA_FULL_ERROR=1 +export TF_CPP_MIN_LOG_LEVEL=2 +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +MY_EXPERIMENT="pre_bc" +MY_TASK_NAME=$MY_EXPERIMENT"-debug" + +source ~/miniconda3/etc/profile.d/conda.sh +conda activate catk +torchrun \ + -m src.run \ + experiment=$MY_EXPERIMENT \ + task_name=$MY_TASK_NAME + +# ! below is for training with ddp +# torchrun \ +# --rdzv_id $SLURM_JOB_ID \ +# --rdzv_backend c10d \ +# --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ +# --nnodes $NUM_NODES \ +# --nproc_per_node gpu \ +# -m src.run \ +# experiment=$MY_EXPERIMENT \ +# trainer=ddp \ +# task_name=$MY_TASK_NAME + +echo "bash train.sh done!" \ No newline at end of file diff --git a/backups/thirdparty/catk/scripts/wosac_sub.sh b/backups/thirdparty/catk/scripts/wosac_sub.sh new file mode 100644 index 0000000000000000000000000000000000000000..fb4c2694703e6446ce406cfb72695c03941506f9 --- /dev/null +++ b/backups/thirdparty/catk/scripts/wosac_sub.sh @@ -0,0 +1,31 @@ +#!/bin/sh +export LOGLEVEL=INFO +export HYDRA_FULL_ERROR=1 +export TF_CPP_MIN_LOG_LEVEL=2 + +ACTION=validate # validate, test +MY_EXPERIMENT="wosac_sub" +MY_TASK_NAME=$MY_EXPERIMENT-$ACTION"-debug" + +source ~/miniconda3/etc/profile.d/conda.sh +conda activate catk +python \ + -m src.run \ + experiment=$MY_EXPERIMENT \ + action=$ACTION \ + task_name=$MY_TASK_NAME + +# below is for training with ddp +# torchrun \ +# --rdzv_id $SLURM_JOB_ID \ +# --rdzv_backend c10d \ +# --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ +# --nnodes $NUM_NODES \ +# --nproc_per_node gpu \ +# -m src.run \ +# experiment=$MY_EXPERIMENT \ +# trainer=ddp \ +# action=$ACTION \ +# task_name=$MY_TASK_NAME + +echo bash $ACTION done! \ No newline at end of file diff --git a/backups/thirdparty/catk/src/__init__.py b/backups/thirdparty/catk/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backups/thirdparty/catk/src/data_preprocess.py b/backups/thirdparty/catk/src/data_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..436979012e3e057a39acad7f13b325797ffff3fb --- /dev/null +++ b/backups/thirdparty/catk/src/data_preprocess.py @@ -0,0 +1,521 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import multiprocessing +import pickle +from argparse import ArgumentParser +from functools import partial +from pathlib import Path +from typing import Any, Dict, List, Optional + +import numpy as np +import pandas as pd +import tensorflow as tf +import torch +from scipy.interpolate import interp1d +from tqdm import tqdm +from waymo_open_dataset.protos import scenario_pb2 + +from src.smart.utils.geometry import wrap_angle +from src.smart.utils.preprocess import get_polylines_from_polygon, preprocess_map + +# agent_types = {0: "vehicle", 1: "pedestrian", 2: "cyclist"} +# agent_roles = {0: "ego_vehicle", 1: "interest", 2: "predict"} +# polyline_type = { +# # for lane +# "TYPE_FREEWAY": 0, +# "TYPE_SURFACE_STREET": 1, +# "TYPE_STOP_SIGN": 2, +# "TYPE_BIKE_LANE": 3, +# # for roadedge +# "TYPE_ROAD_EDGE_BOUNDARY": 4, +# "TYPE_ROAD_EDGE_MEDIAN": 5, +# # for roadline +# "BROKEN": 6, +# "SOLID_SINGLE": 7, +# "DOUBLE": 8, +# # for crosswalk, speed bump and drive way +# "TYPE_CROSSWALK": 9, +# } +_polygon_types = ["lane", "road_edge", "road_line", "crosswalk"] +_polygon_light_type = [ + "NO_LANE_STATE", + "LANE_STATE_UNKNOWN", + "LANE_STATE_STOP", + "LANE_STATE_GO", + "LANE_STATE_CAUTION", +] + + +def get_agent_features( + track_infos: Dict[str, np.ndarray], split, num_historical_steps, num_steps +) -> Dict[str, Any]: + """ + track_infos: + object_id (100,) int64 + object_type (100,) uint8 + states (100, 91, 9) float32 + valid (100, 91) bool + role (100, 3) bool + """ + + idx_agents_to_add = [] + for i in range(len(track_infos["object_id"])): + add_agent = track_infos["valid"][i, num_historical_steps - 1] + + if add_agent: + idx_agents_to_add.append(i) + + num_agents = len(idx_agents_to_add) + out_dict = { + "num_nodes": num_agents, + "valid_mask": torch.zeros([num_agents, num_steps], dtype=torch.bool), + "role": torch.zeros([num_agents, 3], dtype=torch.bool), + "id": torch.zeros(num_agents, dtype=torch.int64) - 1, + "type": torch.zeros(num_agents, dtype=torch.uint8), + "position": torch.zeros([num_agents, num_steps, 3], dtype=torch.float32), + "heading": torch.zeros([num_agents, num_steps], dtype=torch.float32), + "velocity": torch.zeros([num_agents, num_steps, 2], dtype=torch.float32), + "shape": torch.zeros([num_agents, 3], dtype=torch.float32), + } + + for i, idx in enumerate(idx_agents_to_add): + + out_dict["role"][i] = torch.from_numpy(track_infos["role"][idx]) + out_dict["id"][i] = track_infos["object_id"][idx] + out_dict["type"][i] = track_infos["object_type"][idx] + + valid = track_infos["valid"][idx] # [n_step] + states = track_infos["states"][idx] + + object_shape = states[:, 3:6] # [n_step, 3], length, width, height + object_shape = object_shape[valid].mean(axis=0) # [3] + out_dict["shape"][i] = torch.from_numpy(object_shape) + + valid_steps = np.where(valid)[0] + position = states[:, :3] # [n_step, dim], x, y, z + velocity = states[:, 7:9] # [n_step, 2], vx, vy + heading = states[:, 6] # [n_step], heading + if valid.sum() > 1: + t_start, t_end = valid_steps[0], valid_steps[-1] + f_pos = interp1d(valid_steps, position[valid], axis=0) + f_vel = interp1d(valid_steps, velocity[valid], axis=0) + f_yaw = interp1d(valid_steps, np.unwrap(heading[valid], axis=0), axis=0) + t_in = np.arange(t_start, t_end + 1) + out_dict["valid_mask"][i, t_start : t_end + 1] = True + out_dict["position"][i, t_start : t_end + 1] = torch.from_numpy(f_pos(t_in)) + out_dict["velocity"][i, t_start : t_end + 1] = torch.from_numpy(f_vel(t_in)) + out_dict["heading"][i, t_start : t_end + 1] = torch.from_numpy(f_yaw(t_in)) + else: + t = valid_steps[0] + out_dict["valid_mask"][i, t] = True + out_dict["position"][i, t] = torch.from_numpy(position[t]) + out_dict["velocity"][i, t] = torch.from_numpy(velocity[t]) + out_dict["heading"][i, t] = torch.tensor(heading[t]) + + return out_dict + + +def get_map_features(map_infos, tf_current_light, dim=2): + polygon_ids = [x["id"] for k in _polygon_types for x in map_infos[k]] + num_polygons = len(polygon_ids) + + # initialization + polygon_type = torch.zeros(num_polygons, dtype=torch.uint8) + polygon_light_type = torch.zeros(num_polygons, dtype=torch.uint8) + point_position: List[Optional[torch.Tensor]] = [None] * num_polygons + # point_orientation: List[Optional[torch.Tensor]] = [None] * num_polygons + point_type: List[Optional[torch.Tensor]] = [None] * num_polygons + + for _key in _polygon_types: + for _seg in map_infos[_key]: + _idx = polygon_ids.index(_seg["id"]) + centerline = map_infos["all_polylines"][ + _seg["polyline_index"][0] : _seg["polyline_index"][1] + ] + centerline = torch.from_numpy(centerline).float() + polygon_type[_idx] = _polygon_types.index(_key) + + point_position[_idx] = centerline[:-1, :dim] + center_vectors = centerline[1:] - centerline[:-1] + # point_orientation[_idx] = torch.cat( + # [torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0 + # ) + point_type[_idx] = torch.full( + (len(center_vectors),), _seg["type"], dtype=torch.uint8 + ) + + if _key == "lane": + res = tf_current_light[tf_current_light["lane_id"] == _seg["id"]] + if len(res) != 0: + polygon_light_type[_idx] = _polygon_light_type.index( + res["state"].item() + ) + + num_points = torch.tensor( + [point.size(0) for point in point_position], dtype=torch.long + ) + point_to_polygon_edge_index = torch.stack( + [ + torch.arange(num_points.sum(), dtype=torch.long), + torch.arange(num_polygons, dtype=torch.long).repeat_interleave(num_points), + ], + dim=0, + ) + + map_data = { + "map_polygon": {}, + "map_point": {}, + ("map_point", "to", "map_polygon"): {}, + } + map_data["map_polygon"]["num_nodes"] = num_polygons + map_data["map_polygon"]["type"] = polygon_type + map_data["map_polygon"]["light_type"] = polygon_light_type + if len(num_points) == 0: + map_data["map_point"]["num_nodes"] = 0 + map_data["map_point"]["position"] = torch.tensor([], dtype=torch.float) + # map_data["map_point"]["orientation"] = torch.tensor([], dtype=torch.float) + map_data["map_point"]["type"] = torch.tensor([], dtype=torch.uint8) + else: + map_data["map_point"]["num_nodes"] = num_points.sum().item() + map_data["map_point"]["position"] = torch.cat(point_position, dim=0) + # map_data["map_point"]["orientation"] = wrap_angle( + # torch.cat(point_orientation, dim=0) + # ) + map_data["map_point"]["type"] = torch.cat(point_type, dim=0) + map_data["map_point", "to", "map_polygon"][ + "edge_index" + ] = point_to_polygon_edge_index + return map_data + + +def process_dynamic_map(dynamic_map_infos): + lane_ids = dynamic_map_infos["lane_id"] + tf_lights = [] + for t in range(len(lane_ids)): + lane_id = lane_ids[t] + time = np.ones_like(lane_id) * t + state = dynamic_map_infos["state"][t] + tf_light = np.concatenate([lane_id, time, state], axis=0) + tf_lights.append(tf_light) + tf_lights = np.concatenate(tf_lights, axis=1).transpose(1, 0) + tf_lights = pd.DataFrame(data=tf_lights, columns=["lane_id", "time_step", "state"]) + tf_lights["time_step"] = tf_lights["time_step"].astype("int") + tf_lights["lane_id"] = tf_lights["lane_id"].astype("int") + tf_lights["state"] = tf_lights["state"].astype("str") + tf_lights.loc[tf_lights["state"].str.contains("STOP"), ["state"]] = ( + "LANE_STATE_STOP" + ) + tf_lights.loc[tf_lights["state"].str.contains("GO"), ["state"]] = "LANE_STATE_GO" + tf_lights.loc[tf_lights["state"].str.contains("CAUTION"), ["state"]] = ( + "LANE_STATE_CAUTION" + ) + tf_lights.loc[tf_lights["state"].str.contains("UNKNOWN"), ["state"]] = ( + "LANE_STATE_UNKNOWN" + ) + return tf_lights + + +def decode_tracks_from_proto(scenario): + sdc_track_index = scenario.sdc_track_index + track_index_predict = [i.track_index for i in scenario.tracks_to_predict] + object_id_interest = [i for i in scenario.objects_of_interest] + + track_infos = { + "object_id": [], + "object_type": [], + "states": [], + "valid": [], + "role": [], + } + for i, cur_data in enumerate(scenario.tracks): # number of objects + + step_state = [] + step_valid = [] + for s in cur_data.states: + step_state.append( + [ + s.center_x, + s.center_y, + s.center_z, + s.length, + s.width, + s.height, + s.heading, + s.velocity_x, + s.velocity_y, + ] + ) + step_valid.append(s.valid) + # This angle is normalized to [-pi, pi). The velocity vector in m/s + + track_infos["object_id"].append(cur_data.id) + track_infos["object_type"].append(cur_data.object_type - 1) + track_infos["states"].append(np.array(step_state, dtype=np.float32)) + track_infos["valid"].append(np.array(step_valid)) + + track_infos["role"].append([False, False, False]) + if i in track_index_predict: + track_infos["role"][-1][2] = True # predict=2 + if cur_data.id in object_id_interest: + track_infos["role"][-1][1] = True # interest=1 + if i == sdc_track_index: # ego_vehicle=0 + track_infos["role"][-1][0] = True + + track_infos["states"] = np.array(track_infos["states"], dtype=np.float32) + track_infos["valid"] = np.array(track_infos["valid"], dtype=bool) + track_infos["role"] = np.array(track_infos["role"], dtype=bool) + track_infos["object_id"] = np.array(track_infos["object_id"], dtype=np.int64) + track_infos["object_type"] = np.array(track_infos["object_type"], dtype=np.uint8) + return track_infos + + +def decode_map_features_from_proto(map_features): + map_infos = {"lane": [], "road_edge": [], "road_line": [], "crosswalk": []} + polylines = [] + point_cnt = 0 + for mf in map_features: + feature_data_type = mf.WhichOneof("feature_data") + # pip install waymo-open-dataset-tf-2-6-0==1.4.9, not updated, should be driveway + if feature_data_type is None: + continue + + feature = getattr(mf, feature_data_type) + if feature_data_type == "lane": + if len(feature.polyline) > 1: + cur_info = {"id": mf.id} + if feature.type == 0: # UNDEFINED + cur_info["type"] = 1 + elif feature.type == 1: # FREEWAY + cur_info["type"] = 0 + elif feature.type == 2: # SURFACE_STREET + cur_info["type"] = 1 + elif feature.type == 3: # BIKE_LANE + cur_info["type"] = 3 + + cur_polyline = np.stack( + [ + np.array([p.x, p.y, p.z, cur_info["type"], cur_info["id"]]) + for p in feature.polyline + ], + axis=0, + ) + + cur_info["polyline_index"] = (point_cnt, point_cnt + len(cur_polyline)) + map_infos["lane"].append(cur_info) + polylines.append(cur_polyline) + point_cnt += len(cur_polyline) + + elif feature_data_type == "road_edge": + if len(feature.polyline) > 1: + cur_info = {"id": mf.id} + # assert feature.type > 0 + cur_info["type"] = feature.type + 3 + + cur_polyline = np.stack( + [ + np.array([p.x, p.y, p.z, cur_info["type"], cur_info["id"]]) + for p in feature.polyline + ], + axis=0, + ) + + cur_info["polyline_index"] = (point_cnt, point_cnt + len(cur_polyline)) + map_infos["road_edge"].append(cur_info) + polylines.append(cur_polyline) + point_cnt += len(cur_polyline) + + elif feature_data_type == "road_line": + if len(feature.polyline) > 1: + cur_info = {"id": mf.id} + # there is no UNKNOWN = 0 + # BROKEN_SINGLE_WHITE = 1 + # SOLID_SINGLE_WHITE = 2 + # SOLID_DOUBLE_WHITE = 3 + # BROKEN_SINGLE_YELLOW = 4 + # BROKEN_DOUBLE_YELLOW = 5 + # SOLID_SINGLE_YELLOW = 6 + # SOLID_DOUBLE_YELLOW = 7 + # PASSING_DOUBLE_YELLOW = 8 + # assert feature.type > 0 # no UNKNOWN = 0 + if feature.type in [1, 4, 5]: + cur_info["type"] = 6 # BROKEN + elif feature.type in [2, 6]: + cur_info["type"] = 7 # SOLID_SINGLE + else: + cur_info["type"] = 8 # DOUBLE + + cur_polyline = np.stack( + [ + np.array([p.x, p.y, p.z, cur_info["type"], cur_info["id"]]) + for p in feature.polyline + ], + axis=0, + ) + + cur_info["polyline_index"] = (point_cnt, point_cnt + len(cur_polyline)) + map_infos["road_line"].append(cur_info) + polylines.append(cur_polyline) + point_cnt += len(cur_polyline) + + elif feature_data_type in ["speed_bump", "driveway", "crosswalk"]: + xyz = np.array([[p.x, p.y, p.z] for p in feature.polygon]) + polygon_idx = np.linspace(0, xyz.shape[0], 4, endpoint=False, dtype=int) + pl_polygon = get_polylines_from_polygon(xyz[polygon_idx]) + cur_info = {"id": mf.id, "type": 9} + + cur_polyline = np.stack( + [ + np.array([p[0], p[1], p[2], cur_info["type"], cur_info["id"]]) + for p in pl_polygon + ], + axis=0, + ) + + cur_info["polyline_index"] = (point_cnt, point_cnt + len(cur_polyline)) + map_infos["crosswalk"].append(cur_info) + polylines.append(cur_polyline) + point_cnt += len(cur_polyline) + + for mf in map_features: + feature_data_type = mf.WhichOneof("feature_data") + if feature_data_type == "stop_sign": + feature = mf.stop_sign + for l_id in feature.lane: + # override FREEWAY/SURFACE_STREET with stop sign lane + # BIKE_LANE remains unchanged + is_found = False + for _i in range(len(map_infos["lane"])): + if map_infos["lane"][_i]["id"] == l_id: + is_found = True + if map_infos["lane"][_i]["type"] < 2: + map_infos["lane"][_i]["type"] = 2 + # not necessary found, some stop sign lanes are for lane with length 1 + # assert is_found + + try: + polylines = np.concatenate(polylines, axis=0).astype(np.float32) + except: + polylines = np.zeros((0, 8), dtype=np.float32) + print("Empty polylines.") + map_infos["all_polylines"] = polylines + return map_infos + + +def decode_dynamic_map_states_from_proto(dynamic_map_states): + signal_state = { + 0: "LANE_STATE_UNKNOWN", + # States for traffic signals with arrows. + 1: "LANE_STATE_ARROW_STOP", + 2: "LANE_STATE_ARROW_CAUTION", + 3: "LANE_STATE_ARROW_GO", + # Standard round traffic signals. + 4: "LANE_STATE_STOP", + 5: "LANE_STATE_CAUTION", + 6: "LANE_STATE_GO", + # Flashing light signals. + 7: "LANE_STATE_FLASHING_STOP", + 8: "LANE_STATE_FLASHING_CAUTION", + } + + dynamic_map_infos = {"lane_id": [], "state": []} + for cur_data in dynamic_map_states: # (num_timestamp) + lane_id, state = [], [] + for cur_signal in cur_data.lane_states: # (num_observed_signals) + lane_id.append(cur_signal.lane) + state.append(signal_state[cur_signal.state]) + + dynamic_map_infos["lane_id"].append(np.array([lane_id])) + dynamic_map_infos["state"].append(np.array([state])) + + return dynamic_map_infos + + +def wm2argo(file_path, split, output_dir, output_dir_tfrecords_splitted): + dataset = tf.data.TFRecordDataset( + file_path, compression_type="", num_parallel_reads=3 + ) + for tf_data in dataset: + tf_data = tf_data.numpy() + scenario = scenario_pb2.Scenario() + scenario.ParseFromString(bytes(tf_data)) + + track_infos = decode_tracks_from_proto(scenario) + map_infos = decode_map_features_from_proto(scenario.map_features) + dynamic_map_infos = decode_dynamic_map_states_from_proto( + scenario.dynamic_map_states + ) + + current_time_index = scenario.current_time_index + scenario_id = scenario.scenario_id + tf_lights = process_dynamic_map(dynamic_map_infos) + tf_current_light = tf_lights.loc[tf_lights["time_step"] == current_time_index] + map_data = get_map_features(map_infos, tf_current_light) + + data = preprocess_map(map_data) + data["agent"] = get_agent_features( + track_infos, + split=split, + num_historical_steps=current_time_index + 1, + num_steps=91, + ) + + data["scenario_id"] = scenario_id + with open(output_dir / f"{scenario_id}.pkl", "wb+") as f: + pickle.dump(data, f) + + if output_dir_tfrecords_splitted is not None: + file_name = output_dir_tfrecords_splitted / f"{scenario_id}.tfrecords" + with tf.io.TFRecordWriter(file_name.as_posix()) as file_writer: + file_writer.write(tf_data) + + +def batch_process9s_transformer(input_dir, output_dir, split, num_workers): + output_dir = Path(output_dir) + output_dir_tfrecords_splitted = None + if split == "validation": + output_dir_tfrecords_splitted = output_dir / "validation_tfrecords_splitted" + output_dir_tfrecords_splitted.mkdir(exist_ok=True, parents=True) + output_dir = output_dir / split + output_dir.mkdir(exist_ok=True, parents=True) + + input_dir = Path(input_dir) / split + packages = sorted([p.as_posix() for p in input_dir.glob("*")]) + func = partial( + wm2argo, + split=split, + output_dir=output_dir, + output_dir_tfrecords_splitted=output_dir_tfrecords_splitted, + ) + + with multiprocessing.Pool(num_workers) as p: + r = list(tqdm(p.imap_unordered(func, packages), total=len(packages))) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument( + "--input_dir", + type=str, + default="/root/workspace/data/womd/uncompressed/scenario", + ) + parser.add_argument( + "--output_dir", type=str, default="/root/workspace/data/SMART_new" + ) + parser.add_argument("--split", type=str, default="validation") + parser.add_argument("--num_workers", type=int, default=2) + args = parser.parse_args() + + batch_process9s_transformer( + args.input_dir, args.output_dir, args.split, num_workers=args.num_workers + ) diff --git a/backups/thirdparty/catk/src/run.py b/backups/thirdparty/catk/src/run.py new file mode 100644 index 0000000000000000000000000000000000000000..b1c685115d317c864dbd2c825c1817fa2c632a03 --- /dev/null +++ b/backups/thirdparty/catk/src/run.py @@ -0,0 +1,109 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from typing import List + +import hydra +import lightning as L +import torch +import wandb +from lightning import Callback, LightningDataModule, LightningModule, Trainer +from lightning.pytorch.loggers import Logger +from lightning.pytorch.loggers.wandb import WandbLogger +from omegaconf import DictConfig + +from src.utils import ( + RankedLogger, + instantiate_callbacks, + instantiate_loggers, + log_hyperparameters, + print_config_tree, +) + +log = RankedLogger(__name__, rank_zero_only=True) + +torch.set_float32_matmul_precision("high") + + +def run(cfg: DictConfig) -> None: + if cfg.get("seed"): + L.seed_everything(cfg.seed, workers=True) + + log.info(f"Instantiating datamodule <{cfg.data._target_}>") + datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) + + log.info(f"Instantiating model <{cfg.model._target_}>") + model: LightningModule = hydra.utils.instantiate(cfg.model, _recursive_=False) + + log.info("Instantiating callbacks...") + callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) + + log.info(f"Instantiating loggers...") + logger: List[Logger] = instantiate_loggers(cfg.get("logger")) + # setup model watching + for _logger in logger: + if isinstance(_logger, WandbLogger): + _logger.watch(model, log="all") + + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") + trainer: Trainer = hydra.utils.instantiate( + cfg.trainer, callbacks=callbacks, logger=logger + ) + + log.info("Logging hyperparameters!") + log_hyperparameters( + { + "cfg": cfg, + "datamodule": datamodule, + "model": model, + "callbacks": callbacks, + "logger": logger, + "trainer": trainer, + } + ) + + log.info(f"Resuming from ckpt: cfg.ckpt_path={cfg.ckpt_path}") + if cfg.action == "fit": + log.info("Starting training!") + trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) + elif cfg.action == "finetune": + log.info("Starting finetuning!") + model.load_state_dict(torch.load(cfg.ckpt_path)["state_dict"], strict=False) + trainer.fit(model=model, datamodule=datamodule) + elif cfg.action == "validate": + log.info("Starting validating!") + trainer.validate( + model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path") + ) + elif cfg.action == "test": + log.info("Starting testing!") + trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) + + +@hydra.main(config_path="../configs/", config_name="run.yaml", version_base=None) +def main(cfg: DictConfig) -> None: + torch.set_printoptions(precision=3) + + log.info("Printing config tree with Rich! ") + print_config_tree(cfg, resolve=True, save_to_file=True) + + run(cfg) # train/val/test the model + + log.info("Closing wandb!") + wandb.finish() + log.info(f"Output dir: {cfg.paths.output_dir}") + + +if __name__ == "__main__": + main() + log.info("run.py DONE!!!") diff --git a/backups/thirdparty/catk/src/smart/__init__.py b/backups/thirdparty/catk/src/smart/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backups/thirdparty/catk/src/smart/datamodules/__init__.py b/backups/thirdparty/catk/src/smart/datamodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..741105702425e25de489959fa0ec2e37bdae0b57 --- /dev/null +++ b/backups/thirdparty/catk/src/smart/datamodules/__init__.py @@ -0,0 +1 @@ +from src.smart.datamodules.scalable_datamodule import MultiDataModule diff --git a/backups/thirdparty/catk/src/smart/datamodules/scalable_datamodule.py b/backups/thirdparty/catk/src/smart/datamodules/scalable_datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..8495dd30237465284088c47091d47e96140d4917 --- /dev/null +++ b/backups/thirdparty/catk/src/smart/datamodules/scalable_datamodule.py @@ -0,0 +1,108 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from typing import Optional + +from lightning import LightningDataModule +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from torch_geometric.loader import DataLoader + +from src.smart.datasets import MultiDataset + +from .target_builder import WaymoTargetBuilderTrain, WaymoTargetBuilderVal + + +class MultiDataModule(LightningDataModule): + def __init__( + self, + train_batch_size: int, + val_batch_size: int, + test_batch_size: int, + train_raw_dir: str, + val_raw_dir: str, + test_raw_dir: str, + val_tfrecords_splitted: str, + shuffle: bool, + num_workers: int, + pin_memory: bool, + persistent_workers: bool, + train_max_num: int, + ) -> None: + super(MultiDataModule, self).__init__() + self.train_batch_size = train_batch_size + self.val_batch_size = val_batch_size + self.test_batch_size = test_batch_size + self.shuffle = shuffle + self.num_workers = num_workers + self.pin_memory = pin_memory + self.persistent_workers = persistent_workers and num_workers > 0 + self.train_raw_dir = train_raw_dir + self.val_raw_dir = val_raw_dir + self.test_raw_dir = test_raw_dir + self.val_tfrecords_splitted = val_tfrecords_splitted + + self.train_transform = WaymoTargetBuilderTrain(train_max_num) + self.val_transform = WaymoTargetBuilderVal() + self.test_transform = WaymoTargetBuilderVal() + + def setup(self, stage: Optional[str] = None) -> None: + if stage == "fit" or stage is None: + self.train_dataset = MultiDataset(self.train_raw_dir, self.train_transform) + self.val_dataset = MultiDataset( + self.val_raw_dir, + self.val_transform, + tfrecord_dir=self.val_tfrecords_splitted, + ) + elif stage == "validate": + self.val_dataset = MultiDataset( + self.val_raw_dir, + self.val_transform, + tfrecord_dir=self.val_tfrecords_splitted, + ) + elif stage == "test": + self.test_dataset = MultiDataset(self.test_raw_dir, self.test_transform) + else: + raise ValueError(f"{stage} should be one of [fit, validate, test]") + + def train_dataloader(self) -> TRAIN_DATALOADERS: + return DataLoader( + self.train_dataset, + batch_size=self.train_batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers, + drop_last=False, + ) + + def val_dataloader(self) -> EVAL_DATALOADERS: + return DataLoader( + self.val_dataset, + batch_size=self.val_batch_size, + shuffle=False, + num_workers=self.num_workers, + pin_memory=self.pin_memory, # False + persistent_workers=self.persistent_workers, + drop_last=False, + ) + + def test_dataloader(self) -> EVAL_DATALOADERS: + return DataLoader( + self.test_dataset, + batch_size=self.test_batch_size, + shuffle=False, + num_workers=self.num_workers, # 0 + pin_memory=self.pin_memory, # False + persistent_workers=self.persistent_workers, + drop_last=False, + ) diff --git a/backups/thirdparty/catk/src/smart/datamodules/target_builder.py b/backups/thirdparty/catk/src/smart/datamodules/target_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..9dc7ce11e9b91185710294c2024ef066b5e2461f --- /dev/null +++ b/backups/thirdparty/catk/src/smart/datamodules/target_builder.py @@ -0,0 +1,58 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch +from torch_geometric.data import HeteroData +from torch_geometric.transforms import BaseTransform + + +class WaymoTargetBuilderTrain(BaseTransform): + def __init__(self, max_num: int) -> None: + super(WaymoTargetBuilderTrain, self).__init__() + self.step_current = 10 + self.max_num = max_num + + def __call__(self, data) -> HeteroData: + pos = data["agent"]["position"] + av_index = torch.where(data["agent"]["role"][:, 0])[0].item() + distance = torch.norm(pos - pos[av_index], dim=-1) + + # we do not believe the perception out of range of 150 meters + data["agent"]["valid_mask"] = data["agent"]["valid_mask"] & (distance < 150) + + # we do not predict vehicle too far away from ego car + role_train_mask = data["agent"]["role"].any(-1) + extra_train_mask = (distance[:, self.step_current] < 100) & ( + data["agent"]["valid_mask"][:, self.step_current + 1 :].sum(-1) >= 5 + ) + + train_mask = extra_train_mask | role_train_mask + if train_mask.sum() > self.max_num: # too many vehicle + _indices = torch.where(extra_train_mask & ~role_train_mask)[0] + selected_indices = _indices[ + torch.randperm(_indices.size(0))[: self.max_num - role_train_mask.sum()] + ] + data["agent"]["train_mask"] = role_train_mask + data["agent"]["train_mask"][selected_indices] = True + else: + data["agent"]["train_mask"] = train_mask # [n_agent] + + return HeteroData(data) + + +class WaymoTargetBuilderVal(BaseTransform): + def __init__(self) -> None: + super(WaymoTargetBuilderVal, self).__init__() + + def __call__(self, data) -> HeteroData: + return HeteroData(data) diff --git a/backups/thirdparty/catk/src/smart/datasets/__init__.py b/backups/thirdparty/catk/src/smart/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fbc3e8e8fd8d0a248905f9b0b5f06e90a07bb86c --- /dev/null +++ b/backups/thirdparty/catk/src/smart/datasets/__init__.py @@ -0,0 +1 @@ +from src.smart.datasets.scalable_dataset import MultiDataset diff --git a/backups/thirdparty/catk/src/smart/datasets/scalable_dataset.py b/backups/thirdparty/catk/src/smart/datasets/scalable_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d13cbf52b4d7949e7eb53b45c2113584f013d5ca --- /dev/null +++ b/backups/thirdparty/catk/src/smart/datasets/scalable_dataset.py @@ -0,0 +1,58 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import pickle +from pathlib import Path +from typing import Callable, List, Optional + +from torch_geometric.data import Dataset + +from src.utils import RankedLogger + +log = RankedLogger(__name__, rank_zero_only=True) + + +class MultiDataset(Dataset): + def __init__( + self, + raw_dir: str, + transform: Callable, + tfrecord_dir: Optional[str] = None, + ) -> None: + raw_dir = Path(raw_dir) + self._raw_paths = [p.as_posix() for p in sorted(raw_dir.glob("*"))] + self._num_samples = len(self._raw_paths) + + self._tfrecord_dir = Path(tfrecord_dir) if tfrecord_dir is not None else None + + log.info("Length of {} dataset is ".format(raw_dir) + str(self._num_samples)) + super(MultiDataset, self).__init__( + transform=transform, pre_transform=None, pre_filter=None + ) + + @property + def raw_paths(self) -> List[str]: + return self._raw_paths + + def len(self) -> int: + return self._num_samples + + def get(self, idx: int): + with open(self.raw_paths[idx], "rb") as handle: + data = pickle.load(handle) + + if self._tfrecord_dir is not None: + data["tfrecord_path"] = ( + self._tfrecord_dir / (data["scenario_id"] + ".tfrecords") + ).as_posix() + return data diff --git a/backups/thirdparty/catk/src/smart/layers/__init__.py b/backups/thirdparty/catk/src/smart/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..388bae5858900af2002e96e88a4aabeff8809c8f --- /dev/null +++ b/backups/thirdparty/catk/src/smart/layers/__init__.py @@ -0,0 +1,3 @@ +from src.smart.layers.attention_layer import AttentionLayer +from src.smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding +from src.smart.layers.mlp_layer import MLPLayer diff --git a/backups/thirdparty/catk/src/smart/layers/attention_layer.py b/backups/thirdparty/catk/src/smart/layers/attention_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..e85d76aa9772749ab5e45885533713a1f24783c9 --- /dev/null +++ b/backups/thirdparty/catk/src/smart/layers/attention_layer.py @@ -0,0 +1,114 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch_geometric.nn.conv import MessagePassing +from torch_geometric.utils import softmax + +from src.smart.utils import weight_init + + +class AttentionLayer(MessagePassing): + + def __init__( + self, + hidden_dim: int, + num_heads: int, + head_dim: int, + dropout: float, + bipartite: bool, + has_pos_emb: bool, + **kwargs + ) -> None: + super(AttentionLayer, self).__init__(aggr="add", node_dim=0, **kwargs) + self.num_heads = num_heads + self.head_dim = head_dim + self.has_pos_emb = has_pos_emb + self.scale = head_dim**-0.5 + + self.to_q = nn.Linear(hidden_dim, head_dim * num_heads) + self.to_k = nn.Linear(hidden_dim, head_dim * num_heads, bias=False) + self.to_v = nn.Linear(hidden_dim, head_dim * num_heads) + if has_pos_emb: + self.to_k_r = nn.Linear(hidden_dim, head_dim * num_heads, bias=False) + self.to_v_r = nn.Linear(hidden_dim, head_dim * num_heads) + self.to_s = nn.Linear(hidden_dim, head_dim * num_heads) + self.to_g = nn.Linear(head_dim * num_heads + hidden_dim, head_dim * num_heads) + self.to_out = nn.Linear(head_dim * num_heads, hidden_dim) + self.attn_drop = nn.Dropout(dropout) + self.ff_mlp = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim * 4), + nn.ReLU(inplace=True), + nn.Dropout(dropout), + nn.Linear(hidden_dim * 4, hidden_dim), + ) + if bipartite: + self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim) + self.attn_prenorm_x_dst = nn.LayerNorm(hidden_dim) + else: + self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim) + self.attn_prenorm_x_dst = self.attn_prenorm_x_src + if has_pos_emb: + self.attn_prenorm_r = nn.LayerNorm(hidden_dim) + self.attn_postnorm = nn.LayerNorm(hidden_dim) + self.ff_prenorm = nn.LayerNorm(hidden_dim) + self.ff_postnorm = nn.LayerNorm(hidden_dim) + self.apply(weight_init) + + def forward( + self, + x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + r: Optional[torch.Tensor], + edge_index: torch.Tensor, + ) -> torch.Tensor: + if isinstance(x, torch.Tensor): + x_src = x_dst = self.attn_prenorm_x_src(x) + else: + x_src, x_dst = x + x_src = self.attn_prenorm_x_src(x_src) + x_dst = self.attn_prenorm_x_dst(x_dst) + x = x[1] + if self.has_pos_emb and r is not None: + r = self.attn_prenorm_r(r) + x = x + self.attn_postnorm(self._attn_block(x_src, x_dst, r, edge_index)) + x = x + self.ff_postnorm(self._ff_block(self.ff_prenorm(x))) + return x + + def message( + self, + q_i: torch.Tensor, + k_j: torch.Tensor, + v_j: torch.Tensor, + r: Optional[torch.Tensor], + index: torch.Tensor, + ptr: Optional[torch.Tensor], + ) -> torch.Tensor: + if self.has_pos_emb and r is not None: + k_j = k_j + self.to_k_r(r).view(-1, self.num_heads, self.head_dim) + v_j = v_j + self.to_v_r(r).view(-1, self.num_heads, self.head_dim) + sim = (q_i * k_j).sum(dim=-1) * self.scale + attn = softmax(sim, index, ptr) + self.attention_weight = attn.sum(-1).detach() + attn = self.attn_drop(attn) + return v_j * attn.unsqueeze(-1) + + def update(self, inputs: torch.Tensor, x_dst: torch.Tensor) -> torch.Tensor: + inputs = inputs.view(-1, self.num_heads * self.head_dim) + g = torch.sigmoid(self.to_g(torch.cat([inputs, x_dst], dim=-1))) + return inputs + g * (self.to_s(x_dst) - inputs) + + def _attn_block( + self, + x_src: torch.Tensor, + x_dst: torch.Tensor, + r: Optional[torch.Tensor], + edge_index: torch.Tensor, + ) -> torch.Tensor: + q = self.to_q(x_dst).view(-1, self.num_heads, self.head_dim) + k = self.to_k(x_src).view(-1, self.num_heads, self.head_dim) + v = self.to_v(x_src).view(-1, self.num_heads, self.head_dim) + agg = self.propagate(edge_index=edge_index, x_dst=x_dst, q=q, k=k, v=v, r=r) + return self.to_out(agg) + + def _ff_block(self, x: torch.Tensor) -> torch.Tensor: + return self.ff_mlp(x) diff --git a/backups/thirdparty/catk/src/smart/layers/fourier_embedding.py b/backups/thirdparty/catk/src/smart/layers/fourier_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..b4d85140b4f6fd3416caf218b0f51086a60483e0 --- /dev/null +++ b/backups/thirdparty/catk/src/smart/layers/fourier_embedding.py @@ -0,0 +1,89 @@ +import math +from typing import List, Optional + +import torch +import torch.nn as nn + +from src.smart.utils import weight_init + + +class FourierEmbedding(nn.Module): + + def __init__(self, input_dim: int, hidden_dim: int, num_freq_bands: int) -> None: + super(FourierEmbedding, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.freqs = nn.Embedding(input_dim, num_freq_bands) if input_dim != 0 else None + self.mlps = nn.ModuleList( + [ + nn.Sequential( + nn.Linear(num_freq_bands * 2 + 1, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, hidden_dim), + ) + for _ in range(input_dim) + ] + ) + self.to_out = nn.Sequential( + nn.LayerNorm(hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, hidden_dim), + ) + self.apply(weight_init) + + def forward( + self, + continuous_inputs: Optional[torch.Tensor] = None, + categorical_embs: Optional[List[torch.Tensor]] = None, + ) -> torch.Tensor: + if continuous_inputs is None: + if categorical_embs is not None: + x = torch.stack(categorical_embs).sum(dim=0) + else: + raise ValueError("Both continuous_inputs and categorical_embs are None") + else: + x = continuous_inputs.unsqueeze(-1) * self.freqs.weight * 2 * math.pi + # Warning: if your data are noisy, don't use learnable sinusoidal embedding + x = torch.cat([x.cos(), x.sin(), continuous_inputs.unsqueeze(-1)], dim=-1) + continuous_embs: List[Optional[torch.Tensor]] = [None] * self.input_dim + for i in range(self.input_dim): + continuous_embs[i] = self.mlps[i](x[:, i]) + x = torch.stack(continuous_embs).sum(dim=0) + if categorical_embs is not None: + x = x + torch.stack(categorical_embs).sum(dim=0) + return self.to_out(x) + + +class MLPEmbedding(nn.Module): + def __init__(self, input_dim: int, hidden_dim: int) -> None: + super(MLPEmbedding, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.mlp = nn.Sequential( + nn.Linear(input_dim, 128), + nn.LayerNorm(128), + nn.ReLU(inplace=True), + nn.Linear(128, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, hidden_dim), + ) + self.apply(weight_init) + + def forward( + self, + continuous_inputs: Optional[torch.Tensor] = None, + categorical_embs: Optional[List[torch.Tensor]] = None, + ) -> torch.Tensor: + if continuous_inputs is None: + if categorical_embs is not None: + x = torch.stack(categorical_embs).sum(dim=0) + else: + raise ValueError("Both continuous_inputs and categorical_embs are None") + else: + x = self.mlp(continuous_inputs) + if categorical_embs is not None: + x = x + torch.stack(categorical_embs).sum(dim=0) + return x diff --git a/backups/thirdparty/catk/src/smart/layers/mlp_layer.py b/backups/thirdparty/catk/src/smart/layers/mlp_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..5f5f815a2ee9c5fa4721dc9ad591d16032cbe237 --- /dev/null +++ b/backups/thirdparty/catk/src/smart/layers/mlp_layer.py @@ -0,0 +1,20 @@ +import torch +import torch.nn as nn + +from src.smart.utils import weight_init + + +class MLPLayer(nn.Module): + + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int) -> None: + super(MLPLayer, self).__init__() + self.mlp = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, output_dim), + ) + self.apply(weight_init) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(x) diff --git a/backups/thirdparty/catk/src/smart/metrics/__init__.py b/backups/thirdparty/catk/src/smart/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6dd57e290271e6923fb600ccac7acf9962b56b0 --- /dev/null +++ b/backups/thirdparty/catk/src/smart/metrics/__init__.py @@ -0,0 +1,20 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from src.smart.metrics.cross_entropy import CrossEntropy +from src.smart.metrics.ego_nll import EgoNLL +from src.smart.metrics.gmm_ade import GMMADE +from src.smart.metrics.min_ade import minADE +from src.smart.metrics.next_token_cls import TokenCls +from src.smart.metrics.wosac_metrics import WOSACMetrics +from src.smart.metrics.wosac_submission import WOSACSubmission diff --git a/backups/thirdparty/catk/src/smart/metrics/cross_entropy.py b/backups/thirdparty/catk/src/smart/metrics/cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..a5c32ef117c2bfc1631529211cb09fa85842b4c8 --- /dev/null +++ b/backups/thirdparty/catk/src/smart/metrics/cross_entropy.py @@ -0,0 +1,117 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from typing import Optional + +import torch +from torch import Tensor, tensor +from torch.nn.functional import cross_entropy +from torchmetrics.metric import Metric + +from .utils import get_euclidean_targets, get_prob_targets + + +class CrossEntropy(Metric): + + is_differentiable = True + higher_is_better = False + full_state_update = False + + def __init__( + self, + use_gt_raw: bool, + gt_thresh_scale_length: float, # {"veh": 4.8, "cyc": 2.0, "ped": 1.0} + label_smoothing: float, + rollout_as_gt: bool, + ) -> None: + super().__init__() + self.use_gt_raw = use_gt_raw + self.gt_thresh_scale_length = gt_thresh_scale_length + self.label_smoothing = label_smoothing + self.rollout_as_gt = rollout_as_gt + self.add_state("loss_sum", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("count", default=tensor(0.0), dist_reduce_fx="sum") + + def update( + self, + # ! action that goes from [(10->15), ..., (85->90)] + next_token_logits: Tensor, # [n_agent, 16, n_token] + next_token_valid: Tensor, # [n_agent, 16] + # ! for step {5, 10, ..., 90} and act [(0->5), (5->10), ..., (85->90)] + pred_pos: Tensor, # [n_agent, 18, 2] + pred_head: Tensor, # [n_agent, 18] + pred_valid: Tensor, # [n_agent, 18] + # ! for step {5, 10, ..., 90} + gt_pos_raw: Tensor, # [n_agent, 18, 2] + gt_head_raw: Tensor, # [n_agent, 18] + gt_valid_raw: Tensor, # [n_agent, 18] + # or use the tokenized gt + gt_pos: Tensor, # [n_agent, 18, 2] + gt_head: Tensor, # [n_agent, 18] + gt_valid: Tensor, # [n_agent, 18] + # ! for tokenization + token_agent_shape: Tensor, # [n_agent, 2] + token_traj: Tensor, # [n_agent, n_token, 4, 2] + # ! for filtering intersting agent for training + train_mask: Optional[Tensor] = None, # [n_agent] + # ! for rollout_as_gt + next_token_action: Optional[Tensor] = None, # [n_agent, 16, 3] + **kwargs, + ) -> None: + # ! use raw or tokenized GT + if self.use_gt_raw: + gt_pos = gt_pos_raw + gt_head = gt_head_raw + gt_valid = gt_valid_raw + + # ! GT is valid if it's close to the rollout. + if self.gt_thresh_scale_length > 0: + dist = torch.norm(pred_pos - gt_pos, dim=-1) # [n_agent, n_step] + _thresh = token_agent_shape[:, 1] * self.gt_thresh_scale_length # [n_agent] + gt_valid = gt_valid & (dist < _thresh.unsqueeze(1)) # [n_agent, n_step] + + # ! get prob_targets + euclidean_target, euclidean_target_valid = get_euclidean_targets( + pred_pos=pred_pos, + pred_head=pred_head, + pred_valid=pred_valid, + gt_pos=gt_pos, + gt_head=gt_head, + gt_valid=gt_valid, + ) + if self.rollout_as_gt and (next_token_action is not None): + euclidean_target = next_token_action + + prob_target = get_prob_targets( + target=euclidean_target, # [n_agent, n_step, 3] x,y,yaw in local + token_agent_shape=token_agent_shape, # [n_agent, 2] + token_traj=token_traj, # [n_agent, n_token, 4, 2] + ) # [n_agent, n_step, n_token] prob, last dim sum up to 1 + + loss = cross_entropy( + next_token_logits.transpose(1, 2), # [n_agent, n_token, n_step], logits + prob_target.transpose(1, 2), # [n_agent, n_token, n_step], prob + reduction="none", + label_smoothing=self.label_smoothing, + ) # [n_agent, n_step=16] + + # ! weighting final loss [n_agent, n_step] + loss_weighting_mask = next_token_valid & euclidean_target_valid + if self.training: + loss_weighting_mask &= train_mask.unsqueeze(1) # [n_agent, n_step] + + self.loss_sum += (loss * loss_weighting_mask).sum() + self.count += (loss_weighting_mask > 0).sum() + + def compute(self) -> Tensor: + return self.loss_sum / self.count diff --git a/backups/thirdparty/catk/src/smart/metrics/ego_nll.py b/backups/thirdparty/catk/src/smart/metrics/ego_nll.py new file mode 100644 index 0000000000000000000000000000000000000000..c0d9f3a5d1b499a5bda68f049622d4d15cb1f4a8 --- /dev/null +++ b/backups/thirdparty/catk/src/smart/metrics/ego_nll.py @@ -0,0 +1,141 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from typing import Optional + +import torch +from torch import Tensor, tensor +from torch.distributions import Categorical, Independent, MixtureSameFamily, Normal +from torchmetrics.metric import Metric + +from .utils import get_euclidean_targets + + +class EgoNLL(Metric): + + is_differentiable = True + higher_is_better = False + full_state_update = False + + def __init__( + self, + use_gt_raw: bool, + gt_thresh_scale_length: float, # {"veh": 4.8, "cyc": 2.0, "ped": 1.0} + hard_assignment: bool, + rollout_as_gt: bool, + ) -> None: + super().__init__() + self.use_gt_raw = use_gt_raw + self.gt_thresh_scale_length = gt_thresh_scale_length + self.hard_assignment = hard_assignment + self.rollout_as_gt = rollout_as_gt + self.add_state("loss_sum", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("count", default=tensor(0.0), dist_reduce_fx="sum") + + def update( + self, + # ! action that goes from [(10->15), ..., (85->90)] + ego_next_logits: Tensor, # [n_batch, 16, n_k_ego_gmm] + ego_next_poses: Tensor, # [n_batch, 16, n_k_ego_gmm, 3] + ego_next_valid: Tensor, # [n_batch, 16] + ego_next_cov: Tensor, # [2], one for pos, one for heading. + # ! for step {5, 10, ..., 90} and act [(0->5), (5->10), ..., (85->90)] + pred_pos: Tensor, # [n_batch, 18, 2] + pred_head: Tensor, # [n_batch, 18] + pred_valid: Tensor, # [n_batch, 18] + # ! for step {5, 10, ..., 90} + gt_pos_raw: Tensor, # [n_batch, 18, 2] + gt_head_raw: Tensor, # [n_batch, 18] + gt_valid_raw: Tensor, # [n_batch, 18] + # or use the tokenized gt + gt_pos: Tensor, # [n_batch, 18, 2] + gt_head: Tensor, # [n_batch, 18] + gt_valid: Tensor, # [n_batch, 18] + token_agent_shape: Tensor, # [n_agent, 2] + # ! for rollout_as_gt + next_token_action: Optional[Tensor] = None, # [n_batch, 16, 3] + **kwargs, + ) -> None: + # ! use raw or tokenized GT + if self.use_gt_raw: + gt_pos = gt_pos_raw + gt_head = gt_head_raw + gt_valid = gt_valid_raw + + # ! GT is valid if it's close to the rollout. + if self.gt_thresh_scale_length > 0: + dist = torch.norm(pred_pos - gt_pos, dim=-1) # [n_agent, n_step] + _thresh = token_agent_shape[:, 1] * self.gt_thresh_scale_length # [n_agent] + gt_valid = gt_valid & (dist < _thresh.unsqueeze(1)) # [n_agent, n_step] + + # ! get prob_targets + target, target_valid = get_euclidean_targets( + pred_pos=pred_pos, + pred_head=pred_head, + pred_valid=pred_valid, + gt_pos=gt_pos, + gt_head=gt_head, + gt_valid=gt_valid, + ) + if self.rollout_as_gt and (next_token_action is not None): + target = next_token_action + + # ! transform yaw angle to unit vector + ego_next_poses = torch.cat( + [ + ego_next_poses[..., :2], + ego_next_poses[..., [-1]].cos(), + ego_next_poses[..., [-1]].sin(), + ], + dim=-1, + ) + ego_next_poses = ego_next_poses.flatten(0, 1) # [n_batch*n_step, K, 4] + cov = ego_next_cov.repeat_interleave(2)[None, None, :].expand( + *ego_next_poses.shape + ) # [n_batch*n_step, K, 4] + + n_batch, n_step = target_valid.shape + target = torch.cat( + [target[..., :2], target[..., [-1]].cos(), target[..., [-1]].sin()], dim=-1 + ) # [n_batch, n_step, 4] + target = target.flatten(0, 1) # [n_batch*n_step, 4] + + ego_next_logits = ego_next_logits.flatten(0, 1) # [n_batch*n_step, K] + if self.hard_assignment: + idx_hard_assign = ( + (ego_next_poses - target.unsqueeze(1))[..., :2].norm(dim=-1).argmin(-1) + ) + n_batch_step = idx_hard_assign.shape[0] + ego_next_poses = ego_next_poses[ + torch.arange(n_batch_step), idx_hard_assign + ].unsqueeze(1) + cov = cov[torch.arange(n_batch_step), idx_hard_assign].unsqueeze(1) + ego_next_logits = ego_next_logits[ + torch.arange(n_batch_step), idx_hard_assign + ].unsqueeze(1) + + gmm = MixtureSameFamily( + Categorical(logits=ego_next_logits), + Independent(Normal(ego_next_poses, cov), 1), + ) + + loss = -gmm.log_prob(target) # [n_batch*n_step] + loss = loss.view(n_batch, n_step) # [n_batch, n_step] + + loss_weighting_mask = target_valid & ego_next_valid # [n_batch, n_step] + + self.loss_sum += (loss * loss_weighting_mask).sum() + self.count += (loss_weighting_mask > 0).sum() + + def compute(self) -> Tensor: + return self.loss_sum / self.count diff --git a/backups/thirdparty/catk/src/smart/metrics/gmm_ade.py b/backups/thirdparty/catk/src/smart/metrics/gmm_ade.py new file mode 100644 index 0000000000000000000000000000000000000000..3d6622b2c627f77e4639e2950de998284cebf19b --- /dev/null +++ b/backups/thirdparty/catk/src/smart/metrics/gmm_ade.py @@ -0,0 +1,47 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch +from torch import Tensor, tensor +from torchmetrics import Metric + + +class GMMADE(Metric): + + def __init__(self) -> None: + super(GMMADE, self).__init__() + self.add_state("sum", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("count", default=tensor(0.0), dist_reduce_fx="sum") + + def update( + self, + logits: Tensor, # [n_agent, n_step, n_k] + pred: Tensor, # [n_agent, n_step, n_k, 2] + target: Tensor, # [n_agent, n_step, 2] + valid: Tensor, # [n_agent, n_step] + ) -> None: + n_agent, n_step, _ = logits.shape + idx_max = logits.argmax(-1) # [n_agent, n_step] + pred_max = pred[ + torch.arange(n_agent).unsqueeze(1), + torch.arange(n_step).unsqueeze(0), + idx_max, + ] # [n_agent, n_step, 2] + + dist = torch.norm(pred_max - target, p=2, dim=-1) # [n_agent, n_step] + dist = ((dist * valid).sum(-1)) / (valid.sum(-1) + 1e-6) # [n_agent] + self.sum += dist.sum() + self.count += valid.any(-1).sum() + + def compute(self) -> torch.Tensor: + return self.sum / self.count diff --git a/backups/thirdparty/catk/src/smart/metrics/min_ade.py b/backups/thirdparty/catk/src/smart/metrics/min_ade.py new file mode 100644 index 0000000000000000000000000000000000000000..d455286cc2d36685f389d413e3a738840914c051 --- /dev/null +++ b/backups/thirdparty/catk/src/smart/metrics/min_ade.py @@ -0,0 +1,42 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch +from torch import Tensor, tensor +from torchmetrics import Metric + + +class minADE(Metric): + + def __init__(self) -> None: + super(minADE, self).__init__() + self.add_state("sum", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("count", default=tensor(0.0), dist_reduce_fx="sum") + + def update( + self, + pred: Tensor, # [n_agent, n_rollout, n_step, 2] + target: Tensor, # [n_agent, n_step, 2] + target_valid: Tensor, # [n_agent, n_step] + ) -> None: + + # [n_agent, n_rollout, n_step] + dist = torch.norm(pred - target.unsqueeze(1), p=2, dim=-1) + dist = (dist * target_valid.unsqueeze(1)).sum(-1).min(-1).values # [n_agent] + + dist = dist / (target_valid.sum(-1) + 1e-6) # [n_agent] + self.sum += dist.sum() + self.count += target_valid.any(-1).sum() + + def compute(self) -> torch.Tensor: + return self.sum / self.count diff --git a/backups/thirdparty/catk/src/smart/metrics/next_token_cls.py b/backups/thirdparty/catk/src/smart/metrics/next_token_cls.py new file mode 100644 index 0000000000000000000000000000000000000000..746dd466950435c596446818bff8d6080180e14c --- /dev/null +++ b/backups/thirdparty/catk/src/smart/metrics/next_token_cls.py @@ -0,0 +1,41 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch +from torchmetrics import Metric + + +class TokenCls(Metric): + + def __init__(self, max_guesses: int = 6, **kwargs) -> None: + super(TokenCls, self).__init__(**kwargs) + self.add_state("sum", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("count", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.max_guesses = max_guesses + + def update( + self, + pred: torch.Tensor, # next_token_logits: [n_agent, 16, n_token] + pred_valid: torch.Tensor, # next_token_idx_gt: [n_agent, 16] + target: torch.Tensor, # next_token_idx_gt: [n_agent, 16] + target_valid: torch.Tensor, # [n_agent, 16] + ) -> None: + target = target[..., None] + acc = (torch.topk(pred, k=self.max_guesses, dim=-1)[1] == target).any(dim=-1) + valid_mask = pred_valid & target_valid + acc = acc * valid_mask + self.sum += acc.sum() + self.count += valid_mask.sum() + + def compute(self) -> torch.Tensor: + return self.sum / self.count diff --git a/backups/thirdparty/catk/src/smart/metrics/utils.py b/backups/thirdparty/catk/src/smart/metrics/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bafbb5b64f3ca45dc0543b70df91b0e7de798bba --- /dev/null +++ b/backups/thirdparty/catk/src/smart/metrics/utils.py @@ -0,0 +1,84 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from typing import Tuple + +import torch +from torch import Tensor +from torch.nn.functional import one_hot + +from src.smart.utils import cal_polygon_contour, transform_to_local, wrap_angle + + +@torch.no_grad() +def get_prob_targets( + target: Tensor, # [n_agent, n_step, 3] x,y,yaw in local coord + token_agent_shape: Tensor, # [n_agent, 2] + token_traj: Tensor, # [n_agent, n_token, 4, 2] +) -> Tensor: # [n_agent, n_step, n_token] prob, last dim sum up to 1 + # ! tokenize to index, then compute prob + contour = cal_polygon_contour( + target[..., :2], # [n_agent, n_step, 2] + target[..., 2], # [n_agent, n_step] + token_agent_shape[:, None, :], # [n_agent, 1, 1, 2] + ) # [n_agent, n_step, 4, 2] in local coord + + # [n_agent, n_step, 1, 4, 2] - [n_agent, 1, n_token, 4, 2] + target_token_index = ( + torch.norm(contour.unsqueeze(2) - token_traj[:, None, :, :, :], dim=-1) + .sum(-1) + .argmin(-1) + ) # [n_agent, n_step] + + # [n_agent, n_step, n_token] bool + prob_target = one_hot(target_token_index, num_classes=token_traj.shape[1]) + prob_target = prob_target.to(target.dtype) + return prob_target + + +@torch.no_grad() +def get_euclidean_targets( + pred_pos: Tensor, # [n_agent, 18, 2] + pred_head: Tensor, # [n_agent, 18] + pred_valid: Tensor, # [n_agent, 18] + gt_pos: Tensor, # [n_agent, 18, 2] + gt_head: Tensor, # [n_agent, 18] + gt_valid: Tensor, # [n_agent, 18] +) -> Tuple[Tensor, Tensor]: + """ + Return: action that goes from [(10->15), ..., (85->90)] + target: [n_agent, 16, 3], x,y,yaw + target_valid: [n_agent, 16] + """ + gt_last_pos = gt_pos.roll(shifts=-1, dims=1).flatten(0, 1) + gt_last_head = gt_head.roll(shifts=-1, dims=1).flatten(0, 1) + gt_last_valid = gt_valid.roll(shifts=-1, dims=1) # [n_agent, 18] + gt_last_valid[:, -1:] = False # [n_agent, 18] + + target_pos, target_head = transform_to_local( + pos_global=gt_last_pos.unsqueeze(1), # [n_agent*18, 1, 2] + head_global=gt_last_head.unsqueeze(1), # [n_agent*18, 1] + pos_now=pred_pos.flatten(0, 1), # [n_agent*18, 2] + head_now=pred_head.flatten(0, 1), # [n_agent*18] + ) + target_valid = pred_valid & gt_last_valid # [n_agent, 18] + + target_pos = target_pos.squeeze(1).view(gt_pos.shape) # n_agent, 18, 2] + target_head = wrap_angle(target_head) # [n_agent, 18] + target_head = target_head.squeeze(1).view(gt_head.shape) + target = torch.cat((target_pos, target_head.unsqueeze(-1)), dim=-1) + + # truncate [(5->10), ..., (90->5)] to [(10->15), ..., (85->90)] + target = target[:, 1:-1] # [n_agent, 16, 3], x,y,yaw + target_valid = target_valid[:, 1:-1] # [n_agent, 16] + return target, target_valid diff --git a/backups/thirdparty/catk/src/smart/metrics/wosac_metrics.py b/backups/thirdparty/catk/src/smart/metrics/wosac_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..fdd2c88dfe1c621dd8b1c66e7ce6c2fd01170ff5 --- /dev/null +++ b/backups/thirdparty/catk/src/smart/metrics/wosac_metrics.py @@ -0,0 +1,185 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import itertools +import multiprocessing as mp +import os +from pathlib import Path +from typing import Dict, List + +import tensorflow as tf +import waymo_open_dataset.wdl_limited.sim_agents_metrics.metrics as wosac_metrics +from google.protobuf import text_format +from torch import Tensor, tensor +from torchmetrics import Metric +from waymo_open_dataset.protos import ( + scenario_pb2, + sim_agents_metrics_pb2, + sim_agents_submission_pb2, +) + + +class WOSACMetrics(Metric): + """ + validation metrics based on ground truth trajectory, using waymo_open_dataset api + """ + + def __init__(self, prefix: str, ego_only: bool = False) -> None: + super().__init__() + self.is_mp_init = False + self.prefix = prefix + self.ego_only = ego_only + self.wosac_config = self.load_metrics_config() + + self.field_names = [ + "metametric", + "average_displacement_error", + "linear_speed_likelihood", + "linear_acceleration_likelihood", + "angular_speed_likelihood", + "angular_acceleration_likelihood", + "distance_to_nearest_object_likelihood", + "collision_indication_likelihood", + "time_to_collision_likelihood", + "distance_to_road_edge_likelihood", + "offroad_indication_likelihood", + "min_average_displacement_error", + "simulated_collision_rate", + "simulated_offroad_rate", + ] + for k in self.field_names: + self.add_state(k, default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("scenario_counter", default=tensor(0.0), dist_reduce_fx="sum") + tf.config.set_visible_devices([], "GPU") + + @staticmethod + def _compute_scenario_metrics( + config, scenario_file, scenario_rollout, ego_only + ) -> sim_agents_metrics_pb2.SimAgentMetrics: + scenario = scenario_pb2.Scenario() + for data in tf.data.TFRecordDataset([scenario_file], compression_type=""): + scenario.ParseFromString(bytes(data.numpy())) + break + if ego_only: + for i in range(len(scenario.tracks)): + if i != scenario.sdc_track_index: + for t in range(91): + scenario.tracks[i].states[t].valid = False + while len(scenario.tracks_to_predict) > 1: + scenario.tracks_to_predict.pop() + scenario.tracks_to_predict[0].track_index = scenario.sdc_track_index + + return wosac_metrics.compute_scenario_metrics_for_bundle( + config, scenario, scenario_rollout + ) + + def update( + self, + scenario_files: List[str], + scenario_rollouts: List[sim_agents_submission_pb2.ScenarioRollouts], + ) -> None: + + if os.environ.get("CUDA_VISIBLE_DEVICES", "") in ["", "0"]: + if not self.is_mp_init: + self.is_mp_init = True + mp.set_start_method("forkserver", force=True) + with mp.Pool(processes=len(scenario_rollouts)) as pool: + pool_scenario_metrics = pool.starmap( + self._compute_scenario_metrics, + zip( + itertools.repeat(self.wosac_config), + scenario_files, + scenario_rollouts, + itertools.repeat(self.ego_only), + ), + ) + pool.close() + pool.join() + else: + pool_scenario_metrics = [] + for _scenario, _scenario_rollout in zip(scenario_files, scenario_rollouts): + pool_scenario_metrics.append( + self._compute_scenario_metrics( + self.wosac_config, _scenario, _scenario_rollout, self.ego_only + ) + ) + + for scenario_metrics in pool_scenario_metrics: + self.scenario_counter += 1 + self.metametric += scenario_metrics.metametric + self.average_displacement_error += ( + scenario_metrics.average_displacement_error + ) + self.linear_speed_likelihood += scenario_metrics.linear_speed_likelihood + self.linear_acceleration_likelihood += ( + scenario_metrics.linear_acceleration_likelihood + ) + self.angular_speed_likelihood += scenario_metrics.angular_speed_likelihood + self.angular_acceleration_likelihood += ( + scenario_metrics.angular_acceleration_likelihood + ) + self.distance_to_nearest_object_likelihood += ( + scenario_metrics.distance_to_nearest_object_likelihood + ) + self.collision_indication_likelihood += ( + scenario_metrics.collision_indication_likelihood + ) + self.time_to_collision_likelihood += ( + scenario_metrics.time_to_collision_likelihood + ) + self.distance_to_road_edge_likelihood += ( + scenario_metrics.distance_to_road_edge_likelihood + ) + self.offroad_indication_likelihood += ( + scenario_metrics.offroad_indication_likelihood + ) + self.min_average_displacement_error += ( + scenario_metrics.min_average_displacement_error + ) + self.simulated_collision_rate += scenario_metrics.simulated_collision_rate + self.simulated_offroad_rate += scenario_metrics.simulated_offroad_rate + + def compute(self) -> Dict[str, Tensor]: + metrics_dict = {} + for k in self.field_names: + metrics_dict[k] = getattr(self, k) / self.scenario_counter + + mean_metrics = sim_agents_metrics_pb2.SimAgentMetrics( + scenario_id="", **metrics_dict + ) + final_metrics = wosac_metrics.aggregate_metrics_to_buckets( + self.wosac_config, mean_metrics + ) + + out_dict = { + f"{self.prefix}/wosac/realism_meta_metric": final_metrics.realism_meta_metric, + f"{self.prefix}/wosac/kinematic_metrics": final_metrics.kinematic_metrics, + f"{self.prefix}/wosac/interactive_metrics": final_metrics.interactive_metrics, + f"{self.prefix}/wosac/map_based_metrics": final_metrics.map_based_metrics, + f"{self.prefix}/wosac/min_ade": final_metrics.min_ade, + f"{self.prefix}/wosac/scenario_counter": self.scenario_counter, + } + for k in self.field_names: + out_dict[f"{self.prefix}/wosac_likelihood/{k}"] = metrics_dict[k] + + return out_dict + + @staticmethod + def load_metrics_config() -> sim_agents_metrics_pb2.SimAgentMetricsConfig: + config_path = ( + Path(wosac_metrics.__file__).parent / "challenge_2024_config.textproto" + ) + with open(config_path, "r") as f: + config = sim_agents_metrics_pb2.SimAgentMetricsConfig() + text_format.Parse(f.read(), config) + return config diff --git a/backups/thirdparty/catk/src/smart/metrics/wosac_submission.py b/backups/thirdparty/catk/src/smart/metrics/wosac_submission.py new file mode 100644 index 0000000000000000000000000000000000000000..4a275027eb5433fd96f8980604a396cba024526b --- /dev/null +++ b/backups/thirdparty/catk/src/smart/metrics/wosac_submission.py @@ -0,0 +1,140 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import tarfile +from pathlib import Path +from typing import Dict, List + +import hydra +from omegaconf import ListConfig +from torch import Tensor +from torchmetrics.metric import Metric +from waymo_open_dataset.protos import sim_agents_submission_pb2 + +from src.utils import RankedLogger +from src.utils.wosac_utils import get_scenario_id_int_tensor + +log = RankedLogger(__name__, rank_zero_only=False) + + +class WOSACSubmission(Metric): + def __init__( + self, + is_active: bool, + method_name: str, + authors: ListConfig[str], + affiliation: str, + description: str, + method_link: str, + account_name: str, + ) -> None: + super().__init__() + self.is_active = is_active + if self.is_active: + self.method_name = method_name + self.authors = authors + self.affiliation = affiliation + self.description = description + self.method_link = method_link + self.account_name = account_name + self.buffer_scenario_rollouts = [] + self.i_file = 0 + self.submission_dir = ( + hydra.core.hydra_config.HydraConfig.get().runtime.output_dir + ) + self.submission_dir = Path(self.submission_dir) / "wosac_submission" + self.submission_dir.mkdir(exist_ok=True) + self.submission_scenario_id = [] + + self.data_keys = [ + "scenario_id", + "agent_id", + "agent_batch", + "pred_traj", + "pred_z", + "pred_head", + ] + for k in self.data_keys: + self.add_state(k, default=[], dist_reduce_fx="cat") + + def update( + self, + scenario_id: List[str], + agent_id: List[List[float]], + agent_batch: Tensor, + pred_traj: Tensor, + pred_z: Tensor, + pred_head: Tensor, + global_rank: int, + ) -> None: + _device = pred_traj.device + self.agent_id.append(agent_id) + self.scenario_id.append(get_scenario_id_int_tensor(scenario_id, _device)) + self.pred_traj.append(pred_traj) + self.pred_z.append(pred_z) + self.pred_head.append(pred_head) + + batch_size = len(scenario_id) + self.agent_batch.append(agent_batch + batch_size * global_rank) + + def compute(self) -> Dict[str, Tensor]: + return {k: getattr(self, k) for k in self.data_keys} + + def aggregate_rollouts( + self, scenario_rollouts: List[sim_agents_submission_pb2.ScenarioRollouts] + ) -> None: + for rollout in scenario_rollouts: + if rollout.scenario_id not in self.submission_scenario_id: + self.submission_scenario_id.append(rollout.scenario_id) + self.buffer_scenario_rollouts.append(rollout) + if len(self.buffer_scenario_rollouts) > 300: + self._save_shard() + + def save_sub_file(self) -> None: + self._save_shard() + self.i_file = 0 + tar_file_name = self.submission_dir.as_posix() + ".tar.gz" + + log.info(f"Saving wosac submission files to {tar_file_name}") + + shard_files = sorted([p.as_posix() for p in self.submission_dir.glob("*")]) + with tarfile.open(tar_file_name, "w:gz") as tar: + for output_filename in shard_files: + tar.add( + output_filename, + arcname=output_filename + f"-of-{len(shard_files):05d}", + ) + log.info(f"DONE: Saved wosac submission files to {tar_file_name}") + + def _save_shard(self) -> None: + shard_submission = sim_agents_submission_pb2.SimAgentsChallengeSubmission( + scenario_rollouts=self.buffer_scenario_rollouts, + submission_type=sim_agents_submission_pb2.SimAgentsChallengeSubmission.SIM_AGENTS_SUBMISSION, + account_name=self.account_name, + unique_method_name=self.method_name, + authors=self.authors, + affiliation=self.affiliation, + description=self.description, + method_link=self.method_link, + uses_lidar_data=False, + uses_camera_data=False, + uses_public_model_pretraining=False, + num_model_parameters="7M", + acknowledge_complies_with_closed_loop_requirement=True, + ) + output_filename = self.submission_dir / f"submission.binproto-{self.i_file:05d}" + log.info(f"Saving wosac submission files to {output_filename}") + with open(output_filename, "wb") as f: + f.write(shard_submission.SerializeToString()) + self.i_file += 1 + self.buffer_scenario_rollouts = [] diff --git a/backups/thirdparty/catk/src/smart/model/__init__.py b/backups/thirdparty/catk/src/smart/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backups/thirdparty/catk/src/smart/model/ego_gmm_smart.py b/backups/thirdparty/catk/src/smart/model/ego_gmm_smart.py new file mode 100644 index 0000000000000000000000000000000000000000..478d4e1dfe385b115283ce8d7c2877f8c342bee7 --- /dev/null +++ b/backups/thirdparty/catk/src/smart/model/ego_gmm_smart.py @@ -0,0 +1,246 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import math +from pathlib import Path + +import hydra +import torch +from lightning import LightningModule +from torch.optim.lr_scheduler import LambdaLR + +from src.smart.metrics import GMMADE, EgoNLL, WOSACMetrics, minADE +from src.smart.metrics.utils import get_euclidean_targets +from src.smart.modules.ego_gmm_smart_decoder import EgoGMMSMARTDecoder +from src.smart.tokens.token_processor import TokenProcessor +from src.smart.utils.finetune import set_model_for_finetuning +from src.utils.vis_waymo import VisWaymo +from src.utils.wosac_utils import get_scenario_id_int_tensor, get_scenario_rollouts + + +class EgoGMMSMART(LightningModule): + + def __init__(self, model_config) -> None: + super(EgoGMMSMART, self).__init__() + self.save_hyperparameters() + self.lr = model_config.lr + self.lr_warmup_steps = model_config.lr_warmup_steps + self.lr_total_steps = model_config.lr_total_steps + self.lr_min_ratio = model_config.lr_min_ratio + self.num_historical_steps = model_config.decoder.num_historical_steps + self.log_epoch = -1 + self.val_closed_loop = model_config.val_closed_loop + self.token_processor = TokenProcessor(**model_config.token_processor) + + self.encoder = EgoGMMSMARTDecoder(**model_config.decoder) + set_model_for_finetuning(self.encoder, model_config.finetune) + + self.minADE = minADE() + self.wosac_metrics = WOSACMetrics("val_closed", ego_only=True) + self.gmm_ade_pos = GMMADE() + self.gmm_ade_head = GMMADE() + self.training_loss = EgoNLL(**model_config.training_loss) + + self.n_rollout_closed_val = model_config.n_rollout_closed_val + self.n_vis_batch = model_config.n_vis_batch + self.n_vis_scenario = model_config.n_vis_scenario + self.n_vis_rollout = model_config.n_vis_rollout + self.n_batch_wosac_metric = model_config.n_batch_wosac_metric + + self.video_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir + self.video_dir = Path(self.video_dir) / "videos" + + self.training_rollout_sampling = model_config.training_rollout_sampling + self.validation_rollout_sampling = model_config.validation_rollout_sampling + + def training_step(self, data, batch_idx): + tokenized_map, tokenized_agent = self.token_processor(data) + if self.training_rollout_sampling.num_k <= 0: + pred = self.encoder(tokenized_map, tokenized_agent) + else: + pred = self.encoder.inference( + tokenized_map, + tokenized_agent, + sampling_scheme=self.training_rollout_sampling, + ) + + loss = self.training_loss( + **pred, + token_agent_shape=tokenized_agent["token_agent_shape"][ + tokenized_agent["ego_mask"] + ], # [n_agent, 2] + current_epoch=self.current_epoch, + ) + self.log("train/loss", loss, on_step=True, batch_size=1) + + return loss + + def validation_step(self, data, batch_idx): + tokenized_map, tokenized_agent = self.token_processor(data) + + # ! open-loop vlidation + pred = self.encoder(tokenized_map, tokenized_agent) + loss = self.training_loss( + **pred, + token_agent_shape=tokenized_agent["token_agent_shape"][ + tokenized_agent["ego_mask"] + ], + ) + self.log("val_open/loss", loss, on_epoch=True, sync_dist=True, batch_size=1) + + bc_target, bc_target_valid = get_euclidean_targets( + pred_pos=pred["gt_pos_raw"], + pred_head=pred["gt_head_raw"], + pred_valid=pred["gt_valid_raw"], + gt_pos=pred["gt_pos_raw"], + gt_head=pred["gt_head_raw"], + gt_valid=pred["gt_valid_raw"], + ) # bc_target: [n_agent, 16, 3], x,y,yaw. bc_target_valid: [n_agent, 16] + + self.gmm_ade_pos.update( + logits=pred["ego_next_logits"], # [n_agent, 16, n_k_ego_gmm] + pred=pred["ego_next_poses"][..., :2], # [n_agent, 16, n_k_ego_gmm, 2] + target=bc_target[..., :2], # [n_agent, 16, 2] + valid=bc_target_valid & pred["ego_next_valid"], # [n_agent, 16] + ) + bc_target_head = torch.stack( + [bc_target[..., -1].cos(), bc_target[..., -1].sin()], dim=-1 + ) # [n_agent, 16, 2] + ego_next_heads = torch.stack( + [ + pred["ego_next_poses"][..., -1].cos(), + pred["ego_next_poses"][..., -1].sin(), + ], + dim=-1, + ) # [n_agent, 16, n_k_ego_gmm, 2] + self.gmm_ade_head.update( + logits=pred["ego_next_logits"], # [n_agent, 16, n_k_ego_gmm] + pred=ego_next_heads, # [n_agent, 16, n_k_ego_gmm, 2] + target=bc_target_head, # [n_agent, 16, 2] + valid=bc_target_valid & pred["ego_next_valid"], # [n_agent, 16] + ) + self.log( + "val_open/gmm_ade_pos", + self.gmm_ade_pos, + on_epoch=True, + sync_dist=True, + batch_size=1, + ) + + self.log( + "val_open/gmm_ade_head", + self.gmm_ade_head, + on_epoch=True, + sync_dist=True, + batch_size=1, + ) + + # ! closed-loop vlidation + if self.val_closed_loop: + pred_traj, pred_z, pred_head = [], [], [] + for _ in range(self.n_rollout_closed_val): + pred = self.encoder.inference( + tokenized_map, tokenized_agent, self.validation_rollout_sampling + ) + pred_traj.append(pred["pred_traj_10hz"]) + pred_z.append(pred["pred_z_10hz"]) + pred_head.append(pred["pred_head_10hz"]) + + pred_traj = torch.stack(pred_traj, dim=1) # [n_ag, n_rollout, n_step, 2] + pred_z = torch.stack(pred_z, dim=1) # [n_ag, n_rollout, n_step] + pred_head = torch.stack(pred_head, dim=1) # [n_ag, n_rollout, n_step] + + # ! WOSAC + self.minADE.update( + pred=pred_traj[tokenized_agent["ego_mask"]], + target=data["agent"]["position"][ + :, self.num_historical_steps :, : pred_traj.shape[-1] + ][tokenized_agent["ego_mask"]], + target_valid=data["agent"]["valid_mask"][ + :, self.num_historical_steps : + ][tokenized_agent["ego_mask"]], + ) + + # WOSAC metrics + if batch_idx < self.n_batch_wosac_metric: + device = pred_traj.device + scenario_rollouts = get_scenario_rollouts( + scenario_id=get_scenario_id_int_tensor(data["scenario_id"], device), + agent_id=data["agent"]["id"], + agent_batch=data["agent"]["batch"], + pred_traj=pred_traj, + pred_z=pred_z, + pred_head=pred_head, + ) + self.wosac_metrics.update(data["tfrecord_path"], scenario_rollouts) + + # ! visualization + if self.global_rank == 0 and batch_idx < self.n_vis_batch: + device = pred_traj.device + scenario_rollouts = get_scenario_rollouts( + scenario_id=get_scenario_id_int_tensor(data["scenario_id"], device), + agent_id=data["agent"]["id"][tokenized_agent["ego_mask"]], + agent_batch=data["agent"]["batch"][tokenized_agent["ego_mask"]], + pred_traj=pred_traj[tokenized_agent["ego_mask"]], + pred_z=pred_z[tokenized_agent["ego_mask"]], + pred_head=pred_head[tokenized_agent["ego_mask"]], + ) + for _i_sc in range(self.n_vis_scenario): + _vis = VisWaymo( + scenario_path=data["tfrecord_path"][_i_sc], + save_dir=self.video_dir + / f"batch_{batch_idx:02d}-scenario_{_i_sc:02d}", + ) + _vis.save_video_scenario_rollout( + scenario_rollouts[_i_sc], self.n_vis_rollout + ) + for _path in _vis.video_paths: + self.logger.log_video("/".join(_path.split("/")[-3:]), [_path]) + + def on_validation_epoch_end(self): + if self.val_closed_loop: + epoch_wosac_metrics = self.wosac_metrics.compute() + epoch_wosac_metrics["val_closed/ADE"] = self.minADE.compute() + if self.global_rank == 0: + epoch_wosac_metrics["epoch"] = ( + self.log_epoch if self.log_epoch >= 0 else self.current_epoch + ) + self.logger.log_metrics(epoch_wosac_metrics) + + self.wosac_metrics.reset() + self.minADE.reset() + + def configure_optimizers(self): + optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) + + def lr_lambda(current_step): + current_step = self.current_epoch + 1 + if current_step < self.lr_warmup_steps: + return ( + self.lr_min_ratio + + (1 - self.lr_min_ratio) * current_step / self.lr_warmup_steps + ) + return self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * ( + 1.0 + + math.cos( + math.pi + * min( + 1.0, + (current_step - self.lr_warmup_steps) + / (self.lr_total_steps - self.lr_warmup_steps), + ) + ) + ) + + lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) + return [optimizer], [lr_scheduler] diff --git a/backups/thirdparty/catk/src/smart/model/smart.py b/backups/thirdparty/catk/src/smart/model/smart.py new file mode 100644 index 0000000000000000000000000000000000000000..35b54c0772ea35fa16efef1785a463cf1b186bbb --- /dev/null +++ b/backups/thirdparty/catk/src/smart/model/smart.py @@ -0,0 +1,284 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import math +from pathlib import Path + +import hydra +import torch +from lightning import LightningModule +from torch.optim.lr_scheduler import LambdaLR + +from src.smart.metrics import ( + CrossEntropy, + TokenCls, + WOSACMetrics, + WOSACSubmission, + minADE, +) +from src.smart.modules.smart_decoder import SMARTDecoder +from src.smart.tokens.token_processor import TokenProcessor +from src.smart.utils.finetune import set_model_for_finetuning +from src.utils.vis_waymo import VisWaymo +from src.utils.wosac_utils import get_scenario_id_int_tensor, get_scenario_rollouts + + +class SMART(LightningModule): + + def __init__(self, model_config) -> None: + super(SMART, self).__init__() + self.save_hyperparameters() + self.lr = model_config.lr + self.lr_warmup_steps = model_config.lr_warmup_steps + self.lr_total_steps = model_config.lr_total_steps + self.lr_min_ratio = model_config.lr_min_ratio + self.num_historical_steps = model_config.decoder.num_historical_steps + self.log_epoch = -1 + self.val_open_loop = model_config.val_open_loop + self.val_closed_loop = model_config.val_closed_loop + self.token_processor = TokenProcessor(**model_config.token_processor) + + self.encoder = SMARTDecoder( + **model_config.decoder, n_token_agent=self.token_processor.n_token_agent + ) + set_model_for_finetuning(self.encoder, model_config.finetune) + + self.minADE = minADE() + self.TokenCls = TokenCls(max_guesses=5) + self.wosac_metrics = WOSACMetrics("val_closed") + self.wosac_submission = WOSACSubmission(**model_config.wosac_submission) + self.training_loss = CrossEntropy(**model_config.training_loss) + + self.n_rollout_closed_val = model_config.n_rollout_closed_val + self.n_vis_batch = model_config.n_vis_batch + self.n_vis_scenario = model_config.n_vis_scenario + self.n_vis_rollout = model_config.n_vis_rollout + self.n_batch_wosac_metric = model_config.n_batch_wosac_metric + + self.video_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir + self.video_dir = Path(self.video_dir) / "videos" + + self.training_rollout_sampling = model_config.training_rollout_sampling + self.validation_rollout_sampling = model_config.validation_rollout_sampling + + def training_step(self, data, batch_idx): + tokenized_map, tokenized_agent = self.token_processor(data) + if self.training_rollout_sampling.num_k <= 0: + pred = self.encoder(tokenized_map, tokenized_agent) + else: + pred = self.encoder.inference( + tokenized_map, + tokenized_agent, + sampling_scheme=self.training_rollout_sampling, + ) + + loss = self.training_loss( + **pred, + token_agent_shape=tokenized_agent["token_agent_shape"], # [n_agent, 2] + token_traj=tokenized_agent["token_traj"], # [n_agent, n_token, 4, 2] + train_mask=data["agent"]["train_mask"], # [n_agent] + current_epoch=self.current_epoch, + ) + self.log("train/loss", loss, on_step=True, batch_size=1) + + return loss + + def validation_step(self, data, batch_idx): + tokenized_map, tokenized_agent = self.token_processor(data) + + # ! open-loop vlidation + if self.val_open_loop: + pred = self.encoder(tokenized_map, tokenized_agent) + loss = self.training_loss( + **pred, + token_agent_shape=tokenized_agent["token_agent_shape"], # [n_agent, 2] + token_traj=tokenized_agent["token_traj"], # [n_agent, n_token, 4, 2] + ) + + self.TokenCls.update( + # action that goes from [(10->15), ..., (85->90)] + pred=pred["next_token_logits"], # [n_agent, 16, n_token] + pred_valid=pred["next_token_valid"], # [n_agent, 16] + target=tokenized_agent["gt_idx"][:, 2:], + target_valid=tokenized_agent["valid_mask"][:, 2:], + ) + self.log( + "val_open/acc", + self.TokenCls, + on_epoch=True, + sync_dist=True, + batch_size=1, + ) + self.log("val_open/loss", loss, on_epoch=True, sync_dist=True, batch_size=1) + + # ! closed-loop vlidation + if self.val_closed_loop: + pred_traj, pred_z, pred_head = [], [], [] + for _ in range(self.n_rollout_closed_val): + pred = self.encoder.inference( + tokenized_map, tokenized_agent, self.validation_rollout_sampling + ) + pred_traj.append(pred["pred_traj_10hz"]) + pred_z.append(pred["pred_z_10hz"]) + pred_head.append(pred["pred_head_10hz"]) + + pred_traj = torch.stack(pred_traj, dim=1) # [n_ag, n_rollout, n_step, 2] + pred_z = torch.stack(pred_z, dim=1) # [n_ag, n_rollout, n_step] + pred_head = torch.stack(pred_head, dim=1) # [n_ag, n_rollout, n_step] + + # ! WOSAC + scenario_rollouts = None + if self.wosac_submission.is_active: # ! save WOSAC submission + self.wosac_submission.update( + scenario_id=data["scenario_id"], + agent_id=data["agent"]["id"], + agent_batch=data["agent"]["batch"], + pred_traj=pred_traj, + pred_z=pred_z, + pred_head=pred_head, + global_rank=self.global_rank, + ) + _gpu_dict_sync = self.wosac_submission.compute() + if self.global_rank == 0: + for k in _gpu_dict_sync.keys(): # single gpu fix + if type(_gpu_dict_sync[k]) is list: + _gpu_dict_sync[k] = _gpu_dict_sync[k][0] + scenario_rollouts = get_scenario_rollouts(**_gpu_dict_sync) + self.wosac_submission.aggregate_rollouts(scenario_rollouts) + self.wosac_submission.reset() + + else: # ! compute metrics, disable if save WOSAC submission + self.minADE.update( + pred=pred_traj, + target=data["agent"]["position"][ + :, self.num_historical_steps :, : pred_traj.shape[-1] + ], + target_valid=data["agent"]["valid_mask"][ + :, self.num_historical_steps : + ], + ) + + # WOSAC metrics + if batch_idx < self.n_batch_wosac_metric: + device = pred_traj.device + scenario_rollouts = get_scenario_rollouts( + scenario_id=get_scenario_id_int_tensor( + data["scenario_id"], device + ), + agent_id=data["agent"]["id"], + agent_batch=data["agent"]["batch"], + pred_traj=pred_traj, + pred_z=pred_z, + pred_head=pred_head, + ) + self.wosac_metrics.update(data["tfrecord_path"], scenario_rollouts) + + # ! visualization + if self.global_rank == 0 and batch_idx < self.n_vis_batch: + if scenario_rollouts is not None: + for _i_sc in range(self.n_vis_scenario): + _vis = VisWaymo( + scenario_path=data["tfrecord_path"][_i_sc], + save_dir=self.video_dir + / f"batch_{batch_idx:02d}-scenario_{_i_sc:02d}", + ) + _vis.save_video_scenario_rollout( + scenario_rollouts[_i_sc], self.n_vis_rollout + ) + for _path in _vis.video_paths: + self.logger.log_video( + "/".join(_path.split("/")[-3:]), [_path] + ) + + def on_validation_epoch_end(self): + if self.val_closed_loop: + if not self.wosac_submission.is_active: + epoch_wosac_metrics = self.wosac_metrics.compute() + epoch_wosac_metrics["val_closed/ADE"] = self.minADE.compute() + if self.global_rank == 0: + epoch_wosac_metrics["epoch"] = ( + self.log_epoch if self.log_epoch >= 0 else self.current_epoch + ) + self.logger.log_metrics(epoch_wosac_metrics) + + self.wosac_metrics.reset() + self.minADE.reset() + + if self.global_rank == 0: + if self.wosac_submission.is_active: + self.wosac_submission.save_sub_file() + + def configure_optimizers(self): + optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) + + def lr_lambda(current_step): + current_step = self.current_epoch + 1 + if current_step < self.lr_warmup_steps: + return ( + self.lr_min_ratio + + (1 - self.lr_min_ratio) * current_step / self.lr_warmup_steps + ) + return self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * ( + 1.0 + + math.cos( + math.pi + * min( + 1.0, + (current_step - self.lr_warmup_steps) + / (self.lr_total_steps - self.lr_warmup_steps), + ) + ) + ) + + lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) + return [optimizer], [lr_scheduler] + + def test_step(self, data, batch_idx): + tokenized_map, tokenized_agent = self.token_processor(data) + + # ! only closed-loop vlidation + pred_traj, pred_z, pred_head = [], [], [] + for _ in range(self.n_rollout_closed_val): + pred = self.encoder.inference( + tokenized_map, tokenized_agent, self.validation_rollout_sampling + ) + pred_traj.append(pred["pred_traj_10hz"]) + pred_z.append(pred["pred_z_10hz"]) + pred_head.append(pred["pred_head_10hz"]) + + pred_traj = torch.stack(pred_traj, dim=1) # [n_ag, n_rollout, n_step, 2] + pred_z = torch.stack(pred_z, dim=1) # [n_ag, n_rollout, n_step] + pred_head = torch.stack(pred_head, dim=1) # [n_ag, n_rollout, n_step] + + # ! WOSAC submission save + self.wosac_submission.update( + scenario_id=data["scenario_id"], + agent_id=data["agent"]["id"], + agent_batch=data["agent"]["batch"], + pred_traj=pred_traj, + pred_z=pred_z, + pred_head=pred_head, + global_rank=self.global_rank, + ) + _gpu_dict_sync = self.wosac_submission.compute() + if self.global_rank == 0: + for k in _gpu_dict_sync.keys(): # single gpu fix + if type(_gpu_dict_sync[k]) is list: + _gpu_dict_sync[k] = _gpu_dict_sync[k][0] + scenario_rollouts = get_scenario_rollouts(**_gpu_dict_sync) + self.wosac_submission.aggregate_rollouts(scenario_rollouts) + self.wosac_submission.reset() + + def on_test_epoch_end(self): + if self.global_rank == 0: + self.wosac_submission.save_sub_file() diff --git a/backups/thirdparty/catk/src/smart/modules/__init__.py b/backups/thirdparty/catk/src/smart/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backups/thirdparty/catk/src/smart/modules/agent_decoder.py b/backups/thirdparty/catk/src/smart/modules/agent_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a025235b82b18db773c4fac5a00df3587172fe7f --- /dev/null +++ b/backups/thirdparty/catk/src/smart/modules/agent_decoder.py @@ -0,0 +1,746 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from typing import Dict, Optional + +import torch +import torch.nn as nn +from omegaconf import DictConfig +from torch_cluster import radius, radius_graph +from torch_geometric.utils import dense_to_sparse, subgraph + +from src.smart.layers import MLPLayer +from src.smart.layers.attention_layer import AttentionLayer +from src.smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding +from src.smart.utils import ( + angle_between_2d_vectors, + sample_next_token_traj, + transform_to_global, + weight_init, + wrap_angle, +) + + +class SMARTAgentDecoder(nn.Module): + + def __init__( + self, + hidden_dim: int, + num_historical_steps: int, + num_future_steps: int, + time_span: Optional[int], + pl2a_radius: float, + a2a_radius: float, + num_freq_bands: int, + num_layers: int, + num_heads: int, + head_dim: int, + dropout: float, + hist_drop_prob: float, + n_token_agent: int, + ) -> None: + super(SMARTAgentDecoder, self).__init__() + self.hidden_dim = hidden_dim + self.num_historical_steps = num_historical_steps + self.num_future_steps = num_future_steps + self.time_span = time_span if time_span is not None else num_historical_steps + self.pl2a_radius = pl2a_radius + self.a2a_radius = a2a_radius + self.num_layers = num_layers + self.shift = 5 + self.hist_drop_prob = hist_drop_prob + + input_dim_x_a = 2 + input_dim_r_t = 4 + input_dim_r_pt2a = 3 + input_dim_r_a2a = 3 + input_dim_token = 8 + + self.type_a_emb = nn.Embedding(3, hidden_dim) + self.shape_emb = MLPLayer(3, hidden_dim, hidden_dim) + + self.x_a_emb = FourierEmbedding( + input_dim=input_dim_x_a, + hidden_dim=hidden_dim, + num_freq_bands=num_freq_bands, + ) + self.r_t_emb = FourierEmbedding( + input_dim=input_dim_r_t, + hidden_dim=hidden_dim, + num_freq_bands=num_freq_bands, + ) + self.r_pt2a_emb = FourierEmbedding( + input_dim=input_dim_r_pt2a, + hidden_dim=hidden_dim, + num_freq_bands=num_freq_bands, + ) + self.r_a2a_emb = FourierEmbedding( + input_dim=input_dim_r_a2a, + hidden_dim=hidden_dim, + num_freq_bands=num_freq_bands, + ) + self.token_emb_veh = MLPEmbedding( + input_dim=input_dim_token, hidden_dim=hidden_dim + ) + self.token_emb_ped = MLPEmbedding( + input_dim=input_dim_token, hidden_dim=hidden_dim + ) + self.token_emb_cyc = MLPEmbedding( + input_dim=input_dim_token, hidden_dim=hidden_dim + ) + self.fusion_emb = MLPEmbedding( + input_dim=self.hidden_dim * 2, hidden_dim=self.hidden_dim + ) + + self.t_attn_layers = nn.ModuleList( + [ + AttentionLayer( + hidden_dim=hidden_dim, + num_heads=num_heads, + head_dim=head_dim, + dropout=dropout, + bipartite=False, + has_pos_emb=True, + ) + for _ in range(num_layers) + ] + ) + self.pt2a_attn_layers = nn.ModuleList( + [ + AttentionLayer( + hidden_dim=hidden_dim, + num_heads=num_heads, + head_dim=head_dim, + dropout=dropout, + bipartite=True, + has_pos_emb=True, + ) + for _ in range(num_layers) + ] + ) + self.a2a_attn_layers = nn.ModuleList( + [ + AttentionLayer( + hidden_dim=hidden_dim, + num_heads=num_heads, + head_dim=head_dim, + dropout=dropout, + bipartite=False, + has_pos_emb=True, + ) + for _ in range(num_layers) + ] + ) + self.token_predict_head = MLPLayer( + input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=n_token_agent + ) + self.apply(weight_init) + + def agent_token_embedding( + self, + agent_token_index, # [n_agent, n_step] + trajectory_token_veh, # [n_token, 8] + trajectory_token_ped, # [n_token, 8] + trajectory_token_cyc, # [n_token, 8] + pos_a, # [n_agent, n_step, 2] + head_vector_a, # [n_agent, n_step, 2] + agent_type, # [n_agent] + agent_shape, # [n_agent, 3] + inference=False, + ): + n_agent, n_step, traj_dim = pos_a.shape + _device = pos_a.device + + veh_mask = agent_type == 0 + ped_mask = agent_type == 1 + cyc_mask = agent_type == 2 + # [n_token, hidden_dim] + agent_token_emb_veh = self.token_emb_veh(trajectory_token_veh) + agent_token_emb_ped = self.token_emb_ped(trajectory_token_ped) + agent_token_emb_cyc = self.token_emb_cyc(trajectory_token_cyc) + agent_token_emb = torch.zeros( + (n_agent, n_step, self.hidden_dim), device=_device, dtype=pos_a.dtype + ) + agent_token_emb[veh_mask] = agent_token_emb_veh[agent_token_index[veh_mask]] + agent_token_emb[ped_mask] = agent_token_emb_ped[agent_token_index[ped_mask]] + agent_token_emb[cyc_mask] = agent_token_emb_cyc[agent_token_index[cyc_mask]] + + motion_vector_a = torch.cat( + [ + pos_a.new_zeros(agent_token_index.shape[0], 1, traj_dim), + pos_a[:, 1:] - pos_a[:, :-1], + ], + dim=1, + ) # [n_agent, n_step, 2] + feature_a = torch.stack( + [ + torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1), + angle_between_2d_vectors( + ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2] + ), + ], + dim=-1, + ) # [n_agent, n_step, 2] + categorical_embs = [ + self.type_a_emb(agent_type.long()), + self.shape_emb(agent_shape), + ] # List of len=2, shape [n_agent, hidden_dim] + + x_a = self.x_a_emb( + continuous_inputs=feature_a.view(-1, feature_a.size(-1)), + categorical_embs=[ + v.repeat_interleave(repeats=n_step, dim=0) for v in categorical_embs + ], + ) # [n_agent*n_step, hidden_dim] + x_a = x_a.view(-1, n_step, self.hidden_dim) # [n_agent, n_step, hidden_dim] + + feat_a = torch.cat((agent_token_emb, x_a), dim=-1) + feat_a = self.fusion_emb(feat_a) + + if inference: + return ( + feat_a, # [n_agent, n_step, hidden_dim] + agent_token_emb, # [n_agent, n_step, hidden_dim] + agent_token_emb_veh, # [n_agent, hidden_dim] + agent_token_emb_ped, # [n_agent, hidden_dim] + agent_token_emb_cyc, # [n_agent, hidden_dim] + veh_mask, # [n_agent] + ped_mask, # [n_agent] + cyc_mask, # [n_agent] + categorical_embs, # List of len=2, shape [n_agent, hidden_dim] + ) + else: + return feat_a # [n_agent, n_step, hidden_dim] + + def build_temporal_edge( + self, + pos_a, # [n_agent, n_step, 2] + head_a, # [n_agent, n_step] + head_vector_a, # [n_agent, n_step, 2], + mask, # [n_agent, n_step] + inference_mask=None, # [n_agent, n_step] + ): + pos_t = pos_a.flatten(0, 1) + head_t = head_a.flatten(0, 1) + head_vector_t = head_vector_a.flatten(0, 1) + + if self.hist_drop_prob > 0 and self.training: + _mask_keep = torch.bernoulli( + torch.ones_like(mask) * (1 - self.hist_drop_prob) + ).bool() + mask = mask & _mask_keep + + if inference_mask is not None: + mask_t = mask.unsqueeze(2) & inference_mask.unsqueeze(1) + else: + mask_t = mask.unsqueeze(2) & mask.unsqueeze(1) + + edge_index_t = dense_to_sparse(mask_t)[0] + edge_index_t = edge_index_t[:, edge_index_t[1] > edge_index_t[0]] + edge_index_t = edge_index_t[ + :, edge_index_t[1] - edge_index_t[0] <= self.time_span / self.shift + ] + rel_pos_t = pos_t[edge_index_t[0]] - pos_t[edge_index_t[1]] + rel_pos_t = rel_pos_t[:, :2] + rel_head_t = wrap_angle(head_t[edge_index_t[0]] - head_t[edge_index_t[1]]) + r_t = torch.stack( + [ + torch.norm(rel_pos_t, p=2, dim=-1), + angle_between_2d_vectors( + ctr_vector=head_vector_t[edge_index_t[1]], nbr_vector=rel_pos_t + ), + rel_head_t, + edge_index_t[0] - edge_index_t[1], + ], + dim=-1, + ) + r_t = self.r_t_emb(continuous_inputs=r_t, categorical_embs=None) + return edge_index_t, r_t + + def build_interaction_edge( + self, + pos_a, # [n_agent, n_step, 2] + head_a, # [n_agent, n_step] + head_vector_a, # [n_agent, n_step, 2] + batch_s, # [n_agent*n_step] + mask, # [n_agent, n_step] + ): + mask = mask.transpose(0, 1).reshape(-1) + pos_s = pos_a.transpose(0, 1).flatten(0, 1) + head_s = head_a.transpose(0, 1).reshape(-1) + head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2) + edge_index_a2a = radius_graph( + x=pos_s[:, :2], + r=self.a2a_radius, + batch=batch_s, + loop=False, + max_num_neighbors=300, + ) + edge_index_a2a = subgraph(subset=mask, edge_index=edge_index_a2a)[0] + rel_pos_a2a = pos_s[edge_index_a2a[0]] - pos_s[edge_index_a2a[1]] + rel_head_a2a = wrap_angle(head_s[edge_index_a2a[0]] - head_s[edge_index_a2a[1]]) + r_a2a = torch.stack( + [ + torch.norm(rel_pos_a2a[:, :2], p=2, dim=-1), + angle_between_2d_vectors( + ctr_vector=head_vector_s[edge_index_a2a[1]], + nbr_vector=rel_pos_a2a[:, :2], + ), + rel_head_a2a, + ], + dim=-1, + ) + r_a2a = self.r_a2a_emb(continuous_inputs=r_a2a, categorical_embs=None) + return edge_index_a2a, r_a2a + + def build_map2agent_edge( + self, + pos_pl, # [n_pl, 2] + orient_pl, # [n_pl] + pos_a, # [n_agent, n_step, 2] + head_a, # [n_agent, n_step] + head_vector_a, # [n_agent, n_step, 2] + mask, # [n_agent, n_step] + batch_s, # [n_agent*n_step] + batch_pl, # [n_pl*n_step] + ): + n_step = pos_a.shape[1] + mask_pl2a = mask.transpose(0, 1).reshape(-1) + pos_s = pos_a.transpose(0, 1).flatten(0, 1) + head_s = head_a.transpose(0, 1).reshape(-1) + head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2) + pos_pl = pos_pl.repeat(n_step, 1) + orient_pl = orient_pl.repeat(n_step) + edge_index_pl2a = radius( + x=pos_s[:, :2], + y=pos_pl[:, :2], + r=self.pl2a_radius, + batch_x=batch_s, + batch_y=batch_pl, + max_num_neighbors=300, + ) + edge_index_pl2a = edge_index_pl2a[:, mask_pl2a[edge_index_pl2a[1]]] + rel_pos_pl2a = pos_pl[edge_index_pl2a[0]] - pos_s[edge_index_pl2a[1]] + rel_orient_pl2a = wrap_angle( + orient_pl[edge_index_pl2a[0]] - head_s[edge_index_pl2a[1]] + ) + r_pl2a = torch.stack( + [ + torch.norm(rel_pos_pl2a[:, :2], p=2, dim=-1), + angle_between_2d_vectors( + ctr_vector=head_vector_s[edge_index_pl2a[1]], + nbr_vector=rel_pos_pl2a[:, :2], + ), + rel_orient_pl2a, + ], + dim=-1, + ) + r_pl2a = self.r_pt2a_emb(continuous_inputs=r_pl2a, categorical_embs=None) + return edge_index_pl2a, r_pl2a + + def forward( + self, + tokenized_agent: Dict[str, torch.Tensor], + map_feature: Dict[str, torch.Tensor], + ) -> Dict[str, torch.Tensor]: + mask = tokenized_agent["valid_mask"] + pos_a = tokenized_agent["sampled_pos"] + head_a = tokenized_agent["sampled_heading"] + head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1) + n_agent, n_step = head_a.shape + + # ! get agent token embeddings + feat_a = self.agent_token_embedding( + agent_token_index=tokenized_agent["sampled_idx"], # [n_ag, n_step] + trajectory_token_veh=tokenized_agent["trajectory_token_veh"], + trajectory_token_ped=tokenized_agent["trajectory_token_ped"], + trajectory_token_cyc=tokenized_agent["trajectory_token_cyc"], + pos_a=pos_a, # [n_agent, n_step, 2] + head_vector_a=head_vector_a, # [n_agent, n_step, 2] + agent_type=tokenized_agent["type"], # [n_agent] + agent_shape=tokenized_agent["shape"], # [n_agent, 3] + ) # feat_a: [n_agent, n_step, hidden_dim] + + # ! build temporal, interaction and map2agent edges + edge_index_t, r_t = self.build_temporal_edge( + pos_a=pos_a, # [n_agent, n_step, 2] + head_a=head_a, # [n_agent, n_step] + head_vector_a=head_vector_a, # [n_agent, n_step, 2] + mask=mask, # [n_agent, n_step] + ) # edge_index_t: [2, n_edge_t], r_t: [n_edge_t, hidden_dim] + + batch_s = torch.cat( + [ + tokenized_agent["batch"] + tokenized_agent["num_graphs"] * t + for t in range(n_step) + ], + dim=0, + ) # [n_agent*n_step] + batch_pl = torch.cat( + [ + map_feature["batch"] + tokenized_agent["num_graphs"] * t + for t in range(n_step) + ], + dim=0, + ) # [n_pl*n_step] + + edge_index_a2a, r_a2a = self.build_interaction_edge( + pos_a=pos_a, # [n_agent, n_step, 2] + head_a=head_a, # [n_agent, n_step] + head_vector_a=head_vector_a, # [n_agent, n_step, 2] + batch_s=batch_s, # [n_agent*n_step] + mask=mask, # [n_agent, n_step] + ) # edge_index_a2a: [2, n_edge_a2a], r_a2a: [n_edge_a2a, hidden_dim] + + edge_index_pl2a, r_pl2a = self.build_map2agent_edge( + pos_pl=map_feature["position"], # [n_pl, 2] + orient_pl=map_feature["orientation"], # [n_pl] + pos_a=pos_a, # [n_agent, n_step, 2] + head_a=head_a, # [n_agent, n_step] + head_vector_a=head_vector_a, # [n_agent, n_step, 2] + mask=mask, # [n_agent, n_step] + batch_s=batch_s, # [n_agent*n_step] + batch_pl=batch_pl, # [n_pl*n_step] + ) + + # ! attention layers + # [n_step*n_pl, hidden_dim] + feat_map = ( + map_feature["pt_token"].unsqueeze(0).expand(n_step, -1, -1).flatten(0, 1) + ) + + for i in range(self.num_layers): + feat_a = feat_a.flatten(0, 1) # [n_agent*n_step, hidden_dim] + feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t) + # [n_step*n_agent, hidden_dim] + feat_a = feat_a.view(n_agent, n_step, -1).transpose(0, 1).flatten(0, 1) + feat_a = self.pt2a_attn_layers[i]( + (feat_map, feat_a), r_pl2a, edge_index_pl2a + ) + feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a) + feat_a = feat_a.view(n_step, n_agent, -1).transpose(0, 1) + + # ! final mlp to get outputs + next_token_logits = self.token_predict_head(feat_a) + + return { + # action that goes from [(10->15), ..., (85->90)] + "next_token_logits": next_token_logits[:, 1:-1], # [n_agent, 16, n_token] + "next_token_valid": tokenized_agent["valid_mask"][:, 1:-1], # [n_agent, 16] + # for step {5, 10, ..., 90} and act [(0->5), (5->10), ..., (85->90)] + "pred_pos": tokenized_agent["sampled_pos"], # [n_agent, 18, 2] + "pred_head": tokenized_agent["sampled_heading"], # [n_agent, 18] + "pred_valid": tokenized_agent["valid_mask"], # [n_agent, 18] + # for step {5, 10, ..., 90} + "gt_pos_raw": tokenized_agent["gt_pos_raw"], # [n_agent, 18, 2] + "gt_head_raw": tokenized_agent["gt_head_raw"], # [n_agent, 18] + "gt_valid_raw": tokenized_agent["gt_valid_raw"], # [n_agent, 18] + # or use the tokenized gt + "gt_pos": tokenized_agent["gt_pos"], # [n_agent, 18, 2] + "gt_head": tokenized_agent["gt_heading"], # [n_agent, 18] + "gt_valid": tokenized_agent["valid_mask"], # [n_agent, 18] + } + + def inference( + self, + tokenized_agent: Dict[str, torch.Tensor], + map_feature: Dict[str, torch.Tensor], + sampling_scheme: DictConfig, + ) -> Dict[str, torch.Tensor]: + n_agent = tokenized_agent["valid_mask"].shape[0] + n_step_future_10hz = self.num_future_steps # 80 + n_step_future_2hz = n_step_future_10hz // self.shift # 16 + step_current_10hz = self.num_historical_steps - 1 # 10 + step_current_2hz = step_current_10hz // self.shift # 2 + + pos_a = tokenized_agent["gt_pos"][:, :step_current_2hz].clone() + head_a = tokenized_agent["gt_heading"][:, :step_current_2hz].clone() + head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1) + pred_idx = tokenized_agent["gt_idx"].clone() + ( + feat_a, # [n_agent, step_current_2hz, hidden_dim] + agent_token_emb, # [n_agent, step_current_2hz, hidden_dim] + agent_token_emb_veh, # [n_agent, hidden_dim] + agent_token_emb_ped, # [n_agent, hidden_dim] + agent_token_emb_cyc, # [n_agent, hidden_dim] + veh_mask, # [n_agent] + ped_mask, # [n_agent] + cyc_mask, # [n_agent] + categorical_embs, # List of len=2, shape [n_agent, hidden_dim] + ) = self.agent_token_embedding( + agent_token_index=tokenized_agent["gt_idx"][:, :step_current_2hz], + trajectory_token_veh=tokenized_agent["trajectory_token_veh"], + trajectory_token_ped=tokenized_agent["trajectory_token_ped"], + trajectory_token_cyc=tokenized_agent["trajectory_token_cyc"], + pos_a=pos_a, + head_vector_a=head_vector_a, + agent_type=tokenized_agent["type"], + agent_shape=tokenized_agent["shape"], + inference=True, + ) + + if not self.training: + pred_traj_10hz = torch.zeros( + [n_agent, n_step_future_10hz, 2], dtype=pos_a.dtype, device=pos_a.device + ) + pred_head_10hz = torch.zeros( + [n_agent, n_step_future_10hz], dtype=pos_a.dtype, device=pos_a.device + ) + + pred_valid = tokenized_agent["valid_mask"].clone() + next_token_logits_list = [] + next_token_action_list = [] + feat_a_t_dict = {} + for t in range(n_step_future_2hz): # 0 -> 15 + t_now = step_current_2hz - 1 + t # 1 -> 16 + n_step = t_now + 1 # 2 -> 17 + + if t == 0: # init + hist_step = step_current_2hz + batch_s = torch.cat( + [ + tokenized_agent["batch"] + tokenized_agent["num_graphs"] * t + for t in range(hist_step) + ], + dim=0, + ) + batch_pl = torch.cat( + [ + map_feature["batch"] + tokenized_agent["num_graphs"] * t + for t in range(hist_step) + ], + dim=0, + ) + inference_mask = pred_valid[:, :n_step] + edge_index_t, r_t = self.build_temporal_edge( + pos_a=pos_a, + head_a=head_a, + head_vector_a=head_vector_a, + mask=pred_valid[:, :n_step], + ) + else: + hist_step = 1 + batch_s = tokenized_agent["batch"] + batch_pl = map_feature["batch"] + inference_mask = pred_valid[:, :n_step].clone() + inference_mask[:, :-1] = False + edge_index_t, r_t = self.build_temporal_edge( + pos_a=pos_a, + head_a=head_a, + head_vector_a=head_vector_a, + mask=pred_valid[:, :n_step], + inference_mask=inference_mask, + ) + edge_index_t[1] = (edge_index_t[1] + 1) // n_step - 1 + + # In the inference stage, we only infer the current stage for recurrent + edge_index_pl2a, r_pl2a = self.build_map2agent_edge( + pos_pl=map_feature["position"], # [n_pl, 2] + orient_pl=map_feature["orientation"], # [n_pl] + pos_a=pos_a[:, -hist_step:], # [n_agent, hist_step, 2] + head_a=head_a[:, -hist_step:], # [n_agent, hist_step] + head_vector_a=head_vector_a[:, -hist_step:], # [n_agent, hist_step, 2] + mask=inference_mask[:, -hist_step:], # [n_agent, hist_step] + batch_s=batch_s, # [n_agent*hist_step] + batch_pl=batch_pl, # [n_pl*hist_step] + ) + edge_index_a2a, r_a2a = self.build_interaction_edge( + pos_a=pos_a[:, -hist_step:], # [n_agent, hist_step, 2] + head_a=head_a[:, -hist_step:], # [n_agent, hist_step] + head_vector_a=head_vector_a[:, -hist_step:], # [n_agent, hist_step, 2] + batch_s=batch_s, # [n_agent*hist_step] + mask=inference_mask[:, -hist_step:], # [n_agent, hist_step] + ) + + # ! attention layers + for i in range(self.num_layers): + # [n_agent, n_step, hidden_dim] + _feat_temporal = feat_a if i == 0 else feat_a_t_dict[i] + + if t == 0: # init, process hist_step together + _feat_temporal = self.t_attn_layers[i]( + _feat_temporal.flatten(0, 1), r_t, edge_index_t + ).view(n_agent, n_step, -1) + _feat_temporal = _feat_temporal.transpose(0, 1).flatten(0, 1) + + # [hist_step*n_pl, hidden_dim] + _feat_map = ( + map_feature["pt_token"] + .unsqueeze(0) + .expand(hist_step, -1, -1) + .flatten(0, 1) + ) + + _feat_temporal = self.pt2a_attn_layers[i]( + (_feat_map, _feat_temporal), r_pl2a, edge_index_pl2a + ) + _feat_temporal = self.a2a_attn_layers[i]( + _feat_temporal, r_a2a, edge_index_a2a + ) + _feat_temporal = _feat_temporal.view(n_step, n_agent, -1).transpose( + 0, 1 + ) + feat_a_now = _feat_temporal[:, -1] # [n_agent, hidden_dim] + + if i + 1 < self.num_layers: + feat_a_t_dict[i + 1] = _feat_temporal + + else: # process one step + feat_a_now = self.t_attn_layers[i]( + (_feat_temporal.flatten(0, 1), _feat_temporal[:, -1]), + r_t, + edge_index_t, + ) + # * give same results as below, but more efficient + # feat_a_now = self.t_attn_layers[i]( + # _feat_temporal.flatten(0, 1), r_t, edge_index_t + # ).view(n_agent, n_step, -1)[:, -1] + + feat_a_now = self.pt2a_attn_layers[i]( + (map_feature["pt_token"], feat_a_now), r_pl2a, edge_index_pl2a + ) + feat_a_now = self.a2a_attn_layers[i]( + feat_a_now, r_a2a, edge_index_a2a + ) + + # [n_agent, n_step, hidden_dim] + if i + 1 < self.num_layers: + feat_a_t_dict[i + 1] = torch.cat( + (feat_a_t_dict[i + 1], feat_a_now.unsqueeze(1)), dim=1 + ) + + # ! get outputs + next_token_logits = self.token_predict_head(feat_a_now) + next_token_logits_list.append(next_token_logits) # [n_agent, n_token] + + next_token_idx, next_token_traj_all = sample_next_token_traj( + token_traj=tokenized_agent["token_traj"], + token_traj_all=tokenized_agent["token_traj_all"], + sampling_scheme=sampling_scheme, + # ! for most-likely sampling + next_token_logits=next_token_logits, + # ! for nearest-pos sampling + pos_now=pos_a[:, t_now], # [n_agent, 2] + head_now=head_a[:, t_now], # [n_agent] + pos_next_gt=tokenized_agent["gt_pos_raw"][:, n_step], # [n_agent, 2] + head_next_gt=tokenized_agent["gt_head_raw"][:, n_step], # [n_agent] + valid_next_gt=tokenized_agent["gt_valid_raw"][:, n_step], # [n_agent] + token_agent_shape=tokenized_agent["token_agent_shape"], # [n_token, 2] + ) # next_token_idx: [n_agent], next_token_traj_all: [n_agent, 6, 4, 2] + + diff_xy = next_token_traj_all[:, -1, 0] - next_token_traj_all[:, -1, 3] + next_token_action_list.append( + torch.cat( + [ + next_token_traj_all[:, -1].mean(1), # [n_agent, 2] + torch.arctan2(diff_xy[:, [1]], diff_xy[:, [0]]), # [n_agent, 1] + ], + dim=-1, + ) # [n_agent, 3] + ) + + token_traj_global = transform_to_global( + pos_local=next_token_traj_all.flatten(1, 2), # [n_agent, 6*4, 2] + head_local=None, + pos_now=pos_a[:, t_now], # [n_agent, 2] + head_now=head_a[:, t_now], # [n_agent] + )[0].view(*next_token_traj_all.shape) + + if not self.training: + pred_traj_10hz[:, t * 5 : (t + 1) * 5] = token_traj_global[:, 1:].mean( + 2 + ) + diff_xy = token_traj_global[:, 1:, 0] - token_traj_global[:, 1:, 3] + pred_head_10hz[:, t * 5 : (t + 1) * 5] = torch.arctan2( + diff_xy[:, :, 1], diff_xy[:, :, 0] + ) + + # ! get pos_a_next and head_a_next, spawn unseen agents + pos_a_next = token_traj_global[:, -1].mean(dim=1) + diff_xy_next = token_traj_global[:, -1, 0] - token_traj_global[:, -1, 3] + head_a_next = torch.arctan2(diff_xy_next[:, 1], diff_xy_next[:, 0]) + pred_idx[:, n_step] = next_token_idx + + # ! update tensors for for next step + pred_valid[:, n_step] = pred_valid[:, t_now] + # pred_valid[:, n_step] = pred_valid[:, t_now] | mask_spawn + pos_a = torch.cat([pos_a, pos_a_next.unsqueeze(1)], dim=1) + head_a = torch.cat([head_a, head_a_next.unsqueeze(1)], dim=1) + head_vector_a_next = torch.stack( + [head_a_next.cos(), head_a_next.sin()], dim=-1 + ) + head_vector_a = torch.cat( + [head_vector_a, head_vector_a_next.unsqueeze(1)], dim=1 + ) + + # ! get agent_token_emb_next + agent_token_emb_next = torch.zeros_like(agent_token_emb[:, 0]) + agent_token_emb_next[veh_mask] = agent_token_emb_veh[ + next_token_idx[veh_mask] + ] + agent_token_emb_next[ped_mask] = agent_token_emb_ped[ + next_token_idx[ped_mask] + ] + agent_token_emb_next[cyc_mask] = agent_token_emb_cyc[ + next_token_idx[cyc_mask] + ] + agent_token_emb = torch.cat( + [agent_token_emb, agent_token_emb_next.unsqueeze(1)], dim=1 + ) + + # ! get feat_a_next + motion_vector_a = pos_a[:, -1] - pos_a[:, -2] # [n_agent, 2] + x_a = torch.stack( + [ + torch.norm(motion_vector_a, p=2, dim=-1), + angle_between_2d_vectors( + ctr_vector=head_vector_a[:, -1], nbr_vector=motion_vector_a + ), + ], + dim=-1, + ) + # [n_agent, hidden_dim] + x_a = self.x_a_emb(continuous_inputs=x_a, categorical_embs=categorical_embs) + # [n_agent, 1, 2*hidden_dim] + feat_a_next = torch.cat((agent_token_emb_next, x_a), dim=-1).unsqueeze(1) + feat_a_next = self.fusion_emb(feat_a_next) + feat_a = torch.cat([feat_a, feat_a_next], dim=1) + + out_dict = { + # action that goes from [(10->15), ..., (85->90)] + "next_token_logits": torch.stack(next_token_logits_list, dim=1), + "next_token_valid": pred_valid[:, 1:-1], # [n_agent, 16] + # for step {5, 10, ..., 90} and act [(0->5), (5->10), ..., (85->90)] + "pred_pos": pos_a, # [n_agent, 18, 2] + "pred_head": head_a, # [n_agent, 18] + "pred_valid": pred_valid, # [n_agent, 18] + "pred_idx": pred_idx, # [n_agent, 18] + # for step {5, 10, ..., 90} + "gt_pos_raw": tokenized_agent["gt_pos_raw"], # [n_agent, 18, 2] + "gt_head_raw": tokenized_agent["gt_head_raw"], # [n_agent, 18] + "gt_valid_raw": tokenized_agent["gt_valid_raw"], # [n_agent, 18] + # or use the tokenized gt + "gt_pos": tokenized_agent["gt_pos"], # [n_agent, 18, 2] + "gt_head": tokenized_agent["gt_heading"], # [n_agent, 18] + "gt_valid": tokenized_agent["valid_mask"], # [n_agent, 18] + # for shifting proxy targets by lr + "next_token_action": torch.stack(next_token_action_list, dim=1), + } + + if not self.training: # 10hz predictions for wosac evaluation and submission + out_dict["pred_traj_10hz"] = pred_traj_10hz + out_dict["pred_head_10hz"] = pred_head_10hz + pred_z = tokenized_agent["gt_z_raw"].unsqueeze(1) # [n_agent, 1] + out_dict["pred_z_10hz"] = pred_z.expand(-1, pred_traj_10hz.shape[1]) + + return out_dict diff --git a/backups/thirdparty/catk/src/smart/modules/ego_gmm_agent_decoder.py b/backups/thirdparty/catk/src/smart/modules/ego_gmm_agent_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..2c56d0b3afa58fb15f4640219edd36c828329a74 --- /dev/null +++ b/backups/thirdparty/catk/src/smart/modules/ego_gmm_agent_decoder.py @@ -0,0 +1,775 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from typing import Dict, Optional + +import torch +import torch.nn as nn +from omegaconf import DictConfig, ListConfig +from torch_cluster import radius, radius_graph +from torch_geometric.utils import dense_to_sparse, subgraph + +from src.smart.layers import MLPLayer +from src.smart.layers.attention_layer import AttentionLayer +from src.smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding +from src.smart.utils import ( + angle_between_2d_vectors, + sample_next_gmm_traj, + transform_to_global, + weight_init, + wrap_angle, +) + + +class EgoGMMAgentDecoder(nn.Module): + + def __init__( + self, + hidden_dim: int, + num_historical_steps: int, + num_future_steps: int, + time_span: Optional[int], + pl2a_radius: float, + a2a_radius: float, + num_freq_bands: int, + num_layers: int, + num_heads: int, + head_dim: int, + dropout: float, + hist_drop_prob: float, + k_ego_gmm: int, + cov_ego_gmm: ListConfig[float], + cov_learnable: bool, + ) -> None: + super(EgoGMMAgentDecoder, self).__init__() + self.hidden_dim = hidden_dim + self.num_historical_steps = num_historical_steps + self.num_future_steps = num_future_steps + self.time_span = time_span if time_span is not None else num_historical_steps + self.pl2a_radius = pl2a_radius + self.a2a_radius = a2a_radius + self.num_layers = num_layers + self.shift = 5 + self.hist_drop_prob = hist_drop_prob + + input_dim_x_a = 2 + input_dim_r_t = 4 + input_dim_r_pt2a = 3 + input_dim_r_a2a = 3 + input_dim_token = 8 + + self.type_a_emb = nn.Embedding(3, hidden_dim) + self.shape_emb = MLPLayer(3, hidden_dim, hidden_dim) + + self.x_a_emb = FourierEmbedding( + input_dim=input_dim_x_a, + hidden_dim=hidden_dim, + num_freq_bands=num_freq_bands, + ) + self.r_t_emb = FourierEmbedding( + input_dim=input_dim_r_t, + hidden_dim=hidden_dim, + num_freq_bands=num_freq_bands, + ) + self.r_pt2a_emb = FourierEmbedding( + input_dim=input_dim_r_pt2a, + hidden_dim=hidden_dim, + num_freq_bands=num_freq_bands, + ) + self.r_a2a_emb = FourierEmbedding( + input_dim=input_dim_r_a2a, + hidden_dim=hidden_dim, + num_freq_bands=num_freq_bands, + ) + self.token_emb_veh = MLPEmbedding( + input_dim=input_dim_token, hidden_dim=hidden_dim + ) + self.token_emb_ped = MLPEmbedding( + input_dim=input_dim_token, hidden_dim=hidden_dim + ) + self.token_emb_cyc = MLPEmbedding( + input_dim=input_dim_token, hidden_dim=hidden_dim + ) + self.fusion_emb = MLPEmbedding( + input_dim=self.hidden_dim * 2, hidden_dim=self.hidden_dim + ) + + self.t_attn_layers = nn.ModuleList( + [ + AttentionLayer( + hidden_dim=hidden_dim, + num_heads=num_heads, + head_dim=head_dim, + dropout=dropout, + bipartite=False, + has_pos_emb=True, + ) + for _ in range(num_layers) + ] + ) + self.pt2a_attn_layers = nn.ModuleList( + [ + AttentionLayer( + hidden_dim=hidden_dim, + num_heads=num_heads, + head_dim=head_dim, + dropout=dropout, + bipartite=True, + has_pos_emb=True, + ) + for _ in range(num_layers) + ] + ) + self.a2a_attn_layers = nn.ModuleList( + [ + AttentionLayer( + hidden_dim=hidden_dim, + num_heads=num_heads, + head_dim=head_dim, + dropout=dropout, + bipartite=False, + has_pos_emb=True, + ) + for _ in range(num_layers) + ] + ) + self.gmm_logits_head = MLPLayer( + input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=k_ego_gmm + ) + self.gmm_pose_head = MLPLayer( + input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=k_ego_gmm * 3 + ) + + self.gmm_cov = torch.nn.Parameter( + torch.tensor(cov_ego_gmm, dtype=torch.float32), requires_grad=cov_learnable + ) + + self.apply(weight_init) + + def agent_token_embedding( + self, + agent_token_index, # [n_agent, n_step] + trajectory_token_veh, # [n_token, 8] + trajectory_token_ped, # [n_token, 8] + trajectory_token_cyc, # [n_token, 8] + pos_a, # [n_agent, n_step, 2] + head_vector_a, # [n_agent, n_step, 2] + agent_type, # [n_agent] + agent_shape, # [n_agent, 3] + inference=False, + ): + n_agent, n_step, traj_dim = pos_a.shape + _device = pos_a.device + + veh_mask = agent_type == 0 + ped_mask = agent_type == 1 + cyc_mask = agent_type == 2 + # [n_token, hidden_dim] + agent_token_emb_veh = self.token_emb_veh(trajectory_token_veh) + agent_token_emb_ped = self.token_emb_ped(trajectory_token_ped) + agent_token_emb_cyc = self.token_emb_cyc(trajectory_token_cyc) + agent_token_emb = torch.zeros( + (n_agent, n_step, self.hidden_dim), device=_device, dtype=pos_a.dtype + ) + agent_token_emb[veh_mask] = agent_token_emb_veh[agent_token_index[veh_mask]] + agent_token_emb[ped_mask] = agent_token_emb_ped[agent_token_index[ped_mask]] + agent_token_emb[cyc_mask] = agent_token_emb_cyc[agent_token_index[cyc_mask]] + + motion_vector_a = torch.cat( + [ + pos_a.new_zeros(agent_token_index.shape[0], 1, traj_dim), + pos_a[:, 1:] - pos_a[:, :-1], + ], + dim=1, + ) # [n_agent, n_step, 2] + feature_a = torch.stack( + [ + torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1), + angle_between_2d_vectors( + ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2] + ), + ], + dim=-1, + ) # [n_agent, n_step, 2] + categorical_embs = [ + self.type_a_emb(agent_type.long()), + self.shape_emb(agent_shape), + ] # List of len=2, shape [n_agent, hidden_dim] + + x_a = self.x_a_emb( + continuous_inputs=feature_a.view(-1, feature_a.size(-1)), + categorical_embs=[ + v.repeat_interleave(repeats=n_step, dim=0) for v in categorical_embs + ], + ) # [n_agent*n_step, hidden_dim] + x_a = x_a.view(-1, n_step, self.hidden_dim) # [n_agent, n_step, hidden_dim] + + feat_a = torch.cat((agent_token_emb, x_a), dim=-1) + feat_a = self.fusion_emb(feat_a) + + if inference: + return ( + feat_a, # [n_agent, n_step, hidden_dim] + agent_token_emb, # [n_agent, n_step, hidden_dim] + agent_token_emb_veh, # [n_agent, hidden_dim] + agent_token_emb_ped, # [n_agent, hidden_dim] + agent_token_emb_cyc, # [n_agent, hidden_dim] + veh_mask, # [n_agent] + ped_mask, # [n_agent] + cyc_mask, # [n_agent] + categorical_embs, # List of len=2, shape [n_agent, hidden_dim] + ) + else: + return feat_a # [n_agent, n_step, hidden_dim] + + def build_temporal_edge( + self, + pos_a, # [n_agent, n_step, 2] + head_a, # [n_agent, n_step] + head_vector_a, # [n_agent, n_step, 2], + mask, # [n_agent, n_step] + inference_mask=None, # [n_agent, n_step] + ): + pos_t = pos_a.flatten(0, 1) + head_t = head_a.flatten(0, 1) + head_vector_t = head_vector_a.flatten(0, 1) + + if self.hist_drop_prob > 0 and self.training: + _mask_keep = torch.bernoulli( + torch.ones_like(mask) * (1 - self.hist_drop_prob) + ).bool() + mask = mask & _mask_keep + + if inference_mask is not None: + mask_t = mask.unsqueeze(2) & inference_mask.unsqueeze(1) + else: + mask_t = mask.unsqueeze(2) & mask.unsqueeze(1) + + edge_index_t = dense_to_sparse(mask_t)[0] + edge_index_t = edge_index_t[:, edge_index_t[1] > edge_index_t[0]] + edge_index_t = edge_index_t[ + :, edge_index_t[1] - edge_index_t[0] <= self.time_span / self.shift + ] + rel_pos_t = pos_t[edge_index_t[0]] - pos_t[edge_index_t[1]] + rel_pos_t = rel_pos_t[:, :2] + rel_head_t = wrap_angle(head_t[edge_index_t[0]] - head_t[edge_index_t[1]]) + r_t = torch.stack( + [ + torch.norm(rel_pos_t, p=2, dim=-1), + angle_between_2d_vectors( + ctr_vector=head_vector_t[edge_index_t[1]], nbr_vector=rel_pos_t + ), + rel_head_t, + edge_index_t[0] - edge_index_t[1], + ], + dim=-1, + ) + r_t = self.r_t_emb(continuous_inputs=r_t, categorical_embs=None) + return edge_index_t, r_t + + def build_interaction_edge( + self, + pos_a, # [n_agent, n_step, 2] + head_a, # [n_agent, n_step] + head_vector_a, # [n_agent, n_step, 2] + batch_s, # [n_agent*n_step] + mask, # [n_agent, n_step] + ): + mask = mask.transpose(0, 1).reshape(-1) + pos_s = pos_a.transpose(0, 1).flatten(0, 1) + head_s = head_a.transpose(0, 1).reshape(-1) + head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2) + edge_index_a2a = radius_graph( + x=pos_s[:, :2], + r=self.a2a_radius, + batch=batch_s, + loop=False, + max_num_neighbors=300, + ) + edge_index_a2a = subgraph(subset=mask, edge_index=edge_index_a2a)[0] + rel_pos_a2a = pos_s[edge_index_a2a[0]] - pos_s[edge_index_a2a[1]] + rel_head_a2a = wrap_angle(head_s[edge_index_a2a[0]] - head_s[edge_index_a2a[1]]) + r_a2a = torch.stack( + [ + torch.norm(rel_pos_a2a[:, :2], p=2, dim=-1), + angle_between_2d_vectors( + ctr_vector=head_vector_s[edge_index_a2a[1]], + nbr_vector=rel_pos_a2a[:, :2], + ), + rel_head_a2a, + ], + dim=-1, + ) + r_a2a = self.r_a2a_emb(continuous_inputs=r_a2a, categorical_embs=None) + return edge_index_a2a, r_a2a + + def build_map2agent_edge( + self, + pos_pl, # [n_pl, 2] + orient_pl, # [n_pl] + pos_a, # [n_agent, n_step, 2] + head_a, # [n_agent, n_step] + head_vector_a, # [n_agent, n_step, 2] + mask, # [n_agent, n_step] + batch_s, # [n_agent*n_step] + batch_pl, # [n_pl*n_step] + ): + n_step = pos_a.shape[1] + mask_pl2a = mask.transpose(0, 1).reshape(-1) + pos_s = pos_a.transpose(0, 1).flatten(0, 1) + head_s = head_a.transpose(0, 1).reshape(-1) + head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2) + pos_pl = pos_pl.repeat(n_step, 1) + orient_pl = orient_pl.repeat(n_step) + edge_index_pl2a = radius( + x=pos_s[:, :2], + y=pos_pl[:, :2], + r=self.pl2a_radius, + batch_x=batch_s, + batch_y=batch_pl, + max_num_neighbors=300, + ) + edge_index_pl2a = edge_index_pl2a[:, mask_pl2a[edge_index_pl2a[1]]] + rel_pos_pl2a = pos_pl[edge_index_pl2a[0]] - pos_s[edge_index_pl2a[1]] + rel_orient_pl2a = wrap_angle( + orient_pl[edge_index_pl2a[0]] - head_s[edge_index_pl2a[1]] + ) + r_pl2a = torch.stack( + [ + torch.norm(rel_pos_pl2a[:, :2], p=2, dim=-1), + angle_between_2d_vectors( + ctr_vector=head_vector_s[edge_index_pl2a[1]], + nbr_vector=rel_pos_pl2a[:, :2], + ), + rel_orient_pl2a, + ], + dim=-1, + ) + r_pl2a = self.r_pt2a_emb(continuous_inputs=r_pl2a, categorical_embs=None) + return edge_index_pl2a, r_pl2a + + def forward( + self, + tokenized_agent: Dict[str, torch.Tensor], + map_feature: Dict[str, torch.Tensor], + ) -> Dict[str, torch.Tensor]: + mask = tokenized_agent["valid_mask"] + pos_a = tokenized_agent["sampled_pos"] + head_a = tokenized_agent["sampled_heading"] + head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1) + n_agent, n_step = head_a.shape + + # ! get agent token embeddings + feat_a = self.agent_token_embedding( + agent_token_index=tokenized_agent["sampled_idx"], # [n_ag, n_step] + trajectory_token_veh=tokenized_agent["trajectory_token_veh"], + trajectory_token_ped=tokenized_agent["trajectory_token_ped"], + trajectory_token_cyc=tokenized_agent["trajectory_token_cyc"], + pos_a=pos_a, # [n_agent, n_step, 2] + head_vector_a=head_vector_a, # [n_agent, n_step, 2] + agent_type=tokenized_agent["type"], # [n_agent] + agent_shape=tokenized_agent["shape"], # [n_agent, 3] + ) # feat_a: [n_agent, n_step, hidden_dim] + + # ! build temporal, interaction and map2agent edges + edge_index_t, r_t = self.build_temporal_edge( + pos_a=pos_a, # [n_agent, n_step, 2] + head_a=head_a, # [n_agent, n_step] + head_vector_a=head_vector_a, # [n_agent, n_step, 2] + mask=mask, # [n_agent, n_step] + ) # edge_index_t: [2, n_edge_t], r_t: [n_edge_t, hidden_dim] + + batch_s = torch.cat( + [ + tokenized_agent["batch"] + tokenized_agent["num_graphs"] * t + for t in range(n_step) + ], + dim=0, + ) # [n_agent*n_step] + batch_pl = torch.cat( + [ + map_feature["batch"] + tokenized_agent["num_graphs"] * t + for t in range(n_step) + ], + dim=0, + ) # [n_pl*n_step] + + edge_index_a2a, r_a2a = self.build_interaction_edge( + pos_a=pos_a, # [n_agent, n_step, 2] + head_a=head_a, # [n_agent, n_step] + head_vector_a=head_vector_a, # [n_agent, n_step, 2] + batch_s=batch_s, # [n_agent*n_step] + mask=mask, # [n_agent, n_step] + ) # edge_index_a2a: [2, n_edge_a2a], r_a2a: [n_edge_a2a, hidden_dim] + + edge_index_pl2a, r_pl2a = self.build_map2agent_edge( + pos_pl=map_feature["position"], # [n_pl, 2] + orient_pl=map_feature["orientation"], # [n_pl] + pos_a=pos_a, # [n_agent, n_step, 2] + head_a=head_a, # [n_agent, n_step] + head_vector_a=head_vector_a, # [n_agent, n_step, 2] + mask=mask, # [n_agent, n_step] + batch_s=batch_s, # [n_agent*n_step] + batch_pl=batch_pl, # [n_pl*n_step] + ) + + # ! attention layers + # [n_step*n_pl, hidden_dim] + feat_map = ( + map_feature["pt_token"].unsqueeze(0).expand(n_step, -1, -1).flatten(0, 1) + ) + + for i in range(self.num_layers): + feat_a = feat_a.flatten(0, 1) # [n_agent*n_step, hidden_dim] + feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t) + # [n_step*n_agent, hidden_dim] + feat_a = feat_a.view(n_agent, n_step, -1).transpose(0, 1).flatten(0, 1) + feat_a = self.pt2a_attn_layers[i]( + (feat_map, feat_a), r_pl2a, edge_index_pl2a + ) + feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a) + feat_a = feat_a.view(n_step, n_agent, -1).transpose(0, 1) + + # ! final mlp to get outputs + ego_mask = tokenized_agent["ego_mask"] + feat_a = feat_a[ego_mask] + ego_next_logits = self.gmm_logits_head(feat_a) + ego_next_poses = self.gmm_pose_head(feat_a).view(*ego_next_logits.shape, 3) + return { + # action that goes from [(10->15), ..., (85->90)] + "ego_next_logits": ego_next_logits[:, 1:-1], # [n_batch, 16, n_k_ego_gmm] + "ego_next_poses": ego_next_poses[:, 1:-1], # [n_batch, 16, n_k_ego_gmm, 3] + "ego_next_valid": tokenized_agent["valid_mask"][ego_mask][:, 1:-1], + "ego_next_cov": self.gmm_cov, # [2], one for pos, one for heading. + # for step {5, 10, ..., 90} and act [(0->5), (5->10), ..., (85->90)] + "pred_pos": tokenized_agent["sampled_pos"][ego_mask], # [n_batch, 18, 2] + "pred_head": tokenized_agent["sampled_heading"][ego_mask], # [n_batch, 18] + "pred_valid": tokenized_agent["valid_mask"][ego_mask], # [n_batch, 18] + # for step {5, 10, ..., 90} + "gt_pos_raw": tokenized_agent["gt_pos_raw"][ego_mask], # [n_batch, 18, 2] + "gt_head_raw": tokenized_agent["gt_head_raw"][ego_mask], # [n_batch, 18] + "gt_valid_raw": tokenized_agent["gt_valid_raw"][ego_mask], # [n_batch, 18] + # or use the tokenized gt + "gt_pos": tokenized_agent["gt_pos"][ego_mask], # [n_batch, 18, 2] + "gt_head": tokenized_agent["gt_heading"][ego_mask], # [n_batch, 18] + "gt_valid": tokenized_agent["valid_mask"][ego_mask], # [n_batch, 18] + } + + def inference( + self, + tokenized_agent: Dict[str, torch.Tensor], + map_feature: Dict[str, torch.Tensor], + sampling_scheme: DictConfig, + ) -> Dict[str, torch.Tensor]: + n_agent = tokenized_agent["valid_mask"].shape[0] + n_step_future_10hz = self.num_future_steps # 80 + n_step_future_2hz = n_step_future_10hz // self.shift # 16 + step_current_10hz = self.num_historical_steps - 1 # 10 + step_current_2hz = step_current_10hz // self.shift # 2 + ego_mask = tokenized_agent["ego_mask"] + + pos_a = tokenized_agent["gt_pos"][:, :step_current_2hz].clone() + head_a = tokenized_agent["gt_heading"][:, :step_current_2hz].clone() + head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1) + pred_idx = tokenized_agent["gt_idx"].clone() + ( + feat_a, # [n_agent, step_current_2hz, hidden_dim] + agent_token_emb, # [n_agent, step_current_2hz, hidden_dim] + agent_token_emb_veh, # [n_agent, hidden_dim] + agent_token_emb_ped, # [n_agent, hidden_dim] + agent_token_emb_cyc, # [n_agent, hidden_dim] + veh_mask, # [n_agent] + ped_mask, # [n_agent] + cyc_mask, # [n_agent] + categorical_embs, # List of len=2, shape [n_agent, hidden_dim] + ) = self.agent_token_embedding( + agent_token_index=tokenized_agent["gt_idx"][:, :step_current_2hz], + trajectory_token_veh=tokenized_agent["trajectory_token_veh"], + trajectory_token_ped=tokenized_agent["trajectory_token_ped"], + trajectory_token_cyc=tokenized_agent["trajectory_token_cyc"], + pos_a=pos_a, + head_vector_a=head_vector_a, + agent_type=tokenized_agent["type"], + agent_shape=tokenized_agent["shape"], + inference=True, + ) + + if not self.training: + pred_traj_10hz = torch.zeros( + [n_agent, n_step_future_10hz, 2], dtype=pos_a.dtype, device=pos_a.device + ) + pred_head_10hz = torch.zeros( + [n_agent, n_step_future_10hz], dtype=pos_a.dtype, device=pos_a.device + ) + + pred_valid = tokenized_agent["valid_mask"].clone() + ego_next_logits_list = [] + ego_next_poses_list = [] + next_token_action_list = [] + feat_a_t_dict = {} + for t in range(n_step_future_2hz): # 0 -> 15 + t_now = step_current_2hz - 1 + t # 1 -> 16 + n_step = t_now + 1 # 2 -> 17 + + if t == 0: # init + hist_step = step_current_2hz + batch_s = torch.cat( + [ + tokenized_agent["batch"] + tokenized_agent["num_graphs"] * t + for t in range(hist_step) + ], + dim=0, + ) + batch_pl = torch.cat( + [ + map_feature["batch"] + tokenized_agent["num_graphs"] * t + for t in range(hist_step) + ], + dim=0, + ) + inference_mask = pred_valid[:, :n_step] + edge_index_t, r_t = self.build_temporal_edge( + pos_a=pos_a, + head_a=head_a, + head_vector_a=head_vector_a, + mask=pred_valid[:, :n_step], + ) + else: + hist_step = 1 + batch_s = tokenized_agent["batch"] + batch_pl = map_feature["batch"] + inference_mask = pred_valid[:, :n_step].clone() + inference_mask[:, :-1] = False + edge_index_t, r_t = self.build_temporal_edge( + pos_a=pos_a, + head_a=head_a, + head_vector_a=head_vector_a, + mask=pred_valid[:, :n_step], + inference_mask=inference_mask, + ) + edge_index_t[1] = (edge_index_t[1] + 1) // n_step - 1 + + # In the inference stage, we only infer the current stage for recurrent + edge_index_pl2a, r_pl2a = self.build_map2agent_edge( + pos_pl=map_feature["position"], # [n_pl, 2] + orient_pl=map_feature["orientation"], # [n_pl] + pos_a=pos_a[:, -hist_step:], # [n_agent, hist_step, 2] + head_a=head_a[:, -hist_step:], # [n_agent, hist_step] + head_vector_a=head_vector_a[:, -hist_step:], # [n_agent, hist_step, 2] + mask=inference_mask[:, -hist_step:], # [n_agent, hist_step] + batch_s=batch_s, # [n_agent*hist_step] + batch_pl=batch_pl, # [n_pl*hist_step] + ) + edge_index_a2a, r_a2a = self.build_interaction_edge( + pos_a=pos_a[:, -hist_step:], # [n_agent, hist_step, 2] + head_a=head_a[:, -hist_step:], # [n_agent, hist_step] + head_vector_a=head_vector_a[:, -hist_step:], # [n_agent, hist_step, 2] + batch_s=batch_s, # [n_agent*hist_step] + mask=inference_mask[:, -hist_step:], # [n_agent, hist_step] + ) + + # ! attention layers + for i in range(self.num_layers): + # [n_agent, n_step, hidden_dim] + _feat_temporal = feat_a if i == 0 else feat_a_t_dict[i] + + if t == 0: # init, process hist_step together + _feat_temporal = self.t_attn_layers[i]( + _feat_temporal.flatten(0, 1), r_t, edge_index_t + ).view(n_agent, n_step, -1) + _feat_temporal = _feat_temporal.transpose(0, 1).flatten(0, 1) + + # [hist_step*n_pl, hidden_dim] + _feat_map = ( + map_feature["pt_token"] + .unsqueeze(0) + .expand(hist_step, -1, -1) + .flatten(0, 1) + ) + + _feat_temporal = self.pt2a_attn_layers[i]( + (_feat_map, _feat_temporal), r_pl2a, edge_index_pl2a + ) + _feat_temporal = self.a2a_attn_layers[i]( + _feat_temporal, r_a2a, edge_index_a2a + ) + _feat_temporal = _feat_temporal.view(n_step, n_agent, -1).transpose( + 0, 1 + ) + feat_a_now = _feat_temporal[:, -1] # [n_agent, hidden_dim] + + if i + 1 < self.num_layers: + feat_a_t_dict[i + 1] = _feat_temporal + + else: # process one step + feat_a_now = self.t_attn_layers[i]( + (_feat_temporal.flatten(0, 1), _feat_temporal[:, -1]), + r_t, + edge_index_t, + ) + # * give same results as below, but more efficient + # feat_a_now = self.t_attn_layers[i]( + # _feat_temporal.flatten(0, 1), r_t, edge_index_t + # ).view(n_agent, n_step, -1)[:, -1] + + feat_a_now = self.pt2a_attn_layers[i]( + (map_feature["pt_token"], feat_a_now), r_pl2a, edge_index_pl2a + ) + feat_a_now = self.a2a_attn_layers[i]( + feat_a_now, r_a2a, edge_index_a2a + ) + + # [n_agent, n_step, hidden_dim] + if i + 1 < self.num_layers: + feat_a_t_dict[i + 1] = torch.cat( + (feat_a_t_dict[i + 1], feat_a_now.unsqueeze(1)), dim=1 + ) + + # ! get outputs + feat_a_now = feat_a_now[ego_mask] + ego_next_logits = self.gmm_logits_head(feat_a_now) + ego_next_poses = self.gmm_pose_head(feat_a_now).view( + *ego_next_logits.shape, 3 + ) + + ego_next_logits_list.append(ego_next_logits) # [n_batch, n_k_ego_gmm] + ego_next_poses_list.append(ego_next_poses) # [n_batch, n_k_ego_gmm, 3] + + next_token_idx, next_token_traj_all = sample_next_gmm_traj( + token_traj=tokenized_agent["token_traj"], + token_traj_all=tokenized_agent["token_traj_all"], + sampling_scheme=sampling_scheme, + # ! for most-likely sampling + ego_mask=ego_mask, # [n_agent] + ego_next_logits=ego_next_logits, # [n_batch, n_k_ego_gmm] + ego_next_poses=ego_next_poses, # [n_batch, n_k_ego_gmm, 3] + ego_next_cov=self.gmm_cov, # [2], one for pos, one for heading. + # ! for nearest-pos sampling + pos_now=pos_a[:, t_now], # [n_agent, 2] + head_now=head_a[:, t_now], # [n_agent] + pos_next_gt=tokenized_agent["gt_pos_raw"][:, n_step], # [n_agent, 2] + head_next_gt=tokenized_agent["gt_head_raw"][:, n_step], # [n_agent] + valid_next_gt=tokenized_agent["gt_valid_raw"][:, n_step], # [n_agent] + token_agent_shape=tokenized_agent["token_agent_shape"], # [n_token, 2] + next_token_idx=tokenized_agent["gt_idx"][:, n_step].clone(), + ) # next_token_idx: [n_agent], next_token_traj_all: [n_agent, 6, 4, 2] + + diff_xy = next_token_traj_all[:, -1, 0] - next_token_traj_all[:, -1, 3] + next_token_action_list.append( + torch.cat( + [ + next_token_traj_all[:, -1].mean(1), # [n_agent, 2] + torch.arctan2(diff_xy[:, [1]], diff_xy[:, [0]]), # [n_agent, 1] + ], + dim=-1, + )[ego_mask] + ) # [n_batch, 3] + + token_traj_global = transform_to_global( + pos_local=next_token_traj_all.flatten(1, 2), # [n_agent, 6*4, 2] + head_local=None, + pos_now=pos_a[:, t_now], # [n_agent, 2] + head_now=head_a[:, t_now], # [n_agent] + )[0].view(*next_token_traj_all.shape) + + if not self.training: + pred_traj_10hz[:, t * 5 : (t + 1) * 5] = token_traj_global[:, 1:].mean( + 2 + ) + diff_xy = token_traj_global[:, 1:, 0] - token_traj_global[:, 1:, 3] + pred_head_10hz[:, t * 5 : (t + 1) * 5] = torch.arctan2( + diff_xy[:, :, 1], diff_xy[:, :, 0] + ) + + # ! get pos_a_next and head_a_next, spawn unseen agents + pos_a_next = token_traj_global[:, -1].mean(dim=1) + diff_xy_next = token_traj_global[:, -1, 0] - token_traj_global[:, -1, 3] + head_a_next = torch.arctan2(diff_xy_next[:, 1], diff_xy_next[:, 0]) + pred_idx[:, n_step] = next_token_idx + + # ! update tensors for for next step + pred_valid[:, n_step][ego_mask] = True + pos_a = torch.cat([pos_a, pos_a_next.unsqueeze(1)], dim=1) + head_a = torch.cat([head_a, head_a_next.unsqueeze(1)], dim=1) + head_vector_a_next = torch.stack( + [head_a_next.cos(), head_a_next.sin()], dim=-1 + ) + head_vector_a = torch.cat( + [head_vector_a, head_vector_a_next.unsqueeze(1)], dim=1 + ) + + # ! get agent_token_emb_next + agent_token_emb_next = torch.zeros_like(agent_token_emb[:, 0]) + agent_token_emb_next[veh_mask] = agent_token_emb_veh[ + next_token_idx[veh_mask] + ] + agent_token_emb_next[ped_mask] = agent_token_emb_ped[ + next_token_idx[ped_mask] + ] + agent_token_emb_next[cyc_mask] = agent_token_emb_cyc[ + next_token_idx[cyc_mask] + ] + agent_token_emb = torch.cat( + [agent_token_emb, agent_token_emb_next.unsqueeze(1)], dim=1 + ) + + # ! get feat_a_next + motion_vector_a = pos_a[:, -1] - pos_a[:, -2] # [n_agent, 2] + x_a = torch.stack( + [ + torch.norm(motion_vector_a, p=2, dim=-1), + angle_between_2d_vectors( + ctr_vector=head_vector_a[:, -1], nbr_vector=motion_vector_a + ), + ], + dim=-1, + ) + # [n_agent, hidden_dim] + x_a = self.x_a_emb(continuous_inputs=x_a, categorical_embs=categorical_embs) + # [n_agent, 1, 2*hidden_dim] + feat_a_next = torch.cat((agent_token_emb_next, x_a), dim=-1).unsqueeze(1) + feat_a_next = self.fusion_emb(feat_a_next) + feat_a = torch.cat([feat_a, feat_a_next], dim=1) + + out_dict = { + # action that goes from [(10->15), ..., (85->90)] + # [n_batch, 16, n_k_ego_gmm] + "ego_next_logits": torch.stack(ego_next_logits_list, dim=1), + # [n_batch, 16, n_k_ego_gmm, 3] + "ego_next_poses": torch.stack(ego_next_poses_list, dim=1), + "ego_next_valid": pred_valid[ego_mask][:, 1:-1], # [n_batch, 16] + "ego_next_cov": self.gmm_cov, # [2], one for pos, one for heading. + # for step {5, 10, ..., 90} and act [(0->5), (5->10), ..., (85->90)] + "pred_pos": pos_a[ego_mask], # [n_batch, 18, 2] + "pred_head": head_a[ego_mask], # [n_batch, 18] + "pred_valid": pred_valid[ego_mask], # [n_batch, 18] + # "pred_idx": pred_idx, # [n_batch, 18] + # for step {5, 10, ..., 90} + "gt_pos_raw": tokenized_agent["gt_pos_raw"][ego_mask], # [n_batch, 18, 2] + "gt_head_raw": tokenized_agent["gt_head_raw"][ego_mask], # [n_batch, 18] + "gt_valid_raw": tokenized_agent["gt_valid_raw"][ego_mask], # [n_batch, 18] + # or use the tokenized gt + "gt_pos": tokenized_agent["gt_pos"][ego_mask], # [n_batch, 18, 2] + "gt_head": tokenized_agent["gt_heading"][ego_mask], # [n_batch, 18] + "gt_valid": tokenized_agent["valid_mask"][ego_mask], # [n_batch, 18] + # for shifting proxy targets by lr, [n_batch, 16, 3] + "next_token_action": torch.stack(next_token_action_list, dim=1), + } + + if not self.training: # 10hz predictions for wosac evaluation and submission + out_dict["pred_traj_10hz"] = pred_traj_10hz + out_dict["pred_head_10hz"] = pred_head_10hz + pred_z = tokenized_agent["gt_z_raw"].unsqueeze(1) # [n_agent, 1] + out_dict["pred_z_10hz"] = pred_z.expand(-1, pred_traj_10hz.shape[1]) + + return out_dict diff --git a/backups/thirdparty/catk/src/smart/modules/ego_gmm_smart_decoder.py b/backups/thirdparty/catk/src/smart/modules/ego_gmm_smart_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..cf0462aba418f8dd7ab010a370c17a9076f0472c --- /dev/null +++ b/backups/thirdparty/catk/src/smart/modules/ego_gmm_smart_decoder.py @@ -0,0 +1,91 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from typing import Dict, Optional + +import torch.nn as nn +from omegaconf import DictConfig, ListConfig +from torch import Tensor + +from .ego_gmm_agent_decoder import EgoGMMAgentDecoder +from .map_decoder import SMARTMapDecoder + + +class EgoGMMSMARTDecoder(nn.Module): + + def __init__( + self, + hidden_dim: int, + num_historical_steps: int, + num_future_steps: int, + pl2pl_radius: float, + time_span: Optional[int], + pl2a_radius: float, + a2a_radius: float, + num_freq_bands: int, + num_map_layers: int, + num_agent_layers: int, + num_heads: int, + head_dim: int, + dropout: float, + hist_drop_prob: float, + k_ego_gmm: int = -1, + cov_ego_gmm: ListConfig[float] = [1.0, 0.1], + cov_learnable: bool = False, + ) -> None: + super(EgoGMMSMARTDecoder, self).__init__() + self.map_encoder = SMARTMapDecoder( + hidden_dim=hidden_dim, + pl2pl_radius=pl2pl_radius, + num_freq_bands=num_freq_bands, + num_layers=num_map_layers, + num_heads=num_heads, + head_dim=head_dim, + dropout=dropout, + ) + self.agent_encoder = EgoGMMAgentDecoder( + hidden_dim=hidden_dim, + num_historical_steps=num_historical_steps, + num_future_steps=num_future_steps, + time_span=time_span, + pl2a_radius=pl2a_radius, + a2a_radius=a2a_radius, + num_freq_bands=num_freq_bands, + num_layers=num_agent_layers, + num_heads=num_heads, + head_dim=head_dim, + dropout=dropout, + hist_drop_prob=hist_drop_prob, + k_ego_gmm=k_ego_gmm, + cov_ego_gmm=cov_ego_gmm, + cov_learnable=cov_learnable, + ) + + def forward( + self, tokenized_map: Dict[str, Tensor], tokenized_agent: Dict[str, Tensor] + ) -> Dict[str, Tensor]: + map_feature = self.map_encoder(tokenized_map) + pred_dict = self.agent_encoder(tokenized_agent, map_feature) + return pred_dict + + def inference( + self, + tokenized_map: Dict[str, Tensor], + tokenized_agent: Dict[str, Tensor], + sampling_scheme: DictConfig, + ) -> Dict[str, Tensor]: + map_feature = self.map_encoder(tokenized_map) + pred_dict = self.agent_encoder.inference( + tokenized_agent, map_feature, sampling_scheme + ) + return pred_dict diff --git a/backups/thirdparty/catk/src/smart/modules/map_decoder.py b/backups/thirdparty/catk/src/smart/modules/map_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..bbb7f15e3d366484133a8021c5ad9d7275aa2653 --- /dev/null +++ b/backups/thirdparty/catk/src/smart/modules/map_decoder.py @@ -0,0 +1,113 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from typing import Dict + +import torch +import torch.nn as nn +from torch_cluster import radius_graph + +from src.smart.layers.attention_layer import AttentionLayer +from src.smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding +from src.smart.utils import angle_between_2d_vectors, weight_init, wrap_angle + + +class SMARTMapDecoder(nn.Module): + + def __init__( + self, + hidden_dim: int, + pl2pl_radius: float, + num_freq_bands: int, + num_layers: int, + num_heads: int, + head_dim: int, + dropout: float, + ) -> None: + super(SMARTMapDecoder, self).__init__() + self.pl2pl_radius = pl2pl_radius + self.num_layers = num_layers + + self.type_pt_emb = nn.Embedding(10, hidden_dim) + self.polygon_type_emb = nn.Embedding(4, hidden_dim) + self.light_pl_emb = nn.Embedding(5, hidden_dim) + + input_dim_r_pt2pt = 3 + self.r_pt2pt_emb = FourierEmbedding( + input_dim=input_dim_r_pt2pt, + hidden_dim=hidden_dim, + num_freq_bands=num_freq_bands, + ) + self.pt2pt_layers = nn.ModuleList( + [ + AttentionLayer( + hidden_dim=hidden_dim, + num_heads=num_heads, + head_dim=head_dim, + dropout=dropout, + bipartite=False, + has_pos_emb=True, + ) + for _ in range(num_layers) + ] + ) + + # map_token_traj_src: [n_token, 11, 2].flatten(0,1) + self.token_emb = MLPEmbedding(input_dim=22, hidden_dim=hidden_dim) + self.apply(weight_init) + + def forward(self, tokenized_map: Dict) -> Dict[str, torch.Tensor]: + pos_pt = tokenized_map["position"] + orient_pt = tokenized_map["orientation"] + orient_vector_pt = torch.stack([orient_pt.cos(), orient_pt.sin()], dim=-1) + pt_token_emb_src = self.token_emb(tokenized_map["token_traj_src"]) + x_pt = pt_token_emb_src[tokenized_map["token_idx"]] + + x_pt_categorical_embs = [ + self.type_pt_emb(tokenized_map["type"]), + self.polygon_type_emb(tokenized_map["pl_type"]), + self.light_pl_emb(tokenized_map["light_type"]), + ] + x_pt = x_pt + torch.stack(x_pt_categorical_embs).sum(dim=0) + edge_index_pt2pt = radius_graph( + x=pos_pt, + r=self.pl2pl_radius, + batch=tokenized_map["batch"], + loop=False, + max_num_neighbors=100, + ) + rel_pos_pt2pt = pos_pt[edge_index_pt2pt[0]] - pos_pt[edge_index_pt2pt[1]] + rel_orient_pt2pt = wrap_angle( + orient_pt[edge_index_pt2pt[0]] - orient_pt[edge_index_pt2pt[1]] + ) + r_pt2pt = torch.stack( + [ + torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1), + angle_between_2d_vectors( + ctr_vector=orient_vector_pt[edge_index_pt2pt[1]], + nbr_vector=rel_pos_pt2pt[:, :2], + ), + rel_orient_pt2pt, + ], + dim=-1, + ) + r_pt2pt = self.r_pt2pt_emb(continuous_inputs=r_pt2pt, categorical_embs=None) + for i in range(self.num_layers): + x_pt = self.pt2pt_layers[i](x_pt, r_pt2pt, edge_index_pt2pt) + + return { + "pt_token": x_pt, + "position": pos_pt, + "orientation": orient_pt, + "batch": tokenized_map["batch"], + } diff --git a/backups/thirdparty/catk/src/smart/modules/smart_decoder.py b/backups/thirdparty/catk/src/smart/modules/smart_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..daceb487e9d10e5c5a2a3a38338118090c28c2bc --- /dev/null +++ b/backups/thirdparty/catk/src/smart/modules/smart_decoder.py @@ -0,0 +1,87 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from typing import Dict, Optional + +import torch.nn as nn +from omegaconf import DictConfig +from torch import Tensor + +from .agent_decoder import SMARTAgentDecoder +from .map_decoder import SMARTMapDecoder + + +class SMARTDecoder(nn.Module): + + def __init__( + self, + hidden_dim: int, + num_historical_steps: int, + num_future_steps: int, + pl2pl_radius: float, + time_span: Optional[int], + pl2a_radius: float, + a2a_radius: float, + num_freq_bands: int, + num_map_layers: int, + num_agent_layers: int, + num_heads: int, + head_dim: int, + dropout: float, + hist_drop_prob: float, + n_token_agent: int, + ) -> None: + super(SMARTDecoder, self).__init__() + self.map_encoder = SMARTMapDecoder( + hidden_dim=hidden_dim, + pl2pl_radius=pl2pl_radius, + num_freq_bands=num_freq_bands, + num_layers=num_map_layers, + num_heads=num_heads, + head_dim=head_dim, + dropout=dropout, + ) + self.agent_encoder = SMARTAgentDecoder( + hidden_dim=hidden_dim, + num_historical_steps=num_historical_steps, + num_future_steps=num_future_steps, + time_span=time_span, + pl2a_radius=pl2a_radius, + a2a_radius=a2a_radius, + num_freq_bands=num_freq_bands, + num_layers=num_agent_layers, + num_heads=num_heads, + head_dim=head_dim, + dropout=dropout, + hist_drop_prob=hist_drop_prob, + n_token_agent=n_token_agent, + ) + + def forward( + self, tokenized_map: Dict[str, Tensor], tokenized_agent: Dict[str, Tensor] + ) -> Dict[str, Tensor]: + map_feature = self.map_encoder(tokenized_map) + pred_dict = self.agent_encoder(tokenized_agent, map_feature) + return pred_dict + + def inference( + self, + tokenized_map: Dict[str, Tensor], + tokenized_agent: Dict[str, Tensor], + sampling_scheme: DictConfig, + ) -> Dict[str, Tensor]: + map_feature = self.map_encoder(tokenized_map) + pred_dict = self.agent_encoder.inference( + tokenized_agent, map_feature, sampling_scheme + ) + return pred_dict diff --git a/backups/thirdparty/catk/src/smart/tokens/__init__.py b/backups/thirdparty/catk/src/smart/tokens/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backups/thirdparty/catk/src/smart/tokens/token_processor.py b/backups/thirdparty/catk/src/smart/tokens/token_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..705cf362f7f64ecdc02137e24e3589ea7c84a149 --- /dev/null +++ b/backups/thirdparty/catk/src/smart/tokens/token_processor.py @@ -0,0 +1,377 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import pickle +from typing import Dict, Tuple + +import torch +from omegaconf import DictConfig +from torch import Tensor +from torch.distributions import Categorical +from torch_geometric.data import HeteroData + +from src.smart.utils import ( + cal_polygon_contour, + transform_to_global, + transform_to_local, + wrap_angle, +) + + +class TokenProcessor(torch.nn.Module): + + def __init__( + self, + map_token_file: str, + agent_token_file: str, + map_token_sampling: DictConfig, + agent_token_sampling: DictConfig, + ) -> None: + super(TokenProcessor, self).__init__() + self.map_token_sampling = map_token_sampling + self.agent_token_sampling = agent_token_sampling + self.shift = 5 + + module_dir = os.path.dirname(__file__) + self.init_agent_token(os.path.join(module_dir, agent_token_file)) + self.init_map_token(os.path.join(module_dir, map_token_file)) + self.n_token_agent = self.agent_token_all_veh.shape[0] + + @torch.no_grad() + def forward(self, data: HeteroData) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]: + tokenized_map = self.tokenize_map(data) + tokenized_agent = self.tokenize_agent(data) + return tokenized_map, tokenized_agent + + def init_map_token(self, map_token_traj_path, argmin_sample_len=3) -> None: + map_token_traj = pickle.load(open(map_token_traj_path, "rb"))["traj_src"] + indices = torch.linspace( + 0, map_token_traj.shape[1] - 1, steps=argmin_sample_len + ).long() + + self.register_buffer( + "map_token_traj_src", + torch.tensor(map_token_traj, dtype=torch.float32).flatten(1, 2), + persistent=False, + ) # [n_token, 11*2] + + self.register_buffer( + "map_token_sample_pt", + torch.tensor(map_token_traj[:, indices], dtype=torch.float32).unsqueeze(0), + persistent=False, + ) # [1, n_token, 3, 2] + + def init_agent_token(self, agent_token_path) -> None: + agent_token_data = pickle.load(open(agent_token_path, "rb")) + for k, v in agent_token_data["token_all"].items(): + v = torch.tensor(v, dtype=torch.float32) + # [n_token, 6, 4, 2], countour, 10 hz + self.register_buffer(f"agent_token_all_{k}", v, persistent=False) + + def tokenize_map(self, data: HeteroData) -> Dict[str, Tensor]: + traj_pos = data["map_save"]["traj_pos"] # [n_pl, 3, 2] + traj_theta = data["map_save"]["traj_theta"] # [n_pl] + + traj_pos_local, _ = transform_to_local( + pos_global=traj_pos, # [n_pl, 3, 2] + head_global=None, # [n_pl, 1] + pos_now=traj_pos[:, 0], # [n_pl, 2] + head_now=traj_theta, # [n_pl] + ) + # [1, n_token, 3, 2] - [n_pl, 1, 3, 2] + dist = torch.sum( + (self.map_token_sample_pt - traj_pos_local.unsqueeze(1)) ** 2, + dim=(-2, -1), + ) # [n_pl, n_token] + + if self.training and (self.map_token_sampling.num_k > 1): + topk_dists, topk_indices = torch.topk( + dist, + self.map_token_sampling.num_k, + dim=-1, + largest=False, + sorted=False, + ) # [n_pl, K] + + topk_logits = (-1e-6 - topk_dists) / self.map_token_sampling.temp + _samples = Categorical(logits=topk_logits).sample() # [n_pl] in K + token_idx = topk_indices[torch.arange(len(_samples)), _samples].contiguous() + else: + token_idx = torch.argmin(dist, dim=-1) + + tokenized_map = { + "position": traj_pos[:, 0].contiguous(), # [n_pl, 2] + "orientation": traj_theta, # [n_pl] + "token_idx": token_idx, # [n_pl] + "token_traj_src": self.map_token_traj_src, # [n_token, 11*2] + "type": data["pt_token"]["type"].long(), # [n_pl] + "pl_type": data["pt_token"]["pl_type"].long(), # [n_pl] + "light_type": data["pt_token"]["light_type"].long(), # [n_pl] + "batch": data["pt_token"]["batch"], # [n_pl] + } + return tokenized_map + + def tokenize_agent(self, data: HeteroData) -> Dict[str, Tensor]: + """ + Args: data["agent"]: Dict + "valid_mask": [n_agent, n_step], bool + "role": [n_agent, 3], bool + "id": [n_agent], int64 + "type": [n_agent], uint8 + "position": [n_agent, n_step, 3], float32 + "heading": [n_agent, n_step], float32 + "velocity": [n_agent, n_step, 2], float32 + "shape": [n_agent, 3], float32 + """ + # ! collate width/length, traj tokens for current batch + agent_shape, token_traj_all, token_traj = self._get_agent_shape_and_token_traj( + data["agent"]["type"] + ) + + # ! get raw trajectory data + valid = data["agent"]["valid_mask"] # [n_agent, n_step] + heading = data["agent"]["heading"] # [n_agent, n_step] + pos = data["agent"]["position"][..., :2].contiguous() # [n_agent, n_step, 2] + vel = data["agent"]["velocity"] # [n_agent, n_step, 2] + + # ! agent, specifically vehicle's heading can be 180 degree off. We fix it here. + heading = self._clean_heading(valid, heading) + # ! extrapolate to previous 5th step. + valid, pos, heading, vel = self._extrapolate_agent_to_prev_token_step( + valid, pos, heading, vel + ) + + # ! prepare output dict + tokenized_agent = { + "num_graphs": data.num_graphs, + "type": data["agent"]["type"], + "shape": data["agent"]["shape"], + "ego_mask": data["agent"]["role"][:, 0], # [n_agent] + "token_agent_shape": agent_shape, # [n_agent, 2] + "batch": data["agent"]["batch"], + "token_traj_all": token_traj_all, # [n_agent, n_token, 6, 4, 2] + "token_traj": token_traj, # [n_agent, n_token, 4, 2] + # for step {5, 10, ..., 90} + "gt_pos_raw": pos[:, self.shift :: self.shift], # [n_agent, n_step=18, 2] + "gt_head_raw": heading[:, self.shift :: self.shift], # [n_agent, n_step=18] + "gt_valid_raw": valid[:, self.shift :: self.shift], # [n_agent, n_step=18] + } + # [n_token, 8] + for k in ["veh", "ped", "cyc"]: + tokenized_agent[f"trajectory_token_{k}"] = getattr( + self, f"agent_token_all_{k}" + )[:, -1].flatten(1, 2) + + # ! match token for each agent + if not self.training: + # [n_agent] + tokenized_agent["gt_z_raw"] = data["agent"]["position"][:, 10, 2] + + token_dict = self._match_agent_token( + valid=valid, + pos=pos, + heading=heading, + agent_shape=agent_shape, + token_traj=token_traj, + ) + tokenized_agent.update(token_dict) + return tokenized_agent + + def _match_agent_token( + self, + valid: Tensor, # [n_agent, n_step] + pos: Tensor, # [n_agent, n_step, 2] + heading: Tensor, # [n_agent, n_step] + agent_shape: Tensor, # [n_agent, 2] + token_traj: Tensor, # [n_agent, n_token, 4, 2] + ) -> Dict[str, Tensor]: + """n_step_token=n_step//5 + n_step_token=18 for train with BC. + n_step_token=2 for val/test and train with closed-loop rollout. + Returns: Dict + # ! action that goes from [(0->5), (5->10), ..., (85->90)] + "valid_mask": [n_agent, n_step_token] + "gt_idx": [n_agent, n_step_token] + # ! at step [5, 10, 15, ..., 90] + "gt_pos": [n_agent, n_step_token, 2] + "gt_heading": [n_agent, n_step_token] + # ! noisy sampling for training data augmentation + "sampled_idx": [n_agent, n_step_token] + "sampled_pos": [n_agent, n_step_token, 2] + "sampled_heading": [n_agent, n_step_token] + """ + num_k = self.agent_token_sampling.num_k if self.training else 1 + n_agent, n_step = valid.shape + range_a = torch.arange(n_agent) + + prev_pos, prev_head = pos[:, 0], heading[:, 0] # [n_agent, 2], [n_agent] + prev_pos_sample, prev_head_sample = pos[:, 0], heading[:, 0] + + out_dict = { + "valid_mask": [], + "gt_idx": [], + "gt_pos": [], + "gt_heading": [], + "sampled_idx": [], + "sampled_pos": [], + "sampled_heading": [], + } + + for i in range(self.shift, n_step, self.shift): # [5, 10, 15, ..., 90] + _valid_mask = valid[:, i - self.shift] & valid[:, i] # [n_agent] + _invalid_mask = ~_valid_mask + out_dict["valid_mask"].append(_valid_mask) + + #! gt_contour: [n_agent, 4, 2] in global coord + gt_contour = cal_polygon_contour(pos[:, i], heading[:, i], agent_shape) + gt_contour = gt_contour.unsqueeze(1) # [n_agent, 1, 4, 2] + + # ! tokenize without sampling + token_world_gt = transform_to_global( + pos_local=token_traj.flatten(1, 2), # [n_agent, n_token*4, 2] + head_local=None, + pos_now=prev_pos, # [n_agent, 2] + head_now=prev_head, # [n_agent] + )[0].view(*token_traj.shape) + token_idx_gt = torch.argmin( + torch.norm(token_world_gt - gt_contour, dim=-1).sum(-1), dim=-1 + ) # [n_agent] + # [n_agent, 4, 2] + token_contour_gt = token_world_gt[range_a, token_idx_gt] + + # udpate prev_pos, prev_head + prev_head = heading[:, i].clone() + dxy = token_contour_gt[:, 0] - token_contour_gt[:, 3] + prev_head[_valid_mask] = torch.arctan2(dxy[:, 1], dxy[:, 0])[_valid_mask] + prev_pos = pos[:, i].clone() + prev_pos[_valid_mask] = token_contour_gt.mean(1)[_valid_mask] + # add to output dict + out_dict["gt_idx"].append(token_idx_gt) + out_dict["gt_pos"].append( + prev_pos.masked_fill(_invalid_mask.unsqueeze(1), 0) + ) + out_dict["gt_heading"].append(prev_head.masked_fill(_invalid_mask, 0)) + + # ! tokenize from sampled rollout state + if num_k == 1: # K=1 means no sampling + out_dict["sampled_idx"].append(out_dict["gt_idx"][-1]) + out_dict["sampled_pos"].append(out_dict["gt_pos"][-1]) + out_dict["sampled_heading"].append(out_dict["gt_heading"][-1]) + else: + # contour: [n_agent, n_token, 4, 2], 2HZ, global coord + token_world_sample = transform_to_global( + pos_local=token_traj.flatten(1, 2), # [n_agent, n_token*4, 2] + head_local=None, + pos_now=prev_pos_sample, # [n_agent, 2] + head_now=prev_head_sample, # [n_agent] + )[0].view(*token_traj.shape) + + # dist: [n_agent, n_token] + dist = torch.norm(token_world_sample - gt_contour, dim=-1).mean(-1) + topk_dists, topk_indices = torch.topk( + dist, num_k, dim=-1, largest=False, sorted=False + ) # [n_agent, K] + + topk_logits = (-1.0 * topk_dists) / self.agent_token_sampling.temp + _samples = Categorical(logits=topk_logits).sample() # [n_agent] in K + token_idx_sample = topk_indices[range_a, _samples] + token_contour_sample = token_world_sample[range_a, token_idx_sample] + + # udpate prev_pos_sample, prev_head_sample + prev_head_sample = heading[:, i].clone() + dxy = token_contour_sample[:, 0] - token_contour_sample[:, 3] + prev_head_sample[_valid_mask] = torch.arctan2(dxy[:, 1], dxy[:, 0])[ + _valid_mask + ] + prev_pos_sample = pos[:, i].clone() + prev_pos_sample[_valid_mask] = token_contour_sample.mean(1)[_valid_mask] + # add to output dict + out_dict["sampled_idx"].append(token_idx_sample) + out_dict["sampled_pos"].append( + prev_pos_sample.masked_fill(_invalid_mask.unsqueeze(1), 0.0) + ) + out_dict["sampled_heading"].append( + prev_head_sample.masked_fill(_invalid_mask, 0.0) + ) + out_dict = {k: torch.stack(v, dim=1) for k, v in out_dict.items()} + return out_dict + + @staticmethod + def _clean_heading(valid: Tensor, heading: Tensor) -> Tensor: + valid_pairs = valid[:, :-1] & valid[:, 1:] + for i in range(heading.shape[1] - 1): + heading_diff = torch.abs(wrap_angle(heading[:, i] - heading[:, i + 1])) + change_needed = (heading_diff > 1.5) & valid_pairs[:, i] + heading[:, i + 1][change_needed] = heading[:, i][change_needed] + return heading + + def _extrapolate_agent_to_prev_token_step( + self, + valid: Tensor, # [n_agent, n_step] + pos: Tensor, # [n_agent, n_step, 2] + heading: Tensor, # [n_agent, n_step] + vel: Tensor, # [n_agent, n_step, 2] + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + # [n_agent], max will give the first True step + first_valid_step = torch.max(valid, dim=1).indices + + for i, t in enumerate(first_valid_step): # extrapolate to previous 5th step. + n_step_to_extrapolate = t % self.shift + if (t == 10) and (not valid[i, 10 - self.shift]): + # such that at least one token is valid in the history. + n_step_to_extrapolate = self.shift + + if n_step_to_extrapolate > 0: + vel[i, t - n_step_to_extrapolate : t] = vel[i, t] + valid[i, t - n_step_to_extrapolate : t] = True + heading[i, t - n_step_to_extrapolate : t] = heading[i, t] + + for j in range(n_step_to_extrapolate): + pos[i, t - j - 1] = pos[i, t - j] - vel[i, t] * 0.1 + + return valid, pos, heading, vel + + def _get_agent_shape_and_token_traj( + self, agent_type: Tensor + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """ + agent_shape: [n_agent, 2] + token_traj_all: [n_agent, n_token, 6, 4, 2] + token_traj: [n_agent, n_token, 4, 2] + """ + agent_type_masks = { + "veh": agent_type == 0, + "ped": agent_type == 1, + "cyc": agent_type == 2, + } + agent_shape = 0.0 + token_traj_all = 0.0 + for k, mask in agent_type_masks.items(): + if k == "veh": + width = 2.0 + length = 4.8 + elif k == "cyc": + width = 1.0 + length = 2.0 + else: + width = 1.0 + length = 1.0 + agent_shape += torch.stack([width * mask, length * mask], dim=-1) + + token_traj_all += mask[:, None, None, None, None] * ( + getattr(self, f"agent_token_all_{k}").unsqueeze(0) + ) + + token_traj = token_traj_all[:, :, -1, :, :].contiguous() + return agent_shape, token_traj_all, token_traj diff --git a/backups/thirdparty/catk/src/smart/tokens/traj_clustering.py b/backups/thirdparty/catk/src/smart/tokens/traj_clustering.py new file mode 100644 index 0000000000000000000000000000000000000000..afac56f8973b8337fb9314cc1cc7bfde2a3e03ce --- /dev/null +++ b/backups/thirdparty/catk/src/smart/tokens/traj_clustering.py @@ -0,0 +1,185 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import pickle +from pathlib import Path + +import lightning as L +import torch +from torch_geometric.data import HeteroData +from torch_geometric.loader import DataLoader +from tqdm import tqdm + +from src.smart.datasets import MultiDataset +from src.smart.tokens.token_processor import TokenProcessor +from src.smart.utils import cal_polygon_contour, transform_to_local, wrap_angle + + +def Kdisk_cluster( + X, # [n_trajs, 4, 2], bbox of the last point of the segment + N, # int + tol, # float + a_pos, # [n_trajs, 6, 3], the complete segment + cal_mean_heading=True, +): + n_total = X.shape[0] + ret_traj_list = [] + + for i in range(N): + if i == 0: + choice_index = 0 # always include [0, 0, 0] + else: + choice_index = torch.randint(0, X.shape[0], (1,)).item() + x0 = X[choice_index] + # res_mask = torch.sum((X - x0) ** 2, dim=[1, 2]) / 4.0 > (tol**2) + res_mask = torch.norm(X - x0, dim=-1).mean(-1) > tol + if cal_mean_heading: + ret_traj = a_pos[~res_mask].mean(0, keepdim=True) + else: + ret_traj = a_pos[[choice_index]] + X = X[res_mask] + a_pos = a_pos[res_mask] + ret_traj_list.append(ret_traj) + + remain = X.shape[0] * 100.0 / n_total + n_inside = (~res_mask).sum().item() + print(f"{i=}, {remain=:.2f}%, {n_inside=}") + + return torch.cat(ret_traj_list, dim=0) # [N, 6, 3] + + +if __name__ == "__main__": + L.seed_everything(seed=2, workers=True) + n_trajs = 2048 * 100 # 2e5 + load_data_from_file = True + data_cache_path = Path("/root/.cache/SMART") + out_file_name = "agent_vocab_555_s2.pkl" + tol_dist = [0.05, 0.05, 0.05] # veh, ped, cyc + + # ! don't change these params + shift = 5 # motion token time dimension + num_cluster = 2048 # vocabulary size + n_step = 91 + data_file_path = data_cache_path / "kdisk_trajs.pkl" + if load_data_from_file: + with open(data_file_path, "rb") as f: + data = pickle.load(f) + else: + trajs = [ + torch.zeros([1, 6, 3], dtype=torch.float32), # veh + torch.zeros([1, 6, 3], dtype=torch.float32), # ped + torch.zeros([1, 6, 3], dtype=torch.float32), # cyc + ] + dataloader = DataLoader( + dataset=MultiDataset( + raw_dir=data_cache_path / "training", transform=lambda x: HeteroData(x) + ), + batch_size=8, + shuffle=False, + num_workers=8, + drop_last=False, + ) + + with tqdm( + total=len(dataloader), + desc=f"n_trajs={n_trajs}", + postfix={"n_veh": 0, "n_ped": 0, "n_cyc": 0}, + ) as pbar: + + for data in dataloader: + valid_mask = data["agent"]["valid_mask"] + data["agent"]["heading"] = TokenProcessor._clean_heading( + valid_mask, data["agent"]["heading"] + ) + + for i_ag in range(valid_mask.shape[0]): + if valid_mask[i_ag, :].sum() < 30: + continue + for t in range(0, n_step - shift, shift): + if valid_mask[i_ag, t] and valid_mask[i_ag, t + shift]: + _type = data["agent"]["type"][i_ag] + if trajs[_type].shape[0] < n_trajs: + pos = data["agent"]["position"][ + i_ag, t : t + shift + 1, :2 + ] + head = data["agent"]["heading"][i_ag, t : t + shift + 1] + pos, head = transform_to_local( + pos_global=pos.unsqueeze(0), # [1, 6, 2] + head_global=head.unsqueeze(0), # [1, 6] + pos_now=pos[[0]], # [1, 2] + head_now=head[[0]], # [1] + ) + head = wrap_angle(head) + to_add = torch.cat([pos, head.unsqueeze(-1)], dim=-1) + + if not ( + ( + (trajs[_type] - to_add).abs().sum([1, 2]) < 1e-2 + ).any() + ): + trajs[_type] = torch.cat( + [trajs[_type], to_add], dim=0 + ) + pbar.update(1) + pbar.set_postfix( + n_veh=trajs[0].shape[0], + n_ped=trajs[1].shape[0], + n_cyc=trajs[2].shape[0], + ) + if ( + trajs[0].shape[0] == n_trajs + and trajs[1].shape[0] == n_trajs + and trajs[2].shape[0] == n_trajs + ): + break + + # [n_trajs, shift+1, [relative_x, relative_y, relative_theta]] + data = {"veh": trajs[0], "ped": trajs[1], "cyc": trajs[2]} + + with open(data_file_path, "wb") as f: + pickle.dump(data, f) + + res = {"token_all": {}} + + for k, v in data.items(): + if k == "veh": + width_length = torch.tensor([2.0, 4.8]) + elif k == "ped": + width_length = torch.tensor([1.0, 1.0]) + elif k == "cyc": + width_length = torch.tensor([1.0, 2.0]) + width_length = width_length.unsqueeze(0) # [1, 2] + + contour = cal_polygon_contour( + pos=v[:, -1, :2], head=v[:, -1, 2], width_length=width_length + ) # [n_trajs, 4, 2] + + if k == "veh": + tol = tol_dist[0] + elif k == "ped": + tol = tol_dist[1] + elif k == "cyc": + tol = tol_dist[2] + print(k, tol) + ret_traj = Kdisk_cluster(X=contour, N=num_cluster, tol=tol, a_pos=v) + ret_traj[:, :, -1] = wrap_angle(ret_traj[:, :, -1]) + + contour = cal_polygon_contour( + pos=ret_traj[:, :, :2], # [N, 6, 2] + head=ret_traj[:, :, 2], # [N, 6] + width_length=width_length.unsqueeze(0), + ) + res["token_all"][k] = contour.numpy() + + with open(Path(__file__).resolve().parent / out_file_name, "wb") as f: + pickle.dump(res, f) diff --git a/backups/thirdparty/catk/src/smart/utils/__init__.py b/backups/thirdparty/catk/src/smart/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..113c28b0fbb565b9bb068ede116cec377ee6a049 --- /dev/null +++ b/backups/thirdparty/catk/src/smart/utils/__init__.py @@ -0,0 +1,22 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from src.smart.utils.geometry import angle_between_2d_vectors, wrap_angle +from src.smart.utils.rollout import ( + cal_polygon_contour, + sample_next_gmm_traj, + sample_next_token_traj, + transform_to_global, + transform_to_local, +) +from src.smart.utils.weight_init import weight_init diff --git a/backups/thirdparty/catk/src/smart/utils/finetune.py b/backups/thirdparty/catk/src/smart/utils/finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..186c2fc46f8b9e6a410f732fa917fb9aee826f01 --- /dev/null +++ b/backups/thirdparty/catk/src/smart/utils/finetune.py @@ -0,0 +1,46 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +from src.utils import RankedLogger + +log = RankedLogger(__name__, rank_zero_only=True) + + +def set_model_for_finetuning(model: torch.nn.Module, finetune: bool) -> None: + def _unfreeze(module: torch.nn.Module) -> None: + for p in module.parameters(): + p.requires_grad = True + + if finetune: + for p in model.parameters(): + p.requires_grad = False + + try: + _unfreeze(model.agent_encoder.token_predict_head) + log.info("Unfreezing token_predict_head") + except: + log.info("No token_predict_head in model.agent_encoder") + + try: + _unfreeze(model.agent_encoder.gmm_logits_head) + _unfreeze(model.agent_encoder.gmm_pose_head) + # _unfreeze(model.agent_encoder.gmm_gmm_covpose_head) + log.info("Unfreezing gmm heads") + except: + log.info("No gmm_logits_head in model.agent_encoder") + + _unfreeze(model.agent_encoder.t_attn_layers) + _unfreeze(model.agent_encoder.pt2a_attn_layers) + _unfreeze(model.agent_encoder.a2a_attn_layers) diff --git a/backups/thirdparty/catk/src/smart/utils/geometry.py b/backups/thirdparty/catk/src/smart/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..949a7c986d58899c89bbb3cfdcb144677137f99c --- /dev/null +++ b/backups/thirdparty/catk/src/smart/utils/geometry.py @@ -0,0 +1,32 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import math + +import torch + + +def angle_between_2d_vectors( + ctr_vector: torch.Tensor, nbr_vector: torch.Tensor +) -> torch.Tensor: + return torch.atan2( + ctr_vector[..., 0] * nbr_vector[..., 1] + - ctr_vector[..., 1] * nbr_vector[..., 0], + (ctr_vector[..., :2] * nbr_vector[..., :2]).sum(dim=-1), + ) + + +def wrap_angle( + angle: torch.Tensor, min_val: float = -math.pi, max_val: float = math.pi +) -> torch.Tensor: + return min_val + (angle + max_val) % (max_val - min_val) diff --git a/backups/thirdparty/catk/src/smart/utils/preprocess.py b/backups/thirdparty/catk/src/smart/utils/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..2340aa6735ca7a3284be88694e47e24277b71e86 --- /dev/null +++ b/backups/thirdparty/catk/src/smart/utils/preprocess.py @@ -0,0 +1,177 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from typing import Any, Dict + +import numpy as np +import torch +from scipy.interpolate import interp1d + + +def get_polylines_from_polygon(polygon: np.ndarray) -> np.ndarray: + # polygon: [4, 3] + l1 = np.linalg.norm(polygon[1, :2] - polygon[0, :2]) + l2 = np.linalg.norm(polygon[2, :2] - polygon[1, :2]) + + def _pl_interp_start_end(start: np.ndarray, end: np.ndarray) -> np.ndarray: + length = np.linalg.norm(start - end) + unit_vec = (end - start) / length + pl = [] + for i in range(int(length) + 1): # 4.5 -> 5 [0,1,2,3,4] + x, y, z = start + unit_vec * i + pl.append([x, y, z]) + pl.append([end[0], end[1], end[2]]) + return np.array(pl) + + if l1 > l2: + pl1 = _pl_interp_start_end(polygon[0], polygon[1]) + pl2 = _pl_interp_start_end(polygon[2], polygon[3]) + else: + pl1 = _pl_interp_start_end(polygon[0], polygon[3]) + pl2 = _pl_interp_start_end(polygon[2], polygon[1]) + return np.concatenate([pl1, pl1[::-1], pl2, pl2[::-1]], axis=0) + + +def _interplating_polyline(polylines, distance=0.5, split_distace=5): + # Calculate the cumulative distance along the path, up-sample the polyline to 0.5 meter + dist_along_path_list = [] + polylines_list = [] + euclidean_dists = np.linalg.norm(polylines[1:, :2] - polylines[:-1, :2], axis=-1) + euclidean_dists = np.concatenate([[0], euclidean_dists]) + breakpoints = np.where(euclidean_dists > 3)[0] + breakpoints = np.concatenate([[0], breakpoints, [polylines.shape[0]]]) + for i in range(1, breakpoints.shape[0]): + start = breakpoints[i - 1] + end = breakpoints[i] + dist_along_path_list.append( + np.cumsum(euclidean_dists[start:end]) - euclidean_dists[start] + ) + polylines_list.append(polylines[start:end]) + + multi_polylines_list = [] + for idx in range(len(dist_along_path_list)): + if len(dist_along_path_list[idx]) < 2: + continue + dist_along_path = dist_along_path_list[idx] + polylines_cur = polylines_list[idx] + # Create interpolation functions for x and y coordinates + fxy = interp1d(dist_along_path, polylines_cur, axis=0) + + # Create an array of distances at which to interpolate + new_dist_along_path = np.arange(0, dist_along_path[-1], distance) + new_dist_along_path = np.concatenate( + [new_dist_along_path, dist_along_path[[-1]]] + ) + + # Combine the new x and y coordinates into a single array + new_polylines = fxy(new_dist_along_path) + polyline_size = int(split_distace / distance) + if new_polylines.shape[0] >= (polyline_size + 1): + padding_size = ( + new_polylines.shape[0] - (polyline_size + 1) + ) % polyline_size + final_index = ( + new_polylines.shape[0] - (polyline_size + 1) + ) // polyline_size + 1 + else: + padding_size = new_polylines.shape[0] + final_index = 0 + multi_polylines = None + new_polylines = torch.from_numpy(new_polylines) + new_heading = torch.atan2( + new_polylines[1:, 1] - new_polylines[:-1, 1], + new_polylines[1:, 0] - new_polylines[:-1, 0], + ) + new_heading = torch.cat([new_heading, new_heading[-1:]], -1)[..., None] + new_polylines = torch.cat([new_polylines, new_heading], -1) + if new_polylines.shape[0] >= (polyline_size + 1): + multi_polylines = new_polylines.unfold( + dimension=0, size=polyline_size + 1, step=polyline_size + ) + multi_polylines = multi_polylines.transpose(1, 2) + multi_polylines = multi_polylines[:, ::5, :] + if padding_size >= 3: + last_polyline = new_polylines[final_index * polyline_size :] + last_polyline = last_polyline[ + torch.linspace(0, last_polyline.shape[0] - 1, steps=3).long() + ] + if multi_polylines is not None: + multi_polylines = torch.cat( + [multi_polylines, last_polyline.unsqueeze(0)], dim=0 + ) + else: + multi_polylines = last_polyline.unsqueeze(0) + if multi_polylines is None: + continue + multi_polylines_list.append(multi_polylines) + if len(multi_polylines_list) > 0: + multi_polylines_list = torch.cat(multi_polylines_list, dim=0).to(torch.float32) + else: + multi_polylines_list = None + return multi_polylines_list + + +def preprocess_map(map_data: Dict[str, Any]) -> Dict[str, Any]: + pt2pl = map_data[("map_point", "to", "map_polygon")]["edge_index"] + split_polyline_type = [] + split_polyline_pos = [] + split_polyline_theta = [] + split_polygon_type = [] + split_light_type = [] + + for i in sorted(torch.unique(pt2pl[1])): + index = pt2pl[0, pt2pl[1] == i] + if len(index) <= 2: + continue + + polygon_type = map_data["map_polygon"]["type"][i] + light_type = map_data["map_polygon"]["light_type"][i] + cur_type = map_data["map_point"]["type"][index] + cur_pos = map_data["map_point"]["position"][index, :2] + + # assert len(np.unique(cur_type)) == 1 + + split_polyline = _interplating_polyline(cur_pos.numpy()) + if split_polyline is None: + continue + split_polyline_pos.append(split_polyline[..., :2]) + split_polyline_theta.append(split_polyline[..., 2]) + split_polyline_type.append(cur_type[0].repeat(split_polyline.shape[0])) + split_polygon_type.append(polygon_type.repeat(split_polyline.shape[0])) + split_light_type.append(light_type.repeat(split_polyline.shape[0])) + + data = {} + if len(split_polyline_pos) == 0: # add dummy empty map + data["map_save"] = { + # 6e4 such that it's within the range of float16. + "traj_pos": torch.zeros([1, 3, 2], dtype=torch.float32) + 6e4, + "traj_theta": torch.zeros([1], dtype=torch.float32), + } + data["pt_token"] = { + "type": torch.tensor([0], dtype=torch.uint8), + "pl_type": torch.tensor([0], dtype=torch.uint8), + "light_type": torch.tensor([0], dtype=torch.uint8), + "num_nodes": 1, + } + else: + data["map_save"] = { + "traj_pos": torch.cat(split_polyline_pos, dim=0), # [num_nodes, 3, 2] + "traj_theta": torch.cat(split_polyline_theta, dim=0)[:, 0], # [num_nodes] + } + data["pt_token"] = { + "type": torch.cat(split_polyline_type, dim=0), # [num_nodes], uint8 + "pl_type": torch.cat(split_polygon_type, dim=0), # [num_nodes], uint8 + "light_type": torch.cat(split_light_type, dim=0), # [num_nodes], uint8 + "num_nodes": data["map_save"]["traj_pos"].shape[0], + } + return data diff --git a/backups/thirdparty/catk/src/smart/utils/rollout.py b/backups/thirdparty/catk/src/smart/utils/rollout.py new file mode 100644 index 0000000000000000000000000000000000000000..4a062287eabdb5d68bafb0887e43756cc2be9aa1 --- /dev/null +++ b/backups/thirdparty/catk/src/smart/utils/rollout.py @@ -0,0 +1,293 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from typing import Optional, Tuple + +import torch +from omegaconf import DictConfig +from torch import Tensor +from torch.distributions import Categorical, Independent, MixtureSameFamily, Normal + + +@torch.no_grad() +def cal_polygon_contour( + pos: Tensor, # [n_agent, n_step, n_target, 2] + head: Tensor, # [n_agent, n_step, n_target] + width_length: Tensor, # [n_agent, 1, 1, 2] +) -> Tensor: # [n_agent, n_step, n_target, 4, 2] + x, y = pos[..., 0], pos[..., 1] # [n_agent, n_step, n_target] + width, length = width_length[..., 0], width_length[..., 1] # [n_agent, 1 ,1] + + half_cos = 0.5 * head.cos() # [n_agent, n_step, n_target] + half_sin = 0.5 * head.sin() # [n_agent, n_step, n_target] + length_cos = length * half_cos # [n_agent, n_step, n_target] + length_sin = length * half_sin # [n_agent, n_step, n_target] + width_cos = width * half_cos # [n_agent, n_step, n_target] + width_sin = width * half_sin # [n_agent, n_step, n_target] + + left_front_x = x + length_cos - width_sin + left_front_y = y + length_sin + width_cos + left_front = torch.stack((left_front_x, left_front_y), dim=-1) + + right_front_x = x + length_cos + width_sin + right_front_y = y + length_sin - width_cos + right_front = torch.stack((right_front_x, right_front_y), dim=-1) + + right_back_x = x - length_cos + width_sin + right_back_y = y - length_sin - width_cos + right_back = torch.stack((right_back_x, right_back_y), dim=-1) + + left_back_x = x - length_cos - width_sin + left_back_y = y - length_sin + width_cos + left_back = torch.stack((left_back_x, left_back_y), dim=-1) + + polygon_contour = torch.stack( + (left_front, right_front, right_back, left_back), dim=-2 + ) + + return polygon_contour + + +def transform_to_global( + pos_local: Tensor, # [n_agent, n_step, 2] + head_local: Optional[Tensor], # [n_agent, n_step] + pos_now: Tensor, # [n_agent, 2] + head_now: Tensor, # [n_agent] +) -> Tuple[Tensor, Optional[Tensor]]: + cos, sin = head_now.cos(), head_now.sin() + rot_mat = torch.zeros((head_now.shape[0], 2, 2), device=head_now.device) + rot_mat[:, 0, 0] = cos + rot_mat[:, 0, 1] = sin + rot_mat[:, 1, 0] = -sin + rot_mat[:, 1, 1] = cos + + pos_global = torch.bmm(pos_local, rot_mat) # [n_agent, n_step, 2]*[n_agent, 2, 2] + pos_global = pos_global + pos_now.unsqueeze(1) + if head_local is None: + head_global = None + else: + head_global = head_local + head_now.unsqueeze(1) + return pos_global, head_global + + +def transform_to_local( + pos_global: Tensor, # [n_agent, n_step, 2] + head_global: Optional[Tensor], # [n_agent, n_step] + pos_now: Tensor, # [n_agent, 2] + head_now: Tensor, # [n_agent] +) -> Tuple[Tensor, Optional[Tensor]]: + cos, sin = head_now.cos(), head_now.sin() + rot_mat = torch.zeros((head_now.shape[0], 2, 2), device=head_now.device) + rot_mat[:, 0, 0] = cos + rot_mat[:, 0, 1] = -sin + rot_mat[:, 1, 0] = sin + rot_mat[:, 1, 1] = cos + + pos_local = pos_global - pos_now.unsqueeze(1) + pos_local = torch.bmm(pos_local, rot_mat) # [n_agent, n_step, 2]*[n_agent, 2, 2] + if head_global is None: + head_local = None + else: + head_local = head_global - head_now.unsqueeze(1) + return pos_local, head_local + + +def sample_next_token_traj( + token_traj: Tensor, # [n_agent, n_token, 4, 2] + token_traj_all: Tensor, # [n_agent, n_token, 6, 4, 2] + sampling_scheme: DictConfig, + # ! for most-likely sampling + next_token_logits: Tensor, # [n_agent, n_token], with grad + # ! for nearest-pos sampling, sampling near to GT + pos_now: Tensor, # [n_agent, 2] + head_now: Tensor, # [n_agent] + pos_next_gt: Tensor, # [n_agent, 2] + head_next_gt: Tensor, # [n_agent] + valid_next_gt: Tensor, # [n_agent] + token_agent_shape: Tensor, # [n_agent, 2] +) -> Tuple[Tensor, Tensor]: + """ + Returns: + next_token_traj_all: [n_agent, 6, 4, 2], local coord + next_token_idx: [n_agent], without grad + """ + range_a = torch.arange(next_token_logits.shape[0]) + next_token_logits = next_token_logits.detach() + + if ( + sampling_scheme.criterium == "topk_prob" + or sampling_scheme.criterium == "topk_prob_sampled_with_dist" + ): + topk_logits, topk_indices = torch.topk( + next_token_logits, sampling_scheme.num_k, dim=-1, sorted=False + ) + if sampling_scheme.criterium == "topk_prob_sampled_with_dist": + #! gt_contour: [n_agent, 4, 2] in global coord + gt_contour = cal_polygon_contour( + pos_next_gt, head_next_gt, token_agent_shape + ) + gt_contour = gt_contour.unsqueeze(1) # [n_agent, 1, 4, 2] + token_world_sample = token_traj[range_a.unsqueeze(1), topk_indices] + token_world_sample = transform_to_global( + pos_local=token_world_sample.flatten(1, 2), + head_local=None, + pos_now=pos_now, # [n_agent, 2] + head_now=head_now, # [n_agent] + )[0].view(*token_world_sample.shape) + + # dist: [n_agent, n_token] + dist = torch.norm(token_world_sample - gt_contour, dim=-1).mean(-1) + topk_logits = topk_logits.masked_fill( + valid_next_gt.unsqueeze(1), 0.0 + ) - 1.0 * dist.masked_fill(~valid_next_gt.unsqueeze(1), 0.0) + elif sampling_scheme.criterium == "topk_dist_sampled_with_prob": + #! gt_contour: [n_agent, 4, 2] in global coord + gt_contour = cal_polygon_contour(pos_next_gt, head_next_gt, token_agent_shape) + gt_contour = gt_contour.unsqueeze(1) # [n_agent, 1, 4, 2] + token_world_sample = transform_to_global( + pos_local=token_traj.flatten(1, 2), # [n_agent, n_token*4, 2] + head_local=None, + pos_now=pos_now, # [n_agent, 2] + head_now=head_now, # [n_agent] + )[0].view(*token_traj.shape) + + _invalid = ~valid_next_gt + # dist: [n_agent, n_token] + dist = torch.norm(token_world_sample - gt_contour, dim=-1).mean(-1) + _logits = -1.0 * dist.masked_fill(_invalid.unsqueeze(1), 0.0) + + if _invalid.any(): + _logits[_invalid] = next_token_logits[_invalid] + _, topk_indices = torch.topk( + _logits, sampling_scheme.num_k, dim=-1, sorted=False + ) # [n_agent, K] + topk_logits = next_token_logits[range_a.unsqueeze(1), topk_indices] + + else: + raise ValueError(f"Invalid criterium: {sampling_scheme.criterium}") + + # topk_logits, topk_indices: [n_agent, K] + topk_logits = topk_logits / sampling_scheme.temp + samples = Categorical(logits=topk_logits).sample() # [n_agent] in K + next_token_idx = topk_indices[range_a, samples] + next_token_traj_all = token_traj_all[range_a, next_token_idx] + + return next_token_idx, next_token_traj_all + + +def sample_next_gmm_traj( + token_traj: Tensor, # [n_agent, n_token, 4, 2] + token_traj_all: Tensor, # [n_agent, n_token, 6, 4, 2] + sampling_scheme: DictConfig, + # ! for most-likely sampling + ego_mask: Tensor, # [n_agent], bool, ego_mask.sum()==n_batch + ego_next_logits: Tensor, # [n_batch, n_k_ego_gmm] + ego_next_poses: Tensor, # [n_batch, n_k_ego_gmm, 3] + ego_next_cov: Tensor, # [2], one for pos, one for heading. + # ! for nearest-pos sampling, sampling near to GT + pos_now: Tensor, # [n_agent, 2] + head_now: Tensor, # [n_agent] + pos_next_gt: Tensor, # [n_agent, 2] + head_next_gt: Tensor, # [n_agent] + valid_next_gt: Tensor, # [n_agent] + token_agent_shape: Tensor, # [n_agent, 2] + next_token_idx: Tensor, # [n_agent] +) -> Tuple[Tensor, Tensor]: + """ + Returns: + next_token_traj_all: [n_agent, 6, 4, 2], local coord + next_token_idx: [n_agent], without grad + """ + n_agent = token_traj.shape[0] + n_batch = ego_next_logits.shape[0] + next_token_traj_all = token_traj_all[torch.arange(n_agent), next_token_idx] + + # ! sample only the ego-vehicle + assert ( + sampling_scheme.criterium == "topk_prob" + or sampling_scheme.criterium == "topk_prob_sampled_with_dist" + ) + topk_logits, topk_indices = torch.topk( + ego_next_logits, sampling_scheme.num_k, dim=-1, sorted=False + ) # [n_agent, k], [n_agent, k] + ego_pose_topk = ego_next_poses[ + torch.arange(n_batch).unsqueeze(1), topk_indices + ] # [n_batch, k, 3] + + if sampling_scheme.criterium == "topk_prob_sampled_with_dist": + # udpate topk_logits + gt_contour = cal_polygon_contour( + pos_next_gt[ego_mask], + head_next_gt[ego_mask], + token_agent_shape[ego_mask], + ) # [n_batch, 4, 2] in global coord + gt_contour = gt_contour.unsqueeze(1) # [n_batch, 1, 4, 2] + + ego_pos_global, ego_head_global = transform_to_global( + pos_local=ego_pose_topk[:, :, :2], # [n_batch, k, 2] + head_local=ego_pose_topk[:, :, -1], # [n_batch, k] + pos_now=pos_now[ego_mask], # [n_batch, 2] + head_now=head_now[ego_mask], # [n_batch] + ) + ego_contour = cal_polygon_contour( + ego_pos_global, # [n_batch, k, 2] + ego_head_global, # [n_batch, k] + token_agent_shape[ego_mask].unsqueeze(1), + ) # [n_batch, k, 4, 2] in global coord + + dist = torch.norm(ego_contour - gt_contour, dim=-1).mean(-1) # [n_batch, k] + topk_logits = topk_logits.masked_fill( + valid_next_gt[ego_mask].unsqueeze(1), 0.0 + ) - 1.0 * dist.masked_fill(~valid_next_gt[ego_mask].unsqueeze(1), 0.0) + + topk_logits = topk_logits / sampling_scheme.temp_mode # [n_batch, k] + ego_pose_topk = torch.cat( + [ + ego_pose_topk[..., :2], + ego_pose_topk[..., [-1]].cos(), + ego_pose_topk[..., [-1]].sin(), + ], + dim=-1, + ) + cov = ( + (ego_next_cov * sampling_scheme.temp_cov) + .repeat_interleave(2)[None, None, :] + .expand(*ego_pose_topk.shape) + ) # [n_batch, k, 4] + gmm = MixtureSameFamily( + Categorical(logits=topk_logits), Independent(Normal(ego_pose_topk, cov), 1) + ) + ego_sample = gmm.sample() # [n_batch, 4] + + ego_contour_local = cal_polygon_contour( + ego_sample[:, :2], # [n_batch, 2] + torch.arctan2(ego_sample[:, -1], ego_sample[:, -2]), # [n_batch] + token_agent_shape[ego_mask], # [n_batch, 2] + ) # [n_batch, 4, 2] in local coord + + ego_token_local = token_traj[ego_mask] # [n_batch, n_token, 4, 2] + + dist = torch.norm(ego_contour_local.unsqueeze(1) - ego_token_local, dim=-1).mean( + -1 + ) # [n_batch, n_token] + next_token_idx[ego_mask] = dist.argmin(-1) + + ego_contour_local # [n_batch, 4, 2] in local coord + ego_countour_start = next_token_traj_all[ego_mask][:, 0] # [n_batch, 4, 2] + n_step = next_token_traj_all.shape[1] + diff = (ego_contour_local - ego_countour_start) / (n_step - 1) + ego_token_interp = [ego_countour_start + diff * i for i in range(n_step)] + # [n_batch, 6, 4, 2] + next_token_traj_all[ego_mask] = torch.stack(ego_token_interp, dim=1) + + return next_token_idx, next_token_traj_all diff --git a/backups/thirdparty/catk/src/smart/utils/weight_init.py b/backups/thirdparty/catk/src/smart/utils/weight_init.py new file mode 100644 index 0000000000000000000000000000000000000000..a507cf3017ec9612939e51edad22b7ec59a4d0b7 --- /dev/null +++ b/backups/thirdparty/catk/src/smart/utils/weight_init.py @@ -0,0 +1,69 @@ +import torch.nn as nn + + +def weight_init(m: nn.Module) -> None: + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + fan_in = m.in_channels / m.groups + fan_out = m.out_channels / m.groups + bound = (6.0 / (fan_in + fan_out)) ** 0.5 + nn.init.uniform_(m.weight, -bound, bound) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, mean=0.0, std=0.02) + elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.MultiheadAttention): + if m.in_proj_weight is not None: + fan_in = m.embed_dim + fan_out = m.embed_dim + bound = (6.0 / (fan_in + fan_out)) ** 0.5 + nn.init.uniform_(m.in_proj_weight, -bound, bound) + else: + nn.init.xavier_uniform_(m.q_proj_weight) + nn.init.xavier_uniform_(m.k_proj_weight) + nn.init.xavier_uniform_(m.v_proj_weight) + if m.in_proj_bias is not None: + nn.init.zeros_(m.in_proj_bias) + nn.init.xavier_uniform_(m.out_proj.weight) + if m.out_proj.bias is not None: + nn.init.zeros_(m.out_proj.bias) + if m.bias_k is not None: + nn.init.normal_(m.bias_k, mean=0.0, std=0.02) + if m.bias_v is not None: + nn.init.normal_(m.bias_v, mean=0.0, std=0.02) + elif isinstance(m, (nn.LSTM, nn.LSTMCell)): + for name, param in m.named_parameters(): + if "weight_ih" in name: + for ih in param.chunk(4, 0): + nn.init.xavier_uniform_(ih) + elif "weight_hh" in name: + for hh in param.chunk(4, 0): + nn.init.orthogonal_(hh) + elif "weight_hr" in name: + nn.init.xavier_uniform_(param) + elif "bias_ih" in name: + nn.init.zeros_(param) + elif "bias_hh" in name: + nn.init.zeros_(param) + nn.init.ones_(param.chunk(4, 0)[1]) + elif isinstance(m, (nn.GRU, nn.GRUCell)): + for name, param in m.named_parameters(): + if "weight_ih" in name: + for ih in param.chunk(3, 0): + nn.init.xavier_uniform_(ih) + elif "weight_hh" in name: + for hh in param.chunk(3, 0): + nn.init.orthogonal_(hh) + elif "bias_ih" in name: + nn.init.zeros_(param) + elif "bias_hh" in name: + nn.init.zeros_(param) diff --git a/backups/thirdparty/catk/src/utils/__init__.py b/backups/thirdparty/catk/src/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..86975931033d4afeb3f4d7f69541554f8574e9b6 --- /dev/null +++ b/backups/thirdparty/catk/src/utils/__init__.py @@ -0,0 +1,17 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from src.utils.instantiators import instantiate_callbacks, instantiate_loggers +from src.utils.logging_utils import log_hyperparameters +from src.utils.pylogger import RankedLogger +from src.utils.rich_utils import print_config_tree diff --git a/backups/thirdparty/catk/src/utils/instantiators.py b/backups/thirdparty/catk/src/utils/instantiators.py new file mode 100644 index 0000000000000000000000000000000000000000..e2f3f0a859d32ce9909947def19160b24f46785f --- /dev/null +++ b/backups/thirdparty/catk/src/utils/instantiators.py @@ -0,0 +1,69 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from typing import List, Optional + +import hydra +from lightning import Callback +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig + +from . import pylogger + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config. + + :param callbacks_cfg: A DictConfig object containing callback configurations. + :return: A list of instantiated callbacks. + """ + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("No callback configs found! Skipping..") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config. + + :param logger_cfg: A DictConfig object containing logger configurations. + :return: A list of instantiated loggers. + """ + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("No logger configs found! Skipping...") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger diff --git a/backups/thirdparty/catk/src/utils/logging_utils.py b/backups/thirdparty/catk/src/utils/logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c714d48174da2a94a1af23d49bfcb05a90337c23 --- /dev/null +++ b/backups/thirdparty/catk/src/utils/logging_utils.py @@ -0,0 +1,70 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from typing import Any, Dict + +from lightning_utilities.core.rank_zero import rank_zero_only +from omegaconf import OmegaConf + +from . import pylogger + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +@rank_zero_only +def log_hyperparameters(object_dict: Dict[str, Any]) -> None: + """Controls which config parts are saved by Lightning loggers. + + Additionally saves: + - Number of model parameters + + :param object_dict: A dictionary containing the following objects: + - `"cfg"`: A DictConfig object containing the main config. + - `"model"`: The Lightning model. + - `"trainer"`: The Lightning trainer. + """ + hparams = {} + + cfg = OmegaConf.to_container(object_dict["cfg"]) + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + hparams["model"] = cfg["model"] + + # save number of model parameters + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + hparams["model/params/non_trainable"] = sum( + p.numel() for p in model.parameters() if not p.requires_grad + ) + + hparams["data"] = cfg["data"] + hparams["trainer"] = cfg["trainer"] + + hparams["callbacks"] = cfg.get("callbacks") + hparams["extras"] = cfg.get("extras") + + hparams["task_name"] = cfg.get("task_name") + hparams["ckpt_path"] = cfg.get("ckpt_path") + hparams["seed"] = cfg.get("seed") + hparams["train_job_id"] = cfg.get("train_job_id") + + # send hparams to all loggers + for logger in trainer.loggers: + logger.log_hyperparams(hparams) diff --git a/backups/thirdparty/catk/src/utils/pylogger.py b/backups/thirdparty/catk/src/utils/pylogger.py new file mode 100644 index 0000000000000000000000000000000000000000..8734c2d5bea9150259f6afea90872887a8f06704 --- /dev/null +++ b/backups/thirdparty/catk/src/utils/pylogger.py @@ -0,0 +1,68 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import logging +from typing import Mapping, Optional + +from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only + + +class RankedLogger(logging.LoggerAdapter): + """A multi-GPU-friendly python command line logger.""" + + def __init__( + self, + name: str = __name__, + rank_zero_only: bool = False, + extra: Optional[Mapping[str, object]] = None, + ) -> None: + """Initializes a multi-GPU-friendly python command line logger that logs on all processes + with their rank prefixed in the log message. + + :param name: The name of the logger. Default is ``__name__``. + :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. + :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. + """ + logger = logging.getLogger(name) + super().__init__(logger=logger, extra=extra) + self.rank_zero_only = rank_zero_only + + def log( + self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs + ) -> None: + """Delegate a log call to the underlying logger, after prefixing its message with the rank + of the process it's being logged from. If `'rank'` is provided, then the log will only + occur on that rank/process. + + :param level: The level to log at. Look at `logging.__init__.py` for more information. + :param msg: The message to log. + :param rank: The rank to log at. + :param args: Additional args to pass to the underlying logging function. + :param kwargs: Any additional keyword args to pass to the underlying logging function. + """ + if self.isEnabledFor(level): + msg, kwargs = self.process(msg, kwargs) + current_rank = getattr(rank_zero_only, "rank", None) + if current_rank is None: + raise RuntimeError( + "The `rank_zero_only.rank` needs to be set before use" + ) + msg = rank_prefixed_message(msg, current_rank) + if self.rank_zero_only: + if current_rank == 0: + self.logger.log(level, msg, *args, **kwargs) + else: + if rank is None: + self.logger.log(level, msg, *args, **kwargs) + elif current_rank == rank: + self.logger.log(level, msg, *args, **kwargs) diff --git a/backups/thirdparty/catk/src/utils/rich_utils.py b/backups/thirdparty/catk/src/utils/rich_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2b1f226a0bce840a4e00e4cfa1086f571b3fb818 --- /dev/null +++ b/backups/thirdparty/catk/src/utils/rich_utils.py @@ -0,0 +1,89 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from pathlib import Path +from typing import Sequence + +import rich +import rich.syntax +import rich.tree +from lightning_utilities.core.rank_zero import rank_zero_only +from omegaconf import DictConfig, OmegaConf + +from . import pylogger + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +@rank_zero_only +def print_config_tree( + cfg: DictConfig, + print_order: Sequence[str] = ( + "data", + "model", + "callbacks", + "logger", + "trainer", + "paths", + "extras", + ), + resolve: bool = False, + save_to_file: bool = False, +) -> None: + """Prints the contents of a DictConfig as a tree structure using the Rich library. + + :param cfg: A DictConfig composed by Hydra. + :param print_order: Determines in what order config components are printed. Default is ``("data", "model", + "callbacks", "logger", "trainer", "paths", "extras")``. + :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. + :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. + """ + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + queue = [] + + # add fields from `print_order` to queue + for field in print_order: + ( + queue.append(field) + if field in cfg + else log.warning( + f"Field '{field}' not found in config. Skipping '{field}' config printing..." + ) + ) + + # add all the other fields to queue (not specified in `print_order`) + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) + else: + branch_content = str(config_group) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + # print config tree + rich.print(tree) + + # save config tree to file + if save_to_file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + rich.print(tree, file=file) diff --git a/backups/thirdparty/catk/src/utils/video_recorder.py b/backups/thirdparty/catk/src/utils/video_recorder.py new file mode 100644 index 0000000000000000000000000000000000000000..7a20a7c1d36e4a8c851b0106bb00cc0aeb5a6cd2 --- /dev/null +++ b/backups/thirdparty/catk/src/utils/video_recorder.py @@ -0,0 +1,142 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import distutils.spawn +import distutils.version +import os +import os.path +import pkgutil +import subprocess + +import numpy as np + + +class ImageEncoder(object): + def __init__(self, output_path, frame_shape, frames_per_sec, output_frames_per_sec): + self.proc = None + self.output_path = output_path + # Frame shape should be lines-first, so w and h are swapped + h, w, pixfmt = frame_shape + if pixfmt != 3 and pixfmt != 4: + raise RuntimeError( + "Your frame has shape {}, but we require (w,h,3) or (w,h,4), i.e., RGB values for a w-by-h image, with an optional alpha channel.".format( + frame_shape + ) + ) + self.wh = (w, h) + self.includes_alpha = pixfmt == 4 + self.frame_shape = frame_shape + self.frames_per_sec = frames_per_sec + self.output_frames_per_sec = output_frames_per_sec + + if distutils.spawn.find_executable("avconv") is not None: + self.backend = "avconv" + elif distutils.spawn.find_executable("ffmpeg") is not None: + self.backend = "ffmpeg" + elif pkgutil.find_loader("imageio_ffmpeg"): + raise RuntimeError + # import imageio_ffmpeg + # self.backend = imageio_ffmpeg.get_ffmpeg_exe() + else: + raise RuntimeError( + """Found neither the ffmpeg nor avconv executables. On OS X, you can install ffmpeg via `brew install ffmpeg`. On most Ubuntu variants, `sudo apt-get install ffmpeg` should do it. On Ubuntu 14.04, however, you'll need to install avconv with `sudo apt-get install libav-tools`. Alternatively, please install imageio-ffmpeg with `pip install imageio-ffmpeg`""" + ) + + self.start() + + @property + def version_info(self): + return { + "backend": self.backend, + "version": str( + subprocess.check_output( + [self.backend, "-version"], stderr=subprocess.STDOUT + ) + ), + "cmdline": self.cmdline, + } + + def start(self): + self.cmdline = ( + self.backend, + "-nostats", + "-loglevel", + "error", # suppress warnings + "-y", + # input + "-f", + "rawvideo", + "-s:v", + "{}x{}".format(*self.wh), + "-pix_fmt", + ("rgb32" if self.includes_alpha else "rgb24"), + "-framerate", + "%d" % self.frames_per_sec, + "-i", + "-", # this used to be /dev/stdin, which is not Windows-friendly + # output + "-vf", + "scale=trunc(iw/2)*2:trunc(ih/2)*2", + "-vcodec", + "libx264", + "-pix_fmt", + "yuv420p", + "-r", + "%d" % self.output_frames_per_sec, + self.output_path, + ) + + # print('Starting %s with "%s"', self.backend, " ".join(self.cmdline)) + if hasattr(os, "setsid"): # setsid not present on Windows + self.proc = subprocess.Popen( + self.cmdline, stdin=subprocess.PIPE, preexec_fn=os.setsid + ) + else: + self.proc = subprocess.Popen(self.cmdline, stdin=subprocess.PIPE) + + def capture_frame(self, frame): + if not isinstance(frame, (np.ndarray, np.generic)): + raise RuntimeError( + "Wrong type {} for {} (must be np.ndarray or np.generic)".format( + type(frame), frame + ) + ) + if frame.shape != self.frame_shape: + raise RuntimeError( + "Your frame has shape {}, but the VideoRecorder is configured for shape {}.".format( + frame.shape, self.frame_shape + ) + ) + if frame.dtype != np.uint8: + raise RuntimeError( + "Your frame has data type {}, but we require uint8 (i.e. RGB values from 0-255).".format( + frame.dtype + ) + ) + + try: + if distutils.version.LooseVersion( + np.__version__ + ) >= distutils.version.LooseVersion("1.9.0"): + self.proc.stdin.write(frame.tobytes()) + else: + self.proc.stdin.write(frame.tostring()) + except Exception as e: + stdout, stderr = self.proc.communicate() + print("VideoRecorder encoder failed: %s", stderr) + + def close(self): + self.proc.stdin.close() + ret = self.proc.wait() + if ret != 0: + print("VideoRecorder encoder exited with status {}".format(ret)) diff --git a/backups/thirdparty/catk/src/utils/vis_waymo.py b/backups/thirdparty/catk/src/utils/vis_waymo.py new file mode 100644 index 0000000000000000000000000000000000000000..a527ebf90c8521a4229daeb3db8c3c88f0a87f5c --- /dev/null +++ b/backups/thirdparty/catk/src/utils/vis_waymo.py @@ -0,0 +1,550 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from copy import deepcopy +from pathlib import Path +from typing import List, Tuple + +import cv2 +import numpy as np +import tensorflow as tf +from waymo_open_dataset.protos import scenario_pb2, sim_agents_submission_pb2 + +from .video_recorder import ImageEncoder + +COLOR_BLACK = (0, 0, 0) +COLOR_WHITE = (255, 255, 255) +COLOR_RED = (255, 0, 0) +COLOR_GREEN = (0, 255, 0) +COLOR_CYAN = (0, 255, 255) +COLOR_MAGENTA = (255, 0, 255) +COLOR_YELLOW = (255, 255, 0) +COLOR_VIOLET = (170, 0, 255) +COLOR_BUTTER = (252, 233, 79) +COLOR_ORANGE = (209, 92, 0) +COLOR_CHOCOLATE = (143, 89, 2) +COLOR_CHAMELEON = (78, 154, 6) +COLOR_SKY_BLUE_0 = (114, 159, 207) +COLOR_SKY_BLUE_1 = (32, 74, 135) +COLOR_PLUM = (92, 53, 102) +COLOR_SCARLET_RED = (164, 0, 0) +COLOR_ALUMINIUM_0 = (238, 238, 236) +COLOR_ALUMINIUM_1 = (211, 215, 207) +COLOR_ALUMINIUM_2 = (66, 62, 64) + + +class VisWaymo: + def __init__( + self, + scenario_path: str, + save_dir: Path, + px_per_m: float = 10.0, + video_size: int = 960, + n_step: int = 91, + step_current: int = 10, + vis_ghost_gt: bool = True, + ) -> None: + self.px_per_m = px_per_m + self.video_size = video_size + self.n_step = n_step + self.step_current = step_current + self.px_agent2bottom = video_size // 2 + self.vis_ghost_gt = vis_ghost_gt + + # colors + self.lane_style = [ + (COLOR_WHITE, 6), # FREEWAY = 0 + (COLOR_ALUMINIUM_2, 6), # SURFACE_STREET = 1 + (COLOR_ORANGE, 6), # STOP_SIGN = 2 + (COLOR_CHOCOLATE, 6), # BIKE_LANE = 3 + (COLOR_SKY_BLUE_1, 4), # TYPE_ROAD_EDGE_BOUNDARY = 4 + (COLOR_PLUM, 4), # TYPE_ROAD_EDGE_MEDIAN = 5 + (COLOR_BUTTER, 2), # BROKEN = 6 + (COLOR_MAGENTA, 2), # SOLID_SINGLE = 7 + (COLOR_SCARLET_RED, 2), # DOUBLE = 8 + (COLOR_CHAMELEON, 4), # SPEED_BUMP = 9 + (COLOR_SKY_BLUE_0, 4), # CROSSWALK = 10 + ] + + self.tl_style = [ + COLOR_ALUMINIUM_1, # STATE_UNKNOWN = 0; + COLOR_RED, # STOP = 1; + COLOR_YELLOW, # CAUTION = 2; + COLOR_GREEN, # GO = 3; + COLOR_VIOLET, # FLASHING = 4; + ] + # sdc=0, interest=1, predict=2 + self.agent_role_style = [COLOR_CYAN, COLOR_CHAMELEON, COLOR_MAGENTA] + + self.agent_cmd_txt = [ + "STATIONARY", # STATIONARY = 0; + "STRAIGHT", # STRAIGHT = 1; + "STRAIGHT_LEFT", # STRAIGHT_LEFT = 2; + "STRAIGHT_RIGHT", # STRAIGHT_RIGHT = 3; + "LEFT_U_TURN", # LEFT_U_TURN = 4; + "LEFT_TURN", # LEFT_TURN = 5; + "RIGHT_U_TURN", # RIGHT_U_TURN = 6; + "RIGHT_TURN", # RIGHT_TURN = 7; + ] + + # load tfrecord scenario + scenario = scenario_pb2.Scenario() + for data in tf.data.TFRecordDataset([scenario_path], compression_type=""): + scenario.ParseFromString(bytes(data.numpy())) + break + + # make output dir + self.save_dir = save_dir + self.save_dir.mkdir(exist_ok=True, parents=True) + + # draw gt + mp_xyz, mp_id, mp_type = get_map_features(scenario.map_features) + + tl_lane_state, tl_lane_id = get_traffic_light_features( + scenario.dynamic_map_states + ) + ag_valid, ag_xy, ag_yaw, ag_size, ag_role, ag_id = get_agent_features( + scenario, step_current=step_current + ) + self.ag_id2size = dict(zip(ag_id, ag_size)) + self.ag_id2role = dict(zip(ag_id, ag_role)) + + raster_map, self.top_left_px = self._register_map(mp_xyz, self.px_per_m) + self._draw_map(raster_map, mp_xyz, mp_type) + + im_gt_maps = [raster_map.copy() for _ in range(n_step)] + self._draw_traffic_lights(im_gt_maps, tl_lane_state, tl_lane_id, mp_xyz, mp_id) + + # save gt video and get paths for wandb logging + im_gt = deepcopy(im_gt_maps) + self._draw_agents(im_gt, ag_valid, ag_xy, ag_yaw, ag_size, ag_role) + + gt_video_path = (self.save_dir / "gt.mp4").as_posix() + save_images_to_mp4(im_gt, gt_video_path) + self.video_paths = [gt_video_path] + + # prepare images for drawing prediction on top + self.im_gt_blended = [] + if self.vis_ghost_gt: + im_gt_agents = [np.zeros_like(raster_map) for _ in range(n_step)] + self._draw_agents(im_gt_agents, ag_valid, ag_xy, ag_yaw, ag_size, ag_role) + for i in range(n_step): + self.im_gt_blended.append( + cv2.addWeighted(im_gt_agents[i], 0.6, im_gt_maps[i], 1, 0) + ) + else: + for i in range(n_step): + if i <= 10: + self.im_gt_blended.append(deepcopy(im_gt[i])) + else: + self.im_gt_blended.append(deepcopy(im_gt_maps[i])) + + @staticmethod + def _register_map( + mp_xyz: List[np.ndarray], px_per_m: float, edge_px: int = 100 + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Args: + mp_xyz: len=n_pl, list of np array [n_pl_node, 3] + px_per_m: float + + Returns: + raster_map: empty image + top_left_px + """ + xmin = min([arr[:, 0].min() for arr in mp_xyz]) + xmax = max([arr[:, 0].max() for arr in mp_xyz]) + ymin = min([arr[:, 1].min() for arr in mp_xyz]) + ymax = max([arr[:, 1].max() for arr in mp_xyz]) + map_boundary = np.array([xmin, xmax, ymin, ymax]) + + # y axis is inverted in pixel coordinate + xmin, xmax, ymax, ymin = (map_boundary * px_per_m).astype(np.int64) + ymax *= -1 + ymin *= -1 + xmin -= edge_px + ymin -= edge_px + xmax += edge_px + ymax += edge_px + + raster_map = np.zeros([ymax - ymin, xmax - xmin, 3], dtype=np.uint8) + top_left_px = np.array([xmin, ymin], dtype=np.float32) + return raster_map, top_left_px + + def _draw_map( + self, raster_map: np.ndarray, mp_xyz: List[np.ndarray], mp_type: np.ndarray + ) -> None: + """ + Args: numpy arrays + mp_xyz: len=n_pl, list of np array [n_pl_node, 3] + mp_type: [n_pl], int + + Returns: + draw on raster_map + """ + for i, _type in enumerate(mp_type): + color, thickness = self.lane_style[_type] + cv2.polylines( + raster_map, + [self._to_pixel(mp_xyz[i][:, :2])], + isClosed=False, + color=color, + thickness=thickness, + lineType=cv2.LINE_AA, + ) + + def _draw_traffic_lights( + self, + input_images: List[np.ndarray], + tl_lane_state: List[np.ndarray], + tl_lane_id: List[np.ndarray], + mp_xyz: List[np.ndarray], + mp_id: np.ndarray, + ) -> None: + for step_t, step_image in enumerate(input_images): + if step_t < len(tl_lane_state): + for i_tl, _state in enumerate(tl_lane_state[step_t]): + _lane_id = tl_lane_id[step_t][i_tl] + _lane_idx = np.argwhere(mp_id == _lane_id).item() + pos = self._to_pixel(mp_xyz[_lane_idx][:, :2]) + cv2.polylines( + step_image, + [pos], + isClosed=False, + color=self.tl_style[_state], + thickness=8, + lineType=cv2.LINE_AA, + ) + if _state >= 1 and _state <= 3: + cv2.drawMarker( + step_image, + pos[-1], + color=self.tl_style[_state], + markerType=cv2.MARKER_TILTED_CROSS, + markerSize=10, + thickness=6, + ) + + def _draw_agents( + self, + input_images: List[np.ndarray], + ag_valid: np.ndarray, # [n_ag, n_step], bool + ag_xy: np.ndarray, # [n_ag, n_step, 2], (x,y) + ag_yaw: np.ndarray, # [n_ag, n_step, 1], [-pi, pi] + ag_size: np.ndarray, # [n_ag, 3], [length, width, height] + ag_role: np.ndarray, # [n_ag, 3], one_hot [sdc=0, interest=1, predict=2] + ) -> None: + for step_t, step_image in enumerate(input_images): + if step_t < ag_valid.shape[1]: + _valid = ag_valid[:, step_t] # [n_ag] + _pos = ag_xy[:, step_t] # [n_ag, 2] + _yaw = ag_yaw[:, step_t] # [n_ag, 1] + + bbox_gt = self._to_pixel( + self._get_agent_bbox(_valid, _pos, _yaw, ag_size) + ) + heading_start = self._to_pixel(_pos[_valid]) + _yaw = _yaw[:, 0][_valid] + heading_end = self._to_pixel( + _pos[_valid] + 1.5 * np.stack([np.cos(_yaw), np.sin(_yaw)], axis=-1) + ) + _role = ag_role[_valid] + for i in range(_role.shape[0]): + if not _role[i].any(): + color = COLOR_ALUMINIUM_0 + else: + color = self.agent_role_style[np.where(_role[i])[0].min()] + cv2.fillConvexPoly(step_image, bbox_gt[i], color=color) + cv2.arrowedLine( + step_image, + heading_start[i], + heading_end[i], + color=COLOR_BLACK, + thickness=4, + line_type=cv2.LINE_AA, + tipLength=0.6, + ) + + def save_video_scenario_rollout( + self, + scenario_rollout: sim_agents_submission_pb2.ScenarioRollouts, + n_vis_rollout: int, + ): + for i_rollout in range(n_vis_rollout): + images = deepcopy(self.im_gt_blended) + ag_valid, ag_xy, ag_yaw, ag_size, ag_role = self._get_features_from_trajs( + scenario_rollout.joint_scenes[i_rollout].simulated_trajectories + ) + self._draw_agents( + images[self.step_current + 1 :], + ag_valid, + ag_xy, + ag_yaw, + ag_size, + ag_role, + ) + _video_path = (self.save_dir / f"rollout_{i_rollout:02d}.mp4").as_posix() + self.video_paths.append(_video_path) + save_images_to_mp4(images, _video_path) + + def _get_features_from_trajs( + self, trajs: List[sim_agents_submission_pb2.SimulatedTrajectory] + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """ + ag_valid: [n_ag, n_step], bool + ag_xy: [n_ag, n_step, 2], (x,y) + ag_yaw: [n_ag, n_step, 1], [-pi, pi] + ag_size: [n_ag, 3], [length, width, height] + ag_role: [n_ag, 3], one_hot [sdc=0, interest=1, predict=2] + """ + n_ag = len(trajs) + n_step = len(trajs[0].center_x) + ag_valid = np.ones([n_ag, n_step], dtype=bool) + ag_xy = np.zeros([n_ag, n_step, 2], dtype=np.float32) + ag_yaw = np.zeros([n_ag, n_step, 1], dtype=np.float32) + ag_size = np.zeros([n_ag, 3], dtype=np.float32) + ag_role = np.zeros([n_ag, 3], dtype=bool) + + for i_ag, _traj in enumerate(trajs): + ag_xy[i_ag] = np.stack([_traj.center_x, _traj.center_y], axis=-1) + ag_yaw[i_ag, :, 0] = _traj.heading + ag_size[i_ag] = self.ag_id2size[_traj.object_id] + ag_role[i_ag] = self.ag_id2role[_traj.object_id] + + return ag_valid, ag_xy, ag_yaw, ag_size, ag_role + + def _to_pixel(self, pos: np.ndarray) -> np.ndarray: + pos = pos * self.px_per_m + pos[..., 0] = pos[..., 0] - self.top_left_px[0] + pos[..., 1] = -pos[..., 1] - self.top_left_px[1] + return np.round(pos).astype(np.int32) + + @staticmethod + def _get_agent_bbox( + agent_valid: np.ndarray, + agent_pos: np.ndarray, + agent_yaw: np.ndarray, + agent_size: np.ndarray, + ) -> np.ndarray: + yaw = agent_yaw[agent_valid] # n, 1 + cos_yaw = np.cos(yaw) + sin_yaw = np.sin(yaw) + v_forward = np.concatenate([cos_yaw, sin_yaw], axis=-1) # n,2 + v_right = np.concatenate([sin_yaw, -cos_yaw], axis=-1) + + offset_forward = 0.5 * agent_size[agent_valid, 0:1] * v_forward # [n, 2] + offset_right = 0.5 * agent_size[agent_valid, 1:2] * v_right # [n, 2] + + vertex_offset = np.stack( + [ + -offset_forward + offset_right, + offset_forward + offset_right, + offset_forward - offset_right, + -offset_forward - offset_right, + ], + axis=1, + ) # n,4,2 + + agent_pos = agent_pos[agent_valid] + bbox = agent_pos[:, None, :].repeat(4, 1) + vertex_offset # n,4,2 + return bbox + + +def save_images_to_mp4(images: List[np.ndarray], out_path: str, fps=20) -> None: + encoder = ImageEncoder(out_path, images[0].shape, fps, fps) + for im in images: + encoder.capture_frame(im) + encoder.close() + encoder = None + + +def get_agent_features( + scenario: scenario_pb2.Scenario, step_current: int +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """ + ag_valid: [n_ag, n_step], bool + ag_xy: [n_ag, n_step, 2], (x,y) + ag_yaw: [n_ag, n_step, 1], [-pi, pi] + ag_size: [n_ag, 3], [length, width, height] + ag_role: [n_ag, 3], one_hot [sdc=0, interest=1, predict=2] + ag_id: [n_ag], int + """ + tracks = scenario.tracks + sdc_track_index = scenario.sdc_track_index + track_index_predict = ([i.track_index for i in scenario.tracks_to_predict],) + object_id_interest = ([i for i in scenario.objects_of_interest],) + + ag_valid, ag_xy, ag_yaw, ag_size, ag_role, ag_id = [], [], [], [], [], [] + for i, _track in enumerate(tracks): + # [VEHICLE=1, PEDESTRIAN=2, CYCLIST=3] -> [0,1,2] + # ag_type.append(_track.object_type - 1) + if _track.states[step_current].valid: + ag_id.append(_track.id) + step_valid, step_xy, step_yaw = [], [], [] + for s in _track.states: + step_valid.append(s.valid) + step_xy.append([s.center_x, s.center_y]) + step_yaw.append([s.heading]) + + ag_valid.append(step_valid) + ag_xy.append(step_xy) + ag_yaw.append(step_yaw) + + ag_size.append( + [ + _track.states[step_current].length, + _track.states[step_current].width, + _track.states[step_current].height, + ] + ) + + ag_role.append([False, False, False]) + if i in track_index_predict: + ag_role[-1][2] = True + if _track.id in object_id_interest: + ag_role[-1][1] = True + if i == sdc_track_index: + ag_role[-1][0] = True + + ag_valid = np.array(ag_valid) + ag_xy = np.array(ag_xy) + ag_yaw = np.array(ag_yaw) + ag_size = np.array(ag_size) + ag_role = np.array(ag_role) + ag_id = np.array(ag_id) + return ag_valid, ag_xy, ag_yaw, ag_size, ag_role, ag_id + + +def get_traffic_light_features( + tl_features, +) -> Tuple[List[np.ndarray], List[np.ndarray]]: + """n_tl is not constant for each timestep + tl_lane_state: len=n_step, list of array [n_tl] + tl_lane_id: len=n_step, list of array [n_tl] + """ + tl_lane_state, tl_lane_id, tl_stop_point = [], [], [] + for _step_tl in tl_features: + step_tl_lane_state, step_tl_lane_id, step_tl_stop_point = [], [], [] + for _tl in _step_tl.lane_states: # modify LANE_STATE + if _tl.state == 0: # UNKNOWN = 0; + tl_state = 0 # UNKNOWN = 0; + elif _tl.state in [1, 4]: # ARROW_STOP = 1; STOP = 4; + tl_state = 1 # STOP = 1; + elif _tl.state in [2, 5]: # ARROW_CAUTION = 2; CAUTION = 5; + tl_state = 2 # CAUTION = 2; + elif _tl.state in [3, 6]: # ARROW_GO = 3; GO = 6; + tl_state = 3 # GO = 3; + elif _tl.state in [7, 8]: # FLASHING_STOP = 7; FLASHING_CAUTION = 8; + tl_state = 4 # FLASHING = 4; + else: + assert ValueError + + step_tl_lane_state.append(tl_state) + step_tl_lane_id.append(_tl.lane) + + tl_lane_state.append(np.array(step_tl_lane_state)) + tl_lane_id.append(np.array(step_tl_lane_id)) + return tl_lane_state, tl_lane_id + + +def get_map_features( + map_features, +) -> Tuple[List[np.ndarray], np.ndarray, np.ndarray]: + mp_xyz, mp_id, mp_type = [], [], [] + for mf in map_features: + feature_data_type = mf.WhichOneof("feature_data") + # pip install waymo-open-dataset-tf-2-6-0==1.4.9, not updated, should be driveway + if feature_data_type is None: + continue + feature = getattr(mf, feature_data_type) + if feature_data_type == "lane": + if feature.type == 0: # UNDEFINED + mp_type.append(1) + elif feature.type == 1: # FREEWAY + mp_type.append(0) + elif feature.type == 2: # SURFACE_STREET + mp_type.append(1) + elif feature.type == 3: # BIKE_LANE + mp_type.append(3) + mp_id.append(mf.id) + mp_xyz.append([[p.x, p.y, p.z] for p in feature.polyline][::2]) + elif feature_data_type == "stop_sign": + for l_id in feature.lane: + # override FREEWAY/SURFACE_STREET with stop sign lane + # BIKE_LANE remains unchanged + idx_lane = mp_id.index(l_id) + if mp_type[idx_lane] < 2: + mp_type[idx_lane] = 2 + elif feature_data_type == "road_edge": + assert feature.type > 0 # no UNKNOWN = 0 + mp_id.append(mf.id) + mp_type.append(feature.type + 3) # [1, 2] -> [4, 5] + mp_xyz.append([[p.x, p.y, p.z] for p in feature.polyline][::2]) + elif feature_data_type == "road_line": + assert feature.type > 0 # no UNKNOWN = 0 + # BROKEN_SINGLE_WHITE = 1 + # SOLID_SINGLE_WHITE = 2 + # SOLID_DOUBLE_WHITE = 3 + # BROKEN_SINGLE_YELLOW = 4 + # BROKEN_DOUBLE_YELLOW = 5 + # SOLID_SINGLE_YELLOW = 6 + # SOLID_DOUBLE_YELLOW = 7 + # PASSING_DOUBLE_YELLOW = 8 + if feature.type in [1, 4, 5]: + feature_type_new = 6 # BROKEN + elif feature.type in [2, 6]: + feature_type_new = 7 # SOLID_SINGLE + else: + feature_type_new = 8 # DOUBLE + mp_id.append(mf.id) + mp_type.append(feature_type_new) + mp_xyz.append([[p.x, p.y, p.z] for p in feature.polyline][::2]) + elif feature_data_type in ["speed_bump", "driveway", "crosswalk"]: + xyz = np.array([[p.x, p.y, p.z] for p in feature.polygon]) + polygon_idx = np.linspace(0, xyz.shape[0], 4, endpoint=False, dtype=int) + pl_polygon = _get_polylines_from_polygon(xyz[polygon_idx]) + mp_xyz.extend(pl_polygon) + mp_id.extend([mf.id] * len(pl_polygon)) + pl_type = 9 if feature_data_type in ["speed_bump", "driveway"] else 10 + mp_type.extend([pl_type] * len(pl_polygon)) + else: + raise ValueError + + mp_id = np.array(mp_id) # [n_pl] + mp_type = np.array(mp_type) # [n_pl] + mp_xyz = [np.stack(line) for line in mp_xyz] # len=n_pl, list of [n_pl_node, 3] + return mp_xyz, mp_id, mp_type + + +def _get_polylines_from_polygon(polygon: np.ndarray) -> List[List[List]]: + # polygon: [4, 3] + l1 = np.linalg.norm(polygon[1, :2] - polygon[0, :2]) + l2 = np.linalg.norm(polygon[2, :2] - polygon[1, :2]) + + def _pl_interp_start_end(start: np.ndarray, end: np.ndarray) -> List[List]: + length = np.linalg.norm(start - end) + unit_vec = (end - start) / length + pl = [] + for i in range(int(length) + 1): # 4.5 -> 5 [0,1,2,3,4] + x, y, z = start + unit_vec * i + pl.append([x, y, z]) + pl.append([end[0], end[1], end[2]]) + return pl + + if l1 > l2: + pl1 = _pl_interp_start_end(polygon[0], polygon[1]) + pl2 = _pl_interp_start_end(polygon[2], polygon[3]) + else: + pl1 = _pl_interp_start_end(polygon[0], polygon[3]) + pl2 = _pl_interp_start_end(polygon[2], polygon[1]) + return [pl1, pl1[::-1], pl2, pl2[::-1]] diff --git a/backups/thirdparty/catk/src/utils/wosac_utils.py b/backups/thirdparty/catk/src/utils/wosac_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d796b4f7ecb812a08709a4f5d1930a5bfe0bd14b --- /dev/null +++ b/backups/thirdparty/catk/src/utils/wosac_utils.py @@ -0,0 +1,87 @@ +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from typing import List + +import torch +from torch import Tensor +from torch_geometric.utils import degree +from waymo_open_dataset.protos import sim_agents_submission_pb2 + + +def _unbatch(src: Tensor, batch: Tensor, dim: int = 0) -> List[Tensor]: + sizes = degree(batch, dtype=torch.long).tolist() + return src.split(sizes, dim) + + +def get_scenario_rollouts( + scenario_id: Tensor, # [n_scenario, n_str_length] + agent_id: Tensor, # [n_agent] + agent_batch: Tensor, # [n_agent] + pred_traj: Tensor, # [n_agent, n_rollout, n_step, 2] + pred_z: Tensor, # [n_agent, n_rollout, n_step] + pred_head: Tensor, # [n_agent, n_rollout, n_step] +) -> List[sim_agents_submission_pb2.ScenarioRollouts]: + scenario_id = scenario_id.cpu().numpy() + agent_id = _unbatch(agent_id, agent_batch) + pred_traj = _unbatch(pred_traj, agent_batch) + pred_z = _unbatch(pred_z, agent_batch) + pred_head = _unbatch(pred_head, agent_batch) + agent_id = [x.cpu().numpy() for x in agent_id] + pred_traj = [x.cpu().numpy() for x in pred_traj] + pred_z = [x.cpu().numpy() for x in pred_z] + pred_head = [x.cpu().numpy() for x in pred_head] + + n_scenario = scenario_id.shape[0] + n_rollout = pred_traj[0].shape[1] + scenario_rollouts = [] + for i_scenario in range(n_scenario): + joint_scenes = [] + for i_rollout in range(n_rollout): + simulated_trajectories = [] + for i_agent in range(len(agent_id[i_scenario])): + simulated_trajectories.append( + sim_agents_submission_pb2.SimulatedTrajectory( + center_x=pred_traj[i_scenario][i_agent, i_rollout, :, 0], + center_y=pred_traj[i_scenario][i_agent, i_rollout, :, 1], + center_z=pred_z[i_scenario][i_agent, i_rollout], + heading=pred_head[i_scenario][i_agent, i_rollout], + object_id=agent_id[i_scenario][i_agent], + ) + ) + joint_scenes.append( + sim_agents_submission_pb2.JointScene( + simulated_trajectories=simulated_trajectories + ) + ) + + _str_scenario_id = "".join([chr(x) for x in scenario_id[i_scenario] if x > 0]) + scenario_rollouts.append( + sim_agents_submission_pb2.ScenarioRollouts( + joint_scenes=joint_scenes, scenario_id=_str_scenario_id + ) + ) + + return scenario_rollouts + + +def get_scenario_id_int_tensor(scenario_id: List[str], device: torch.device) -> Tensor: + scenario_id_int_tensor = [] + for str_id in scenario_id: + int_id = [-1] * 16 # max_len of scenario_id string is 16 + for i, c in enumerate(str_id): + int_id[i] = ord(c) + scenario_id_int_tensor.append( + torch.tensor(int_id, dtype=torch.int32, device=device) + ) + return torch.stack(scenario_id_int_tensor, dim=0) # [n_scenario, 16] diff --git a/backups/train.py b/backups/train.py new file mode 100644 index 0000000000000000000000000000000000000000..476f4a55fb494bd2bce2e530b276255ba5103ec1 --- /dev/null +++ b/backups/train.py @@ -0,0 +1,123 @@ +import pytorch_lightning as pl +import os +import shutil +import fnmatch +import torch +from argparse import ArgumentParser +from pytorch_lightning.callbacks import LearningRateMonitor +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.strategies import DDPStrategy +from pytorch_lightning.loggers import WandbLogger + +from dev.utils.func import Logging, load_config_act +from dev.datasets.scalable_dataset import MultiDataModule +from dev.model.smart import SMART + + +def backup(source_dir, backup_dir): + """ + Back up the source directory (code and configs) to a backup directory. + """ + + if os.path.exists(backup_dir): + return + os.makedirs(backup_dir, exist_ok=False) + + # Helper function to check if a path matches exclude patterns + def should_exclude(path): + for pattern in exclude_patterns: + if fnmatch.fnmatch(os.path.basename(path), pattern): + return True + return False + + # Iterate through the files and directories in source_dir + for root, dirs, files in os.walk(source_dir): + # Skip excluded directories + dirs[:] = [d for d in dirs if not should_exclude(d)] + + # Determine the relative path and destination path + rel_path = os.path.relpath(root, source_dir) + dest_dir = os.path.join(backup_dir, rel_path) + os.makedirs(dest_dir, exist_ok=True) + + # Copy all relevant files + for file in files: + if any(fnmatch.fnmatch(file, pattern) for pattern in include_patterns): + shutil.copy2(os.path.join(root, file), os.path.join(dest_dir, file)) + + print(f"Backup completed. Files saved to: {backup_dir}") + + +if __name__ == '__main__': + pl.seed_everything(2024, workers=True) + torch.set_printoptions(precision=3) + + parser = ArgumentParser() + Predictor_hash = {'smart': SMART,} + parser.add_argument('--config', type=str, default='configs/ours_long_term.yaml') + parser.add_argument('--pretrain_ckpt', type=str, default='') + parser.add_argument('--ckpt_path', type=str, default='') + parser.add_argument('--save_ckpt_path', type=str, default="output/debug") + parser.add_argument('--devices', type=int, default=1) + args = parser.parse_args() + + # backup codes + exclude_patterns = ['*output*', '*logs', 'wandb', 'data', '*debug*', '*backup*', 'interact_*', '*edge_map*', '__pycache__'] + include_patterns = ['*.py', '*.json', '*.yaml', '*.yml', '*.sh'] + backup(os.getcwd(), os.path.join(args.save_ckpt_path, 'backups')) + + logger = Logging().log(level='DEBUG') + config = load_config_act(args.config) + Predictor = Predictor_hash[config.Model.predictor] + strategy = DDPStrategy(find_unused_parameters=True, gradient_as_bucket_view=True) + Data_config = config.Dataset + datamodule = MultiDataModule(**vars(Data_config), logger=logger) + + import os + wandb_logger = None + if int(os.getenv('WANDB', 0)) and not int(os.getenv('DEBUG', 0)): + # squeue -O username,state,nodelist,gres,minmemory,numcpus,name + wandb_logger = WandbLogger(project='simagent') + + trainer_config = config.Trainer + max_epochs = trainer_config.max_epochs + + if args.pretrain_ckpt == '': + model = Predictor(config.Model, save_path=args.save_ckpt_path, logger=logger, max_epochs=max_epochs) + else: + model = Predictor(config.Model, save_path=args.save_ckpt_path, logger=logger, max_epochs=max_epochs) + model.load_params_from_file(filename=args.pretrain_ckpt) + + every_n_epochs = 1 + if int(os.getenv('OVERFIT', 0)): + max_epochs = trainer_config.overfit_epochs + every_n_epochs = 100 + + if int(os.getenv('CHECK_INPUTS', 0)): + max_epochs = 1 + + check_val_every_n_epoch = 1 # save checkpoints for each epoch + model_checkpoint = ModelCheckpoint(dirpath=args.save_ckpt_path, + filename='{epoch:02d}', + save_top_k=5, + monitor='epoch', + mode='max', + save_last=True, + every_n_train_steps=1000, + save_on_train_epoch_end=True) + lr_monitor = LearningRateMonitor(logging_interval='epoch') + trainer = pl.Trainer(accelerator=trainer_config.accelerator, devices=args.devices if args.devices is not None else trainer_config.devices, + strategy=strategy, logger=wandb_logger, + accumulate_grad_batches=trainer_config.accumulate_grad_batches, + num_nodes=trainer_config.num_nodes, + callbacks=[model_checkpoint, lr_monitor], + max_epochs=max_epochs, + num_sanity_val_steps=0, + check_val_every_n_epoch=check_val_every_n_epoch, + log_every_n_steps=1, + gradient_clip_val=0.5) + + if args.ckpt_path == '': + trainer.fit(model, datamodule) + else: + trainer.fit(model, datamodule, ckpt_path=args.ckpt_path) diff --git a/backups/val.py b/backups/val.py new file mode 100644 index 0000000000000000000000000000000000000000..370573c7c7b162df967b64c1e3252c8829212cdf --- /dev/null +++ b/backups/val.py @@ -0,0 +1,68 @@ +import os +import pytorch_lightning as pl +from argparse import ArgumentParser +from rich.console import Console +from torch_geometric.loader import DataLoader + +from dev.datasets.scalable_dataset import MultiDataset, WaymoTargetBuilder +from dev.utils.func import load_config_act, Logging +from dev.model.smart import SMART + +CONSOLE = Console(width=120) + + +if __name__ == '__main__': + pl.seed_everything(2024, workers=True) + parser = ArgumentParser() + parser.add_argument('--seed', type=int, default=2024) + parser.add_argument('--config', type=str, default='configs/train/train_scalable_with_state.yaml') + parser.add_argument('--ckpt_path', type=str, default="") + parser.add_argument('--insert_agent', action='store_true') + parser.add_argument('--t', type=str, default=2) + parser.add_argument('--save_path', type=str, default=None) + args = parser.parse_args() + pl.seed_everything(args.seed, workers=True) + config = load_config_act(args.config) + logger = Logging().log(level='DEBUG') + + data_config = config.Dataset + val_dataset = MultiDataset(split='val', + raw_dir=data_config.val_raw_dir, + token_size=data_config.token_size, + transform=WaymoTargetBuilder( + config.Model.num_historical_steps, + config.Model.decoder.num_future_steps, + max_num=data_config.max_num, + training=False), + tfrecord_dir=data_config.val_tfrecords_splitted, + predict_motion=config.Model.predict_motion, + predict_state=config.Model.predict_state, + predict_map=config.Model.predict_map, + buffer_size=config.Model.buffer_size, + logger=logger, + ) + dataloader = DataLoader(val_dataset, + shuffle=False, + num_workers=data_config.num_workers, + pin_memory=data_config.pin_memory, + persistent_workers=True if data_config.num_workers > 0 else False + ) + + if args.save_path is not None: + save_path = args.save_path + else: + assert args.ckpt_path != "" and os.path.exists(args.ckpt_path), f"Path {args.ckpt_path} not exist!" + save_path = os.path.join(os.path.dirname(args.ckpt_path), 'val') + CONSOLE.log(f"Results will be saved to [yellow]{save_path}[/]") + os.makedirs(save_path, exist_ok=True) + + model = SMART(config.Model, save_path=save_path, logger=logger, insert_agent=args.insert_agent, t=args.t) + CONSOLE.log(f"Loaded model from [yellow]{args.ckpt_path}[/]") + + trainer_config = config.Trainer + trainer = pl.Trainer(accelerator=trainer_config.accelerator, + devices=trainer_config.devices, + strategy='ddp', num_sanity_val_steps=0) + trainer.validate(model, dataloader, ckpt_path=args.ckpt_path) + + CONSOLE.log(f"Validation done!") diff --git a/epoch=04.ckpt b/epoch=04.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..d9fc006ae2848b5929420375ee526beeb09e5b25 --- /dev/null +++ b/epoch=04.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:907a672b22392f4f992bea92ea7ec2e9b644aa51641c987546d840cd862b061b +size 116282362 diff --git a/last.ckpt b/last.ckpt index eb39f4937e150821d6e644705cf25da26fdad46f..509a69803d7df91ae1e61b1a0e7ebd88d0809114 100644 --- a/last.ckpt +++ b/last.ckpt @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5aa49600f5998d2514ac0f6b5518f230c40d05cc6c0a0d96743b6650c6ea9a3b -size 135739030 +oid sha256:25cfa41243cbea3b0d15c1414b767ec564ecca2656f0f2d78b13180d00b5c2da +size 116282362 diff --git a/training_006705_fd50c7db8a208383_prob_seed.png b/training_006705_fd50c7db8a208383_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..fbf2ec65bff0ddd59203de35c571b37392444bc4 Binary files /dev/null and b/training_006705_fd50c7db8a208383_prob_seed.png differ diff --git a/training_006706_6e0fe3923492e547_prob_seed.png b/training_006706_6e0fe3923492e547_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..449e41407f2576b72fecb3f78ff4ca884dcc0afd Binary files /dev/null and b/training_006706_6e0fe3923492e547_prob_seed.png differ diff --git a/training_006706_91ffcd2d1333dac9_prob_seed.png b/training_006706_91ffcd2d1333dac9_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..6bb013d3198c81f88a3b23ad289e6fadc896abbc Binary files /dev/null and b/training_006706_91ffcd2d1333dac9_prob_seed.png differ diff --git a/training_006706_9b35a1ed11f6bc1b_prob_seed.png b/training_006706_9b35a1ed11f6bc1b_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..fe147a65be9a3094bc9419e9fd76ce49bb791621 Binary files /dev/null and b/training_006706_9b35a1ed11f6bc1b_prob_seed.png differ diff --git a/training_006706_be750663404dd7ad_prob_seed.png b/training_006706_be750663404dd7ad_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..88e352ca708c04563bccefda29827b6febfbcb07 Binary files /dev/null and b/training_006706_be750663404dd7ad_prob_seed.png differ diff --git a/training_006706_d93898a06063eaea_prob_seed.png b/training_006706_d93898a06063eaea_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..11b200bffff82b306ef850b56edf243594e3f858 Binary files /dev/null and b/training_006706_d93898a06063eaea_prob_seed.png differ diff --git a/training_006706_df9a06abeea201c7_prob_seed.png b/training_006706_df9a06abeea201c7_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..8970a6a25f356a9752128b44b6a7254ecf084bc5 Binary files /dev/null and b/training_006706_df9a06abeea201c7_prob_seed.png differ diff --git a/training_006706_f1396c9c46d4f3ac_prob_seed.png b/training_006706_f1396c9c46d4f3ac_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..f9bbc671c6577ef6e0dfef93c7020466fc838155 Binary files /dev/null and b/training_006706_f1396c9c46d4f3ac_prob_seed.png differ diff --git a/training_008526_655431f3855a6fd5_prob_seed.png b/training_008526_655431f3855a6fd5_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..95b82efe1749d08ebdbc0202748f15880850a321 Binary files /dev/null and b/training_008526_655431f3855a6fd5_prob_seed.png differ diff --git a/training_008527_1f6e7b6e8a49fc99_prob_seed.png b/training_008527_1f6e7b6e8a49fc99_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..bd2d3c17b072193071a7ab3e19830b4675c2ded7 Binary files /dev/null and b/training_008527_1f6e7b6e8a49fc99_prob_seed.png differ diff --git a/training_008527_2555a11e21949028_prob_seed.png b/training_008527_2555a11e21949028_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..96d98cbf67bff1f83571413b9fccf5b78224f28c Binary files /dev/null and b/training_008527_2555a11e21949028_prob_seed.png differ diff --git a/training_008527_2ed417a749e11301_prob_seed.png b/training_008527_2ed417a749e11301_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..37967244890cb72da13f581eb4fa43bd828f023d Binary files /dev/null and b/training_008527_2ed417a749e11301_prob_seed.png differ diff --git a/training_008527_48950bd55b1e8386_prob_seed.png b/training_008527_48950bd55b1e8386_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..be2f4942defb1930b4b5ee0a58e51f70d87e9ed1 Binary files /dev/null and b/training_008527_48950bd55b1e8386_prob_seed.png differ diff --git a/training_008527_5b2fc08dffce0e8f_prob_seed.png b/training_008527_5b2fc08dffce0e8f_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..8fd63335011b9a3af98e2f4c8fed38eb649cb1fa Binary files /dev/null and b/training_008527_5b2fc08dffce0e8f_prob_seed.png differ diff --git a/training_008527_7a838885022a6cb2_prob_seed.png b/training_008527_7a838885022a6cb2_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..bca671ba2ea7a9568855d7450aa4783507293844 Binary files /dev/null and b/training_008527_7a838885022a6cb2_prob_seed.png differ diff --git a/training_008527_967643a6fcf27a9d_prob_seed.png b/training_008527_967643a6fcf27a9d_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..68661cf23b2019f7e1480badfe302059237b1a8f Binary files /dev/null and b/training_008527_967643a6fcf27a9d_prob_seed.png differ diff --git a/training_047694_acd6ae25520dcff_prob_seed.png b/training_047694_acd6ae25520dcff_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..b830a670a91335b31e1e91136601c6f0a5f9ab1f Binary files /dev/null and b/training_047694_acd6ae25520dcff_prob_seed.png differ diff --git a/training_047695_190811854323ee00_prob_seed.png b/training_047695_190811854323ee00_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..9ced82761fbcc3789e5c4f6f2b62dd7af702e1de Binary files /dev/null and b/training_047695_190811854323ee00_prob_seed.png differ diff --git a/training_047695_3e4e99f21af59388_prob_seed.png b/training_047695_3e4e99f21af59388_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..a73b5989d29c4ccd88bf5fedc211879b6dbf9ae2 Binary files /dev/null and b/training_047695_3e4e99f21af59388_prob_seed.png differ diff --git a/training_047695_924a92fff4bd3d56_prob_seed.png b/training_047695_924a92fff4bd3d56_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..16fe8ab3b0d5e1c05c45b49a441740c3f80973a4 Binary files /dev/null and b/training_047695_924a92fff4bd3d56_prob_seed.png differ diff --git a/training_047695_97c5d6b46b72bf9c_prob_seed.png b/training_047695_97c5d6b46b72bf9c_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..590d3832467455dd35c06f9d9b71d6d1c9ab58ec Binary files /dev/null and b/training_047695_97c5d6b46b72bf9c_prob_seed.png differ diff --git a/training_047695_a7de975481c9e13_prob_seed.png b/training_047695_a7de975481c9e13_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..2bc55763e49b38a3d98ddc1bc765d4394f607b85 Binary files /dev/null and b/training_047695_a7de975481c9e13_prob_seed.png differ diff --git a/training_047695_c9e0b36660057828_prob_seed.png b/training_047695_c9e0b36660057828_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..00fed5b45b06333ecd86668d55bb2eed3a3bc06c Binary files /dev/null and b/training_047695_c9e0b36660057828_prob_seed.png differ diff --git a/training_047695_f00511a5615a3607_prob_seed.png b/training_047695_f00511a5615a3607_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..240c00bd7f9dfce5aecfef6913169d45ad94a2aa Binary files /dev/null and b/training_047695_f00511a5615a3607_prob_seed.png differ diff --git a/training_051704_b1536e866752a65d_prob_seed.png b/training_051704_b1536e866752a65d_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..9cc930b9bcd826e36f606122d73d256c3304cf87 Binary files /dev/null and b/training_051704_b1536e866752a65d_prob_seed.png differ diff --git a/training_051705_205031af415d73c5_prob_seed.png b/training_051705_205031af415d73c5_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..8a361f206bbaad94fd9ce1e5be85cb19c835803e Binary files /dev/null and b/training_051705_205031af415d73c5_prob_seed.png differ diff --git a/training_051705_21087f965c829484_prob_seed.png b/training_051705_21087f965c829484_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..368d38d1555e136f743c566f2391612f50766d5d Binary files /dev/null and b/training_051705_21087f965c829484_prob_seed.png differ diff --git a/training_051705_31df6ea77e6a2cae_prob_seed.png b/training_051705_31df6ea77e6a2cae_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..09bad75d6d0fc77dc7cb0631fca13f66a84360fc Binary files /dev/null and b/training_051705_31df6ea77e6a2cae_prob_seed.png differ diff --git a/training_051705_3703b9b3e0d2a2b3_prob_seed.png b/training_051705_3703b9b3e0d2a2b3_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..525db42f990bd6a72c7bd0f1f525621c3b3c8bb1 Binary files /dev/null and b/training_051705_3703b9b3e0d2a2b3_prob_seed.png differ diff --git a/training_051705_38592fdfb110611a_prob_seed.png b/training_051705_38592fdfb110611a_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..0676718327f50e97a14c063b6f188cbc06be9ea8 Binary files /dev/null and b/training_051705_38592fdfb110611a_prob_seed.png differ diff --git a/training_051705_d1f20058e3a2d025_prob_seed.png b/training_051705_d1f20058e3a2d025_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..6d23ac415f6939f03b5e063653335b248cf6bd28 Binary files /dev/null and b/training_051705_d1f20058e3a2d025_prob_seed.png differ diff --git a/training_051705_d81a4dadcbabf42_prob_seed.png b/training_051705_d81a4dadcbabf42_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..5b4b5541e43b78478716c3ea640afebecb538795 Binary files /dev/null and b/training_051705_d81a4dadcbabf42_prob_seed.png differ diff --git a/training_059797_d2c79d2871f76459_prob_seed.png b/training_059797_d2c79d2871f76459_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..5e911065a3917890ae24bc3534f1c8a32d15f786 Binary files /dev/null and b/training_059797_d2c79d2871f76459_prob_seed.png differ diff --git a/training_059798_418ba5cd626c3230_prob_seed.png b/training_059798_418ba5cd626c3230_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..b75cb65659635a008323b99ff9b11174a227af60 Binary files /dev/null and b/training_059798_418ba5cd626c3230_prob_seed.png differ diff --git a/training_059798_6c7c423976824c37_prob_seed.png b/training_059798_6c7c423976824c37_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..f70a82525b41066cad31bada366cb3168219d05d Binary files /dev/null and b/training_059798_6c7c423976824c37_prob_seed.png differ diff --git a/training_059798_8e41018efe98a1f4_prob_seed.png b/training_059798_8e41018efe98a1f4_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..dec192dce7be2df80c572a9a9c45d17d8809b29a Binary files /dev/null and b/training_059798_8e41018efe98a1f4_prob_seed.png differ diff --git a/training_059798_b6fba06e1bf49d22_prob_seed.png b/training_059798_b6fba06e1bf49d22_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..6e9eedc6a5c64e058fc2b2aa41567048b2812ee7 Binary files /dev/null and b/training_059798_b6fba06e1bf49d22_prob_seed.png differ diff --git a/training_059798_ba274a542fadaf64_prob_seed.png b/training_059798_ba274a542fadaf64_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..fc010ae2b3f1b20647fa9e579e9f0248a8e06b19 Binary files /dev/null and b/training_059798_ba274a542fadaf64_prob_seed.png differ diff --git a/training_059798_be1faeb7d85167c1_prob_seed.png b/training_059798_be1faeb7d85167c1_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..e438a2d304008d4271d7b22c028c60ec18b4c1ff Binary files /dev/null and b/training_059798_be1faeb7d85167c1_prob_seed.png differ diff --git a/training_059798_c6df42ce25636f3a_prob_seed.png b/training_059798_c6df42ce25636f3a_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..3141496c6a9d12e38cd271213d4985101f1a6e46 Binary files /dev/null and b/training_059798_c6df42ce25636f3a_prob_seed.png differ diff --git a/training_063511_43fe95e4c4b0de6_prob_seed.png b/training_063511_43fe95e4c4b0de6_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..9fadd48c148465e180172b08acc7a49c165b44cc Binary files /dev/null and b/training_063511_43fe95e4c4b0de6_prob_seed.png differ diff --git a/training_063512_21b130ca1988c1d3_prob_seed.png b/training_063512_21b130ca1988c1d3_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..eeb3b2c268d0e6103fbc07a7dca881c17c0727e7 Binary files /dev/null and b/training_063512_21b130ca1988c1d3_prob_seed.png differ diff --git a/training_063512_2f3e494646f6534e_prob_seed.png b/training_063512_2f3e494646f6534e_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..e8985665a651b3e391723fcb4472162e9f9539fb Binary files /dev/null and b/training_063512_2f3e494646f6534e_prob_seed.png differ diff --git a/training_063512_5904f1fc063346d7_prob_seed.png b/training_063512_5904f1fc063346d7_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..93712ef6f14077c05d91c29f6a5db3054848ed8d Binary files /dev/null and b/training_063512_5904f1fc063346d7_prob_seed.png differ diff --git a/training_063512_62edc99f60a73428_prob_seed.png b/training_063512_62edc99f60a73428_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..9acfbc946396c1bdc8f2ac19520f353599d5ec1b Binary files /dev/null and b/training_063512_62edc99f60a73428_prob_seed.png differ diff --git a/training_063512_9b759a0ea34abc67_prob_seed.png b/training_063512_9b759a0ea34abc67_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..e6143e4fa76c0a9136eb03568af9c25c9312aa50 Binary files /dev/null and b/training_063512_9b759a0ea34abc67_prob_seed.png differ diff --git a/training_063512_c3666d88cff01f8a_prob_seed.png b/training_063512_c3666d88cff01f8a_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..9c17ffbec47857a19ba676fd909eb513bcd1776b Binary files /dev/null and b/training_063512_c3666d88cff01f8a_prob_seed.png differ diff --git a/training_063512_dc48c34f595066f4_prob_seed.png b/training_063512_dc48c34f595066f4_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..f73348d374eaa5575778f4aa190db0c3385d0114 Binary files /dev/null and b/training_063512_dc48c34f595066f4_prob_seed.png differ diff --git a/training_112483_7a62a7bf1ae6cb3c_prob_seed.png b/training_112483_7a62a7bf1ae6cb3c_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..117b248a9f41bdd802e0b2a09202b9cb48cafe34 Binary files /dev/null and b/training_112483_7a62a7bf1ae6cb3c_prob_seed.png differ diff --git a/training_112484_32e6ab48038e7208_prob_seed.png b/training_112484_32e6ab48038e7208_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..2af07ae7807f3f05ea425ae9b9217d97e33715af Binary files /dev/null and b/training_112484_32e6ab48038e7208_prob_seed.png differ diff --git a/training_112484_3c6051432178e8e8_prob_seed.png b/training_112484_3c6051432178e8e8_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..6295e59c1a8dfa2567f3b87bc082bde6d0fde97c Binary files /dev/null and b/training_112484_3c6051432178e8e8_prob_seed.png differ diff --git a/training_112484_cfe540c9cb2074cc_prob_seed.png b/training_112484_cfe540c9cb2074cc_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..0aa7df4a04a5032c17c3010c394cf04117381e99 Binary files /dev/null and b/training_112484_cfe540c9cb2074cc_prob_seed.png differ diff --git a/training_112484_d1763b4592937a65_prob_seed.png b/training_112484_d1763b4592937a65_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..dfd79fbda8759daccd30140129e9b027f682fb11 Binary files /dev/null and b/training_112484_d1763b4592937a65_prob_seed.png differ diff --git a/training_112484_e0fa1b3673916e54_prob_seed.png b/training_112484_e0fa1b3673916e54_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..b9a2dc05b9d2643fae2088ef435204117dc65e5c Binary files /dev/null and b/training_112484_e0fa1b3673916e54_prob_seed.png differ diff --git a/training_112484_e32ab26b9c1b217_prob_seed.png b/training_112484_e32ab26b9c1b217_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..34e15ebf830daa3acc4140e5b50b85dfb68f99bb Binary files /dev/null and b/training_112484_e32ab26b9c1b217_prob_seed.png differ diff --git a/training_112484_e6de5b41954001d0_prob_seed.png b/training_112484_e6de5b41954001d0_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..90264b2459efd210168be18ece4d8861cc174c3d Binary files /dev/null and b/training_112484_e6de5b41954001d0_prob_seed.png differ diff --git a/training_128225_a844f073478d6d73_prob_seed.png b/training_128225_a844f073478d6d73_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..9799acb05a88f334eb5399e53d68f2457434aca0 Binary files /dev/null and b/training_128225_a844f073478d6d73_prob_seed.png differ diff --git a/training_128226_1963977e2131f55c_prob_seed.png b/training_128226_1963977e2131f55c_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..6cfcce392e36280860792ef428b1f3dec1147d2e Binary files /dev/null and b/training_128226_1963977e2131f55c_prob_seed.png differ diff --git a/training_128226_4a7b82d581b3bb4d_prob_seed.png b/training_128226_4a7b82d581b3bb4d_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..0176fa7129f2d50f53834b9118e432f8f13b44aa Binary files /dev/null and b/training_128226_4a7b82d581b3bb4d_prob_seed.png differ diff --git a/training_128226_6fdd3b0ae58299b_prob_seed.png b/training_128226_6fdd3b0ae58299b_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..5e1495aaf63426870a389b345bdc6adf54a4e8ca Binary files /dev/null and b/training_128226_6fdd3b0ae58299b_prob_seed.png differ diff --git a/training_128226_a56382d0993a466a_prob_seed.png b/training_128226_a56382d0993a466a_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..8cbddb99cc28d5d0dfcbd170b91b97f4fca354ae Binary files /dev/null and b/training_128226_a56382d0993a466a_prob_seed.png differ diff --git a/training_128226_bbb3b4a5cef6707d_prob_seed.png b/training_128226_bbb3b4a5cef6707d_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..732bc772b16e4eae250989035f1d7f6a3e0e53c7 Binary files /dev/null and b/training_128226_bbb3b4a5cef6707d_prob_seed.png differ diff --git a/training_128226_f409b2a80d9d6bdb_prob_seed.png b/training_128226_f409b2a80d9d6bdb_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..b25b6ea2eebeea2e9ce1d981395e184de8deec94 Binary files /dev/null and b/training_128226_f409b2a80d9d6bdb_prob_seed.png differ diff --git a/training_128226_fe599a27075967b8_prob_seed.png b/training_128226_fe599a27075967b8_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..740798c67f448b75823cdffcad2ef68b8908a08e Binary files /dev/null and b/training_128226_fe599a27075967b8_prob_seed.png differ diff --git a/training_185044_e68831c6bf85bbca_prob_seed.png b/training_185044_e68831c6bf85bbca_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..2660d8b3139ad46ec3ca66c2b0f9277eb7012320 Binary files /dev/null and b/training_185044_e68831c6bf85bbca_prob_seed.png differ diff --git a/training_185045_1b49dfd40704e78c_prob_seed.png b/training_185045_1b49dfd40704e78c_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..7bf9ce8ca2d1cce3cf1bb2d81b0bf71d545e5462 Binary files /dev/null and b/training_185045_1b49dfd40704e78c_prob_seed.png differ diff --git a/training_185045_1bb7325da08154fa_prob_seed.png b/training_185045_1bb7325da08154fa_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..76da5e6ef60b15b4885c6667e2f7077da163dd95 Binary files /dev/null and b/training_185045_1bb7325da08154fa_prob_seed.png differ diff --git a/training_185045_1be4d39fec374733_prob_seed.png b/training_185045_1be4d39fec374733_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..80a6bccda1cc054e8fc1d8633cff3c187e90cc0e Binary files /dev/null and b/training_185045_1be4d39fec374733_prob_seed.png differ diff --git a/training_185045_8ebc0218acec4fa3_prob_seed.png b/training_185045_8ebc0218acec4fa3_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..f0d5337251b3aa75df17abc14a6ecca9a9ac9b8e Binary files /dev/null and b/training_185045_8ebc0218acec4fa3_prob_seed.png differ diff --git a/training_185045_9d9d41bb53c3742c_prob_seed.png b/training_185045_9d9d41bb53c3742c_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..8e33b45c9f2e4fa6f8be404c7fbd5af9314f7a25 Binary files /dev/null and b/training_185045_9d9d41bb53c3742c_prob_seed.png differ diff --git a/training_185045_b3c966d607c2c392_prob_seed.png b/training_185045_b3c966d607c2c392_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..246653a83589784ba3e999fa039fa5ae4f1f80e4 Binary files /dev/null and b/training_185045_b3c966d607c2c392_prob_seed.png differ diff --git a/training_185045_b605d684f24288e9_prob_seed.png b/training_185045_b605d684f24288e9_prob_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..ef1616dc5900a651e1e2e39e2ebd95bc310bd8fb Binary files /dev/null and b/training_185045_b605d684f24288e9_prob_seed.png differ