svincoff's picture
caid benchmark
bae913a
import torch.nn as nn
# Transformer model class
class DisorderPredictor(nn.Module):
def __init__(self, input_dim, hidden_dim, num_heads, num_layers, dropout):
super(DisorderPredictor, self).__init__()
self.embedding_dim = input_dim
self.self_attention = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads, dropout=dropout)
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dropout=dropout,
batch_first=True
)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.classifier = nn.Linear(input_dim, 1) # juts produce probabilities for 1 class
#self.softmax = nn.Softmax(dim=-1)
self.sigmoid = nn.Sigmoid()
def forward(self, embeddings):
attn_out, _ = self.self_attention(embeddings, embeddings, embeddings) # Start with embeddings as random Q, K, V vectors
transformer_out = self.transformer_encoder(attn_out) # Get outputs from encoder layers
logits = self.classifier(transformer_out) # Linear classifier
probs = self.sigmoid(logits.squeeze(-1)) # sigmoid for probabilities; remove the last dimension of size 1 (since we only predicted 1 class)
return probs # Get probabilities of dimension seq_len