import torch.nn as nn # Transformer model class class DisorderPredictor(nn.Module): def __init__(self, input_dim, hidden_dim, num_heads, num_layers, dropout): super(DisorderPredictor, self).__init__() self.embedding_dim = input_dim self.self_attention = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads, dropout=dropout) encoder_layer = nn.TransformerEncoderLayer( d_model=hidden_dim, nhead=num_heads, dropout=dropout, batch_first=True ) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.classifier = nn.Linear(input_dim, 1) # juts produce probabilities for 1 class #self.softmax = nn.Softmax(dim=-1) self.sigmoid = nn.Sigmoid() def forward(self, embeddings): attn_out, _ = self.self_attention(embeddings, embeddings, embeddings) # Start with embeddings as random Q, K, V vectors transformer_out = self.transformer_encoder(attn_out) # Get outputs from encoder layers logits = self.classifier(transformer_out) # Linear classifier probs = self.sigmoid(logits.squeeze(-1)) # sigmoid for probabilities; remove the last dimension of size 1 (since we only predicted 1 class) return probs # Get probabilities of dimension seq_len