File size: 2,982 Bytes
69524d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from torch import nn
import torch
from models.bin import BiN

class MLPLOB(nn.Module):
    def __init__(self, 

                 hidden_dim: int,

                 num_layers: int,

                 seq_size: int,

                 num_features: int,

                 dataset_type: str

                 ) -> None:
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.dataset_type = dataset_type
        self.layers = nn.ModuleList()
        self.order_type_embedder = nn.Embedding(3, 1)
        self.first_layer = nn.Linear(num_features, hidden_dim)
        self.norm_layer = BiN(num_features, seq_size)
        self.layers.append(self.first_layer)
        self.layers.append(nn.GELU())
        for i in range(num_layers):
            if i != num_layers-1:
                self.layers.append(MLP(hidden_dim, hidden_dim*4, hidden_dim))
                self.layers.append(MLP(seq_size, seq_size*4, seq_size))
            else:
                self.layers.append(MLP(hidden_dim, hidden_dim*2, hidden_dim//4))
                self.layers.append(MLP(seq_size, seq_size*2, seq_size//4))
                
        total_dim = (hidden_dim//4)*(seq_size//4)
        self.final_layers = nn.ModuleList()
        while total_dim > 128:
            self.final_layers.append(nn.Linear(total_dim, total_dim//4))
            self.final_layers.append(nn.GELU())
            total_dim = total_dim//4
        self.final_layers.append(nn.Linear(total_dim, 3))
    
    def forward(self, input):
        if self.dataset_type == "LOBSTER":
            continuous_features = torch.cat([input[:, :, :41], input[:, :, 42:]], dim=2)
            order_type = input[:, :, 41].long()
            order_type_emb = self.order_type_embedder(order_type).detach()
            x = torch.cat([continuous_features, order_type_emb], dim=2)
        else:
            x = input
        x = x.permute(0, 2, 1)
        x = self.norm_layer(x)
        x = x.permute(0, 2, 1)
        for layer in self.layers:
            x = layer(x)
            x = x.permute(0, 2, 1)
        x = x.reshape(x.shape[0], -1)
        for layer in self.final_layers:
            x = layer(x)
        return x
        
        
class MLP(nn.Module):
    def __init__(self, 

                 start_dim: int,

                 hidden_dim: int,

                 final_dim: int

                 ) -> None:
        super().__init__()
        
        self.layer_norm = nn.LayerNorm(final_dim)
        self.fc = nn.Linear(start_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, final_dim)
        self.gelu = nn.GELU()
        
    def forward(self, x):
        residual = x
        x = self.fc(x)
        x = self.gelu(x)
        x = self.fc2(x)
        if x.shape[2] == residual.shape[2]:
            x = x + residual
        x = self.layer_norm(x)
        x = self.gelu(x)
        return x