gzzyyxy's picture
Upload folder using huggingface_hub
c1a7f73 verified
from typing import Dict, Mapping, Optional
import math
import torch
import torch.nn as nn
from torch_cluster import radius, radius_graph
from torch_geometric.data import HeteroData, Batch
from torch_geometric.utils import dense_to_sparse, subgraph
from dev.modules.layers import *
from dev.modules.map_decoder import discretize_neighboring
from dev.utils.geometry import angle_between_2d_vectors, wrap_angle
from dev.utils.weight_init import weight_init
def cal_polygon_contour(x, y, theta, width, length):
left_front_x = x + 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta)
left_front_y = y + 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta)
left_front = (left_front_x, left_front_y)
right_front_x = x + 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta)
right_front_y = y + 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta)
right_front = (right_front_x, right_front_y)
right_back_x = x - 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta)
right_back_y = y - 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta)
right_back = (right_back_x, right_back_y)
left_back_x = x - 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta)
left_back_y = y - 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta)
left_back = (left_back_x, left_back_y)
polygon_contour = [left_front, right_front, right_back, left_back]
return polygon_contour
class SMARTAgentDecoder(nn.Module):
def __init__(self,
dataset: str,
input_dim: int,
hidden_dim: int,
num_historical_steps: int,
num_interaction_steps: int,
time_span: Optional[int],
pl2a_radius: float,
pl2seed_radius: float,
a2a_radius: float,
num_freq_bands: int,
num_layers: int,
num_heads: int,
head_dim: int,
dropout: float,
token_data: Dict,
token_size: int,
special_token_index: list=[],
predict_motion: bool=False,
predict_state: bool=False,
predict_map: bool=False,
state_token: Dict[str, int]=None,
seed_size: int=5) -> None:
super(SMARTAgentDecoder, self).__init__()
self.dataset = dataset
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_historical_steps = num_historical_steps
self.num_interaction_steps = num_interaction_steps
self.time_span = time_span if time_span is not None else num_historical_steps
self.pl2a_radius = pl2a_radius
self.pl2seed_radius = pl2seed_radius
self.a2a_radius = a2a_radius
self.num_freq_bands = num_freq_bands
self.num_layers = num_layers
self.num_heads = num_heads
self.head_dim = head_dim
self.dropout = dropout
self.special_token_index = special_token_index
self.predict_motion = predict_motion
self.predict_state = predict_state
self.predict_map = predict_map
# state tokens
self.state_type = list(state_token.keys())
self.state_token = state_token
self.invalid_state = int(state_token['invalid'])
self.valid_state = int(state_token['valid'])
self.enter_state = int(state_token['enter'])
self.exit_state = int(state_token['exit'])
self.seed_state_type = ['invalid', 'enter']
self.valid_state_type = ['invalid', 'valid', 'exit']
input_dim_x_a = 2
input_dim_r_t = 4
input_dim_r_pt2a = 3
input_dim_r_a2a = 3
input_dim_token = 8 # tokens: (token_size, 4, 2)
self.seed_size = seed_size
self.all_agent_type = ['veh', 'ped', 'cyc', 'background', 'invalid', 'seed']
self.seed_agent_type = ['veh', 'ped', 'cyc', 'seed']
self.type_a_emb = nn.Embedding(len(self.all_agent_type), hidden_dim)
self.shape_emb = MLPLayer(3, hidden_dim, hidden_dim)
if self.predict_state:
self.state_a_emb = nn.Embedding(len(self.state_type), hidden_dim)
self.invalid_shape_value = .1
self.motion_gap = 1.
self.heading_gap = 1.
self.x_a_emb = FourierEmbedding(input_dim=input_dim_x_a, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands)
self.r_t_emb = FourierEmbedding(input_dim=input_dim_r_t, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands)
self.r_pt2a_emb = FourierEmbedding(input_dim=input_dim_r_pt2a, hidden_dim=hidden_dim,
num_freq_bands=num_freq_bands)
self.r_a2a_emb = FourierEmbedding(input_dim=input_dim_r_a2a, hidden_dim=hidden_dim,
num_freq_bands=num_freq_bands)
self.token_emb_veh = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
self.token_emb_ped = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
self.token_emb_cyc = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
self.no_token_emb = nn.Embedding(1, hidden_dim)
self.bos_token_emb = nn.Embedding(1, hidden_dim)
# FIXME: do we need this???
self.token_emb_offset = MLPEmbedding(input_dim=2, hidden_dim=hidden_dim)
num_inputs = 2
if self.predict_state:
num_inputs = 3
self.fusion_emb = MLPEmbedding(input_dim=self.hidden_dim * num_inputs, hidden_dim=self.hidden_dim)
self.t_attn_layers = nn.ModuleList(
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
bipartite=False, has_pos_emb=True) for _ in range(num_layers)]
)
self.pt2a_attn_layers = nn.ModuleList(
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
bipartite=True, has_pos_emb=True) for _ in range(num_layers)]
)
self.a2a_attn_layers = nn.ModuleList(
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
bipartite=False, has_pos_emb=True) for _ in range(num_layers)]
)
self.token_size = token_size # 2048
# agent motion prediction head
self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=self.token_size)
# agent state prediction head
if self.predict_state:
self.seed_feature = nn.Embedding(self.seed_size, self.hidden_dim)
self.state_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=len(self.valid_state_type))
self.seed_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=hidden_dim)
self.seed_state_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=len(self.seed_state_type))
self.seed_type_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=len(self.seed_agent_type))
# entering token prediction
# FIXME: this is just under test!!!
# self.bos_pl_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
# output_dim=200)
# self.bos_offset_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
# output_dim=2601)
self.trajectory_token = token_data['token'] # dict('veh', 'ped', 'cyc') (2048, 4, 2)
self.trajectory_token_traj = token_data['traj'] # (2048, 6, 3)
self.trajectory_token_all = token_data['token_all'] # (2048, 6, 4, 2)
self.apply(weight_init)
self.shift = 5
self.beam_size = 5
self.hist_mask = True
self.temporal_attn_to_invalid = True
self.temporal_attn_seed = False
# FIXME: This is just under test!!!
# self.mapping_network = MappingNetwork(z_dim=hidden_dim, w_dim=hidden_dim, num_layers=num_layers)
def transform_rel(self, token_traj, prev_pos, prev_heading=None):
if prev_heading is None:
diff_xy = prev_pos[:, :, -1, :] - prev_pos[:, :, -2, :]
prev_heading = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])
num_agent, num_step, traj_num, traj_dim = token_traj.shape
cos, sin = prev_heading.cos(), prev_heading.sin()
rot_mat = torch.zeros((num_agent, num_step, 2, 2), device=prev_heading.device)
rot_mat[:, :, 0, 0] = cos
rot_mat[:, :, 0, 1] = -sin
rot_mat[:, :, 1, 0] = sin
rot_mat[:, :, 1, 1] = cos
agent_diff_rel = torch.bmm(token_traj.view(-1, traj_num, 2), rot_mat.view(-1, 2, 2)).view(num_agent, num_step, traj_num, traj_dim)
agent_pred_rel = agent_diff_rel + prev_pos[:, :, -1:, :]
return agent_pred_rel
def agent_token_embedding(self, data, agent_token_index, agent_state, pos_a, head_a, inference=False,
filter_mask=None, av_index=None):
if filter_mask is None:
filter_mask = torch.ones_like(agent_state[:, 2], dtype=torch.bool)
num_agent, num_step, traj_dim = pos_a.shape # traj_dim=2
agent_type = data['agent']['type'][filter_mask]
veh_mask = (agent_type == 0)
ped_mask = (agent_type == 1)
cyc_mask = (agent_type == 2)
# set the position of invalid agents to the position of ego agent
# note here we only set invalid steps BEFORE the bos token!
# is_invalid = agent_state == self.invalid_state
# is_bos = agent_state == self.enter_state
# bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
# bos_mask = torch.arange(num_step).expand(num_agent, -1).to(agent_state.device) < bos_index[:, None]
# is_invalid[~bos_mask] = False
# ego_pos_a = pos_a[av_index].clone()
# ego_head_vector_a = head_vector_a[av_index].clone()
# pos_a[is_invalid] = ego_pos_a[None, :].repeat(pos_a.shape[0], 1, 1)[is_invalid]
# head_vector_a[is_invalid] = ego_head_vector_a[None, :].repeat(head_vector_a.shape[0], 1, 1)[is_invalid]
motion_vector_a, head_vector_a = self.build_vector_a(pos_a, head_a, agent_state)
trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(pos_a.device).to(torch.float)
trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(pos_a.device).to(torch.float)
trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(pos_a.device).to(torch.float)
self.agent_token_emb_veh = self.token_emb_veh(trajectory_token_veh.view(trajectory_token_veh.shape[0], -1)) # (token_size, 8)
self.agent_token_emb_ped = self.token_emb_ped(trajectory_token_ped.view(trajectory_token_ped.shape[0], -1))
self.agent_token_emb_cyc = self.token_emb_cyc(trajectory_token_cyc.view(trajectory_token_cyc.shape[0], -1))
# add bos token embedding
self.agent_token_emb_veh = torch.cat([self.agent_token_emb_veh, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())])
self.agent_token_emb_ped = torch.cat([self.agent_token_emb_ped, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())])
self.agent_token_emb_cyc = torch.cat([self.agent_token_emb_cyc, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())])
# add invalid token embedding
self.agent_token_emb_veh = torch.cat([self.agent_token_emb_veh, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())])
self.agent_token_emb_ped = torch.cat([self.agent_token_emb_ped, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())])
self.agent_token_emb_cyc = torch.cat([self.agent_token_emb_cyc, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())])
if inference:
agent_token_traj_all = torch.zeros((num_agent, self.token_size, self.shift + 1, 4, 2), device=pos_a.device)
trajectory_token_all_veh = torch.from_numpy(self.trajectory_token_all['veh']).clone().to(pos_a.device).to(torch.float)
trajectory_token_all_ped = torch.from_numpy(self.trajectory_token_all['ped']).clone().to(pos_a.device).to(torch.float)
trajectory_token_all_cyc = torch.from_numpy(self.trajectory_token_all['cyc']).clone().to(pos_a.device).to(torch.float)
agent_token_traj_all[veh_mask] = torch.cat(
[trajectory_token_all_veh[:, :self.shift], trajectory_token_veh[:, None, ...]], dim=1)
agent_token_traj_all[ped_mask] = torch.cat(
[trajectory_token_all_ped[:, :self.shift], trajectory_token_ped[:, None, ...]], dim=1)
agent_token_traj_all[cyc_mask] = torch.cat(
[trajectory_token_all_cyc[:, :self.shift], trajectory_token_cyc[:, None, ...]], dim=1)
# additional token embeddings are already added -> -1: invalid, -2: bos
agent_token_emb = torch.zeros((num_agent, num_step, self.hidden_dim), device=pos_a.device)
agent_token_emb[veh_mask] = self.agent_token_emb_veh[agent_token_index[veh_mask]]
agent_token_emb[ped_mask] = self.agent_token_emb_ped[agent_token_index[ped_mask]]
agent_token_emb[cyc_mask] = self.agent_token_emb_cyc[agent_token_index[cyc_mask]]
# 'vehicle', 'pedestrian', 'cyclist', 'background'
is_invalid = (agent_state == self.invalid_state) & (agent_state != self.enter_state)
agent_types = data['agent']['type'][filter_mask].long().repeat_interleave(repeats=num_step, dim=0)
agent_types[is_invalid.reshape(-1)] = self.all_agent_type.index('invalid')
agent_shapes = data['agent']['shape'][filter_mask, self.num_historical_steps - 1, :].repeat_interleave(repeats=num_step, dim=0)
agent_shapes[is_invalid.reshape(-1)] = self.invalid_shape_value
categorical_embs = [self.type_a_emb(agent_types), self.shape_emb(agent_shapes)]
feature_a = torch.stack(
[torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2]),
], dim=-1) # (num_agent, num_shifted_step, 2)
x_a = self.x_a_emb(continuous_inputs=feature_a.view(-1, feature_a.size(-1)),
categorical_embs=categorical_embs)
x_a = x_a.view(-1, num_step, self.hidden_dim) # (num_agent, num_step, hidden_dim)
s_a = self.state_a_emb(agent_state.reshape(-1).long()).reshape(num_agent, num_step, self.hidden_dim)
feat_a = torch.cat((agent_token_emb, x_a, s_a), dim=-1) # (num_agent, num_step, hidden_dim * 3)
feat_a = self.fusion_emb(feat_a) # (num_agent, num_step, hidden_dim)
# seed agent feature
motion_vector_seed = motion_vector_a[av_index : av_index + 1]
head_vector_seed = head_vector_a[av_index : av_index + 1]
feat_seed = self.build_invalid_agent_feature(num_step, pos_a.device, type_index=self.all_agent_type.index('seed'),
motion_vector=motion_vector_seed, head_vector=head_vector_seed)
# replace the features of steps before bos of valid agents with the corresponding invalid agent features
# is_bos = agent_state == self.enter_state
# is_eos = agent_state == self.exit_state
# bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
# eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1))
# is_before_bos = torch.arange(num_step).expand(num_agent, -1).to(agent_state.device) < bos_index[:, None]
# is_after_eos = torch.arange(num_step).expand(num_agent, -1).to(agent_state.device) > eos_index[:, None] + 1
# feat_ina = self.build_invalid_agent_feature(num_step, pos_a.device)
# feat_a[is_before_bos | is_after_eos] = feat_ina.repeat(num_agent, 1, 1)[is_before_bos | is_after_eos]
# print("train")
# is_bos = agent_state == self.enter_state
# is_eos = agent_state == self.exit_state
# bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
# eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1))
# mask = torch.arange(num_step).expand(num_agent, -1).to(agent_state.device)
# mask = (mask >= bos_index[:, None]) & (mask <= eos_index[:, None] + 1)
# is_invalid[mask] = False
# print(feat_a.sum(dim=-1)[is_invalid])
feat_a = torch.cat([feat_a, feat_seed], dim=0) # (num_agent + 1, num_step, hidden_dim)
# feat_a_sum = feat_a.sum(dim=-1)
# for a in range(num_agent):
# print(f"agent {a}:")
# print(f"state: {agent_state[a, :]}")
# print(f"feat_a_sum: {feat_a_sum[a, :]}")
# exit(1)
if inference:
return feat_a, head_vector_a, agent_token_traj_all, agent_token_emb, categorical_embs
else:
return feat_a, head_vector_a
def build_vector_a(self, pos_a, head_a, state_a):
num_agent = pos_a.shape[0]
motion_vector_a = torch.cat([pos_a.new_zeros(num_agent, 1, self.input_dim),
pos_a[:, 1:] - pos_a[:, :-1]], dim=1)
# update the relative motion/head vectors
is_bos = state_a == self.enter_state
motion_vector_a[is_bos] = self.motion_gap
is_last_eos = state_a.roll(shifts=1, dims=1) == self.exit_state
is_last_eos[:, 0] = False
motion_vector_a[is_last_eos] = -self.motion_gap
head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1)
return motion_vector_a, head_vector_a
def build_invalid_agent_feature(self, num_step, device, motion_vector=None, head_vector=None, type_index=None, shape_value=None):
invalid_agent_token_emb = self.no_token_emb(torch.zeros(1, device=device).long())[:, None].repeat(1, num_step, 1)
if motion_vector is None or head_vector is None:
motion_vector = torch.zeros((1, num_step, 2), device=device)
head_vector = torch.stack([torch.cos(torch.zeros(1, device=device)), torch.sin(torch.zeros(1, device=device))], dim=-1)[:, None, :].repeat(1, num_step, 1)
feature_ina = torch.stack(
[torch.norm(motion_vector[:, :, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=head_vector, nbr_vector=motion_vector[:, :, :2]),
], dim=-1)
if type_index is None:
type_index = self.all_agent_type.index('invalid')
if shape_value is None:
shape_value = torch.full((1, 3), self.invalid_shape_value, device=device)
categorical_embs_ina = [self.type_a_emb(torch.tensor([type_index], device=device)),
self.shape_emb(shape_value)]
x_ina = self.x_a_emb(continuous_inputs=feature_ina.view(-1, feature_ina.size(-1)),
categorical_embs=categorical_embs_ina)
x_ina = x_ina.view(-1, num_step, self.hidden_dim) # (1, num_step, hidden_dim)
s_ina = self.state_a_emb(torch.tensor([self.invalid_state], device=device))[:, None].repeat(1, num_step, 1) # NOTE: do not use `expand`
feat_ina = torch.cat((invalid_agent_token_emb, x_ina, s_ina), dim=-1)
feat_ina = self.fusion_emb(feat_ina) # (1, num_step, hidden_dim)
return feat_ina
def build_temporal_edge(self, pos_a, head_a, head_vector_a, state_a, mask, inference_mask=None, av_index=None):
num_agent = pos_a.shape[0]
hist_mask = mask.clone()
if not self.temporal_attn_to_invalid:
hist_mask[state_a == self.invalid_state] = False
# set the position of invalid agents to the position of ego agent
ego_pos_a = pos_a[av_index].clone() # (num_step, 2)
ego_head_a = head_a[av_index].clone()
ego_head_vector_a = head_vector_a[av_index].clone()
ego_state_a = state_a[av_index].clone()
# is_invalid = state_a == self.invalid_state
# pos_a[is_invalid] = ego_pos_a[None, :].repeat(pos_a.shape[0], 1, 1)[is_invalid]
# head_a[is_invalid] = ego_head_a[None, :].repeat(head_a.shape[0], 1)[is_invalid]
# add seed agent
pos_a = torch.cat([pos_a, ego_pos_a[None]], dim=0)
head_a = torch.cat([head_a, ego_head_a[None]], dim=0)
state_a = torch.cat([state_a, ego_state_a[None]], dim=0)
head_vector_a = torch.cat([head_vector_a, ego_head_vector_a[None]], dim=0)
hist_mask = torch.cat([hist_mask, torch.ones_like(hist_mask[0:1])], dim=0).bool()
if not self.temporal_attn_seed:
hist_mask[-1:] = False
if inference_mask is not None:
inference_mask[-1:] = False
pos_t = pos_a.reshape(-1, self.input_dim) # (num_agent * num_step, ...)
head_t = head_a.reshape(-1)
head_vector_t = head_vector_a.reshape(-1, 2)
# for those invalid agents won't predict any motion token, we don't attend to them
is_bos = state_a == self.enter_state
is_bos[-1] = False
bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
motion_predict_start_index = torch.clamp(bos_index - self.time_span / self.shift + 1, min=0)
motion_predict_mask = torch.arange(hist_mask.shape[1]).expand(hist_mask.shape[0], -1).to(hist_mask.device)
motion_predict_mask = motion_predict_mask >= motion_predict_start_index[:, None]
hist_mask[~motion_predict_mask] = False
if self.hist_mask and self.training:
hist_mask[
torch.arange(mask.shape[0]).unsqueeze(1), torch.randint(0, mask.shape[1], (num_agent, 10))] = False
mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1)
elif inference_mask is not None:
mask_t = hist_mask.unsqueeze(2) & inference_mask.unsqueeze(1)
else:
mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1)
# mask_t: (num_agent, 18, 18), edge_index_t: (2, num_edge)
edge_index_t = dense_to_sparse(mask_t)[0]
edge_index_t = edge_index_t[:, (edge_index_t[1] - edge_index_t[0] > 0) &
(edge_index_t[1] - edge_index_t[0] <= self.time_span / self.shift)]
rel_pos_t = pos_t[edge_index_t[0]] - pos_t[edge_index_t[1]]
rel_head_t = wrap_angle(head_t[edge_index_t[0]] - head_t[edge_index_t[1]])
# FIXME relative motion/head for bos/eos token
# is_next_bos = state_a.roll(shifts=-1, dims=1) == self.enter_state
# is_next_bos[:, -1] = False # the last step
# is_next_bos_t = is_next_bos.reshape(-1)
# rel_pos_t[is_next_bos_t[edge_index_t[0]]] = -self.bos_motion
# rel_pos_t[is_next_bos_t[edge_index_t[1]]] = self.bos_motion
# rel_head_t[is_next_bos_t[edge_index_t[0]]] = -torch.pi
# rel_head_t[is_next_bos_t[edge_index_t[1]]] = torch.pi
# is_last_eos = state_a.roll(shifts=1, dims=1) == self.exit_state
# is_last_eos[:, 0] = False # the first step
# is_last_eos_t = is_last_eos.reshape(-1)
# rel_pos_t[is_last_eos_t[edge_index_t[0]]] = -self.bos_motion
# rel_pos_t[is_last_eos_t[edge_index_t[1]]] = self.bos_motion
# rel_head_t[is_last_eos_t[edge_index_t[0]]] = -torch.pi
# rel_head_t[is_last_eos_t[edge_index_t[1]]] = torch.pi
# handle the bos token of ego agent
# is_invalid = state_a == self.invalid_state
# is_invalid_t = is_invalid.reshape(-1)
# is_ego_bos = (ego_state_a == self.enter_state)[None, :].expand(num_agent + 1, -1)
# is_ego_bos_t = is_ego_bos.reshape(-1)
# rel_pos_t[is_invalid_t[edge_index_t[0]] & is_ego_bos_t[edge_index_t[0]]] = 0.
# rel_pos_t[is_invalid_t[edge_index_t[1]] & is_ego_bos_t[edge_index_t[1]]] = 0.
# rel_head_t[is_invalid_t[edge_index_t[0]] & is_ego_bos_t[edge_index_t[0]]] = 0.
# rel_head_t[is_invalid_t[edge_index_t[1]] & is_ego_bos_t[edge_index_t[1]]] = 0.
# handle the invalid steps
is_invalid = state_a == self.invalid_state
is_invalid_t = is_invalid.reshape(-1)
rel_pos_t[is_invalid_t[edge_index_t[0]]] = -self.motion_gap
rel_pos_t[is_invalid_t[edge_index_t[1]]] = self.motion_gap
rel_head_t[is_invalid_t[edge_index_t[0]]] = -self.heading_gap
rel_head_t[is_invalid_t[edge_index_t[1]]] = self.heading_gap
r_t = torch.stack(
[torch.norm(rel_pos_t[:, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=head_vector_t[edge_index_t[1]], nbr_vector=rel_pos_t[:, :2]),
rel_head_t,
edge_index_t[0] - edge_index_t[1]], dim=-1)
r_t = self.r_t_emb(continuous_inputs=r_t, categorical_embs=None)
return edge_index_t, r_t
def build_interaction_edge(self, pos_a, head_a, head_vector_a, state_a, batch_s, mask_a, inference_mask=None, av_index=None):
num_agent, num_step, _ = pos_a.shape
pos_a = torch.cat([pos_a, pos_a[av_index][None]], dim=0)
head_a = torch.cat([head_a, head_a[av_index][None]], dim=0)
state_a = torch.cat([state_a, state_a[av_index][None]], dim=0)
head_vector_a = torch.cat([head_vector_a, head_vector_a[av_index][None]], dim=0)
pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
head_s = head_a.transpose(0, 1).reshape(-1)
head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)
if inference_mask is not None:
mask_a = mask_a & inference_mask
mask_s = mask_a.transpose(0, 1).reshape(-1)
# seed agent
mask_seed = state_a[av_index] != self.invalid_state
pos_seed = pos_a[av_index]
edge_index_seed2a = radius(x=pos_seed[:, :2], y=pos_s[:, :2], r=self.pl2seed_radius,
batch_x=torch.arange(num_step).to(pos_s.device), batch_y=batch_s, max_num_neighbors=300)
edge_index_seed2a = edge_index_seed2a[:, mask_s[edge_index_seed2a[0]] & mask_seed[edge_index_seed2a[1]]]
# convert to global index (must be unilateral connection)
edge_index_seed2a[1, :] = (edge_index_seed2a[1, :] + 1) * (num_agent + 1) - 1
# build agent2agent bilateral connection
edge_index_a2a = radius_graph(x=pos_s[:, :2], r=self.a2a_radius, batch=batch_s, loop=False,
max_num_neighbors=300)
edge_index_a2a = subgraph(subset=mask_s, edge_index=edge_index_a2a)[0]
# add the edges which connect seed agents
edge_index_a2a = torch.cat([edge_index_a2a, edge_index_seed2a], dim=-1)
# set the position of invalid agents to the position of ego agent
# ego_pos_a = pos_a[av_index].clone()
# ego_head_a = head_a[av_index].clone()
# is_invalid = state_a == self.invalid_state
# pos_a[is_invalid] = ego_pos_a[None, :].repeat(pos_a.shape[0], 1, 1)[is_invalid]
# head_a[is_invalid] = ego_head_a[None, :].repeat(head_a.shape[0], 1)[is_invalid]
rel_pos_a2a = pos_s[edge_index_a2a[0]] - pos_s[edge_index_a2a[1]]
rel_head_a2a = wrap_angle(head_s[edge_index_a2a[0]] - head_s[edge_index_a2a[1]])
# relative motion/head for bos/eos token
# is_bos = state_a == self.enter_state
# is_bos_s = is_bos.transpose(0, 1).reshape(-1)
# rel_pos_a2a[is_bos_s[edge_index_a2a[0]]] = -self.bos_motion
# rel_pos_a2a[is_bos_s[edge_index_a2a[1]]] = self.bos_motion
# rel_head_a2a[is_bos_s[edge_index_a2a[0]]] = -torch.pi
# rel_head_a2a[is_bos_s[edge_index_a2a[1]]] = torch.pi
# is_last_eos = state_a.roll(shifts=-1, dims=1) == self.exit_state
# is_last_eos[:, 0] = False # first step
# is_last_eos_s = is_last_eos.transpose(0, 1).reshape(-1)
# rel_pos_a2a[is_last_eos_s[edge_index_a2a[0]]] = -self.bos_motion
# rel_pos_a2a[is_last_eos_s[edge_index_a2a[1]]] = self.bos_motion
# rel_head_a2a[is_last_eos_s[edge_index_a2a[0]]] = -torch.pi
# rel_head_a2a[is_last_eos_s[edge_index_a2a[1]]] = torch.pi
# handle the bos token of ego agent
# is_invalid = state_a == self.invalid_state
# is_invalid_s = is_invalid.transpose(0, 1).reshape(-1)
# is_ego_bos = (state_a[av_index] == self.enter_state)[None, :].expand(num_agent + 1, -1)
# is_ego_bos_s = is_ego_bos.transpose(0, 1).reshape(-1)
# rel_pos_a2a[is_invalid_s[edge_index_a2a[0]] & is_ego_bos_s[edge_index_a2a[0]]] = 0.
# rel_pos_a2a[is_invalid_s[edge_index_a2a[1]] & is_ego_bos_s[edge_index_a2a[1]]] = 0.
# rel_head_a2a[is_invalid_s[edge_index_a2a[0]] & is_ego_bos_s[edge_index_a2a[0]]] = 0.
# rel_head_a2a[is_invalid_s[edge_index_a2a[1]] & is_ego_bos_s[edge_index_a2a[1]]] = 0.
# handle the invalid steps
is_invalid = state_a == self.invalid_state
is_invalid_s = is_invalid.transpose(0, 1).reshape(-1)
rel_pos_a2a[is_invalid_s[edge_index_a2a[0]]] = -self.motion_gap
rel_pos_a2a[is_invalid_s[edge_index_a2a[1]]] = self.motion_gap
rel_head_a2a[is_invalid_s[edge_index_a2a[0]]] = -self.heading_gap
rel_head_a2a[is_invalid_s[edge_index_a2a[1]]] = self.heading_gap
r_a2a = torch.stack(
[torch.norm(rel_pos_a2a[:, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2a[1]], nbr_vector=rel_pos_a2a[:, :2]),
rel_head_a2a], dim=-1)
r_a2a = self.r_a2a_emb(continuous_inputs=r_a2a, categorical_embs=None)
return edge_index_a2a, r_a2a
def build_map2agent_edge(self, data, num_step, pos_a, head_a, head_vector_a, state_a, batch_s, batch_pl,
mask, inference_mask=None, av_index=None):
num_agent, num_step, _ = pos_a.shape
mask_pl2a = mask.clone()
if inference_mask is not None:
mask_pl2a = mask_pl2a & inference_mask
mask_pl2a = mask_pl2a.transpose(0, 1).reshape(-1)
pos_a = torch.cat([pos_a, pos_a[av_index][None]], dim=0)
state_a = torch.cat([state_a, state_a[av_index][None]], dim=0)
head_a = torch.cat([head_a, head_a[av_index][None]], dim=0)
head_vector_a = torch.cat([head_vector_a, head_vector_a[av_index][None]], dim=0)
pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
head_s = head_a.transpose(0, 1).reshape(-1)
head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)
ori_pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous()
ori_orient_pl = data['pt_token']['orientation'].contiguous()
pos_pl = ori_pos_pl.repeat(num_step, 1) # not `repeat_interleave`
orient_pl = ori_orient_pl.repeat(num_step)
# seed agent
mask_seed = state_a[av_index] != self.invalid_state
pos_seed = pos_a[av_index]
edge_index_pl2seed = radius(x=pos_seed[:, :2], y=pos_pl[:, :2], r=self.pl2seed_radius,
batch_x=torch.arange(num_step).to(pos_s.device), batch_y=batch_pl, max_num_neighbors=600)
edge_index_pl2seed = edge_index_pl2seed[:, mask_seed[edge_index_pl2seed[1]]]
# convert to global index
edge_index_pl2seed[1, :] = (edge_index_pl2seed[1, :] + 1) * (num_agent + 1) - 1
# build map2agent directed graph
# edge_index_pl2a[0]: pl token; edge_index_pl2a[1]: agent token
edge_index_pl2a = radius(x=pos_s[:, :2], y=pos_pl[:, :2], r=self.pl2a_radius,
batch_x=batch_s, batch_y=batch_pl, max_num_neighbors=300)
# We force invalid agents to interact with **all** (visible in current window) map tokens
# invalid_node_index_a = torch.where(bos_state_s.bool())[0]
# sampled_node_index_m = torch.arange(ori_pos_pl.shape[0]).to(pos_pl.device)
# if kwargs.get('sample_pt_indices', None) is not None:
# sampled_node_index_m = sampled_node_index_m[kwargs['sample_pt_indices'].long()]
# grid_a, grid_b = torch.meshgrid(sampled_node_index_m, invalid_node_index_a, indexing='ij')
# invalid_edge_index_pl2a = torch.stack([grid_a.reshape(-1), grid_b.reshape(-1)], dim=0)
# edge_index_pl2a = torch.concat([edge_index_pl2a, invalid_edge_index_pl2a], dim=-1)
# remove the edges which connect with motion-invalid agents
edge_index_pl2a = edge_index_pl2a[:, mask_pl2a[edge_index_pl2a[1]]]
# add the edges which connect seed agents with map tokens
edge_index_pl2a = torch.cat([edge_index_pl2a, edge_index_pl2seed], dim=-1)
# set the position of invalid agents to the position of ego agent
# ego_pos_a = pos_a[av_index].clone()
# ego_head_a = head_a[av_index].clone()
# is_invalid = state_a == self.invalid_state
# pos_a[is_invalid] = ego_pos_a[None, :].repeat(pos_a.shape[0], 1, 1)[is_invalid]
# head_a[is_invalid] = ego_head_a[None, :].repeat(head_a.shape[0], 1)[is_invalid]
rel_pos_pl2a = pos_pl[edge_index_pl2a[0]] - pos_s[edge_index_pl2a[1]]
rel_orient_pl2a = wrap_angle(orient_pl[edge_index_pl2a[0]] - head_s[edge_index_pl2a[1]])
# handle the invalid steps
is_invalid = state_a == self.invalid_state
is_invalid_s = is_invalid.transpose(0, 1).reshape(-1)
rel_pos_pl2a[is_invalid_s[edge_index_pl2a[1]]] = self.motion_gap
rel_orient_pl2a[is_invalid_s[edge_index_pl2a[1]]] = self.heading_gap
r_pl2a = torch.stack(
[torch.norm(rel_pos_pl2a[:, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_pl2a[1]], nbr_vector=rel_pos_pl2a[:, :2]),
rel_orient_pl2a], dim=-1)
r_pl2a = self.r_pt2a_emb(continuous_inputs=r_pl2a, categorical_embs=None)
return edge_index_pl2a, r_pl2a
def get_inputs(self, data: HeteroData) -> Dict[str, torch.Tensor]:
pos_a = data['agent']['token_pos']
head_a = data['agent']['token_heading']
agent_category = data['agent']['category']
agent_token_index = data['agent']['token_idx']
agent_state_index = data['agent']['state_idx']
mask = data['agent']['raw_agent_valid_mask'].clone()
# mask[agent_category != 3] = False
if not self.predict_state:
agent_state_index = None
next_token_index_gt = agent_token_index.roll(shifts=-1, dims=1)
next_state_index_gt = agent_state_index.roll(shifts=-1, dims=1)
if self.predict_state:
next_token_eval_mask = mask.clone()
next_token_eval_mask = next_token_eval_mask * next_token_eval_mask.roll(shifts=1, dims=1)
bos_token_index = torch.nonzero(agent_state_index == 2)
eos_token_index = torch.nonzero(agent_state_index == 3)
next_token_eval_mask[bos_token_index[:, 0], bos_token_index[:, 1]] = 1
for eos_token_index_ in eos_token_index:
if not next_token_eval_mask[eos_token_index_[0], eos_token_index_[1]]:
next_token_eval_mask[eos_token_index_[0], eos_token_index_[1]:] = 0
next_token_eval_mask = next_token_eval_mask.roll(shifts=-1, dims=1)
# TODO: next_state_eval_mask !!!
if next_token_index_gt[next_token_eval_mask].min() < 0:
raise RuntimeError()
next_token_eval_mask[:, -1] = False
return {'token_pos': pos_a,
'token_heading': head_a,
'agent_category': agent_category,
'next_token_idx_gt': next_token_index_gt,
'next_state_idx_gt': next_state_index_gt,
'next_token_eval_mask': next_token_eval_mask,
'raw_agent_valid_mask': data['agent']['raw_agent_valid_mask'],
}
def forward(self,
data: HeteroData,
map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
pos_a = data['agent']['token_pos'].clone() # (num_agent, num_shifted_step, 2)
head_a = data['agent']['token_heading'].clone() # (num_agent, num_shifted_step)
num_agent, num_step, traj_dim = pos_a.shape # e.g. (50, 18, 2)
agent_category = data['agent']['category'].clone() # (num_agent,)
agent_token_index = data['agent']['token_idx'].clone() # (num_agent, num_step)
agent_state_index = data['agent']['state_idx'].clone() # (num_agent, num_step)
agent_type_index = data['agent']['type'].clone() # (num_agent, num_step)
agent_enter_pl_token_idx = None
agent_enter_offset_token_idx = None
device = pos_a.device
seed_step_mask = agent_state_index[:, 1:] == self.enter_state
if torch.any(seed_step_mask.sum(dim=0) > self.seed_size):
print(agent_state_index)
print(agent_state_index.shape)
print(seed_step_mask.long())
print(seed_step_mask.sum(dim=0))
raise RuntimeError(f"Seed size {self.seed_size} is too small.")
# fix pos and head of invalid agents
av_index = int(data['agent']['av_index'])
# ego_pos_a = pos_a[av_index].clone() # (num_shifted_step, 2)
# ego_head_vector_a = head_vector_a[av_index] # (num_shifted_step, 2)
# is_invalid = agent_state_index == self.invalid_state
# pos_a[is_invalid] = ego_pos_a[None, :].expand(pos_a.shape[0], -1, -1)[is_invalid]
# head_vector_a[is_invalid] = ego_head_vector_a[None, :].expand(head_vector_a.shape[0], -1, -1)[is_invalid]
if not self.predict_state:
agent_state_index = None
feat_a, head_vector_a = self.agent_token_embedding(data, agent_token_index, agent_state_index, pos_a, head_a, av_index=av_index)
# build masks
mask = data['agent']['raw_agent_valid_mask'].clone()
temporal_mask = mask.clone()
interact_mask = mask.clone()
if self.predict_state:
agent_enter_offset_token_idx = data['agent']['neighbor_token_idx']
agent_enter_pl_token_idx = data['agent']['map_bos_token_idx']
agent_enter_pl_token_id = data['agent']['map_bos_token_id']
is_bos = agent_state_index == self.enter_state
is_eos = agent_state_index == self.exit_state
bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1)) # not `-1`
temporal_mask = torch.ones_like(mask)
motion_mask = torch.arange(mask.shape[1]).expand(mask.shape[0], -1).to(device)
motion_mask = (motion_mask > bos_index[:, None]) & (motion_mask <= eos_index[:, None])
temporal_mask[motion_mask] = mask[motion_mask]
interact_mask[agent_state_index == self.enter_state] = True
interact_mask = torch.cat([interact_mask, torch.ones_like(interact_mask[:1])]).bool() # placeholder
edge_index_t, r_t = self.build_temporal_edge(pos_a, head_a, head_vector_a, agent_state_index, temporal_mask,
av_index=av_index)
# +1: placeholder for seed agent
# if isinstance(data, Batch):
# print(data['agent']['batch'], data.num_graphs)
# batch_s = torch.cat([data['agent']['batch'] + data.num_graphs * t for t in range(num_step)], dim=0)
# batch_pl = torch.cat([data['pt_token']['batch'] + data.num_graphs * t for t in range(num_step)], dim=0)
# else:
batch_s = torch.arange(num_step, device=device).repeat_interleave(data['agent']['num_nodes'] + 1)
batch_pl = torch.arange(num_step, device=device).repeat_interleave(data['pt_token']['num_nodes'])
edge_index_a2a, r_a2a = self.build_interaction_edge(pos_a, head_a, head_vector_a, agent_state_index, batch_s,
interact_mask, av_index=av_index)
agent_category = torch.cat([agent_category, torch.full(agent_category[-1:].shape, 3, device=device)])
interact_mask[agent_category != 3] = False
edge_index_pl2a, r_pl2a = self.build_map2agent_edge(data, num_step, pos_a, head_a, head_vector_a,
agent_state_index, batch_s, batch_pl, interact_mask, av_index=av_index)
# mapping network
# z = torch.randn(num_agent, self.hidden_dim).to(feat_a.device)
# w = self.mapping_network(z)
for i in range(self.num_layers):
# feat_a = feat_a + w[:, None]
feat_a = feat_a.reshape(-1, self.hidden_dim) # (num_agent, num_step, hidden_dim) -> (seq_len, hidden_dim)
feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t)
feat_a = feat_a.reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim)
feat_a = self.pt2a_attn_layers[i]((
map_enc['x_pt'].repeat_interleave(repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(
-1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a)
feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a)
feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1)
# next motion token
next_token_prob = self.token_predict_head(feat_a[:-1]) # (num_agent, num_step, token_size)
next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)
_, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1) # (num_agent, num_step, 10)
next_token_index_gt = agent_token_index.roll(shifts=-1, dims=1)
# next state token
next_state_prob = self.state_predict_head(feat_a[:-1])
next_state_idx = next_state_prob.softmax(dim=-1).argmax(dim=-1, keepdim=True) # (num_agent, num_step, 1)
next_state_index_gt = agent_state_index.roll(shifts=-1, dims=1) # (invalid, valid, exit)
# seed agent
feat_seed = self.seed_head(feat_a[-1:]) + self.seed_feature.weight[:, None]
next_state_prob_seed = self.seed_state_predict_head(feat_seed)
next_state_idx_seed = next_state_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True) # (self.seed_size, num_step, 1)
next_type_prob_seed = self.seed_type_predict_head(feat_seed)
next_type_idx_seed = next_type_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True)
next_type_index_gt = agent_type_index[:, None].expand(-1, num_step).roll(shifts=-1, dims=1)
# polygon token for bos token
# next_bos_pl_prob = self.bos_pl_predict_head(feat_a)
# next_bos_pl_prob_softmax = torch.softmax(next_bos_pl_prob, dim=-1)
# _, next_bos_pl_idx = torch.topk(next_bos_pl_prob_softmax, k=1, dim=-1) # (num_agent, num_step, 1)
# next_bos_pl_index_gt = agent_enter_pl_token_id.roll(shifts=-1, dims=-1)
# offset token for bos token
# next_bos_offset_prob = self.bos_offset_predict_head(feat_a)
# next_bos_offset_prob_softmax = torch.softmax(next_bos_offset_prob, dim=-1)
# _, next_bos_offset_idx = torch.topk(next_bos_offset_prob_softmax, k=1, dim=-1)
# next_bos_offset_index_gt = agent_enter_offset_token_idx.roll(shifts=-1, dims=-1)
# next token prediction mask
bos_token_index = torch.nonzero(agent_state_index == self.enter_state)
eos_token_index = torch.nonzero(agent_state_index == self.exit_state)
# mask for motion tokens
next_token_eval_mask = mask.clone()
next_token_eval_mask = next_token_eval_mask * next_token_eval_mask.roll(shifts=-1, dims=1) * next_token_eval_mask.roll(shifts=1, dims=1)
for bos_token_index_ in bos_token_index:
next_token_eval_mask[bos_token_index_[0], bos_token_index_[1] : bos_token_index_[1] + 1] = 1
next_token_eval_mask[bos_token_index_[0], bos_token_index_[1] + 1 : bos_token_index_[1] + 2] = \
mask[bos_token_index_[0], bos_token_index_[1] + 2 : bos_token_index_[1] + 3]
next_token_eval_mask[eos_token_index[:, 0], eos_token_index[:, 1]] = 0
# mask for state tokens
next_state_eval_mask = mask.clone()
next_state_eval_mask = next_state_eval_mask * next_state_eval_mask.roll(shifts=-1, dims=1) * next_state_eval_mask.roll(shifts=1, dims=1)
for bos_token_index_ in bos_token_index:
next_state_eval_mask[bos_token_index_[0], :bos_token_index_[1]] = 0
next_state_eval_mask[bos_token_index_[0], bos_token_index_[1] : bos_token_index_[1] + 1] = 1
next_state_eval_mask[bos_token_index_[0], bos_token_index_[1] + 1 : bos_token_index_[1] + 2] = \
mask[bos_token_index_[0], bos_token_index_[1] + 2 : bos_token_index_[1] + 3]
for eos_token_index_ in eos_token_index:
next_state_eval_mask[eos_token_index_[0], eos_token_index_[1] + 1:] = 1
next_state_eval_mask[eos_token_index_[0], eos_token_index_[1] : eos_token_index_[1] + 1] = \
mask[eos_token_index_[0], eos_token_index_[1] - 1 : eos_token_index_[1]]
# seed agents
next_bos_token_index = torch.nonzero(next_state_index_gt == self.enter_state)
next_bos_token_index = next_bos_token_index[next_bos_token_index[:, 1] < num_step - 1]
next_state_index_gt_seed = torch.full((self.seed_size, num_step), self.seed_state_type.index('invalid'), device=next_state_index_gt.device)
next_type_index_gt_seed = torch.full((self.seed_size, num_step), self.seed_agent_type.index('seed'), device=next_state_index_gt.device)
next_eval_mask_seed = torch.ones_like(next_state_index_gt_seed)
num_seed = torch.zeros(num_step, device=next_state_index_gt.device).long()
for next_bos_token_index_ in next_bos_token_index:
if num_seed[next_bos_token_index_[1]] < self.seed_size:
next_state_index_gt_seed[num_seed[next_bos_token_index_[1]], next_bos_token_index_[1]] = self.seed_state_type.index('enter')
next_type_index_gt_seed[num_seed[next_bos_token_index_[1]], next_bos_token_index_[1]] = next_type_index_gt[next_bos_token_index_[0], next_bos_token_index_[1]]
num_seed[next_bos_token_index_[1]] += 1
# the last timestep is the beginning of the sequence (also the input)
next_token_eval_mask[:, -1] = 0
next_state_eval_mask[:, -1] = 0
next_eval_mask_seed[:, -1] = 0
# next_bos_token_eval_mask[:, -1] = False
# no invalid motion token will be supervised
if (next_token_index_gt[next_token_eval_mask] < 0).any():
raise RuntimeError()
next_state_index_gt[next_state_index_gt == self.exit_state] = self.valid_state_type.index('exit')
return {'x_a': feat_a,
# motion token
'next_token_idx': next_token_idx,
'next_token_prob': next_token_prob,
'next_token_idx_gt': next_token_index_gt,
'next_token_eval_mask': next_token_eval_mask.bool(),
# state token
'next_state_idx': next_state_idx,
'next_state_prob': next_state_prob,
'next_state_idx_gt': next_state_index_gt,
'next_state_eval_mask': next_state_eval_mask.bool(),
# seed agent
'next_state_idx_seed': next_state_idx_seed,
'next_state_prob_seed': next_state_prob_seed,
'next_state_idx_gt_seed': next_state_index_gt_seed,
'next_type_idx_seed': next_type_idx_seed,
'next_type_prob_seed': next_type_prob_seed,
'next_type_idx_gt_seed': next_type_index_gt_seed,
'next_eval_mask_seed': next_eval_mask_seed.bool(),
# pl token for bos
# 'next_bos_pl_idx': next_bos_pl_idx,
# 'next_bos_pl_prob': next_bos_pl_prob,
# 'next_bos_pl_index_gt': next_bos_pl_index_gt,
# offset token for bos
# 'next_bos_offset_idx': next_bos_offset_idx,
# 'next_bos_offset_prob': next_bos_offset_prob,
# 'next_bos_offset_index_gt': next_bos_offset_index_gt,
# 'next_bos_token_eval_mask': next_bos_token_eval_mask,
}
def inference(self,
data: HeteroData,
map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
start_state_idx = data['agent']['state_idx'][:, (self.num_historical_steps - 1) // self.shift]
filter_mask = (start_state_idx == self.valid_state) | (start_state_idx == self.exit_state)
seed_step_mask = data['agent']['state_idx'][:, (self.num_historical_steps - 1) // self.shift:] == self.enter_state
seed_agent_index_per_step = [torch.nonzero(seed_step_mask[:, t]).squeeze(dim=-1) for t in range(seed_step_mask.shape[1])]
if torch.any(seed_step_mask.sum(dim=0) > self.seed_size):
raise RuntimeError(f"Seed size {self.seed_size} is too small.")
# num_historical_steps=11
eval_mask = data['agent']['valid_mask'][filter_mask, self.num_historical_steps - 1]
if self.predict_state:
eval_mask = torch.ones_like(eval_mask).bool()
# agent attributes
pos_a = data['agent']['token_pos'][filter_mask].clone() # (num_agent, num_step, 2)
state_a = data['agent']['state_idx'][filter_mask].clone() # (num_agent, num_step)
head_a = data['agent']['token_heading'][filter_mask].clone() # (num_agent, num_step)
gt_traj = data['agent']['position'][filter_mask, self.num_historical_steps:, :self.input_dim].contiguous()
num_agent, num_step, traj_dim = pos_a.shape
av_index = int(data['agent']['av_index'])
av_index -= (~filter_mask[:av_index]).sum()
# map attributes
pos_pl = data['pt_token']['position'][:, :2].clone() # (num_pl, 2)
# make future steps to zero
pos_a[:, (self.num_historical_steps - 1) // self.shift:] = 0
state_a[:, (self.num_historical_steps - 1) // self.shift:] = 0
head_a[:, (self.num_historical_steps - 1) // self.shift:] = 0
agent_valid_mask = data['agent']['raw_agent_valid_mask'][filter_mask].clone() # token_valid_mask
agent_valid_mask[:, (self.num_historical_steps - 1) // self.shift:] = True
agent_valid_mask[~eval_mask] = False
agent_token_index = data['agent']['token_idx'][filter_mask]
agent_state_index = data['agent']['state_idx'][filter_mask]
agent_type = data['agent']['type'][filter_mask]
agent_category = data['agent']['category'][filter_mask]
feat_a, head_vector_a, agent_token_traj_all, agent_token_emb, categorical_embs = self.agent_token_embedding(data,
agent_token_index,
agent_state_index,
pos_a,
head_a,
inference=True,
filter_mask=filter_mask,
av_index=av_index,
)
feat_seed = feat_a[-1:]
feat_a = feat_a[:-1]
agent_type = data["agent"]["type"][filter_mask]
veh_mask = agent_type == 0
cyc_mask = agent_type == 2
ped_mask = agent_type == 1
# self.num_recurrent_steps_val = 91 - 11 = 80
self.num_recurrent_steps_val = data["agent"]['position'].shape[1] - self.num_historical_steps
pred_traj = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val, 2, device=feat_a.device) # (num_agent, 80, 2)
pred_head = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val, device=feat_a.device)
pred_type = agent_type.clone()
pred_state = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val, device=feat_a.device)
pred_prob = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val // self.shift, device=feat_a.device) # (num_agent, 80 // 5 = 16)
next_token_idx_list = []
next_state_idx_list = []
next_bos_pl_idx_list = []
next_bos_offset_idx_list = []
feat_a_t_dict = {}
feat_sa_t_dict = {}
# build masks (init)
mask = agent_valid_mask.clone()
temporal_mask = mask.clone()
interact_mask = mask.clone()
if self.predict_state:
# find bos and eos index
is_bos = agent_state_index == self.enter_state
is_eos = agent_state_index == self.exit_state
bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1))
temporal_mask = torch.ones_like(mask)
motion_mask = torch.arange(mask.shape[1]).expand(mask.shape[0], mask.shape[1]).to(mask.device)
motion_mask = (motion_mask > bos_index[:, None]) & (motion_mask <= eos_index[:, None])
motion_mask[:, self.num_historical_steps // self.shift:] = False
temporal_mask[motion_mask] = mask[motion_mask]
interact_mask = torch.ones_like(mask)
non_motion_mask = ~motion_mask
non_motion_mask[:, self.num_historical_steps // self.shift:] = False
interact_mask[non_motion_mask] = False
interact_mask[agent_state_index == self.enter_state] = True
temporal_mask[:, (self.num_historical_steps - 1) // self.shift:] = True
interact_mask[:, (self.num_historical_steps - 1) // self.shift:] = True
# mapping network
# z = torch.randn(num_agent, self.hidden_dim).to(feat_a.device)
# w = self.mapping_network(z)
# we only need to predict 16 next tokens
for t in range(self.num_recurrent_steps_val // self.shift):
# feat_a = feat_a + w[:, None]
num_agent = pos_a.shape[0]
if t == 0:
inference_mask = temporal_mask.clone()
inference_mask = torch.cat([inference_mask, torch.ones_like(inference_mask[-1:])])
inference_mask[:, (self.num_historical_steps - 1) // self.shift + t:] = False
else:
inference_mask = torch.zeros_like(temporal_mask)
inference_mask = torch.cat([inference_mask, torch.zeros_like(inference_mask[-1:])])
inference_mask[:, max((self.num_historical_steps - 1) // self.shift + t - (self.num_interaction_steps // self.shift), 0) :
(self.num_historical_steps - 1) // self.shift + t] = True
interact_mask = torch.cat([interact_mask, torch.ones_like(interact_mask[:1])]).bool() # placeholder
edge_index_t, r_t = self.build_temporal_edge(pos_a, head_a, head_vector_a, state_a, temporal_mask, inference_mask,
av_index=av_index)
# +1: placeholder for seed agent
batch_s = torch.arange(num_step, device=pos_a.device).repeat_interleave(num_agent + 1)
batch_pl = torch.arange(num_step, device=pos_a.device).repeat_interleave(data['pt_token']['num_nodes'])
# In the inference stage, we only infer the current stage for recurrent
edge_index_a2a, r_a2a = self.build_interaction_edge(pos_a, head_a, head_vector_a, state_a, batch_s,
interact_mask, inference_mask, av_index=av_index)
edge_index_pl2a, r_pl2a = self.build_map2agent_edge(data, num_step, pos_a, head_a, head_vector_a, state_a, batch_s, batch_pl,
interact_mask, inference_mask, av_index=av_index)
interact_mask = interact_mask[:-1]
# if t > 0:
# feat_a_sum = feat_a.sum(dim=-1)
# for a in range(pos_a.shape[0]):
# t_1 = (self.num_historical_steps - 1) // self.shift + t - 1
# print(f"agent {a} t_1 {t_1}")
# print(f"token: {next_token_idx[a]}")
# print(f"state: {next_state_idx[a]}")
# print(f"feat_a_sum: {feat_a_sum[a, t_1]}")
for i in range(self.num_layers):
if (i in feat_a_t_dict) and (i in feat_sa_t_dict):
feat_a = feat_a_t_dict[i]
feat_seed = feat_sa_t_dict[i]
feat_a = torch.cat([feat_a, feat_seed], dim=0)
feat_a = feat_a.reshape(-1, self.hidden_dim)
feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t)
feat_a = feat_a.reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim)
feat_a = self.pt2a_attn_layers[i]((
map_enc['x_pt'].repeat_interleave(repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(
-1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a)
feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a)
feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1)
feat_seed = feat_a[-1:] # (1, num_step, hidden_dim)
feat_a = feat_a[:-1] # (num_agent, num_step, hidden_dim)
if t == 0:
feat_a_t_dict[i + 1] = feat_a
feat_sa_t_dict[i + 1] = feat_seed
else:
# update agent features at current step
n = feat_a_t_dict[i + 1].shape[0]
feat_a_t_dict[i + 1][:n, (self.num_historical_steps - 1) // self.shift - 1 + t] = feat_a[:n, (self.num_historical_steps - 1) // self.shift - 1 + t]
# add newly inserted agent features (only when t changed)
if feat_a.shape[0] > n:
m = feat_a.shape[0] - n
feat_a_t_dict[i + 1] = torch.cat([feat_a_t_dict[i + 1], feat_a[-m:]])
# update seed agent features at current step
feat_sa_t_dict[i + 1][:, (self.num_historical_steps - 1) // self.shift - 1 + t] = feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t]
# next motion token
next_token_prob = self.token_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)
topk_token_prob, next_token_idx = torch.topk(next_token_prob_softmax, k=self.beam_size, dim=-1) # both (num_agent, beam_size) e.g. (31, 5)
# next state token
next_state_prob = self.state_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
next_state_idx = next_state_prob.softmax(dim=-1).argmax(dim=-1)
next_state_idx[next_state_idx == self.valid_state_type.index('exit')] = self.exit_state
# seed agent
feat_seed = self.seed_head(feat_seed) + self.seed_feature.weight[:, None]
next_state_prob_seed = self.seed_state_predict_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
next_state_idx_seed = next_state_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True)
next_state_idx_seed[next_state_idx_seed == self.seed_state_type.index('enter')] = self.enter_state
next_type_prob_seed = self.seed_type_predict_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
next_type_idx_seed = next_type_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True)
# print(f"t: {t}")
# print(next_type_idx_seed[..., 0].tolist())
# bos pl prediction
# next_bos_pl_prob = self.bos_pl_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
# next_bos_pl_prob_softmax = torch.softmax(next_bos_pl_prob, dim=-1)
# next_bos_pl_idx = torch.argmax(next_bos_pl_prob_softmax, dim=-1)
# bos offset prediction
# next_bos_offset_prob = self.bos_offset_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
# next_bos_offset_prob_softmax = torch.softmax(next_bos_offset_prob, dim=-1)
# next_bos_offset_idx = torch.argmax(next_bos_offset_prob_softmax, dim=-1)
# convert the predicted token to a 0.5s (6 timesteps) trajectory
expanded_token_index = next_token_idx[..., None, None, None].expand(-1, -1, 6, 4, 2)
next_token_traj = torch.gather(agent_token_traj_all, 1, expanded_token_index) # (num_agent, beam_size, 6, 4, 2)
# apply rotation and translation on 'next_token_traj'
theta = head_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]
cos, sin = theta.cos(), theta.sin()
rot_mat = torch.zeros((num_agent, 2, 2), device=theta.device)
rot_mat[:, 0, 0] = cos
rot_mat[:, 0, 1] = sin
rot_mat[:, 1, 0] = -sin
rot_mat[:, 1, 1] = cos
agent_diff_rel = torch.bmm(next_token_traj.view(-1, 4, 2),
rot_mat[:, None, None, ...].repeat(1, self.beam_size, self.shift + 1, 1, 1).view(
-1, 2, 2)).view(num_agent, self.beam_size, self.shift + 1, 4, 2)
agent_pred_rel = agent_diff_rel + pos_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t, :][:, None, None, None, ...]
# sample 1 most probable index of top beam_size tokens, (num_agent, beam_size) -> (num_agent, 1)
# then sample the agent_pred_rel, (num_agent, beam_size, 6, 4, 2) -> (num_agent, 6, 4, 2)
sample_token_index = torch.multinomial(topk_token_prob, 1).to(agent_pred_rel.device)
next_token_idx = next_token_idx.gather(dim=1, index=sample_token_index).squeeze(-1)
agent_pred_rel = agent_pred_rel.gather(dim=1,
index=sample_token_index[..., None, None, None].expand(-1, -1, 6, 4,
2))[:, 0, ...]
# get predicted position and heading of current shifted timesteps
diff_xy = agent_pred_rel[:, 1:, 0, :] - agent_pred_rel[:, 1:, 3, :]
pred_traj[:num_agent, t * 5 : (t + 1) * 5] = agent_pred_rel[:, 1:, ...].clone().mean(dim=2)
pred_head[:num_agent, t * 5 : (t + 1) * 5] = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])
pred_state[:num_agent, t * 5 : (t + 1) * 5] = next_state_idx[:, None].repeat(1, 5)
# pred_prob[:num_agent, t] = topk_token_prob.gather(dim=-1, index=sample_token_index)[:, 0] # (num_agent, beam_size) -> (num_agent,)
# update pos/head/state of current step
pos_a[:, (self.num_historical_steps - 1) // self.shift + t] = agent_pred_rel[:, -1, ...].clone().mean(dim=1)
diff_xy = agent_pred_rel[:, -1, 0, :] - agent_pred_rel[:, -1, 3, :]
theta = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])
head_a[:, (self.num_historical_steps - 1) // self.shift + t] = theta
state_a[:, (self.num_historical_steps - 1) // self.shift + t] = next_state_idx
# the case that the current predicted state token is invalid/exit
is_eos = next_state_idx == self.exit_state
is_invalid = next_state_idx == self.invalid_state
next_token_idx[is_invalid] = -1
pos_a[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = 0.
head_a[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = 0.
mask[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = False # to handle those newly-added agents
interact_mask[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = False
agent_token_emb[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = self.no_token_emb(torch.zeros(1, device=pos_a.device).long())
type_emb = categorical_embs[0].reshape(num_agent, num_step, -1)
shape_emb = categorical_embs[1].reshape(num_agent, num_step, -1)
type_emb[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = self.type_a_emb(torch.tensor(self.all_agent_type.index('invalid'), device=pos_a.device).long())
shape_emb[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = self.shape_emb(torch.full((1, 3), self.invalid_shape_value, device=pos_a.device))
categorical_embs = [type_emb.reshape(-1, self.hidden_dim), shape_emb.reshape(-1, self.hidden_dim)]
# FIXME: need to discuss!!!
# if is_eos.any():
# pos_a[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = 0.
# head_a[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = 0.
# mask[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = False # to handle those newly-added agents
# interact_mask[torch.cat([is_eos, torch.zeros(1, device=is_eos.device).bool()]), (self.num_historical_steps - 1) // self.shift + t + 1:] = False
# agent_token_emb[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = self.no_token_emb(torch.zeros(1, device=pos_a.device).long())
# type_emb = categorical_embs[0].reshape(num_agent, num_step, -1)
# shape_emb = categorical_embs[1].reshape(num_agent, num_step, -1)
# type_emb[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = self.type_a_emb(torch.tensor(self.all_agent_type.index('invalid'), device=pos_a.device).long())
# shape_emb[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = self.shape_emb(torch.full((1, 3), self.invalid_shape_value, device=pos_a.device))
# categorical_embs = [type_emb.reshape(-1, self.hidden_dim), shape_emb.reshape(-1, self.hidden_dim)]
# for sa in range(next_state_idx_seed.shape[0]):
# if next_state_idx_seed[sa] == self.enter_state:
# print(f"agent {sa} is entering at step {t}")
# insert new agents (from seed agent)
seed_agent_index_cur_step = seed_agent_index_per_step[t]
num_new_agent = min(len(seed_agent_index_cur_step), next_state_idx_seed.bool().sum())
new_agent_mask = next_state_idx_seed.bool()
next_state_idx_seed = next_state_idx_seed[new_agent_mask]
next_state_idx_seed = next_state_idx_seed[:num_new_agent]
next_type_idx_seed = next_type_idx_seed[new_agent_mask]
next_type_idx_seed = next_type_idx_seed[:num_new_agent]
selected_agent_index_cur_step = seed_agent_index_cur_step[:num_new_agent]
agent_token_index = torch.cat([agent_token_index, data['agent']['token_idx'][selected_agent_index_cur_step]])
agent_state_index = torch.cat([agent_state_index, data['agent']['state_idx'][selected_agent_index_cur_step]])
agent_category = torch.cat([agent_category, data['agent']['category'][selected_agent_index_cur_step]])
agent_valid_mask = torch.cat([agent_valid_mask, data['agent']['raw_agent_valid_mask'][selected_agent_index_cur_step]])
gt_traj = torch.cat([gt_traj, data['agent']['position'][selected_agent_index_cur_step, self.num_historical_steps:, :self.input_dim]])
# FIXME: under test!!! bos token index is -2
next_state_idx = torch.cat([next_state_idx, next_state_idx_seed], dim=0).long()
next_token_idx = torch.cat([next_token_idx, torch.zeros(num_new_agent, device=next_token_idx.device) - 2], dim=0).long()
mask = torch.cat([mask, torch.ones(num_new_agent, num_step, device=mask.device)], dim=0).bool()
temporal_mask = torch.cat([temporal_mask, torch.ones(num_new_agent, num_step, device=temporal_mask.device)], dim=0).bool()
interact_mask = torch.cat([interact_mask, torch.ones(num_new_agent, num_step, device=interact_mask.device)], dim=0).bool()
# new_pos_a = ego_pos_a[None].repeat(num_new_agent, 1, 1)
# new_head_a = ego_head_a[None].repeat(num_new_agent, 1)
new_pos_a = torch.zeros(num_new_agent, num_step, 2, device=pos_a.device)
new_head_a = torch.zeros(num_new_agent, num_step, device=pos_a.device)
new_state_a = torch.zeros(num_new_agent, num_step, device=state_a.device)
new_shape_a = torch.full((num_new_agent, num_step, 3), self.invalid_shape_value, device=pos_a.device)
new_type_a = torch.full((num_new_agent, num_step), self.all_agent_type.index('invalid'), device=pos_a.device)
if num_new_agent > 0:
gt_bos_pos_a = data['agent']['position'][seed_agent_index_cur_step[:num_new_agent], (self.num_historical_steps - 1) // self.shift + t]
new_pos_a[:, (self.num_historical_steps - 1) // self.shift + t] = gt_bos_pos_a[:, :2].clone()
pos_a = torch.cat([pos_a, new_pos_a], dim=0)
gt_bos_head_a = data['agent']['heading'][seed_agent_index_cur_step[:num_new_agent], (self.num_historical_steps - 1) // self.shift + t]
new_head_a[:, (self.num_historical_steps - 1) // self.shift + t] = gt_bos_head_a.clone()
head_a = torch.cat([head_a, new_head_a], dim=0)
gt_bos_shape_a = data['agent']['shape'][seed_agent_index_cur_step[:num_new_agent], self.num_historical_steps - 1]
gt_bos_type_a = data['agent']['type'][seed_agent_index_cur_step[:num_new_agent]]
new_shape_a[:, (self.num_historical_steps - 1) // self.shift + t:] = gt_bos_shape_a.clone()[:, None]
new_type_a[:, (self.num_historical_steps - 1) // self.shift + t:] = gt_bos_type_a.clone()[:, None]
# new_type_a[:, (self.num_historical_steps - 1) // self.shift + t] = next_type_idx_seed
pred_type = torch.cat([pred_type, new_type_a[:, (self.num_historical_steps - 1) // self.shift + t]])
new_state_a[:, (self.num_historical_steps - 1) // self.shift + t] = self.enter_state
state_a = torch.cat([state_a, new_state_a], dim=0)
mask[-num_new_agent:, :(self.num_historical_steps - 1) // self.shift + t + 1] = 0
interact_mask[-num_new_agent:, :(self.num_historical_steps - 1) // self.shift + t] = 0
# update all steps
new_pred_traj = torch.zeros(num_new_agent, self.num_recurrent_steps_val, 2, device=pos_a.device)
new_pred_traj[:, t * 5 : (t + 1) * 5] = new_pos_a[:, (self.num_historical_steps - 1) // self.shift + t][:, None].repeat(1, 5, 1)
pred_traj = torch.cat([pred_traj, new_pred_traj], dim=0)
new_pred_head = torch.zeros(num_new_agent, self.num_recurrent_steps_val, device=pos_a.device)
new_pred_head[:, t * 5 : (t + 1) * 5] = new_head_a[:, (self.num_historical_steps - 1) // self.shift + t][:, None].repeat(1, 5)
pred_head = torch.cat([pred_head, new_pred_head], dim=0)
new_pred_state = torch.zeros(num_new_agent, self.num_recurrent_steps_val, device=pos_a.device)
new_pred_state[:, t * 5 : (t + 1) * 5] = next_state_idx_seed[:, None].repeat(1, 5)
pred_state = torch.cat([pred_state, new_pred_state], dim=0)
# handle the position/heading of bos token
# bos_pl_pos = pos_pl[next_bos_pl_idx[is_bos].long()]
# bos_offset_pos = discretize_neighboring(neighbor_index=next_bos_offset_idx[is_bos])
# pos_a[is_bos, (self.num_historical_steps - 1) // self.shift + t] += (bos_pl_pos + bos_offset_pos)
# # headings before bos token remains 0 which align with training process
# head_a[is_bos, (self.num_historical_steps - 1) // self.shift + t] += 0.
# add new agents token embeddings
agent_token_emb = torch.cat([agent_token_emb, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())[None, :].repeat(num_new_agent, num_step, 1)])
veh_mask = torch.cat([veh_mask, next_type_idx_seed == self.seed_agent_type.index('veh')])
ped_mask = torch.cat([ped_mask, next_type_idx_seed == self.seed_agent_type.index('ped')])
cyc_mask = torch.cat([cyc_mask, next_type_idx_seed == self.seed_agent_type.index('cyc')])
# add new agents trajectory embeddings
trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(pos_a.device).to(torch.float)
trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(pos_a.device).to(torch.float)
trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(pos_a.device).to(torch.float)
new_agent_token_traj_all = torch.zeros((num_new_agent, self.token_size, self.shift + 1, 4, 2), device=pos_a.device)
trajectory_token_all_veh = torch.from_numpy(self.trajectory_token_all['veh']).clone().to(pos_a.device).to(torch.float)
trajectory_token_all_ped = torch.from_numpy(self.trajectory_token_all['ped']).clone().to(pos_a.device).to(torch.float)
trajectory_token_all_cyc = torch.from_numpy(self.trajectory_token_all['cyc']).clone().to(pos_a.device).to(torch.float)
new_agent_token_traj_all[next_type_idx_seed == 0] = torch.cat(
[trajectory_token_all_veh[:, :self.shift], trajectory_token_veh[:, None, ...]], dim=1)
new_agent_token_traj_all[next_type_idx_seed == 1] = torch.cat(
[trajectory_token_all_ped[:, :self.shift], trajectory_token_ped[:, None, ...]], dim=1)
new_agent_token_traj_all[next_type_idx_seed == 2] = torch.cat(
[trajectory_token_all_cyc[:, :self.shift], trajectory_token_cyc[:, None, ...]], dim=1)
agent_token_traj_all = torch.cat([agent_token_traj_all, new_agent_token_traj_all], dim=0)
# add new agents categorical embeddings
new_categorical_embs = [self.type_a_emb(new_type_a.reshape(-1).long()), self.shape_emb(new_shape_a.reshape(-1, 3))]
categorical_embs = [torch.cat([categorical_embs[0], new_categorical_embs[0]], dim=0),
torch.cat([categorical_embs[1], new_categorical_embs[1]], dim=0)]
# update token embeddings of current step
agent_token_emb[veh_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_veh[
next_token_idx[veh_mask]]
agent_token_emb[ped_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_ped[
next_token_idx[ped_mask]]
agent_token_emb[cyc_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_cyc[
next_token_idx[cyc_mask]]
motion_vector_a, head_vector_a = self.build_vector_a(pos_a, head_a, state_a)
motion_vector_a[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0.
head_vector_a[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0.
x_a = torch.stack(
[torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2])], dim=-1)
x_b = x_a.clone()
x_a = self.x_a_emb(continuous_inputs=x_a.view(-1, x_a.size(-1)),
categorical_embs=categorical_embs)
x_a = x_a.view(-1, num_step, self.hidden_dim)
s_a = self.state_a_emb(state_a.reshape(-1).long()).reshape(num_agent + num_new_agent, num_step, self.hidden_dim)
feat_a = torch.cat((agent_token_emb, x_a, s_a), dim=-1)
feat_a = self.fusion_emb(feat_a)
# if t >= 15:
# print(f"inference {t}")
# is_invalid = state_a == self.invalid_state
# is_bos = state_a == self.enter_state
# is_eos = state_a == self.exit_state
# bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
# eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1))
# mask = torch.arange(num_step).expand(num_agent + num_new_agent, -1).to(state_a.device)
# mask = (mask >= bos_index[:, None]) & (mask <= eos_index[:, None] + 1)
# is_invalid[mask] = False
# is_invalid[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = False
# print(pos_a[:, :((self.num_historical_steps - 1) // self.shift + 1 + t)])
# print(state_a[:, :((self.num_historical_steps - 1) // self.shift + 1 + t)])
# print(pos_a[is_invalid][:, 0])
# print(head_a[is_invalid])
# print(categorical_embs[0].sum(dim=-1)[is_invalid.reshape(-1)])
# print(categorical_embs[1].sum(dim=-1)[is_invalid.reshape(-1)])
# print(motion_vector_a[is_invalid][:, 0])
# print(head_vector_a[is_invalid][:, 0])
# print(x_b.sum(dim=-1)[is_invalid])
# print(x_a.sum(dim=-1)[is_invalid])
# for a in range(state_a.shape[0]):
# print(f"agent: {a}")
# print(state_a[a])
# print(is_invalid[a].long())
# print(pos_a[a, :, 0])
# print(motion_vector_a[a, :, 0])
# print(s_a.sum(dim=-1)[is_invalid])
# print(feat_a.sum(dim=-1)[is_invalid])
# replace the features of steps before bos of valid agents with the corresponding seed agent features
# is_bos = state_a == self.enter_state
# bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(num_step))
# before_bos_mask = torch.arange(num_step).expand(num_agent + num_new_agent, -1).to(state_a.device) < bos_index[:, None]
# feat_a[before_bos_mask] = feat_seed.repeat(num_agent + num_new_agent, 1, 1)[before_bos_mask]
# build seed agent features
motion_vector_seed = motion_vector_a[av_index : av_index + 1]
head_vector_seed = head_vector_a[av_index : av_index + 1]
feat_seed = self.build_invalid_agent_feature(num_step, pos_a.device, type_index=self.all_agent_type.index('seed'),
motion_vector=motion_vector_seed, head_vector=head_vector_seed)
# print(f"inference {t}")
# print(feat_seed.sum(dim=-1))
next_token_idx_list.append(next_token_idx[:, None])
next_state_idx_list.append(next_state_idx[:, None])
# next_bos_pl_idx_list.append(next_bos_pl_idx[:, None])
# next_bos_offset_idx_list.append(next_bos_offset_idx[:, None])
# TODO: check this
# agent_valid_mask[agent_category != 3] = False
# print("inference")
# is_invalid = state_a == self.invalid_state
# is_bos = state_a == self.enter_state
# is_eos = state_a == self.exit_state
# bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
# eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1))
# mask = torch.arange(num_step).expand(num_agent, -1).to(state_a.device)
# mask = (mask >= bos_index[:, None]) & (mask <= eos_index[:, None] + 1)
# is_invalid[mask] = False
# print(feat_a.sum(dim=-1)[is_invalid])
# print(pos_a[is_invalid][: 0])
# print(head_a[is_invalid])
# exit(1)
num_agent = pos_a.shape[0]
for i in range(len(next_token_idx_list)):
next_token_idx_list[i] = torch.cat([next_token_idx_list[i], torch.zeros(num_agent - next_token_idx_list[i].shape[0], 1, device=next_token_idx_list[i].device) - 1], dim=0).long()
next_state_idx_list[i] = torch.cat([next_state_idx_list[i], torch.zeros(num_agent - next_state_idx_list[i].shape[0], 1, device=next_state_idx_list[i].device)], dim=0).long()
# eval mask
next_token_eval_mask = agent_valid_mask.clone()
next_state_eval_mask = agent_valid_mask.clone()
bos_token_index = torch.nonzero(agent_state_index == self.enter_state)
eos_token_index = torch.nonzero(agent_state_index == self.exit_state)
next_token_eval_mask[bos_token_index[:, 0], bos_token_index[:, 1]] = 1
for bos_token_index_i in bos_token_index:
next_state_eval_mask[bos_token_index_i[0], :bos_token_index_i[1] + 2] = 1
for eos_token_index_i in eos_token_index:
next_state_eval_mask[eos_token_index_i[0], eos_token_index_i[1]:] = 1
# add history attributes
num_agent = pred_traj.shape[0]
num_init_agent = filter_mask.sum()
pred_traj = torch.cat([pred_traj, torch.zeros(num_agent, self.num_historical_steps - 1, *(pred_traj.shape[2:]), device=pred_traj.device)], dim=1)
pred_head = torch.cat([pred_head, torch.zeros(num_agent, self.num_historical_steps - 1, *(pred_head.shape[2:]), device=pred_head.device)], dim=1)
pred_state = torch.cat([pred_state, torch.zeros(num_agent, self.num_historical_steps - 1, *(pred_state.shape[2:]), device=pred_state.device)], dim=1)
pred_state[:num_init_agent, :self.num_historical_steps - 1] = data['agent']['state_idx'][filter_mask, :(self.num_historical_steps - 1) // self.shift].repeat_interleave(repeats=self.shift, dim=1)
historical_token_idx = data['agent']['token_idx'][filter_mask, :(self.num_historical_steps - 1) // self.shift]
historical_token_idx[historical_token_idx < 0] = 0
historical_token_traj_all = torch.gather(agent_token_traj_all, 1, historical_token_idx[..., None, None, None].expand(-1, -1, 6, 4, 2))
init_theta = head_a[:num_init_agent, 0]
cos, sin = init_theta.cos(), init_theta.sin()
rot_mat = torch.zeros((num_init_agent, 2, 2), device=init_theta.device)
rot_mat[:, 0, 0] = cos
rot_mat[:, 0, 1] = sin
rot_mat[:, 1, 0] = -sin
rot_mat[:, 1, 1] = cos
historical_token_traj_all = torch.bmm(historical_token_traj_all.view(-1, 4, 2),
rot_mat[:, None, None, ...].repeat(1, (self.num_historical_steps - 1) // self.shift, self.shift + 1, 1, 1).view(
-1, 2, 2)).view(num_init_agent, (self.num_historical_steps - 1) // self.shift, self.shift + 1, 4, 2)
historical_token_traj_all = historical_token_traj_all + pos_a[:num_init_agent, 0, :][:, None, None, None, ...]
pred_traj[:num_init_agent, :self.num_historical_steps - 1] = historical_token_traj_all[:, :, 1:, ...].clone().mean(dim=3).reshape(num_init_agent, -1, 2)
diff_xy = historical_token_traj_all[..., 1:, 0, :] - historical_token_traj_all[..., 1:, 3, :]
pred_head[:num_init_agent, :self.num_historical_steps - 1] = torch.arctan2(diff_xy[..., 1], diff_xy[..., 0]).reshape(num_init_agent, -1)
return {
'av_index': av_index,
'valid_mask': agent_valid_mask[:, self.num_historical_steps:],
'pos_a': pos_a[:, (self.num_historical_steps - 1) // self.shift:],
'head_a': head_a[:, (self.num_historical_steps - 1) // self.shift:],
'gt_traj': gt_traj,
'pred_traj': pred_traj,
'pred_head': pred_head,
'pred_type': list(map(lambda i: self.seed_agent_type[i], pred_type.tolist())),
'pred_state': pred_state,
'next_token_idx': torch.cat(next_token_idx_list, dim=-1), # (num_agent, num_step)
'next_token_idx_gt': agent_token_index,
'next_state_idx': torch.cat(next_state_idx_list, dim=-1) if len(next_state_idx_list) > 0 else None,
'next_state_idx_gt': agent_state_index,
'next_token_eval_mask': next_token_eval_mask,
'next_state_eval_mask': next_state_eval_mask,
# 'next_bos_pl_idx': torch.cat(next_bos_pl_idx_list, dim=-1),
# 'next_bos_offset_idx': torch.cat(next_bos_offset_idx_list, dim=-1),
}