ilokavat commited on
Commit
fedb882
·
verified ·
1 Parent(s): 1c6bbca

Upload Finetuning-notebook-wav2vec2-large-960h-on-acc-data.ipynb

Browse files

Jupyter notebook creating this model based on fine-tuning [facebook/wav2vec2-large-960h](https://huggingface.co/facebook/wav2vec2-large-960h) with the [acc_dataset_v2](https://huggingface.co/datasets/monadical-labs/acc_dataset_v2).

Finetuning-notebook-wav2vec2-large-960h-on-acc-data.ipynb ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "c7f374d3-4c44-48cb-bba9-e18c099fbe38",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "!which python"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "id": "f0ae33f0-52f3-4d4c-88f9-28a458036be8",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "pip_ouput = !pip install accelerate evaluate torch transformers\n",
21
+ "#print(pip_ouput)"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "id": "c8defb5e-962b-49c0-a32f-5f50f0e52f50",
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "from datasets import load_dataset\n",
32
+ "\n",
33
+ "acc_dataset = load_dataset(\"monadical-labs/acc_dataset_v2\")"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": null,
39
+ "id": "58ea0320-19d7-4a98-954d-0d3302060e7a",
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "import re\n",
44
+ "\n",
45
+ "acc_dataset = acc_dataset.filter(lambda x: not re.search(r'\\d', x[\"text\"]))"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": null,
51
+ "id": "47d2aa85-7c2a-488f-abeb-448718571828",
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "from datasets import ClassLabel\n",
56
+ "import random\n",
57
+ "import pandas as pd\n",
58
+ "from IPython.display import display, HTML\n",
59
+ "\n",
60
+ "def show_random_elements(dataset, num_examples=10):\n",
61
+ " assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n",
62
+ " picks = []\n",
63
+ " for _ in range(num_examples):\n",
64
+ " pick = random.randint(0, len(dataset)-1)\n",
65
+ " while pick in picks:\n",
66
+ " pick = random.randint(0, len(dataset)-1)\n",
67
+ " picks.append(pick)\n",
68
+ " \n",
69
+ " df = pd.DataFrame(dataset[picks])\n",
70
+ " display(HTML(df.to_html()))"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "id": "34743004-7d8c-46b7-81ea-ad448ec450ed",
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "show_random_elements(acc_dataset[\"train\"].remove_columns([\"audio\"]))"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": null,
86
+ "id": "15beb241-e9d1-4baf-9375-10d1f6824a91",
87
+ "metadata": {},
88
+ "outputs": [],
89
+ "source": [
90
+ "import re\n",
91
+ "chars_to_ignore_regex = '[\\,\\?\\.\\!\\-\\;\\:\\\"]'\n",
92
+ "\n",
93
+ "def remove_special_characters(batch):\n",
94
+ " batch[\"text\"] = re.sub(chars_to_ignore_regex, '', batch[\"text\"]).upper()\n",
95
+ " return batch"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": null,
101
+ "id": "6791e858-e2a0-4494-83d5-0b2c30ded226",
102
+ "metadata": {},
103
+ "outputs": [],
104
+ "source": [
105
+ "acc_dataset = acc_dataset.map(remove_special_characters)"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": null,
111
+ "id": "2199ae88-fdcd-48ed-a028-e263f6237494",
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": [
115
+ "acc_dataset"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": null,
121
+ "id": "0fd7facb-809d-4209-9b93-27a06e2e044f",
122
+ "metadata": {},
123
+ "outputs": [],
124
+ "source": [
125
+ "show_random_elements(acc_dataset[\"train\"].remove_columns([\"audio\"]))"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": null,
131
+ "id": "471b4745-9398-4f32-a25d-cd5ba5d0150e",
132
+ "metadata": {},
133
+ "outputs": [],
134
+ "source": [
135
+ "from transformers import AutoModelForCTC, Wav2Vec2Processor\n",
136
+ "\n",
137
+ "model_repo_name = \"facebook/wav2vec2-large-960h\"\n",
138
+ "\n",
139
+ "processor = Wav2Vec2Processor.from_pretrained(model_repo_name)"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": null,
145
+ "id": "f1561f97-4d9f-4f17-9b84-da135c55715b",
146
+ "metadata": {},
147
+ "outputs": [],
148
+ "source": [
149
+ "acc_dataset['train'][0][\"audio\"]"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": null,
155
+ "id": "fba3685d-bf62-40e1-bc83-71c4456cc824",
156
+ "metadata": {},
157
+ "outputs": [],
158
+ "source": [
159
+ "import IPython.display as ipd\n",
160
+ "import numpy as np\n",
161
+ "import random\n",
162
+ "\n",
163
+ "rand_int = random.randint(0, len(acc_dataset[\"train\"]))\n",
164
+ "\n",
165
+ "print(acc_dataset[\"train\"][rand_int][\"text\"])\n",
166
+ "ipd.Audio(data=np.asarray(acc_dataset[\"train\"][rand_int][\"audio\"][\"array\"]), autoplay=True, rate=16000)"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "code",
171
+ "execution_count": null,
172
+ "id": "8138249e-a571-4c52-9eb4-3fcdf0c10469",
173
+ "metadata": {},
174
+ "outputs": [],
175
+ "source": [
176
+ "rand_int = random.randint(0, len(acc_dataset[\"train\"]))\n",
177
+ "\n",
178
+ "print(\"Target text:\", acc_dataset[\"train\"][rand_int][\"text\"])\n",
179
+ "print(\"Input array shape:\", np.asarray(acc_dataset[\"train\"][rand_int][\"audio\"][\"array\"]).shape)\n",
180
+ "print(\"Sampling rate:\", acc_dataset[\"train\"][rand_int][\"audio\"][\"sampling_rate\"])"
181
+ ]
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "execution_count": null,
186
+ "id": "0c49cc8d-c108-46a9-8f6c-6e5bf5d09c1c",
187
+ "metadata": {},
188
+ "outputs": [],
189
+ "source": [
190
+ "def prepare_dataset(batch):\n",
191
+ " audio = batch[\"audio\"]\n",
192
+ "\n",
193
+ " # batched output is \"un-batched\" to ensure mapping is correct\n",
194
+ " batch[\"input_values\"] = processor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_values[0]\n",
195
+ " \n",
196
+ " batch[\"labels\"] = processor.tokenizer(batch[\"text\"]).input_ids\n",
197
+ " \n",
198
+ " return batch"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": null,
204
+ "id": "9e84b0ac-85bc-4901-b605-0de1f9db716b",
205
+ "metadata": {},
206
+ "outputs": [],
207
+ "source": [
208
+ "acc_dataset"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": null,
214
+ "id": "d3c2946e-a5aa-4572-9717-3ab86878d121",
215
+ "metadata": {},
216
+ "outputs": [],
217
+ "source": [
218
+ "acc_dataset = acc_dataset.map(prepare_dataset, remove_columns=acc_dataset.column_names[\"train\"], num_proc=4)"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": null,
224
+ "id": "300ce7ac-5d0a-40cf-abe9-0c27149ffded",
225
+ "metadata": {},
226
+ "outputs": [],
227
+ "source": [
228
+ "print(acc_dataset[\"train\"][0][\"labels\"])"
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "code",
233
+ "execution_count": null,
234
+ "id": "5f1f0015-734f-44fa-bce9-1a605df36280",
235
+ "metadata": {},
236
+ "outputs": [],
237
+ "source": [
238
+ "from dataclasses import dataclass, field\n",
239
+ "import torch\n",
240
+ "from typing import Any, Dict, List, Optional, Union\n",
241
+ "\n",
242
+ "@dataclass\n",
243
+ "class DataCollatorCTCWithPadding:\n",
244
+ " processor: Wav2Vec2Processor\n",
245
+ " padding: Union[bool, str] = True\n",
246
+ " max_length: Optional[int] = None\n",
247
+ " max_length_labels: Optional[int] = None\n",
248
+ " pad_to_multiple_of: Optional[int] = None\n",
249
+ " pad_to_multiple_of_labels: Optional[int] = None\n",
250
+ "\n",
251
+ " def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n",
252
+ " input_features = [{\"input_values\": feature[\"input_values\"]} for feature in features]\n",
253
+ " label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n",
254
+ "\n",
255
+ " batch = self.processor.pad(\n",
256
+ " input_features,\n",
257
+ " padding=self.padding,\n",
258
+ " max_length=self.max_length,\n",
259
+ " pad_to_multiple_of=self.pad_to_multiple_of,\n",
260
+ " return_tensors=\"pt\",\n",
261
+ " )\n",
262
+ " with self.processor.as_target_processor():\n",
263
+ " labels_batch = self.processor.pad(\n",
264
+ " label_features,\n",
265
+ " padding=self.padding,\n",
266
+ " max_length=self.max_length_labels,\n",
267
+ " pad_to_multiple_of=self.pad_to_multiple_of_labels,\n",
268
+ " return_tensors=\"pt\",\n",
269
+ " )\n",
270
+ "\n",
271
+ " # replace padding with -100 to ignore loss correctly\n",
272
+ " labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n",
273
+ "\n",
274
+ " batch[\"labels\"] = labels\n",
275
+ "\n",
276
+ " return batch"
277
+ ]
278
+ },
279
+ {
280
+ "cell_type": "code",
281
+ "execution_count": null,
282
+ "id": "6519ec7b-dc55-4c37-90f7-2822c40e3e52",
283
+ "metadata": {},
284
+ "outputs": [],
285
+ "source": [
286
+ "data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)"
287
+ ]
288
+ },
289
+ {
290
+ "cell_type": "code",
291
+ "execution_count": null,
292
+ "id": "f05bd4f0-cb15-4729-a8dd-610baaee6c8f",
293
+ "metadata": {},
294
+ "outputs": [],
295
+ "source": [
296
+ "import evaluate \n",
297
+ "\n",
298
+ "\n",
299
+ "wer_metric = evaluate.load(\"wer\")\n",
300
+ "cer_metric = evaluate.load(\"cer\")"
301
+ ]
302
+ },
303
+ {
304
+ "cell_type": "code",
305
+ "execution_count": null,
306
+ "id": "b75cbad2-2487-4cd2-b20d-98dbc5631fa6",
307
+ "metadata": {},
308
+ "outputs": [],
309
+ "source": [
310
+ "def compute_metrics(pred):\n",
311
+ " pred_logits = pred.predictions\n",
312
+ " pred_ids = np.argmax(pred_logits, axis=-1)\n",
313
+ "\n",
314
+ " pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id\n",
315
+ "\n",
316
+ " pred_str = processor.batch_decode(pred_ids)\n",
317
+ " # we do not want to group tokens when computing the metrics\n",
318
+ " label_str = processor.batch_decode(pred.label_ids, group_tokens=False)\n",
319
+ "\n",
320
+ " wer = wer_metric.compute(predictions=pred_str, references=label_str)\n",
321
+ "\n",
322
+ " return {\"wer\": wer}"
323
+ ]
324
+ },
325
+ {
326
+ "cell_type": "code",
327
+ "execution_count": null,
328
+ "id": "81cd6a27-032b-46e9-9465-9f12efe0ea0e",
329
+ "metadata": {},
330
+ "outputs": [],
331
+ "source": [
332
+ "from transformers import Wav2Vec2ForCTC\n",
333
+ "\n",
334
+ "\n",
335
+ "model = Wav2Vec2ForCTC.from_pretrained(\n",
336
+ " model_repo_name, \n",
337
+ " ctc_loss_reduction=\"mean\", \n",
338
+ " pad_token_id=processor.tokenizer.pad_token_id,\n",
339
+ ")"
340
+ ]
341
+ },
342
+ {
343
+ "cell_type": "code",
344
+ "execution_count": null,
345
+ "id": "f3a57e26-451e-4eb3-9b5e-0ba789895ff5",
346
+ "metadata": {},
347
+ "outputs": [],
348
+ "source": [
349
+ "model.freeze_feature_extractor()"
350
+ ]
351
+ },
352
+ {
353
+ "cell_type": "code",
354
+ "execution_count": null,
355
+ "id": "adb7eaaa-e18d-4716-af7a-0c5fdc24a95c",
356
+ "metadata": {},
357
+ "outputs": [],
358
+ "source": [
359
+ "from transformers import TrainingArguments\n",
360
+ "\n",
361
+ "dir_for_training_artifacts = \"training-artifacts-\" + model_repo_name\n",
362
+ "\n",
363
+ "\n",
364
+ "training_args = TrainingArguments(\n",
365
+ " eval_steps=50,\n",
366
+ " evaluation_strategy=\"steps\",\n",
367
+ " fp16=True,\n",
368
+ " gradient_checkpointing=True,\n",
369
+ " group_by_length=True,\n",
370
+ " learning_rate=1e-4,\n",
371
+ " logging_steps=50,\n",
372
+ " num_train_epochs=128,\n",
373
+ " output_dir=dir_for_training_artifacts,\n",
374
+ " per_device_train_batch_size=64,\n",
375
+ " save_steps=50,\n",
376
+ " save_total_limit=2,\n",
377
+ " warmup_steps=15,\n",
378
+ " weight_decay=0.01,\n",
379
+ ")"
380
+ ]
381
+ },
382
+ {
383
+ "cell_type": "code",
384
+ "execution_count": null,
385
+ "id": "1f6db0db-6ce1-4a59-bf49-d8f776fa3a67",
386
+ "metadata": {},
387
+ "outputs": [],
388
+ "source": [
389
+ "from transformers import Trainer\n",
390
+ "\n",
391
+ "trainer = Trainer(\n",
392
+ " model=model,\n",
393
+ " data_collator=data_collator,\n",
394
+ " args=training_args,\n",
395
+ " compute_metrics=compute_metrics,\n",
396
+ " train_dataset=acc_dataset[\"train\"],\n",
397
+ " eval_dataset=acc_dataset[\"validate\"],\n",
398
+ " tokenizer=processor.feature_extractor,\n",
399
+ ")\n"
400
+ ]
401
+ },
402
+ {
403
+ "cell_type": "code",
404
+ "execution_count": null,
405
+ "id": "08212258-db86-44d6-b4f3-9fb936ceee85",
406
+ "metadata": {},
407
+ "outputs": [],
408
+ "source": [
409
+ "# Authenticate with HF if you haven't already. \n",
410
+ "\n",
411
+ "#from huggingface_hub import notebook_login\n",
412
+ "\n",
413
+ "#notebook_login()"
414
+ ]
415
+ },
416
+ {
417
+ "cell_type": "code",
418
+ "execution_count": null,
419
+ "id": "0940c20f-6d2f-4643-8aa4-ecd2e74f29ab",
420
+ "metadata": {},
421
+ "outputs": [],
422
+ "source": [
423
+ "trainer.train()"
424
+ ]
425
+ },
426
+ {
427
+ "cell_type": "code",
428
+ "execution_count": null,
429
+ "id": "0434e129-05c5-469b-8e2c-1b16bdfd2432",
430
+ "metadata": {},
431
+ "outputs": [],
432
+ "source": [
433
+ "trainer.push_to_hub()"
434
+ ]
435
+ },
436
+ {
437
+ "cell_type": "code",
438
+ "execution_count": null,
439
+ "id": "4d4f0605-9b4a-46d3-912c-cda97d3a6b9e",
440
+ "metadata": {},
441
+ "outputs": [],
442
+ "source": [
443
+ "def map_to_result(batch):\n",
444
+ " with torch.no_grad():\n",
445
+ " input_values = torch.tensor(batch[\"input_values\"], device=\"cuda\").unsqueeze(0)\n",
446
+ " logits = model(input_values).logits\n",
447
+ "\n",
448
+ " pred_ids = torch.argmax(logits, dim=-1)\n",
449
+ " batch[\"pred_str\"] = processor.batch_decode(pred_ids)[0]\n",
450
+ " batch[\"text\"] = processor.decode(batch[\"labels\"], group_tokens=False)\n",
451
+ " \n",
452
+ " return batch"
453
+ ]
454
+ },
455
+ {
456
+ "cell_type": "code",
457
+ "execution_count": null,
458
+ "id": "907d2b0c-23c1-4862-900c-a020d7d8b8c0",
459
+ "metadata": {},
460
+ "outputs": [],
461
+ "source": [
462
+ "results = acc_dataset[\"test\"].map(map_to_result)\n",
463
+ "#results = acc_dataset[\"validate\"].map(map_to_result)\n",
464
+ "#results = acc_dataset[\"train\"].map(map_to_result)"
465
+ ]
466
+ },
467
+ {
468
+ "cell_type": "code",
469
+ "execution_count": null,
470
+ "id": "bb25801b-adb7-48e0-9849-473dec2ee765",
471
+ "metadata": {},
472
+ "outputs": [],
473
+ "source": [
474
+ "import evaluate \n",
475
+ "\n",
476
+ "\n",
477
+ "wer_metric = evaluate.load(\"wer\")\n",
478
+ "cer_metric = evaluate.load(\"cer\")"
479
+ ]
480
+ },
481
+ {
482
+ "cell_type": "code",
483
+ "execution_count": null,
484
+ "id": "d3c3da77-c625-45fc-be34-ec43e2dbd6c2",
485
+ "metadata": {},
486
+ "outputs": [],
487
+ "source": [
488
+ "print(\"WER: {:.3f}\".format(wer_metric.compute(predictions=results[\"pred_str\"], references=results[\"text\"])))\n",
489
+ "print(\"CER: {:.3f}\".format(cer_metric.compute(predictions=results[\"pred_str\"], references=results[\"text\"])))"
490
+ ]
491
+ },
492
+ {
493
+ "cell_type": "code",
494
+ "execution_count": null,
495
+ "id": "ce444b9a-c222-4c01-a237-32b255a4617d",
496
+ "metadata": {},
497
+ "outputs": [],
498
+ "source": [
499
+ "def show_random_elements(dataset, num_examples=10):\n",
500
+ " assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n",
501
+ " picks = []\n",
502
+ " for _ in range(num_examples):\n",
503
+ " pick = random.randint(0, len(dataset)-1)\n",
504
+ " while pick in picks:\n",
505
+ " pick = random.randint(0, len(dataset)-1)\n",
506
+ " picks.append(pick)\n",
507
+ " \n",
508
+ " df = pd.DataFrame(dataset[picks])\n",
509
+ " display(HTML(df.to_html()))"
510
+ ]
511
+ },
512
+ {
513
+ "cell_type": "code",
514
+ "execution_count": null,
515
+ "id": "3b5f15fa-8099-49fd-9f30-75db02fae4e1",
516
+ "metadata": {},
517
+ "outputs": [],
518
+ "source": [
519
+ "show_random_elements(results.select_columns([\"pred_str\", \"text\"]))"
520
+ ]
521
+ },
522
+ {
523
+ "cell_type": "code",
524
+ "execution_count": null,
525
+ "id": "b4d29c7b-9610-4fc5-a30d-9ebffb41dd1d",
526
+ "metadata": {},
527
+ "outputs": [],
528
+ "source": [
529
+ "with torch.no_grad():\n",
530
+ " logits = model(torch.tensor(acc_dataset[\"test\"][:1][\"input_values\"], device=\"cuda\")).logits\n",
531
+ "\n",
532
+ "pred_ids = torch.argmax(logits, dim=-1)\n",
533
+ "\n",
534
+ "# convert ids to tokens\n",
535
+ "\" \".join(processor.tokenizer.convert_ids_to_tokens(pred_ids[0].tolist()))"
536
+ ]
537
+ },
538
+ {
539
+ "cell_type": "code",
540
+ "execution_count": null,
541
+ "id": "d7bbd7a4-3c1c-4950-a44c-08800de24667",
542
+ "metadata": {},
543
+ "outputs": [],
544
+ "source": [
545
+ "results.select_columns([\"pred_str\", \"text\"])"
546
+ ]
547
+ }
548
+ ],
549
+ "metadata": {
550
+ "kernelspec": {
551
+ "display_name": "Python 3 (ipykernel)",
552
+ "language": "python",
553
+ "name": "python3"
554
+ },
555
+ "language_info": {
556
+ "codemirror_mode": {
557
+ "name": "ipython",
558
+ "version": 3
559
+ },
560
+ "file_extension": ".py",
561
+ "mimetype": "text/x-python",
562
+ "name": "python",
563
+ "nbconvert_exporter": "python",
564
+ "pygments_lexer": "ipython3",
565
+ "version": "3.10.12"
566
+ }
567
+ },
568
+ "nbformat": 4,
569
+ "nbformat_minor": 5
570
+ }