{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "c7f374d3-4c44-48cb-bba9-e18c099fbe38", "metadata": {}, "outputs": [], "source": [ "!which python" ] }, { "cell_type": "code", "execution_count": null, "id": "f0ae33f0-52f3-4d4c-88f9-28a458036be8", "metadata": {}, "outputs": [], "source": [ "pip_ouput = !pip install accelerate evaluate torch transformers\n", "#print(pip_ouput)" ] }, { "cell_type": "code", "execution_count": null, "id": "c8defb5e-962b-49c0-a32f-5f50f0e52f50", "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "acc_dataset = load_dataset(\"monadical-labs/acc_dataset_v2\")" ] }, { "cell_type": "code", "execution_count": null, "id": "58ea0320-19d7-4a98-954d-0d3302060e7a", "metadata": {}, "outputs": [], "source": [ "import re\n", "\n", "acc_dataset = acc_dataset.filter(lambda x: not re.search(r'\\d', x[\"text\"]))" ] }, { "cell_type": "code", "execution_count": null, "id": "47d2aa85-7c2a-488f-abeb-448718571828", "metadata": {}, "outputs": [], "source": [ "from datasets import ClassLabel\n", "import random\n", "import pandas as pd\n", "from IPython.display import display, HTML\n", "\n", "def show_random_elements(dataset, num_examples=10):\n", " assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n", " picks = []\n", " for _ in range(num_examples):\n", " pick = random.randint(0, len(dataset)-1)\n", " while pick in picks:\n", " pick = random.randint(0, len(dataset)-1)\n", " picks.append(pick)\n", " \n", " df = pd.DataFrame(dataset[picks])\n", " display(HTML(df.to_html()))" ] }, { "cell_type": "code", "execution_count": null, "id": "34743004-7d8c-46b7-81ea-ad448ec450ed", "metadata": {}, "outputs": [], "source": [ "show_random_elements(acc_dataset[\"train\"].remove_columns([\"audio\"]))" ] }, { "cell_type": "code", "execution_count": null, "id": "15beb241-e9d1-4baf-9375-10d1f6824a91", "metadata": {}, "outputs": [], "source": [ "import re\n", "chars_to_ignore_regex = '[\\,\\?\\.\\!\\-\\;\\:\\\"]'\n", "\n", "def remove_special_characters(batch):\n", " batch[\"text\"] = re.sub(chars_to_ignore_regex, '', batch[\"text\"]).upper()\n", " return batch" ] }, { "cell_type": "code", "execution_count": null, "id": "6791e858-e2a0-4494-83d5-0b2c30ded226", "metadata": {}, "outputs": [], "source": [ "acc_dataset = acc_dataset.map(remove_special_characters)" ] }, { "cell_type": "code", "execution_count": null, "id": "2199ae88-fdcd-48ed-a028-e263f6237494", "metadata": {}, "outputs": [], "source": [ "acc_dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "0fd7facb-809d-4209-9b93-27a06e2e044f", "metadata": {}, "outputs": [], "source": [ "show_random_elements(acc_dataset[\"train\"].remove_columns([\"audio\"]))" ] }, { "cell_type": "code", "execution_count": null, "id": "471b4745-9398-4f32-a25d-cd5ba5d0150e", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModelForCTC, Wav2Vec2Processor\n", "\n", "model_repo_name = \"facebook/wav2vec2-large-960h\"\n", "\n", "processor = Wav2Vec2Processor.from_pretrained(model_repo_name)" ] }, { "cell_type": "code", "execution_count": null, "id": "f1561f97-4d9f-4f17-9b84-da135c55715b", "metadata": {}, "outputs": [], "source": [ "acc_dataset['train'][0][\"audio\"]" ] }, { "cell_type": "code", "execution_count": null, "id": "fba3685d-bf62-40e1-bc83-71c4456cc824", "metadata": {}, "outputs": [], "source": [ "import IPython.display as ipd\n", "import numpy as np\n", "import random\n", "\n", "rand_int = random.randint(0, len(acc_dataset[\"train\"]))\n", "\n", "print(acc_dataset[\"train\"][rand_int][\"text\"])\n", "ipd.Audio(data=np.asarray(acc_dataset[\"train\"][rand_int][\"audio\"][\"array\"]), autoplay=True, rate=16000)" ] }, { "cell_type": "code", "execution_count": null, "id": "8138249e-a571-4c52-9eb4-3fcdf0c10469", "metadata": {}, "outputs": [], "source": [ "rand_int = random.randint(0, len(acc_dataset[\"train\"]))\n", "\n", "print(\"Target text:\", acc_dataset[\"train\"][rand_int][\"text\"])\n", "print(\"Input array shape:\", np.asarray(acc_dataset[\"train\"][rand_int][\"audio\"][\"array\"]).shape)\n", "print(\"Sampling rate:\", acc_dataset[\"train\"][rand_int][\"audio\"][\"sampling_rate\"])" ] }, { "cell_type": "code", "execution_count": null, "id": "0c49cc8d-c108-46a9-8f6c-6e5bf5d09c1c", "metadata": {}, "outputs": [], "source": [ "def prepare_dataset(batch):\n", " audio = batch[\"audio\"]\n", "\n", " # batched output is \"un-batched\" to ensure mapping is correct\n", " batch[\"input_values\"] = processor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_values[0]\n", " \n", " batch[\"labels\"] = processor.tokenizer(batch[\"text\"]).input_ids\n", " \n", " return batch" ] }, { "cell_type": "code", "execution_count": null, "id": "9e84b0ac-85bc-4901-b605-0de1f9db716b", "metadata": {}, "outputs": [], "source": [ "acc_dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "d3c2946e-a5aa-4572-9717-3ab86878d121", "metadata": {}, "outputs": [], "source": [ "acc_dataset = acc_dataset.map(prepare_dataset, remove_columns=acc_dataset.column_names[\"train\"], num_proc=4)" ] }, { "cell_type": "code", "execution_count": null, "id": "300ce7ac-5d0a-40cf-abe9-0c27149ffded", "metadata": {}, "outputs": [], "source": [ "print(acc_dataset[\"train\"][0][\"labels\"])" ] }, { "cell_type": "code", "execution_count": null, "id": "5f1f0015-734f-44fa-bce9-1a605df36280", "metadata": {}, "outputs": [], "source": [ "from dataclasses import dataclass, field\n", "import torch\n", "from typing import Any, Dict, List, Optional, Union\n", "\n", "@dataclass\n", "class DataCollatorCTCWithPadding:\n", " processor: Wav2Vec2Processor\n", " padding: Union[bool, str] = True\n", " max_length: Optional[int] = None\n", " max_length_labels: Optional[int] = None\n", " pad_to_multiple_of: Optional[int] = None\n", " pad_to_multiple_of_labels: Optional[int] = None\n", "\n", " def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n", " input_features = [{\"input_values\": feature[\"input_values\"]} for feature in features]\n", " label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n", "\n", " batch = self.processor.pad(\n", " input_features,\n", " padding=self.padding,\n", " max_length=self.max_length,\n", " pad_to_multiple_of=self.pad_to_multiple_of,\n", " return_tensors=\"pt\",\n", " )\n", " with self.processor.as_target_processor():\n", " labels_batch = self.processor.pad(\n", " label_features,\n", " padding=self.padding,\n", " max_length=self.max_length_labels,\n", " pad_to_multiple_of=self.pad_to_multiple_of_labels,\n", " return_tensors=\"pt\",\n", " )\n", "\n", " # replace padding with -100 to ignore loss correctly\n", " labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n", "\n", " batch[\"labels\"] = labels\n", "\n", " return batch" ] }, { "cell_type": "code", "execution_count": null, "id": "6519ec7b-dc55-4c37-90f7-2822c40e3e52", "metadata": {}, "outputs": [], "source": [ "data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "f05bd4f0-cb15-4729-a8dd-610baaee6c8f", "metadata": {}, "outputs": [], "source": [ "import evaluate \n", "\n", "\n", "wer_metric = evaluate.load(\"wer\")\n", "cer_metric = evaluate.load(\"cer\")" ] }, { "cell_type": "code", "execution_count": null, "id": "b75cbad2-2487-4cd2-b20d-98dbc5631fa6", "metadata": {}, "outputs": [], "source": [ "def compute_metrics(pred):\n", " pred_logits = pred.predictions\n", " pred_ids = np.argmax(pred_logits, axis=-1)\n", "\n", " pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id\n", "\n", " pred_str = processor.batch_decode(pred_ids)\n", " # we do not want to group tokens when computing the metrics\n", " label_str = processor.batch_decode(pred.label_ids, group_tokens=False)\n", "\n", " wer = wer_metric.compute(predictions=pred_str, references=label_str)\n", "\n", " return {\"wer\": wer}" ] }, { "cell_type": "code", "execution_count": null, "id": "81cd6a27-032b-46e9-9465-9f12efe0ea0e", "metadata": {}, "outputs": [], "source": [ "from transformers import Wav2Vec2ForCTC\n", "\n", "\n", "model = Wav2Vec2ForCTC.from_pretrained(\n", " model_repo_name, \n", " ctc_loss_reduction=\"mean\", \n", " pad_token_id=processor.tokenizer.pad_token_id,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "f3a57e26-451e-4eb3-9b5e-0ba789895ff5", "metadata": {}, "outputs": [], "source": [ "model.freeze_feature_extractor()" ] }, { "cell_type": "code", "execution_count": null, "id": "adb7eaaa-e18d-4716-af7a-0c5fdc24a95c", "metadata": {}, "outputs": [], "source": [ "from transformers import TrainingArguments\n", "\n", "dir_for_training_artifacts = \"training-artifacts-\" + model_repo_name\n", "\n", "\n", "training_args = TrainingArguments(\n", " eval_steps=50,\n", " evaluation_strategy=\"steps\",\n", " fp16=True,\n", " gradient_checkpointing=True,\n", " group_by_length=True,\n", " learning_rate=1e-4,\n", " logging_steps=50,\n", " num_train_epochs=128,\n", " output_dir=dir_for_training_artifacts,\n", " per_device_train_batch_size=64,\n", " save_steps=50,\n", " save_total_limit=2,\n", " warmup_steps=15,\n", " weight_decay=0.01,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "1f6db0db-6ce1-4a59-bf49-d8f776fa3a67", "metadata": {}, "outputs": [], "source": [ "from transformers import Trainer\n", "\n", "trainer = Trainer(\n", " model=model,\n", " data_collator=data_collator,\n", " args=training_args,\n", " compute_metrics=compute_metrics,\n", " train_dataset=acc_dataset[\"train\"],\n", " eval_dataset=acc_dataset[\"validate\"],\n", " tokenizer=processor.feature_extractor,\n", ")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "08212258-db86-44d6-b4f3-9fb936ceee85", "metadata": {}, "outputs": [], "source": [ "# Authenticate with HF if you haven't already. \n", "\n", "#from huggingface_hub import notebook_login\n", "\n", "#notebook_login()" ] }, { "cell_type": "code", "execution_count": null, "id": "0940c20f-6d2f-4643-8aa4-ecd2e74f29ab", "metadata": {}, "outputs": [], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "id": "0434e129-05c5-469b-8e2c-1b16bdfd2432", "metadata": {}, "outputs": [], "source": [ "trainer.push_to_hub()" ] }, { "cell_type": "code", "execution_count": null, "id": "4d4f0605-9b4a-46d3-912c-cda97d3a6b9e", "metadata": {}, "outputs": [], "source": [ "def map_to_result(batch):\n", " with torch.no_grad():\n", " input_values = torch.tensor(batch[\"input_values\"], device=\"cuda\").unsqueeze(0)\n", " logits = model(input_values).logits\n", "\n", " pred_ids = torch.argmax(logits, dim=-1)\n", " batch[\"pred_str\"] = processor.batch_decode(pred_ids)[0]\n", " batch[\"text\"] = processor.decode(batch[\"labels\"], group_tokens=False)\n", " \n", " return batch" ] }, { "cell_type": "code", "execution_count": null, "id": "907d2b0c-23c1-4862-900c-a020d7d8b8c0", "metadata": {}, "outputs": [], "source": [ "results = acc_dataset[\"test\"].map(map_to_result)\n", "#results = acc_dataset[\"validate\"].map(map_to_result)\n", "#results = acc_dataset[\"train\"].map(map_to_result)" ] }, { "cell_type": "code", "execution_count": null, "id": "bb25801b-adb7-48e0-9849-473dec2ee765", "metadata": {}, "outputs": [], "source": [ "import evaluate \n", "\n", "\n", "wer_metric = evaluate.load(\"wer\")\n", "cer_metric = evaluate.load(\"cer\")" ] }, { "cell_type": "code", "execution_count": null, "id": "d3c3da77-c625-45fc-be34-ec43e2dbd6c2", "metadata": {}, "outputs": [], "source": [ "print(\"WER: {:.3f}\".format(wer_metric.compute(predictions=results[\"pred_str\"], references=results[\"text\"])))\n", "print(\"CER: {:.3f}\".format(cer_metric.compute(predictions=results[\"pred_str\"], references=results[\"text\"])))" ] }, { "cell_type": "code", "execution_count": null, "id": "ce444b9a-c222-4c01-a237-32b255a4617d", "metadata": {}, "outputs": [], "source": [ "def show_random_elements(dataset, num_examples=10):\n", " assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n", " picks = []\n", " for _ in range(num_examples):\n", " pick = random.randint(0, len(dataset)-1)\n", " while pick in picks:\n", " pick = random.randint(0, len(dataset)-1)\n", " picks.append(pick)\n", " \n", " df = pd.DataFrame(dataset[picks])\n", " display(HTML(df.to_html()))" ] }, { "cell_type": "code", "execution_count": null, "id": "3b5f15fa-8099-49fd-9f30-75db02fae4e1", "metadata": {}, "outputs": [], "source": [ "show_random_elements(results.select_columns([\"pred_str\", \"text\"]))" ] }, { "cell_type": "code", "execution_count": null, "id": "b4d29c7b-9610-4fc5-a30d-9ebffb41dd1d", "metadata": {}, "outputs": [], "source": [ "with torch.no_grad():\n", " logits = model(torch.tensor(acc_dataset[\"test\"][:1][\"input_values\"], device=\"cuda\")).logits\n", "\n", "pred_ids = torch.argmax(logits, dim=-1)\n", "\n", "# convert ids to tokens\n", "\" \".join(processor.tokenizer.convert_ids_to_tokens(pred_ids[0].tolist()))" ] }, { "cell_type": "code", "execution_count": null, "id": "d7bbd7a4-3c1c-4950-a44c-08800de24667", "metadata": {}, "outputs": [], "source": [ "results.select_columns([\"pred_str\", \"text\"])" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }