TharunSivamani commited on
Commit
4228c6f
·
verified ·
1 Parent(s): 0ed0ae8

model file

Browse files
Files changed (1) hide show
  1. model.py +270 -0
model.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+
6
+ import warnings
7
+ warnings.simplefilter(action='ignore', category=FutureWarning)
8
+
9
+ # hyperparameters
10
+ batch_size = 8
11
+ block_size = 2048
12
+ eval_interval = 500
13
+ learning_rate = 3e-4
14
+ device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
15
+ eval_iters = 200
16
+ n_embd = 784
17
+ n_head = 12
18
+ n_layer = 12
19
+ dropout = 0.1
20
+
21
+ # Reserved memory allocation for H100 GPU
22
+ if torch.cuda.is_available():
23
+ torch.cuda.set_device(device)
24
+ torch.cuda.empty_cache()
25
+
26
+ # Mixed precision training setup
27
+ scaler = torch.cuda.amp.GradScaler()
28
+
29
+ torch.manual_seed(1337)
30
+
31
+ with open('input.txt', 'r', encoding='utf-8') as f:
32
+ text = f.read()
33
+
34
+ chars = sorted(list(set(text)))
35
+ vocab_size = 50257
36
+ stoi = {ch: i for i, ch in enumerate(chars)}
37
+ itos = {i: ch for i, ch in enumerate(chars)}
38
+ encode = lambda s: [stoi[c] for c in s]
39
+ decode = lambda l: ''.join([itos[i] for i in l])
40
+
41
+ data = torch.tensor(encode(text), dtype=torch.long)
42
+ n = int(0.9 * len(data))
43
+ train_data = data[:n]
44
+ val_data = data[n:]
45
+
46
+ def get_batch(split):
47
+ data = train_data if split == 'train' else val_data
48
+ ix = torch.randint(len(data) - block_size, (batch_size,))
49
+ x = torch.stack([data[i:i + block_size] for i in ix])
50
+ y = torch.stack([data[i + 1:i + block_size + 1] for i in ix])
51
+ x, y = x.to(device), y.to(device)
52
+ return x, y
53
+
54
+ @torch.no_grad()
55
+ def estimate_loss():
56
+ out = {}
57
+ model.eval()
58
+ eval_start_time = time.time()
59
+ for split in ['train', 'val']:
60
+ losses = torch.zeros(eval_iters)
61
+ for k in range(eval_iters):
62
+ X, Y = get_batch(split)
63
+ with torch.cuda.amp.autocast():
64
+ logits, loss = model(X, Y)
65
+ losses[k] = loss.item()
66
+ out[split] = losses.mean()
67
+ eval_time = time.time() - eval_start_time
68
+ print(f"Evaluation time: {eval_time:.2f} seconds")
69
+ model.train()
70
+ return out
71
+
72
+ class Head(nn.Module):
73
+ """ one head of self-attention """
74
+
75
+ def __init__(self, head_size):
76
+ super().__init__()
77
+ self.key = nn.Linear(n_embd, head_size, bias=False)
78
+ self.query = nn.Linear(n_embd, head_size, bias=False)
79
+ self.value = nn.Linear(n_embd, head_size, bias=False)
80
+ self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
81
+
82
+ self.dropout = nn.Dropout(dropout)
83
+
84
+ def forward(self, x):
85
+ # input of size (batch, time-step, channels)
86
+ # output of size (batch, time-step, head size)
87
+ B,T,C = x.shape
88
+ k = self.key(x) # (B,T,hs)
89
+ q = self.query(x) # (B,T,hs)
90
+ # compute attention scores ("affinities")
91
+ wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
92
+ wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
93
+ wei = F.softmax(wei, dim=-1) # (B, T, T)
94
+ wei = self.dropout(wei)
95
+ # perform the weighted aggregation of the values
96
+ v = self.value(x) # (B,T,hs)
97
+ out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
98
+ return out
99
+
100
+ class MultiHeadAttention(nn.Module):
101
+ """ multiple heads of self-attention in parallel """
102
+
103
+ def __init__(self, num_heads, head_size):
104
+ super().__init__()
105
+ self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
106
+ self.proj = nn.Linear(head_size * num_heads, n_embd)
107
+ self.dropout = nn.Dropout(dropout)
108
+
109
+ def forward(self, x):
110
+ out = torch.cat([h(x) for h in self.heads], dim=-1)
111
+ out = self.dropout(self.proj(out))
112
+ return out
113
+
114
+ class FeedFoward(nn.Module):
115
+ """ a simple linear layer followed by a non-linearity """
116
+
117
+ def __init__(self, n_embd):
118
+ super().__init__()
119
+ self.net = nn.Sequential(
120
+ nn.Linear(n_embd, 4 * n_embd),
121
+ nn.ReLU(),
122
+ nn.Linear(4 * n_embd, n_embd),
123
+ nn.Dropout(dropout),
124
+ )
125
+
126
+ def forward(self, x):
127
+ return self.net(x)
128
+
129
+ class Block(nn.Module):
130
+ """ Transformer block: communication followed by computation """
131
+
132
+ def __init__(self, n_embd, n_head):
133
+ # n_embd: embedding dimension, n_head: the number of heads we'd like
134
+ super().__init__()
135
+ head_size = n_embd // n_head
136
+ self.sa = MultiHeadAttention(n_head, head_size)
137
+ self.ffwd = FeedFoward(n_embd)
138
+ self.ln1 = nn.LayerNorm(n_embd)
139
+ self.ln2 = nn.LayerNorm(n_embd)
140
+
141
+ def forward(self, x):
142
+ x = x + self.sa(self.ln1(x))
143
+ x = x + self.ffwd(self.ln2(x))
144
+ return x
145
+
146
+ class GPTLanguageModel(nn.Module):
147
+
148
+ def __init__(self):
149
+ super().__init__()
150
+ # each token directly reads off the logits for the next token from a lookup table
151
+ self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
152
+ self.position_embedding_table = nn.Embedding(block_size, n_embd)
153
+ self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
154
+ self.ln_f = nn.LayerNorm(n_embd) # final layer norm
155
+ self.lm_head = nn.Linear(n_embd, vocab_size)
156
+
157
+ # better init, not covered in the original GPT video, but important, will cover in followup video
158
+ self.apply(self._init_weights)
159
+
160
+ def _init_weights(self, module):
161
+ if isinstance(module, nn.Linear):
162
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
163
+ if module.bias is not None:
164
+ torch.nn.init.zeros_(module.bias)
165
+ elif isinstance(module, nn.Embedding):
166
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
167
+
168
+ def forward(self, idx, targets=None):
169
+ B, T = idx.shape
170
+
171
+ # idx and targets are both (B,T) tensor of integers
172
+ tok_emb = self.token_embedding_table(idx) # (B,T,C)
173
+ pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
174
+ x = tok_emb + pos_emb # (B,T,C)
175
+ x = self.blocks(x) # (B,T,C)
176
+ x = self.ln_f(x) # (B,T,C)
177
+ logits = self.lm_head(x) # (B,T,vocab_size)
178
+
179
+ if targets is None:
180
+ loss = None
181
+ else:
182
+ B, T, C = logits.shape
183
+ logits = logits.view(B*T, C)
184
+ targets = targets.view(B*T)
185
+ loss = F.cross_entropy(logits, targets)
186
+
187
+ return logits, loss
188
+
189
+ def generate(self, idx, max_new_tokens):
190
+ # idx is (B, T) array of indices in the current context
191
+ for _ in range(max_new_tokens):
192
+ # crop idx to the last block_size tokens
193
+ idx_cond = idx[:, -block_size:]
194
+ # get the predictions
195
+ logits, loss = self(idx_cond)
196
+ # focus only on the last time step
197
+ logits = logits[:, -1, :] # becomes (B, C)
198
+ # apply softmax to get probabilities
199
+ probs = F.softmax(logits, dim=-1) # (B, C)
200
+ # sample from the distribution
201
+ idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
202
+ # append sampled index to the running sequence
203
+ idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
204
+ return idx
205
+
206
+ model = GPTLanguageModel()
207
+ m = model.to(device)
208
+ # print the number of parameters in the model
209
+ print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')
210
+
211
+ # optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
212
+
213
+ # training_start_time = time.time()
214
+
215
+ # iter = 0
216
+ # print("Initializing training...")
217
+
218
+ # while True:
219
+
220
+ # # Evaluate losses at evaluation intervals
221
+ # if iter % eval_interval == 0:
222
+ # losses = estimate_loss()
223
+ # print(f"Step {iter}: train loss = {losses['train']:.4f}, val loss = {losses['val']:.4f}")
224
+
225
+ # # Stop training if train loss is below the threshold
226
+ # if losses['train'] < 0.099999:
227
+ # print(f"Step {iter}: train loss = {losses['train']:.4f}, val loss = {losses['val']:.4f}")
228
+ # print("Training Loss is less than 0.099999. Stopping training.")
229
+
230
+ # model_save_path = 'model.pth'
231
+ # torch.save(model.state_dict(), model_save_path)
232
+ # print(f"Model saved to {model_save_path}")
233
+
234
+ # torch.save(optimizer.state_dict(), 'optimizer.pth')
235
+ # print("Optimizer state saved.")
236
+
237
+ # break
238
+
239
+ # # Fetch training batch
240
+ # xb, yb = get_batch('train')
241
+
242
+ # # Start iteration timing
243
+ # iter_start_time = time.time()
244
+
245
+ # # Forward pass with mixed precision
246
+ # with torch.amp.autocast('cuda'):
247
+ # logits, loss = model(xb, yb)
248
+
249
+ # # Backward pass and optimization
250
+ # optimizer.zero_grad()
251
+ # scaler.scale(loss).backward()
252
+ # scaler.step(optimizer)
253
+ # scaler.update()
254
+
255
+ # # Log every 50 iterations
256
+ # if iter % 50 == 0:
257
+ # iter_time = time.time() - iter_start_time
258
+ # print(f"Iteration {iter}: loss = {loss.item():.4f}, time = {iter_time:.2f} seconds")
259
+
260
+ # # Increment iteration counter
261
+ # iter += 1
262
+
263
+ # # Log total training time
264
+ # training_time = time.time() - training_start_time
265
+ # print(f"Total training time: {training_time:.2f} seconds")
266
+
267
+ # Generate text from the model
268
+ context = torch.zeros((1, 1), dtype=torch.long, device=device)
269
+ # print("Generated text:")
270
+ # print(decode(model.generate(context, max_new_tokens=500)[0].tolist()))