ilokavat commited on
Commit
f125260
·
verified ·
1 Parent(s): f31c073

Upload Finetuning-notebook-whisper-on-acc-data.ipynb

Browse files
Finetuning-notebook-whisper-on-acc-data.ipynb ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "c7f374d3-4c44-48cb-bba9-e18c099fbe38",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "!which python"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "id": "f0ae33f0-52f3-4d4c-88f9-28a458036be8",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "pip_ouput = !pip install accelerate evaluate torch transformers\n",
21
+ "#print(pip_ouput)"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "id": "c8defb5e-962b-49c0-a32f-5f50f0e52f50",
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "from datasets import load_dataset\n",
32
+ "\n",
33
+ "acc_dataset = load_dataset(\"monadical-labs/acc_dataset_v3\")"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": null,
39
+ "id": "47d2aa85-7c2a-488f-abeb-448718571828",
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "from datasets import ClassLabel\n",
44
+ "import random\n",
45
+ "import pandas as pd\n",
46
+ "from IPython.display import display, HTML\n",
47
+ "\n",
48
+ "def show_random_elements(dataset, num_examples=10):\n",
49
+ " assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n",
50
+ " picks = []\n",
51
+ " for _ in range(num_examples):\n",
52
+ " pick = random.randint(0, len(dataset)-1)\n",
53
+ " while pick in picks:\n",
54
+ " pick = random.randint(0, len(dataset)-1)\n",
55
+ " picks.append(pick)\n",
56
+ " \n",
57
+ " df = pd.DataFrame(dataset[picks])\n",
58
+ " display(HTML(df.to_html()))"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": null,
64
+ "id": "34743004-7d8c-46b7-81ea-ad448ec450ed",
65
+ "metadata": {},
66
+ "outputs": [],
67
+ "source": [
68
+ "show_random_elements(acc_dataset[\"train\"].remove_columns([\"audio\"]))"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": null,
74
+ "id": "2199ae88-fdcd-48ed-a028-e263f6237494",
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "acc_dataset"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": null,
84
+ "id": "3eee3cc1-dbbd-47a5-b053-b052f087e070",
85
+ "metadata": {},
86
+ "outputs": [],
87
+ "source": [
88
+ "for split in acc_dataset:\n",
89
+ " acc_dataset[split] = acc_dataset[split].remove_columns([\"text\"])\n",
90
+ " acc_dataset[split] = acc_dataset[split].rename_column(\"text_with_digits\", \"text\")"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "id": "0fd7facb-809d-4209-9b93-27a06e2e044f",
97
+ "metadata": {},
98
+ "outputs": [],
99
+ "source": [
100
+ "show_random_elements(acc_dataset[\"train\"].remove_columns([\"audio\"]))"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": null,
106
+ "id": "471b4745-9398-4f32-a25d-cd5ba5d0150e",
107
+ "metadata": {},
108
+ "outputs": [],
109
+ "source": [
110
+ "from transformers import WhisperFeatureExtractor, WhisperForConditionalGeneration, WhisperProcessor, WhisperTokenizer\n",
111
+ "\n",
112
+ "model_name = \"openai/whisper-medium.en\"\n",
113
+ "\n",
114
+ "model = WhisperForConditionalGeneration.from_pretrained(model_name)\n",
115
+ "processor = WhisperProcessor.from_pretrained(model_name, language=\"English\", task=\"transcribe\")"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": null,
121
+ "id": "dafa7e33-4628-426a-863e-3b50b9027929",
122
+ "metadata": {},
123
+ "outputs": [],
124
+ "source": [
125
+ "input_str = acc_dataset['train'][9][\"text\"]\n",
126
+ "labels = processor.tokenizer(input_str).input_ids\n",
127
+ "decoded_with_special = processor.tokenizer.decode(labels, skip_special_tokens=False)\n",
128
+ "decoded_str = processor.tokenizer.decode(labels, skip_special_tokens=True)\n",
129
+ "\n",
130
+ "print(f\"Input: {input_str}\")\n",
131
+ "print(f\"Decoded w/ special: {decoded_with_special}\")\n",
132
+ "print(f\"Decoded w/out special: {decoded_str}\")\n",
133
+ "print(f\"Are equal: {input_str == decoded_str}\")"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": null,
139
+ "id": "f1561f97-4d9f-4f17-9b84-da135c55715b",
140
+ "metadata": {},
141
+ "outputs": [],
142
+ "source": [
143
+ "acc_dataset['train'][0][\"audio\"]"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": null,
149
+ "id": "fba3685d-bf62-40e1-bc83-71c4456cc824",
150
+ "metadata": {},
151
+ "outputs": [],
152
+ "source": [
153
+ "import IPython.display as ipd\n",
154
+ "import numpy as np\n",
155
+ "import random\n",
156
+ "\n",
157
+ "rand_int = random.randint(0, len(acc_dataset[\"train\"]))\n",
158
+ "\n",
159
+ "print(acc_dataset[\"train\"][rand_int][\"text\"])\n",
160
+ "#pd.Audio(data=np.asarray(acc_dataset[\"train\"][rand_int][\"audio\"][\"array\"]), autoplay=True, rate=16000)"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": null,
166
+ "id": "8138249e-a571-4c52-9eb4-3fcdf0c10469",
167
+ "metadata": {},
168
+ "outputs": [],
169
+ "source": [
170
+ "rand_int = random.randint(0, len(acc_dataset[\"train\"]))\n",
171
+ "\n",
172
+ "print(\"Target text:\", acc_dataset[\"train\"][rand_int][\"text\"])\n",
173
+ "print(\"Input array shape:\", np.asarray(acc_dataset[\"train\"][rand_int][\"audio\"][\"array\"]).shape)\n",
174
+ "print(\"Sampling rate:\", acc_dataset[\"train\"][rand_int][\"audio\"][\"sampling_rate\"])"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": null,
180
+ "id": "0c49cc8d-c108-46a9-8f6c-6e5bf5d09c1c",
181
+ "metadata": {},
182
+ "outputs": [],
183
+ "source": [
184
+ "def prepare_dataset(batch):\n",
185
+ " audio = batch[\"audio\"]\n",
186
+ "\n",
187
+ " # batched output is \"un-batched\" to ensure mapping is correct\n",
188
+ " batch[\"input_features\"] = processor.feature_extractor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_features[0]\n",
189
+ " \n",
190
+ " batch[\"labels\"] = processor.tokenizer(batch[\"text\"]).input_ids\n",
191
+ " \n",
192
+ " return batch"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "execution_count": null,
198
+ "id": "d3c2946e-a5aa-4572-9717-3ab86878d121",
199
+ "metadata": {},
200
+ "outputs": [],
201
+ "source": [
202
+ "acc_dataset = acc_dataset.map(prepare_dataset)"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": null,
208
+ "id": "5f1f0015-734f-44fa-bce9-1a605df36280",
209
+ "metadata": {},
210
+ "outputs": [],
211
+ "source": [
212
+ "import torch\n",
213
+ "\n",
214
+ "from dataclasses import dataclass\n",
215
+ "from typing import Any, Dict, List, Union\n",
216
+ "\n",
217
+ "@dataclass\n",
218
+ "class DataCollatorSpeechSeq2SeqWithPadding:\n",
219
+ " processor: Any\n",
220
+ " decoder_start_token_id: int\n",
221
+ "\n",
222
+ " def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n",
223
+ " input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features]\n",
224
+ " batch = self.processor.feature_extractor.pad(input_features, return_tensors=\"pt\")\n",
225
+ "\n",
226
+ " label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n",
227
+ " labels_batch = self.processor.tokenizer.pad(label_features, return_tensors=\"pt\")\n",
228
+ "\n",
229
+ " labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n",
230
+ "\n",
231
+ " if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():\n",
232
+ " labels = labels[:, 1:]\n",
233
+ "\n",
234
+ " batch[\"labels\"] = labels\n",
235
+ "\n",
236
+ " return batch"
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "execution_count": null,
242
+ "id": "6519ec7b-dc55-4c37-90f7-2822c40e3e52",
243
+ "metadata": {},
244
+ "outputs": [],
245
+ "source": [
246
+ "data_collator = DataCollatorSpeechSeq2SeqWithPadding(\n",
247
+ " processor=processor,\n",
248
+ " decoder_start_token_id=model.config.decoder_start_token_id,\n",
249
+ ")"
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "code",
254
+ "execution_count": null,
255
+ "id": "f05bd4f0-cb15-4729-a8dd-610baaee6c8f",
256
+ "metadata": {},
257
+ "outputs": [],
258
+ "source": [
259
+ "import evaluate \n",
260
+ "\n",
261
+ "\n",
262
+ "wer_metric = evaluate.load(\"wer\")\n",
263
+ "cer_metric = evaluate.load(\"cer\")"
264
+ ]
265
+ },
266
+ {
267
+ "cell_type": "code",
268
+ "execution_count": null,
269
+ "id": "b75cbad2-2487-4cd2-b20d-98dbc5631fa6",
270
+ "metadata": {},
271
+ "outputs": [],
272
+ "source": [
273
+ "def compute_metrics(pred):\n",
274
+ " pred_ids = pred.predictions\n",
275
+ " label_ids = pred.label_ids\n",
276
+ "\n",
277
+ " label_ids[label_ids == -100] = processor.tokenizer.pad_token_id\n",
278
+ "\n",
279
+ " pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n",
280
+ " label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
281
+ "\n",
282
+ " wer = wer_metric.compute(predictions=pred_str, references=label_str)\n",
283
+ "\n",
284
+ " return {\"wer\": wer}"
285
+ ]
286
+ },
287
+ {
288
+ "cell_type": "code",
289
+ "execution_count": null,
290
+ "id": "adb7eaaa-e18d-4716-af7a-0c5fdc24a95c",
291
+ "metadata": {},
292
+ "outputs": [],
293
+ "source": [
294
+ "from transformers import Seq2SeqTrainingArguments\n",
295
+ "\n",
296
+ "dir_for_training_artifacts = \"training-artifacts-\" + model_name\n",
297
+ "\n",
298
+ "eval_step_count = 25\n",
299
+ "max_step_count = 300\n",
300
+ "\n",
301
+ "training_args = Seq2SeqTrainingArguments(\n",
302
+ " evaluation_strategy=\"steps\",\n",
303
+ " eval_steps=eval_step_count,\n",
304
+ " fp16=True,\n",
305
+ " generation_max_length=225,\n",
306
+ " gradient_checkpointing=True,\n",
307
+ " greater_is_better=False,\n",
308
+ " learning_rate=5e-5,\n",
309
+ " load_best_model_at_end=True,\n",
310
+ " logging_steps=eval_step_count,\n",
311
+ " max_steps=max_step_count,\n",
312
+ " metric_for_best_model=\"wer\",\n",
313
+ " output_dir= dir_for_training_artifacts,\n",
314
+ " per_device_eval_batch_size=4,\n",
315
+ " per_device_train_batch_size=32,\n",
316
+ " predict_with_generate=True,\n",
317
+ " push_to_hub=True,\n",
318
+ " warmup_steps=eval_step_count,\n",
319
+ ")"
320
+ ]
321
+ },
322
+ {
323
+ "cell_type": "code",
324
+ "execution_count": null,
325
+ "id": "1f6db0db-6ce1-4a59-bf49-d8f776fa3a67",
326
+ "metadata": {},
327
+ "outputs": [],
328
+ "source": [
329
+ "from transformers import Seq2SeqTrainer\n",
330
+ "\n",
331
+ "trainer = Seq2SeqTrainer(\n",
332
+ " args=training_args,\n",
333
+ " model=model,\n",
334
+ " train_dataset=acc_dataset[\"train\"],\n",
335
+ " eval_dataset=acc_dataset[\"validate\"],\n",
336
+ " data_collator=data_collator,\n",
337
+ " compute_metrics=compute_metrics,\n",
338
+ " tokenizer=processor.feature_extractor,\n",
339
+ ")"
340
+ ]
341
+ },
342
+ {
343
+ "cell_type": "code",
344
+ "execution_count": null,
345
+ "id": "08212258-db86-44d6-b4f3-9fb936ceee85",
346
+ "metadata": {},
347
+ "outputs": [],
348
+ "source": [
349
+ "# Authenticate with HF if you haven't already. \n",
350
+ "\n",
351
+ "#from huggingface_hub import notebook_login\n",
352
+ "\n",
353
+ "#notebook_login()"
354
+ ]
355
+ },
356
+ {
357
+ "cell_type": "code",
358
+ "execution_count": null,
359
+ "id": "0940c20f-6d2f-4643-8aa4-ecd2e74f29ab",
360
+ "metadata": {},
361
+ "outputs": [],
362
+ "source": [
363
+ "trainer.train()"
364
+ ]
365
+ },
366
+ {
367
+ "cell_type": "code",
368
+ "execution_count": null,
369
+ "id": "0434e129-05c5-469b-8e2c-1b16bdfd2432",
370
+ "metadata": {},
371
+ "outputs": [],
372
+ "source": [
373
+ "trainer.push_to_hub()"
374
+ ]
375
+ },
376
+ {
377
+ "cell_type": "code",
378
+ "execution_count": null,
379
+ "id": "4d4f0605-9b4a-46d3-912c-cda97d3a6b9e",
380
+ "metadata": {},
381
+ "outputs": [],
382
+ "source": [
383
+ "def map_to_result(batch):\n",
384
+ " with torch.no_grad():\n",
385
+ " input_values = torch.tensor(batch[\"input_features\"], device=\"cuda\").unsqueeze(0)\n",
386
+ " predicted_ids = model.generate(input_values)\n",
387
+ "\n",
388
+ " batch[\"pred_str\"] = processor.batch_decode(predicted_ids, skip_special_tokens=False)[0]\n",
389
+ " return batch"
390
+ ]
391
+ },
392
+ {
393
+ "cell_type": "code",
394
+ "execution_count": null,
395
+ "id": "907d2b0c-23c1-4862-900c-a020d7d8b8c0",
396
+ "metadata": {},
397
+ "outputs": [],
398
+ "source": [
399
+ "results = acc_dataset[\"test\"].map(map_to_result)\n",
400
+ "#results = acc_dataset[\"validate\"].map(map_to_result)\n",
401
+ "#results = acc_dataset[\"train\"].map(map_to_result)"
402
+ ]
403
+ },
404
+ {
405
+ "cell_type": "code",
406
+ "execution_count": null,
407
+ "id": "bb25801b-adb7-48e0-9849-473dec2ee765",
408
+ "metadata": {},
409
+ "outputs": [],
410
+ "source": [
411
+ "import evaluate \n",
412
+ "\n",
413
+ "\n",
414
+ "wer_metric = evaluate.load(\"wer\")\n",
415
+ "cer_metric = evaluate.load(\"cer\")"
416
+ ]
417
+ },
418
+ {
419
+ "cell_type": "code",
420
+ "execution_count": null,
421
+ "id": "8ee6948b-f8ea-4f39-8f33-0519ba8d8d85",
422
+ "metadata": {},
423
+ "outputs": [],
424
+ "source": [
425
+ "results[\"pred_str\"][0]"
426
+ ]
427
+ },
428
+ {
429
+ "cell_type": "code",
430
+ "execution_count": null,
431
+ "id": "d3c3da77-c625-45fc-be34-ec43e2dbd6c2",
432
+ "metadata": {},
433
+ "outputs": [],
434
+ "source": [
435
+ "print(\"WER: {:.3f}\".format(wer_metric.compute(predictions=results[\"pred_str\"], references=results[\"text\"])))\n",
436
+ "print(\"CER: {:.3f}\".format(cer_metric.compute(predictions=results[\"pred_str\"], references=results[\"text\"])))"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": null,
442
+ "id": "ce444b9a-c222-4c01-a237-32b255a4617d",
443
+ "metadata": {},
444
+ "outputs": [],
445
+ "source": [
446
+ "def show_random_elements(dataset, num_examples=10):\n",
447
+ " assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n",
448
+ " picks = []\n",
449
+ " for _ in range(num_examples):\n",
450
+ " pick = random.randint(0, len(dataset)-1)\n",
451
+ " while pick in picks:\n",
452
+ " pick = random.randint(0, len(dataset)-1)\n",
453
+ " picks.append(pick)\n",
454
+ " \n",
455
+ " df = pd.DataFrame(dataset[picks])\n",
456
+ " display(HTML(df.to_html()))"
457
+ ]
458
+ },
459
+ {
460
+ "cell_type": "code",
461
+ "execution_count": null,
462
+ "id": "3b5f15fa-8099-49fd-9f30-75db02fae4e1",
463
+ "metadata": {},
464
+ "outputs": [],
465
+ "source": [
466
+ "show_random_elements(results.select_columns([\"text\", \"pred_str\"]))"
467
+ ]
468
+ },
469
+ {
470
+ "cell_type": "code",
471
+ "execution_count": null,
472
+ "id": "b4d29c7b-9610-4fc5-a30d-9ebffb41dd1d",
473
+ "metadata": {},
474
+ "outputs": [],
475
+ "source": [
476
+ "with torch.no_grad():\n",
477
+ " predicted_ids = model.generate(torch.tensor(acc_dataset[\"train\"][:1][\"input_features\"], device=\"cuda\"))\n",
478
+ "\n",
479
+ "print(predicted_ids)\n",
480
+ "\n",
481
+ "# convert ids to tokens\n",
482
+ "processor.batch_decode(predicted_ids, skip_special_tokens=False)[0]"
483
+ ]
484
+ }
485
+ ],
486
+ "metadata": {
487
+ "kernelspec": {
488
+ "display_name": "Python 3 (ipykernel)",
489
+ "language": "python",
490
+ "name": "python3"
491
+ },
492
+ "language_info": {
493
+ "codemirror_mode": {
494
+ "name": "ipython",
495
+ "version": 3
496
+ },
497
+ "file_extension": ".py",
498
+ "mimetype": "text/x-python",
499
+ "name": "python",
500
+ "nbconvert_exporter": "python",
501
+ "pygments_lexer": "ipython3",
502
+ "version": "3.10.12"
503
+ }
504
+ },
505
+ "nbformat": 4,
506
+ "nbformat_minor": 5
507
+ }