import torch import torch.nn as nn import torch.nn.functional as F import math class PositionalEncoding(nn.Module): def __init__(self, d_model, max_seq_length): super(PositionalEncoding, self).__init__() pe = torch.zeros(max_seq_length, d_model) position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe.unsqueeze(0)) def forward(self, x): return x + self.pe[:, :x.size(1)] class MultiHeadSelfAttention(nn.Module): def __init__(self, d_model, num_heads): super(MultiHeadSelfAttention, self).__init__() assert d_model % num_heads == 0, "d_model must be divisible by num_heads" self.d_model = d_model self.num_heads = num_heads self.depth = d_model // num_heads self.wq = nn.Linear(d_model, d_model) self.wk = nn.Linear(d_model, d_model) self.wv = nn.Linear(d_model, d_model) self.fc = nn.Linear(d_model, d_model) def split_heads(self, x, batch_size): x = x.view(batch_size, -1, self.num_heads, self.depth) return x.permute(0, 2, 1, 3) def forward(self, q, k, v, mask=None): batch_size = q.size(0) q = self.split_heads(self.wq(q), batch_size) k = self.split_heads(self.wk(k), batch_size) v = self.split_heads(self.wv(v), batch_size) scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.depth, dtype=torch.float32)) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) attn = F.softmax(scores, dim=-1) out = torch.matmul(attn, v) out = out.permute(0, 2, 1, 3).contiguous() out = out.view(batch_size, -1, self.d_model) out = self.fc(out) return out class FeedForwardNetwork(nn.Module): def __init__(self, d_model, d_ff, dropout=0.1): super(FeedForwardNetwork, self).__init__() self.fc1 = nn.Linear(d_model, d_ff) self.fc2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) def forward(self, x): x = self.fc1(x) x = F.relu(x) x = self.dropout(x) x = self.fc2(x) return x class EncoderLayer(nn.Module): def __init__(self, d_model, num_heads, d_ff, dropout=0.1): super(EncoderLayer, self).__init__() self.self_attn = MultiHeadSelfAttention(d_model, num_heads) self.ffn = FeedForwardNetwork(d_model, d_ff, dropout) self.layernorm1 = nn.LayerNorm(d_model) self.layernorm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x, mask=None): attn_output = self.self_attn(x, x, x, mask) x = self.layernorm1(x + self.dropout(attn_output)) ffn_output = self.ffn(x) x = self.layernorm2(x + self.dropout(ffn_output)) return x class DecoderLayer(nn.Module): def __init__(self, d_model, num_heads, d_ff, dropout=0.1): super(DecoderLayer, self).__init__() self.self_attn = MultiHeadSelfAttention(d_model, num_heads) self.cross_attn = MultiHeadSelfAttention(d_model, num_heads) self.ffn = FeedForwardNetwork(d_model, d_ff, dropout) self.layernorm1 = nn.LayerNorm(d_model) self.layernorm2 = nn.LayerNorm(d_model) self.layernorm3 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x, enc_output, src_mask=None, tgt_mask=None): self_attn_output = self.self_attn(q=x, k=x, v=x, mask=tgt_mask) x = self.layernorm1(x + self.dropout(self_attn_output)) cross_attn_output = self.cross_attn(q=x, k=enc_output, v=enc_output, mask=src_mask) x = self.layernorm2(x + self.dropout(cross_attn_output)) ffn_output = self.ffn(x) x = self.layernorm3(x + self.dropout(ffn_output)) return x class TransformerBlock(nn.Module): def __init__(self, tokenizer=None, config=None, stress=False): super(TransformerBlock, self).__init__() self.config = config self.tokenizer = tokenizer self.input_vocab_size = tokenizer.get_vocab_size() self.target_vocab_size = tokenizer.get_vocab_size() self.d_model = config.get('D_MODEL', 512) self.num_heads = config.get('NUM_HEADS', 8) self.num_encoder_layers = config.get('NUM', 6) self.num_decoder_layers = config.get('NUM', 6) self.d_ff = config.get('D_FF', 2048) self.dropout = config.get('DROPOUT', 0.1) self.stress = stress self.encoder_embedding = nn.Embedding(self.input_vocab_size, self.d_model) self.decoder_embedding = nn.Embedding(self.target_vocab_size, self.d_model) self.pos_embedding = PositionalEncoding(self.d_model, config.get('MAX_LEN', 32)) self.encoder_layers = nn.ModuleList( [EncoderLayer(self.d_model, self.num_heads, self.d_ff, self.dropout) for _ in range(self.num_encoder_layers)]) self.decoder_layers = nn.ModuleList( [DecoderLayer(self.d_model, self.num_heads, self.d_ff, self.dropout) for _ in range(self.num_decoder_layers)]) self.fc_out = nn.Linear(self.d_model, self.target_vocab_size) def encode(self, src, src_mask): src = self.pos_embedding(self.encoder_embedding(src)) for layer in self.encoder_layers: src = layer(src, src_mask) return src def decode(self, memory, src_mask, tgt, tgt_mask): tgt = self.pos_embedding(self.decoder_embedding(tgt)) for layer in self.decoder_layers: tgt = layer(tgt, memory, src_mask, tgt_mask) return tgt def forward(self, src, tgt, src_mask, tgt_mask): memory = self.encode(src, src_mask) output = self.decode(memory, src_mask, tgt, tgt_mask) output = self.fc_out(output) return output