|
import torch.nn as nn |
|
|
|
|
|
class DisorderPredictor(nn.Module): |
|
def __init__(self, input_dim, hidden_dim, num_heads, num_layers, dropout): |
|
super(DisorderPredictor, self).__init__() |
|
self.embedding_dim = input_dim |
|
self.self_attention = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads, dropout=dropout) |
|
encoder_layer = nn.TransformerEncoderLayer( |
|
d_model=hidden_dim, |
|
nhead=num_heads, |
|
dropout=dropout, |
|
batch_first=True |
|
) |
|
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) |
|
self.classifier = nn.Linear(input_dim, 1) |
|
|
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, embeddings): |
|
attn_out, _ = self.self_attention(embeddings, embeddings, embeddings) |
|
transformer_out = self.transformer_encoder(attn_out) |
|
logits = self.classifier(transformer_out) |
|
probs = self.sigmoid(logits.squeeze(-1)) |
|
return probs |
|
|