File size: 1,660 Bytes
65224b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# src/model.py
import torch
import torch.nn as nn

class TransformerModel(nn.Module):
    def __init__(self, vocab_size, embed_size, num_heads, hidden_dim, num_layers, dropout=0.1):
        super(TransformerModel, self).__init__()
        self.embed_size = embed_size
        self.token_embedding = nn.Embedding(vocab_size, embed_size)
        self.position_embedding = nn.Embedding(5000, embed_size)  # Max sequence length

        encoder_layers = nn.TransformerEncoderLayer(
            d_model=embed_size,
            nhead=num_heads,
            dim_feedforward=hidden_dim,
            dropout=dropout
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)

        self.fc_out = nn.Linear(embed_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, src_mask):
        batch_size, seq_length = src.size()
        positions = torch.arange(0, seq_length).unsqueeze(0).repeat(batch_size, 1).to(src.device)

        x = self.token_embedding(src) + self.position_embedding(positions)
        x = self.dropout(x)
        x = x.permute(1, 0, 2)  # Transformer expects [seq_length, batch_size, embed_size]
        transformer_out = self.transformer_encoder(x, src_mask)
        transformer_out = transformer_out.permute(1, 0, 2)
        logits = self.fc_out(transformer_out)
        return logits

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask