|
import os |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from tqdm import tqdm |
|
from typing import Dict, Mapping, Optional, Literal |
|
from torch_cluster import radius, radius_graph |
|
from torch_geometric.data import HeteroData, Batch |
|
from torch_geometric.utils import dense_to_sparse, subgraph |
|
from scipy.optimize import linear_sum_assignment |
|
|
|
from dev.modules.attr_tokenizer import Attr_Tokenizer |
|
from dev.modules.layers import * |
|
from dev.utils.visualization import * |
|
from dev.datasets.preprocess import AGENT_SHAPE, AGENT_TYPE |
|
from dev.utils.func import angle_between_2d_vectors, wrap_angle, weight_init |
|
|
|
|
|
class HungarianMatcher(nn.Module): |
|
|
|
def __init__(self, loss_weight: dict, enter_state: int = 0): |
|
super().__init__() |
|
self.enter_state = enter_state |
|
self.cost_state = loss_weight['state_cls_loss'] |
|
self.cost_pos = loss_weight['pos_cls_loss'] |
|
self.cost_head = loss_weight['head_cls_loss'] |
|
self.cost_shape = loss_weight['shape_reg_loss'] |
|
self.seed_state_weight = loss_weight['seed_state_weight'] |
|
self.seed_type_weight = loss_weight['seed_type_weight'] |
|
|
|
@torch.no_grad() |
|
def forward(self, outputs, targets, ptr_pred, ptr_gt, valid_mask=None): |
|
|
|
pred_indices = [] |
|
gt_indices = [] |
|
|
|
for b in range(len(ptr_gt) - 1): |
|
|
|
start_pred, end_pred = ptr_pred[b], ptr_pred[b + 1] |
|
start_gt, end_gt = ptr_gt[b], ptr_gt[b + 1] |
|
|
|
pos_pred = outputs['pos_pred'][start_pred : end_pred] |
|
shape_pred = outputs['shape_pred'][start_pred : end_pred] |
|
|
|
pos_gt = targets['pos_gt'][start_gt : end_gt] |
|
shape_gt = targets['shape_gt'][start_gt : end_gt] |
|
|
|
num_pred = pos_pred.shape[0] |
|
num_gt = pos_gt.shape[0] |
|
|
|
cost_pos = F.cross_entropy(pos_pred[:, None].repeat(1, num_gt, 1, 1).reshape(-1, pos_pred.shape[-1]), |
|
pos_gt[None, ...].repeat(num_pred, 1, 1).reshape(-1), |
|
label_smoothing=0.1, ignore_index=-1, reduction='none' |
|
).reshape(num_pred, num_gt, -1) |
|
cost_shape = ((shape_pred[:, None] - shape_gt[None, ...]) ** 2).sum(-1) |
|
|
|
C = ( |
|
self.cost_pos * cost_pos |
|
+ self.cost_shape * cost_shape |
|
) |
|
|
|
C = C.reshape(num_pred, num_gt, -1).cpu().numpy() |
|
|
|
if valid_mask is not None: |
|
|
|
C[:, ~valid_mask[start_gt : end_gt].cpu().numpy().astype(np.bool_)] = 1 << 15 |
|
|
|
_indices = [] |
|
for t in range(C.shape[-1]): |
|
_indices.append(linear_sum_assignment(C[..., t])) |
|
|
|
_indices = ( |
|
torch.as_tensor(np.array([indices_t[0] for indices_t in _indices]) + int(start_pred), dtype=torch.long).transpose(-1, -2), |
|
torch.as_tensor(np.array([indices_t[1] for indices_t in _indices]) + int(start_gt), dtype=torch.long).transpose(-1, -2), |
|
) |
|
|
|
pred_indices.append(_indices[0]) |
|
gt_indices.append(_indices[1]) |
|
|
|
pred_indices = torch.cat(pred_indices) |
|
gt_indices = torch.cat(gt_indices) |
|
|
|
return pred_indices, gt_indices |
|
|
|
def __repr__(self): |
|
head = "Matcher " + self.__class__.__name__ |
|
body = [ |
|
"cost_class: {}".format(self.cost_class), |
|
"cost_pos: {}".format(self.cost_pos), |
|
"cost_head: {}".format(self.cost_head), |
|
] |
|
_repr_indent = 4 |
|
lines = [head] + [" " * _repr_indent + line for line in body] |
|
return "\n".join(lines) |
|
|
|
|
|
class SMARTAgentDecoder(nn.Module): |
|
|
|
def __init__(self, |
|
dataset: str, |
|
input_dim: int, |
|
hidden_dim: int, |
|
num_historical_steps: int, |
|
time_span: Optional[int], |
|
pl2a_radius: float, |
|
pl2seed_radius: float, |
|
a2a_radius: float, |
|
a2sa_radius: float, |
|
pl2sa_radius: float, |
|
num_freq_bands: int, |
|
num_layers: int, |
|
num_heads: int, |
|
head_dim: int, |
|
dropout: float, |
|
token_size: int, |
|
attr_tokenizer: Attr_Tokenizer=None, |
|
predict_motion: bool=False, |
|
predict_state: bool=False, |
|
predict_map: bool=False, |
|
predict_occ: bool=False, |
|
state_token: Dict[str, int]=None, |
|
use_grid_token: bool=True, |
|
use_head_token: bool=True, |
|
use_state_token: bool=True, |
|
disable_insertion: bool=False, |
|
seed_size: int=5, |
|
buffer_size: int=32, |
|
num_recurrent_steps_val: int=-1, |
|
loss_weight: dict=None, |
|
logger=None) -> None: |
|
|
|
super(SMARTAgentDecoder, self).__init__() |
|
self.dataset = dataset |
|
self.input_dim = input_dim |
|
self.hidden_dim = hidden_dim |
|
self.num_historical_steps = num_historical_steps |
|
self.time_span = time_span if time_span is not None else num_historical_steps |
|
self.pl2a_radius = pl2a_radius |
|
self.pl2seed_radius = pl2seed_radius |
|
self.a2a_radius = a2a_radius |
|
self.a2sa_radius = a2sa_radius |
|
self.pl2sa_radius = pl2sa_radius |
|
self.num_freq_bands = num_freq_bands |
|
self.num_layers = num_layers |
|
self.num_heads = num_heads |
|
self.head_dim = head_dim |
|
self.dropout = dropout |
|
self.predict_motion = predict_motion |
|
self.predict_state = predict_state |
|
self.predict_map = predict_map |
|
self.predict_occ = predict_occ |
|
self.use_grid_token = use_grid_token |
|
self.use_head_token = use_head_token |
|
self.use_state_token = use_state_token |
|
self.disable_insertion = disable_insertion |
|
self.num_recurrent_steps_val = num_recurrent_steps_val |
|
self.loss_weight = loss_weight |
|
self.logger = logger |
|
|
|
self.attr_tokenizer = attr_tokenizer |
|
|
|
|
|
self.state_type = list(state_token.keys()) |
|
self.state_token = state_token |
|
self.invalid_state = int(state_token['invalid']) |
|
self.valid_state = int(state_token['valid']) |
|
self.enter_state = int(state_token['enter']) |
|
self.exit_state = int(state_token['exit']) |
|
|
|
self.seed_state_type = ['invalid', 'enter'] |
|
self.valid_state_type = ['invalid', 'valid', 'exit'] |
|
|
|
input_dim_x_a = 2 |
|
input_dim_r_t = 4 |
|
input_dim_r_pt2a = 3 |
|
input_dim_r_pt2sa = 3 |
|
input_dim_r_a2a = 3 |
|
input_dim_r_a2sa = 3 |
|
input_dim_motion_token = 8 |
|
input_dim_offset_token = 2 |
|
|
|
self.seed_size = seed_size |
|
self.buffer_size = buffer_size |
|
|
|
|
|
self.type_a_emb = nn.Embedding(len(AGENT_TYPE), hidden_dim) |
|
self.shape_emb = MLPEmbedding(input_dim=3, hidden_dim=hidden_dim) |
|
self.state_a_emb = nn.Embedding(len(self.state_type), hidden_dim) |
|
self.motion_gap = 1. |
|
self.heading_gap = 1. |
|
self.invalid_shape_value = .1 |
|
self.invalid_motion_value = -2. |
|
self.invalid_head_value = -2. |
|
|
|
self.x_a_emb = FourierEmbedding(input_dim=input_dim_x_a, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands) |
|
self.r_t_emb = FourierEmbedding(input_dim=input_dim_r_t, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands) |
|
self.r_pt2a_emb = FourierEmbedding(input_dim=input_dim_r_pt2a, hidden_dim=hidden_dim, |
|
num_freq_bands=num_freq_bands) |
|
self.r_a2a_emb = FourierEmbedding(input_dim=input_dim_r_a2a, hidden_dim=hidden_dim, |
|
num_freq_bands=num_freq_bands) |
|
|
|
|
|
self.r_pt2sa_emb = FourierEmbedding(input_dim=input_dim_r_pt2sa, hidden_dim=hidden_dim, |
|
num_freq_bands=num_freq_bands) |
|
self.r_a2sa_emb = FourierEmbedding(input_dim=input_dim_r_a2sa, hidden_dim=hidden_dim, |
|
num_freq_bands=num_freq_bands) |
|
self.token_emb_veh = MLPEmbedding(input_dim=input_dim_motion_token, hidden_dim=hidden_dim) |
|
self.token_emb_ped = MLPEmbedding(input_dim=input_dim_motion_token, hidden_dim=hidden_dim) |
|
self.token_emb_cyc = MLPEmbedding(input_dim=input_dim_motion_token, hidden_dim=hidden_dim) |
|
self.token_emb_grid = MLPEmbedding(input_dim=input_dim_offset_token, hidden_dim=hidden_dim) |
|
self.no_token_emb = nn.Embedding(1, hidden_dim) |
|
self.bos_token_emb = nn.Embedding(1, hidden_dim) |
|
self.invalid_offset_token_emb = nn.Embedding(1, hidden_dim) |
|
|
|
if self.use_grid_token: |
|
num_inputs = 4 |
|
else: |
|
num_inputs = 3 |
|
self.fusion_emb = MLPEmbedding(input_dim=self.hidden_dim * num_inputs, hidden_dim=self.hidden_dim) |
|
|
|
self.t_attn_layers = nn.ModuleList( |
|
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, |
|
bipartite=False, has_pos_emb=True) for _ in range(num_layers)] |
|
) |
|
self.pt2a_attn_layers = nn.ModuleList( |
|
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, |
|
bipartite=True, has_pos_emb=True) for _ in range(num_layers)] |
|
) |
|
self.a2a_attn_layers = nn.ModuleList( |
|
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, |
|
bipartite=False, has_pos_emb=True) for _ in range(num_layers)] |
|
) |
|
|
|
|
|
self.seed_layers = 3 |
|
|
|
|
|
|
|
|
|
self.pt2sa_attn_layers = nn.ModuleList( |
|
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, |
|
bipartite=True, has_pos_emb=True) for _ in range(self.seed_layers)] |
|
) |
|
self.a2sa_attn_layers = nn.ModuleList( |
|
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, |
|
bipartite=False, has_pos_emb=True) for _ in range(self.seed_layers)] |
|
) |
|
self.occ2sa_attn_layers = nn.ModuleList( |
|
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, |
|
bipartite=True, has_pos_emb=False) for _ in range(self.seed_layers)] |
|
) |
|
|
|
self.token_size = token_size |
|
|
|
self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, |
|
output_dim=self.token_size) |
|
|
|
self.state_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, |
|
output_dim=len(self.valid_state_type)) |
|
|
|
self.seed_state_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, |
|
output_dim=len(self.seed_state_type)) |
|
self.seed_type_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, |
|
output_dim=len(AGENT_TYPE) - 1) |
|
self.seed_shape_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, |
|
output_dim=3) |
|
|
|
self.grid_size = self.attr_tokenizer.grid_size |
|
self.angle_size = self.attr_tokenizer.angle_size |
|
|
|
if self.use_grid_token: |
|
self.seed_pos_rel_token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, |
|
output_dim=self.grid_size) |
|
self.seed_offset_xy_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, |
|
output_dim=2) |
|
self.seed_agent_occ_embed = MLPLayer(input_dim=self.grid_size, hidden_dim=hidden_dim, |
|
output_dim=hidden_dim) |
|
else: |
|
self.seed_pos_rel_xy_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, |
|
output_dim=2) |
|
if self.use_head_token: |
|
self.seed_heading_rel_token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, |
|
output_dim=self.angle_size) |
|
else: |
|
self.seed_heading_rel_theta_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, |
|
output_dim=1) |
|
|
|
|
|
|
|
if self.predict_occ: |
|
self.grid_agent_occ_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, |
|
output_dim=self.grid_size) |
|
self.grid_pt_occ_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, |
|
output_dim=self.grid_size) |
|
self.grid_index_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, |
|
output_dim=self.grid_size) |
|
|
|
|
|
|
|
self.num_seed_feature = 10 |
|
|
|
|
|
|
|
|
|
self.apply(weight_init) |
|
|
|
self.shift = 5 |
|
self.motion_beam_size = 5 |
|
self.insert_beam_size = 10 |
|
self.hist_mask = True |
|
self.temporal_attn_to_invalid = False |
|
self.use_rel = False |
|
self.inference_filter_overlap = True |
|
assert self.num_recurrent_steps_val % self.shift == 0 or self.num_recurrent_steps_val == -1, \ |
|
f"Invalid num_recurrent_steps_val: {num_recurrent_steps_val}." |
|
|
|
|
|
self.temporal_attn_seed = False |
|
self.seed_attn_to_av = True |
|
self.seed_use_ego_motion = False |
|
|
|
self.matcher = HungarianMatcher(loss_weight=loss_weight, enter_state=self.enter_state) |
|
|
|
def transform_rel(self, token_traj, prev_pos, prev_heading=None): |
|
if prev_heading is None: |
|
diff_xy = prev_pos[:, :, -1, :] - prev_pos[:, :, -2, :] |
|
prev_heading = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0]) |
|
|
|
num_agent, num_step, traj_num, traj_dim = token_traj.shape |
|
cos, sin = prev_heading.cos(), prev_heading.sin() |
|
rot_mat = torch.zeros((num_agent, num_step, 2, 2), device=prev_heading.device) |
|
rot_mat[:, :, 0, 0] = cos |
|
rot_mat[:, :, 0, 1] = -sin |
|
rot_mat[:, :, 1, 0] = sin |
|
rot_mat[:, :, 1, 1] = cos |
|
agent_diff_rel = torch.bmm(token_traj.view(-1, traj_num, 2), rot_mat.view(-1, 2, 2)).view(num_agent, num_step, traj_num, traj_dim) |
|
agent_pred_rel = agent_diff_rel + prev_pos[:, :, -1:, :] |
|
return agent_pred_rel |
|
|
|
def _agent_token_embedding(self, data, agent_token_index, agent_state, agent_offset_token_idx, pos_a, head_a, |
|
inference=False, filter_mask=None, av_index=None): |
|
|
|
if filter_mask is None: |
|
filter_mask = torch.ones_like(agent_state[:, 2], dtype=torch.bool) |
|
|
|
num_agent, num_step, traj_dim = pos_a.shape |
|
agent_type = data['agent']['type'][filter_mask] |
|
veh_mask = agent_type == 0 |
|
ped_mask = agent_type == 1 |
|
cyc_mask = agent_type == 2 |
|
|
|
motion_vector_a, head_vector_a = self._build_vector_a(pos_a, head_a, agent_state) |
|
|
|
|
|
trajectory_token_veh = data['agent']['trajectory_token_veh'] |
|
trajectory_token_ped = data['agent']['trajectory_token_ped'] |
|
trajectory_token_cyc = data['agent']['trajectory_token_cyc'] |
|
agent_token_emb_veh = self.token_emb_veh(trajectory_token_veh[:, -1].flatten(1, 2)) |
|
agent_token_emb_ped = self.token_emb_ped(trajectory_token_ped[:, -1].flatten(1, 2)) |
|
agent_token_emb_cyc = self.token_emb_cyc(trajectory_token_cyc[:, -1].flatten(1, 2)) |
|
|
|
|
|
agent_token_emb_veh = torch.cat([agent_token_emb_veh, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())]) |
|
agent_token_emb_ped = torch.cat([agent_token_emb_ped, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())]) |
|
agent_token_emb_cyc = torch.cat([agent_token_emb_cyc, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())]) |
|
|
|
|
|
agent_token_emb_veh = torch.cat([agent_token_emb_veh, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())]) |
|
agent_token_emb_ped = torch.cat([agent_token_emb_ped, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())]) |
|
agent_token_emb_cyc = torch.cat([agent_token_emb_cyc, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())]) |
|
|
|
|
|
agent_token_emb = torch.zeros((num_agent, num_step, self.hidden_dim), device=pos_a.device) |
|
agent_token_emb[veh_mask] = agent_token_emb_veh[agent_token_index[veh_mask]] |
|
agent_token_emb[ped_mask] = agent_token_emb_ped[agent_token_index[ped_mask]] |
|
agent_token_emb[cyc_mask] = agent_token_emb_cyc[agent_token_index[cyc_mask]] |
|
|
|
|
|
self.grid_token_emb = self.token_emb_grid(self.attr_tokenizer.grid) |
|
self.grid_token_emb = torch.cat([self.grid_token_emb, self.invalid_offset_token_emb(torch.zeros(1, device=pos_a.device).long())]) |
|
offset_token_emb = self.grid_token_emb[agent_offset_token_idx] |
|
|
|
|
|
is_invalid = agent_state == self.invalid_state |
|
agent_types = data['agent']['type'].clone()[filter_mask].long().repeat_interleave(repeats=num_step, dim=0) |
|
agent_types[is_invalid.reshape(-1)] = AGENT_TYPE.index('seed') |
|
agent_shapes = data['agent']['shape'].clone()[filter_mask, self.num_historical_steps - 1, :].repeat_interleave(repeats=num_step, dim=0) |
|
agent_shapes[is_invalid.reshape(-1)] = self.invalid_shape_value |
|
|
|
|
|
offset_pos = pos_a - pos_a[av_index].repeat_interleave(repeats=data['batch_size_a'], dim=0) |
|
feat_a, categorical_embs = self._build_agent_feature(num_step, pos_a.device, |
|
motion_vector_a, |
|
head_vector_a, |
|
agent_token_emb, |
|
offset_token_emb, |
|
offset_pos=offset_pos, |
|
type=agent_types, |
|
shape=agent_shapes, |
|
state=agent_state, |
|
n=num_agent) |
|
|
|
if inference: |
|
return ( |
|
feat_a, |
|
agent_token_emb, |
|
agent_token_emb_veh, |
|
agent_token_emb_ped, |
|
agent_token_emb_cyc, |
|
categorical_embs, |
|
trajectory_token_veh, |
|
trajectory_token_ped, |
|
trajectory_token_cyc, |
|
) |
|
|
|
else: |
|
|
|
|
|
if self.seed_use_ego_motion: |
|
motion_vector_seed = motion_vector_a[av_index].repeat_interleave(repeats=self.num_seed_feature, dim=0) |
|
head_vector_seed = head_vector_a[av_index].repeat_interleave(repeats=self.num_seed_feature, dim=0) |
|
else: |
|
motion_vector_seed = head_vector_seed = None |
|
feat_seed, _ = self._build_agent_feature(num_step, pos_a.device, |
|
motion_vector_seed, |
|
head_vector_seed, |
|
state_index=self.invalid_state, |
|
n=data.num_graphs * self.num_seed_feature) |
|
|
|
feat_a = torch.cat([feat_a, feat_seed], dim=0) |
|
|
|
return feat_a |
|
|
|
def _build_vector_a(self, pos_a, head_a, state_a): |
|
num_agent = pos_a.shape[0] |
|
|
|
motion_vector_a = torch.cat([pos_a.new_zeros(num_agent, 1, self.input_dim), |
|
pos_a[:, 1:] - pos_a[:, :-1]], dim=1) |
|
|
|
motion_vector_a[state_a == self.invalid_state] = self.invalid_motion_value |
|
|
|
|
|
is_last_invalid = (state_a.roll(shifts=1, dims=1) == self.invalid_state) & (state_a != self.invalid_state) |
|
is_last_invalid[:, 0] = state_a[:, 0] == self.enter_state |
|
motion_vector_a[is_last_invalid] = self.motion_gap |
|
|
|
|
|
is_last_valid = (state_a.roll(shifts=1, dims=1) != self.invalid_state) & (state_a == self.invalid_state) |
|
is_last_valid[:, 0] = False |
|
motion_vector_a[is_last_valid] = -self.motion_gap |
|
|
|
head_a[state_a == self.invalid_state] == self.invalid_head_value |
|
head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1) |
|
|
|
return motion_vector_a, head_vector_a |
|
|
|
def _build_agent_feature(self, num_step, device, |
|
motion_vector=None, |
|
head_vector=None, |
|
agent_token_emb=None, |
|
agent_grid_emb=None, |
|
offset_pos=None, |
|
type=None, |
|
shape=None, |
|
categorical_embs_a=None, |
|
state=None, |
|
state_index=None, |
|
n=1): |
|
|
|
if agent_token_emb is None: |
|
agent_token_emb = self.no_token_emb(torch.zeros(1, device=device).long())[:, None].repeat(n, num_step, 1) |
|
if state is not None: |
|
agent_token_emb[state == self.enter_state] = self.bos_token_emb(torch.zeros(1, device=device).long()) |
|
|
|
if agent_grid_emb is None: |
|
agent_grid_emb = self.grid_token_emb[None, None, self.grid_size // 2].repeat(n, num_step, 1) |
|
|
|
if motion_vector is None or head_vector is None: |
|
pos_a = torch.zeros((n, num_step, 2), device=device) |
|
head_a = torch.zeros((n, num_step), device=device) |
|
if state is None: |
|
state = torch.full((n, num_step), self.invalid_state, device=device) |
|
motion_vector, head_vector = self._build_vector_a(pos_a, head_a, state) |
|
|
|
if offset_pos is None: |
|
offset_pos = torch.zeros_like(motion_vector) |
|
|
|
feature_a = torch.stack( |
|
[torch.norm(motion_vector[:, :, :2], p=2, dim=-1), |
|
angle_between_2d_vectors(ctr_vector=head_vector, nbr_vector=motion_vector[:, :, :2]), |
|
|
|
], dim=-1) |
|
|
|
if categorical_embs_a is None: |
|
if type is None: |
|
type = torch.tensor([AGENT_TYPE.index('seed')], device=device) |
|
if shape is None: |
|
shape = torch.full((1, 3), self.invalid_shape_value, device=device) |
|
|
|
categorical_embs_a = [self.type_a_emb(type.reshape(-1)), self.shape_emb(shape.reshape(-1, shape.shape[-1]))] |
|
|
|
x_a = self.x_a_emb(continuous_inputs=feature_a.view(-1, feature_a.size(-1)), |
|
categorical_embs=categorical_embs_a) |
|
x_a = x_a.view(-1, num_step, self.hidden_dim) |
|
|
|
if state is None: |
|
assert state_index is not None, f"state index need to be set when state tensor is None!" |
|
state = torch.tensor([state_index], device=device)[:, None].repeat(n, num_step, 1) |
|
s_a = self.state_a_emb(state.reshape(-1).long()).reshape(n, num_step, self.hidden_dim) |
|
|
|
feat_a = torch.cat((agent_token_emb, x_a, s_a), dim=-1) |
|
if self.use_grid_token: |
|
feat_a = torch.cat([feat_a, agent_grid_emb], dim=-1) |
|
|
|
feat_a = self.fusion_emb(feat_a) |
|
|
|
return feat_a, categorical_embs_a |
|
|
|
def _pad_feat(self, num_graph, av_index, *feats, num_seed_feature=None): |
|
|
|
if num_seed_feature is None: |
|
num_seed_feature = self.num_seed_feature |
|
|
|
padded_feats = tuple() |
|
for i in range(len(feats)): |
|
padded_feats += (torch.cat([feats[i], feats[i][av_index].repeat_interleave( |
|
repeats=num_seed_feature, dim=0)], |
|
dim=0 |
|
),) |
|
|
|
pad_mask = torch.ones(*padded_feats[0].shape[:2], device=feats[0].device).bool() |
|
pad_mask[-num_graph * num_seed_feature:] = False |
|
|
|
return padded_feats + (pad_mask,) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _build_temporal_edge(self, data, pos_a, head_a, state_a, head_vector_a, mask, inference_mask=None): |
|
|
|
num_graph = data.num_graphs |
|
num_agent = pos_a.shape[0] |
|
hist_mask = mask.clone() |
|
|
|
if not self.temporal_attn_to_invalid: |
|
is_bos = state_a == self.enter_state |
|
bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0)) |
|
history_invalid_mask = torch.arange(mask.shape[1]).expand(mask.shape[0], mask.shape[1]).to(mask.device) |
|
history_invalid_mask = (history_invalid_mask < bos_index[:, None]) |
|
hist_mask[history_invalid_mask] = False |
|
|
|
if not self.temporal_attn_seed: |
|
hist_mask[-num_graph * self.num_seed_feature:] = False |
|
if inference_mask is not None: |
|
inference_mask[-num_graph * self.num_seed_feature:] = False |
|
else: |
|
|
|
|
|
raise RuntimeError("Wrong settings!") |
|
|
|
pos_t = pos_a.reshape(-1, self.input_dim) |
|
head_t = head_a.reshape(-1) |
|
head_vector_t = head_vector_a.reshape(-1, 2) |
|
|
|
|
|
is_bos = state_a == self.enter_state |
|
is_bos[-num_graph * self.num_seed_feature:] = False |
|
bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0)) |
|
motion_predict_start_index = torch.clamp(bos_index - self.time_span / self.shift + 1, min=0) |
|
motion_predict_mask = torch.arange(hist_mask.shape[1]).expand(hist_mask.shape[0], -1).to(hist_mask.device) |
|
motion_predict_mask = motion_predict_mask >= motion_predict_start_index[:, None] |
|
hist_mask[~motion_predict_mask] = False |
|
|
|
if self.hist_mask and self.training: |
|
hist_mask[ |
|
torch.arange(mask.shape[0]).unsqueeze(1), torch.randint(0, mask.shape[1], (num_agent, 10))] = False |
|
mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1) |
|
elif inference_mask is not None: |
|
mask_t = hist_mask.unsqueeze(2) & inference_mask.unsqueeze(1) |
|
else: |
|
mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1) |
|
|
|
|
|
edge_index_t = dense_to_sparse(mask_t)[0] |
|
edge_index_t = edge_index_t[:, (edge_index_t[1] - edge_index_t[0] > 0) & |
|
(edge_index_t[1] - edge_index_t[0] <= self.time_span / self.shift)] |
|
rel_pos_t = pos_t[edge_index_t[0]] - pos_t[edge_index_t[1]] |
|
rel_head_t = wrap_angle(head_t[edge_index_t[0]] - head_t[edge_index_t[1]]) |
|
|
|
|
|
is_invalid = state_a == self.invalid_state |
|
is_invalid_t = is_invalid.reshape(-1) |
|
|
|
rel_pos_t[is_invalid_t[edge_index_t[0]] & ~is_invalid_t[edge_index_t[1]]] = -self.motion_gap |
|
rel_pos_t[~is_invalid_t[edge_index_t[0]] & is_invalid_t[edge_index_t[1]]] = self.motion_gap |
|
rel_head_t[is_invalid_t[edge_index_t[0]] & ~is_invalid_t[edge_index_t[1]]] = -self.heading_gap |
|
rel_head_t[~is_invalid_t[edge_index_t[1]] & is_invalid_t[edge_index_t[1]]] = self.heading_gap |
|
|
|
rel_pos_t[is_invalid_t[edge_index_t[0]] & is_invalid_t[edge_index_t[1]]] = self.invalid_motion_value |
|
rel_head_t[is_invalid_t[edge_index_t[0]] & is_invalid_t[edge_index_t[1]]] = self.invalid_head_value |
|
|
|
r_t = torch.stack( |
|
[torch.norm(rel_pos_t[:, :2], p=2, dim=-1), |
|
angle_between_2d_vectors(ctr_vector=head_vector_t[edge_index_t[1]], nbr_vector=rel_pos_t[:, :2]), |
|
rel_head_t, |
|
edge_index_t[0] - edge_index_t[1]], dim=-1) |
|
r_t = self.r_t_emb(continuous_inputs=r_t, categorical_embs=None) |
|
|
|
return edge_index_t, r_t |
|
|
|
def _build_interaction_edge(self, data, pos_a, head_a, state_a, head_vector_a, batch_s, mask, pad_mask=None, inference_mask=None, |
|
av_index=None, seq_mask=None, seq_index=None, grid_index_a=None, **plot_kwargs): |
|
num_graph = data.num_graphs |
|
num_agent, num_step, _ = pos_a.shape |
|
is_training = inference_mask is None |
|
|
|
mask_a = mask.clone() |
|
|
|
if pad_mask is None: |
|
pad_mask = torch.ones_like(state_a).bool() |
|
|
|
pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim) |
|
head_s = head_a.transpose(0, 1).reshape(-1) |
|
head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2) |
|
pad_mask_s = pad_mask.transpose(0, 1).reshape(-1) |
|
if inference_mask is not None: |
|
mask_a = mask_a & inference_mask |
|
mask_s = mask_a.transpose(0, 1).reshape(-1) |
|
|
|
|
|
edge_index_a2a = radius_graph(x=pos_s[:, :2], r=self.a2a_radius, batch=batch_s, loop=False, |
|
max_num_neighbors=300) |
|
edge_index_a2a = subgraph(subset=mask_s & pad_mask_s, edge_index=edge_index_a2a)[0] |
|
|
|
if int(os.getenv('PLOT_EDGE', 0)): |
|
plot_interact_edge(edge_index_a2a, data['scenario_id'], data['batch_size_a'].cpu(), self.num_seed_feature, num_step, |
|
av_index=av_index, **plot_kwargs) |
|
|
|
rel_pos_a2a = pos_s[edge_index_a2a[0]] - pos_s[edge_index_a2a[1]] |
|
rel_head_a2a = wrap_angle(head_s[edge_index_a2a[0]] - head_s[edge_index_a2a[1]]) |
|
|
|
|
|
is_invalid = state_a == self.invalid_state |
|
is_invalid_s = is_invalid.transpose(0, 1).reshape(-1) |
|
|
|
rel_pos_a2a[is_invalid_s[edge_index_a2a[0]] & ~is_invalid_s[edge_index_a2a[1]]] = -self.motion_gap |
|
rel_pos_a2a[~is_invalid_s[edge_index_a2a[0]] & is_invalid_s[edge_index_a2a[1]]] = self.motion_gap |
|
rel_head_a2a[is_invalid_s[edge_index_a2a[0]] & ~is_invalid_s[edge_index_a2a[1]]] = -self.heading_gap |
|
rel_head_a2a[~is_invalid_s[edge_index_a2a[1]] & is_invalid_s[edge_index_a2a[1]]] = self.heading_gap |
|
|
|
rel_pos_a2a[is_invalid_s[edge_index_a2a[0]] & is_invalid_s[edge_index_a2a[1]]] = self.invalid_motion_value |
|
rel_head_a2a[is_invalid_s[edge_index_a2a[0]] & is_invalid_s[edge_index_a2a[1]]] = self.invalid_head_value |
|
|
|
r_a2a = torch.stack( |
|
[torch.norm(rel_pos_a2a[:, :2], p=2, dim=-1), |
|
angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2a[1]], nbr_vector=rel_pos_a2a[:, :2]), |
|
rel_head_a2a], dim=-1) |
|
r_a2a = self.r_a2a_emb(continuous_inputs=r_a2a, categorical_embs=None) |
|
|
|
|
|
if is_training: |
|
mask_av = torch.ones_like(mask_a).bool() |
|
if not self.seed_attn_to_av: |
|
mask_av[av_index] = False |
|
mask_a &= mask_av |
|
edge_index_seed2a, r_seed2a = self._build_a2sa_edge(data, pos_a, head_a, head_vector_a, batch_s, |
|
mask_a.clone(), ~pad_mask.clone(), inference_mask=inference_mask, |
|
r=self.pl2seed_radius, max_num_neighbors=300, |
|
seq_mask=seq_mask, seq_index=seq_index, grid_index_a=grid_index_a, mode='insert') |
|
|
|
if os.getenv('PLOT_EDGE', False): |
|
plot_interact_edge(edge_index_seed2a, data['scenario_id'], data['batch_size_a'].cpu(), self.num_seed_feature, num_step, |
|
'interact_edge_map_seed', av_index=av_index, **plot_kwargs) |
|
|
|
edge_index_a2a = torch.cat([edge_index_a2a, edge_index_seed2a], dim=-1) |
|
r_a2a = torch.cat([r_a2a, r_seed2a]) |
|
|
|
return edge_index_a2a, r_a2a, (edge_index_a2a.shape[1], edge_index_seed2a.shape[1]) |
|
|
|
return edge_index_a2a, r_a2a |
|
|
|
def _build_map2agent_edge(self, data, pos_a, head_a, state_a, head_vector_a, batch_s, batch_pl, |
|
mask, pad_mask=None, inference_mask=None, av_index=None, **kwargs): |
|
num_graph = data.num_graphs |
|
num_agent, num_step, _ = pos_a.shape |
|
is_training = inference_mask is None |
|
|
|
mask_pl2a = mask.clone() |
|
|
|
if pad_mask is None: |
|
pad_mask = torch.ones_like(state_a).bool() |
|
|
|
pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim) |
|
head_s = head_a.transpose(0, 1).reshape(-1) |
|
head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2) |
|
pad_mask_s = pad_mask.transpose(0, 1).reshape(-1) |
|
if inference_mask is not None: |
|
mask_pl2a = mask_pl2a & inference_mask |
|
mask_pl2a = mask_pl2a.transpose(0, 1).reshape(-1) |
|
|
|
ori_pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous() |
|
ori_orient_pl = data['pt_token']['orientation'].contiguous() |
|
pos_pl = ori_pos_pl.repeat(num_step, 1) |
|
orient_pl = ori_orient_pl.repeat(num_step) |
|
|
|
|
|
|
|
|
|
edge_index_pl2a = radius(x=pos_pl[:, :2], y=pos_s[:, :2], r=self.pl2a_radius, |
|
batch_x=batch_pl, batch_y=batch_s, max_num_neighbors=5) |
|
edge_index_pl2a = edge_index_pl2a[[1, 0]] |
|
edge_index_pl2a = edge_index_pl2a[:, mask_pl2a[edge_index_pl2a[1]] & |
|
pad_mask_s[edge_index_pl2a[1]]] |
|
|
|
rel_pos_pl2a = pos_pl[edge_index_pl2a[0]] - pos_s[edge_index_pl2a[1]] |
|
rel_orient_pl2a = wrap_angle(orient_pl[edge_index_pl2a[0]] - head_s[edge_index_pl2a[1]]) |
|
|
|
|
|
is_invalid = state_a == self.invalid_state |
|
is_invalid_s = is_invalid.transpose(0, 1).reshape(-1) |
|
rel_pos_pl2a[is_invalid_s[edge_index_pl2a[1]]] = self.motion_gap |
|
rel_orient_pl2a[is_invalid_s[edge_index_pl2a[1]]] = self.heading_gap |
|
|
|
r_pl2a = torch.stack( |
|
[torch.norm(rel_pos_pl2a[:, :2], p=2, dim=-1), |
|
angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_pl2a[1]], nbr_vector=rel_pos_pl2a[:, :2]), |
|
rel_orient_pl2a], dim=-1) |
|
r_pl2a = self.r_pt2a_emb(continuous_inputs=r_pl2a, categorical_embs=None) |
|
|
|
|
|
if is_training: |
|
edge_index_pl2seed, r_pl2seed = self._build_map2sa_edge(data, pos_a, head_a, head_vector_a, batch_s, batch_pl, |
|
~pad_mask.clone(), inference_mask=inference_mask, |
|
r=self.pl2seed_radius, max_num_neighbors=2048, mode='insert') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if os.getenv('PLOT_EDGE', False): |
|
plot_interact_edge(edge_index_pl2seed, data['scenario_id'], data['batch_size_a'].cpu(), self.num_seed_feature, num_step, |
|
'interact_edge_map_seed', av_index=av_index) |
|
|
|
edge_index_pl2a = torch.cat([edge_index_pl2a, edge_index_pl2seed], dim=-1) |
|
r_pl2a = torch.cat([r_pl2a, r_pl2seed]) |
|
|
|
return edge_index_pl2a, r_pl2a, (edge_index_pl2a.shape[1], edge_index_pl2seed.shape[1]) |
|
|
|
return edge_index_pl2a, r_pl2a |
|
|
|
def _build_a2sa_edge(self, data, pos_a, head_a, head_vector_a, batch_s, mask_a, mask_sa, |
|
inference_mask=None, r=None, max_num_neighbors=8, seq_mask=None, seq_index=None, |
|
grid_index_a=None, mode: Literal['insert', 'feature']='feature', **plot_kwargs): |
|
|
|
num_agent, num_step, _ = pos_a.shape |
|
is_training = inference_mask is None |
|
|
|
pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim) |
|
head_s = head_a.transpose(0, 1).reshape(-1) |
|
head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2) |
|
if inference_mask is not None: |
|
mask_a = mask_a & inference_mask |
|
mask_sa = mask_sa & inference_mask |
|
mask_s = mask_a.transpose(0, 1).reshape(-1) |
|
mask_s_sa = mask_sa.transpose(0, 1).reshape(-1) |
|
|
|
|
|
assert r is not None, "r needs to be specified!" |
|
|
|
|
|
edge_index_a2sa = radius(x=pos_s[:, :2], y=pos_s[mask_s_sa, :2], r=r, |
|
batch_x=batch_s, batch_y=batch_s[mask_s_sa], max_num_neighbors=max_num_neighbors) |
|
edge_index_a2sa = edge_index_a2sa[[1, 0]] |
|
edge_index_a2sa = edge_index_a2sa[:, ~mask_s_sa[edge_index_a2sa[0]] & mask_s[edge_index_a2sa[0]]] |
|
|
|
|
|
if seq_mask is not None: |
|
edge_mask = seq_mask[edge_index_a2sa[1]] |
|
edge_mask = torch.gather(edge_mask, dim=1, index=edge_index_a2sa[0, :, None] % num_agent)[:, 0] |
|
edge_index_a2sa = edge_index_a2sa[:, edge_mask] |
|
|
|
if seq_index is None: |
|
seq_index = torch.zeros(num_agent, device=pos_a.device).long() |
|
if seq_index.dim() == 1: |
|
seq_index = seq_index[:, None].repeat(1, num_step) |
|
seq_index = seq_index.transpose(0, 1).reshape(-1) |
|
assert seq_index.shape[0] == pos_s.shape[0], f"Inconsistent lenght {seq_index.shape[0]} and {pos_s.shape[0]}!" |
|
|
|
|
|
all_index = torch.arange(pos_s.shape[0], device=pos_a.device).long() |
|
sa_index = all_index[mask_s_sa] |
|
edge_index_a2sa[1] = sa_index[edge_index_a2sa[1]] |
|
|
|
|
|
if os.getenv('PLOT_EDGE_INFERENCE', False) and not is_training: |
|
num_agent, num_step, _ = pos_a.shape |
|
|
|
|
|
plot_interact_edge(edge_index_a2sa, data['scenario_id'], torch.tensor([num_agent - 1]), 1, num_step, |
|
f"interact_a2sa_edge_map_infer_{plot_kwargs['tag']}", **plot_kwargs) |
|
|
|
rel_pos_a2sa = pos_s[edge_index_a2sa[0]] - pos_s[edge_index_a2sa[1]] |
|
rel_head_a2sa = wrap_angle(head_s[edge_index_a2sa[0]] - head_s[edge_index_a2sa[1]]) |
|
|
|
if mode == 'insert': |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r_a2sa = torch.stack( |
|
[torch.norm(rel_pos_a2sa[:, :2], p=2, dim=-1), |
|
angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2sa[1]], nbr_vector=rel_pos_a2sa[:, :2]), |
|
rel_head_a2sa], dim=-1) |
|
|
|
r_a2sa = self.r_a2sa_emb(continuous_inputs=r_a2sa, categorical_embs=None) |
|
|
|
elif mode == 'feature': |
|
|
|
r_a2sa = torch.stack( |
|
[torch.norm(rel_pos_a2sa[:, :2], p=2, dim=-1), |
|
angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2sa[1]], nbr_vector=rel_pos_a2sa[:, :2]), |
|
rel_head_a2sa], dim=-1) |
|
r_a2sa = self.r_a2a_emb(continuous_inputs=r_a2sa, categorical_embs=None) |
|
|
|
else: |
|
raise ValueError(f"Unsupport mode {mode}.") |
|
|
|
return edge_index_a2sa, r_a2sa |
|
|
|
def _build_map2sa_edge(self, data, pos_a, head_a, head_vector_a, batch_s, batch_pl, |
|
mask_sa, inference_mask=None, r=None, max_num_neighbors=32, mode: Literal['insert', 'feature']='feature'): |
|
|
|
_, num_step, _ = pos_a.shape |
|
|
|
mask_pl2sa = torch.ones_like(mask_sa).bool() |
|
if inference_mask is not None: |
|
mask_pl2sa = mask_pl2sa & inference_mask |
|
mask_pl2sa = mask_pl2sa.transpose(0, 1).reshape(-1) |
|
mask_s_sa = mask_sa.transpose(0, 1).reshape(-1) |
|
|
|
pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim) |
|
head_s = head_a.transpose(0, 1).reshape(-1) |
|
head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2) |
|
|
|
ori_pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous() |
|
ori_orient_pl = data['pt_token']['orientation'].contiguous() |
|
pos_pl = ori_pos_pl.repeat(num_step, 1) |
|
orient_pl = ori_orient_pl.repeat(num_step) |
|
|
|
|
|
assert r is not None, "r needs to be specified!" |
|
|
|
|
|
edge_index_pl2sa = radius(x=pos_pl[:, :2], y=pos_s[mask_s_sa, :2], r=r, |
|
batch_x=batch_pl, batch_y=batch_s[mask_s_sa], max_num_neighbors=max_num_neighbors) |
|
edge_index_pl2sa = edge_index_pl2sa[[1, 0]] |
|
edge_index_pl2sa = edge_index_pl2sa[:, mask_pl2sa[mask_s_sa][edge_index_pl2sa[1]]] |
|
|
|
|
|
all_index = torch.arange(pos_s.shape[0], device=pos_a.device).long() |
|
sa_index = all_index[mask_s_sa] |
|
edge_index_pl2sa[1] = sa_index[edge_index_pl2sa[1]] |
|
|
|
|
|
|
|
|
|
|
|
rel_pos_pl2sa = pos_pl[edge_index_pl2sa[0]] - pos_s[edge_index_pl2sa[1]] |
|
rel_orient_pl2sa = wrap_angle(orient_pl[edge_index_pl2sa[0]] - head_s[edge_index_pl2sa[1]]) |
|
|
|
r_pl2sa = torch.stack( |
|
[torch.norm(rel_pos_pl2sa[:, :2], p=2, dim=-1), |
|
angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_pl2sa[1]], nbr_vector=rel_pos_pl2sa[:, :2]), |
|
rel_orient_pl2sa], dim=-1) |
|
|
|
if mode == 'insert': |
|
r_pl2sa = self.r_pt2sa_emb(continuous_inputs=r_pl2sa, categorical_embs=None) |
|
elif mode == 'feature': |
|
r_pl2sa = self.r_pt2a_emb(continuous_inputs=r_pl2sa, categorical_embs=None) |
|
else: |
|
raise ValueError(f"Unsupport mode {mode}.") |
|
|
|
return edge_index_pl2sa, r_pl2sa |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_inputs(self, data: HeteroData) -> Dict[str, torch.Tensor]: |
|
|
|
pos_a = data['agent']['token_pos'].clone() |
|
head_a = data['agent']['token_heading'].clone() |
|
agent_token_index = data['agent']['token_idx'].clone() |
|
agent_state_index = data['agent']['state_idx'].clone() |
|
mask = data['agent']['raw_agent_valid_mask'].clone() |
|
|
|
agent_grid_token_idx = data['agent']['grid_token_idx'] |
|
agent_grid_offset_xy = data['agent']['grid_offset_xy'] |
|
agent_head_token_idx = data['agent']['heading_token_idx'] |
|
sort_indices = data['agent']['sort_indices'] |
|
|
|
next_token_index_gt = agent_token_index.roll(shifts=-1, dims=1) |
|
next_state_index_gt = agent_token_index.roll(shifts=-1, dims=1) |
|
|
|
|
|
bos_token_index = torch.nonzero(agent_state_index == self.enter_state) |
|
eos_token_index = torch.nonzero(agent_state_index == self.exit_state) |
|
|
|
|
|
next_token_eval_mask = mask.clone() |
|
next_token_eval_mask = next_token_eval_mask * next_token_eval_mask.roll(shifts=-1, dims=1) * next_token_eval_mask.roll(shifts=1, dims=1) |
|
for bos_token_index_ in bos_token_index: |
|
next_token_eval_mask[bos_token_index_[0], bos_token_index_[1] : bos_token_index_[1] + 1] = 1 |
|
next_token_eval_mask[bos_token_index_[0], bos_token_index_[1] + 1 : bos_token_index_[1] + 2] = \ |
|
mask[bos_token_index_[0], bos_token_index_[1] + 2 : bos_token_index_[1] + 3] |
|
next_token_eval_mask[eos_token_index[:, 0], eos_token_index[:, 1]] = 0 |
|
|
|
|
|
next_state_eval_mask = mask.clone() |
|
next_state_eval_mask = next_state_eval_mask * next_state_eval_mask.roll(shifts=-1, dims=1) * next_state_eval_mask.roll(shifts=1, dims=1) |
|
for bos_token_index_ in bos_token_index: |
|
next_state_eval_mask[bos_token_index_[0], :bos_token_index_[1]] = 0 |
|
next_state_eval_mask[bos_token_index_[0], bos_token_index_[1] : bos_token_index_[1] + 1] = 1 |
|
next_state_eval_mask[bos_token_index_[0], bos_token_index_[1] + 1 : bos_token_index_[1] + 2] = \ |
|
mask[bos_token_index_[0], bos_token_index_[1] + 2 : bos_token_index_[1] + 3] |
|
for eos_token_index_ in eos_token_index: |
|
next_state_eval_mask[eos_token_index_[0], eos_token_index_[1] + 1:] = 1 |
|
next_state_eval_mask[eos_token_index_[0], eos_token_index_[1] : eos_token_index_[1] + 1] = \ |
|
mask[eos_token_index_[0], eos_token_index_[1] - 1 : eos_token_index_[1]] |
|
|
|
|
|
next_token_eval_mask[:, 0] = mask[:, 0] * mask[:, 1] |
|
next_state_eval_mask[:, 0] = mask[:, 0] * mask[:, 1] |
|
next_token_eval_mask[:, -1] = 0 |
|
next_state_eval_mask[:, -1] = 0 |
|
|
|
if next_token_index_gt[next_token_eval_mask].min() < 0: |
|
raise RuntimeError() |
|
|
|
return {'token_pos': pos_a, |
|
'token_heading': head_a, |
|
'next_token_idx_gt': next_token_index_gt, |
|
'next_state_idx_gt': next_state_index_gt, |
|
'next_token_eval_mask': next_token_eval_mask, |
|
'raw_agent_valid_mask': data['agent']['raw_agent_valid_mask'], |
|
'state_token': agent_state_index, |
|
'grid_index': agent_grid_token_idx, |
|
} |
|
|
|
def _build_seq(self, device, data, num_agent, num_step, av_index, sort_indices): |
|
""" |
|
Args: |
|
sort_indices (torch.Tensor): shape (num_agent, num_atep) |
|
""" |
|
ptr = data['agent']['ptr'] |
|
num_graph = len(ptr) - 1 |
|
|
|
|
|
seq_mask = torch.ones(num_graph * self.num_seed_feature, num_step, num_agent + num_graph * self.num_seed_feature, device=device).bool() |
|
seq_mask[..., -num_graph * self.num_seed_feature:] = False |
|
for b in range(num_graph): |
|
batch_sort_indices = sort_indices[ptr[b] : ptr[b + 1]] |
|
for t in range(num_step): |
|
for s in range(self.num_seed_feature): |
|
seq_mask[b * self.num_seed_feature + s, t, batch_sort_indices[s:, t].flatten().long()] = False |
|
if self.seed_attn_to_av: |
|
seq_mask[..., av_index] = True |
|
seq_mask = seq_mask.transpose(0, 1).reshape(-1, num_agent + num_graph * self.num_seed_feature) |
|
|
|
seq_index = torch.cat([torch.zeros(num_agent), (torch.arange(self.num_seed_feature) + 1).repeat(num_graph)]).to(device) |
|
seq_index = seq_index[:, None].repeat(1, num_step) |
|
|
|
for b in range(num_graph): |
|
batch_sort_indices = sort_indices[ptr[b] : ptr[b + 1]] |
|
for t in range(num_step): |
|
for s in range(self.num_seed_feature): |
|
seq_index[batch_sort_indices[s : s + 1, t].flatten().long() + ptr[b], t] = s + 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seq_index[av_index] = 0 |
|
|
|
return seq_mask, seq_index |
|
|
|
def _build_occ_gt(self, data, seq_mask, pos_rel_index_gt, pos_rel_index_gt_seed=None, mask_seed=None, |
|
edge_index=None, mode='edge_index'): |
|
""" |
|
Args: |
|
seq_mask (torch.Tensor): shape (num_step * num_seed_feature, num_agent + self.num_seed_feature) |
|
pos_rel_index_gt (torch.Tensor): shape (num_agent, num_step) |
|
pos_rel_index_gt_seed (torch.Tensor): shape (num_seed, num_step) |
|
""" |
|
num_agent = data['agent']['state_idx'].shape[0] + data.num_graphs * self.num_seed_feature |
|
num_step = data['agent']['state_idx'].shape[1] |
|
data['agent']['agent_occ'] = torch.zeros(data.num_graphs * self.num_seed_feature, num_step, self.attr_tokenizer.grid_size, |
|
device=data['agent']['state_idx'].device).long() |
|
data['agent']['map_occ'] = torch.zeros(data.num_graphs, num_step, self.attr_tokenizer.grid_size, |
|
device=data['agent']['state_idx'].device).long() |
|
|
|
if mode == 'edge_index': |
|
|
|
assert edge_index is not None, f"Need edge_index input!" |
|
for src_index in torch.unique(edge_index[1]): |
|
|
|
src_row = src_index % num_agent - (num_agent - data.num_graphs * self.num_seed_feature) |
|
src_col = src_index // num_agent |
|
|
|
tgt_indexes = edge_index[0, edge_index[1] == src_index] |
|
tgt_rows = tgt_indexes % num_agent |
|
tgt_cols = tgt_indexes // num_agent |
|
assert tgt_rows.max() < num_agent - data.num_graphs * self.num_seed_feature, f"Invalid {tgt_rows}" |
|
assert torch.unique(tgt_cols).shape[0] == 1 and torch.unique(tgt_cols)[0] == src_col |
|
data['agent']['agent_occ'][src_row, src_col, pos_rel_index_gt[tgt_rows, tgt_cols]] = 1 |
|
|
|
else: |
|
|
|
seq_mask = seq_mask.reshape(num_step, self.num_seed_feature, -1).transpose(0, 1)[..., :-self.num_seed_feature] |
|
for s in range(self.num_seed_feature): |
|
for t in range(num_step): |
|
index = pos_rel_index_gt[seq_mask[s, t], t] |
|
data['agent']['agent_occ'][s, t, index[index != -1]] = 1 |
|
if t > 0 and s < pos_rel_index_gt_seed.shape[0] and mask_seed[s, t - 1]: |
|
data['agent']['agent_occ'][s, t, pos_rel_index_gt_seed[s, t - 1]] = -1 |
|
|
|
ptr = data['pt_token']['ptr'] |
|
pt_grid_token_idx = data['agent']['pt_grid_token_idx'] |
|
for b in range(data.num_graphs): |
|
batch_pt_grid_token_idx = pt_grid_token_idx[:, ptr[b] : ptr[b + 1]] |
|
for t in range(num_step): |
|
data['agent']['map_occ'][b, t, batch_pt_grid_token_idx[t][batch_pt_grid_token_idx[t] != -1]] = 1 |
|
data['agent']['map_occ'] = data['agent']['map_occ'].repeat_interleave(repeats=self.num_seed_feature, dim=0) |
|
|
|
def forward(self, |
|
data: HeteroData, |
|
map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
|
|
pos_a = data['agent']['token_pos'].clone() |
|
head_a = data['agent']['token_heading'].clone() |
|
num_agent, num_step, traj_dim = pos_a.shape |
|
agent_shape = data['agent']['shape'][:, self.num_historical_steps - 1].clone() |
|
agent_token_index = data['agent']['token_idx'].clone() |
|
agent_state_index = data['agent']['state_idx'].clone() |
|
agent_type_index = data['agent']['type'].clone() |
|
|
|
av_index = data['agent']['av_index'].long() |
|
ego_pos = pos_a[av_index] |
|
ego_head = head_a[av_index] |
|
|
|
_, head_vector_a = self._build_vector_a(pos_a, head_a, agent_state_index) |
|
|
|
agent_grid_token_idx = data['agent']['grid_token_idx'] |
|
agent_grid_offset_xy = data['agent']['grid_offset_xy'] |
|
agent_head_token_idx = data['agent']['heading_token_idx'] |
|
agent_pos_xy = data['agent']['pos_xy'] |
|
agent_heading_theta = data['agent']['heading_theta'] |
|
sort_indices = data['agent']['sort_indices'] |
|
|
|
device = pos_a.device |
|
|
|
feat_a = self._agent_token_embedding(data, |
|
agent_token_index, |
|
agent_state_index, |
|
agent_grid_token_idx, |
|
pos_a, |
|
head_a, |
|
av_index=av_index) |
|
|
|
raw_feat_a = feat_a[:-data.num_graphs * self.num_seed_feature].clone() |
|
raw_feat_seed = feat_a[-data.num_graphs * self.num_seed_feature:].clone() |
|
|
|
|
|
mask = data['agent']['raw_agent_valid_mask'].clone() |
|
temporal_mask = mask.clone() |
|
interact_mask = mask.clone() |
|
|
|
is_bos = agent_state_index == self.enter_state |
|
is_eos = agent_state_index == self.exit_state |
|
bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0)) |
|
eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1)) |
|
|
|
temporal_mask = torch.ones_like(mask) |
|
motion_mask = torch.arange(mask.shape[1]).expand(mask.shape[0], -1).to(device) |
|
motion_mask = (motion_mask > bos_index[:, None]) & (motion_mask <= eos_index[:, None]) |
|
temporal_mask[motion_mask] = mask[motion_mask] |
|
temporal_mask = torch.cat([temporal_mask, torch.ones(data.num_graphs * self.num_seed_feature, *temporal_mask.shape[1:], device=device)]).bool() |
|
|
|
interact_mask[agent_state_index == self.enter_state] = True |
|
interact_mask = torch.cat([interact_mask, torch.ones(data.num_graphs * self.num_seed_feature, *interact_mask.shape[1:], device=device)]).bool() |
|
|
|
pos_a_p, head_a_p, state_a_p, head_vector_a_p, grid_index_a_p, pad_mask = \ |
|
self._pad_feat(data.num_graphs, av_index, pos_a, head_a, agent_state_index, head_vector_a, agent_grid_token_idx) |
|
edge_index_t, r_t = self._build_temporal_edge(data, pos_a_p, head_a_p, state_a_p, head_vector_a_p, temporal_mask) |
|
|
|
|
|
batch_s = torch.cat([ |
|
torch.cat([data['agent']['batch'], torch.arange(data.num_graphs, device=device |
|
).repeat_interleave(repeats=self.num_seed_feature, dim=0)], dim=0) |
|
+ data.num_graphs * t for t in range(num_step) |
|
], dim=0) |
|
batch_pl = torch.cat([data['pt_token']['batch'] + data.num_graphs * t for t in range(num_step)], dim=0) |
|
|
|
seq_mask, seq_index = self._build_seq(device, data, num_agent, num_step, av_index, sort_indices) |
|
plot_kwargs = dict(is_bos=agent_state_index == self.enter_state) |
|
edge_index_a2a, r_a2a, (na2a, na2sa) = self._build_interaction_edge(data, pos_a_p, head_a_p, state_a_p, head_vector_a_p, batch_s, |
|
interact_mask, pad_mask=pad_mask, av_index=av_index, |
|
seq_mask=seq_mask, seq_index=seq_index, grid_index_a=grid_index_a_p, **plot_kwargs) |
|
|
|
edge_index_pl2a, r_pl2a, (npl2a, npl2sa) = self._build_map2agent_edge(data, pos_a_p, head_a_p, state_a_p, head_vector_a_p, batch_s, batch_pl, |
|
interact_mask, pad_mask=pad_mask, av_index=av_index) |
|
interact_mask = interact_mask[:-data.num_graphs * self.num_seed_feature] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(self.num_layers): |
|
|
|
feat_a = feat_a.reshape(-1, self.hidden_dim) |
|
feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t) |
|
|
|
feat_a = feat_a.reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim) |
|
feat_a = self.pt2a_attn_layers[i](( |
|
map_enc['x_pt'].repeat_interleave(repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape( |
|
-1, self.hidden_dim), feat_a), r_pl2a[:npl2a], edge_index_pl2a[:, :npl2a]) |
|
|
|
feat_a = self.a2a_attn_layers[i](feat_a, r_a2a[:na2a], edge_index_a2a[:, :na2a]) |
|
feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1) |
|
|
|
feat_ea = feat_a[:-data.num_graphs * self.num_seed_feature] |
|
|
|
|
|
next_token_prob = self.token_predict_head(feat_ea) |
|
next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1) |
|
_, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1) |
|
|
|
next_token_index_gt = agent_token_index.roll(shifts=-1, dims=1) |
|
|
|
|
|
next_state_prob = self.state_predict_head(feat_ea) |
|
next_state_idx = next_state_prob.softmax(dim=-1).argmax(dim=-1, keepdim=True) |
|
|
|
next_state_index_gt = agent_state_index.roll(shifts=-1, dims=1) |
|
|
|
|
|
grid_agent_occ_gt_seed = grid_pt_occ_gt_seed = None |
|
if self.use_grid_token: |
|
self._build_occ_gt(data, seq_mask, agent_grid_token_idx.long(), edge_index=edge_index_a2a[:, -na2sa:], mode='edge_index') |
|
grid_agent_occ_gt_seed = data['agent']['agent_occ'] |
|
grid_pt_occ_gt_seed = data['agent']['map_occ'] |
|
|
|
if self.use_grid_token: |
|
occ_embed_a = self.seed_agent_occ_embed(grid_agent_occ_gt_seed.transpose(0, 1).reshape(-1, self.grid_size).float()) |
|
|
|
edge_index_occ2sa_src = torch.arange(feat_a.shape[0] * feat_a.shape[1], device=device).long() |
|
edge_index_occ2sa_src = edge_index_occ2sa_src[~pad_mask.transpose(0, 1).reshape(-1)] |
|
edge_index_occ2sa_tgt = torch.arange(occ_embed_a.shape[0], device=device).long() |
|
edge_index_occ2sa = torch.stack([edge_index_occ2sa_tgt, edge_index_occ2sa_src], dim=0) |
|
|
|
feat_sa = torch.cat([raw_feat_a, raw_feat_seed]) |
|
|
|
for i in range(self.seed_layers): |
|
|
|
|
|
|
|
|
|
feat_sa = feat_sa.reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim) |
|
if self.use_grid_token: |
|
feat_sa = self.occ2sa_attn_layers[i]((occ_embed_a, feat_sa), None, edge_index_occ2sa) |
|
feat_sa = self.pt2sa_attn_layers[i](( |
|
map_enc['x_pt'].repeat_interleave(repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape( |
|
-1, self.hidden_dim), feat_sa), r_pl2a[-npl2sa:], edge_index_pl2a[:, -npl2sa:]) |
|
feat_sa = self.a2sa_attn_layers[i](feat_sa, r_a2a[-na2sa:], edge_index_a2a[:, -na2sa:]) |
|
feat_sa = feat_sa.reshape(num_step, -1, self.hidden_dim).transpose(0, 1) |
|
|
|
feat_seed = feat_sa[-data.num_graphs * self.num_seed_feature:] |
|
|
|
|
|
next_state_prob_seed = self.seed_state_predict_head(feat_seed) |
|
raw_next_state_prob_seed = next_state_prob_seed.clone() |
|
next_state_idx_seed = next_state_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True) |
|
|
|
next_type_prob_seed = self.seed_type_predict_head(feat_seed) |
|
next_type_idx_seed = next_type_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True) |
|
|
|
next_type_index_gt = agent_type_index[:, None].repeat(1, num_step).long() |
|
|
|
next_shape_seed = self.seed_shape_predict_head(feat_seed) |
|
|
|
next_shape_gt = agent_shape[:, None].repeat(1, num_step, 1).float() |
|
|
|
if self.use_grid_token: |
|
next_pos_rel_prob_seed = self.seed_pos_rel_token_predict_head(feat_seed) |
|
next_pos_rel_idx_seed = next_pos_rel_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True) |
|
else: |
|
next_pos_rel_prob_seed = self.seed_pos_rel_xy_predict_head(feat_seed) |
|
next_pos_rel_xy_seed = torch.tanh(next_pos_rel_prob_seed) |
|
|
|
next_pos_rel_index_gt = agent_grid_token_idx.long() |
|
next_pos_rel_xy_gt = agent_pos_xy.float() / self.pl2seed_radius |
|
|
|
|
|
if self.use_grid_token: |
|
neighbor_agent_grid_index_gt = grid_index_a_p.transpose(0, 1).reshape(-1)[edge_index_a2a[0, -na2sa:]] |
|
neighbor_pt_grid_index_gt = data['agent']['pt_grid_token_idx'].reshape(-1)[edge_index_pl2a[0, -npl2sa:]] |
|
neighbor_agent_grid_idx = self.grid_index_head(r_a2a[-na2sa:]) |
|
neighbor_pt_grid_idx = self.grid_index_head(r_pl2a[-npl2sa:]) |
|
neighbor_agent_grid_index_eval_mask = torch.zeros_like(neighbor_agent_grid_index_gt).bool() |
|
neighbor_pt_grid_index_eval_mask = torch.zeros_like(neighbor_pt_grid_index_gt).bool() |
|
neighbor_agent_grid_index_eval_mask[torch.randperm(neighbor_agent_grid_index_gt.shape[0])[:180]] = True |
|
neighbor_pt_grid_index_eval_mask[torch.randperm(neighbor_pt_grid_index_gt.shape[0])[:600]] = True |
|
|
|
|
|
grid_agent_occ_seed = grid_pt_occ_seed = grid_agent_occ_eval_mask_seed = grid_pt_occ_eval_mask_seed = None |
|
if self.predict_occ: |
|
|
|
grid_agent_occ_seed = self.grid_agent_occ_head(feat_seed) |
|
grid_pt_occ_seed = self.grid_pt_occ_head(feat_seed) |
|
|
|
|
|
batch_s = torch.cat([data['agent']['batch'] + data.num_graphs * t for t in range(num_step)], dim=0) |
|
batch_pl = torch.cat([data['pt_token']['batch'] + data.num_graphs * t for t in range(num_step)], dim=0) |
|
|
|
mask_sa = torch.zeros_like(agent_state_index).bool() |
|
for t in range(mask_sa.shape[1]): |
|
availabel_rows = ((agent_state_index[:, t] != self.invalid_state) & |
|
(agent_grid_token_idx[:, t] != -1)).nonzero()[..., 0] |
|
mask_sa[availabel_rows[torch.randperm(availabel_rows.shape[0])[:data.num_graphs * 10]], t] = True |
|
mask_sa[agent_state_index == self.enter_state] = True |
|
mask_sa[:, 0] = False |
|
mask_sa[av_index] = False |
|
|
|
state_sa = torch.full_like(agent_state_index, self.invalid_state).long() |
|
state_sa[mask_sa] = self.enter_state |
|
|
|
sa_indices = torch.nonzero(mask_sa) |
|
pos_sa = pos_a.clone() |
|
head_sa = head_a.clone() |
|
expanded_av_index = av_index.repeat_interleave(repeats=data['batch_size_a'], dim=0) |
|
head_sa[sa_indices[:, 0], sa_indices[:, 1]] = head_a[expanded_av_index[sa_indices[:, 0]], sa_indices[:, 1]] |
|
|
|
motion_vector_sa, head_vector_sa = self._build_vector_a(pos_a, head_sa, state_sa) |
|
motion_vector_sa[mask_sa] = self.motion_gap |
|
|
|
offset_pos = pos_a - data['ego_pos'].repeat_interleave(repeats=data['batch_size_a'], dim=0) |
|
agent_grid_emb = self.grid_token_emb[agent_grid_token_idx.long()] |
|
feat_sa, _ = self._build_agent_feature(num_step, pos_a.device, |
|
motion_vector_sa, |
|
head_vector_sa, |
|
agent_grid_emb=agent_grid_emb, |
|
offset_pos=offset_pos, |
|
type=next_type_index_gt.long(), |
|
shape=next_shape_gt, |
|
state=state_sa, |
|
n=num_agent) |
|
feat_sa[~mask_sa] = raw_feat_a[~mask_sa].clone() |
|
|
|
edge_index_a2sa, r_a2sa = self._build_a2sa_edge(data, pos_a, head_sa, head_vector_sa, batch_s, |
|
interact_mask, mask_sa=mask_sa, r=self.a2sa_radius) |
|
edge_index_pl2sa, r_pl2sa = self._build_map2sa_edge(data, pos_a, head_sa, head_vector_sa, batch_s, batch_pl, |
|
mask_sa=mask_sa, r=self.pl2sa_radius) |
|
|
|
|
|
global_index = set(torch.nonzero(mask_sa.transpose(0, 1).reshape(-1).int())[:, 0].tolist()) |
|
a2sa_index = set(edge_index_a2sa[1].tolist()) |
|
pl2sa_index = set(edge_index_pl2sa[1].tolist()) |
|
assert a2sa_index.issubset(global_index) and pl2sa_index.issubset(global_index), "Invalid index!" |
|
|
|
select_mask = torch.zeros_like(mask_sa.view(-1)).bool() |
|
select_mask[torch.unique(edge_index_a2sa[1])] = True |
|
select_mask[torch.unique(edge_index_pl2sa[1])] = True |
|
mask_sa[~select_mask.reshape(num_step, -1).transpose(0, 1)] = False |
|
|
|
for i in range(self.seed_layers): |
|
|
|
feat_sa = feat_sa.transpose(0, 1).reshape(-1, self.hidden_dim) |
|
feat_sa = self.pt2a_attn_layers[i](( |
|
map_enc['x_pt'].repeat_interleave(repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape( |
|
-1, self.hidden_dim), feat_sa), r_pl2sa, edge_index_pl2sa) |
|
|
|
feat_sa = self.a2a_attn_layers[i](feat_sa, r_a2sa, edge_index_a2sa) |
|
feat_sa = feat_sa.reshape(num_step, -1, self.hidden_dim).transpose(0, 1) |
|
|
|
if self.use_head_token: |
|
next_head_rel_theta_seed = None |
|
next_head_rel_prob_seed = self.seed_heading_rel_token_predict_head(feat_sa) |
|
next_head_rel_idx_seed = next_head_rel_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True) |
|
else: |
|
next_head_rel_prob_seed = None |
|
next_head_rel_theta_seed = self.seed_heading_rel_theta_predict_head(feat_sa) |
|
next_head_rel_theta_seed = torch.tanh(next_head_rel_theta_seed)[..., 0] |
|
|
|
next_head_rel_index_gt_seed = agent_head_token_idx.long() |
|
next_head_rel_theta_gt_seed = agent_heading_theta.float() / torch.pi |
|
|
|
next_offset_xy_seed = None |
|
if self.use_grid_token: |
|
next_offset_xy_seed = self.seed_offset_xy_predict_head(feat_sa) |
|
next_offset_xy_seed = torch.tanh(next_offset_xy_seed) * 2 |
|
|
|
next_offset_xy_gt_seed = agent_grid_offset_xy.float() |
|
|
|
|
|
bos_token_index = torch.nonzero(agent_state_index == self.enter_state) |
|
eos_token_index = torch.nonzero(agent_state_index == self.exit_state) |
|
|
|
|
|
next_token_eval_mask = mask.clone() |
|
next_token_eval_mask = next_token_eval_mask * next_token_eval_mask.roll(shifts=-1, dims=1) * next_token_eval_mask.roll(shifts=1, dims=1) |
|
for bos_token_index_ in bos_token_index: |
|
next_token_eval_mask[bos_token_index_[0], bos_token_index_[1] : bos_token_index_[1] + 1] = 1 |
|
next_token_eval_mask[bos_token_index_[0], bos_token_index_[1] + 1 : bos_token_index_[1] + 2] = \ |
|
mask[bos_token_index_[0], bos_token_index_[1] + 2 : bos_token_index_[1] + 3] |
|
next_token_eval_mask[eos_token_index[:, 0], eos_token_index[:, 1]] = 0 |
|
|
|
|
|
next_state_eval_mask = mask.clone() |
|
next_state_eval_mask = next_state_eval_mask * next_state_eval_mask.roll(shifts=-1, dims=1) * next_state_eval_mask.roll(shifts=1, dims=1) |
|
for bos_token_index_ in bos_token_index: |
|
next_state_eval_mask[bos_token_index_[0], :bos_token_index_[1]] = 0 |
|
next_state_eval_mask[bos_token_index_[0], bos_token_index_[1] : bos_token_index_[1] + 1] = 1 |
|
next_state_eval_mask[bos_token_index_[0], bos_token_index_[1] + 1 : bos_token_index_[1] + 2] = \ |
|
mask[bos_token_index_[0], bos_token_index_[1] + 2 : bos_token_index_[1] + 3] |
|
for eos_token_index_ in eos_token_index: |
|
next_state_eval_mask[eos_token_index_[0], eos_token_index_[1] + 1:] = 1 |
|
next_state_eval_mask[eos_token_index_[0], eos_token_index_[1] : eos_token_index_[1] + 1] = \ |
|
mask[eos_token_index_[0], eos_token_index_[1] - 1 : eos_token_index_[1]] |
|
|
|
next_state_eval_mask_seed = torch.ones_like(next_state_idx_seed[..., 0]) |
|
|
|
|
|
next_token_eval_mask[:, 0] = mask[:, 0] * mask[:, 1] |
|
next_state_eval_mask[:, 0] = mask[:, 0] * mask[:, 1] |
|
next_token_eval_mask[:, -1] = 0 |
|
next_state_eval_mask[:, -1] = 0 |
|
next_state_eval_mask_seed[:, 0] = 0 |
|
|
|
|
|
if (next_token_index_gt[next_token_eval_mask] < 0).any(): |
|
raise RuntimeError("Found invalid motion index.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pred_indices = [] |
|
gt_indices = [] |
|
agent_ptr = data['agent']['ptr'] |
|
num_seed_gt = 0 |
|
for b in range(data.num_graphs): |
|
batch_sort_indices = sort_indices[agent_ptr[b] : agent_ptr[b + 1]] |
|
batch_num_seed_gt = min(self.num_seed_feature, batch_sort_indices.shape[0]) |
|
num_seed_gt += batch_num_seed_gt |
|
pred_indices.append((torch.arange(batch_num_seed_gt, device=device) + b * self.num_seed_feature |
|
)[:, None, None].repeat(1, num_step, 1).long()) |
|
gt_indices.append(batch_sort_indices[:batch_num_seed_gt] + agent_ptr[b]) |
|
pred_indices = torch.concat(pred_indices) |
|
gt_indices = torch.concat(gt_indices) |
|
|
|
n = pred_indices.shape[0] |
|
|
|
res_pred_indices = [] |
|
for t in range(next_state_idx_seed.shape[1]): |
|
indices_t = torch.arange(next_state_idx_seed.shape[0]).to(device) |
|
selected_pred_mask = torch.zeros_like(indices_t) |
|
selected_pred_mask[pred_indices[:, t]] = 1 |
|
res_pred_indices.append(indices_t[~selected_pred_mask.bool()]) |
|
res_pred_indices = torch.stack(res_pred_indices, dim=1) |
|
padded_pred_indices = torch.concat([pred_indices, res_pred_indices[..., None]]) |
|
next_state_idx_seed = torch.gather(next_state_idx_seed, dim=0, index=padded_pred_indices) |
|
next_state_prob_seed = torch.gather(next_state_prob_seed, dim=0, index=padded_pred_indices.expand( |
|
-1, -1, next_state_prob_seed.shape[-1])) |
|
next_state_index_gt_seed = torch.gather(agent_state_index, dim=0, index=gt_indices) |
|
next_state_index_gt_seed = torch.concat([next_state_index_gt_seed, |
|
torch.zeros((next_state_prob_seed.shape[0] - next_state_index_gt_seed.shape[0], next_state_index_gt_seed.shape[1]), device=device)]).long() |
|
seed_enter_mask = next_state_index_gt_seed == self.enter_state |
|
next_state_index_gt_seed = torch.full(next_state_index_gt_seed.shape, self.seed_state_type.index('invalid'), device=device) |
|
next_state_index_gt_seed[seed_enter_mask] = self.seed_state_type.index('enter') |
|
|
|
next_type_idx_seed = torch.gather(next_type_idx_seed, dim=0, index=pred_indices) |
|
next_type_prob_seed = torch.gather(next_type_prob_seed, dim=0, index=pred_indices.expand( |
|
-1, -1, next_type_prob_seed.shape[-1])) |
|
next_type_index_gt_seed = torch.gather(next_type_index_gt, dim=0, index=gt_indices) |
|
|
|
if self.use_grid_token: |
|
next_pos_rel_xy_seed = None |
|
next_pos_rel_prob_seed = torch.gather(next_pos_rel_prob_seed, dim=0, index=pred_indices.expand( |
|
-1, -1, next_pos_rel_prob_seed.shape[-1])) |
|
else: |
|
next_pos_rel_prob_seed = None |
|
next_pos_rel_xy_seed = torch.gather(next_pos_rel_xy_seed, dim=0, index=pred_indices.expand( |
|
-1, -1, next_pos_rel_xy_seed.shape[-1])) |
|
next_pos_rel_index_gt_seed = torch.gather(next_pos_rel_index_gt, dim=0, index=gt_indices) |
|
next_pos_rel_xy_gt_seed = torch.gather(next_pos_rel_xy_gt, dim=0, index=gt_indices[..., None].expand( |
|
-1, -1, next_pos_rel_xy_gt.shape[-1])) |
|
|
|
next_shape_seed = torch.gather(next_shape_seed, dim=0, index=pred_indices.expand( |
|
-1, -1, next_shape_seed.shape[-1])) |
|
next_shape_gt_seed = torch.gather(next_shape_gt, dim=0, index=gt_indices[..., None].expand( |
|
-1, -1, next_shape_gt.shape[-1])) |
|
|
|
next_attr_eval_mask_seed = seed_enter_mask[:n] |
|
next_attr_eval_mask_seed[:, 0] = False |
|
next_attr_eval_mask_seed[next_pos_rel_index_gt_seed == self.grid_size // 2] = False |
|
|
|
next_state_eval_mask[av_index] = 0 |
|
|
|
if (torch.any(next_type_index_gt_seed[next_attr_eval_mask_seed] == AGENT_TYPE.index('seed')) \ |
|
or torch.any(torch.all(next_shape_gt_seed[next_attr_eval_mask_seed] == self.invalid_shape_value, dim=-1)) \ |
|
or torch.any(next_pos_rel_index_gt_seed[next_attr_eval_mask_seed] < 0)) and num_seed_gt > 0: |
|
raise ValueError(f"Found invalid gt values in scenario {data['scenario_id'][0]}.") |
|
|
|
next_state_index_gt[next_state_index_gt == self.exit_state] = self.valid_state_type.index('exit') |
|
|
|
|
|
if self.predict_occ: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
grid_occ_eval_mask_seed = torch.ones_like(grid_agent_occ_seed).bool() |
|
grid_occ_eval_mask_seed[:, 0] = False |
|
grid_occ_eval_mask_seed[..., self.grid_size // 2] = False |
|
grid_agent_occ_eval_mask_seed = grid_pt_occ_eval_mask_seed = grid_occ_eval_mask_seed |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
target_indices = pred_indices.clone() |
|
target_indices[~next_attr_eval_mask_seed] = -1 |
|
|
|
return {'x_a': feat_a, |
|
'ego_pos': ego_pos, |
|
|
|
'next_token_idx': next_token_idx, |
|
'next_token_prob': next_token_prob, |
|
'next_token_idx_gt': next_token_index_gt, |
|
'next_token_eval_mask': next_token_eval_mask.bool(), |
|
|
|
'next_state_idx': next_state_idx, |
|
'next_state_prob': next_state_prob, |
|
'next_state_idx_gt': next_state_index_gt, |
|
'next_state_eval_mask': next_state_eval_mask.bool(), |
|
|
|
'next_state_idx_seed': next_state_idx_seed, |
|
'next_state_prob_seed': next_state_prob_seed, |
|
'next_state_idx_gt_seed': next_state_index_gt_seed, |
|
|
|
'next_type_idx_seed': next_type_idx_seed, |
|
'next_type_prob_seed': next_type_prob_seed, |
|
'next_type_idx_gt_seed': next_type_index_gt_seed, |
|
|
|
'next_pos_rel_prob_seed': next_pos_rel_prob_seed, |
|
'next_pos_rel_index_gt_seed': next_pos_rel_index_gt_seed, |
|
'next_pos_rel_xy_seed': next_pos_rel_xy_seed, |
|
'next_pos_rel_xy_gt_seed': next_pos_rel_xy_gt_seed, |
|
'next_head_rel_prob_seed': next_head_rel_prob_seed, |
|
'next_head_rel_index_gt_seed': next_head_rel_index_gt_seed, |
|
'next_head_rel_theta_seed': next_head_rel_theta_seed, |
|
'next_head_rel_theta_gt_seed': next_head_rel_theta_gt_seed, |
|
'next_offset_xy_seed': next_offset_xy_seed, |
|
'next_offset_xy_gt_seed': next_offset_xy_gt_seed, |
|
'next_shape_seed': next_shape_seed, |
|
'next_shape_gt_seed': next_shape_gt_seed, |
|
|
|
'grid_agent_occ_seed': grid_agent_occ_seed, |
|
'grid_pt_occ_seed': grid_pt_occ_seed, |
|
'grid_agent_occ_gt_seed': grid_agent_occ_gt_seed, |
|
'grid_pt_occ_gt_seed': grid_pt_occ_gt_seed, |
|
'neighbor_agent_grid_idx': neighbor_agent_grid_idx |
|
if self.use_grid_token else None, |
|
'neighbor_pt_grid_idx': neighbor_pt_grid_idx |
|
if self.use_grid_token else None, |
|
'neighbor_agent_grid_index_gt': neighbor_agent_grid_index_gt |
|
if self.use_grid_token else None, |
|
'neighbor_pt_grid_index_gt': neighbor_pt_grid_index_gt |
|
if self.use_grid_token else None, |
|
|
|
'target_indices': target_indices[..., 0], |
|
'raw_next_state_prob_seed': raw_next_state_prob_seed, |
|
|
|
'next_state_eval_mask_seed': next_state_eval_mask_seed.bool(), |
|
'next_attr_eval_mask_seed': next_attr_eval_mask_seed.bool(), |
|
'next_head_eval_mask_seed': mask_sa.bool(), |
|
'grid_agent_occ_eval_mask_seed': grid_agent_occ_eval_mask_seed |
|
if self.use_grid_token else None, |
|
'grid_pt_occ_eval_mask_seed': grid_pt_occ_eval_mask_seed |
|
if self.use_grid_token else None, |
|
'neighbor_agent_grid_index_eval_mask': neighbor_agent_grid_index_eval_mask.bool() |
|
if self.use_grid_token else None, |
|
'neighbor_pt_grid_index_eval_mask': neighbor_pt_grid_index_eval_mask.bool() |
|
if self.use_grid_token else None, |
|
} |
|
|
|
def inference(self, |
|
data: HeteroData, |
|
map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
|
|
filter_mask = data['agent']['state_idx'][:, (self.num_historical_steps - 1) // self.shift - 1] != self.invalid_state |
|
|
|
seed_step_mask = data['agent']['state_idx'][:, (self.num_historical_steps - 1) // self.shift:] == self.enter_state |
|
|
|
|
|
|
|
eval_mask = data['agent']['valid_mask'][filter_mask, self.num_historical_steps - 1] |
|
|
|
|
|
agent_id = data['agent']['id'][filter_mask].clone() |
|
agent_valid_mask = data['agent']['raw_agent_valid_mask'][filter_mask].clone() |
|
pos_a = data['agent']['token_pos'][filter_mask].clone() |
|
token_a = data['agent']['token_idx'][filter_mask].clone() |
|
state_a = data['agent']['state_idx'][filter_mask].clone() |
|
head_a = data['agent']['token_heading'][filter_mask].clone() |
|
shape_a = data['agent']['shape'][filter_mask].clone() |
|
type_a = data['agent']['type'][filter_mask].clone() |
|
grid_a = data['agent']['grid_token_idx'][filter_mask].clone() |
|
gt_traj = data['agent']['position'][filter_mask, self.num_historical_steps:, :self.input_dim].contiguous() |
|
agent_token_traj_all = data['agent']['token_traj_all'][filter_mask] |
|
|
|
device = pos_a.device |
|
max_agent_id = agent_id.max() |
|
|
|
if self.num_recurrent_steps_val == -1: |
|
|
|
self.num_recurrent_steps_val = data["agent"]['position'].shape[1] - self.num_historical_steps |
|
num_agent, num_ori_step, traj_dim = pos_a.shape |
|
num_infer_step = (self.num_recurrent_steps_val + self.num_historical_steps) // self.shift |
|
if num_infer_step > num_ori_step: |
|
pad_shape = num_agent, num_infer_step - num_ori_step |
|
agent_valid_mask = torch.cat([agent_valid_mask, torch.full(pad_shape, True, device=device)], dim=1) |
|
pos_a = torch.cat([pos_a, torch.zeros((*pad_shape, pos_a.shape[-1]), device=device)], dim=1) |
|
token_a = torch.cat([token_a, torch.full(pad_shape, -1, device=device)], dim=1) |
|
state_a = torch.cat([state_a, torch.full(pad_shape, self.invalid_state, device=device)], dim=1) |
|
head_a = torch.cat([head_a, torch.zeros(pad_shape, device=device)], dim=1) |
|
grid_a = torch.cat([grid_a, torch.full(pad_shape, -1, device=device)], dim=1) |
|
|
|
|
|
num_removed_agent = int((~filter_mask[:data['agent']['av_index']]).sum()) |
|
data['batch_size_a'] -= num_removed_agent |
|
av_index = data['agent']['av_index'] - num_removed_agent |
|
|
|
|
|
pos_a[:, (self.num_historical_steps - 1) // self.shift:] = 0 |
|
head_a[:, (self.num_historical_steps - 1) // self.shift:] = 0 |
|
token_a[:, (self.num_historical_steps - 1) // self.shift:] = -1 |
|
state_a[:, (self.num_historical_steps - 1) // self.shift:] = 0 |
|
grid_a[:, (self.num_historical_steps - 1) // self.shift:] = -1 |
|
|
|
motion_vector_a, head_vector_a = self._build_vector_a(pos_a, head_a, state_a) |
|
|
|
agent_valid_mask[:, (self.num_historical_steps - 1) // self.shift:] = True |
|
agent_valid_mask[~eval_mask] = False |
|
agent_token_index = data['agent']['token_idx'][filter_mask] |
|
agent_state_index = data['agent']['state_idx'][filter_mask] |
|
|
|
(feat_a, agent_token_emb, agent_token_emb_veh, agent_token_emb_ped, agent_token_emb_cyc, categorical_embs, |
|
trajectory_token_veh, trajectory_token_ped, trajectory_token_cyc) = self._agent_token_embedding( |
|
data, |
|
token_a, |
|
state_a, |
|
grid_a, |
|
pos_a, |
|
head_a, |
|
inference=True, |
|
filter_mask=filter_mask, |
|
av_index=av_index, |
|
) |
|
raw_feat_a = feat_a.clone() |
|
|
|
veh_mask = type_a == 0 |
|
cyc_mask = type_a == 2 |
|
ped_mask = type_a == 1 |
|
|
|
pred_traj = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val, 2, device=device) |
|
pred_head = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val, device=device) |
|
pred_type = type_a.clone() |
|
pred_shape = shape_a[:, (self.num_historical_steps - 1) // self.shift - 1] |
|
pred_state = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val, device=device) |
|
pred_prob = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val // self.shift, device=device) |
|
|
|
feat_a_t_dict = {} |
|
feat_sa_t_dict = {} |
|
|
|
|
|
mask = agent_valid_mask.clone() |
|
temporal_mask = mask.clone() |
|
interact_mask = mask.clone() |
|
|
|
|
|
is_bos = state_a == self.enter_state |
|
is_eos = state_a == self.exit_state |
|
bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0)) |
|
eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_infer_step - 1)) |
|
|
|
temporal_mask = torch.ones_like(mask) |
|
motion_mask = torch.arange(mask.shape[1]).expand(mask.shape[0], mask.shape[1]).to(mask.device) |
|
motion_mask = (motion_mask > bos_index[:, None]) & (motion_mask <= eos_index[:, None]) |
|
motion_mask[:, self.num_historical_steps // self.shift:] = False |
|
temporal_mask[motion_mask] = mask[motion_mask] |
|
|
|
interact_mask = torch.ones_like(mask) |
|
non_motion_mask = ~motion_mask |
|
non_motion_mask[:, self.num_historical_steps // self.shift:] = False |
|
interact_mask[non_motion_mask] = 0 |
|
interact_mask[state_a == self.enter_state] = 1 |
|
interact_mask[av_index] = 1 |
|
|
|
temporal_mask[:, (self.num_historical_steps - 1) // self.shift:] = 1 |
|
interact_mask[:, (self.num_historical_steps - 1) // self.shift:] = 1 |
|
|
|
self.log_message = "" |
|
num_inserted_agents_total = num_inserted_agents = 0 |
|
next_token_idx_list = [] |
|
next_state_idx_list = [] |
|
grid_agent_occ_list = [] |
|
grid_pt_occ_list = [] |
|
grid_agent_occ_gt_list = [] |
|
next_state_prob_seed_list = [] |
|
next_pos_rel_prob_seed_list = [] |
|
agent_labels = [[None] * num_infer_step for _ in range(pos_a.shape[0])] |
|
|
|
|
|
for i in range((self.num_historical_steps - 1) // self.shift): |
|
next_token_idx_list.append(agent_token_index[:, i : i + 1]) |
|
next_state_idx_list.append(agent_state_index[:, i : i + 1]) |
|
|
|
num_seed_feature = 1 |
|
insert_limit = 10 |
|
|
|
for t in ( |
|
pbar := tqdm(range(self.num_recurrent_steps_val // self.shift), leave=False, desc='Timestep ...') |
|
): |
|
|
|
|
|
num_new_agents = 0 |
|
next_state_prob_seeds = torch.zeros(10 + 1, 1, device=device) |
|
next_pos_rel_prob_seeds = torch.zeros(10 + 1, 1, self.attr_tokenizer.grid_size, device=device) |
|
grid_agent_occ_seeds = torch.zeros(10 + 1, 1, self.attr_tokenizer.grid_size, device=device) |
|
grid_pt_occ_seeds = torch.zeros(10 + 1, 1, self.attr_tokenizer.grid_size, device=device) |
|
grid_agent_occ_gt_seeds = torch.zeros(10 + 1, 1, self.attr_tokenizer.grid_size, device=device) |
|
|
|
valid_state_mask = state_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t] != self.invalid_state |
|
distance = ((pos_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t, :2] - pos_a[av_index, (self.num_historical_steps - 1) // self.shift - 1 + t, :2]) ** 2).sum(-1).sqrt() |
|
inrange_mask = distance <= self.pl2seed_radius |
|
seq_valid_mask = valid_state_mask & inrange_mask |
|
seq_valid_mask[av_index] = False |
|
res_seq_index = torch.zeros_like(state_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) |
|
res_seq_index[seq_valid_mask] = torch.randperm(seq_valid_mask.sum(), device=device) + 1 |
|
|
|
if t == 0: |
|
inference_mask = temporal_mask.clone() |
|
inference_mask = torch.cat([inference_mask, torch.ones_like(inference_mask[-1:]).repeat( |
|
num_seed_feature, *([1] * (inference_mask.dim() - 1)))]) |
|
inference_mask[:, (self.num_historical_steps - 1) // self.shift + t:] = False |
|
else: |
|
inference_mask = torch.zeros_like(temporal_mask) |
|
inference_mask = torch.cat([inference_mask, torch.zeros_like(inference_mask[-1:]).repeat( |
|
num_seed_feature, *([1] * (inference_mask.dim() - 1)))]) |
|
inference_mask[:, (self.num_historical_steps - 1) // self.shift + t - 1] = True |
|
|
|
plot_kwargs = dict() |
|
p = 0 |
|
while True: |
|
|
|
p += 1 |
|
if t == 0 or p - 1 >= insert_limit or self.disable_insertion: break |
|
|
|
|
|
inference_mask = torch.zeros_like(temporal_mask) |
|
inference_mask = torch.cat([inference_mask, torch.zeros_like(inference_mask[-1:]).repeat( |
|
num_seed_feature, *([1] * (inference_mask.dim() - 1)))]) |
|
inference_mask[:, (self.num_historical_steps - 1) // self.shift + t - 1] = True |
|
|
|
|
|
assert torch.all(state_a[:, :(self.num_historical_steps - 1) // self.shift + t][ |
|
~interact_mask[:, :(self.num_historical_steps - 1) // self.shift + t]] == self.invalid_state) and \ |
|
torch.all(state_a[:, :(self.num_historical_steps - 1) // self.shift + t][ |
|
interact_mask[:, :(self.num_historical_steps - 1) // self.shift + t]] != self.invalid_state), \ |
|
f"Got wrong with interact mask at scenario {data['scenario_id'][0]} t={t}!" |
|
|
|
temporal_mask = torch.cat([temporal_mask, torch.ones_like(temporal_mask[:1]).repeat( |
|
num_seed_feature, *([1] * (temporal_mask.dim() - 1)))]).bool() |
|
interact_mask = torch.cat([interact_mask, torch.ones_like(interact_mask[:1]).repeat( |
|
num_seed_feature, *([1] * (interact_mask.dim() - 1)))]).bool() |
|
|
|
pos_a_p, head_a_p, state_a_p, head_vector_a_p, grid_index_a_p, pad_mask = \ |
|
self._pad_feat(data.num_graphs, av_index, pos_a, head_a, state_a, head_vector_a, grid_a, num_seed_feature=num_seed_feature) |
|
|
|
assert torch.all(~pad_mask[-num_seed_feature:]), "Got wrong with pad mask!" |
|
|
|
batch_s = torch.arange(num_infer_step, device=device).repeat_interleave(num_agent + num_seed_feature) |
|
batch_pl = torch.arange(num_infer_step, device=device).repeat_interleave(data['pt_token']['num_nodes']) |
|
|
|
inference_mask_sa = torch.zeros_like(inference_mask).bool() |
|
inference_mask_sa[:, (self.num_historical_steps - 1) // self.shift - 1 + t] = True |
|
|
|
|
|
if self.seed_use_ego_motion: |
|
motion_vector_seed = motion_vector_a[av_index] |
|
head_vector_seed = head_vector_a[av_index] |
|
else: |
|
motion_vector_seed = head_vector_seed = None |
|
|
|
feat_seed, _ = self._build_agent_feature(num_infer_step, device, |
|
motion_vector_seed, |
|
head_vector_seed, |
|
state_index=self.invalid_state, |
|
n=num_seed_feature) |
|
|
|
if feat_a.shape[1] != feat_seed.shape[1]: |
|
assert t == 0, f"Unmatched timestep {feat_a.shape[1]} and {feat_seed.shape[1]}." |
|
feat_a = torch.cat([feat_a, feat_a[:, -1:].repeat(1, feat_seed.shape[1] - feat_a.shape[1], 1)], dim=1) |
|
|
|
raw_feat_a = feat_a.clone() |
|
feat_a = torch.cat([feat_a, feat_seed], dim=0) |
|
|
|
|
|
plot_kwargs.update(t=t, n=num_new_agents, tag='global_feature') |
|
|
|
seq_index = torch.cat([torch.zeros(pos_a.shape[0] - num_new_agents), torch.arange(num_new_agents + 1) + 1]).to(device) |
|
|
|
|
|
edge_index_a2seed, r_seed2a = self._build_a2sa_edge(data, pos_a_p, head_a_p, head_vector_a_p, batch_s, |
|
interact_mask.clone(), |
|
mask_sa=~pad_mask.clone(), |
|
inference_mask=inference_mask_sa, |
|
r=self.pl2seed_radius, |
|
max_num_neighbors=300, |
|
seq_index=seq_index, |
|
grid_index_a=grid_index_a_p, |
|
mode='insert', **plot_kwargs) |
|
edge_index_pl2seed, r_pl2seed = self._build_map2sa_edge(data, pos_a_p, head_a_p, head_vector_a_p, batch_s, batch_pl, |
|
mask_sa=~pad_mask.clone(), |
|
inference_mask=inference_mask_sa, |
|
r=self.pl2seed_radius, |
|
max_num_neighbors=2048, |
|
mode='insert') |
|
temporal_mask = temporal_mask[:-num_seed_feature] |
|
interact_mask = interact_mask[:-num_seed_feature] |
|
|
|
if self.use_grid_token: |
|
grid_agent_occ_gt_t_1 = torch.zeros((self.grid_size,), device=device).long() |
|
grid_t_1 = grid_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t] |
|
grid_agent_occ_gt_t_1[grid_t_1[grid_t_1 != -1]] = 1 |
|
occ_embed_a = self.seed_agent_occ_embed(grid_agent_occ_gt_t_1.reshape(1, self.grid_size).float()).repeat(num_seed_feature, 1) |
|
edge_index_occ2sa_src = torch.arange(feat_a.shape[0] * feat_a.shape[1], device=device).long() |
|
edge_index_occ2sa_src = edge_index_occ2sa_src[(~pad_mask.transpose(0, 1).reshape(-1)) & (inference_mask_sa.transpose(0, 1).reshape(-1))] |
|
edge_index_occ2sa_tgt = torch.arange(occ_embed_a.shape[0], device=device).long() |
|
edge_index_occ2sa = torch.stack([edge_index_occ2sa_tgt, edge_index_occ2sa_src], dim=0) |
|
|
|
for i in range(self.seed_layers): |
|
|
|
feat_a = feat_a.transpose(0, 1).reshape(-1, self.hidden_dim) |
|
if self.use_grid_token: |
|
feat_a = self.occ2sa_attn_layers[i]((occ_embed_a, feat_a), None, edge_index_occ2sa) |
|
feat_a = self.pt2sa_attn_layers[i](( |
|
map_enc['x_pt'].repeat_interleave(repeats=num_infer_step, dim=0).reshape(-1, num_infer_step, self.hidden_dim).transpose(0, 1).reshape( |
|
-1, self.hidden_dim), feat_a), r_pl2seed, edge_index_pl2seed) |
|
|
|
feat_a = self.a2sa_attn_layers[i](feat_a, r_seed2a, edge_index_a2seed) |
|
feat_a = feat_a.reshape(num_infer_step, -1, self.hidden_dim).transpose(0, 1) |
|
|
|
feat_seed = feat_a[-num_seed_feature:] |
|
|
|
ego_pos_t_1 = pos_a[av_index, (self.num_historical_steps - 1) // self.shift - 1 + t] |
|
ego_head_t_1 = head_a[av_index, (self.num_historical_steps - 1) // self.shift - 1 + t] |
|
|
|
|
|
if self.predict_occ: |
|
grid_agent_occ_seed = self.grid_agent_occ_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) |
|
grid_pt_occ_seed = self.grid_pt_occ_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) |
|
|
|
|
|
next_state_prob_seed = self.seed_state_predict_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) |
|
next_state_idx_seed = next_state_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True) |
|
next_state_idx_seed[next_state_idx_seed == self.seed_state_type.index('invalid')] = self.invalid_state |
|
next_state_idx_seed[next_state_idx_seed == self.seed_state_type.index('enter')] = self.enter_state |
|
if int(os.getenv('DEBUG', 0)): |
|
next_state_idx_seed = torch.full(next_state_idx_seed.shape, self.enter_state, device=device).long() |
|
|
|
|
|
next_type_prob_seed = self.seed_type_predict_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) |
|
next_type_idx_seed = next_type_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True) |
|
next_shape_seed = self.seed_shape_predict_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) |
|
|
|
|
|
if self.use_grid_token: |
|
next_pos_rel_prob_seed = self.seed_pos_rel_token_predict_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) |
|
next_pos_rel_prob_softmax = torch.softmax(next_pos_rel_prob_seed, dim=-1) |
|
|
|
|
|
topk_pos_rel_prob, next_pos_rel_idx_seed = torch.topk(next_pos_rel_prob_softmax, k=self.insert_beam_size, dim=-1) |
|
sample_pos_rel_index = torch.multinomial(topk_pos_rel_prob, 1).to(device) |
|
next_pos_rel_idx_seed = next_pos_rel_idx_seed.gather(dim=1, index=sample_pos_rel_index) |
|
next_pos_seed = self.attr_tokenizer.decode_pos(next_pos_rel_idx_seed[..., 0], y=ego_pos_t_1, theta_y=ego_head_t_1) |
|
if self.inference_filter_overlap: |
|
if grid_agent_occ_gt_t_1[next_pos_rel_idx_seed[..., 0]]: |
|
feat_a = raw_feat_a.clone() |
|
continue |
|
else: |
|
next_pos_rel_xy_seed = self.seed_pos_rel_xy_predict_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) |
|
next_pos_seed = torch.tanh(next_pos_rel_xy_seed) * self.pl2seed_radius + ego_pos_t_1 |
|
|
|
if torch.all(next_state_idx_seed == self.invalid_state) or num_new_agents + 1 > insert_limit: |
|
break |
|
|
|
num_new_agent = 1 |
|
num_new_agents += 1 |
|
|
|
|
|
|
|
|
|
agent_id = torch.cat([agent_id, torch.tensor([max_agent_id + 1], device=device, dtype=agent_id.dtype)]) |
|
max_agent_id += 1 |
|
|
|
mask = torch.cat([mask, torch.ones(num_new_agent, num_infer_step, device=mask.device)], dim=0).bool() |
|
temporal_mask = torch.cat([temporal_mask, torch.ones(num_new_agent, num_infer_step, device=temporal_mask.device)], dim=0).bool() |
|
interact_mask = torch.cat([interact_mask, torch.ones(num_new_agent, num_infer_step, device=interact_mask.device)], dim=0).bool() |
|
|
|
|
|
new_pos_a = torch.zeros(num_new_agent, num_infer_step, 2, device=device) |
|
new_head_a = torch.zeros(num_new_agent, num_infer_step, device=device) |
|
new_grid_a = torch.full((num_new_agent, num_infer_step), -1, device=device) |
|
new_state_a = torch.full((num_new_agent, num_infer_step), self.invalid_state, device=state_a.device) |
|
new_shape_a = torch.full((num_new_agent, num_infer_step, 3), self.invalid_shape_value, device=device) |
|
new_type_a = torch.full((num_new_agent, num_infer_step), AGENT_TYPE.index('seed'), device=device) |
|
|
|
|
|
new_pos_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t] = next_pos_seed |
|
pos_a = torch.cat([pos_a, new_pos_a], dim=0) |
|
|
|
new_head_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t] = ego_head_t_1 |
|
head_a = torch.cat([head_a, new_head_a], dim=0) |
|
|
|
if self.use_grid_token: |
|
new_grid_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t] = next_pos_rel_idx_seed |
|
grid_a = torch.cat([grid_a, new_grid_a], dim=0) |
|
|
|
new_type_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t:] = next_type_idx_seed |
|
new_shape_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t:] = next_shape_seed[:, None] |
|
pred_type = torch.cat([pred_type, new_type_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]]) |
|
pred_shape = torch.cat([pred_shape, new_shape_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]]) |
|
|
|
new_state_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t] = next_state_idx_seed |
|
state_a = torch.cat([state_a, new_state_a], dim=0) |
|
|
|
mask[-num_new_agent:, :(self.num_historical_steps - 1) // self.shift + t] = 0 |
|
interact_mask[-num_new_agent:, :(self.num_historical_steps - 1) // self.shift - 1 + t] = 0 |
|
|
|
|
|
new_pred_traj = torch.zeros(num_new_agent, self.num_recurrent_steps_val, 2, device=device) |
|
new_pred_head = torch.zeros(num_new_agent, self.num_recurrent_steps_val, device=device) |
|
new_pred_state = torch.zeros(num_new_agent, self.num_recurrent_steps_val, device=device) |
|
|
|
if t > 0: |
|
new_pred_traj[:, (t - 1) * 5 : t * 5] = new_pos_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t, None].repeat(1, 5, 1) |
|
new_pred_head[:, (t - 1) * 5 : t * 5] = new_head_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t, None].repeat(1, 5) |
|
new_pred_state[:, (t - 1) * 5 : t * 5] = next_state_idx_seed.repeat(1, 5) |
|
|
|
pred_traj = torch.cat([pred_traj, new_pred_traj], dim=0) |
|
pred_head = torch.cat([pred_head, new_pred_head], dim=0) |
|
pred_state = torch.cat([pred_state, new_pred_state], dim=0) |
|
|
|
|
|
new_agent_token_emb = self.no_token_emb(torch.zeros(1, device=device).long())[None, :].repeat(num_new_agent, num_infer_step, 1) |
|
new_agent_token_emb[:, (self.num_historical_steps - 1) // self.shift - 1 + t] = self.bos_token_emb(torch.zeros(1, device=device).long()) |
|
agent_token_emb = torch.cat([agent_token_emb, new_agent_token_emb]) |
|
next_veh_mask = next_type_idx_seed[..., 0] == AGENT_TYPE.index('veh') |
|
next_ped_mask = next_type_idx_seed[..., 0] == AGENT_TYPE.index('ped') |
|
next_cyc_mask = next_type_idx_seed[..., 0] == AGENT_TYPE.index('cyc') |
|
veh_mask = torch.cat([veh_mask, next_veh_mask]) |
|
ped_mask = torch.cat([ped_mask, next_ped_mask]) |
|
cyc_mask = torch.cat([cyc_mask, next_cyc_mask]) |
|
|
|
|
|
new_agent_token_traj_all = torch.zeros((num_new_agent, self.token_size, self.shift + 1, 4, 2), device=device) |
|
new_agent_token_traj_all[next_veh_mask] = trajectory_token_veh[None, ...] |
|
new_agent_token_traj_all[next_ped_mask] = trajectory_token_ped[None, ...] |
|
new_agent_token_traj_all[next_cyc_mask] = trajectory_token_cyc[None, ...] |
|
|
|
agent_token_traj_all = torch.cat([agent_token_traj_all, new_agent_token_traj_all], dim=0) |
|
|
|
new_categorical_embs = [self.type_a_emb(new_type_a.reshape(-1).long()), self.shape_emb(new_shape_a.reshape(-1, 3))] |
|
categorical_embs = [torch.cat([categorical_embs[0], new_categorical_embs[0]], dim=0), |
|
torch.cat([categorical_embs[1], new_categorical_embs[1]], dim=0)] |
|
|
|
new_labels = [None] * num_infer_step |
|
new_labels[(self.num_historical_steps - 1) // self.shift + t] = f'A{num_new_agents}' |
|
agent_labels.append(new_labels) |
|
|
|
|
|
motion_vector_sa, head_vector_sa = self._build_vector_a(pos_a[-num_new_agent:], |
|
head_a[-num_new_agent:], |
|
state_a[-num_new_agent:]) |
|
|
|
assert torch.all(motion_vector_sa[:, :(self.num_historical_steps - 1) // self.shift - 1 + t] == self.invalid_motion_value) and \ |
|
torch.all(motion_vector_sa[:, (self.num_historical_steps - 1) // self.shift - 1 + t] == self.motion_gap), \ |
|
f"Found invalid values in motion_vectect_a at scenario {data['scenario_id'][0]} t={t}!" |
|
|
|
motion_vector_sa[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0. |
|
head_vector_sa[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0. |
|
motion_vector_a = torch.cat([motion_vector_a, motion_vector_sa]) |
|
head_vector_a = torch.cat([head_vector_a, head_vector_sa]) |
|
|
|
new_offset_pos = pos_a[-num_new_agent:] - pos_a[av_index] |
|
new_agent_grid_emb = self.grid_token_emb[new_grid_a] if self.use_grid_token else None |
|
|
|
feat_sa, _ = self._build_agent_feature(num_infer_step, device, |
|
motion_vector_sa, |
|
head_vector_sa, |
|
agent_token_emb=new_agent_token_emb, |
|
agent_grid_emb=new_agent_grid_emb, |
|
offset_pos=new_offset_pos, |
|
categorical_embs_a=new_categorical_embs, |
|
state=new_state_a) |
|
|
|
feat_a = torch.cat([raw_feat_a, feat_sa]) |
|
|
|
batch_s = torch.arange(num_infer_step, device=device).repeat_interleave(num_agent + num_new_agent) |
|
batch_pl = torch.arange(num_infer_step, device=device).repeat_interleave(data['pt_token']['num_nodes']) |
|
|
|
|
|
assert pos_a.shape[0] == head_a.shape[0] == head_vector_a.shape[0] == interact_mask.shape[0] == \ |
|
pad_mask.shape[0] == inference_mask_sa.shape[0] == (num_agent + num_new_agent), f"Inconsistent shapes!" |
|
|
|
plot_kwargs.update(tag='heading') |
|
edge_index_a2sa, r_a2sa = self._build_a2sa_edge(data, pos_a, head_a, head_vector_a, batch_s, |
|
interact_mask.clone(), |
|
mask_sa=~pad_mask.clone(), |
|
inference_mask=inference_mask_sa, |
|
r=self.a2sa_radius, |
|
max_num_neighbors=24, |
|
**plot_kwargs) |
|
edge_index_pl2sa, r_pl2sa = self._build_map2sa_edge(data, pos_a, head_a, head_vector_a, batch_s, batch_pl, |
|
mask_sa=~pad_mask.clone(), |
|
inference_mask=inference_mask_sa, |
|
r=self.pl2sa_radius, |
|
max_num_neighbors=128) |
|
|
|
for i in range(self.seed_layers): |
|
|
|
feat_a = feat_a.transpose(0, 1).reshape(-1, self.hidden_dim) |
|
feat_a = self.pt2a_attn_layers[i](( |
|
map_enc['x_pt'].repeat_interleave(repeats=num_infer_step, dim=0).reshape(-1, num_infer_step, self.hidden_dim).transpose(0, 1).reshape( |
|
-1, self.hidden_dim), feat_a), r_pl2sa, edge_index_pl2sa) |
|
|
|
feat_a = self.a2a_attn_layers[i](feat_a, r_a2sa, edge_index_a2sa) |
|
feat_a = feat_a.reshape(num_infer_step, -1, self.hidden_dim).transpose(0, 1) |
|
|
|
if self.use_head_token: |
|
next_head_rel_prob_seed = self.seed_heading_rel_token_predict_head(feat_a[-num_new_agent:, (self.num_historical_steps - 1) // self.shift - 1 + t]) |
|
next_head_rel_idx_seed = next_head_rel_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True) |
|
next_head_seed = wrap_angle(self.attr_tokenizer.decode_heading(next_head_rel_idx_seed) + ego_head_t_1) |
|
else: |
|
next_head_rel_theta_seed = self.seed_heading_rel_theta_predict_head(feat_a[-num_new_agent:, (self.num_historical_steps - 1) // self.shift - 1 + t]) |
|
next_head_seed = torch.tanh(next_head_rel_theta_seed) * torch.pi + ego_head_t_1 |
|
|
|
head_a[-num_new_agent:, (self.num_historical_steps - 1) // self.shift - 1 + t] = next_head_seed |
|
|
|
if self.use_grid_token: |
|
next_offset_xy_seed = self.seed_offset_xy_predict_head(feat_a[-num_new_agent:, (self.num_historical_steps - 1) // self.shift - 1 + t]) |
|
next_offset_xy_seed = torch.tanh(next_offset_xy_seed) * 2 |
|
|
|
pos_a[-num_new_agent:, (self.num_historical_steps - 1) // self.shift - 1 + t] += next_offset_xy_seed |
|
|
|
|
|
motion_vector_sa, head_vector_sa = self._build_vector_a(pos_a[-num_new_agent:], |
|
head_a[-num_new_agent:], |
|
state_a[-num_new_agent:]) |
|
motion_vector_sa[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0. |
|
head_vector_sa[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0. |
|
motion_vector_a[-num_new_agent:] = motion_vector_sa |
|
head_vector_a[-num_new_agents:] = head_vector_sa |
|
|
|
feat_sa, _ = self._build_agent_feature(num_infer_step, device, |
|
motion_vector_sa, |
|
head_vector_sa, |
|
agent_token_emb=new_agent_token_emb, |
|
agent_grid_emb=new_agent_grid_emb, |
|
offset_pos=new_offset_pos, |
|
categorical_embs_a=new_categorical_embs, |
|
state=state_a[-num_new_agent:], |
|
n=num_new_agent) |
|
|
|
feat_a = torch.cat([raw_feat_a, feat_sa]) |
|
raw_feat_a = feat_a.clone() |
|
|
|
num_agent = pos_a.shape[0] |
|
|
|
if self.use_grid_token: |
|
grid_agent_occ_gt_seeds[num_new_agents] = grid_agent_occ_gt_t_1 |
|
grid_agent_occ_seeds[num_new_agents] = grid_agent_occ_seed |
|
grid_pt_occ_seeds[num_new_agents] = grid_pt_occ_seed |
|
next_pos_rel_prob_seeds[num_new_agents] = next_pos_rel_prob_softmax |
|
next_state_prob_seeds[num_new_agents] = next_state_prob_seed.softmax(dim=-1)[:, -1] |
|
|
|
inference_mask = inference_mask[:-num_seed_feature] |
|
next_state_prob_seed_list.append(next_state_prob_seeds) |
|
if self.use_grid_token: |
|
next_pos_rel_prob_seed_list.append(next_pos_rel_prob_seeds) |
|
grid_agent_occ_list.append(grid_agent_occ_seeds) |
|
grid_pt_occ_list.append(grid_pt_occ_seeds) |
|
grid_agent_occ_gt_list.append(grid_agent_occ_gt_seeds) |
|
next_state_idx_list[-1] = torch.cat([next_state_idx_list[-1], torch.full((num_new_agents, 1), self.enter_state, device=device).long()]) |
|
|
|
|
|
feat_a = raw_feat_a |
|
|
|
|
|
inference_mask = torch.zeros_like(temporal_mask) |
|
inference_mask[:, (self.num_historical_steps - 1) // self.shift + t - 1] = True |
|
|
|
edge_index_t, r_t = self._build_temporal_edge(data, pos_a, head_a, state_a, head_vector_a, temporal_mask, inference_mask.clone()) |
|
|
|
batch_s = torch.arange(num_infer_step, device=device).repeat_interleave(num_agent) |
|
batch_pl = torch.arange(num_infer_step, device=device).repeat_interleave(data['pt_token']['num_nodes']) |
|
|
|
edge_index_a2a, r_a2a = self._build_interaction_edge(data, pos_a, head_a, state_a, head_vector_a, batch_s, |
|
interact_mask, inference_mask=inference_mask, av_index=av_index, **plot_kwargs) |
|
edge_index_pl2a, r_pl2a = self._build_map2agent_edge(data, pos_a, head_a, state_a, head_vector_a, batch_s, batch_pl, |
|
interact_mask, inference_mask=inference_mask, av_index=av_index, **plot_kwargs) |
|
|
|
for i in range(self.num_layers): |
|
|
|
if i in feat_a_t_dict: |
|
feat_a = feat_a_t_dict[i] |
|
|
|
feat_a = feat_a.reshape(-1, self.hidden_dim) |
|
feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t) |
|
|
|
feat_a = feat_a.reshape(-1, num_infer_step, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim) |
|
feat_a = self.pt2a_attn_layers[i](( |
|
map_enc['x_pt'].repeat_interleave(repeats=num_infer_step, dim=0).reshape(-1, num_infer_step, self.hidden_dim).transpose(0, 1).reshape( |
|
-1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a) |
|
|
|
feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a) |
|
feat_a = feat_a.reshape(num_infer_step, -1, self.hidden_dim).transpose(0, 1) |
|
|
|
if t == 0: |
|
feat_a_t_dict[i + 1] = feat_a |
|
else: |
|
|
|
n = feat_a_t_dict[i + 1].shape[0] |
|
feat_a_t_dict[i + 1][:n, (self.num_historical_steps - 1) // self.shift - 1 + t] = feat_a[:n, (self.num_historical_steps - 1) // self.shift - 1 + t] |
|
|
|
if feat_a.shape[0] > n: |
|
m = feat_a.shape[0] - n |
|
feat_a_t_dict[i + 1] = torch.cat([feat_a_t_dict[i + 1], feat_a[-m:]]) |
|
|
|
|
|
next_token_prob = self.token_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) |
|
next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1) |
|
topk_token_prob, next_token_idx = torch.topk(next_token_prob_softmax, k=self.motion_beam_size, dim=-1) |
|
|
|
|
|
next_state_prob = self.state_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) |
|
next_state_idx = next_state_prob.softmax(dim=-1).argmax(dim=-1) |
|
next_state_idx[next_state_idx == self.valid_state_type.index('exit')] = self.exit_state |
|
next_state_idx[av_index] = self.valid_state |
|
if not self.use_state_token: |
|
next_state_idx[next_state_idx == self.exit_state] = self.valid_state |
|
|
|
|
|
expanded_token_index = next_token_idx[..., None, None, None].expand(-1, -1, 6, 4, 2) |
|
next_token_traj = torch.gather(agent_token_traj_all, 1, expanded_token_index) |
|
|
|
|
|
theta = head_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t] |
|
cos, sin = theta.cos(), theta.sin() |
|
rot_mat = torch.zeros((num_agent, 2, 2), device=theta.device) |
|
rot_mat[:, 0, 0] = cos |
|
rot_mat[:, 0, 1] = sin |
|
rot_mat[:, 1, 0] = -sin |
|
rot_mat[:, 1, 1] = cos |
|
agent_diff_rel = torch.bmm(next_token_traj.view(-1, 4, 2), |
|
rot_mat[:, None, None, ...].repeat(1, self.motion_beam_size, self.shift + 1, 1, 1).view( |
|
-1, 2, 2)).view(num_agent, self.motion_beam_size, self.shift + 1, 4, 2) |
|
agent_pred_rel = agent_diff_rel + pos_a[:, None, None, None, (self.num_historical_steps - 1) // self.shift - 1 + t, :] |
|
|
|
|
|
|
|
sample_token_index = torch.multinomial(topk_token_prob, 1).to(agent_pred_rel.device) |
|
next_token_idx = next_token_idx.gather(dim=1, index=sample_token_index).squeeze(-1) |
|
agent_pred_rel = agent_pred_rel.gather(dim=1, |
|
index=sample_token_index[..., None, None, None].expand(-1, -1, 6, 4, |
|
2))[:, 0, ...] |
|
|
|
|
|
diff_xy = agent_pred_rel[:, 1:, 0, :] - agent_pred_rel[:, 1:, 3, :] |
|
pred_traj[:num_agent, t * 5 : (t + 1) * 5] = agent_pred_rel[:, 1:, ...].clone().mean(dim=2) |
|
pred_head[:num_agent, t * 5 : (t + 1) * 5] = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0]) |
|
pred_state[:num_agent, t * 5 : (t + 1) * 5] = next_state_idx[:, None].repeat(1, 5) |
|
|
|
|
|
|
|
pos_a[:, (self.num_historical_steps - 1) // self.shift + t] = agent_pred_rel[:, -1, ...].clone().mean(dim=1) |
|
diff_xy = agent_pred_rel[:, -1, 0, :] - agent_pred_rel[:, -1, 3, :] |
|
theta = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0]) |
|
head_a[:, (self.num_historical_steps - 1) // self.shift + t] = theta |
|
state_a[:, (self.num_historical_steps - 1) // self.shift + t] = next_state_idx |
|
if self.use_grid_token: |
|
grid_a[:, (self.num_historical_steps - 1) // self.shift + t] = self.attr_tokenizer.encode_pos( |
|
x=pos_a[:, (self.num_historical_steps - 1) // self.shift + t], |
|
y=pos_a[av_index, (self.num_historical_steps - 1) // self.shift + t], |
|
theta_y=theta[av_index], |
|
)[0] |
|
|
|
|
|
is_eos = next_state_idx == self.exit_state |
|
is_invalid = next_state_idx == self.invalid_state |
|
|
|
next_token_idx[is_invalid] = -1 |
|
pos_a[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = 0. |
|
head_a[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = 0. |
|
if self.use_grid_token: |
|
grid_a[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = -1 |
|
|
|
mask[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = False |
|
interact_mask[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = False |
|
|
|
agent_token_emb[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = self.no_token_emb(torch.zeros(1, device=device).long()) |
|
|
|
type_emb = categorical_embs[0].reshape(num_agent, num_infer_step, -1) |
|
shape_emb = categorical_embs[1].reshape(num_agent, num_infer_step, -1) |
|
type_emb[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = self.type_a_emb(torch.tensor(AGENT_TYPE.index('seed'), device=device).long()) |
|
shape_emb[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = self.shape_emb(torch.full((1, 3), self.invalid_shape_value, device=device)) |
|
categorical_embs = [type_emb.reshape(-1, self.hidden_dim), shape_emb.reshape(-1, self.hidden_dim)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent_token_emb[veh_mask, (self.num_historical_steps - 1) // self.shift + t] = agent_token_emb_veh[ |
|
next_token_idx[veh_mask]] |
|
agent_token_emb[ped_mask, (self.num_historical_steps - 1) // self.shift + t] = agent_token_emb_ped[ |
|
next_token_idx[ped_mask]] |
|
agent_token_emb[cyc_mask, (self.num_historical_steps - 1) // self.shift + t] = agent_token_emb_cyc[ |
|
next_token_idx[cyc_mask]] |
|
|
|
|
|
motion_vector_a, head_vector_a = self._build_vector_a(pos_a, head_a, state_a) |
|
motion_vector_a[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0. |
|
head_vector_a[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0. |
|
|
|
offset_pos = pos_a - pos_a[av_index] |
|
|
|
x_a = torch.stack( |
|
[torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1), |
|
angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2]), |
|
|
|
], dim=-1) |
|
|
|
x_a = self.x_a_emb(continuous_inputs=x_a.view(-1, x_a.size(-1)), |
|
categorical_embs=categorical_embs) |
|
x_a = x_a.view(-1, num_infer_step, self.hidden_dim) |
|
|
|
s_a = self.state_a_emb(state_a.reshape(-1).long()).reshape(-1, num_infer_step, self.hidden_dim) |
|
feat_a = torch.cat((agent_token_emb, x_a, s_a), dim=-1) |
|
if self.use_grid_token: |
|
agent_grid_emb = self.grid_token_emb[grid_a] |
|
feat_a = torch.cat([feat_a, agent_grid_emb], dim=-1) |
|
feat_a = self.fusion_emb(feat_a) |
|
raw_feat_a = feat_a.clone() |
|
|
|
next_token_idx_list.append(next_token_idx[:, None]) |
|
next_state_idx_list.append(next_state_idx[:, None]) |
|
|
|
|
|
num_inserted_agents_total += num_new_agents |
|
num_inserted_agents += num_new_agents |
|
if num_new_agents > 0: |
|
self.log(t, next_pos_seed, ego_pos_t_1, next_head_seed, ego_head_t_1, next_shape_seed, next_type_idx_seed) |
|
|
|
|
|
allocated_memory = torch.cuda.memory_allocated() / (1024 ** 3) |
|
pbar.set_postfix(memory=f'{allocated_memory:.2f}GB', |
|
insert=f'{num_inserted_agents_total}/{seed_step_mask.sum()}') |
|
|
|
for i in range(len(next_token_idx_list)): |
|
next_token_idx_list[i] = torch.cat([next_token_idx_list[i], torch.zeros(num_agent - next_token_idx_list[i].shape[0], 1, device=device) - 1], dim=0).long() |
|
next_state_idx_list[i] = torch.cat([next_state_idx_list[i], torch.zeros(num_agent - next_state_idx_list[i].shape[0], 1, device=device)], dim=0).long() |
|
|
|
|
|
num_agent = pred_traj.shape[0] |
|
num_init_agent = filter_mask.sum() |
|
|
|
pred_traj = torch.cat([torch.zeros(num_agent, self.num_historical_steps, *(pred_traj.shape[2:]), device=pred_traj.device), pred_traj], dim=1) |
|
pred_head = torch.cat([torch.zeros(num_agent, self.num_historical_steps, *(pred_head.shape[2:]), device=pred_head.device), pred_head], dim=1) |
|
pred_state = torch.cat([torch.zeros(num_agent, self.num_historical_steps, *(pred_state.shape[2:]), device=pred_state.device), pred_state], dim=1) |
|
|
|
pred_traj[:num_init_agent, 0] = data['agent']['position'][filter_mask, 0, :2] |
|
pred_head[:num_init_agent, 0] = data['agent']['heading'][filter_mask, 0] |
|
pred_state[:num_init_agent, 1 : self.num_historical_steps] = data['agent']['state_idx'][filter_mask, :(self.num_historical_steps - 1) // self.shift].repeat_interleave(repeats=self.shift, dim=1) |
|
|
|
historical_token_idx = data['agent']['token_idx'][filter_mask, :(self.num_historical_steps - 1) // self.shift] |
|
historical_token_idx[historical_token_idx < 0] = 0 |
|
historical_token_traj_all = torch.gather(agent_token_traj_all, 1, historical_token_idx[..., None, None, None].expand(-1, -1, 6, 4, 2)) |
|
init_theta = head_a[:num_init_agent, 0] |
|
cos, sin = init_theta.cos(), init_theta.sin() |
|
rot_mat = torch.zeros((num_init_agent, 2, 2), device=init_theta.device) |
|
rot_mat[:, 0, 0] = cos |
|
rot_mat[:, 0, 1] = sin |
|
rot_mat[:, 1, 0] = -sin |
|
rot_mat[:, 1, 1] = cos |
|
historical_token_traj_all = torch.bmm(historical_token_traj_all.view(-1, 4, 2), |
|
rot_mat[:, None, None, ...].repeat(1, (self.num_historical_steps - 1) // self.shift, self.shift + 1, 1, 1).view( |
|
-1, 2, 2)).view(num_init_agent, (self.num_historical_steps - 1) // self.shift, self.shift + 1, 4, 2) |
|
historical_token_traj_all = historical_token_traj_all + pos_a[:num_init_agent, 0, :][:, None, None, None, ...] |
|
pred_traj[:num_init_agent, 1 : self.num_historical_steps] = historical_token_traj_all[:, :, 1:, ...].clone().mean(dim=3).reshape(num_init_agent, -1, 2) |
|
diff_xy = historical_token_traj_all[..., 1:, 0, :] - historical_token_traj_all[..., 1:, 3, :] |
|
pred_head[:num_init_agent, 1 : self.num_historical_steps] = torch.arctan2(diff_xy[..., 1], diff_xy[..., 0]).reshape(num_init_agent, -1) |
|
|
|
|
|
pred_z = torch.zeros_like(pred_traj[..., 0]) |
|
pred_valid = (pred_state != self.invalid_state) & (pred_state != self.enter_state) |
|
|
|
|
|
eval_shape = torch.zeros_like(pred_shape) |
|
eval_shape[veh_mask] = torch.tensor(AGENT_SHAPE['vehicle'], device=device)[None, ...] |
|
eval_shape[ped_mask] = torch.tensor(AGENT_SHAPE['pedstrain'], device=device)[None, ...] |
|
eval_shape[cyc_mask] = torch.tensor(AGENT_SHAPE['cyclist'], device=device)[None, ...] |
|
|
|
next_token_idx = torch.cat(next_token_idx_list, dim=-1) |
|
next_state_idx = torch.cat(next_state_idx_list, dim=-1) if len(next_state_idx_list) > 0 else None |
|
|
|
|
|
assert torch.all(pos_a[next_state_idx == self.invalid_state] == 0), f'Invalid step should have all zeros position!' |
|
|
|
if self.log_message == "": |
|
self.log_message = "No agents inserted!" |
|
else: |
|
self.log_message += f"\nNumber of total inserted agents: {num_inserted_agents_total}/{seed_step_mask.sum()}" |
|
|
|
return { |
|
'ego_index': int(av_index), |
|
'agent_id': agent_id, |
|
|
|
|
|
|
|
'valid_mask': agent_valid_mask, |
|
'pos_a': pos_a, |
|
'head_a': head_a, |
|
'gt_traj': gt_traj, |
|
'pred_traj': pred_traj, |
|
'pred_head': pred_head, |
|
'pred_type': pred_type, |
|
'pred_state': pred_state, |
|
'pred_z': pred_z, |
|
'pred_shape': pred_shape, |
|
'eval_shape': eval_shape, |
|
'pred_valid': pred_valid, |
|
'next_state_prob_seed': torch.cat(next_state_prob_seed_list, dim=1), |
|
'next_pos_rel_prob_seed': torch.cat(next_pos_rel_prob_seed_list, dim=1) |
|
if self.use_grid_token else None, |
|
'next_token_idx': next_token_idx, |
|
'next_state_idx': next_state_idx, |
|
'grid_agent_occ_seed': torch.cat(grid_agent_occ_list, dim=1) |
|
if self.use_grid_token else None, |
|
'grid_pt_occ_seed': torch.cat(grid_pt_occ_list, dim=1) |
|
if self.use_grid_token else None, |
|
'grid_agent_occ_gt_seed': torch.cat(grid_agent_occ_gt_list, dim=1) |
|
if self.use_grid_token else None, |
|
'agent_labels': agent_labels, |
|
'log_message': self.log_message, |
|
} |
|
|
|
def log(self, t, next_pos_seed, ego_pos, next_head_seed, ego_head, next_shape_seed, next_type_idx_seed): |
|
i = 0 |
|
_repr_indent = 4 |
|
for sa in range(next_pos_seed.shape[0]): |
|
head = f"\n{i} agent {sa} is entering at step {t}" |
|
body = [ |
|
f"rel pos {(next_pos_seed[sa] - ego_pos).tolist()}, pos {next_pos_seed[sa].tolist()}", |
|
f"rel head {wrap_angle(next_head_seed[sa] - ego_head).item()}, head {next_head_seed[sa].item()}", |
|
f"shape {next_shape_seed[sa].tolist()}, type {next_type_idx_seed[sa].item()}", |
|
] |
|
self.log_message += "\n".join([head] + [" " * _repr_indent + line for line in body]) |
|
i += 1 |
|
|