import math from dataclasses import dataclass from typing import Optional, Tuple, List import torch import torch.nn as nn import torch.nn.functional as F @dataclass class ModelConfig: vocab_size: int = 65536 n_layer: int = 6 n_head: int = 8 n_embd: int = 512 block_size: int = 512 dropout: float = 0.1 class PreNormSelfAttention(nn.Module): def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float): super().__init__() assert n_embd % n_head == 0, "n_embd must be divisible by n_head" self.n_head = n_head self.head_dim = n_embd // n_head self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False) self.proj = nn.Linear(n_embd, n_embd, bias=False) self.attn_drop = nn.Dropout(dropout) self.resid_drop = nn.Dropout(dropout) self.ln = nn.LayerNorm(n_embd) mask = torch.tril(torch.ones(block_size, block_size)) self.register_buffer("mask", mask.view(1, 1, block_size, block_size), persistent=False) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: B, T, C = x.size() x_norm = self.ln(x) qkv = self.qkv(x_norm).view(B, T, 3, self.n_head, self.head_dim).transpose(1, 3) q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf")) att = F.softmax(att, dim=-1) att = self.attn_drop(att) y = att @ v y = y.transpose(1, 2).contiguous().view(B, T, C) y = self.resid_drop(self.proj(y)) out = x + y return out, y class PreNormMLP(nn.Module): def __init__(self, n_embd: int, dropout: float): super().__init__() hidden = 4 * n_embd self.ln = nn.LayerNorm(n_embd) self.fc1 = nn.Linear(n_embd, hidden) self.fc2 = nn.Linear(hidden, n_embd) self.drop = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: x_norm = self.ln(x) h = F.gelu(self.fc1(x_norm)) h = self.drop(h) y = self.fc2(h) y = self.drop(y) out = x + y return out, y class Block(nn.Module): def __init__(self, cfg: ModelConfig): super().__init__() self.attn = PreNormSelfAttention(cfg.n_embd, cfg.n_head, cfg.block_size, cfg.dropout) self.mlp = PreNormMLP(cfg.n_embd, cfg.dropout) def forward(self, x: torch.Tensor): x, attn_out = self.attn(x) x, mlp_out = self.mlp(x) return x, {"attn": attn_out, "mlp": mlp_out} class ResearchTransformer(nn.Module): def __init__(self, cfg: ModelConfig): super().__init__() self.cfg = cfg self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.n_embd) self.pos_emb = nn.Embedding(cfg.block_size, cfg.n_embd) self.drop = nn.Dropout(cfg.dropout) self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layer)]) self.ln_f = nn.LayerNorm(cfg.n_embd) self.lm_head = nn.Linear(cfg.n_embd, cfg.vocab_size, bias=False) self.lm_head.weight = self.tok_emb.weight self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): nn.init.normal_(module.weight, mean=0.0, std=0.02) if isinstance(module, nn.Linear) and module.bias is not None: nn.init.zeros_(module.bias) def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, return_activations: bool = False): B, T = input_ids.size() assert T <= self.cfg.block_size, f"Input length {T} exceeds block size {self.cfg.block_size}" pos = torch.arange(0, T, dtype=torch.long, device=input_ids.device).unsqueeze(0) x = self.tok_emb(input_ids) + self.pos_emb(pos) x = self.drop(x) activations = [] for blk in self.blocks: x, acts = blk(x) if return_activations: activations.append(acts) x = self.ln_f(x) logits = self.lm_head(x) loss = None if labels is not None: loss = F.cross_entropy( logits[:, :-1, :].contiguous().view(-1, logits.size(-1)), labels[:, 1:].contiguous().view(-1), ignore_index=-100 ) class Output: pass out = Output() out.logits = logits out.loss = loss if return_activations: out.activations = activations return out @torch.no_grad() def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 50): self.eval() for _ in range(max_new_tokens): if input_ids.size(1) > self.cfg.block_size: input_ids = input_ids[:, -self.cfg.block_size:] out = self(input_ids) next_token = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True) input_ids = torch.cat([input_ids, next_token], dim=1) return input_ids