Spaces:
Sleeping
Sleeping
import math | |
from dataclasses import dataclass | |
from typing import Optional, Tuple, List | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class ModelConfig: | |
vocab_size: int = 65536 | |
n_layer: int = 6 | |
n_head: int = 8 | |
n_embd: int = 512 | |
block_size: int = 512 | |
dropout: float = 0.1 | |
class PreNormSelfAttention(nn.Module): | |
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float): | |
super().__init__() | |
assert n_embd % n_head == 0, "n_embd must be divisible by n_head" | |
self.n_head = n_head | |
self.head_dim = n_embd // n_head | |
self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False) | |
self.proj = nn.Linear(n_embd, n_embd, bias=False) | |
self.attn_drop = nn.Dropout(dropout) | |
self.resid_drop = nn.Dropout(dropout) | |
self.ln = nn.LayerNorm(n_embd) | |
mask = torch.tril(torch.ones(block_size, block_size)) | |
self.register_buffer("mask", mask.view(1, 1, block_size, block_size), persistent=False) | |
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
B, T, C = x.size() | |
x_norm = self.ln(x) | |
qkv = self.qkv(x_norm).view(B, T, 3, self.n_head, self.head_dim).transpose(1, 3) | |
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] | |
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) | |
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf")) | |
att = F.softmax(att, dim=-1) | |
att = self.attn_drop(att) | |
y = att @ v | |
y = y.transpose(1, 2).contiguous().view(B, T, C) | |
y = self.resid_drop(self.proj(y)) | |
out = x + y | |
return out, y | |
class PreNormMLP(nn.Module): | |
def __init__(self, n_embd: int, dropout: float): | |
super().__init__() | |
hidden = 4 * n_embd | |
self.ln = nn.LayerNorm(n_embd) | |
self.fc1 = nn.Linear(n_embd, hidden) | |
self.fc2 = nn.Linear(hidden, n_embd) | |
self.drop = nn.Dropout(dropout) | |
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
x_norm = self.ln(x) | |
h = F.gelu(self.fc1(x_norm)) | |
h = self.drop(h) | |
y = self.fc2(h) | |
y = self.drop(y) | |
out = x + y | |
return out, y | |
class Block(nn.Module): | |
def __init__(self, cfg: ModelConfig): | |
super().__init__() | |
self.attn = PreNormSelfAttention(cfg.n_embd, cfg.n_head, cfg.block_size, cfg.dropout) | |
self.mlp = PreNormMLP(cfg.n_embd, cfg.dropout) | |
def forward(self, x: torch.Tensor): | |
x, attn_out = self.attn(x) | |
x, mlp_out = self.mlp(x) | |
return x, {"attn": attn_out, "mlp": mlp_out} | |
class ResearchTransformer(nn.Module): | |
def __init__(self, cfg: ModelConfig): | |
super().__init__() | |
self.cfg = cfg | |
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.n_embd) | |
self.pos_emb = nn.Embedding(cfg.block_size, cfg.n_embd) | |
self.drop = nn.Dropout(cfg.dropout) | |
self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layer)]) | |
self.ln_f = nn.LayerNorm(cfg.n_embd) | |
self.lm_head = nn.Linear(cfg.n_embd, cfg.vocab_size, bias=False) | |
self.lm_head.weight = self.tok_emb.weight | |
self.apply(self._init_weights) | |
def _init_weights(self, module): | |
if isinstance(module, (nn.Linear, nn.Embedding)): | |
nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
if isinstance(module, nn.Linear) and module.bias is not None: | |
nn.init.zeros_(module.bias) | |
def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, | |
labels: Optional[torch.Tensor] = None, return_activations: bool = False): | |
B, T = input_ids.size() | |
assert T <= self.cfg.block_size, f"Input length {T} exceeds block size {self.cfg.block_size}" | |
pos = torch.arange(0, T, dtype=torch.long, device=input_ids.device).unsqueeze(0) | |
x = self.tok_emb(input_ids) + self.pos_emb(pos) | |
x = self.drop(x) | |
activations = [] | |
for blk in self.blocks: | |
x, acts = blk(x) | |
if return_activations: | |
activations.append(acts) | |
x = self.ln_f(x) | |
logits = self.lm_head(x) | |
loss = None | |
if labels is not None: | |
loss = F.cross_entropy( | |
logits[:, :-1, :].contiguous().view(-1, logits.size(-1)), | |
labels[:, 1:].contiguous().view(-1), | |
ignore_index=-100 | |
) | |
class Output: | |
pass | |
out = Output() | |
out.logits = logits | |
out.loss = loss | |
if return_activations: | |
out.activations = activations | |
return out | |
def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 50): | |
self.eval() | |
for _ in range(max_new_tokens): | |
if input_ids.size(1) > self.cfg.block_size: | |
input_ids = input_ids[:, -self.cfg.block_size:] | |
out = self(input_ids) | |
next_token = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True) | |
input_ids = torch.cat([input_ids, next_token], dim=1) | |
return input_ids | |