training_bench / models /research_model.py
rider-provider-777's picture
Upload research_model.py
c975d73 verified
import math
from dataclasses import dataclass
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class ModelConfig:
vocab_size: int = 65536
n_layer: int = 6
n_head: int = 8
n_embd: int = 512
block_size: int = 512
dropout: float = 0.1
class PreNormSelfAttention(nn.Module):
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
super().__init__()
assert n_embd % n_head == 0, "n_embd must be divisible by n_head"
self.n_head = n_head
self.head_dim = n_embd // n_head
self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False)
self.proj = nn.Linear(n_embd, n_embd, bias=False)
self.attn_drop = nn.Dropout(dropout)
self.resid_drop = nn.Dropout(dropout)
self.ln = nn.LayerNorm(n_embd)
mask = torch.tril(torch.ones(block_size, block_size))
self.register_buffer("mask", mask.view(1, 1, block_size, block_size), persistent=False)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
B, T, C = x.size()
x_norm = self.ln(x)
qkv = self.qkv(x_norm).view(B, T, 3, self.n_head, self.head_dim).transpose(1, 3)
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
att = F.softmax(att, dim=-1)
att = self.attn_drop(att)
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.resid_drop(self.proj(y))
out = x + y
return out, y
class PreNormMLP(nn.Module):
def __init__(self, n_embd: int, dropout: float):
super().__init__()
hidden = 4 * n_embd
self.ln = nn.LayerNorm(n_embd)
self.fc1 = nn.Linear(n_embd, hidden)
self.fc2 = nn.Linear(hidden, n_embd)
self.drop = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
x_norm = self.ln(x)
h = F.gelu(self.fc1(x_norm))
h = self.drop(h)
y = self.fc2(h)
y = self.drop(y)
out = x + y
return out, y
class Block(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
self.attn = PreNormSelfAttention(cfg.n_embd, cfg.n_head, cfg.block_size, cfg.dropout)
self.mlp = PreNormMLP(cfg.n_embd, cfg.dropout)
def forward(self, x: torch.Tensor):
x, attn_out = self.attn(x)
x, mlp_out = self.mlp(x)
return x, {"attn": attn_out, "mlp": mlp_out}
class ResearchTransformer(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
self.cfg = cfg
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.n_embd)
self.pos_emb = nn.Embedding(cfg.block_size, cfg.n_embd)
self.drop = nn.Dropout(cfg.dropout)
self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layer)])
self.ln_f = nn.LayerNorm(cfg.n_embd)
self.lm_head = nn.Linear(cfg.n_embd, cfg.vocab_size, bias=False)
self.lm_head.weight = self.tok_emb.weight
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
nn.init.zeros_(module.bias)
def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None, return_activations: bool = False):
B, T = input_ids.size()
assert T <= self.cfg.block_size, f"Input length {T} exceeds block size {self.cfg.block_size}"
pos = torch.arange(0, T, dtype=torch.long, device=input_ids.device).unsqueeze(0)
x = self.tok_emb(input_ids) + self.pos_emb(pos)
x = self.drop(x)
activations = []
for blk in self.blocks:
x, acts = blk(x)
if return_activations:
activations.append(acts)
x = self.ln_f(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
loss = F.cross_entropy(
logits[:, :-1, :].contiguous().view(-1, logits.size(-1)),
labels[:, 1:].contiguous().view(-1),
ignore_index=-100
)
class Output:
pass
out = Output()
out.logits = logits
out.loss = loss
if return_activations:
out.activations = activations
return out
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 50):
self.eval()
for _ in range(max_new_tokens):
if input_ids.size(1) > self.cfg.block_size:
input_ids = input_ids[:, -self.cfg.block_size:]
out = self(input_ids)
next_token = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True)
input_ids = torch.cat([input_ids, next_token], dim=1)
return input_ids