from typing import Dict import torch import torch.nn as nn from torch_cluster import radius_graph from torch_geometric.data import Batch from torch_geometric.data import HeteroData from torch_geometric.utils import subgraph from dev.modules.layers import MLPLayer, AttentionLayer, FourierEmbedding, MLPEmbedding from dev.utils.func import weight_init, wrap_angle, angle_between_2d_vectors class SMARTMapDecoder(nn.Module): def __init__(self, dataset: str, input_dim: int, hidden_dim: int, num_historical_steps: int, pl2pl_radius: float, num_freq_bands: int, num_layers: int, num_heads: int, head_dim: int, dropout: float, map_token) -> None: super(SMARTMapDecoder, self).__init__() self.dataset = dataset self.input_dim = input_dim self.hidden_dim = hidden_dim self.num_historical_steps = num_historical_steps self.pl2pl_radius = pl2pl_radius self.num_freq_bands = num_freq_bands self.num_layers = num_layers self.num_heads = num_heads self.head_dim = head_dim self.dropout = dropout if input_dim == 2: input_dim_r_pt2pt = 3 elif input_dim == 3: input_dim_r_pt2pt = 4 else: raise ValueError('{} is not a valid dimension'.format(input_dim)) self.type_pt_emb = nn.Embedding(17, hidden_dim) self.side_pt_emb = nn.Embedding(4, hidden_dim) self.polygon_type_emb = nn.Embedding(4, hidden_dim) self.light_pl_emb = nn.Embedding(4, hidden_dim) self.r_pt2pt_emb = FourierEmbedding(input_dim=input_dim_r_pt2pt, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands) self.pt2pt_layers = nn.ModuleList( [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, bipartite=False, has_pos_emb=True) for _ in range(num_layers)] ) self.token_size = 1024 self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=self.token_size) input_dim_token = 22 self.token_emb = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim) self.map_token = map_token self.apply(weight_init) self.mask_pt = False def maybe_autocast(self, dtype=torch.float32): return torch.cuda.amp.autocast(dtype=dtype) def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]: pt_valid_mask = data['pt_token']['pt_valid_mask'] pt_pred_mask = data['pt_token']['pt_pred_mask'] pt_target_mask = data['pt_token']['pt_target_mask'] mask_s = pt_valid_mask pos_pt = data['pt_token']['position'][:, :self.input_dim].contiguous() orient_pt = data['pt_token']['orientation'].contiguous() orient_vector_pt = torch.stack([orient_pt.cos(), orient_pt.sin()], dim=-1) token_sample_pt = self.map_token['traj_src'].to(pos_pt.device).to(torch.float) pt_token_emb_src = self.token_emb(token_sample_pt.view(token_sample_pt.shape[0], -1)) pt_token_emb = pt_token_emb_src[data['pt_token']['token_idx']] x_pt = pt_token_emb token2pl = data[('pt_token', 'to', 'map_polygon')]['edge_index'] token_light_type = data['map_polygon']['light_type'][token2pl[1]] x_pt_categorical_embs = [self.type_pt_emb(data['pt_token']['type'].long()), self.polygon_type_emb(data['pt_token']['pl_type'].long()), self.light_pl_emb(token_light_type.long()),] x_pt = x_pt + torch.stack(x_pt_categorical_embs).sum(dim=0) edge_index_pt2pt = radius_graph(x=pos_pt[:, :2], r=self.pl2pl_radius, batch=data['pt_token']['batch'] if isinstance(data, Batch) else None, loop=False, max_num_neighbors=100) if self.mask_pt: edge_index_pt2pt = subgraph(subset=mask_s, edge_index=edge_index_pt2pt)[0] rel_pos_pt2pt = pos_pt[edge_index_pt2pt[0]] - pos_pt[edge_index_pt2pt[1]] rel_orient_pt2pt = wrap_angle(orient_pt[edge_index_pt2pt[0]] - orient_pt[edge_index_pt2pt[1]]) if self.input_dim == 2: r_pt2pt = torch.stack( [torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1), angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]], nbr_vector=rel_pos_pt2pt[:, :2]), rel_orient_pt2pt], dim=-1) elif self.input_dim == 3: r_pt2pt = torch.stack( [torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1), angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]], nbr_vector=rel_pos_pt2pt[:, :2]), rel_pos_pt2pt[:, -1], rel_orient_pt2pt], dim=-1) else: raise ValueError('{} is not a valid dimension'.format(self.input_dim)) # layers r_pt2pt = self.r_pt2pt_emb(continuous_inputs=r_pt2pt, categorical_embs=None) for i in range(self.num_layers): x_pt = self.pt2pt_layers[i](x_pt, r_pt2pt, edge_index_pt2pt) next_token_prob = self.token_predict_head(x_pt[pt_pred_mask]) next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1) _, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1) next_token_index_gt = data['pt_token']['token_idx'][pt_target_mask] return { 'x_pt': x_pt, 'map_next_token_idx': next_token_idx, 'map_next_token_prob': next_token_prob, 'map_next_token_idx_gt': next_token_index_gt, 'map_next_token_eval_mask': pt_pred_mask[pt_pred_mask] }