File size: 6,077 Bytes
c1a7f73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from typing import Dict
import torch
import torch.nn as nn
from torch_cluster import radius_graph
from torch_geometric.data import Batch
from torch_geometric.data import HeteroData
from torch_geometric.utils import subgraph

from dev.modules.layers import MLPLayer, AttentionLayer, FourierEmbedding, MLPEmbedding
from dev.utils.func import weight_init, wrap_angle, angle_between_2d_vectors


class SMARTMapDecoder(nn.Module):

    def __init__(self,
                 dataset: str,
                 input_dim: int,
                 hidden_dim: int,
                 num_historical_steps: int,
                 pl2pl_radius: float,
                 num_freq_bands: int,
                 num_layers: int,
                 num_heads: int,
                 head_dim: int,
                 dropout: float,
                 map_token) -> None:

        super(SMARTMapDecoder, self).__init__()
        self.dataset = dataset
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_historical_steps = num_historical_steps
        self.pl2pl_radius = pl2pl_radius
        self.num_freq_bands = num_freq_bands
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.dropout = dropout

        if input_dim == 2:
            input_dim_r_pt2pt = 3
        elif input_dim == 3:
            input_dim_r_pt2pt = 4
        else:
            raise ValueError('{} is not a valid dimension'.format(input_dim))

        self.type_pt_emb = nn.Embedding(17, hidden_dim)
        self.side_pt_emb = nn.Embedding(4, hidden_dim)
        self.polygon_type_emb = nn.Embedding(4, hidden_dim)
        self.light_pl_emb = nn.Embedding(4, hidden_dim)

        self.r_pt2pt_emb = FourierEmbedding(input_dim=input_dim_r_pt2pt, hidden_dim=hidden_dim,
                                            num_freq_bands=num_freq_bands)
        self.pt2pt_layers = nn.ModuleList(
            [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
                            bipartite=False, has_pos_emb=True) for _ in range(num_layers)]
        )
        self.token_size = 1024
        self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
                                           output_dim=self.token_size)
        input_dim_token = 22
        self.token_emb = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
        self.map_token = map_token
        self.apply(weight_init)
        self.mask_pt = False

    def maybe_autocast(self, dtype=torch.float32):
        return torch.cuda.amp.autocast(dtype=dtype)

    def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]:
        pt_valid_mask = data['pt_token']['pt_valid_mask']
        pt_pred_mask = data['pt_token']['pt_pred_mask']
        pt_target_mask = data['pt_token']['pt_target_mask']
        mask_s = pt_valid_mask

        pos_pt = data['pt_token']['position'][:, :self.input_dim].contiguous()
        orient_pt = data['pt_token']['orientation'].contiguous()
        orient_vector_pt = torch.stack([orient_pt.cos(), orient_pt.sin()], dim=-1)
        token_sample_pt = self.map_token['traj_src'].to(pos_pt.device).to(torch.float)
        pt_token_emb_src = self.token_emb(token_sample_pt.view(token_sample_pt.shape[0], -1))
        pt_token_emb = pt_token_emb_src[data['pt_token']['token_idx']]

        x_pt = pt_token_emb

        token2pl = data[('pt_token', 'to', 'map_polygon')]['edge_index']
        token_light_type = data['map_polygon']['light_type'][token2pl[1]]
        x_pt_categorical_embs = [self.type_pt_emb(data['pt_token']['type'].long()),
                                 self.polygon_type_emb(data['pt_token']['pl_type'].long()),
                                 self.light_pl_emb(token_light_type.long()),]
        x_pt = x_pt + torch.stack(x_pt_categorical_embs).sum(dim=0)
        edge_index_pt2pt = radius_graph(x=pos_pt[:, :2], r=self.pl2pl_radius,
                                        batch=data['pt_token']['batch'] if isinstance(data, Batch) else None,
                                        loop=False, max_num_neighbors=100)
        if self.mask_pt:
            edge_index_pt2pt = subgraph(subset=mask_s, edge_index=edge_index_pt2pt)[0]
        rel_pos_pt2pt = pos_pt[edge_index_pt2pt[0]] - pos_pt[edge_index_pt2pt[1]]
        rel_orient_pt2pt = wrap_angle(orient_pt[edge_index_pt2pt[0]] - orient_pt[edge_index_pt2pt[1]])
        if self.input_dim == 2:
            r_pt2pt = torch.stack(
                [torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1),
                 angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]],
                                          nbr_vector=rel_pos_pt2pt[:, :2]),
                 rel_orient_pt2pt], dim=-1)
        elif self.input_dim == 3:
            r_pt2pt = torch.stack(
                [torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1),
                 angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]],
                                          nbr_vector=rel_pos_pt2pt[:, :2]),
                 rel_pos_pt2pt[:, -1],
                 rel_orient_pt2pt], dim=-1)
        else:
            raise ValueError('{} is not a valid dimension'.format(self.input_dim))

        # layers
        r_pt2pt = self.r_pt2pt_emb(continuous_inputs=r_pt2pt, categorical_embs=None)
        for i in range(self.num_layers):
            x_pt = self.pt2pt_layers[i](x_pt, r_pt2pt, edge_index_pt2pt)

        next_token_prob = self.token_predict_head(x_pt[pt_pred_mask])
        next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)
        _, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1)
        next_token_index_gt = data['pt_token']['token_idx'][pt_target_mask]

        return {
            'x_pt': x_pt,
            'map_next_token_idx': next_token_idx,
            'map_next_token_prob': next_token_prob,
            'map_next_token_idx_gt': next_token_index_gt,
            'map_next_token_eval_mask': pt_pred_mask[pt_pred_mask]
        }