gzzyyxy's picture
Upload folder using huggingface_hub
c1a7f73 verified
from typing import Dict
import torch
import torch.nn as nn
from torch_cluster import radius_graph
from torch_geometric.data import Batch
from torch_geometric.data import HeteroData
from torch_geometric.utils import subgraph
from dev.modules.layers import MLPLayer, AttentionLayer, FourierEmbedding, MLPEmbedding
from dev.utils.func import weight_init, wrap_angle, angle_between_2d_vectors
class SMARTMapDecoder(nn.Module):
def __init__(self,
dataset: str,
input_dim: int,
hidden_dim: int,
num_historical_steps: int,
pl2pl_radius: float,
num_freq_bands: int,
num_layers: int,
num_heads: int,
head_dim: int,
dropout: float,
map_token) -> None:
super(SMARTMapDecoder, self).__init__()
self.dataset = dataset
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_historical_steps = num_historical_steps
self.pl2pl_radius = pl2pl_radius
self.num_freq_bands = num_freq_bands
self.num_layers = num_layers
self.num_heads = num_heads
self.head_dim = head_dim
self.dropout = dropout
if input_dim == 2:
input_dim_r_pt2pt = 3
elif input_dim == 3:
input_dim_r_pt2pt = 4
else:
raise ValueError('{} is not a valid dimension'.format(input_dim))
self.type_pt_emb = nn.Embedding(17, hidden_dim)
self.side_pt_emb = nn.Embedding(4, hidden_dim)
self.polygon_type_emb = nn.Embedding(4, hidden_dim)
self.light_pl_emb = nn.Embedding(4, hidden_dim)
self.r_pt2pt_emb = FourierEmbedding(input_dim=input_dim_r_pt2pt, hidden_dim=hidden_dim,
num_freq_bands=num_freq_bands)
self.pt2pt_layers = nn.ModuleList(
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
bipartite=False, has_pos_emb=True) for _ in range(num_layers)]
)
self.token_size = 1024
self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=self.token_size)
input_dim_token = 22
self.token_emb = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
self.map_token = map_token
self.apply(weight_init)
self.mask_pt = False
def maybe_autocast(self, dtype=torch.float32):
return torch.cuda.amp.autocast(dtype=dtype)
def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]:
pt_valid_mask = data['pt_token']['pt_valid_mask']
pt_pred_mask = data['pt_token']['pt_pred_mask']
pt_target_mask = data['pt_token']['pt_target_mask']
mask_s = pt_valid_mask
pos_pt = data['pt_token']['position'][:, :self.input_dim].contiguous()
orient_pt = data['pt_token']['orientation'].contiguous()
orient_vector_pt = torch.stack([orient_pt.cos(), orient_pt.sin()], dim=-1)
token_sample_pt = self.map_token['traj_src'].to(pos_pt.device).to(torch.float)
pt_token_emb_src = self.token_emb(token_sample_pt.view(token_sample_pt.shape[0], -1))
pt_token_emb = pt_token_emb_src[data['pt_token']['token_idx']]
x_pt = pt_token_emb
token2pl = data[('pt_token', 'to', 'map_polygon')]['edge_index']
token_light_type = data['map_polygon']['light_type'][token2pl[1]]
x_pt_categorical_embs = [self.type_pt_emb(data['pt_token']['type'].long()),
self.polygon_type_emb(data['pt_token']['pl_type'].long()),
self.light_pl_emb(token_light_type.long()),]
x_pt = x_pt + torch.stack(x_pt_categorical_embs).sum(dim=0)
edge_index_pt2pt = radius_graph(x=pos_pt[:, :2], r=self.pl2pl_radius,
batch=data['pt_token']['batch'] if isinstance(data, Batch) else None,
loop=False, max_num_neighbors=100)
if self.mask_pt:
edge_index_pt2pt = subgraph(subset=mask_s, edge_index=edge_index_pt2pt)[0]
rel_pos_pt2pt = pos_pt[edge_index_pt2pt[0]] - pos_pt[edge_index_pt2pt[1]]
rel_orient_pt2pt = wrap_angle(orient_pt[edge_index_pt2pt[0]] - orient_pt[edge_index_pt2pt[1]])
if self.input_dim == 2:
r_pt2pt = torch.stack(
[torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]],
nbr_vector=rel_pos_pt2pt[:, :2]),
rel_orient_pt2pt], dim=-1)
elif self.input_dim == 3:
r_pt2pt = torch.stack(
[torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]],
nbr_vector=rel_pos_pt2pt[:, :2]),
rel_pos_pt2pt[:, -1],
rel_orient_pt2pt], dim=-1)
else:
raise ValueError('{} is not a valid dimension'.format(self.input_dim))
# layers
r_pt2pt = self.r_pt2pt_emb(continuous_inputs=r_pt2pt, categorical_embs=None)
for i in range(self.num_layers):
x_pt = self.pt2pt_layers[i](x_pt, r_pt2pt, edge_index_pt2pt)
next_token_prob = self.token_predict_head(x_pt[pt_pred_mask])
next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)
_, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1)
next_token_index_gt = data['pt_token']['token_idx'][pt_target_mask]
return {
'x_pt': x_pt,
'map_next_token_idx': next_token_idx,
'map_next_token_prob': next_token_prob,
'map_next_token_idx_gt': next_token_index_gt,
'map_next_token_eval_mask': pt_pred_mask[pt_pred_mask]
}