File size: 1,360 Bytes
bae913a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch.nn as nn

# Transformer model class
class DisorderPredictor(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_heads, num_layers, dropout):
        super(DisorderPredictor, self).__init__()
        self.embedding_dim = input_dim
        self.self_attention = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads, dropout=dropout)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.classifier = nn.Linear(input_dim, 1)       # juts produce probabilities for 1 class
        #self.softmax = nn.Softmax(dim=-1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, embeddings):
        attn_out, _ = self.self_attention(embeddings, embeddings, embeddings) # Start with embeddings as random Q, K, V vectors
        transformer_out = self.transformer_encoder(attn_out) # Get outputs from encoder layers
        logits = self.classifier(transformer_out) # Linear classifier
        probs = self.sigmoid(logits.squeeze(-1)) # sigmoid for probabilities; remove the last dimension of size 1 (since we only predicted 1 class)
        return probs  # Get probabilities of dimension seq_len