|
from torch import nn
|
|
import torch
|
|
from einops import rearrange
|
|
import constants as cst
|
|
from models.bin import BiN
|
|
from models.mlplob import MLP
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
|
|
|
|
class ComputeQKV(nn.Module):
|
|
def __init__(self, hidden_dim: int, num_heads: int):
|
|
super().__init__()
|
|
self.hidden_dim = hidden_dim
|
|
self.num_heads = num_heads
|
|
self.q = nn.Linear(hidden_dim, hidden_dim*num_heads)
|
|
self.k = nn.Linear(hidden_dim, hidden_dim*num_heads)
|
|
self.v = nn.Linear(hidden_dim, hidden_dim*num_heads)
|
|
|
|
def forward(self, x):
|
|
q = self.q(x)
|
|
k = self.k(x)
|
|
v = self.v(x)
|
|
return q, k, v
|
|
|
|
|
|
class TransformerLayer(nn.Module):
|
|
def __init__(self, hidden_dim: int, num_heads: int, final_dim: int):
|
|
super().__init__()
|
|
self.hidden_dim = hidden_dim
|
|
self.num_heads = num_heads
|
|
self.norm = nn.LayerNorm(hidden_dim)
|
|
self.qkv = ComputeQKV(hidden_dim, num_heads)
|
|
self.attention = nn.MultiheadAttention(hidden_dim*num_heads, num_heads, batch_first=True, device=cst.DEVICE)
|
|
self.mlp = MLP(hidden_dim, hidden_dim*4, final_dim)
|
|
self.w0 = nn.Linear(hidden_dim*num_heads, hidden_dim)
|
|
|
|
def forward(self, x):
|
|
res = x
|
|
q, k, v = self.qkv(x)
|
|
x, att = self.attention(q, k, v, average_attn_weights=False, need_weights=True)
|
|
x = self.w0(x)
|
|
x = x + res
|
|
x = self.norm(x)
|
|
x = self.mlp(x)
|
|
if x.shape[-1] == res.shape[-1]:
|
|
x = x + res
|
|
return x, att
|
|
|
|
|
|
class TLOB(nn.Module):
|
|
def __init__(self,
|
|
hidden_dim: int,
|
|
num_layers: int,
|
|
seq_size: int,
|
|
num_features: int,
|
|
num_heads: int,
|
|
is_sin_emb: bool,
|
|
dataset_type: str
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.hidden_dim = hidden_dim
|
|
self.num_layers = num_layers
|
|
self.is_sin_emb = is_sin_emb
|
|
self.seq_size = seq_size
|
|
self.num_heads = num_heads
|
|
self.dataset_type = dataset_type
|
|
self.layers = nn.ModuleList()
|
|
self.first_branch = nn.ModuleList()
|
|
self.second_branch = nn.ModuleList()
|
|
self.order_type_embedder = nn.Embedding(3, 1)
|
|
self.norm_layer = BiN(num_features, seq_size)
|
|
self.emb_layer = nn.Linear(num_features, hidden_dim)
|
|
if is_sin_emb:
|
|
self.pos_encoder = sinusoidal_positional_embedding(seq_size, hidden_dim)
|
|
else:
|
|
self.pos_encoder = nn.Parameter(torch.randn(1, seq_size, hidden_dim))
|
|
|
|
for i in range(num_layers):
|
|
if i != num_layers-1:
|
|
self.layers.append(TransformerLayer(hidden_dim, num_heads, hidden_dim))
|
|
self.layers.append(TransformerLayer(seq_size, num_heads, seq_size))
|
|
else:
|
|
self.layers.append(TransformerLayer(hidden_dim, num_heads, hidden_dim//4))
|
|
self.layers.append(TransformerLayer(seq_size, num_heads, seq_size//4))
|
|
self.att_temporal = []
|
|
self.att_feature = []
|
|
self.mean_att_distance_temporal = []
|
|
total_dim = (hidden_dim//4)*(seq_size//4)
|
|
self.final_layers = nn.ModuleList()
|
|
while total_dim > 128:
|
|
self.final_layers.append(nn.Linear(total_dim, total_dim//4))
|
|
self.final_layers.append(nn.GELU())
|
|
total_dim = total_dim//4
|
|
self.final_layers.append(nn.Linear(total_dim, 3))
|
|
|
|
|
|
def forward(self, input, store_att=False):
|
|
if self.dataset_type == "LOBSTER":
|
|
continuous_features = torch.cat([input[:, :, :41], input[:, :, 42:]], dim=2)
|
|
order_type = input[:, :, 41].long()
|
|
order_type_emb = self.order_type_embedder(order_type).detach()
|
|
x = torch.cat([continuous_features, order_type_emb], dim=2)
|
|
else:
|
|
x = input
|
|
x = rearrange(x, 'b s f -> b f s')
|
|
x = self.norm_layer(x)
|
|
x = rearrange(x, 'b f s -> b s f')
|
|
x = self.emb_layer(x)
|
|
x = x[:] + self.pos_encoder
|
|
mean_att_distance_temporal = np.zeros((self.num_layers, self.num_heads))
|
|
att_max_temporal = np.zeros((self.num_layers, 2, self.num_heads, self.seq_size))
|
|
att_max_feature = np.zeros((self.num_layers-1, 2, self.num_heads, self.hidden_dim))
|
|
att_temporal = np.zeros((self.num_layers, self.num_heads, self.seq_size, self.seq_size))
|
|
att_feature = np.zeros((self.num_layers-1, self.num_heads, self.hidden_dim, self.hidden_dim))
|
|
for i in range(len(self.layers)):
|
|
x, att = self.layers[i](x)
|
|
att = att.detach()
|
|
x = x.permute(0, 2, 1)
|
|
if store_att:
|
|
if i % 2 == 0:
|
|
att_temporal[i//2] = att[0].cpu().numpy()
|
|
values, indices = att[0].max(dim=2)
|
|
mean_att_distance_temporal[i//2] = compute_mean_att_distance(att[0])
|
|
att_max_temporal[i//2, 0] = indices.cpu().numpy()
|
|
att_max_temporal[i//2, 1] = values.cpu().numpy()
|
|
elif i % 2 == 1 and i != len(self.layers)-1:
|
|
att_feature[i//2] = att[0].cpu().numpy()
|
|
values, indices = att[0].max(dim=2)
|
|
att_max_feature[i//2, 0] = indices.cpu().numpy()
|
|
att_max_feature[i//2, 1] = values.cpu().numpy()
|
|
self.mean_att_distance_temporal.append(mean_att_distance_temporal)
|
|
if store_att:
|
|
self.att_temporal.append(att_max_temporal)
|
|
self.att_feature.append(att_max_feature)
|
|
x = rearrange(x, 'b s f -> b (f s) 1')
|
|
x = x.reshape(x.shape[0], -1)
|
|
for layer in self.final_layers:
|
|
x = layer(x)
|
|
return x, att_temporal, att_feature
|
|
|
|
|
|
def sinusoidal_positional_embedding(token_sequence_size, token_embedding_dim, n=10000.0):
|
|
|
|
if token_embedding_dim % 2 != 0:
|
|
raise ValueError("Sinusoidal positional embedding cannot apply to odd token embedding dim (got dim={:d})".format(token_embedding_dim))
|
|
|
|
T = token_sequence_size
|
|
d = token_embedding_dim
|
|
|
|
positions = torch.arange(0, T).unsqueeze_(1)
|
|
embeddings = torch.zeros(T, d)
|
|
|
|
denominators = torch.pow(n, 2*torch.arange(0, d//2)/d)
|
|
embeddings[:, 0::2] = torch.sin(positions/denominators)
|
|
embeddings[:, 1::2] = torch.cos(positions/denominators)
|
|
|
|
return embeddings.to(cst.DEVICE, non_blocking=True)
|
|
|
|
|
|
def count_parameters(layer):
|
|
print(f"Number of parameters: {sum(p.numel() for p in layer.parameters() if p.requires_grad)}")
|
|
|
|
|
|
def compute_mean_att_distance(att):
|
|
att_distances = np.zeros((att.shape[0], att.shape[1]))
|
|
for h in range(att.shape[0]):
|
|
for key in range(att.shape[2]):
|
|
for query in range(att.shape[1]):
|
|
distance = abs(query-key)
|
|
att_distances[h, key] += torch.abs(att[h, query, key]).cpu().item()*distance
|
|
mean_distances = att_distances.mean(axis=1)
|
|
return mean_distances
|
|
|
|
|
|
|