Spaces:
Sleeping
Sleeping
Upload research_model.py
Browse files- models/research_model.py +147 -0
models/research_model.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import math
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import Optional, Tuple, List
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class ModelConfig:
|
13 |
+
vocab_size: int = 65536
|
14 |
+
n_layer: int = 6
|
15 |
+
n_head: int = 8
|
16 |
+
n_embd: int = 512
|
17 |
+
block_size: int = 512
|
18 |
+
dropout: float = 0.1
|
19 |
+
|
20 |
+
|
21 |
+
class PreNormSelfAttention(nn.Module):
|
22 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
23 |
+
super().__init__()
|
24 |
+
assert n_embd % n_head == 0, "n_embd must be divisible by n_head"
|
25 |
+
self.n_head = n_head
|
26 |
+
self.head_dim = n_embd // n_head
|
27 |
+
self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False)
|
28 |
+
self.proj = nn.Linear(n_embd, n_embd, bias=False)
|
29 |
+
self.attn_drop = nn.Dropout(dropout)
|
30 |
+
self.resid_drop = nn.Dropout(dropout)
|
31 |
+
self.ln = nn.LayerNorm(n_embd)
|
32 |
+
|
33 |
+
mask = torch.tril(torch.ones(block_size, block_size))
|
34 |
+
self.register_buffer("mask", mask.view(1, 1, block_size, block_size), persistent=False)
|
35 |
+
|
36 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
37 |
+
B, T, C = x.size()
|
38 |
+
x_norm = self.ln(x)
|
39 |
+
qkv = self.qkv(x_norm).view(B, T, 3, self.n_head, self.head_dim).transpose(1, 3)
|
40 |
+
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
|
41 |
+
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
42 |
+
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
|
43 |
+
att = F.softmax(att, dim=-1)
|
44 |
+
att = self.attn_drop(att)
|
45 |
+
y = att @ v
|
46 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
47 |
+
y = self.resid_drop(self.proj(y))
|
48 |
+
out = x + y
|
49 |
+
return out, y
|
50 |
+
|
51 |
+
|
52 |
+
class PreNormMLP(nn.Module):
|
53 |
+
def __init__(self, n_embd: int, dropout: float):
|
54 |
+
super().__init__()
|
55 |
+
hidden = 4 * n_embd
|
56 |
+
self.ln = nn.LayerNorm(n_embd)
|
57 |
+
self.fc1 = nn.Linear(n_embd, hidden)
|
58 |
+
self.fc2 = nn.Linear(hidden, n_embd)
|
59 |
+
self.drop = nn.Dropout(dropout)
|
60 |
+
|
61 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
62 |
+
x_norm = self.ln(x)
|
63 |
+
h = F.gelu(self.fc1(x_norm))
|
64 |
+
h = self.drop(h)
|
65 |
+
y = self.fc2(h)
|
66 |
+
y = self.drop(y)
|
67 |
+
out = x + y
|
68 |
+
return out, y
|
69 |
+
|
70 |
+
|
71 |
+
class Block(nn.Module):
|
72 |
+
def __init__(self, cfg: ModelConfig):
|
73 |
+
super().__init__()
|
74 |
+
self.attn = PreNormSelfAttention(cfg.n_embd, cfg.n_head, cfg.block_size, cfg.dropout)
|
75 |
+
self.mlp = PreNormMLP(cfg.n_embd, cfg.dropout)
|
76 |
+
|
77 |
+
def forward(self, x: torch.Tensor):
|
78 |
+
x, attn_out = self.attn(x)
|
79 |
+
x, mlp_out = self.mlp(x)
|
80 |
+
return x, {"attn": attn_out, "mlp": mlp_out}
|
81 |
+
|
82 |
+
|
83 |
+
class ResearchTransformer(nn.Module):
|
84 |
+
|
85 |
+
def __init__(self, cfg: ModelConfig):
|
86 |
+
super().__init__()
|
87 |
+
self.cfg = cfg
|
88 |
+
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.n_embd)
|
89 |
+
self.pos_emb = nn.Embedding(cfg.block_size, cfg.n_embd)
|
90 |
+
self.drop = nn.Dropout(cfg.dropout)
|
91 |
+
self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layer)])
|
92 |
+
self.ln_f = nn.LayerNorm(cfg.n_embd)
|
93 |
+
self.lm_head = nn.Linear(cfg.n_embd, cfg.vocab_size, bias=False)
|
94 |
+
self.lm_head.weight = self.tok_emb.weight
|
95 |
+
|
96 |
+
self.apply(self._init_weights)
|
97 |
+
|
98 |
+
def _init_weights(self, module):
|
99 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
100 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
101 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
102 |
+
nn.init.zeros_(module.bias)
|
103 |
+
|
104 |
+
def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
|
105 |
+
labels: Optional[torch.Tensor] = None, return_activations: bool = False):
|
106 |
+
B, T = input_ids.size()
|
107 |
+
assert T <= self.cfg.block_size, f"Input length {T} exceeds block size {self.cfg.block_size}"
|
108 |
+
pos = torch.arange(0, T, dtype=torch.long, device=input_ids.device).unsqueeze(0)
|
109 |
+
x = self.tok_emb(input_ids) + self.pos_emb(pos)
|
110 |
+
x = self.drop(x)
|
111 |
+
|
112 |
+
activations = []
|
113 |
+
for blk in self.blocks:
|
114 |
+
x, acts = blk(x)
|
115 |
+
if return_activations:
|
116 |
+
activations.append(acts)
|
117 |
+
|
118 |
+
x = self.ln_f(x)
|
119 |
+
logits = self.lm_head(x)
|
120 |
+
|
121 |
+
loss = None
|
122 |
+
if labels is not None:
|
123 |
+
loss = F.cross_entropy(
|
124 |
+
logits[:, :-1, :].contiguous().view(-1, logits.size(-1)),
|
125 |
+
labels[:, 1:].contiguous().view(-1),
|
126 |
+
ignore_index=-100
|
127 |
+
)
|
128 |
+
|
129 |
+
class Output:
|
130 |
+
pass
|
131 |
+
out = Output()
|
132 |
+
out.logits = logits
|
133 |
+
out.loss = loss
|
134 |
+
if return_activations:
|
135 |
+
out.activations = activations
|
136 |
+
return out
|
137 |
+
|
138 |
+
@torch.no_grad()
|
139 |
+
def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 50):
|
140 |
+
self.eval()
|
141 |
+
for _ in range(max_new_tokens):
|
142 |
+
if input_ids.size(1) > self.cfg.block_size:
|
143 |
+
input_ids = input_ids[:, -self.cfg.block_size:]
|
144 |
+
out = self(input_ids)
|
145 |
+
next_token = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True)
|
146 |
+
input_ids = torch.cat([input_ids, next_token], dim=1)
|
147 |
+
return input_ids
|