Nunzio commited on
Commit
fd709e5
·
1 Parent(s): 6fc931a

Delete check_perplexity.ipynb

Browse files
Files changed (1) hide show
  1. check_perplexity.ipynb +0 -691
check_perplexity.ipynb DELETED
@@ -1,691 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {},
6
- "source": [
7
- "### Original GPT-J perlexity"
8
- ]
9
- },
10
- {
11
- "cell_type": "code",
12
- "execution_count": 1,
13
- "metadata": {},
14
- "outputs": [],
15
- "source": [
16
- "import torch\n",
17
- "import torch.nn as nn\n",
18
- "import torch.nn.functional as F\n",
19
- "\n",
20
- "from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise\n",
21
- "import transformers\n",
22
- "from tqdm.auto import tqdm\n",
23
- "\n",
24
- "\n",
25
- "\n",
26
- "model_name = \"EleutherAI/gpt-j-6B\"\n",
27
- "gpt = transformers.AutoModelForCausalLM.from_pretrained(model_name)\n",
28
- "tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)"
29
- ]
30
- },
31
- {
32
- "cell_type": "code",
33
- "execution_count": 11,
34
- "metadata": {},
35
- "outputs": [],
36
- "source": [
37
- "device = 'cuda' if torch.cuda.is_available else 'cpu'\n",
38
- "gpt.to(device).train(False);"
39
- ]
40
- },
41
- {
42
- "cell_type": "code",
43
- "execution_count": 4,
44
- "metadata": {},
45
- "outputs": [
46
- {
47
- "name": "stderr",
48
- "output_type": "stream",
49
- "text": [
50
- "Reusing dataset wikitext (/home/jheuristic/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)\n"
51
- ]
52
- },
53
- {
54
- "data": {
55
- "application/vnd.jupyter.widget-view+json": {
56
- "model_id": "47f0459174da4ee2bf064c9ae81fdecd",
57
- "version_major": 2,
58
- "version_minor": 0
59
- },
60
- "text/plain": [
61
- " 0%| | 0/3 [00:00<?, ?it/s]"
62
- ]
63
- },
64
- "metadata": {},
65
- "output_type": "display_data"
66
- }
67
- ],
68
- "source": [
69
- "from datasets import load_dataset\n",
70
- "data = load_dataset('wikitext', 'wikitext-2-v1')['test']"
71
- ]
72
- },
73
- {
74
- "cell_type": "code",
75
- "execution_count": 62,
76
- "metadata": {},
77
- "outputs": [
78
- {
79
- "data": {
80
- "application/vnd.jupyter.widget-view+json": {
81
- "model_id": "26cca02205624aafa740e55542ca2e6c",
82
- "version_major": 2,
83
- "version_minor": 0
84
- },
85
- "text/plain": [
86
- " 0%| | 0/4358 [00:00<?, ?it/s]"
87
- ]
88
- },
89
- "metadata": {},
90
- "output_type": "display_data"
91
- }
92
- ],
93
- "source": [
94
- "\n",
95
- "numerator, denominator = 0, 0\n",
96
- "collator = transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)\n",
97
- "loader = torch.utils.data.DataLoader(data, batch_size=1, num_workers=0, shuffle=False)\n",
98
- "\n",
99
- "\n",
100
- "with torch.no_grad(), torch.cuda.amp.autocast(), tqdm(loader) as progressbar:\n",
101
- " for i, row in enumerate(progressbar):\n",
102
- " if max(map(len, row['text'])) <= 1:\n",
103
- " continue\n",
104
- " batch = tokenizer(**row, truncation=False, return_tensors='pt')\n",
105
- " batch = {k: v.cuda() for k, v in batch.items()}\n",
106
- "\n",
107
- " out = gpt.forward(**batch,)\n",
108
- "\n",
109
- " loss = F.cross_entropy(out.logits[:, :-1, :].flatten(0, -2), batch['input_ids'][:, 1:].flatten(),\n",
110
- " reduction='none')\n",
111
- "\n",
112
- " numerator += loss.sum().item()\n",
113
- " denominator += len(loss)\n",
114
- " progressbar.desc = f\"{numerator/denominator:.3f}\""
115
- ]
116
- },
117
- {
118
- "cell_type": "code",
119
- "execution_count": 63,
120
- "metadata": {},
121
- "outputs": [
122
- {
123
- "data": {
124
- "text/plain": [
125
- "18.435175441788164"
126
- ]
127
- },
128
- "execution_count": 63,
129
- "metadata": {},
130
- "output_type": "execute_result"
131
- }
132
- ],
133
- "source": [
134
- "# test perplexity\n",
135
- "import math\n",
136
- "math.exp(numerator/denominator)"
137
- ]
138
- },
139
- {
140
- "cell_type": "markdown",
141
- "metadata": {},
142
- "source": [
143
- "### Quantized GPT-J Perplexity"
144
- ]
145
- },
146
- {
147
- "cell_type": "code",
148
- "execution_count": 64,
149
- "metadata": {},
150
- "outputs": [],
151
- "source": [
152
- "\n",
153
- "import torch\n",
154
- "import torch.nn as nn\n",
155
- "from torch.cuda.amp import custom_fwd, custom_bwd\n",
156
- " \n",
157
- "from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise\n",
158
- "import transformers\n",
159
- "\n",
160
- "\n",
161
- "class DequantizeAndLinear(torch.autograd.Function):\n",
162
- " \n",
163
- " @staticmethod\n",
164
- " @custom_fwd\n",
165
- " def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,\n",
166
- " absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):\n",
167
- " weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)\n",
168
- " ctx.save_for_backward(input, weights_quantized, absmax, code)\n",
169
- " ctx._has_bias = bias is not None\n",
170
- " return F.linear(input, weights_deq, bias)\n",
171
- " \n",
172
- " @staticmethod\n",
173
- " @custom_bwd\n",
174
- " def backward(ctx, grad_output: torch.Tensor):\n",
175
- " assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]\n",
176
- " input, weights_quantized, absmax, code = ctx.saved_tensors\n",
177
- " # grad_output: [*batch, out_features]\n",
178
- " weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)\n",
179
- " grad_input = grad_output @ weights_deq\n",
180
- " grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None\n",
181
- " return grad_input, None, None, None, grad_bias\n",
182
- "\n",
183
- "\n",
184
- "class FrozenBNBLinear(nn.Module):\n",
185
- " def __init__(self, weight, absmax, code, bias=None):\n",
186
- " assert isinstance(bias, nn.Parameter) or bias is None\n",
187
- " super().__init__()\n",
188
- " self.out_features, self.in_features = weight.shape\n",
189
- " self.register_buffer(\"weight\", weight.requires_grad_(False))\n",
190
- " self.register_buffer(\"absmax\", absmax.requires_grad_(False))\n",
191
- " self.register_buffer(\"code\", code.requires_grad_(False))\n",
192
- " self.bias = bias\n",
193
- " \n",
194
- " def forward(self, input):\n",
195
- " return DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)\n",
196
- " \n",
197
- " @classmethod\n",
198
- " def from_linear(cls, linear: nn.Linear) -> \"FrozenBNBLinear\":\n",
199
- " weights_int8, state = quantize_blockise_lowmemory(linear.weight)\n",
200
- " return cls(weights_int8, *state, linear.bias)\n",
201
- " \n",
202
- " def __repr__(self):\n",
203
- " return f\"{self.__class__.__name__}({self.in_features}, {self.out_features})\"\n",
204
- " \n",
205
- " \n",
206
- "class FrozenBNBEmbedding(nn.Module):\n",
207
- " def __init__(self, weight, absmax, code):\n",
208
- " super().__init__()\n",
209
- " self.num_embeddings, self.embedding_dim = weight.shape\n",
210
- " self.register_buffer(\"weight\", weight.requires_grad_(False))\n",
211
- " self.register_buffer(\"absmax\", absmax.requires_grad_(False))\n",
212
- " self.register_buffer(\"code\", code.requires_grad_(False))\n",
213
- " \n",
214
- " def forward(self, x, **kwargs):\n",
215
- " with torch.no_grad():\n",
216
- " # note: both quantuized weights and input indices are *not* differentiable\n",
217
- " weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)\n",
218
- " return F.embedding(x, weight_deq, **kwargs)\n",
219
- " \n",
220
- " @classmethod\n",
221
- " def from_embedding(cls, embedding: nn.Embedding) -> \"FrozenBNBEmbedding\":\n",
222
- " weights_int8, state = quantize_blockise_lowmemory(embedding.weight)\n",
223
- " return cls(weights_int8, *state)\n",
224
- " \n",
225
- " def __repr__(self):\n",
226
- " return f\"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})\"\n",
227
- " \n",
228
- " \n",
229
- "def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):\n",
230
- " assert chunk_size % 4096 == 0\n",
231
- " code = None\n",
232
- " chunks = []\n",
233
- " absmaxes = []\n",
234
- " flat_tensor = matrix.view(-1)\n",
235
- " for i in range((matrix.numel() - 1) // chunk_size + 1):\n",
236
- " input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()\n",
237
- " quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)\n",
238
- " chunks.append(quantized_chunk)\n",
239
- " absmaxes.append(absmax_chunk)\n",
240
- " \n",
241
- " matrix_i8 = torch.cat(chunks).reshape_as(matrix)\n",
242
- " absmax = torch.cat(absmaxes)\n",
243
- " return matrix_i8, (absmax, code)\n",
244
- "\n",
245
- "\n",
246
- "def dummify(model, adapter_dim: int = 0):\n",
247
- " for module in list(model.modules()):\n",
248
- " for name, child in module.named_children():\n",
249
- " if isinstance(child, nn.Linear):\n",
250
- " print(name, child)\n",
251
- " setattr(\n",
252
- " module,\n",
253
- " name,\n",
254
- " FrozenBNBLinear(\n",
255
- " weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),\n",
256
- " absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),\n",
257
- " code=torch.zeros(256),\n",
258
- " bias=child.bias,\n",
259
- " ),\n",
260
- " )\n",
261
- " elif isinstance(child, nn.Embedding):\n",
262
- " setattr(\n",
263
- " module,\n",
264
- " name,\n",
265
- " FrozenBNBEmbedding(\n",
266
- " weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),\n",
267
- " absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),\n",
268
- " code=torch.zeros(256),\n",
269
- " )\n",
270
- " ),\n",
271
- "\n",
272
- "\n",
273
- "def bnbfy_(model, adapter_dim: int = 0):\n",
274
- " for module in list(model.modules()):\n",
275
- " for name, child in module.named_children():\n",
276
- " if isinstance(child, nn.Linear):\n",
277
- " print(name, child)\n",
278
- " setattr(module, name, FrozenBNBLinear.from_linear(child))\n",
279
- " \n",
280
- " elif isinstance(child, nn.Embedding):\n",
281
- " print(name, child)\n",
282
- " setattr(module, name, FrozenBNBEmbedding.from_embedding(child))"
283
- ]
284
- },
285
- {
286
- "cell_type": "code",
287
- "execution_count": 66,
288
- "metadata": {},
289
- "outputs": [],
290
- "source": [
291
- "class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):\n",
292
- " def __init__(self, config):\n",
293
- " print(\"MONKEYPATCH BLOCK\")\n",
294
- " super().__init__(config)\n",
295
- "\n",
296
- " dummify(self.attn)\n",
297
- " dummify(self.mlp)\n",
298
- "\n",
299
- "transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock\n",
300
- "\n",
301
- "\n",
302
- "class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):\n",
303
- " def __init__(self, config):\n",
304
- " super().__init__(config)\n",
305
- " dummify(self)\n",
306
- "class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):\n",
307
- " def __init__(self, config):\n",
308
- " super().__init__(config)\n",
309
- " dummify(self)\n"
310
- ]
311
- },
312
- {
313
- "cell_type": "code",
314
- "execution_count": 67,
315
- "metadata": {},
316
- "outputs": [
317
- {
318
- "data": {
319
- "application/vnd.jupyter.widget-view+json": {
320
- "model_id": "1c98b9ebbf8d44d8b0bc422d4bfce21f",
321
- "version_major": 2,
322
- "version_minor": 0
323
- },
324
- "text/plain": [
325
- "Downloading: 0%| | 0.00/0.98k [00:00<?, ?B/s]"
326
- ]
327
- },
328
- "metadata": {},
329
- "output_type": "display_data"
330
- },
331
- {
332
- "data": {
333
- "application/vnd.jupyter.widget-view+json": {
334
- "model_id": "04bc6b612ff146308ec0b63fc15640f8",
335
- "version_major": 2,
336
- "version_minor": 0
337
- },
338
- "text/plain": [
339
- "Downloading: 0%| | 0.00/5.75G [00:00<?, ?B/s]"
340
- ]
341
- },
342
- "metadata": {},
343
- "output_type": "display_data"
344
- },
345
- {
346
- "name": "stdout",
347
- "output_type": "stream",
348
- "text": [
349
- "MONKEYPATCH BLOCK\n",
350
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
351
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
352
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
353
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
354
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
355
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
356
- "MONKEYPATCH BLOCK\n",
357
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
358
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
359
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
360
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
361
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
362
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
363
- "MONKEYPATCH BLOCK\n",
364
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
365
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
366
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
367
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
368
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
369
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
370
- "MONKEYPATCH BLOCK\n",
371
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
372
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
373
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
374
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
375
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
376
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
377
- "MONKEYPATCH BLOCK\n",
378
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
379
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
380
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
381
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
382
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
383
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
384
- "MONKEYPATCH BLOCK\n",
385
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
386
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
387
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
388
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
389
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
390
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
391
- "MONKEYPATCH BLOCK\n",
392
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
393
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
394
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
395
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
396
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
397
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
398
- "MONKEYPATCH BLOCK\n",
399
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
400
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
401
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
402
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
403
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
404
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
405
- "MONKEYPATCH BLOCK\n",
406
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
407
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
408
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
409
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
410
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
411
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
412
- "MONKEYPATCH BLOCK\n",
413
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
414
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
415
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
416
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
417
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
418
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
419
- "MONKEYPATCH BLOCK\n",
420
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
421
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
422
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
423
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
424
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
425
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
426
- "MONKEYPATCH BLOCK\n",
427
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
428
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
429
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
430
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
431
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
432
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
433
- "MONKEYPATCH BLOCK\n",
434
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
435
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
436
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
437
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
438
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
439
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
440
- "MONKEYPATCH BLOCK\n",
441
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
442
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
443
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
444
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
445
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
446
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
447
- "MONKEYPATCH BLOCK\n",
448
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
449
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
450
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
451
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
452
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
453
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
454
- "MONKEYPATCH BLOCK\n",
455
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
456
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
457
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
458
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
459
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
460
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
461
- "MONKEYPATCH BLOCK\n",
462
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
463
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
464
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
465
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
466
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
467
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
468
- "MONKEYPATCH BLOCK\n",
469
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
470
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
471
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
472
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
473
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
474
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
475
- "MONKEYPATCH BLOCK\n",
476
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
477
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
478
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
479
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
480
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
481
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
482
- "MONKEYPATCH BLOCK\n",
483
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
484
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
485
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
486
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
487
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
488
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
489
- "MONKEYPATCH BLOCK\n",
490
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
491
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
492
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
493
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
494
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
495
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
496
- "MONKEYPATCH BLOCK\n"
497
- ]
498
- },
499
- {
500
- "name": "stdout",
501
- "output_type": "stream",
502
- "text": [
503
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
504
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
505
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
506
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
507
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
508
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
509
- "MONKEYPATCH BLOCK\n",
510
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
511
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
512
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
513
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
514
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
515
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
516
- "MONKEYPATCH BLOCK\n",
517
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
518
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
519
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
520
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
521
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
522
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
523
- "MONKEYPATCH BLOCK\n",
524
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
525
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
526
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
527
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
528
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
529
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
530
- "MONKEYPATCH BLOCK\n",
531
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
532
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
533
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
534
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
535
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
536
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
537
- "MONKEYPATCH BLOCK\n",
538
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
539
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
540
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
541
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
542
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
543
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
544
- "MONKEYPATCH BLOCK\n",
545
- "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
546
- "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
547
- "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
548
- "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
549
- "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
550
- "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
551
- "lm_head Linear(in_features=4096, out_features=50400, bias=True)\n"
552
- ]
553
- }
554
- ],
555
- "source": [
556
- "config = transformers.GPTJConfig.from_pretrained(\"EleutherAI/gpt-j-6B\")\n",
557
- "tokenizer = transformers.AutoTokenizer.from_pretrained(\"EleutherAI/gpt-j-6B\")\n",
558
- "gpt = GPTJForCausalLM.from_pretrained(\"hivemind/gpt-j-6B-8bit\", low_cpu_mem_usage=True)"
559
- ]
560
- },
561
- {
562
- "cell_type": "code",
563
- "execution_count": 68,
564
- "metadata": {},
565
- "outputs": [],
566
- "source": [
567
- "device = 'cuda' if torch.cuda.is_available else 'cpu'\n",
568
- "gpt.to(device).train(False);"
569
- ]
570
- },
571
- {
572
- "cell_type": "code",
573
- "execution_count": 69,
574
- "metadata": {},
575
- "outputs": [
576
- {
577
- "name": "stderr",
578
- "output_type": "stream",
579
- "text": [
580
- "Reusing dataset wikitext (/home/jheuristic/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)\n"
581
- ]
582
- },
583
- {
584
- "data": {
585
- "application/vnd.jupyter.widget-view+json": {
586
- "model_id": "bfbf0e20ed194d679d2f877085f679cb",
587
- "version_major": 2,
588
- "version_minor": 0
589
- },
590
- "text/plain": [
591
- " 0%| | 0/3 [00:00<?, ?it/s]"
592
- ]
593
- },
594
- "metadata": {},
595
- "output_type": "display_data"
596
- }
597
- ],
598
- "source": [
599
- "from datasets import load_dataset\n",
600
- "data = load_dataset('wikitext', 'wikitext-2-v1')['test']"
601
- ]
602
- },
603
- {
604
- "cell_type": "code",
605
- "execution_count": 70,
606
- "metadata": {},
607
- "outputs": [
608
- {
609
- "data": {
610
- "application/vnd.jupyter.widget-view+json": {
611
- "model_id": "53d7e76934de4a1498306d49e4f41ad2",
612
- "version_major": 2,
613
- "version_minor": 0
614
- },
615
- "text/plain": [
616
- " 0%| | 0/4358 [00:00<?, ?it/s]"
617
- ]
618
- },
619
- "metadata": {},
620
- "output_type": "display_data"
621
- }
622
- ],
623
- "source": [
624
- "\n",
625
- "numerator, denominator = 0, 0\n",
626
- "collator = transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)\n",
627
- "loader = torch.utils.data.DataLoader(data, batch_size=1, num_workers=0, shuffle=False)\n",
628
- "\n",
629
- "\n",
630
- "with torch.no_grad(), torch.cuda.amp.autocast(), tqdm(loader) as progressbar:\n",
631
- " for i, row in enumerate(progressbar):\n",
632
- " if max(map(len, row['text'])) <= 1:\n",
633
- " continue\n",
634
- " batch = tokenizer(**row, truncation=False, return_tensors='pt')\n",
635
- " batch = {k: v.cuda() for k, v in batch.items()}\n",
636
- "\n",
637
- " out = gpt.forward(**batch,)\n",
638
- "\n",
639
- " loss = F.cross_entropy(out.logits[:, :-1, :].flatten(0, -2), batch['input_ids'][:, 1:].flatten(),\n",
640
- " reduction='none')\n",
641
- "\n",
642
- " numerator += loss.sum().item()\n",
643
- " denominator += len(loss)\n",
644
- " progressbar.desc = f\"{numerator/denominator:.3f}\""
645
- ]
646
- },
647
- {
648
- "cell_type": "code",
649
- "execution_count": 71,
650
- "metadata": {},
651
- "outputs": [
652
- {
653
- "data": {
654
- "text/plain": [
655
- "18.427138288946292"
656
- ]
657
- },
658
- "execution_count": 71,
659
- "metadata": {},
660
- "output_type": "execute_result"
661
- }
662
- ],
663
- "source": [
664
- "# test perplexity\n",
665
- "import math\n",
666
- "math.exp(numerator/denominator)"
667
- ]
668
- }
669
- ],
670
- "metadata": {
671
- "kernelspec": {
672
- "display_name": "py38",
673
- "language": "python",
674
- "name": "py38"
675
- },
676
- "language_info": {
677
- "codemirror_mode": {
678
- "name": "ipython",
679
- "version": 3
680
- },
681
- "file_extension": ".py",
682
- "mimetype": "text/x-python",
683
- "name": "python",
684
- "nbconvert_exporter": "python",
685
- "pygments_lexer": "ipython3",
686
- "version": "3.8.1"
687
- }
688
- },
689
- "nbformat": 4,
690
- "nbformat_minor": 2
691
- }