from torch import nn import torch from einops import rearrange import constants as cst from models.bin import BiN from models.mlplob import MLP import numpy as np import matplotlib.pyplot as plt import seaborn as sns class ComputeQKV(nn.Module): def __init__(self, hidden_dim: int, num_heads: int): super().__init__() self.hidden_dim = hidden_dim self.num_heads = num_heads self.q = nn.Linear(hidden_dim, hidden_dim*num_heads) self.k = nn.Linear(hidden_dim, hidden_dim*num_heads) self.v = nn.Linear(hidden_dim, hidden_dim*num_heads) def forward(self, x): q = self.q(x) k = self.k(x) v = self.v(x) return q, k, v class TransformerLayer(nn.Module): def __init__(self, hidden_dim: int, num_heads: int, final_dim: int): super().__init__() self.hidden_dim = hidden_dim self.num_heads = num_heads self.norm = nn.LayerNorm(hidden_dim) self.qkv = ComputeQKV(hidden_dim, num_heads) self.attention = nn.MultiheadAttention(hidden_dim*num_heads, num_heads, batch_first=True, device=cst.DEVICE) self.mlp = MLP(hidden_dim, hidden_dim*4, final_dim) self.w0 = nn.Linear(hidden_dim*num_heads, hidden_dim) def forward(self, x): res = x q, k, v = self.qkv(x) x, att = self.attention(q, k, v, average_attn_weights=False, need_weights=True) x = self.w0(x) x = x + res x = self.norm(x) x = self.mlp(x) if x.shape[-1] == res.shape[-1]: x = x + res return x, att class TLOB(nn.Module): def __init__(self, hidden_dim: int, num_layers: int, seq_size: int, num_features: int, num_heads: int, is_sin_emb: bool, dataset_type: str ) -> None: super().__init__() self.hidden_dim = hidden_dim self.num_layers = num_layers self.is_sin_emb = is_sin_emb self.seq_size = seq_size self.num_heads = num_heads self.dataset_type = dataset_type self.layers = nn.ModuleList() self.first_branch = nn.ModuleList() self.second_branch = nn.ModuleList() self.order_type_embedder = nn.Embedding(3, 1) self.norm_layer = BiN(num_features, seq_size) self.emb_layer = nn.Linear(num_features, hidden_dim) if is_sin_emb: self.pos_encoder = sinusoidal_positional_embedding(seq_size, hidden_dim) else: self.pos_encoder = nn.Parameter(torch.randn(1, seq_size, hidden_dim)) for i in range(num_layers): if i != num_layers-1: self.layers.append(TransformerLayer(hidden_dim, num_heads, hidden_dim)) self.layers.append(TransformerLayer(seq_size, num_heads, seq_size)) else: self.layers.append(TransformerLayer(hidden_dim, num_heads, hidden_dim//4)) self.layers.append(TransformerLayer(seq_size, num_heads, seq_size//4)) self.att_temporal = [] self.att_feature = [] self.mean_att_distance_temporal = [] total_dim = (hidden_dim//4)*(seq_size//4) self.final_layers = nn.ModuleList() while total_dim > 128: self.final_layers.append(nn.Linear(total_dim, total_dim//4)) self.final_layers.append(nn.GELU()) total_dim = total_dim//4 self.final_layers.append(nn.Linear(total_dim, 3)) def forward(self, input, store_att=False): if self.dataset_type == "LOBSTER": continuous_features = torch.cat([input[:, :, :41], input[:, :, 42:]], dim=2) order_type = input[:, :, 41].long() order_type_emb = self.order_type_embedder(order_type).detach() x = torch.cat([continuous_features, order_type_emb], dim=2) else: x = input x = rearrange(x, 'b s f -> b f s') x = self.norm_layer(x) x = rearrange(x, 'b f s -> b s f') x = self.emb_layer(x) x = x[:] + self.pos_encoder mean_att_distance_temporal = np.zeros((self.num_layers, self.num_heads)) att_max_temporal = np.zeros((self.num_layers, 2, self.num_heads, self.seq_size)) att_max_feature = np.zeros((self.num_layers-1, 2, self.num_heads, self.hidden_dim)) att_temporal = np.zeros((self.num_layers, self.num_heads, self.seq_size, self.seq_size)) att_feature = np.zeros((self.num_layers-1, self.num_heads, self.hidden_dim, self.hidden_dim)) for i in range(len(self.layers)): x, att = self.layers[i](x) att = att.detach() x = x.permute(0, 2, 1) if store_att: if i % 2 == 0: att_temporal[i//2] = att[0].cpu().numpy() values, indices = att[0].max(dim=2) mean_att_distance_temporal[i//2] = compute_mean_att_distance(att[0]) att_max_temporal[i//2, 0] = indices.cpu().numpy() att_max_temporal[i//2, 1] = values.cpu().numpy() elif i % 2 == 1 and i != len(self.layers)-1: att_feature[i//2] = att[0].cpu().numpy() values, indices = att[0].max(dim=2) att_max_feature[i//2, 0] = indices.cpu().numpy() att_max_feature[i//2, 1] = values.cpu().numpy() self.mean_att_distance_temporal.append(mean_att_distance_temporal) if store_att: self.att_temporal.append(att_max_temporal) self.att_feature.append(att_max_feature) x = rearrange(x, 'b s f -> b (f s) 1') x = x.reshape(x.shape[0], -1) for layer in self.final_layers: x = layer(x) return x, att_temporal, att_feature def sinusoidal_positional_embedding(token_sequence_size, token_embedding_dim, n=10000.0): if token_embedding_dim % 2 != 0: raise ValueError("Sinusoidal positional embedding cannot apply to odd token embedding dim (got dim={:d})".format(token_embedding_dim)) T = token_sequence_size d = token_embedding_dim positions = torch.arange(0, T).unsqueeze_(1) embeddings = torch.zeros(T, d) denominators = torch.pow(n, 2*torch.arange(0, d//2)/d) # 10000^(2i/d_model), i is the index of embedding embeddings[:, 0::2] = torch.sin(positions/denominators) # sin(pos/10000^(2i/d_model)) embeddings[:, 1::2] = torch.cos(positions/denominators) # cos(pos/10000^(2i/d_model)) return embeddings.to(cst.DEVICE, non_blocking=True) def count_parameters(layer): print(f"Number of parameters: {sum(p.numel() for p in layer.parameters() if p.requires_grad)}") def compute_mean_att_distance(att): att_distances = np.zeros((att.shape[0], att.shape[1])) for h in range(att.shape[0]): for key in range(att.shape[2]): for query in range(att.shape[1]): distance = abs(query-key) att_distances[h, key] += torch.abs(att[h, query, key]).cpu().item()*distance mean_distances = att_distances.mean(axis=1) return mean_distances