Much of this code is adapted from Andrej Karpathy's NanoGPT
import math
from dataclasses import dataclass
import torch
from torch import nn
from torch.nn import functional as F
from .model import GPT, MLP, GPTConfig
class NonCausalSelfAttention(nn.Module):
def __init__(self, config):
assert config.n_embd % config.n_head == 0
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.n_head = config.n_head
self.n_embd = config.n_embd
self.dropout = config.dropout
self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0
def forward(self, x):
B, T, C = x.size()
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
if self.flash:
y = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=False
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.resid_dropout(self.c_proj(y))
return y
class FineBlock(nn.Module):
def __init__(self, config):
self.ln_1 = nn.LayerNorm(config.n_embd)
self.attn = NonCausalSelfAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class FineGPT(GPT):
def __init__(self, config):
del self.lm_head
self.config = config
self.n_codes_total = config.n_codes_total
self.transformer = nn.ModuleDict(
[nn.Embedding(config.input_vocab_size, config.n_embd) for _ in range(config.n_codes_total)]
wpe=nn.Embedding(config.block_size, config.n_embd),
h=nn.ModuleList([FineBlock(config) for _ in range(config.n_layer)]),
self.lm_heads = nn.ModuleList(
nn.Linear(config.n_embd, config.output_vocab_size, bias=False)
for _ in range(config.n_codes_given, self.n_codes_total)
for i in range(self.n_codes_total - config.n_codes_given):
self.transformer.wtes[i + 1].weight = self.lm_heads[i].weight
def forward(self, pred_idx, idx):
device = idx.device
b, t, codes = idx.size()
assert (
t <= self.config.block_size
), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
assert pred_idx > 0, "cannot predict 0th codebook"
assert codes == self.n_codes_total, (b, t, codes)
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)
tok_embs = [
wte(idx[:, :, i]).unsqueeze(-1) for i, wte in enumerate(self.transformer.wtes)
tok_emb = torch.cat(tok_embs, dim=-1)
pos_emb = self.transformer.wpe(pos)
x = tok_emb[:, :, :, : pred_idx + 1].sum(dim=-1)
x = self.transformer.drop(x + pos_emb)
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
logits = self.lm_heads[pred_idx - self.config.n_codes_given](x)
return logits
def get_num_params(self, non_embedding=True):
Return the number of parameters in the model.
For non-embedding count (default), the position embeddings get subtracted.
The token embeddings would too, except due to the parameter sharing these
params are actually used as weights in the final layer, so we include them.
n_params = sum(p.numel() for p in self.parameters())
if non_embedding:
for wte in self.transformer.wtes:
n_params -= wte.weight.numel()
n_params -= self.transformer.wpe.weight.numel()
return n_params
class FineGPTConfig(GPTConfig):
n_codes_total: int = 8
n_codes_given: int = 1