longsim-base / backups /dev /modules /agent_decoder.py
gzzyyxy's picture
Upload folder using huggingface_hub
d37e5d1 verified
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from typing import Dict, Mapping, Optional, Literal
from torch_cluster import radius, radius_graph
from torch_geometric.data import HeteroData, Batch
from torch_geometric.utils import dense_to_sparse, subgraph
from scipy.optimize import linear_sum_assignment
from dev.modules.attr_tokenizer import Attr_Tokenizer
from dev.modules.layers import *
from dev.utils.visualization import *
from dev.datasets.preprocess import AGENT_SHAPE, AGENT_TYPE
from dev.utils.func import angle_between_2d_vectors, wrap_angle, weight_init
class HungarianMatcher(nn.Module):
def __init__(self, loss_weight: dict, enter_state: int = 0):
super().__init__()
self.enter_state = enter_state
self.cost_state = loss_weight['state_cls_loss']
self.cost_pos = loss_weight['pos_cls_loss']
self.cost_head = loss_weight['head_cls_loss']
self.cost_shape = loss_weight['shape_reg_loss']
self.seed_state_weight = loss_weight['seed_state_weight']
self.seed_type_weight = loss_weight['seed_type_weight']
@torch.no_grad()
def forward(self, outputs, targets, ptr_pred, ptr_gt, valid_mask=None):
pred_indices = []
gt_indices = []
for b in range(len(ptr_gt) - 1):
start_pred, end_pred = ptr_pred[b], ptr_pred[b + 1]
start_gt, end_gt = ptr_gt[b], ptr_gt[b + 1]
pos_pred = outputs['pos_pred'][start_pred : end_pred] # (n, s, l)
shape_pred = outputs['shape_pred'][start_pred : end_pred]
pos_gt = targets['pos_gt'][start_gt : end_gt]
shape_gt = targets['shape_gt'][start_gt : end_gt]
num_pred = pos_pred.shape[0]
num_gt = pos_gt.shape[0]
cost_pos = F.cross_entropy(pos_pred[:, None].repeat(1, num_gt, 1, 1).reshape(-1, pos_pred.shape[-1]),
pos_gt[None, ...].repeat(num_pred, 1, 1).reshape(-1),
label_smoothing=0.1, ignore_index=-1, reduction='none'
).reshape(num_pred, num_gt, -1)
cost_shape = ((shape_pred[:, None] - shape_gt[None, ...]) ** 2).sum(-1)
C = (
self.cost_pos * cost_pos
+ self.cost_shape * cost_shape
)
C = C.reshape(num_pred, num_gt, -1).cpu().numpy()
if valid_mask is not None:
# in case of seed size is smaller than the maximum number of gt among all steps
C[:, ~valid_mask[start_gt : end_gt].cpu().numpy().astype(np.bool_)] = 1 << 15
_indices = []
for t in range(C.shape[-1]): # num_step
_indices.append(linear_sum_assignment(C[..., t]))
_indices = (
torch.as_tensor(np.array([indices_t[0] for indices_t in _indices]) + int(start_pred), dtype=torch.long).transpose(-1, -2),
torch.as_tensor(np.array([indices_t[1] for indices_t in _indices]) + int(start_gt), dtype=torch.long).transpose(-1, -2),
)
pred_indices.append(_indices[0])
gt_indices.append(_indices[1])
pred_indices = torch.cat(pred_indices)
gt_indices = torch.cat(gt_indices)
return pred_indices, gt_indices
def __repr__(self):
head = "Matcher " + self.__class__.__name__
body = [
"cost_class: {}".format(self.cost_class),
"cost_pos: {}".format(self.cost_pos),
"cost_head: {}".format(self.cost_head),
]
_repr_indent = 4
lines = [head] + [" " * _repr_indent + line for line in body]
return "\n".join(lines)
class SMARTAgentDecoder(nn.Module):
def __init__(self,
dataset: str,
input_dim: int,
hidden_dim: int,
num_historical_steps: int,
time_span: Optional[int],
pl2a_radius: float,
pl2seed_radius: float,
a2a_radius: float,
a2sa_radius: float,
pl2sa_radius: float,
num_freq_bands: int,
num_layers: int,
num_heads: int,
head_dim: int,
dropout: float,
token_size: int,
attr_tokenizer: Attr_Tokenizer=None,
predict_motion: bool=False,
predict_state: bool=False,
predict_map: bool=False,
predict_occ: bool=False,
state_token: Dict[str, int]=None,
use_grid_token: bool=True,
use_head_token: bool=True,
use_state_token: bool=True,
disable_insertion: bool=False,
seed_size: int=5,
buffer_size: int=32,
num_recurrent_steps_val: int=-1,
loss_weight: dict=None,
logger=None) -> None:
super(SMARTAgentDecoder, self).__init__()
self.dataset = dataset
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_historical_steps = num_historical_steps
self.time_span = time_span if time_span is not None else num_historical_steps
self.pl2a_radius = pl2a_radius
self.pl2seed_radius = pl2seed_radius
self.a2a_radius = a2a_radius
self.a2sa_radius = a2sa_radius
self.pl2sa_radius = pl2sa_radius
self.num_freq_bands = num_freq_bands
self.num_layers = num_layers
self.num_heads = num_heads
self.head_dim = head_dim
self.dropout = dropout
self.predict_motion = predict_motion
self.predict_state = predict_state
self.predict_map = predict_map
self.predict_occ = predict_occ
self.use_grid_token = use_grid_token
self.use_head_token = use_head_token
self.use_state_token = use_state_token
self.disable_insertion = disable_insertion
self.num_recurrent_steps_val = num_recurrent_steps_val
self.loss_weight = loss_weight
self.logger = logger
self.attr_tokenizer = attr_tokenizer
# state tokens
self.state_type = list(state_token.keys())
self.state_token = state_token
self.invalid_state = int(state_token['invalid'])
self.valid_state = int(state_token['valid'])
self.enter_state = int(state_token['enter'])
self.exit_state = int(state_token['exit'])
self.seed_state_type = ['invalid', 'enter']
self.valid_state_type = ['invalid', 'valid', 'exit']
input_dim_x_a = 2
input_dim_r_t = 4
input_dim_r_pt2a = 3
input_dim_r_pt2sa = 3
input_dim_r_a2a = 3
input_dim_r_a2sa = 3#4
input_dim_motion_token = 8 # tokens: (token_size, 4, 2)
input_dim_offset_token = 2
self.seed_size = seed_size
self.buffer_size = buffer_size
# self.agent_type = ['veh', 'ped', 'cyc', 'seed']
self.type_a_emb = nn.Embedding(len(AGENT_TYPE), hidden_dim)
self.shape_emb = MLPEmbedding(input_dim=3, hidden_dim=hidden_dim)
self.state_a_emb = nn.Embedding(len(self.state_type), hidden_dim)
self.motion_gap = 1.
self.heading_gap = 1.
self.invalid_shape_value = .1
self.invalid_motion_value = -2.
self.invalid_head_value = -2.
self.x_a_emb = FourierEmbedding(input_dim=input_dim_x_a, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands)
self.r_t_emb = FourierEmbedding(input_dim=input_dim_r_t, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands)
self.r_pt2a_emb = FourierEmbedding(input_dim=input_dim_r_pt2a, hidden_dim=hidden_dim,
num_freq_bands=num_freq_bands)
self.r_a2a_emb = FourierEmbedding(input_dim=input_dim_r_a2a, hidden_dim=hidden_dim,
num_freq_bands=num_freq_bands)
# self.r_sa2sa_emb = FourierEmbedding(input_dim=input_dim_r_t, hidden_dim=hidden_dim,
# num_freq_bands=num_freq_bands)
self.r_pt2sa_emb = FourierEmbedding(input_dim=input_dim_r_pt2sa, hidden_dim=hidden_dim,
num_freq_bands=num_freq_bands)
self.r_a2sa_emb = FourierEmbedding(input_dim=input_dim_r_a2sa, hidden_dim=hidden_dim,
num_freq_bands=num_freq_bands)
self.token_emb_veh = MLPEmbedding(input_dim=input_dim_motion_token, hidden_dim=hidden_dim)
self.token_emb_ped = MLPEmbedding(input_dim=input_dim_motion_token, hidden_dim=hidden_dim)
self.token_emb_cyc = MLPEmbedding(input_dim=input_dim_motion_token, hidden_dim=hidden_dim)
self.token_emb_grid = MLPEmbedding(input_dim=input_dim_offset_token, hidden_dim=hidden_dim)
self.no_token_emb = nn.Embedding(1, hidden_dim)
self.bos_token_emb = nn.Embedding(1, hidden_dim)
self.invalid_offset_token_emb = nn.Embedding(1, hidden_dim)
if self.use_grid_token:
num_inputs = 4
else:
num_inputs = 3
self.fusion_emb = MLPEmbedding(input_dim=self.hidden_dim * num_inputs, hidden_dim=self.hidden_dim)
self.t_attn_layers = nn.ModuleList(
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
bipartite=False, has_pos_emb=True) for _ in range(num_layers)]
)
self.pt2a_attn_layers = nn.ModuleList(
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
bipartite=True, has_pos_emb=True) for _ in range(num_layers)]
)
self.a2a_attn_layers = nn.ModuleList(
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
bipartite=False, has_pos_emb=True) for _ in range(num_layers)]
)
# FIXME: for test!
self.seed_layers = 3
# self.sa2sa_attn_layers = nn.ModuleList(
# [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
# bipartite=False, has_pos_emb=True) for _ in range(self.seed_layers)]
# )
self.pt2sa_attn_layers = nn.ModuleList(
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
bipartite=True, has_pos_emb=True) for _ in range(self.seed_layers)]
)
self.a2sa_attn_layers = nn.ModuleList(
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
bipartite=False, has_pos_emb=True) for _ in range(self.seed_layers)]
)
self.occ2sa_attn_layers = nn.ModuleList(
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
bipartite=True, has_pos_emb=False) for _ in range(self.seed_layers)]
)
self.token_size = token_size # 2048
# agent motion prediction head
self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=self.token_size)
# agent state prediction head
self.state_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=len(self.valid_state_type))
self.seed_state_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=len(self.seed_state_type))
self.seed_type_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=len(AGENT_TYPE) - 1)
self.seed_shape_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=3)
self.grid_size = self.attr_tokenizer.grid_size
self.angle_size = self.attr_tokenizer.angle_size
if self.use_grid_token:
self.seed_pos_rel_token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=self.grid_size)
self.seed_offset_xy_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=2)
self.seed_agent_occ_embed = MLPLayer(input_dim=self.grid_size, hidden_dim=hidden_dim,
output_dim=hidden_dim)
else:
self.seed_pos_rel_xy_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=2)
if self.use_head_token:
self.seed_heading_rel_token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=self.angle_size)
else:
self.seed_heading_rel_theta_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=1)
# self.seed_pt_occ_embed = MLPLayer(input_dim=self.grid_size, hidden_dim=hidden_dim,
# output_dim=hidden_dim)
if self.predict_occ:
self.grid_agent_occ_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=self.grid_size)
self.grid_pt_occ_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=self.grid_size)
self.grid_index_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=self.grid_size)
# self.num_seed_feature = 1
# self.num_seed_feature = self.seed_size
self.num_seed_feature = 10
# self.trajectory_token = token_data['token'] # dict('veh', 'ped', 'cyc') (2048, 4, 2)
# self.trajectory_token_traj = token_data['traj'] # (2048, 6, 3)
# self.trajectory_token_all = token_data['token_all'] # (2048, 6, 4, 2)
self.apply(weight_init)
self.shift = 5
self.motion_beam_size = 5
self.insert_beam_size = 10
self.hist_mask = True
self.temporal_attn_to_invalid = False
self.use_rel = False
self.inference_filter_overlap = True
assert self.num_recurrent_steps_val % self.shift == 0 or self.num_recurrent_steps_val == -1, \
f"Invalid num_recurrent_steps_val: {num_recurrent_steps_val}."
# seed agent
self.temporal_attn_seed = False
self.seed_attn_to_av = True
self.seed_use_ego_motion = False
self.matcher = HungarianMatcher(loss_weight=loss_weight, enter_state=self.enter_state)
def transform_rel(self, token_traj, prev_pos, prev_heading=None):
if prev_heading is None:
diff_xy = prev_pos[:, :, -1, :] - prev_pos[:, :, -2, :]
prev_heading = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])
num_agent, num_step, traj_num, traj_dim = token_traj.shape
cos, sin = prev_heading.cos(), prev_heading.sin()
rot_mat = torch.zeros((num_agent, num_step, 2, 2), device=prev_heading.device)
rot_mat[:, :, 0, 0] = cos
rot_mat[:, :, 0, 1] = -sin
rot_mat[:, :, 1, 0] = sin
rot_mat[:, :, 1, 1] = cos
agent_diff_rel = torch.bmm(token_traj.view(-1, traj_num, 2), rot_mat.view(-1, 2, 2)).view(num_agent, num_step, traj_num, traj_dim)
agent_pred_rel = agent_diff_rel + prev_pos[:, :, -1:, :]
return agent_pred_rel
def _agent_token_embedding(self, data, agent_token_index, agent_state, agent_offset_token_idx, pos_a, head_a,
inference=False, filter_mask=None, av_index=None):
if filter_mask is None:
filter_mask = torch.ones_like(agent_state[:, 2], dtype=torch.bool)
num_agent, num_step, traj_dim = pos_a.shape # traj_dim=2
agent_type = data['agent']['type'][filter_mask]
veh_mask = agent_type == 0
ped_mask = agent_type == 1
cyc_mask = agent_type == 2
motion_vector_a, head_vector_a = self._build_vector_a(pos_a, head_a, agent_state)
trajectory_token_veh = data['agent']['trajectory_token_veh'] # [n_token, 6, 4, 2]
trajectory_token_ped = data['agent']['trajectory_token_ped']
trajectory_token_cyc = data['agent']['trajectory_token_cyc']
agent_token_emb_veh = self.token_emb_veh(trajectory_token_veh[:, -1].flatten(1, 2))
agent_token_emb_ped = self.token_emb_ped(trajectory_token_ped[:, -1].flatten(1, 2))
agent_token_emb_cyc = self.token_emb_cyc(trajectory_token_cyc[:, -1].flatten(1, 2))
# add bos token embedding
agent_token_emb_veh = torch.cat([agent_token_emb_veh, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())])
agent_token_emb_ped = torch.cat([agent_token_emb_ped, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())])
agent_token_emb_cyc = torch.cat([agent_token_emb_cyc, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())])
# add invalid token embedding
agent_token_emb_veh = torch.cat([agent_token_emb_veh, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())])
agent_token_emb_ped = torch.cat([agent_token_emb_ped, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())])
agent_token_emb_cyc = torch.cat([agent_token_emb_cyc, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())])
# additional token embeddings are already added -> -1: invalid, -2: bos
agent_token_emb = torch.zeros((num_agent, num_step, self.hidden_dim), device=pos_a.device)
agent_token_emb[veh_mask] = agent_token_emb_veh[agent_token_index[veh_mask]]
agent_token_emb[ped_mask] = agent_token_emb_ped[agent_token_index[ped_mask]]
agent_token_emb[cyc_mask] = agent_token_emb_cyc[agent_token_index[cyc_mask]]
# grid embedding
self.grid_token_emb = self.token_emb_grid(self.attr_tokenizer.grid)
self.grid_token_emb = torch.cat([self.grid_token_emb, self.invalid_offset_token_emb(torch.zeros(1, device=pos_a.device).long())])
offset_token_emb = self.grid_token_emb[agent_offset_token_idx]
# 'vehicle', 'pedestrian', 'cyclist', 'background'
is_invalid = agent_state == self.invalid_state
agent_types = data['agent']['type'].clone()[filter_mask].long().repeat_interleave(repeats=num_step, dim=0)
agent_types[is_invalid.reshape(-1)] = AGENT_TYPE.index('seed')
agent_shapes = data['agent']['shape'].clone()[filter_mask, self.num_historical_steps - 1, :].repeat_interleave(repeats=num_step, dim=0)
agent_shapes[is_invalid.reshape(-1)] = self.invalid_shape_value
# TODO: fix ego_pos in inference mode
offset_pos = pos_a - pos_a[av_index].repeat_interleave(repeats=data['batch_size_a'], dim=0)
feat_a, categorical_embs = self._build_agent_feature(num_step, pos_a.device,
motion_vector_a,
head_vector_a,
agent_token_emb,
offset_token_emb,
offset_pos=offset_pos,
type=agent_types,
shape=agent_shapes,
state=agent_state,
n=num_agent)
if inference:
return (
feat_a,
agent_token_emb,
agent_token_emb_veh,
agent_token_emb_ped,
agent_token_emb_cyc,
categorical_embs,
trajectory_token_veh,
trajectory_token_ped,
trajectory_token_cyc,
)
else:
# seed agent feature
if self.seed_use_ego_motion:
motion_vector_seed = motion_vector_a[av_index].repeat_interleave(repeats=self.num_seed_feature, dim=0)
head_vector_seed = head_vector_a[av_index].repeat_interleave(repeats=self.num_seed_feature, dim=0)
else:
motion_vector_seed = head_vector_seed = None
feat_seed, _ = self._build_agent_feature(num_step, pos_a.device,
motion_vector_seed,
head_vector_seed,
state_index=self.invalid_state,
n=data.num_graphs * self.num_seed_feature)
feat_a = torch.cat([feat_a, feat_seed], dim=0) # (a + s, t, d)
return feat_a
def _build_vector_a(self, pos_a, head_a, state_a):
num_agent = pos_a.shape[0]
motion_vector_a = torch.cat([pos_a.new_zeros(num_agent, 1, self.input_dim),
pos_a[:, 1:] - pos_a[:, :-1]], dim=1)
motion_vector_a[state_a == self.invalid_state] = self.invalid_motion_value
# invalid -> valid
is_last_invalid = (state_a.roll(shifts=1, dims=1) == self.invalid_state) & (state_a != self.invalid_state)
is_last_invalid[:, 0] = state_a[:, 0] == self.enter_state
motion_vector_a[is_last_invalid] = self.motion_gap
# valid -> invalid
is_last_valid = (state_a.roll(shifts=1, dims=1) != self.invalid_state) & (state_a == self.invalid_state)
is_last_valid[:, 0] = False
motion_vector_a[is_last_valid] = -self.motion_gap
head_a[state_a == self.invalid_state] == self.invalid_head_value
head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1)
return motion_vector_a, head_vector_a
def _build_agent_feature(self, num_step, device,
motion_vector=None,
head_vector=None,
agent_token_emb=None,
agent_grid_emb=None,
offset_pos=None,
type=None,
shape=None,
categorical_embs_a=None,
state=None,
state_index=None,
n=1):
if agent_token_emb is None:
agent_token_emb = self.no_token_emb(torch.zeros(1, device=device).long())[:, None].repeat(n, num_step, 1)
if state is not None:
agent_token_emb[state == self.enter_state] = self.bos_token_emb(torch.zeros(1, device=device).long())
if agent_grid_emb is None:
agent_grid_emb = self.grid_token_emb[None, None, self.grid_size // 2].repeat(n, num_step, 1)
if motion_vector is None or head_vector is None:
pos_a = torch.zeros((n, num_step, 2), device=device)
head_a = torch.zeros((n, num_step), device=device)
if state is None:
state = torch.full((n, num_step), self.invalid_state, device=device)
motion_vector, head_vector = self._build_vector_a(pos_a, head_a, state)
if offset_pos is None:
offset_pos = torch.zeros_like(motion_vector)
feature_a = torch.stack(
[torch.norm(motion_vector[:, :, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=head_vector, nbr_vector=motion_vector[:, :, :2]),
# torch.norm(offset_pos[:, :, :2], p=2, dim=-1),
], dim=-1)
if categorical_embs_a is None:
if type is None:
type = torch.tensor([AGENT_TYPE.index('seed')], device=device)
if shape is None:
shape = torch.full((1, 3), self.invalid_shape_value, device=device)
categorical_embs_a = [self.type_a_emb(type.reshape(-1)), self.shape_emb(shape.reshape(-1, shape.shape[-1]))]
x_a = self.x_a_emb(continuous_inputs=feature_a.view(-1, feature_a.size(-1)),
categorical_embs=categorical_embs_a)
x_a = x_a.view(-1, num_step, self.hidden_dim) # (a, t, d)
if state is None:
assert state_index is not None, f"state index need to be set when state tensor is None!"
state = torch.tensor([state_index], device=device)[:, None].repeat(n, num_step, 1) # do not use `expand`
s_a = self.state_a_emb(state.reshape(-1).long()).reshape(n, num_step, self.hidden_dim)
feat_a = torch.cat((agent_token_emb, x_a, s_a), dim=-1)
if self.use_grid_token:
feat_a = torch.cat([feat_a, agent_grid_emb], dim=-1)
feat_a = self.fusion_emb(feat_a) # (a, t, d)
return feat_a, categorical_embs_a
def _pad_feat(self, num_graph, av_index, *feats, num_seed_feature=None):
if num_seed_feature is None:
num_seed_feature = self.num_seed_feature
padded_feats = tuple()
for i in range(len(feats)):
padded_feats += (torch.cat([feats[i], feats[i][av_index].repeat_interleave(
repeats=num_seed_feature, dim=0)],
dim=0
),)
pad_mask = torch.ones(*padded_feats[0].shape[:2], device=feats[0].device).bool() # (a, t)
pad_mask[-num_graph * num_seed_feature:] = False
return padded_feats + (pad_mask,)
# def _build_seed_feat(self, data, pos_a, head_a, state_a, head_vector_a, mask, sort_indices, av_index):
# seed_mask = sort_indices != av_index.repeat_interleave(repeats=data['batch_size_a'], dim=0)[:, None]
# # TODO: fix batch_size!!!
# print(mask.shape, sort_indices.shape, seed_mask.shape)
# mask[-data.num_graphs * self.num_seed_feature:] = seed_mask[:self.num_seed_feature]
# insert_pos_a = torch.gather(pos_a, dim=0, index=sort_indices[:self.num_seed_feature, :, None].expand(-1, -1, pos_a.shape[-1]))
# pos_a[mask] = insert_pos_a[mask[-self.num_seed_feature:]]
# state_a[-data.num_graphs * self.num_seed_feature:] = self.enter_state
# return pos_a, head_a, state_a, head_vector_a, mask
def _build_temporal_edge(self, data, pos_a, head_a, state_a, head_vector_a, mask, inference_mask=None):
num_graph = data.num_graphs
num_agent = pos_a.shape[0]
hist_mask = mask.clone()
if not self.temporal_attn_to_invalid:
is_bos = state_a == self.enter_state
bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
history_invalid_mask = torch.arange(mask.shape[1]).expand(mask.shape[0], mask.shape[1]).to(mask.device)
history_invalid_mask = (history_invalid_mask < bos_index[:, None])
hist_mask[history_invalid_mask] = False
if not self.temporal_attn_seed:
hist_mask[-num_graph * self.num_seed_feature:] = False
if inference_mask is not None:
inference_mask[-num_graph * self.num_seed_feature:] = False
else:
# WARNING: if use temporal attn to seed
# we need to fix the pos/head of seed!!!
raise RuntimeError("Wrong settings!")
pos_t = pos_a.reshape(-1, self.input_dim) # (num_agent * num_step, ...)
head_t = head_a.reshape(-1)
head_vector_t = head_vector_a.reshape(-1, 2)
# for those invalid agents won't predict any motion token, we don't attend to them
is_bos = state_a == self.enter_state
is_bos[-num_graph * self.num_seed_feature:] = False
bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
motion_predict_start_index = torch.clamp(bos_index - self.time_span / self.shift + 1, min=0)
motion_predict_mask = torch.arange(hist_mask.shape[1]).expand(hist_mask.shape[0], -1).to(hist_mask.device)
motion_predict_mask = motion_predict_mask >= motion_predict_start_index[:, None]
hist_mask[~motion_predict_mask] = False
if self.hist_mask and self.training:
hist_mask[
torch.arange(mask.shape[0]).unsqueeze(1), torch.randint(0, mask.shape[1], (num_agent, 10))] = False
mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1)
elif inference_mask is not None:
mask_t = hist_mask.unsqueeze(2) & inference_mask.unsqueeze(1)
else:
mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1)
# mask_t: (num_agent, 18, 18), edge_index_t: (2, num_edge)
edge_index_t = dense_to_sparse(mask_t)[0]
edge_index_t = edge_index_t[:, (edge_index_t[1] - edge_index_t[0] > 0) &
(edge_index_t[1] - edge_index_t[0] <= self.time_span / self.shift)]
rel_pos_t = pos_t[edge_index_t[0]] - pos_t[edge_index_t[1]]
rel_head_t = wrap_angle(head_t[edge_index_t[0]] - head_t[edge_index_t[1]])
# handle the invalid steps
is_invalid = state_a == self.invalid_state
is_invalid_t = is_invalid.reshape(-1)
rel_pos_t[is_invalid_t[edge_index_t[0]] & ~is_invalid_t[edge_index_t[1]]] = -self.motion_gap
rel_pos_t[~is_invalid_t[edge_index_t[0]] & is_invalid_t[edge_index_t[1]]] = self.motion_gap
rel_head_t[is_invalid_t[edge_index_t[0]] & ~is_invalid_t[edge_index_t[1]]] = -self.heading_gap
rel_head_t[~is_invalid_t[edge_index_t[1]] & is_invalid_t[edge_index_t[1]]] = self.heading_gap
rel_pos_t[is_invalid_t[edge_index_t[0]] & is_invalid_t[edge_index_t[1]]] = self.invalid_motion_value
rel_head_t[is_invalid_t[edge_index_t[0]] & is_invalid_t[edge_index_t[1]]] = self.invalid_head_value
r_t = torch.stack(
[torch.norm(rel_pos_t[:, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=head_vector_t[edge_index_t[1]], nbr_vector=rel_pos_t[:, :2]),
rel_head_t,
edge_index_t[0] - edge_index_t[1]], dim=-1)
r_t = self.r_t_emb(continuous_inputs=r_t, categorical_embs=None)
return edge_index_t, r_t
def _build_interaction_edge(self, data, pos_a, head_a, state_a, head_vector_a, batch_s, mask, pad_mask=None, inference_mask=None,
av_index=None, seq_mask=None, seq_index=None, grid_index_a=None, **plot_kwargs):
num_graph = data.num_graphs
num_agent, num_step, _ = pos_a.shape
is_training = inference_mask is None
mask_a = mask.clone()
if pad_mask is None:
pad_mask = torch.ones_like(state_a).bool()
pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
head_s = head_a.transpose(0, 1).reshape(-1)
head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)
pad_mask_s = pad_mask.transpose(0, 1).reshape(-1)
if inference_mask is not None:
mask_a = mask_a & inference_mask
mask_s = mask_a.transpose(0, 1).reshape(-1)
# build agent2agent bilateral connection
edge_index_a2a = radius_graph(x=pos_s[:, :2], r=self.a2a_radius, batch=batch_s, loop=False,
max_num_neighbors=300)
edge_index_a2a = subgraph(subset=mask_s & pad_mask_s, edge_index=edge_index_a2a)[0]
if int(os.getenv('PLOT_EDGE', 0)):
plot_interact_edge(edge_index_a2a, data['scenario_id'], data['batch_size_a'].cpu(), self.num_seed_feature, num_step,
av_index=av_index, **plot_kwargs)
rel_pos_a2a = pos_s[edge_index_a2a[0]] - pos_s[edge_index_a2a[1]]
rel_head_a2a = wrap_angle(head_s[edge_index_a2a[0]] - head_s[edge_index_a2a[1]])
# handle the invalid steps
is_invalid = state_a == self.invalid_state
is_invalid_s = is_invalid.transpose(0, 1).reshape(-1)
rel_pos_a2a[is_invalid_s[edge_index_a2a[0]] & ~is_invalid_s[edge_index_a2a[1]]] = -self.motion_gap
rel_pos_a2a[~is_invalid_s[edge_index_a2a[0]] & is_invalid_s[edge_index_a2a[1]]] = self.motion_gap
rel_head_a2a[is_invalid_s[edge_index_a2a[0]] & ~is_invalid_s[edge_index_a2a[1]]] = -self.heading_gap
rel_head_a2a[~is_invalid_s[edge_index_a2a[1]] & is_invalid_s[edge_index_a2a[1]]] = self.heading_gap
rel_pos_a2a[is_invalid_s[edge_index_a2a[0]] & is_invalid_s[edge_index_a2a[1]]] = self.invalid_motion_value
rel_head_a2a[is_invalid_s[edge_index_a2a[0]] & is_invalid_s[edge_index_a2a[1]]] = self.invalid_head_value
r_a2a = torch.stack(
[torch.norm(rel_pos_a2a[:, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2a[1]], nbr_vector=rel_pos_a2a[:, :2]),
rel_head_a2a], dim=-1)
r_a2a = self.r_a2a_emb(continuous_inputs=r_a2a, categorical_embs=None)
# add the edges which connect seed agents
if is_training:
mask_av = torch.ones_like(mask_a).bool()
if not self.seed_attn_to_av:
mask_av[av_index] = False
mask_a &= mask_av
edge_index_seed2a, r_seed2a = self._build_a2sa_edge(data, pos_a, head_a, head_vector_a, batch_s,
mask_a.clone(), ~pad_mask.clone(), inference_mask=inference_mask,
r=self.pl2seed_radius, max_num_neighbors=300,
seq_mask=seq_mask, seq_index=seq_index, grid_index_a=grid_index_a, mode='insert')
if os.getenv('PLOT_EDGE', False):
plot_interact_edge(edge_index_seed2a, data['scenario_id'], data['batch_size_a'].cpu(), self.num_seed_feature, num_step,
'interact_edge_map_seed', av_index=av_index, **plot_kwargs)
edge_index_a2a = torch.cat([edge_index_a2a, edge_index_seed2a], dim=-1)
r_a2a = torch.cat([r_a2a, r_seed2a])
return edge_index_a2a, r_a2a, (edge_index_a2a.shape[1], edge_index_seed2a.shape[1])
return edge_index_a2a, r_a2a
def _build_map2agent_edge(self, data, pos_a, head_a, state_a, head_vector_a, batch_s, batch_pl,
mask, pad_mask=None, inference_mask=None, av_index=None, **kwargs):
num_graph = data.num_graphs
num_agent, num_step, _ = pos_a.shape
is_training = inference_mask is None
mask_pl2a = mask.clone()
if pad_mask is None:
pad_mask = torch.ones_like(state_a).bool()
pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
head_s = head_a.transpose(0, 1).reshape(-1)
head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)
pad_mask_s = pad_mask.transpose(0, 1).reshape(-1)
if inference_mask is not None:
mask_pl2a = mask_pl2a & inference_mask
mask_pl2a = mask_pl2a.transpose(0, 1).reshape(-1)
ori_pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous()
ori_orient_pl = data['pt_token']['orientation'].contiguous()
pos_pl = ori_pos_pl.repeat(num_step, 1) # not `repeat_interleave`
orient_pl = ori_orient_pl.repeat(num_step)
# build map2agent directed graph
# edge_index_pl2a = radius(x=pos_s[:, :2], y=pos_pl[:, :2], r=self.pl2a_radius,
# batch_x=batch_s, batch_y=batch_pl, max_num_neighbors=300)
edge_index_pl2a = radius(x=pos_pl[:, :2], y=pos_s[:, :2], r=self.pl2a_radius,
batch_x=batch_pl, batch_y=batch_s, max_num_neighbors=5)
edge_index_pl2a = edge_index_pl2a[[1, 0]]
edge_index_pl2a = edge_index_pl2a[:, mask_pl2a[edge_index_pl2a[1]] &
pad_mask_s[edge_index_pl2a[1]]]
rel_pos_pl2a = pos_pl[edge_index_pl2a[0]] - pos_s[edge_index_pl2a[1]]
rel_orient_pl2a = wrap_angle(orient_pl[edge_index_pl2a[0]] - head_s[edge_index_pl2a[1]])
# handle the invalid steps
is_invalid = state_a == self.invalid_state
is_invalid_s = is_invalid.transpose(0, 1).reshape(-1)
rel_pos_pl2a[is_invalid_s[edge_index_pl2a[1]]] = self.motion_gap
rel_orient_pl2a[is_invalid_s[edge_index_pl2a[1]]] = self.heading_gap
r_pl2a = torch.stack(
[torch.norm(rel_pos_pl2a[:, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_pl2a[1]], nbr_vector=rel_pos_pl2a[:, :2]),
rel_orient_pl2a], dim=-1)
r_pl2a = self.r_pt2a_emb(continuous_inputs=r_pl2a, categorical_embs=None)
# add the edges which connect seed agents
if is_training:
edge_index_pl2seed, r_pl2seed = self._build_map2sa_edge(data, pos_a, head_a, head_vector_a, batch_s, batch_pl,
~pad_mask.clone(), inference_mask=inference_mask,
r=self.pl2seed_radius, max_num_neighbors=2048, mode='insert')
# sanity check
# pl2a_index = torch.zeros(pos_a.shape[0], num_step)
# pl2a_r = torch.zeros(pos_a.shape[0], num_step)
# for src_index in torch.unique(edge_index_pl2seed[1]):
# src_row = src_index % pos_a.shape[0]
# src_col = src_index // pos_a.shape[0]
# pl2a_index[src_row, src_col] = edge_index_pl2seed[0, edge_index_pl2seed[1] == src_index].sum()
# pl2a_r[src_row, src_col] = r_pl2seed[edge_index_pl2seed[1] == src_index].sum()
# print(pl2a_index)
# print(pl2a_r)
# exit(1)
if os.getenv('PLOT_EDGE', False):
plot_interact_edge(edge_index_pl2seed, data['scenario_id'], data['batch_size_a'].cpu(), self.num_seed_feature, num_step,
'interact_edge_map_seed', av_index=av_index)
edge_index_pl2a = torch.cat([edge_index_pl2a, edge_index_pl2seed], dim=-1)
r_pl2a = torch.cat([r_pl2a, r_pl2seed])
return edge_index_pl2a, r_pl2a, (edge_index_pl2a.shape[1], edge_index_pl2seed.shape[1])
return edge_index_pl2a, r_pl2a
def _build_a2sa_edge(self, data, pos_a, head_a, head_vector_a, batch_s, mask_a, mask_sa,
inference_mask=None, r=None, max_num_neighbors=8, seq_mask=None, seq_index=None,
grid_index_a=None, mode: Literal['insert', 'feature']='feature', **plot_kwargs):
num_agent, num_step, _ = pos_a.shape
is_training = inference_mask is None
pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
head_s = head_a.transpose(0, 1).reshape(-1)
head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)
if inference_mask is not None:
mask_a = mask_a & inference_mask
mask_sa = mask_sa & inference_mask
mask_s = mask_a.transpose(0, 1).reshape(-1)
mask_s_sa = mask_sa.transpose(0, 1).reshape(-1)
# build seed_agent2agent unilateral connection
assert r is not None, "r needs to be specified!"
# edge_index_a2sa = radius(x=pos_s[mask_s_sa, :2], y=pos_s[:, :2], r=r,
# batch_x=batch_s[mask_s_sa], batch_y=batch_s, max_num_neighbors=max_num_neighbors)
edge_index_a2sa = radius(x=pos_s[:, :2], y=pos_s[mask_s_sa, :2], r=r,
batch_x=batch_s, batch_y=batch_s[mask_s_sa], max_num_neighbors=max_num_neighbors)
edge_index_a2sa = edge_index_a2sa[[1, 0]]
edge_index_a2sa = edge_index_a2sa[:, ~mask_s_sa[edge_index_a2sa[0]] & mask_s[edge_index_a2sa[0]]]
# only for seed agent sequence training
if seq_mask is not None:
edge_mask = seq_mask[edge_index_a2sa[1]]
edge_mask = torch.gather(edge_mask, dim=1, index=edge_index_a2sa[0, :, None] % num_agent)[:, 0]
edge_index_a2sa = edge_index_a2sa[:, edge_mask]
if seq_index is None:
seq_index = torch.zeros(num_agent, device=pos_a.device).long()
if seq_index.dim() == 1:
seq_index = seq_index[:, None].repeat(1, num_step)
seq_index = seq_index.transpose(0, 1).reshape(-1)
assert seq_index.shape[0] == pos_s.shape[0], f"Inconsistent lenght {seq_index.shape[0]} and {pos_s.shape[0]}!"
# convert to global index
all_index = torch.arange(pos_s.shape[0], device=pos_a.device).long()
sa_index = all_index[mask_s_sa]
edge_index_a2sa[1] = sa_index[edge_index_a2sa[1]]
# plot edge index TODO: now only support bs=1
if os.getenv('PLOT_EDGE_INFERENCE', False) and not is_training:
num_agent, num_step, _ = pos_a.shape
# plot_interact_edge(edge_index_a2sa, data['scenario_id'], data['batch_size_a'].cpu(), 1, num_step,
# 'interact_a2sa_edge_map', **plot_kwargs)
plot_interact_edge(edge_index_a2sa, data['scenario_id'], torch.tensor([num_agent - 1]), 1, num_step,
f"interact_a2sa_edge_map_infer_{plot_kwargs['tag']}", **plot_kwargs)
rel_pos_a2sa = pos_s[edge_index_a2sa[0]] - pos_s[edge_index_a2sa[1]]
rel_head_a2sa = wrap_angle(head_s[edge_index_a2sa[0]] - head_s[edge_index_a2sa[1]])
if mode == 'insert':
# assert grid_index_a is not None, f"Missing input: grid_index_a!"
# grid_index_s = grid_index_a.transpose(0, 1).reshape(-1)
# assert grid_index_s[edge_index_a2sa[0]].min() >= 0, "Found invalid values in grid index"
# r_a2sa = torch.stack(
# [self.attr_tokenizer.dist[grid_index_s[edge_index_a2sa[0]]],
# self.attr_tokenizer.dir[grid_index_s[edge_index_a2sa[0]]],
# rel_head_a2sa,
# seq_index[edge_index_a2sa[0]] - seq_index[edge_index_a2sa[1]]], dim=-1)
# r_a2sa = torch.stack(
# [torch.norm(rel_pos_a2sa[:, :2], p=2, dim=-1),
# angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2sa[1]], nbr_vector=rel_pos_a2sa[:, :2]),
# rel_head_a2sa,
# seq_index[edge_index_a2sa[0]] - seq_index[edge_index_a2sa[1]]], dim=-1)
r_a2sa = torch.stack(
[torch.norm(rel_pos_a2sa[:, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2sa[1]], nbr_vector=rel_pos_a2sa[:, :2]),
rel_head_a2sa], dim=-1)
# TODO: try categorical embs
r_a2sa = self.r_a2sa_emb(continuous_inputs=r_a2sa, categorical_embs=None)
elif mode == 'feature':
r_a2sa = torch.stack(
[torch.norm(rel_pos_a2sa[:, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2sa[1]], nbr_vector=rel_pos_a2sa[:, :2]),
rel_head_a2sa], dim=-1)
r_a2sa = self.r_a2a_emb(continuous_inputs=r_a2sa, categorical_embs=None)
else:
raise ValueError(f"Unsupport mode {mode}.")
return edge_index_a2sa, r_a2sa
def _build_map2sa_edge(self, data, pos_a, head_a, head_vector_a, batch_s, batch_pl,
mask_sa, inference_mask=None, r=None, max_num_neighbors=32, mode: Literal['insert', 'feature']='feature'):
_, num_step, _ = pos_a.shape
mask_pl2sa = torch.ones_like(mask_sa).bool()
if inference_mask is not None:
mask_pl2sa = mask_pl2sa & inference_mask
mask_pl2sa = mask_pl2sa.transpose(0, 1).reshape(-1)
mask_s_sa = mask_sa.transpose(0, 1).reshape(-1)
pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
head_s = head_a.transpose(0, 1).reshape(-1)
head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)
ori_pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous()
ori_orient_pl = data['pt_token']['orientation'].contiguous()
pos_pl = ori_pos_pl.repeat(num_step, 1) # not `repeat_interleave`
orient_pl = ori_orient_pl.repeat(num_step)
# build map2agent directed graph
assert r is not None, "r needs to be specified!"
# edge_index_pl2sa = radius(x=pos_s[mask_s_sa, :2], y=pos_pl[:, :2], r=r,
# batch_x=batch_s[mask_s_sa], batch_y=batch_pl, max_num_neighbors=max_num_neighbors)
edge_index_pl2sa = radius(x=pos_pl[:, :2], y=pos_s[mask_s_sa, :2], r=r,
batch_x=batch_pl, batch_y=batch_s[mask_s_sa], max_num_neighbors=max_num_neighbors)
edge_index_pl2sa = edge_index_pl2sa[[1, 0]]
edge_index_pl2sa = edge_index_pl2sa[:, mask_pl2sa[mask_s_sa][edge_index_pl2sa[1]]]
# convert to global index
all_index = torch.arange(pos_s.shape[0], device=pos_a.device).long()
sa_index = all_index[mask_s_sa]
edge_index_pl2sa[1] = sa_index[edge_index_pl2sa[1]]
# plot edge map
# if os.getenv('PLOT_EDGE', False):
# plot_map_edge(edge_index_pl2sa, pos_s[:, :2], data, save_path='map2sa_edge_map')
rel_pos_pl2sa = pos_pl[edge_index_pl2sa[0]] - pos_s[edge_index_pl2sa[1]]
rel_orient_pl2sa = wrap_angle(orient_pl[edge_index_pl2sa[0]] - head_s[edge_index_pl2sa[1]])
r_pl2sa = torch.stack(
[torch.norm(rel_pos_pl2sa[:, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_pl2sa[1]], nbr_vector=rel_pos_pl2sa[:, :2]),
rel_orient_pl2sa], dim=-1)
if mode == 'insert':
r_pl2sa = self.r_pt2sa_emb(continuous_inputs=r_pl2sa, categorical_embs=None)
elif mode == 'feature':
r_pl2sa = self.r_pt2a_emb(continuous_inputs=r_pl2sa, categorical_embs=None)
else:
raise ValueError(f"Unsupport mode {mode}.")
return edge_index_pl2sa, r_pl2sa
# def _build_sa2sa_edge(self, data, pos_a, head_a, state_a, head_vector_a, batch_s, mask, inference_mask=None, **plot_kwargs):
# num_agent = pos_a.shape[0]
# pos_t = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
# head_t = head_a.reshape(-1)
# head_vector_t = head_vector_a.reshape(-1, 2)
# if inference_mask is not None:
# mask_t = mask.unsqueeze(2) & inference_mask.unsqueeze(1)
# else:
# mask_t = mask.unsqueeze(2) & mask.unsqueeze(1)
# edge_index_sa2sa = dense_to_sparse(mask_t)[0]
# edge_index_sa2sa = edge_index_sa2sa[:, edge_index_sa2sa[1] - edge_index_sa2sa[0] > 0]
# rel_pos_t = pos_t[edge_index_sa2sa[0]] - pos_t[edge_index_sa2sa[1]]
# rel_head_t = wrap_angle(head_t[edge_index_sa2sa[0]] - head_t[edge_index_sa2sa[1]])
# r_t = torch.stack(
# [torch.norm(rel_pos_t[:, :2], p=2, dim=-1),
# angle_between_2d_vectors(ctr_vector=head_vector_t[edge_index_sa2sa[1]], nbr_vector=rel_pos_t[:, :2]),
# rel_head_t,
# edge_index_sa2sa[0] - edge_index_sa2sa[1]], dim=-1)
# r_sa2sa = self.r_sa2sa_emb(continuous_inputs=r_t, categorical_embs=None)
# return edge_index_sa2sa, r_sa2sa
def get_inputs(self, data: HeteroData) -> Dict[str, torch.Tensor]:
pos_a = data['agent']['token_pos'].clone()
head_a = data['agent']['token_heading'].clone()
agent_token_index = data['agent']['token_idx'].clone()
agent_state_index = data['agent']['state_idx'].clone()
mask = data['agent']['raw_agent_valid_mask'].clone()
agent_grid_token_idx = data['agent']['grid_token_idx']
agent_grid_offset_xy = data['agent']['grid_offset_xy']
agent_head_token_idx = data['agent']['heading_token_idx']
sort_indices = data['agent']['sort_indices']
next_token_index_gt = agent_token_index.roll(shifts=-1, dims=1)
next_state_index_gt = agent_token_index.roll(shifts=-1, dims=1)
# next token prediction mask
bos_token_index = torch.nonzero(agent_state_index == self.enter_state)
eos_token_index = torch.nonzero(agent_state_index == self.exit_state)
# mask for motion tokens
next_token_eval_mask = mask.clone()
next_token_eval_mask = next_token_eval_mask * next_token_eval_mask.roll(shifts=-1, dims=1) * next_token_eval_mask.roll(shifts=1, dims=1)
for bos_token_index_ in bos_token_index:
next_token_eval_mask[bos_token_index_[0], bos_token_index_[1] : bos_token_index_[1] + 1] = 1
next_token_eval_mask[bos_token_index_[0], bos_token_index_[1] + 1 : bos_token_index_[1] + 2] = \
mask[bos_token_index_[0], bos_token_index_[1] + 2 : bos_token_index_[1] + 3]
next_token_eval_mask[eos_token_index[:, 0], eos_token_index[:, 1]] = 0
# mask for state tokens
next_state_eval_mask = mask.clone()
next_state_eval_mask = next_state_eval_mask * next_state_eval_mask.roll(shifts=-1, dims=1) * next_state_eval_mask.roll(shifts=1, dims=1)
for bos_token_index_ in bos_token_index:
next_state_eval_mask[bos_token_index_[0], :bos_token_index_[1]] = 0
next_state_eval_mask[bos_token_index_[0], bos_token_index_[1] : bos_token_index_[1] + 1] = 1
next_state_eval_mask[bos_token_index_[0], bos_token_index_[1] + 1 : bos_token_index_[1] + 2] = \
mask[bos_token_index_[0], bos_token_index_[1] + 2 : bos_token_index_[1] + 3]
for eos_token_index_ in eos_token_index:
next_state_eval_mask[eos_token_index_[0], eos_token_index_[1] + 1:] = 1
next_state_eval_mask[eos_token_index_[0], eos_token_index_[1] : eos_token_index_[1] + 1] = \
mask[eos_token_index_[0], eos_token_index_[1] - 1 : eos_token_index_[1]]
# the last timestep is the beginning of the sequence (also the input)
next_token_eval_mask[:, 0] = mask[:, 0] * mask[:, 1]
next_state_eval_mask[:, 0] = mask[:, 0] * mask[:, 1]
next_token_eval_mask[:, -1] = 0
next_state_eval_mask[:, -1] = 0
if next_token_index_gt[next_token_eval_mask].min() < 0:
raise RuntimeError()
return {'token_pos': pos_a,
'token_heading': head_a,
'next_token_idx_gt': next_token_index_gt,
'next_state_idx_gt': next_state_index_gt,
'next_token_eval_mask': next_token_eval_mask,
'raw_agent_valid_mask': data['agent']['raw_agent_valid_mask'],
'state_token': agent_state_index,
'grid_index': agent_grid_token_idx,
}
def _build_seq(self, device, data, num_agent, num_step, av_index, sort_indices):
"""
Args:
sort_indices (torch.Tensor): shape (num_agent, num_atep)
"""
ptr = data['agent']['ptr']
num_graph = len(ptr) - 1
# sort_indices = sort_indices[:self.num_seed_feature]
seq_mask = torch.ones(num_graph * self.num_seed_feature, num_step, num_agent + num_graph * self.num_seed_feature, device=device).bool()
seq_mask[..., -num_graph * self.num_seed_feature:] = False
for b in range(num_graph):
batch_sort_indices = sort_indices[ptr[b] : ptr[b + 1]]
for t in range(num_step):
for s in range(self.num_seed_feature):
seq_mask[b * self.num_seed_feature + s, t, batch_sort_indices[s:, t].flatten().long()] = False
if self.seed_attn_to_av:
seq_mask[..., av_index] = True
seq_mask = seq_mask.transpose(0, 1).reshape(-1, num_agent + num_graph * self.num_seed_feature)
seq_index = torch.cat([torch.zeros(num_agent), (torch.arange(self.num_seed_feature) + 1).repeat(num_graph)]).to(device)
seq_index = seq_index[:, None].repeat(1, num_step)
# 0, 0, 0, ..., 1, 2, 3, ...
for b in range(num_graph):
batch_sort_indices = sort_indices[ptr[b] : ptr[b + 1]]
for t in range(num_step):
for s in range(self.num_seed_feature):
seq_index[batch_sort_indices[s : s + 1, t].flatten().long() + ptr[b], t] = s + 1
# 0, 2, 1, ..., N+1, N+2, ...
# for b in range(num_graph):
# batch_sort_indices = sort_indices[ptr[b] : ptr[b + 1]]
# batch_agent_valid_mask = data['agent']['inrange_mask'][ptr[b] : ptr[b + 1]] & \
# data['agent']['raw_agent_valid_mask'][ptr[b] : ptr[b + 1]] & \
# ~data['agent']['bos_mask'][ptr[b] : ptr[b + 1]]
# batch_agent_valid_mask[av_index[b]] = False
# for t in range(num_step):
# batch_num_valid_agent_t = batch_agent_valid_mask[:, t].sum()
# seq_index[num_agent + b * self.num_seed_feature : num_agent + (b + 1) * self.num_seed_feature, t] += batch_num_valid_agent_t
# random_seq_index = torch.zeros(ptr[b + 1] - ptr[b], device=device)
# random_seq_index[batch_agent_valid_mask[:, t]] = torch.randperm(batch_num_valid_agent_t, device=device).float() + 1 # starts from 1
# seq_index[ptr[b] : ptr[b + 1], t] = random_seq_index
# for s in range(self.num_seed_feature):
# seq_index[batch_sort_indices[s : s + 1, t].flatten().long() + ptr[b], t] = s + 1 + batch_num_valid_agent_t.float()
# 0, 0, 0, ..., N+1, N+2, ...
# for b in range(num_graph):
# batch_sort_indices = sort_indices[ptr[b] : ptr[b + 1]]
# batch_agent_valid_mask = data['agent']['inrange_mask'][ptr[b] : ptr[b + 1]] & \
# data['agent']['raw_agent_valid_mask'][ptr[b] : ptr[b + 1]] & \
# ~data['agent']['bos_mask'][ptr[b] : ptr[b + 1]]
# batch_agent_valid_mask[av_index[b]] = False
# for t in range(num_step):
# batch_num_valid_agent_t = batch_agent_valid_mask[:, t].sum()
# seq_index[num_agent + b * self.num_seed_feature : num_agent + (b + 1) * self.num_seed_feature, t] += batch_num_valid_agent_t
# for s in range(self.num_seed_feature):
# seq_index[batch_sort_indices[s : s + 1, t].flatten().long() + ptr[b], t] = s + 1 + batch_num_valid_agent_t.float()
seq_index[av_index] = 0
return seq_mask, seq_index
def _build_occ_gt(self, data, seq_mask, pos_rel_index_gt, pos_rel_index_gt_seed=None, mask_seed=None,
edge_index=None, mode='edge_index'):
"""
Args:
seq_mask (torch.Tensor): shape (num_step * num_seed_feature, num_agent + self.num_seed_feature)
pos_rel_index_gt (torch.Tensor): shape (num_agent, num_step)
pos_rel_index_gt_seed (torch.Tensor): shape (num_seed, num_step)
"""
num_agent = data['agent']['state_idx'].shape[0] + data.num_graphs * self.num_seed_feature
num_step = data['agent']['state_idx'].shape[1]
data['agent']['agent_occ'] = torch.zeros(data.num_graphs * self.num_seed_feature, num_step, self.attr_tokenizer.grid_size,
device=data['agent']['state_idx'].device).long()
data['agent']['map_occ'] = torch.zeros(data.num_graphs, num_step, self.attr_tokenizer.grid_size,
device=data['agent']['state_idx'].device).long()
if mode == 'edge_index':
assert edge_index is not None, f"Need edge_index input!"
for src_index in torch.unique(edge_index[1]):
# decode src
src_row = src_index % num_agent - (num_agent - data.num_graphs * self.num_seed_feature)
src_col = src_index // num_agent
# decode tgt
tgt_indexes = edge_index[0, edge_index[1] == src_index]
tgt_rows = tgt_indexes % num_agent
tgt_cols = tgt_indexes // num_agent
assert tgt_rows.max() < num_agent - data.num_graphs * self.num_seed_feature, f"Invalid {tgt_rows}"
assert torch.unique(tgt_cols).shape[0] == 1 and torch.unique(tgt_cols)[0] == src_col
data['agent']['agent_occ'][src_row, src_col, pos_rel_index_gt[tgt_rows, tgt_cols]] = 1
else:
seq_mask = seq_mask.reshape(num_step, self.num_seed_feature, -1).transpose(0, 1)[..., :-self.num_seed_feature]
for s in range(self.num_seed_feature):
for t in range(num_step):
index = pos_rel_index_gt[seq_mask[s, t], t]
data['agent']['agent_occ'][s, t, index[index != -1]] = 1
if t > 0 and s < pos_rel_index_gt_seed.shape[0] and mask_seed[s, t - 1]: # insert agents
data['agent']['agent_occ'][s, t, pos_rel_index_gt_seed[s, t - 1]] = -1
ptr = data['pt_token']['ptr']
pt_grid_token_idx = data['agent']['pt_grid_token_idx'] # (t, num_pt)
for b in range(data.num_graphs):
batch_pt_grid_token_idx = pt_grid_token_idx[:, ptr[b] : ptr[b + 1]]
for t in range(num_step):
data['agent']['map_occ'][b, t, batch_pt_grid_token_idx[t][batch_pt_grid_token_idx[t] != -1]] = 1
data['agent']['map_occ'] = data['agent']['map_occ'].repeat_interleave(repeats=self.num_seed_feature, dim=0)
def forward(self,
data: HeteroData,
map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
pos_a = data['agent']['token_pos'].clone() # (a, t, 2)
head_a = data['agent']['token_heading'].clone() # (a, t)
num_agent, num_step, traj_dim = pos_a.shape # e.g. (50, 18, 2)
agent_shape = data['agent']['shape'][:, self.num_historical_steps - 1].clone() # (a, 3)
agent_token_index = data['agent']['token_idx'].clone() # (a, t)
agent_state_index = data['agent']['state_idx'].clone()
agent_type_index = data['agent']['type'].clone()
av_index = data['agent']['av_index'].long()
ego_pos = pos_a[av_index]
ego_head = head_a[av_index]
_, head_vector_a = self._build_vector_a(pos_a, head_a, agent_state_index)
agent_grid_token_idx = data['agent']['grid_token_idx']
agent_grid_offset_xy = data['agent']['grid_offset_xy']
agent_head_token_idx = data['agent']['heading_token_idx']
agent_pos_xy = data['agent']['pos_xy']
agent_heading_theta = data['agent']['heading_theta']
sort_indices = data['agent']['sort_indices']
device = pos_a.device
feat_a = self._agent_token_embedding(data,
agent_token_index,
agent_state_index,
agent_grid_token_idx,
pos_a,
head_a,
av_index=av_index)
raw_feat_a = feat_a[:-data.num_graphs * self.num_seed_feature].clone()
raw_feat_seed = feat_a[-data.num_graphs * self.num_seed_feature:].clone()
# build masks
mask = data['agent']['raw_agent_valid_mask'].clone()
temporal_mask = mask.clone()
interact_mask = mask.clone()
is_bos = agent_state_index == self.enter_state
is_eos = agent_state_index == self.exit_state
bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1)) # not `-1`
temporal_mask = torch.ones_like(mask)
motion_mask = torch.arange(mask.shape[1]).expand(mask.shape[0], -1).to(device)
motion_mask = (motion_mask > bos_index[:, None]) & (motion_mask <= eos_index[:, None])
temporal_mask[motion_mask] = mask[motion_mask]
temporal_mask = torch.cat([temporal_mask, torch.ones(data.num_graphs * self.num_seed_feature, *temporal_mask.shape[1:], device=device)]).bool()
interact_mask[agent_state_index == self.enter_state] = True
interact_mask = torch.cat([interact_mask, torch.ones(data.num_graphs * self.num_seed_feature, *interact_mask.shape[1:], device=device)]).bool() # placeholder
pos_a_p, head_a_p, state_a_p, head_vector_a_p, grid_index_a_p, pad_mask = \
self._pad_feat(data.num_graphs, av_index, pos_a, head_a, agent_state_index, head_vector_a, agent_grid_token_idx)
edge_index_t, r_t = self._build_temporal_edge(data, pos_a_p, head_a_p, state_a_p, head_vector_a_p, temporal_mask)
# placeholder for seed agent
batch_s = torch.cat([
torch.cat([data['agent']['batch'], torch.arange(data.num_graphs, device=device
).repeat_interleave(repeats=self.num_seed_feature, dim=0)], dim=0)
+ data.num_graphs * t for t in range(num_step)
], dim=0)
batch_pl = torch.cat([data['pt_token']['batch'] + data.num_graphs * t for t in range(num_step)], dim=0)
seq_mask, seq_index = self._build_seq(device, data, num_agent, num_step, av_index, sort_indices)
plot_kwargs = dict(is_bos=agent_state_index == self.enter_state)
edge_index_a2a, r_a2a, (na2a, na2sa) = self._build_interaction_edge(data, pos_a_p, head_a_p, state_a_p, head_vector_a_p, batch_s,
interact_mask, pad_mask=pad_mask, av_index=av_index,
seq_mask=seq_mask, seq_index=seq_index, grid_index_a=grid_index_a_p, **plot_kwargs)
edge_index_pl2a, r_pl2a, (npl2a, npl2sa) = self._build_map2agent_edge(data, pos_a_p, head_a_p, state_a_p, head_vector_a_p, batch_s, batch_pl,
interact_mask, pad_mask=pad_mask, av_index=av_index)
interact_mask = interact_mask[:-data.num_graphs * self.num_seed_feature]
# pos_a_s, head_a_s, state_a_s, head_vector_a_s, mask_a_s = self._build_seed_feat(data, pos_a_p, head_a_p, state_a_p, head_vector_a_p, ~pad_mask,
# sort_indices, av_index=av_index)
# edge_index_sa2sa, r_sa2sa = self._build_sa2sa_edge(data, pos_a_s, head_a_s, state_a_s, head_vector_a_s, batch_s, mask=mask_a_s)
# for i in range(self.num_layers):
# feat_a = feat_a.reshape(-1, self.hidden_dim) # (a, t, d) -> (a*t, d)
# feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t)
# feat_a = feat_a.reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim)
# feat_a = self.pt2a_attn_layers[i]((
# map_enc['x_pt'].repeat_interleave(repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(
# -1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a)
# feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a)
# feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1)
# predict next motions
for i in range(self.num_layers):
feat_a = feat_a.reshape(-1, self.hidden_dim) # (a, t, d) -> (a*t, d)
feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t)
feat_a = feat_a.reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim)
feat_a = self.pt2a_attn_layers[i]((
map_enc['x_pt'].repeat_interleave(repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(
-1, self.hidden_dim), feat_a), r_pl2a[:npl2a], edge_index_pl2a[:, :npl2a])
feat_a = self.a2a_attn_layers[i](feat_a, r_a2a[:na2a], edge_index_a2a[:, :na2a])
feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1)
feat_ea = feat_a[:-data.num_graphs * self.num_seed_feature]
# next motion token
next_token_prob = self.token_predict_head(feat_ea) # (a, t, token_size)
next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)
_, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1) # (a, t, 10)
next_token_index_gt = agent_token_index.roll(shifts=-1, dims=1)
# next state token
next_state_prob = self.state_predict_head(feat_ea)
next_state_idx = next_state_prob.softmax(dim=-1).argmax(dim=-1, keepdim=True) # (a, t, 1)
next_state_index_gt = agent_state_index.roll(shifts=-1, dims=1) # (invalid, valid, exit)
# predict next agents: coarse stage
grid_agent_occ_gt_seed = grid_pt_occ_gt_seed = None
if self.use_grid_token:
self._build_occ_gt(data, seq_mask, agent_grid_token_idx.long(), edge_index=edge_index_a2a[:, -na2sa:], mode='edge_index')
grid_agent_occ_gt_seed = data['agent']['agent_occ']
grid_pt_occ_gt_seed = data['agent']['map_occ']
if self.use_grid_token:
occ_embed_a = self.seed_agent_occ_embed(grid_agent_occ_gt_seed.transpose(0, 1).reshape(-1, self.grid_size).float())
# occ_embed_pt = self.seed_pt_occ_embed(grid_pt_occ_gt_seed.transpose(0, 1).reshape(-1, self.grid_size).float())
edge_index_occ2sa_src = torch.arange(feat_a.shape[0] * feat_a.shape[1], device=device).long()
edge_index_occ2sa_src = edge_index_occ2sa_src[~pad_mask.transpose(0, 1).reshape(-1)]
edge_index_occ2sa_tgt = torch.arange(occ_embed_a.shape[0], device=device).long()
edge_index_occ2sa = torch.stack([edge_index_occ2sa_tgt, edge_index_occ2sa_src], dim=0)
feat_sa = torch.cat([raw_feat_a, raw_feat_seed])
# feat_sa = feat_a
for i in range(self.seed_layers):
# feat_sa = feat_a.reshape(-1, self.hidden_dim)
# feat_sa = self.sa2sa_attn_layers[i](feat_sa, r_sa2sa, edge_index_sa2sa)
feat_sa = feat_sa.reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim)
if self.use_grid_token:
feat_sa = self.occ2sa_attn_layers[i]((occ_embed_a, feat_sa), None, edge_index_occ2sa)
feat_sa = self.pt2sa_attn_layers[i]((
map_enc['x_pt'].repeat_interleave(repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(
-1, self.hidden_dim), feat_sa), r_pl2a[-npl2sa:], edge_index_pl2a[:, -npl2sa:])
feat_sa = self.a2sa_attn_layers[i](feat_sa, r_a2a[-na2sa:], edge_index_a2a[:, -na2sa:])
feat_sa = feat_sa.reshape(num_step, -1, self.hidden_dim).transpose(0, 1)
feat_seed = feat_sa[-data.num_graphs * self.num_seed_feature:]
# seed agent
next_state_prob_seed = self.seed_state_predict_head(feat_seed)
raw_next_state_prob_seed = next_state_prob_seed.clone()
next_state_idx_seed = next_state_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True) # (seed_size, t, 1)
next_type_prob_seed = self.seed_type_predict_head(feat_seed)
next_type_idx_seed = next_type_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True)
next_type_index_gt = agent_type_index[:, None].repeat(1, num_step).long()
next_shape_seed = self.seed_shape_predict_head(feat_seed)
next_shape_gt = agent_shape[:, None].repeat(1, num_step, 1).float()
if self.use_grid_token:
next_pos_rel_prob_seed = self.seed_pos_rel_token_predict_head(feat_seed)
next_pos_rel_idx_seed = next_pos_rel_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True)
else:
next_pos_rel_prob_seed = self.seed_pos_rel_xy_predict_head(feat_seed)
next_pos_rel_xy_seed = torch.tanh(next_pos_rel_prob_seed)
next_pos_rel_index_gt = agent_grid_token_idx.long()
next_pos_rel_xy_gt = agent_pos_xy.float() / self.pl2seed_radius
# decode grid index of neighbor agents
if self.use_grid_token:
neighbor_agent_grid_index_gt = grid_index_a_p.transpose(0, 1).reshape(-1)[edge_index_a2a[0, -na2sa:]]
neighbor_pt_grid_index_gt = data['agent']['pt_grid_token_idx'].reshape(-1)[edge_index_pl2a[0, -npl2sa:]]
neighbor_agent_grid_idx = self.grid_index_head(r_a2a[-na2sa:])
neighbor_pt_grid_idx = self.grid_index_head(r_pl2a[-npl2sa:])
neighbor_agent_grid_index_eval_mask = torch.zeros_like(neighbor_agent_grid_index_gt).bool()
neighbor_pt_grid_index_eval_mask = torch.zeros_like(neighbor_pt_grid_index_gt).bool()
neighbor_agent_grid_index_eval_mask[torch.randperm(neighbor_agent_grid_index_gt.shape[0])[:180]] = True
neighbor_pt_grid_index_eval_mask[torch.randperm(neighbor_pt_grid_index_gt.shape[0])[:600]] = True
# occupancy prediction
grid_agent_occ_seed = grid_pt_occ_seed = grid_agent_occ_eval_mask_seed = grid_pt_occ_eval_mask_seed = None
if self.predict_occ:
# grid_occ_embed = self.grid_occ_embed(self.grid_token_emb[:-1])
grid_agent_occ_seed = self.grid_agent_occ_head(feat_seed) # (s, t, d)
grid_pt_occ_seed = self.grid_pt_occ_head(feat_seed)
# refine stage
batch_s = torch.cat([data['agent']['batch'] + data.num_graphs * t for t in range(num_step)], dim=0)
batch_pl = torch.cat([data['pt_token']['batch'] + data.num_graphs * t for t in range(num_step)], dim=0)
mask_sa = torch.zeros_like(agent_state_index).bool()
for t in range(mask_sa.shape[1]):
availabel_rows = ((agent_state_index[:, t] != self.invalid_state) &
(agent_grid_token_idx[:, t] != -1)).nonzero()[..., 0]
mask_sa[availabel_rows[torch.randperm(availabel_rows.shape[0])[:data.num_graphs * 10]], t] = True
mask_sa[agent_state_index == self.enter_state] = True
mask_sa[:, 0] = False # ignore the first step
mask_sa[av_index] = False # ignore self
state_sa = torch.full_like(agent_state_index, self.invalid_state).long()
state_sa[mask_sa] = self.enter_state
sa_indices = torch.nonzero(mask_sa)
pos_sa = pos_a.clone()
head_sa = head_a.clone()
expanded_av_index = av_index.repeat_interleave(repeats=data['batch_size_a'], dim=0)
head_sa[sa_indices[:, 0], sa_indices[:, 1]] = head_a[expanded_av_index[sa_indices[:, 0]], sa_indices[:, 1]]
motion_vector_sa, head_vector_sa = self._build_vector_a(pos_a, head_sa, state_sa)
motion_vector_sa[mask_sa] = self.motion_gap # fix the case e.g. [0, 0, 1, '1', 0, 1]
offset_pos = pos_a - data['ego_pos'].repeat_interleave(repeats=data['batch_size_a'], dim=0)
agent_grid_emb = self.grid_token_emb[agent_grid_token_idx.long()]
feat_sa, _ = self._build_agent_feature(num_step, pos_a.device,
motion_vector_sa,
head_vector_sa,
agent_grid_emb=agent_grid_emb,
offset_pos=offset_pos,
type=next_type_index_gt.long(),
shape=next_shape_gt,
state=state_sa,
n=num_agent)
feat_sa[~mask_sa] = raw_feat_a[~mask_sa].clone()
edge_index_a2sa, r_a2sa = self._build_a2sa_edge(data, pos_a, head_sa, head_vector_sa, batch_s,
interact_mask, mask_sa=mask_sa, r=self.a2sa_radius)
edge_index_pl2sa, r_pl2sa = self._build_map2sa_edge(data, pos_a, head_sa, head_vector_sa, batch_s, batch_pl,
mask_sa=mask_sa, r=self.pl2sa_radius)
# sanity check
global_index = set(torch.nonzero(mask_sa.transpose(0, 1).reshape(-1).int())[:, 0].tolist())
a2sa_index = set(edge_index_a2sa[1].tolist())
pl2sa_index = set(edge_index_pl2sa[1].tolist())
assert a2sa_index.issubset(global_index) and pl2sa_index.issubset(global_index), "Invalid index!"
select_mask = torch.zeros_like(mask_sa.view(-1)).bool()
select_mask[torch.unique(edge_index_a2sa[1])] = True
select_mask[torch.unique(edge_index_pl2sa[1])] = True
mask_sa[~select_mask.reshape(num_step, -1).transpose(0, 1)] = False
for i in range(self.seed_layers):
feat_sa = feat_sa.transpose(0, 1).reshape(-1, self.hidden_dim)
feat_sa = self.pt2a_attn_layers[i]((
map_enc['x_pt'].repeat_interleave(repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(
-1, self.hidden_dim), feat_sa), r_pl2sa, edge_index_pl2sa)
feat_sa = self.a2a_attn_layers[i](feat_sa, r_a2sa, edge_index_a2sa)
feat_sa = feat_sa.reshape(num_step, -1, self.hidden_dim).transpose(0, 1)
if self.use_head_token:
next_head_rel_theta_seed = None
next_head_rel_prob_seed = self.seed_heading_rel_token_predict_head(feat_sa)
next_head_rel_idx_seed = next_head_rel_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True)
else:
next_head_rel_prob_seed = None
next_head_rel_theta_seed = self.seed_heading_rel_theta_predict_head(feat_sa)
next_head_rel_theta_seed = torch.tanh(next_head_rel_theta_seed)[..., 0]
next_head_rel_index_gt_seed = agent_head_token_idx.long()
next_head_rel_theta_gt_seed = agent_heading_theta.float() / torch.pi # [-1, 1]
next_offset_xy_seed = None
if self.use_grid_token:
next_offset_xy_seed = self.seed_offset_xy_predict_head(feat_sa)
next_offset_xy_seed = torch.tanh(next_offset_xy_seed) * 2 # [-2, 2]
next_offset_xy_gt_seed = agent_grid_offset_xy.float()
# next token prediction mask
bos_token_index = torch.nonzero(agent_state_index == self.enter_state)
eos_token_index = torch.nonzero(agent_state_index == self.exit_state)
# mask for motion tokens
next_token_eval_mask = mask.clone()
next_token_eval_mask = next_token_eval_mask * next_token_eval_mask.roll(shifts=-1, dims=1) * next_token_eval_mask.roll(shifts=1, dims=1)
for bos_token_index_ in bos_token_index:
next_token_eval_mask[bos_token_index_[0], bos_token_index_[1] : bos_token_index_[1] + 1] = 1
next_token_eval_mask[bos_token_index_[0], bos_token_index_[1] + 1 : bos_token_index_[1] + 2] = \
mask[bos_token_index_[0], bos_token_index_[1] + 2 : bos_token_index_[1] + 3]
next_token_eval_mask[eos_token_index[:, 0], eos_token_index[:, 1]] = 0
# mask for state tokens
next_state_eval_mask = mask.clone()
next_state_eval_mask = next_state_eval_mask * next_state_eval_mask.roll(shifts=-1, dims=1) * next_state_eval_mask.roll(shifts=1, dims=1)
for bos_token_index_ in bos_token_index:
next_state_eval_mask[bos_token_index_[0], :bos_token_index_[1]] = 0
next_state_eval_mask[bos_token_index_[0], bos_token_index_[1] : bos_token_index_[1] + 1] = 1
next_state_eval_mask[bos_token_index_[0], bos_token_index_[1] + 1 : bos_token_index_[1] + 2] = \
mask[bos_token_index_[0], bos_token_index_[1] + 2 : bos_token_index_[1] + 3]
for eos_token_index_ in eos_token_index:
next_state_eval_mask[eos_token_index_[0], eos_token_index_[1] + 1:] = 1
next_state_eval_mask[eos_token_index_[0], eos_token_index_[1] : eos_token_index_[1] + 1] = \
mask[eos_token_index_[0], eos_token_index_[1] - 1 : eos_token_index_[1]]
next_state_eval_mask_seed = torch.ones_like(next_state_idx_seed[..., 0])
# the last timestep is the beginning of the sequence (also the input)
next_token_eval_mask[:, 0] = mask[:, 0] * mask[:, 1]
next_state_eval_mask[:, 0] = mask[:, 0] * mask[:, 1]
next_token_eval_mask[:, -1] = 0
next_state_eval_mask[:, -1] = 0
next_state_eval_mask_seed[:, 0] = 0
# no invalid motion token will be supervised
if (next_token_index_gt[next_token_eval_mask] < 0).any():
raise RuntimeError("Found invalid motion index.")
# seed agents
# is_next_bos = next_state_index_gt.roll(shifts=1, dims=1) == self.enter_state
# is_next_bos[:, 0] = False # we filter out the last timestep
# is_next_bos[av_index] = False
# num_seed_gt = is_next_bos.sum(dim=0).max()
# pred_indices = torch.zeros((num_seed_gt, num_step, 1), device=device).long()
# gt_indices = torch.zeros((num_seed_gt, num_step), device=device).long()
# if num_seed_gt > 0:
# outputs = dict(state_pred=next_state_prob_seed,
# pos_pred=next_pos_rel_prob_seed,
# shape_pred=next_shape_seed)
# targets = dict(state_gt=next_state_index_gt.clone(),
# pos_gt=next_pos_rel_index_gt.clone(),
# shape_gt=next_shape_gt.clone())
# indices = self.matcher(outputs, targets,
# valid_mask=is_next_bos,
# ptr_gt=data['agent']['ptr'],
# ptr_pred=torch.arange(data.num_graphs + 1, device=device) * self.num_seed_feature)
# pred_indices = indices[0][..., None].to(device)
# gt_indices = indices[1].to(device)
pred_indices = []
gt_indices = []
agent_ptr = data['agent']['ptr']
num_seed_gt = 0
for b in range(data.num_graphs):
batch_sort_indices = sort_indices[agent_ptr[b] : agent_ptr[b + 1]]
batch_num_seed_gt = min(self.num_seed_feature, batch_sort_indices.shape[0])
num_seed_gt += batch_num_seed_gt
pred_indices.append((torch.arange(batch_num_seed_gt, device=device) + b * self.num_seed_feature
)[:, None, None].repeat(1, num_step, 1).long())
gt_indices.append(batch_sort_indices[:batch_num_seed_gt] + agent_ptr[b])
pred_indices = torch.concat(pred_indices)
gt_indices = torch.concat(gt_indices)
n = pred_indices.shape[0]
res_pred_indices = []
for t in range(next_state_idx_seed.shape[1]):
indices_t = torch.arange(next_state_idx_seed.shape[0]).to(device)
selected_pred_mask = torch.zeros_like(indices_t)
selected_pred_mask[pred_indices[:, t]] = 1
res_pred_indices.append(indices_t[~selected_pred_mask.bool()])
res_pred_indices = torch.stack(res_pred_indices, dim=1)
padded_pred_indices = torch.concat([pred_indices, res_pred_indices[..., None]])
next_state_idx_seed = torch.gather(next_state_idx_seed, dim=0, index=padded_pred_indices)
next_state_prob_seed = torch.gather(next_state_prob_seed, dim=0, index=padded_pred_indices.expand(
-1, -1, next_state_prob_seed.shape[-1]))
next_state_index_gt_seed = torch.gather(agent_state_index, dim=0, index=gt_indices)
next_state_index_gt_seed = torch.concat([next_state_index_gt_seed,
torch.zeros((next_state_prob_seed.shape[0] - next_state_index_gt_seed.shape[0], next_state_index_gt_seed.shape[1]), device=device)]).long()
seed_enter_mask = next_state_index_gt_seed == self.enter_state
next_state_index_gt_seed = torch.full(next_state_index_gt_seed.shape, self.seed_state_type.index('invalid'), device=device)
next_state_index_gt_seed[seed_enter_mask] = self.seed_state_type.index('enter')
next_type_idx_seed = torch.gather(next_type_idx_seed, dim=0, index=pred_indices)
next_type_prob_seed = torch.gather(next_type_prob_seed, dim=0, index=pred_indices.expand(
-1, -1, next_type_prob_seed.shape[-1]))
next_type_index_gt_seed = torch.gather(next_type_index_gt, dim=0, index=gt_indices)
if self.use_grid_token:
next_pos_rel_xy_seed = None
next_pos_rel_prob_seed = torch.gather(next_pos_rel_prob_seed, dim=0, index=pred_indices.expand(
-1, -1, next_pos_rel_prob_seed.shape[-1]))
else:
next_pos_rel_prob_seed = None
next_pos_rel_xy_seed = torch.gather(next_pos_rel_xy_seed, dim=0, index=pred_indices.expand(
-1, -1, next_pos_rel_xy_seed.shape[-1]))
next_pos_rel_index_gt_seed = torch.gather(next_pos_rel_index_gt, dim=0, index=gt_indices)
next_pos_rel_xy_gt_seed = torch.gather(next_pos_rel_xy_gt, dim=0, index=gt_indices[..., None].expand(
-1, -1, next_pos_rel_xy_gt.shape[-1]))
next_shape_seed = torch.gather(next_shape_seed, dim=0, index=pred_indices.expand(
-1, -1, next_shape_seed.shape[-1]))
next_shape_gt_seed = torch.gather(next_shape_gt, dim=0, index=gt_indices[..., None].expand(
-1, -1, next_shape_gt.shape[-1]))
next_attr_eval_mask_seed = seed_enter_mask[:n]
next_attr_eval_mask_seed[:, 0] = False # we ignore the first step
next_attr_eval_mask_seed[next_pos_rel_index_gt_seed == self.grid_size // 2] = False
next_state_eval_mask[av_index] = 0 # we dont predict state for ego agent
if (torch.any(next_type_index_gt_seed[next_attr_eval_mask_seed] == AGENT_TYPE.index('seed')) \
or torch.any(torch.all(next_shape_gt_seed[next_attr_eval_mask_seed] == self.invalid_shape_value, dim=-1)) \
or torch.any(next_pos_rel_index_gt_seed[next_attr_eval_mask_seed] < 0)) and num_seed_gt > 0:
raise ValueError(f"Found invalid gt values in scenario {data['scenario_id'][0]}.")
next_state_index_gt[next_state_index_gt == self.exit_state] = self.valid_state_type.index('exit')
# build occ gt
if self.predict_occ:
# grid_agent_occ_seed = torch.einsum('s t d, g d -> s t g', grid_agent_occ_seed, grid_occ_embed)
# grid_pt_occ_seed = torch.einsum('s t d, g d -> s t g', grid_pt_occ_seed, grid_occ_embed)
# augmentation
# TODO: add convolution!!!
# grid_agent_occ_eval_mask_seed = torch.zeros_like(grid_agent_occ_seed).bool()
# grid_pt_occ_eval_mask_seed = torch.zeros_like(grid_agent_occ_seed).bool()
# gt_mask = grid_agent_occ_gt_seed.bool()
# gt_mask[:, 0] = False # ignore the first step
# gt_mask[..., self.grid_size // 2] = False # ignore self
# random_weights = torch.rand_like(grid_agent_occ_seed) * gt_mask
# _, topk_indices = random_weights.topk(10, dim=-1)
# grid_agent_occ_eval_mask_seed.scatter_(-1, topk_indices, True)
# grid_agent_occ_eval_mask_seed[~gt_mask] = False
# random_weights = torch.rand_like(grid_agent_occ_seed) * ~gt_mask
# _, topk_indices = random_weights.topk(10, dim=-1)
# grid_agent_occ_eval_mask_seed.scatter_(-1, topk_indices, True)
# grid_agent_occ_eval_mask_seed[:, 0] = False
# grid_agent_occ_eval_mask_seed[..., self.grid_size // 2] = False
# gt_mask = grid_pt_occ_gt_seed.bool()
# random_weights = torch.rand_like(grid_agent_occ_seed) * gt_mask
# _, topk_indices = random_weights.topk(256, dim=-1)
# grid_pt_occ_eval_mask_seed.scatter_(-1, topk_indices, True)
# grid_pt_occ_eval_mask_seed[~gt_mask] = False
# random_weights = torch.rand_like(grid_agent_occ_seed) * ~gt_mask
# _, topk_indices = random_weights.topk(256, dim=-1)
# grid_pt_occ_eval_mask_seed.scatter_(-1, topk_indices, True)
grid_occ_eval_mask_seed = torch.ones_like(grid_agent_occ_seed).bool()
grid_occ_eval_mask_seed[:, 0] = False
grid_occ_eval_mask_seed[..., self.grid_size // 2] = False
grid_agent_occ_eval_mask_seed = grid_pt_occ_eval_mask_seed = grid_occ_eval_mask_seed
# sanity check
# s = random.randint(0, self.num_seed_feature - 1)
# t = random.randint(0, num_step - 1)
# grid_index = grid_agent_occ_gt_seed[s, t].nonzero()[..., 0]
# check_mask = torch.zeros_like(pad_mask)
# check_mask[av_index + s + 1, t] = 1
# check_index = check_mask.transpose(0, 1).reshape(-1).nonzero()[..., 0]
# check_agent_index = edge_index_a2a[0, edge_index_a2a[1] == check_index[0]] % (num_agent + self.num_seed_feature)
# if not torch.all(grid_index == next_pos_rel_index_gt[check_agent_index, t].unique().sort()[0]):
# raise RuntimeError(f"Grid index not consistent s={s} t={t} scenario_id={data['scenario_id'][0]}")
target_indices = pred_indices.clone()
target_indices[~next_attr_eval_mask_seed] = -1
return {'x_a': feat_a,
'ego_pos': ego_pos,
# motion token
'next_token_idx': next_token_idx,
'next_token_prob': next_token_prob,
'next_token_idx_gt': next_token_index_gt,
'next_token_eval_mask': next_token_eval_mask.bool(),
# state token
'next_state_idx': next_state_idx,
'next_state_prob': next_state_prob,
'next_state_idx_gt': next_state_index_gt,
'next_state_eval_mask': next_state_eval_mask.bool(),
# seed agent
'next_state_idx_seed': next_state_idx_seed,
'next_state_prob_seed': next_state_prob_seed,
'next_state_idx_gt_seed': next_state_index_gt_seed,
'next_type_idx_seed': next_type_idx_seed,
'next_type_prob_seed': next_type_prob_seed,
'next_type_idx_gt_seed': next_type_index_gt_seed,
'next_pos_rel_prob_seed': next_pos_rel_prob_seed,
'next_pos_rel_index_gt_seed': next_pos_rel_index_gt_seed,
'next_pos_rel_xy_seed': next_pos_rel_xy_seed,
'next_pos_rel_xy_gt_seed': next_pos_rel_xy_gt_seed,
'next_head_rel_prob_seed': next_head_rel_prob_seed,
'next_head_rel_index_gt_seed': next_head_rel_index_gt_seed,
'next_head_rel_theta_seed': next_head_rel_theta_seed,
'next_head_rel_theta_gt_seed': next_head_rel_theta_gt_seed,
'next_offset_xy_seed': next_offset_xy_seed,
'next_offset_xy_gt_seed': next_offset_xy_gt_seed,
'next_shape_seed': next_shape_seed,
'next_shape_gt_seed': next_shape_gt_seed,
'grid_agent_occ_seed': grid_agent_occ_seed,
'grid_pt_occ_seed': grid_pt_occ_seed,
'grid_agent_occ_gt_seed': grid_agent_occ_gt_seed,
'grid_pt_occ_gt_seed': grid_pt_occ_gt_seed,
'neighbor_agent_grid_idx': neighbor_agent_grid_idx
if self.use_grid_token else None,
'neighbor_pt_grid_idx': neighbor_pt_grid_idx
if self.use_grid_token else None,
'neighbor_agent_grid_index_gt': neighbor_agent_grid_index_gt
if self.use_grid_token else None,
'neighbor_pt_grid_index_gt': neighbor_pt_grid_index_gt
if self.use_grid_token else None,
'target_indices': target_indices[..., 0],
'raw_next_state_prob_seed': raw_next_state_prob_seed,
'next_state_eval_mask_seed': next_state_eval_mask_seed.bool(),
'next_attr_eval_mask_seed': next_attr_eval_mask_seed.bool(),
'next_head_eval_mask_seed': mask_sa.bool(),
'grid_agent_occ_eval_mask_seed': grid_agent_occ_eval_mask_seed
if self.use_grid_token else None,
'grid_pt_occ_eval_mask_seed': grid_pt_occ_eval_mask_seed
if self.use_grid_token else None,
'neighbor_agent_grid_index_eval_mask': neighbor_agent_grid_index_eval_mask.bool()
if self.use_grid_token else None,
'neighbor_pt_grid_index_eval_mask': neighbor_pt_grid_index_eval_mask.bool()
if self.use_grid_token else None,
}
def inference(self,
data: HeteroData,
map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
filter_mask = data['agent']['state_idx'][:, (self.num_historical_steps - 1) // self.shift - 1] != self.invalid_state
seed_step_mask = data['agent']['state_idx'][:, (self.num_historical_steps - 1) // self.shift:] == self.enter_state
# seed_agent_index_per_step = [torch.nonzero(seed_step_mask[:, t]).squeeze(dim=-1) for t in range(seed_step_mask.shape[1])]
# num_historical_steps=11
eval_mask = data['agent']['valid_mask'][filter_mask, self.num_historical_steps - 1]
# agent attributes
agent_id = data['agent']['id'][filter_mask].clone()
agent_valid_mask = data['agent']['raw_agent_valid_mask'][filter_mask].clone() # token_valid_mask
pos_a = data['agent']['token_pos'][filter_mask].clone() # (a, t, 2)
token_a = data['agent']['token_idx'][filter_mask].clone() # (a, t)
state_a = data['agent']['state_idx'][filter_mask].clone()
head_a = data['agent']['token_heading'][filter_mask].clone()
shape_a = data['agent']['shape'][filter_mask].clone()
type_a = data['agent']['type'][filter_mask].clone()
grid_a = data['agent']['grid_token_idx'][filter_mask].clone()
gt_traj = data['agent']['position'][filter_mask, self.num_historical_steps:, :self.input_dim].contiguous()
agent_token_traj_all = data['agent']['token_traj_all'][filter_mask]
device = pos_a.device
max_agent_id = agent_id.max() # TODO: bs=1
if self.num_recurrent_steps_val == -1:
# self.num_recurrent_steps_val = 91 - 11 = 80
self.num_recurrent_steps_val = data["agent"]['position'].shape[1] - self.num_historical_steps
num_agent, num_ori_step, traj_dim = pos_a.shape
num_infer_step = (self.num_recurrent_steps_val + self.num_historical_steps) // self.shift
if num_infer_step > num_ori_step:
pad_shape = num_agent, num_infer_step - num_ori_step
agent_valid_mask = torch.cat([agent_valid_mask, torch.full(pad_shape, True, device=device)], dim=1)
pos_a = torch.cat([pos_a, torch.zeros((*pad_shape, pos_a.shape[-1]), device=device)], dim=1)
token_a = torch.cat([token_a, torch.full(pad_shape, -1, device=device)], dim=1)
state_a = torch.cat([state_a, torch.full(pad_shape, self.invalid_state, device=device)], dim=1)
head_a = torch.cat([head_a, torch.zeros(pad_shape, device=device)], dim=1)
grid_a = torch.cat([grid_a, torch.full(pad_shape, -1, device=device)], dim=1)
# TODO: support bs > 1 in inference !!!
num_removed_agent = int((~filter_mask[:data['agent']['av_index']]).sum())
data['batch_size_a'] -= num_removed_agent
av_index = data['agent']['av_index'] - num_removed_agent
# make future steps to zero
pos_a[:, (self.num_historical_steps - 1) // self.shift:] = 0
head_a[:, (self.num_historical_steps - 1) // self.shift:] = 0
token_a[:, (self.num_historical_steps - 1) // self.shift:] = -1
state_a[:, (self.num_historical_steps - 1) // self.shift:] = 0
grid_a[:, (self.num_historical_steps - 1) // self.shift:] = -1
motion_vector_a, head_vector_a = self._build_vector_a(pos_a, head_a, state_a)
agent_valid_mask[:, (self.num_historical_steps - 1) // self.shift:] = True
agent_valid_mask[~eval_mask] = False
agent_token_index = data['agent']['token_idx'][filter_mask]
agent_state_index = data['agent']['state_idx'][filter_mask]
(feat_a, agent_token_emb, agent_token_emb_veh, agent_token_emb_ped, agent_token_emb_cyc, categorical_embs,
trajectory_token_veh, trajectory_token_ped, trajectory_token_cyc) = self._agent_token_embedding(
data,
token_a,
state_a,
grid_a,
pos_a,
head_a,
inference=True,
filter_mask=filter_mask,
av_index=av_index,
)
raw_feat_a = feat_a.clone()
veh_mask = type_a == 0
cyc_mask = type_a == 2
ped_mask = type_a == 1
pred_traj = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val, 2, device=device) # (a, val_t, 2)
pred_head = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val, device=device)
pred_type = type_a.clone()
pred_shape = shape_a[:, (self.num_historical_steps - 1) // self.shift - 1] # (a, 3)
pred_state = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val, device=device)
pred_prob = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val // self.shift, device=device) # (a, val_t)
feat_a_t_dict = {}
feat_sa_t_dict = {}
# build masks (init)
mask = agent_valid_mask.clone()
temporal_mask = mask.clone()
interact_mask = mask.clone()
# find bos and eos index
is_bos = state_a == self.enter_state
is_eos = state_a == self.exit_state
bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_infer_step - 1))
temporal_mask = torch.ones_like(mask)
motion_mask = torch.arange(mask.shape[1]).expand(mask.shape[0], mask.shape[1]).to(mask.device)
motion_mask = (motion_mask > bos_index[:, None]) & (motion_mask <= eos_index[:, None])
motion_mask[:, self.num_historical_steps // self.shift:] = False
temporal_mask[motion_mask] = mask[motion_mask]
interact_mask = torch.ones_like(mask)
non_motion_mask = ~motion_mask
non_motion_mask[:, self.num_historical_steps // self.shift:] = False
interact_mask[non_motion_mask] = 0
interact_mask[state_a == self.enter_state] = 1
interact_mask[av_index] = 1
temporal_mask[:, (self.num_historical_steps - 1) // self.shift:] = 1
interact_mask[:, (self.num_historical_steps - 1) // self.shift:] = 1
self.log_message = ""
num_inserted_agents_total = num_inserted_agents = 0
next_token_idx_list = []
next_state_idx_list = []
grid_agent_occ_list = []
grid_pt_occ_list = []
grid_agent_occ_gt_list = []
next_state_prob_seed_list = []
next_pos_rel_prob_seed_list = []
agent_labels = [[None] * num_infer_step for _ in range(pos_a.shape[0])]
# append history motion/state tokens
for i in range((self.num_historical_steps - 1) // self.shift):
next_token_idx_list.append(agent_token_index[:, i : i + 1])
next_state_idx_list.append(agent_state_index[:, i : i + 1])
num_seed_feature = 1
insert_limit = 10
for t in (
pbar := tqdm(range(self.num_recurrent_steps_val // self.shift), leave=False, desc='Timestep ...')
):
# 1. insert agents
num_new_agents = 0
next_state_prob_seeds = torch.zeros(10 + 1, 1, device=device)
next_pos_rel_prob_seeds = torch.zeros(10 + 1, 1, self.attr_tokenizer.grid_size, device=device)
grid_agent_occ_seeds = torch.zeros(10 + 1, 1, self.attr_tokenizer.grid_size, device=device)
grid_pt_occ_seeds = torch.zeros(10 + 1, 1, self.attr_tokenizer.grid_size, device=device)
grid_agent_occ_gt_seeds = torch.zeros(10 + 1, 1, self.attr_tokenizer.grid_size, device=device)
valid_state_mask = state_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t] != self.invalid_state # TODO: only support bs=1
distance = ((pos_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t, :2] - pos_a[av_index, (self.num_historical_steps - 1) // self.shift - 1 + t, :2]) ** 2).sum(-1).sqrt()
inrange_mask = distance <= self.pl2seed_radius
seq_valid_mask = valid_state_mask & inrange_mask
seq_valid_mask[av_index] = False
res_seq_index = torch.zeros_like(state_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
res_seq_index[seq_valid_mask] = torch.randperm(seq_valid_mask.sum(), device=device) + 1
if t == 0:
inference_mask = temporal_mask.clone()
inference_mask = torch.cat([inference_mask, torch.ones_like(inference_mask[-1:]).repeat(
num_seed_feature, *([1] * (inference_mask.dim() - 1)))])
inference_mask[:, (self.num_historical_steps - 1) // self.shift + t:] = False
else:
inference_mask = torch.zeros_like(temporal_mask)
inference_mask = torch.cat([inference_mask, torch.zeros_like(inference_mask[-1:]).repeat(
num_seed_feature, *([1] * (inference_mask.dim() - 1)))])
inference_mask[:, (self.num_historical_steps - 1) // self.shift + t - 1] = True
plot_kwargs = dict()
p = 0
while True:
p += 1
if t == 0 or p - 1 >= insert_limit or self.disable_insertion: break
# rebuild inference mask since number of agents have changed
inference_mask = torch.zeros_like(temporal_mask)
inference_mask = torch.cat([inference_mask, torch.zeros_like(inference_mask[-1:]).repeat(
num_seed_feature, *([1] * (inference_mask.dim() - 1)))])
inference_mask[:, (self.num_historical_steps - 1) // self.shift + t - 1] = True
# sanity check: make sure seed agents will interact with **all** non-invalid agents
assert torch.all(state_a[:, :(self.num_historical_steps - 1) // self.shift + t][
~interact_mask[:, :(self.num_historical_steps - 1) // self.shift + t]] == self.invalid_state) and \
torch.all(state_a[:, :(self.num_historical_steps - 1) // self.shift + t][
interact_mask[:, :(self.num_historical_steps - 1) // self.shift + t]] != self.invalid_state), \
f"Got wrong with interact mask at scenario {data['scenario_id'][0]} t={t}!"
temporal_mask = torch.cat([temporal_mask, torch.ones_like(temporal_mask[:1]).repeat(
num_seed_feature, *([1] * (temporal_mask.dim() - 1)))]).bool()
interact_mask = torch.cat([interact_mask, torch.ones_like(interact_mask[:1]).repeat(
num_seed_feature, *([1] * (interact_mask.dim() - 1)))]).bool() # placeholder
pos_a_p, head_a_p, state_a_p, head_vector_a_p, grid_index_a_p, pad_mask = \
self._pad_feat(data.num_graphs, av_index, pos_a, head_a, state_a, head_vector_a, grid_a, num_seed_feature=num_seed_feature)
# sanity check
assert torch.all(~pad_mask[-num_seed_feature:]), "Got wrong with pad mask!"
batch_s = torch.arange(num_infer_step, device=device).repeat_interleave(num_agent + num_seed_feature)
batch_pl = torch.arange(num_infer_step, device=device).repeat_interleave(data['pt_token']['num_nodes'])
inference_mask_sa = torch.zeros_like(inference_mask).bool()
inference_mask_sa[:, (self.num_historical_steps - 1) // self.shift - 1 + t] = True
# 1.1 build seed agent features
if self.seed_use_ego_motion:
motion_vector_seed = motion_vector_a[av_index]
head_vector_seed = head_vector_a[av_index]
else:
motion_vector_seed = head_vector_seed = None
feat_seed, _ = self._build_agent_feature(num_infer_step, device,
motion_vector_seed,
head_vector_seed,
state_index=self.invalid_state,
n=num_seed_feature)
if feat_a.shape[1] != feat_seed.shape[1]:
assert t == 0, f"Unmatched timestep {feat_a.shape[1]} and {feat_seed.shape[1]}."
feat_a = torch.cat([feat_a, feat_a[:, -1:].repeat(1, feat_seed.shape[1] - feat_a.shape[1], 1)], dim=1)
raw_feat_a = feat_a.clone()
feat_a = torch.cat([feat_a, feat_seed], dim=0)
# 1.2 global feature aggregation
plot_kwargs.update(t=t, n=num_new_agents, tag='global_feature')
# 0, 0, 0, ..., N+1, N+2, ...
seq_index = torch.cat([torch.zeros(pos_a.shape[0] - num_new_agents), torch.arange(num_new_agents + 1) + 1]).to(device)
# 0, 2, 1, ..., N+1, N+2, ...
# seq_index = torch.cat([res_seq_index, torch.arange(num_new_agents + 1, device=device) + 1 + seq_valid_mask.sum()])
edge_index_a2seed, r_seed2a = self._build_a2sa_edge(data, pos_a_p, head_a_p, head_vector_a_p, batch_s,
interact_mask.clone(),
mask_sa=~pad_mask.clone(),
inference_mask=inference_mask_sa,
r=self.pl2seed_radius,
max_num_neighbors=300,
seq_index=seq_index,
grid_index_a=grid_index_a_p,
mode='insert', **plot_kwargs)
edge_index_pl2seed, r_pl2seed = self._build_map2sa_edge(data, pos_a_p, head_a_p, head_vector_a_p, batch_s, batch_pl,
mask_sa=~pad_mask.clone(),
inference_mask=inference_mask_sa,
r=self.pl2seed_radius,
max_num_neighbors=2048,
mode='insert')
temporal_mask = temporal_mask[:-num_seed_feature]
interact_mask = interact_mask[:-num_seed_feature]
if self.use_grid_token:
grid_agent_occ_gt_t_1 = torch.zeros((self.grid_size,), device=device).long()
grid_t_1 = grid_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]
grid_agent_occ_gt_t_1[grid_t_1[grid_t_1 != -1]] = 1
occ_embed_a = self.seed_agent_occ_embed(grid_agent_occ_gt_t_1.reshape(1, self.grid_size).float()).repeat(num_seed_feature, 1)
edge_index_occ2sa_src = torch.arange(feat_a.shape[0] * feat_a.shape[1], device=device).long()
edge_index_occ2sa_src = edge_index_occ2sa_src[(~pad_mask.transpose(0, 1).reshape(-1)) & (inference_mask_sa.transpose(0, 1).reshape(-1))]
edge_index_occ2sa_tgt = torch.arange(occ_embed_a.shape[0], device=device).long()
edge_index_occ2sa = torch.stack([edge_index_occ2sa_tgt, edge_index_occ2sa_src], dim=0)
for i in range(self.seed_layers):
feat_a = feat_a.transpose(0, 1).reshape(-1, self.hidden_dim)
if self.use_grid_token:
feat_a = self.occ2sa_attn_layers[i]((occ_embed_a, feat_a), None, edge_index_occ2sa)
feat_a = self.pt2sa_attn_layers[i]((
map_enc['x_pt'].repeat_interleave(repeats=num_infer_step, dim=0).reshape(-1, num_infer_step, self.hidden_dim).transpose(0, 1).reshape(
-1, self.hidden_dim), feat_a), r_pl2seed, edge_index_pl2seed)
feat_a = self.a2sa_attn_layers[i](feat_a, r_seed2a, edge_index_a2seed)
feat_a = feat_a.reshape(num_infer_step, -1, self.hidden_dim).transpose(0, 1)
feat_seed = feat_a[-num_seed_feature:] # (s, t, d)
ego_pos_t_1 = pos_a[av_index, (self.num_historical_steps - 1) // self.shift - 1 + t]
ego_head_t_1 = head_a[av_index, (self.num_historical_steps - 1) // self.shift - 1 + t]
# occupancy
if self.predict_occ:
grid_agent_occ_seed = self.grid_agent_occ_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) # (num_seed, grid_size)
grid_pt_occ_seed = self.grid_pt_occ_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
# insert prob
next_state_prob_seed = self.seed_state_predict_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
next_state_idx_seed = next_state_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True)
next_state_idx_seed[next_state_idx_seed == self.seed_state_type.index('invalid')] = self.invalid_state
next_state_idx_seed[next_state_idx_seed == self.seed_state_type.index('enter')] = self.enter_state
if int(os.getenv('DEBUG', 0)):
next_state_idx_seed = torch.full(next_state_idx_seed.shape, self.enter_state, device=device).long()
# type and shape
next_type_prob_seed = self.seed_type_predict_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
next_type_idx_seed = next_type_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True)
next_shape_seed = self.seed_shape_predict_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
# position
if self.use_grid_token:
next_pos_rel_prob_seed = self.seed_pos_rel_token_predict_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
next_pos_rel_prob_softmax = torch.softmax(next_pos_rel_prob_seed, dim=-1)
# if self.inference_filter_overlap:
# next_pos_rel_prob_softmax[..., grid_agent_occ_gt_t_1.bool()] = 1e-6 # diffuse!
topk_pos_rel_prob, next_pos_rel_idx_seed = torch.topk(next_pos_rel_prob_softmax, k=self.insert_beam_size, dim=-1)
sample_pos_rel_index = torch.multinomial(topk_pos_rel_prob, 1).to(device)
next_pos_rel_idx_seed = next_pos_rel_idx_seed.gather(dim=1, index=sample_pos_rel_index)
next_pos_seed = self.attr_tokenizer.decode_pos(next_pos_rel_idx_seed[..., 0], y=ego_pos_t_1, theta_y=ego_head_t_1)
if self.inference_filter_overlap:
if grid_agent_occ_gt_t_1[next_pos_rel_idx_seed[..., 0]]: # TODO: only support insert num=1 for each iter!!!
feat_a = raw_feat_a.clone()
continue
else:
next_pos_rel_xy_seed = self.seed_pos_rel_xy_predict_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
next_pos_seed = torch.tanh(next_pos_rel_xy_seed) * self.pl2seed_radius + ego_pos_t_1
if torch.all(next_state_idx_seed == self.invalid_state) or num_new_agents + 1 > insert_limit:
break
num_new_agent = 1 # TODO: fix this term
num_new_agents += 1
# ! 1.5. insert new agents and update attributes
# append new agent id
agent_id = torch.cat([agent_id, torch.tensor([max_agent_id + 1], device=device, dtype=agent_id.dtype)])
max_agent_id += 1
mask = torch.cat([mask, torch.ones(num_new_agent, num_infer_step, device=mask.device)], dim=0).bool()
temporal_mask = torch.cat([temporal_mask, torch.ones(num_new_agent, num_infer_step, device=temporal_mask.device)], dim=0).bool()
interact_mask = torch.cat([interact_mask, torch.ones(num_new_agent, num_infer_step, device=interact_mask.device)], dim=0).bool()
# initialize new attributes
new_pos_a = torch.zeros(num_new_agent, num_infer_step, 2, device=device)
new_head_a = torch.zeros(num_new_agent, num_infer_step, device=device)
new_grid_a = torch.full((num_new_agent, num_infer_step), -1, device=device)
new_state_a = torch.full((num_new_agent, num_infer_step), self.invalid_state, device=state_a.device)
new_shape_a = torch.full((num_new_agent, num_infer_step, 3), self.invalid_shape_value, device=device)
new_type_a = torch.full((num_new_agent, num_infer_step), AGENT_TYPE.index('seed'), device=device)
# add new attributes
new_pos_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t] = next_pos_seed
pos_a = torch.cat([pos_a, new_pos_a], dim=0)
new_head_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t] = ego_head_t_1 # dummy values
head_a = torch.cat([head_a, new_head_a], dim=0)
if self.use_grid_token:
new_grid_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t] = next_pos_rel_idx_seed
grid_a = torch.cat([grid_a, new_grid_a], dim=0)
new_type_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t:] = next_type_idx_seed
new_shape_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t:] = next_shape_seed[:, None]
pred_type = torch.cat([pred_type, new_type_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]])
pred_shape = torch.cat([pred_shape, new_shape_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]])
new_state_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t] = next_state_idx_seed # all enter state
state_a = torch.cat([state_a, new_state_a], dim=0)
mask[-num_new_agent:, :(self.num_historical_steps - 1) // self.shift + t] = 0
interact_mask[-num_new_agent:, :(self.num_historical_steps - 1) // self.shift - 1 + t] = 0
# placeholdersin pred_traj, pred_head, pred_state
new_pred_traj = torch.zeros(num_new_agent, self.num_recurrent_steps_val, 2, device=device)
new_pred_head = torch.zeros(num_new_agent, self.num_recurrent_steps_val, device=device)
new_pred_state = torch.zeros(num_new_agent, self.num_recurrent_steps_val, device=device)
if t > 0:
new_pred_traj[:, (t - 1) * 5 : t * 5] = new_pos_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t, None].repeat(1, 5, 1)
new_pred_head[:, (t - 1) * 5 : t * 5] = new_head_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t, None].repeat(1, 5)
new_pred_state[:, (t - 1) * 5 : t * 5] = next_state_idx_seed.repeat(1, 5)
pred_traj = torch.cat([pred_traj, new_pred_traj], dim=0)
pred_head = torch.cat([pred_head, new_pred_head], dim=0)
pred_state = torch.cat([pred_state, new_pred_state], dim=0)
# add new agents token embeddings
new_agent_token_emb = self.no_token_emb(torch.zeros(1, device=device).long())[None, :].repeat(num_new_agent, num_infer_step, 1)
new_agent_token_emb[:, (self.num_historical_steps - 1) // self.shift - 1 + t] = self.bos_token_emb(torch.zeros(1, device=device).long())
agent_token_emb = torch.cat([agent_token_emb, new_agent_token_emb])
next_veh_mask = next_type_idx_seed[..., 0] == AGENT_TYPE.index('veh')
next_ped_mask = next_type_idx_seed[..., 0] == AGENT_TYPE.index('ped')
next_cyc_mask = next_type_idx_seed[..., 0] == AGENT_TYPE.index('cyc')
veh_mask = torch.cat([veh_mask, next_veh_mask])
ped_mask = torch.cat([ped_mask, next_ped_mask])
cyc_mask = torch.cat([cyc_mask, next_cyc_mask])
# add new agents trajectory embeddings
new_agent_token_traj_all = torch.zeros((num_new_agent, self.token_size, self.shift + 1, 4, 2), device=device)
new_agent_token_traj_all[next_veh_mask] = trajectory_token_veh[None, ...]
new_agent_token_traj_all[next_ped_mask] = trajectory_token_ped[None, ...]
new_agent_token_traj_all[next_cyc_mask] = trajectory_token_cyc[None, ...]
agent_token_traj_all = torch.cat([agent_token_traj_all, new_agent_token_traj_all], dim=0)
new_categorical_embs = [self.type_a_emb(new_type_a.reshape(-1).long()), self.shape_emb(new_shape_a.reshape(-1, 3))]
categorical_embs = [torch.cat([categorical_embs[0], new_categorical_embs[0]], dim=0),
torch.cat([categorical_embs[1], new_categorical_embs[1]], dim=0)]
new_labels = [None] * num_infer_step
new_labels[(self.num_historical_steps - 1) // self.shift + t] = f'A{num_new_agents}' # the first step after bos step!
agent_labels.append(new_labels)
# 2. predict headings for seed agents
motion_vector_sa, head_vector_sa = self._build_vector_a(pos_a[-num_new_agent:],
head_a[-num_new_agent:],
state_a[-num_new_agent:])
# sanity check
assert torch.all(motion_vector_sa[:, :(self.num_historical_steps - 1) // self.shift - 1 + t] == self.invalid_motion_value) and \
torch.all(motion_vector_sa[:, (self.num_historical_steps - 1) // self.shift - 1 + t] == self.motion_gap), \
f"Found invalid values in motion_vectect_a at scenario {data['scenario_id'][0]} t={t}!"
motion_vector_sa[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0.
head_vector_sa[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0.
motion_vector_a = torch.cat([motion_vector_a, motion_vector_sa])
head_vector_a = torch.cat([head_vector_a, head_vector_sa])
new_offset_pos = pos_a[-num_new_agent:] - pos_a[av_index]
new_agent_grid_emb = self.grid_token_emb[new_grid_a] if self.use_grid_token else None
feat_sa, _ = self._build_agent_feature(num_infer_step, device,
motion_vector_sa,
head_vector_sa,
agent_token_emb=new_agent_token_emb,
agent_grid_emb=new_agent_grid_emb,
offset_pos=new_offset_pos,
categorical_embs_a=new_categorical_embs,
state=new_state_a)
feat_a = torch.cat([raw_feat_a, feat_sa])
batch_s = torch.arange(num_infer_step, device=device).repeat_interleave(num_agent + num_new_agent)
batch_pl = torch.arange(num_infer_step, device=device).repeat_interleave(data['pt_token']['num_nodes'])
# sanity check
assert pos_a.shape[0] == head_a.shape[0] == head_vector_a.shape[0] == interact_mask.shape[0] == \
pad_mask.shape[0] == inference_mask_sa.shape[0] == (num_agent + num_new_agent), f"Inconsistent shapes!"
plot_kwargs.update(tag='heading')
edge_index_a2sa, r_a2sa = self._build_a2sa_edge(data, pos_a, head_a, head_vector_a, batch_s,
interact_mask.clone(),
mask_sa=~pad_mask.clone(),
inference_mask=inference_mask_sa,
r=self.a2sa_radius,
max_num_neighbors=24,
**plot_kwargs)
edge_index_pl2sa, r_pl2sa = self._build_map2sa_edge(data, pos_a, head_a, head_vector_a, batch_s, batch_pl,
mask_sa=~pad_mask.clone(),
inference_mask=inference_mask_sa,
r=self.pl2sa_radius,
max_num_neighbors=128)
for i in range(self.seed_layers):
feat_a = feat_a.transpose(0, 1).reshape(-1, self.hidden_dim)
feat_a = self.pt2a_attn_layers[i]((
map_enc['x_pt'].repeat_interleave(repeats=num_infer_step, dim=0).reshape(-1, num_infer_step, self.hidden_dim).transpose(0, 1).reshape(
-1, self.hidden_dim), feat_a), r_pl2sa, edge_index_pl2sa)
feat_a = self.a2a_attn_layers[i](feat_a, r_a2sa, edge_index_a2sa)
feat_a = feat_a.reshape(num_infer_step, -1, self.hidden_dim).transpose(0, 1)
if self.use_head_token:
next_head_rel_prob_seed = self.seed_heading_rel_token_predict_head(feat_a[-num_new_agent:, (self.num_historical_steps - 1) // self.shift - 1 + t])
next_head_rel_idx_seed = next_head_rel_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True)
next_head_seed = wrap_angle(self.attr_tokenizer.decode_heading(next_head_rel_idx_seed) + ego_head_t_1)
else:
next_head_rel_theta_seed = self.seed_heading_rel_theta_predict_head(feat_a[-num_new_agent:, (self.num_historical_steps - 1) // self.shift - 1 + t])
next_head_seed = torch.tanh(next_head_rel_theta_seed) * torch.pi + ego_head_t_1
head_a[-num_new_agent:, (self.num_historical_steps - 1) // self.shift - 1 + t] = next_head_seed
if self.use_grid_token:
next_offset_xy_seed = self.seed_offset_xy_predict_head(feat_a[-num_new_agent:, (self.num_historical_steps - 1) // self.shift - 1 + t])
next_offset_xy_seed = torch.tanh(next_offset_xy_seed) * 2
pos_a[-num_new_agent:, (self.num_historical_steps - 1) // self.shift - 1 + t] += next_offset_xy_seed
# ! finalize new features
motion_vector_sa, head_vector_sa = self._build_vector_a(pos_a[-num_new_agent:],
head_a[-num_new_agent:],
state_a[-num_new_agent:])
motion_vector_sa[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0.
head_vector_sa[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0.
motion_vector_a[-num_new_agent:] = motion_vector_sa
head_vector_a[-num_new_agents:] = head_vector_sa
feat_sa, _ = self._build_agent_feature(num_infer_step, device,
motion_vector_sa,
head_vector_sa,
agent_token_emb=new_agent_token_emb,
agent_grid_emb=new_agent_grid_emb,
offset_pos=new_offset_pos,
categorical_embs_a=new_categorical_embs,
state=state_a[-num_new_agent:],
n=num_new_agent)
feat_a = torch.cat([raw_feat_a, feat_sa])
raw_feat_a = feat_a.clone()
num_agent = pos_a.shape[0]
if self.use_grid_token:
grid_agent_occ_gt_seeds[num_new_agents] = grid_agent_occ_gt_t_1
grid_agent_occ_seeds[num_new_agents] = grid_agent_occ_seed
grid_pt_occ_seeds[num_new_agents] = grid_pt_occ_seed
next_pos_rel_prob_seeds[num_new_agents] = next_pos_rel_prob_softmax
next_state_prob_seeds[num_new_agents] = next_state_prob_seed.softmax(dim=-1)[:, -1]
inference_mask = inference_mask[:-num_seed_feature]
next_state_prob_seed_list.append(next_state_prob_seeds)
if self.use_grid_token:
next_pos_rel_prob_seed_list.append(next_pos_rel_prob_seeds)
grid_agent_occ_list.append(grid_agent_occ_seeds)
grid_pt_occ_list.append(grid_pt_occ_seeds)
grid_agent_occ_gt_list.append(grid_agent_occ_gt_seeds)
next_state_idx_list[-1] = torch.cat([next_state_idx_list[-1], torch.full((num_new_agents, 1), self.enter_state, device=device).long()])
# 3. predict motions for all agents
feat_a = raw_feat_a
# rebuild inference mask since number of agents have changed
inference_mask = torch.zeros_like(temporal_mask)
inference_mask[:, (self.num_historical_steps - 1) // self.shift + t - 1] = True
edge_index_t, r_t = self._build_temporal_edge(data, pos_a, head_a, state_a, head_vector_a, temporal_mask, inference_mask.clone())
batch_s = torch.arange(num_infer_step, device=device).repeat_interleave(num_agent)
batch_pl = torch.arange(num_infer_step, device=device).repeat_interleave(data['pt_token']['num_nodes'])
edge_index_a2a, r_a2a = self._build_interaction_edge(data, pos_a, head_a, state_a, head_vector_a, batch_s,
interact_mask, inference_mask=inference_mask, av_index=av_index, **plot_kwargs)
edge_index_pl2a, r_pl2a = self._build_map2agent_edge(data, pos_a, head_a, state_a, head_vector_a, batch_s, batch_pl,
interact_mask, inference_mask=inference_mask, av_index=av_index, **plot_kwargs)
for i in range(self.num_layers):
if i in feat_a_t_dict:
feat_a = feat_a_t_dict[i]
feat_a = feat_a.reshape(-1, self.hidden_dim)
feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t)
feat_a = feat_a.reshape(-1, num_infer_step, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim)
feat_a = self.pt2a_attn_layers[i]((
map_enc['x_pt'].repeat_interleave(repeats=num_infer_step, dim=0).reshape(-1, num_infer_step, self.hidden_dim).transpose(0, 1).reshape(
-1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a)
feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a)
feat_a = feat_a.reshape(num_infer_step, -1, self.hidden_dim).transpose(0, 1)
if t == 0:
feat_a_t_dict[i + 1] = feat_a
else:
# update agent features at current step
n = feat_a_t_dict[i + 1].shape[0]
feat_a_t_dict[i + 1][:n, (self.num_historical_steps - 1) // self.shift - 1 + t] = feat_a[:n, (self.num_historical_steps - 1) // self.shift - 1 + t]
# add newly inserted agent features (only when t changed)
if feat_a.shape[0] > n:
m = feat_a.shape[0] - n
feat_a_t_dict[i + 1] = torch.cat([feat_a_t_dict[i + 1], feat_a[-m:]])
# next motion token
next_token_prob = self.token_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)
topk_token_prob, next_token_idx = torch.topk(next_token_prob_softmax, k=self.motion_beam_size, dim=-1) # both (num_agent, beam_size) e.g. (31, 5)
# next state token
next_state_prob = self.state_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
next_state_idx = next_state_prob.softmax(dim=-1).argmax(dim=-1)
next_state_idx[next_state_idx == self.valid_state_type.index('exit')] = self.exit_state
next_state_idx[av_index] = self.valid_state # force ego_agent to be valid
if not self.use_state_token:
next_state_idx[next_state_idx == self.exit_state] = self.valid_state
# convert the predicted token to a 0.5s (6 timesteps) trajectory
expanded_token_index = next_token_idx[..., None, None, None].expand(-1, -1, 6, 4, 2)
next_token_traj = torch.gather(agent_token_traj_all, 1, expanded_token_index) # (num_agent, beam_size, 6, 4, 2)
# apply rotation and translation on 'next_token_traj'
theta = head_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]
cos, sin = theta.cos(), theta.sin()
rot_mat = torch.zeros((num_agent, 2, 2), device=theta.device)
rot_mat[:, 0, 0] = cos
rot_mat[:, 0, 1] = sin
rot_mat[:, 1, 0] = -sin
rot_mat[:, 1, 1] = cos
agent_diff_rel = torch.bmm(next_token_traj.view(-1, 4, 2),
rot_mat[:, None, None, ...].repeat(1, self.motion_beam_size, self.shift + 1, 1, 1).view(
-1, 2, 2)).view(num_agent, self.motion_beam_size, self.shift + 1, 4, 2)
agent_pred_rel = agent_diff_rel + pos_a[:, None, None, None, (self.num_historical_steps - 1) // self.shift - 1 + t, :]
# sample 1 most probable index of top beam_size tokens, (num_agent, beam_size) -> (num_agent, 1)
# then sample the agent_pred_rel, (num_agent, beam_size, 6, 4, 2) -> (num_agent, 6, 4, 2)
sample_token_index = torch.multinomial(topk_token_prob, 1).to(agent_pred_rel.device)
next_token_idx = next_token_idx.gather(dim=1, index=sample_token_index).squeeze(-1)
agent_pred_rel = agent_pred_rel.gather(dim=1,
index=sample_token_index[..., None, None, None].expand(-1, -1, 6, 4,
2))[:, 0, ...]
# get predicted position and heading of current shifted timesteps
diff_xy = agent_pred_rel[:, 1:, 0, :] - agent_pred_rel[:, 1:, 3, :]
pred_traj[:num_agent, t * 5 : (t + 1) * 5] = agent_pred_rel[:, 1:, ...].clone().mean(dim=2)
pred_head[:num_agent, t * 5 : (t + 1) * 5] = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])
pred_state[:num_agent, t * 5 : (t + 1) * 5] = next_state_idx[:, None].repeat(1, 5)
# pred_prob[:num_agent, t] = topk_token_prob.gather(dim=-1, index=sample_token_index)[:, 0] # (num_agent, beam_size) -> (num_agent,)
# update pos/head/state of current step
pos_a[:, (self.num_historical_steps - 1) // self.shift + t] = agent_pred_rel[:, -1, ...].clone().mean(dim=1)
diff_xy = agent_pred_rel[:, -1, 0, :] - agent_pred_rel[:, -1, 3, :]
theta = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])
head_a[:, (self.num_historical_steps - 1) // self.shift + t] = theta
state_a[:, (self.num_historical_steps - 1) // self.shift + t] = next_state_idx
if self.use_grid_token:
grid_a[:, (self.num_historical_steps - 1) // self.shift + t] = self.attr_tokenizer.encode_pos(
x=pos_a[:, (self.num_historical_steps - 1) // self.shift + t],
y=pos_a[av_index, (self.num_historical_steps - 1) // self.shift + t],
theta_y=theta[av_index],
)[0]
# the case that the current predicted state token is invalid/exit
is_eos = next_state_idx == self.exit_state
is_invalid = next_state_idx == self.invalid_state
next_token_idx[is_invalid] = -1
pos_a[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = 0.
head_a[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = 0.
if self.use_grid_token:
grid_a[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = -1
mask[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = False # to handle those newly-added agents
interact_mask[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = False
agent_token_emb[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = self.no_token_emb(torch.zeros(1, device=device).long())
type_emb = categorical_embs[0].reshape(num_agent, num_infer_step, -1)
shape_emb = categorical_embs[1].reshape(num_agent, num_infer_step, -1)
type_emb[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = self.type_a_emb(torch.tensor(AGENT_TYPE.index('seed'), device=device).long())
shape_emb[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = self.shape_emb(torch.full((1, 3), self.invalid_shape_value, device=device))
categorical_embs = [type_emb.reshape(-1, self.hidden_dim), shape_emb.reshape(-1, self.hidden_dim)]
# FIXME: need to discuss!!!
# if is_eos.any():
# pos_a[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = 0.
# head_a[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = 0.
# mask[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = False # to handle those newly-added agents
# interact_mask[torch.cat([is_eos, torch.zeros(1, device=is_eos.device).bool()]), (self.num_historical_steps - 1) // self.shift + t + 1:] = False
# agent_token_emb[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = self.no_token_emb(torch.zeros(1, device=device).long())
# type_emb = categorical_embs[0].reshape(num_agent, num_infer_step, -1)
# shape_emb = categorical_embs[1].reshape(num_agent, num_infer_step, -1)
# type_emb[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = self.type_a_emb(torch.tensor(AGENT_TYPE.index('seed'), device=device).long())
# shape_emb[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = self.shape_emb(torch.full((1, 3), self.invalid_shape_value, device=device))
# categorical_embs = [type_emb.reshape(-1, self.hidden_dim), shape_emb.reshape(-1, self.hidden_dim)]
# update token embeddings of current step
agent_token_emb[veh_mask, (self.num_historical_steps - 1) // self.shift + t] = agent_token_emb_veh[
next_token_idx[veh_mask]]
agent_token_emb[ped_mask, (self.num_historical_steps - 1) // self.shift + t] = agent_token_emb_ped[
next_token_idx[ped_mask]]
agent_token_emb[cyc_mask, (self.num_historical_steps - 1) // self.shift + t] = agent_token_emb_cyc[
next_token_idx[cyc_mask]]
# 4. update feat_a (t-1)
motion_vector_a, head_vector_a = self._build_vector_a(pos_a, head_a, state_a)
motion_vector_a[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0.
head_vector_a[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0.
offset_pos = pos_a - pos_a[av_index]
x_a = torch.stack(
[torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2]),
# torch.norm(offset_pos[:, :, :2], p=2, dim=-1),
], dim=-1)
x_a = self.x_a_emb(continuous_inputs=x_a.view(-1, x_a.size(-1)),
categorical_embs=categorical_embs)
x_a = x_a.view(-1, num_infer_step, self.hidden_dim)
s_a = self.state_a_emb(state_a.reshape(-1).long()).reshape(-1, num_infer_step, self.hidden_dim)
feat_a = torch.cat((agent_token_emb, x_a, s_a), dim=-1)
if self.use_grid_token:
agent_grid_emb = self.grid_token_emb[grid_a]
feat_a = torch.cat([feat_a, agent_grid_emb], dim=-1)
feat_a = self.fusion_emb(feat_a)
raw_feat_a = feat_a.clone() # ! IMPORANT: need to update `raw_feat_a`
next_token_idx_list.append(next_token_idx[:, None])
next_state_idx_list.append(next_state_idx[:, None])
# get log message
num_inserted_agents_total += num_new_agents
num_inserted_agents += num_new_agents
if num_new_agents > 0:
self.log(t, next_pos_seed, ego_pos_t_1, next_head_seed, ego_head_t_1, next_shape_seed, next_type_idx_seed)
# pbar
allocated_memory = torch.cuda.memory_allocated() / (1024 ** 3)
pbar.set_postfix(memory=f'{allocated_memory:.2f}GB',
insert=f'{num_inserted_agents_total}/{seed_step_mask.sum()}')
for i in range(len(next_token_idx_list)):
next_token_idx_list[i] = torch.cat([next_token_idx_list[i], torch.zeros(num_agent - next_token_idx_list[i].shape[0], 1, device=device) - 1], dim=0).long() # -1: invalid motion token
next_state_idx_list[i] = torch.cat([next_state_idx_list[i], torch.zeros(num_agent - next_state_idx_list[i].shape[0], 1, device=device)], dim=0).long() # 0: invalid state token
# add history attributes
num_agent = pred_traj.shape[0]
num_init_agent = filter_mask.sum()
pred_traj = torch.cat([torch.zeros(num_agent, self.num_historical_steps, *(pred_traj.shape[2:]), device=pred_traj.device), pred_traj], dim=1)
pred_head = torch.cat([torch.zeros(num_agent, self.num_historical_steps, *(pred_head.shape[2:]), device=pred_head.device), pred_head], dim=1)
pred_state = torch.cat([torch.zeros(num_agent, self.num_historical_steps, *(pred_state.shape[2:]), device=pred_state.device), pred_state], dim=1)
pred_traj[:num_init_agent, 0] = data['agent']['position'][filter_mask, 0, :2]
pred_head[:num_init_agent, 0] = data['agent']['heading'][filter_mask, 0]
pred_state[:num_init_agent, 1 : self.num_historical_steps] = data['agent']['state_idx'][filter_mask, :(self.num_historical_steps - 1) // self.shift].repeat_interleave(repeats=self.shift, dim=1)
historical_token_idx = data['agent']['token_idx'][filter_mask, :(self.num_historical_steps - 1) // self.shift]
historical_token_idx[historical_token_idx < 0] = 0
historical_token_traj_all = torch.gather(agent_token_traj_all, 1, historical_token_idx[..., None, None, None].expand(-1, -1, 6, 4, 2))
init_theta = head_a[:num_init_agent, 0]
cos, sin = init_theta.cos(), init_theta.sin()
rot_mat = torch.zeros((num_init_agent, 2, 2), device=init_theta.device)
rot_mat[:, 0, 0] = cos
rot_mat[:, 0, 1] = sin
rot_mat[:, 1, 0] = -sin
rot_mat[:, 1, 1] = cos
historical_token_traj_all = torch.bmm(historical_token_traj_all.view(-1, 4, 2),
rot_mat[:, None, None, ...].repeat(1, (self.num_historical_steps - 1) // self.shift, self.shift + 1, 1, 1).view(
-1, 2, 2)).view(num_init_agent, (self.num_historical_steps - 1) // self.shift, self.shift + 1, 4, 2)
historical_token_traj_all = historical_token_traj_all + pos_a[:num_init_agent, 0, :][:, None, None, None, ...]
pred_traj[:num_init_agent, 1 : self.num_historical_steps] = historical_token_traj_all[:, :, 1:, ...].clone().mean(dim=3).reshape(num_init_agent, -1, 2)
diff_xy = historical_token_traj_all[..., 1:, 0, :] - historical_token_traj_all[..., 1:, 3, :]
pred_head[:num_init_agent, 1 : self.num_historical_steps] = torch.arctan2(diff_xy[..., 1], diff_xy[..., 0]).reshape(num_init_agent, -1)
# ! build z and valid
pred_z = torch.zeros_like(pred_traj[..., 0]) # hard code
pred_valid = (pred_state != self.invalid_state) & (pred_state != self.enter_state)
# ! predefined agent shape
eval_shape = torch.zeros_like(pred_shape)
eval_shape[veh_mask] = torch.tensor(AGENT_SHAPE['vehicle'], device=device)[None, ...]
eval_shape[ped_mask] = torch.tensor(AGENT_SHAPE['pedstrain'], device=device)[None, ...]
eval_shape[cyc_mask] = torch.tensor(AGENT_SHAPE['cyclist'], device=device)[None, ...]
next_token_idx = torch.cat(next_token_idx_list, dim=-1)
next_state_idx = torch.cat(next_state_idx_list, dim=-1) if len(next_state_idx_list) > 0 else None
# sanity check
assert torch.all(pos_a[next_state_idx == self.invalid_state] == 0), f'Invalid step should have all zeros position!'
if self.log_message == "":
self.log_message = "No agents inserted!"
else:
self.log_message += f"\nNumber of total inserted agents: {num_inserted_agents_total}/{seed_step_mask.sum()}"
return {
'ego_index': int(av_index),
'agent_id': agent_id,
# 'valid_mask': agent_valid_mask[:, self.num_historical_steps:],
# 'pos_a': pos_a[:, (self.num_historical_steps - 1) // self.shift:],
# 'head_a': head_a[:, (self.num_historical_steps - 1) // self.shift:],
'valid_mask': agent_valid_mask, # [n_agent, n_infer_step // shift]
'pos_a': pos_a, # [n_agent, n_infer_step // shift, 2]
'head_a': head_a, # [n_agent, n_infer_step // shift]
'gt_traj': gt_traj,
'pred_traj': pred_traj, # [n_agent, n_infer_step, 2]
'pred_head': pred_head, # [n_agent, n_infer_step]
'pred_type': pred_type,
'pred_state': pred_state,
'pred_z': pred_z,
'pred_shape': pred_shape,
'eval_shape': eval_shape,
'pred_valid': pred_valid,
'next_state_prob_seed': torch.cat(next_state_prob_seed_list, dim=1),
'next_pos_rel_prob_seed': torch.cat(next_pos_rel_prob_seed_list, dim=1)
if self.use_grid_token else None,
'next_token_idx': next_token_idx, # [n_agent, n_infer_step // shift]
'next_state_idx': next_state_idx, # [n_agent, n_infer_step // shift]
'grid_agent_occ_seed': torch.cat(grid_agent_occ_list, dim=1)
if self.use_grid_token else None,
'grid_pt_occ_seed': torch.cat(grid_pt_occ_list, dim=1)
if self.use_grid_token else None,
'grid_agent_occ_gt_seed': torch.cat(grid_agent_occ_gt_list, dim=1)
if self.use_grid_token else None,
'agent_labels': agent_labels,
'log_message': self.log_message,
}
def log(self, t, next_pos_seed, ego_pos, next_head_seed, ego_head, next_shape_seed, next_type_idx_seed):
i = 0
_repr_indent = 4
for sa in range(next_pos_seed.shape[0]):
head = f"\n{i} agent {sa} is entering at step {t}"
body = [
f"rel pos {(next_pos_seed[sa] - ego_pos).tolist()}, pos {next_pos_seed[sa].tolist()}",
f"rel head {wrap_angle(next_head_seed[sa] - ego_head).item()}, head {next_head_seed[sa].item()}",
f"shape {next_shape_seed[sa].tolist()}, type {next_type_idx_seed[sa].item()}",
]
self.log_message += "\n".join([head] + [" " * _repr_indent + line for line in body])
i += 1