rider-provider-777 commited on
Commit
c975d73
·
verified ·
1 Parent(s): 3b26fdb

Upload research_model.py

Browse files
Files changed (1) hide show
  1. models/research_model.py +147 -0
models/research_model.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Optional, Tuple, List
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ @dataclass
12
+ class ModelConfig:
13
+ vocab_size: int = 65536
14
+ n_layer: int = 6
15
+ n_head: int = 8
16
+ n_embd: int = 512
17
+ block_size: int = 512
18
+ dropout: float = 0.1
19
+
20
+
21
+ class PreNormSelfAttention(nn.Module):
22
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
23
+ super().__init__()
24
+ assert n_embd % n_head == 0, "n_embd must be divisible by n_head"
25
+ self.n_head = n_head
26
+ self.head_dim = n_embd // n_head
27
+ self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False)
28
+ self.proj = nn.Linear(n_embd, n_embd, bias=False)
29
+ self.attn_drop = nn.Dropout(dropout)
30
+ self.resid_drop = nn.Dropout(dropout)
31
+ self.ln = nn.LayerNorm(n_embd)
32
+
33
+ mask = torch.tril(torch.ones(block_size, block_size))
34
+ self.register_buffer("mask", mask.view(1, 1, block_size, block_size), persistent=False)
35
+
36
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
37
+ B, T, C = x.size()
38
+ x_norm = self.ln(x)
39
+ qkv = self.qkv(x_norm).view(B, T, 3, self.n_head, self.head_dim).transpose(1, 3)
40
+ q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
41
+ att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
42
+ att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
43
+ att = F.softmax(att, dim=-1)
44
+ att = self.attn_drop(att)
45
+ y = att @ v
46
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
47
+ y = self.resid_drop(self.proj(y))
48
+ out = x + y
49
+ return out, y
50
+
51
+
52
+ class PreNormMLP(nn.Module):
53
+ def __init__(self, n_embd: int, dropout: float):
54
+ super().__init__()
55
+ hidden = 4 * n_embd
56
+ self.ln = nn.LayerNorm(n_embd)
57
+ self.fc1 = nn.Linear(n_embd, hidden)
58
+ self.fc2 = nn.Linear(hidden, n_embd)
59
+ self.drop = nn.Dropout(dropout)
60
+
61
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
62
+ x_norm = self.ln(x)
63
+ h = F.gelu(self.fc1(x_norm))
64
+ h = self.drop(h)
65
+ y = self.fc2(h)
66
+ y = self.drop(y)
67
+ out = x + y
68
+ return out, y
69
+
70
+
71
+ class Block(nn.Module):
72
+ def __init__(self, cfg: ModelConfig):
73
+ super().__init__()
74
+ self.attn = PreNormSelfAttention(cfg.n_embd, cfg.n_head, cfg.block_size, cfg.dropout)
75
+ self.mlp = PreNormMLP(cfg.n_embd, cfg.dropout)
76
+
77
+ def forward(self, x: torch.Tensor):
78
+ x, attn_out = self.attn(x)
79
+ x, mlp_out = self.mlp(x)
80
+ return x, {"attn": attn_out, "mlp": mlp_out}
81
+
82
+
83
+ class ResearchTransformer(nn.Module):
84
+
85
+ def __init__(self, cfg: ModelConfig):
86
+ super().__init__()
87
+ self.cfg = cfg
88
+ self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.n_embd)
89
+ self.pos_emb = nn.Embedding(cfg.block_size, cfg.n_embd)
90
+ self.drop = nn.Dropout(cfg.dropout)
91
+ self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layer)])
92
+ self.ln_f = nn.LayerNorm(cfg.n_embd)
93
+ self.lm_head = nn.Linear(cfg.n_embd, cfg.vocab_size, bias=False)
94
+ self.lm_head.weight = self.tok_emb.weight
95
+
96
+ self.apply(self._init_weights)
97
+
98
+ def _init_weights(self, module):
99
+ if isinstance(module, (nn.Linear, nn.Embedding)):
100
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
101
+ if isinstance(module, nn.Linear) and module.bias is not None:
102
+ nn.init.zeros_(module.bias)
103
+
104
+ def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
105
+ labels: Optional[torch.Tensor] = None, return_activations: bool = False):
106
+ B, T = input_ids.size()
107
+ assert T <= self.cfg.block_size, f"Input length {T} exceeds block size {self.cfg.block_size}"
108
+ pos = torch.arange(0, T, dtype=torch.long, device=input_ids.device).unsqueeze(0)
109
+ x = self.tok_emb(input_ids) + self.pos_emb(pos)
110
+ x = self.drop(x)
111
+
112
+ activations = []
113
+ for blk in self.blocks:
114
+ x, acts = blk(x)
115
+ if return_activations:
116
+ activations.append(acts)
117
+
118
+ x = self.ln_f(x)
119
+ logits = self.lm_head(x)
120
+
121
+ loss = None
122
+ if labels is not None:
123
+ loss = F.cross_entropy(
124
+ logits[:, :-1, :].contiguous().view(-1, logits.size(-1)),
125
+ labels[:, 1:].contiguous().view(-1),
126
+ ignore_index=-100
127
+ )
128
+
129
+ class Output:
130
+ pass
131
+ out = Output()
132
+ out.logits = logits
133
+ out.loss = loss
134
+ if return_activations:
135
+ out.activations = activations
136
+ return out
137
+
138
+ @torch.no_grad()
139
+ def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 50):
140
+ self.eval()
141
+ for _ in range(max_new_tokens):
142
+ if input_ids.size(1) > self.cfg.block_size:
143
+ input_ids = input_ids[:, -self.cfg.block_size:]
144
+ out = self(input_ids)
145
+ next_token = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True)
146
+ input_ids = torch.cat([input_ids, next_token], dim=1)
147
+ return input_ids