Neu256 commited on
Commit
f6a4b47
·
verified ·
1 Parent(s): be73e4f

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +249 -0
model.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ from utils import DEVICE
6
+
7
+ class RMSNorm(torch.nn.Module):
8
+ def __init__(self, dim: int, eps: float = 1e-6):
9
+ super().__init__()
10
+ self.eps = eps
11
+ self.weight = nn.Parameter(torch.ones(dim))
12
+
13
+ def _norm(self, x):
14
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
15
+
16
+ def forward(self, x):
17
+ output = self._norm(x.float()).type_as(x)
18
+ return output * self.weight
19
+
20
+
21
+ class Attention(nn.Module):
22
+ """
23
+ Multi-head Self-Attention with RoPE
24
+ """
25
+
26
+ def __init__(self, num_heads, head_size, num_embed, dropout):
27
+ super().__init__()
28
+ self.num_heads = num_heads
29
+ self.head_size = head_size
30
+
31
+ self.wq = nn.Linear(num_embed, num_heads * head_size, bias = False)
32
+ self.wk = nn.Linear(num_embed, num_heads * head_size, bias = False)
33
+ self.wv = nn.Linear(num_embed, num_heads * head_size, bias = False)
34
+ self.wo = nn.Linear(num_heads * head_size, num_embed, bias = False)
35
+
36
+ inv_freq = 1 / (500000 ** (torch.arange(0, head_size, 2)[: (head_size // 2)].float() / head_size))
37
+ self.register_buffer('inv_freq', inv_freq)
38
+
39
+ self.dropout = nn.Dropout(dropout)
40
+
41
+ def reshape_for_broadcast(self, freq_cis, x):
42
+ ndim = x.ndim
43
+ shape = [1] * (ndim - 2) + list(freq_cis.shape)
44
+ return freq_cis.view(*shape)
45
+
46
+ def apply_rope(self, x, position, freq):
47
+ t = torch.arange(position, device=freq.device, dtype=torch.float32)
48
+ freq = torch.outer(t, freq)
49
+ freq_cis = torch.polar(torch.ones_like(freq), freq)
50
+ x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
51
+ freq_cis = self.reshape_for_broadcast(freq_cis, x)
52
+ x_out = torch.view_as_real(x_ * freq_cis).flatten(3)
53
+ return x_out.type_as(x)
54
+
55
+ def forward(self, x):
56
+ B, T, C = x.shape
57
+
58
+ mask = torch.triu(torch.full((T, T), float("-inf"), device=x.device), diagonal=1)
59
+
60
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
61
+
62
+ xq = xq.view(B, T, self.num_heads, self.head_size)
63
+ xk = xk.view(B, T, self.num_heads, self.head_size)
64
+ xv = xv.view(B, T, self.num_heads, self.head_size)
65
+
66
+ xq = xq.transpose(1, 2)
67
+ xk = xk.transpose(1, 2)
68
+ xv = xv.transpose(1, 2)
69
+
70
+ xq = self.apply_rope(xq, T, self.inv_freq)
71
+ xk = self.apply_rope(xk, T, self.inv_freq)
72
+
73
+ attn_weights = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_size)
74
+ attn_weights += mask
75
+ attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(xq)
76
+ output = torch.matmul(attn_weights, xv)
77
+ output = output.transpose(1, 2).contiguous().view(B, T, C)
78
+ return self.dropout(self.wo(output))
79
+
80
+
81
+ class MLP(nn.Module):
82
+ """
83
+ Implementation of a Multi-Layer Perceptron (MLP) sub-module.
84
+
85
+ This module is a simple feed-forward network with two hidden layers
86
+ used in various Transformer components like the Mixture of Experts layer.
87
+ """
88
+
89
+ def __init__(self, num_embed, dropout):
90
+ """
91
+ Constructor for the MLP.
92
+
93
+ Args:
94
+ num_embed (int): The number of embedding dimensions.
95
+ """
96
+
97
+ super().__init__()
98
+ hidden = int(4 * num_embed * 2 / 3)
99
+
100
+ # Define linear layers for the MLP
101
+ self.w1 = nn.Linear(num_embed, hidden, bias=False)
102
+ self.w2 = nn.Linear(hidden, num_embed, bias=False)
103
+
104
+ self.dropout = nn.Dropout(dropout)
105
+
106
+ def forward(self, x):
107
+ """
108
+ Forward pass of the MLP.
109
+
110
+ Args:
111
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, num_embed).
112
+
113
+ Returns:
114
+ torch.Tensor: Output tensor after passing through the MLP (shape: batch_size, seq_len, num_embed).
115
+ """
116
+ return self.dropout(self.w2(F.silu(self.w1(x))))
117
+
118
+ class TransformerBlock(nn.Module):
119
+ """
120
+ This calss will group together MultiHead Attention and
121
+ MLP, so that we can copy it in Transformer
122
+ """
123
+
124
+ def __init__(self, num_heads, head_size, num_embed, dropout):
125
+ super().__init__()
126
+
127
+ self.mha = Attention(
128
+ num_heads=num_heads,
129
+ head_size=head_size,
130
+ num_embed=num_embed,
131
+ dropout=dropout
132
+ )
133
+
134
+ self.mlp = MLP(num_embed = num_embed, dropout = dropout)
135
+
136
+ # add the layer normalization
137
+ self.norm1 = RMSNorm(num_embed)
138
+ self.norm2 = RMSNorm(num_embed)
139
+
140
+ def forward(self, x):
141
+ """
142
+ Decodes the input sequence.
143
+
144
+ Args:
145
+ x (torch.Tensor): A tensor of shape (batch_size, sequence_length, embedding_dim).
146
+ memory (torch.Tensor): A tensor of shape (batch_size, memory_length, embedding_dim).
147
+
148
+ Returns:
149
+ torch.Tensor: A tensor of shape (batch_size, sequence_length, embedding_dim).
150
+ """
151
+ #print(x.shape)
152
+ x = x + self.mha(self.norm1(x))
153
+ x = x + self.mlp(self.norm2(x))
154
+
155
+ return x
156
+
157
+
158
+ class Transformer(nn.Module):
159
+ def __init__(self, **kwargs):
160
+ super().__init__()
161
+ # a simple lookup table that stores embeddings of a fixed dictionary and size
162
+ # each token directly reads off the logits for the next token from a lookup table
163
+ # see more: https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
164
+ self.model_type = 'Prome'
165
+ self.vocab_size = kwargs.get("vocab_size", 100)
166
+ self.num_embed = kwargs.get("num_embed", 32)
167
+ self.block_size = kwargs.get("block_size", 8)
168
+ self.num_heads = kwargs.get("num_heads", 4)
169
+ self.head_size = kwargs.get("head_size", 128)
170
+ self.num_layers = kwargs.get("num_layers", 4)
171
+ self.dropout = kwargs.get("dropout", 0.2)
172
+ self.max_seq_len = kwargs.get("max_sqe_len", 1024)
173
+ # each token reads the logits for the next token from a lookup table
174
+ self.token_embedding_table = nn.Embedding(self.vocab_size, self.num_embed)
175
+ # each position from 0 to block_size-1 will get its embedding
176
+ #self.position_embedding_table = nn.Embedding(self.max_seq_len, self.num_embed)
177
+
178
+ self.decoder = nn.Sequential(
179
+ *[
180
+ TransformerBlock(
181
+ num_heads=self.num_heads,
182
+ head_size=self.head_size,
183
+ num_embed=self.num_embed,
184
+ dropout=self.dropout,
185
+ )
186
+ for _ in range(self.num_layers)
187
+ ]
188
+ )
189
+
190
+ self.lm_head = nn.Linear(self.num_embed, self.vocab_size)
191
+
192
+ def forward(self, idx, targets=None):
193
+ B, T = idx.shape
194
+ # idx and targets are (B,T) tensor of integers
195
+ # the token_emb is (B, T, C), C = NUM_EMBED
196
+ x = self.token_embedding_table(idx)
197
+ # (T, C)
198
+ #posit_emb = self.position_embedding_table(torch.arange(T, device=DEVICE))
199
+
200
+ #x = token_emb + posit_emb
201
+
202
+ x = self.decoder(x)
203
+
204
+ # (B, T, vocab_size)
205
+ logits = self.lm_head(x)
206
+
207
+ # Compute the loss
208
+ if targets != None:
209
+ # cross_entropy accepts inputs in a (batch_size, num_classes)
210
+ # so we need to reformat our logits dimensions to
211
+ # (batch_size * time, dim_vocabulary), time = block_size
212
+ #logits = logits.to(dtype=torch.float32)
213
+
214
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
215
+ else:
216
+ loss = None
217
+
218
+ return logits, loss
219
+
220
+ def generate(self, idx: torch.Tensor, max_new_tokens: int, temperature: float = 0.6, top_p: float = 0.9):
221
+ for _ in range(max_new_tokens):
222
+ idx_crop = idx[:, -self.max_seq_len:]
223
+
224
+ logits, loss = self.forward(idx_crop)
225
+ logits = logits[:, -1, :]
226
+
227
+ if temperature > 0:
228
+ probs = F.softmax(logits / temperature, dim=-1)
229
+ idx_next = self.sample_top_p(probs, top_p)
230
+ else:
231
+ probs = F.softmax(logits, dim=-1)
232
+ idx_next = torch.multinomial(probs, num_samples=1)
233
+ idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
234
+ return idx
235
+
236
+ def sample_top_p(self, probs: torch.Tensor, top_p: float) -> torch.Tensor:
237
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
238
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
239
+
240
+ # Create a mask for top-p filtering
241
+ top_p_mask = cumulative_probs <= top_p
242
+ top_p_mask[..., 1:] = top_p_mask[..., :-1].clone()
243
+ top_p_mask[..., 0] = 1
244
+
245
+ filtered_probs = sorted_probs * top_p_mask
246
+ filtered_probs /= filtered_probs.sum(dim=-1, keepdim=True) # Normalize filtered probabilities
247
+
248
+ next_token = torch.multinomial(filtered_probs, num_samples=1)
249
+ return torch.gather(sorted_indices, -1, next_token)