{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "c7f374d3-4c44-48cb-bba9-e18c099fbe38", "metadata": {}, "outputs": [], "source": [ "!which python" ] }, { "cell_type": "code", "execution_count": null, "id": "f0ae33f0-52f3-4d4c-88f9-28a458036be8", "metadata": {}, "outputs": [], "source": [ "pip_ouput = !pip install accelerate evaluate torch transformers\n", "#print(pip_ouput)" ] }, { "cell_type": "code", "execution_count": null, "id": "c8defb5e-962b-49c0-a32f-5f50f0e52f50", "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "acc_dataset = load_dataset(\"monadical-labs/acc_dataset_v3\")" ] }, { "cell_type": "code", "execution_count": null, "id": "47d2aa85-7c2a-488f-abeb-448718571828", "metadata": {}, "outputs": [], "source": [ "from datasets import ClassLabel\n", "import random\n", "import pandas as pd\n", "from IPython.display import display, HTML\n", "\n", "def show_random_elements(dataset, num_examples=10):\n", " assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n", " picks = []\n", " for _ in range(num_examples):\n", " pick = random.randint(0, len(dataset)-1)\n", " while pick in picks:\n", " pick = random.randint(0, len(dataset)-1)\n", " picks.append(pick)\n", " \n", " df = pd.DataFrame(dataset[picks])\n", " display(HTML(df.to_html()))" ] }, { "cell_type": "code", "execution_count": null, "id": "34743004-7d8c-46b7-81ea-ad448ec450ed", "metadata": {}, "outputs": [], "source": [ "show_random_elements(acc_dataset[\"train\"].remove_columns([\"audio\"]))" ] }, { "cell_type": "code", "execution_count": null, "id": "2199ae88-fdcd-48ed-a028-e263f6237494", "metadata": {}, "outputs": [], "source": [ "acc_dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "3eee3cc1-dbbd-47a5-b053-b052f087e070", "metadata": {}, "outputs": [], "source": [ "for split in acc_dataset:\n", " acc_dataset[split] = acc_dataset[split].remove_columns([\"text\"])\n", " acc_dataset[split] = acc_dataset[split].rename_column(\"text_with_digits\", \"text\")" ] }, { "cell_type": "code", "execution_count": null, "id": "0fd7facb-809d-4209-9b93-27a06e2e044f", "metadata": {}, "outputs": [], "source": [ "show_random_elements(acc_dataset[\"train\"].remove_columns([\"audio\"]))" ] }, { "cell_type": "code", "execution_count": null, "id": "471b4745-9398-4f32-a25d-cd5ba5d0150e", "metadata": {}, "outputs": [], "source": [ "from transformers import WhisperFeatureExtractor, WhisperForConditionalGeneration, WhisperProcessor, WhisperTokenizer\n", "\n", "model_name = \"openai/whisper-medium.en\"\n", "\n", "model = WhisperForConditionalGeneration.from_pretrained(model_name)\n", "processor = WhisperProcessor.from_pretrained(model_name, language=\"English\", task=\"transcribe\")" ] }, { "cell_type": "code", "execution_count": null, "id": "dafa7e33-4628-426a-863e-3b50b9027929", "metadata": {}, "outputs": [], "source": [ "input_str = acc_dataset['train'][9][\"text\"]\n", "labels = processor.tokenizer(input_str).input_ids\n", "decoded_with_special = processor.tokenizer.decode(labels, skip_special_tokens=False)\n", "decoded_str = processor.tokenizer.decode(labels, skip_special_tokens=True)\n", "\n", "print(f\"Input: {input_str}\")\n", "print(f\"Decoded w/ special: {decoded_with_special}\")\n", "print(f\"Decoded w/out special: {decoded_str}\")\n", "print(f\"Are equal: {input_str == decoded_str}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "f1561f97-4d9f-4f17-9b84-da135c55715b", "metadata": {}, "outputs": [], "source": [ "acc_dataset['train'][0][\"audio\"]" ] }, { "cell_type": "code", "execution_count": null, "id": "fba3685d-bf62-40e1-bc83-71c4456cc824", "metadata": {}, "outputs": [], "source": [ "import IPython.display as ipd\n", "import numpy as np\n", "import random\n", "\n", "rand_int = random.randint(0, len(acc_dataset[\"train\"]))\n", "\n", "print(acc_dataset[\"train\"][rand_int][\"text\"])\n", "#pd.Audio(data=np.asarray(acc_dataset[\"train\"][rand_int][\"audio\"][\"array\"]), autoplay=True, rate=16000)" ] }, { "cell_type": "code", "execution_count": null, "id": "8138249e-a571-4c52-9eb4-3fcdf0c10469", "metadata": {}, "outputs": [], "source": [ "rand_int = random.randint(0, len(acc_dataset[\"train\"]))\n", "\n", "print(\"Target text:\", acc_dataset[\"train\"][rand_int][\"text\"])\n", "print(\"Input array shape:\", np.asarray(acc_dataset[\"train\"][rand_int][\"audio\"][\"array\"]).shape)\n", "print(\"Sampling rate:\", acc_dataset[\"train\"][rand_int][\"audio\"][\"sampling_rate\"])" ] }, { "cell_type": "code", "execution_count": null, "id": "0c49cc8d-c108-46a9-8f6c-6e5bf5d09c1c", "metadata": {}, "outputs": [], "source": [ "def prepare_dataset(batch):\n", " audio = batch[\"audio\"]\n", "\n", " # batched output is \"un-batched\" to ensure mapping is correct\n", " batch[\"input_features\"] = processor.feature_extractor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_features[0]\n", " \n", " batch[\"labels\"] = processor.tokenizer(batch[\"text\"]).input_ids\n", " \n", " return batch" ] }, { "cell_type": "code", "execution_count": null, "id": "d3c2946e-a5aa-4572-9717-3ab86878d121", "metadata": {}, "outputs": [], "source": [ "acc_dataset = acc_dataset.map(prepare_dataset)" ] }, { "cell_type": "code", "execution_count": null, "id": "5f1f0015-734f-44fa-bce9-1a605df36280", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "from dataclasses import dataclass\n", "from typing import Any, Dict, List, Union\n", "\n", "@dataclass\n", "class DataCollatorSpeechSeq2SeqWithPadding:\n", " processor: Any\n", " decoder_start_token_id: int\n", "\n", " def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n", " input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features]\n", " batch = self.processor.feature_extractor.pad(input_features, return_tensors=\"pt\")\n", "\n", " label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n", " labels_batch = self.processor.tokenizer.pad(label_features, return_tensors=\"pt\")\n", "\n", " labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n", "\n", " if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():\n", " labels = labels[:, 1:]\n", "\n", " batch[\"labels\"] = labels\n", "\n", " return batch" ] }, { "cell_type": "code", "execution_count": null, "id": "6519ec7b-dc55-4c37-90f7-2822c40e3e52", "metadata": {}, "outputs": [], "source": [ "data_collator = DataCollatorSpeechSeq2SeqWithPadding(\n", " processor=processor,\n", " decoder_start_token_id=model.config.decoder_start_token_id,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "f05bd4f0-cb15-4729-a8dd-610baaee6c8f", "metadata": {}, "outputs": [], "source": [ "import evaluate \n", "\n", "\n", "wer_metric = evaluate.load(\"wer\")\n", "cer_metric = evaluate.load(\"cer\")" ] }, { "cell_type": "code", "execution_count": null, "id": "b75cbad2-2487-4cd2-b20d-98dbc5631fa6", "metadata": {}, "outputs": [], "source": [ "def compute_metrics(pred):\n", " pred_ids = pred.predictions\n", " label_ids = pred.label_ids\n", "\n", " label_ids[label_ids == -100] = processor.tokenizer.pad_token_id\n", "\n", " pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n", " label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n", "\n", " wer = wer_metric.compute(predictions=pred_str, references=label_str)\n", "\n", " return {\"wer\": wer}" ] }, { "cell_type": "code", "execution_count": null, "id": "adb7eaaa-e18d-4716-af7a-0c5fdc24a95c", "metadata": {}, "outputs": [], "source": [ "from transformers import Seq2SeqTrainingArguments\n", "\n", "dir_for_training_artifacts = \"training-artifacts-\" + model_name\n", "\n", "eval_step_count = 25\n", "max_step_count = 300\n", "\n", "training_args = Seq2SeqTrainingArguments(\n", " evaluation_strategy=\"steps\",\n", " eval_steps=eval_step_count,\n", " fp16=True,\n", " generation_max_length=225,\n", " gradient_checkpointing=True,\n", " greater_is_better=False,\n", " learning_rate=5e-5,\n", " load_best_model_at_end=True,\n", " logging_steps=eval_step_count,\n", " max_steps=max_step_count,\n", " metric_for_best_model=\"wer\",\n", " output_dir= dir_for_training_artifacts,\n", " per_device_eval_batch_size=4,\n", " per_device_train_batch_size=32,\n", " predict_with_generate=True,\n", " push_to_hub=True,\n", " warmup_steps=eval_step_count,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "1f6db0db-6ce1-4a59-bf49-d8f776fa3a67", "metadata": {}, "outputs": [], "source": [ "from transformers import Seq2SeqTrainer\n", "\n", "trainer = Seq2SeqTrainer(\n", " args=training_args,\n", " model=model,\n", " train_dataset=acc_dataset[\"train\"],\n", " eval_dataset=acc_dataset[\"validate\"],\n", " data_collator=data_collator,\n", " compute_metrics=compute_metrics,\n", " tokenizer=processor.feature_extractor,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "08212258-db86-44d6-b4f3-9fb936ceee85", "metadata": {}, "outputs": [], "source": [ "# Authenticate with HF if you haven't already. \n", "\n", "#from huggingface_hub import notebook_login\n", "\n", "#notebook_login()" ] }, { "cell_type": "code", "execution_count": null, "id": "0940c20f-6d2f-4643-8aa4-ecd2e74f29ab", "metadata": {}, "outputs": [], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "id": "0434e129-05c5-469b-8e2c-1b16bdfd2432", "metadata": {}, "outputs": [], "source": [ "trainer.push_to_hub()" ] }, { "cell_type": "code", "execution_count": null, "id": "4d4f0605-9b4a-46d3-912c-cda97d3a6b9e", "metadata": {}, "outputs": [], "source": [ "def map_to_result(batch):\n", " with torch.no_grad():\n", " input_values = torch.tensor(batch[\"input_features\"], device=\"cuda\").unsqueeze(0)\n", " predicted_ids = model.generate(input_values)\n", "\n", " batch[\"pred_str\"] = processor.batch_decode(predicted_ids, skip_special_tokens=False)[0]\n", " return batch" ] }, { "cell_type": "code", "execution_count": null, "id": "907d2b0c-23c1-4862-900c-a020d7d8b8c0", "metadata": {}, "outputs": [], "source": [ "results = acc_dataset[\"test\"].map(map_to_result)\n", "#results = acc_dataset[\"validate\"].map(map_to_result)\n", "#results = acc_dataset[\"train\"].map(map_to_result)" ] }, { "cell_type": "code", "execution_count": null, "id": "bb25801b-adb7-48e0-9849-473dec2ee765", "metadata": {}, "outputs": [], "source": [ "import evaluate \n", "\n", "\n", "wer_metric = evaluate.load(\"wer\")\n", "cer_metric = evaluate.load(\"cer\")" ] }, { "cell_type": "code", "execution_count": null, "id": "8ee6948b-f8ea-4f39-8f33-0519ba8d8d85", "metadata": {}, "outputs": [], "source": [ "results[\"pred_str\"][0]" ] }, { "cell_type": "code", "execution_count": null, "id": "d3c3da77-c625-45fc-be34-ec43e2dbd6c2", "metadata": {}, "outputs": [], "source": [ "print(\"WER: {:.3f}\".format(wer_metric.compute(predictions=results[\"pred_str\"], references=results[\"text\"])))\n", "print(\"CER: {:.3f}\".format(cer_metric.compute(predictions=results[\"pred_str\"], references=results[\"text\"])))" ] }, { "cell_type": "code", "execution_count": null, "id": "ce444b9a-c222-4c01-a237-32b255a4617d", "metadata": {}, "outputs": [], "source": [ "def show_random_elements(dataset, num_examples=10):\n", " assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n", " picks = []\n", " for _ in range(num_examples):\n", " pick = random.randint(0, len(dataset)-1)\n", " while pick in picks:\n", " pick = random.randint(0, len(dataset)-1)\n", " picks.append(pick)\n", " \n", " df = pd.DataFrame(dataset[picks])\n", " display(HTML(df.to_html()))" ] }, { "cell_type": "code", "execution_count": null, "id": "3b5f15fa-8099-49fd-9f30-75db02fae4e1", "metadata": {}, "outputs": [], "source": [ "show_random_elements(results.select_columns([\"text\", \"pred_str\"]))" ] }, { "cell_type": "code", "execution_count": null, "id": "b4d29c7b-9610-4fc5-a30d-9ebffb41dd1d", "metadata": {}, "outputs": [], "source": [ "with torch.no_grad():\n", " predicted_ids = model.generate(torch.tensor(acc_dataset[\"train\"][:1][\"input_features\"], device=\"cuda\"))\n", "\n", "print(predicted_ids)\n", "\n", "# convert ids to tokens\n", "processor.batch_decode(predicted_ids, skip_special_tokens=False)[0]" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }