LeonardoBerti's picture
Upload 51 files
69524d0 verified
raw
history blame
7.44 kB
from torch import nn
import torch
from einops import rearrange
import constants as cst
from models.bin import BiN
from models.mlplob import MLP
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
class ComputeQKV(nn.Module):
def __init__(self, hidden_dim: int, num_heads: int):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.q = nn.Linear(hidden_dim, hidden_dim*num_heads)
self.k = nn.Linear(hidden_dim, hidden_dim*num_heads)
self.v = nn.Linear(hidden_dim, hidden_dim*num_heads)
def forward(self, x):
q = self.q(x)
k = self.k(x)
v = self.v(x)
return q, k, v
class TransformerLayer(nn.Module):
def __init__(self, hidden_dim: int, num_heads: int, final_dim: int):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.norm = nn.LayerNorm(hidden_dim)
self.qkv = ComputeQKV(hidden_dim, num_heads)
self.attention = nn.MultiheadAttention(hidden_dim*num_heads, num_heads, batch_first=True, device=cst.DEVICE)
self.mlp = MLP(hidden_dim, hidden_dim*4, final_dim)
self.w0 = nn.Linear(hidden_dim*num_heads, hidden_dim)
def forward(self, x):
res = x
q, k, v = self.qkv(x)
x, att = self.attention(q, k, v, average_attn_weights=False, need_weights=True)
x = self.w0(x)
x = x + res
x = self.norm(x)
x = self.mlp(x)
if x.shape[-1] == res.shape[-1]:
x = x + res
return x, att
class TLOB(nn.Module):
def __init__(self,
hidden_dim: int,
num_layers: int,
seq_size: int,
num_features: int,
num_heads: int,
is_sin_emb: bool,
dataset_type: str
) -> None:
super().__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.is_sin_emb = is_sin_emb
self.seq_size = seq_size
self.num_heads = num_heads
self.dataset_type = dataset_type
self.layers = nn.ModuleList()
self.first_branch = nn.ModuleList()
self.second_branch = nn.ModuleList()
self.order_type_embedder = nn.Embedding(3, 1)
self.norm_layer = BiN(num_features, seq_size)
self.emb_layer = nn.Linear(num_features, hidden_dim)
if is_sin_emb:
self.pos_encoder = sinusoidal_positional_embedding(seq_size, hidden_dim)
else:
self.pos_encoder = nn.Parameter(torch.randn(1, seq_size, hidden_dim))
for i in range(num_layers):
if i != num_layers-1:
self.layers.append(TransformerLayer(hidden_dim, num_heads, hidden_dim))
self.layers.append(TransformerLayer(seq_size, num_heads, seq_size))
else:
self.layers.append(TransformerLayer(hidden_dim, num_heads, hidden_dim//4))
self.layers.append(TransformerLayer(seq_size, num_heads, seq_size//4))
self.att_temporal = []
self.att_feature = []
self.mean_att_distance_temporal = []
total_dim = (hidden_dim//4)*(seq_size//4)
self.final_layers = nn.ModuleList()
while total_dim > 128:
self.final_layers.append(nn.Linear(total_dim, total_dim//4))
self.final_layers.append(nn.GELU())
total_dim = total_dim//4
self.final_layers.append(nn.Linear(total_dim, 3))
def forward(self, input, store_att=False):
if self.dataset_type == "LOBSTER":
continuous_features = torch.cat([input[:, :, :41], input[:, :, 42:]], dim=2)
order_type = input[:, :, 41].long()
order_type_emb = self.order_type_embedder(order_type).detach()
x = torch.cat([continuous_features, order_type_emb], dim=2)
else:
x = input
x = rearrange(x, 'b s f -> b f s')
x = self.norm_layer(x)
x = rearrange(x, 'b f s -> b s f')
x = self.emb_layer(x)
x = x[:] + self.pos_encoder
mean_att_distance_temporal = np.zeros((self.num_layers, self.num_heads))
att_max_temporal = np.zeros((self.num_layers, 2, self.num_heads, self.seq_size))
att_max_feature = np.zeros((self.num_layers-1, 2, self.num_heads, self.hidden_dim))
att_temporal = np.zeros((self.num_layers, self.num_heads, self.seq_size, self.seq_size))
att_feature = np.zeros((self.num_layers-1, self.num_heads, self.hidden_dim, self.hidden_dim))
for i in range(len(self.layers)):
x, att = self.layers[i](x)
att = att.detach()
x = x.permute(0, 2, 1)
if store_att:
if i % 2 == 0:
att_temporal[i//2] = att[0].cpu().numpy()
values, indices = att[0].max(dim=2)
mean_att_distance_temporal[i//2] = compute_mean_att_distance(att[0])
att_max_temporal[i//2, 0] = indices.cpu().numpy()
att_max_temporal[i//2, 1] = values.cpu().numpy()
elif i % 2 == 1 and i != len(self.layers)-1:
att_feature[i//2] = att[0].cpu().numpy()
values, indices = att[0].max(dim=2)
att_max_feature[i//2, 0] = indices.cpu().numpy()
att_max_feature[i//2, 1] = values.cpu().numpy()
self.mean_att_distance_temporal.append(mean_att_distance_temporal)
if store_att:
self.att_temporal.append(att_max_temporal)
self.att_feature.append(att_max_feature)
x = rearrange(x, 'b s f -> b (f s) 1')
x = x.reshape(x.shape[0], -1)
for layer in self.final_layers:
x = layer(x)
return x, att_temporal, att_feature
def sinusoidal_positional_embedding(token_sequence_size, token_embedding_dim, n=10000.0):
if token_embedding_dim % 2 != 0:
raise ValueError("Sinusoidal positional embedding cannot apply to odd token embedding dim (got dim={:d})".format(token_embedding_dim))
T = token_sequence_size
d = token_embedding_dim
positions = torch.arange(0, T).unsqueeze_(1)
embeddings = torch.zeros(T, d)
denominators = torch.pow(n, 2*torch.arange(0, d//2)/d) # 10000^(2i/d_model), i is the index of embedding
embeddings[:, 0::2] = torch.sin(positions/denominators) # sin(pos/10000^(2i/d_model))
embeddings[:, 1::2] = torch.cos(positions/denominators) # cos(pos/10000^(2i/d_model))
return embeddings.to(cst.DEVICE, non_blocking=True)
def count_parameters(layer):
print(f"Number of parameters: {sum(p.numel() for p in layer.parameters() if p.requires_grad)}")
def compute_mean_att_distance(att):
att_distances = np.zeros((att.shape[0], att.shape[1]))
for h in range(att.shape[0]):
for key in range(att.shape[2]):
for query in range(att.shape[1]):
distance = abs(query-key)
att_distances[h, key] += torch.abs(att[h, query, key]).cpu().item()*distance
mean_distances = att_distances.mean(axis=1)
return mean_distances