{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Qpw04rkbynx0" }, "source": [ "To run this, press \"*Runtime*\" and press \"*Run all*\" on a **free** Tesla T4 Google Colab instance!\n", "
\n", "\n", "To install Unsloth on your own computer, follow the installation instructions on our Github page [here](https://docs.unsloth.ai/get-started/installing-+-updating).\n", "\n", "You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "5fs-yYEaynx1" }, "source": [ "### News" ] }, { "cell_type": "markdown", "metadata": { "id": "pyJK0UZaynx2" }, "source": [ "Unsloth now supports Text-to-Speech (TTS) models. Read our [guide here](https://docs.unsloth.ai/basics/text-to-speech-tts-fine-tuning).\n", "\n", "Read our **[Gemma 3N Guide](https://docs.unsloth.ai/basics/gemma-3n-how-to-run-and-fine-tune)** and check out our new **[Dynamic 2.0](https://docs.unsloth.ai/basics/unsloth-dynamic-2.0-ggufs)** quants which outperforms other quantization methods!\n", "\n", "Visit our docs for all our [model uploads](https://docs.unsloth.ai/get-started/all-our-models) and [notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks).\n" ] }, { "cell_type": "markdown", "metadata": { "id": "SDUHv0mwynx3" }, "source": [ "### Installation" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "MY4G3EIbynx3" }, "outputs": [], "source": [ "%%capture\n", "import os\n", "if \"COLAB_\" not in \"\".join(os.environ.keys()):\n", " %pip install unsloth\n", "else:\n", " # Do this only in Colab notebooks! Otherwise use pip install unsloth\n", " %pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n", " %pip install sentencepiece protobuf \"datasets>=3.4.1,<4.0.0\" \"huggingface_hub>=0.34.0\" hf_transfer\n", " %pip install --no-deps unsloth\n", "%git clone https://github.com/SparkAudio/Spark-TTS\n", "%pip install omegaconf einx" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "QmUBVEnvCDJv", "outputId": "42083a68-d3cc-48c9-d852-b60796377434" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "π¦₯ Unsloth: Will patch your computer to enable 2x faster free finetuning.\n", "π¦₯ Unsloth Zoo will now patch everything to make training faster!\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9ad0d25a6f8549d1ac79addbe171b758", "version_major": 2, "version_minor": 0 }, "text/plain": [ ".gitattributes: 0.00B [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7e83dd9464b64a6d963c349d1660a28c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "config.yaml: 0.00B [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "332e86b12a4c45a89a95f1f265ca0f12", "version_major": 2, "version_minor": 0 }, "text/plain": [ "BiCodec/model.safetensors: 0%| | 0.00/626M [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c1a54d8c9dc8472e8f0f37603ccd3904", "version_major": 2, "version_minor": 0 }, "text/plain": [ "added_tokens.json: 0.00B [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8402d2f2ef204022b0727f2b09437bad", "version_major": 2, "version_minor": 0 }, "text/plain": [ "config.json: 0%| | 0.00/658 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "43f438eabd1843cc8c5977f0ef6226ec", "version_major": 2, "version_minor": 0 }, "text/plain": [ "merges.txt: 0.00B [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "87dce305eba54c1797547c06a2ab7cf6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "LLM/model.safetensors: 0%| | 0.00/2.03G [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3ea6e51894454a5c82bb4cfe1fd0a47f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "special_tokens_map.json: 0%| | 0.00/613 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "94e7da1bdc7549e0ba4dcd0b73d38667", "version_major": 2, "version_minor": 0 }, "text/plain": [ "LLM/tokenizer.json: 0%| | 0.00/14.1M [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1aa226f63eac4ee48537df6b26d921c1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer_config.json: 0.00B [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "420eaeeb7bee4c21964c17968c266ac1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "vocab.json: 0.00B [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "bdcb3d5d6a8e4e969afa77631e7c3104", "version_major": 2, "version_minor": 0 }, "text/plain": [ "README.md: 0.00B [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1cd60c7dbe61410ca5bc61310367635a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "config.yaml: 0%| | 0.00/169 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0ea819afc66b437ca8b0dc7337f5ce5f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "gradio_TTS.png: 0%| | 0.00/81.8k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "00f074bbbc5b44d59c590cc217187aa5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "gradio_control.png: 0%| | 0.00/62.2k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d050a4b7cf2b4f78af51986b9c2eee45", "version_major": 2, "version_minor": 0 }, "text/plain": [ "src/figures/infer_control.png: 0%| | 0.00/127k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6ed5ce435b89443f9cca00ed1b97311e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "src/figures/infer_voice_cloning.png: 0%| | 0.00/119k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0a7db4ff0d204ed4839471cbd8ebefef", "version_major": 2, "version_minor": 0 }, "text/plain": [ "src/logo/HKUST.jpg: 0%| | 0.00/102k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d7b682f3d5d142c68ec6bea0be196792", "version_major": 2, "version_minor": 0 }, "text/plain": [ "src/logo/NPU.jpg: 0%| | 0.00/152k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "bd49989b32d3492894bf08b084059ba6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "NTU.jpg: 0%| | 0.00/77.6k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b4576071c87448ef8ba94df410964d6c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "src/logo/SJU.jpg: 0%| | 0.00/364k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3dbdd98fca6741d2874849b2b26662db", "version_major": 2, "version_minor": 0 }, "text/plain": [ "SparkAudio.jpg: 0%| | 0.00/89.0k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ce753e6904ff4dd4ae5c5824ac554d76", "version_major": 2, "version_minor": 0 }, "text/plain": [ "SparkAudio2.jpg: 0%| | 0.00/40.7k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "90c48554b64b46f388ee14df2c401a02", "version_major": 2, "version_minor": 0 }, "text/plain": [ "SparkTTS.jpg: 0%| | 0.00/52.5k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "059f5fe90c324bd7b0aef23095af1c21", "version_major": 2, "version_minor": 0 }, "text/plain": [ "src/logo/SparkTTS.png: 0%| | 0.00/102k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ccf1938072024151ab5c50492866e253", "version_major": 2, "version_minor": 0 }, "text/plain": [ "src/logo/mobvoi.jpg: 0%| | 0.00/431k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "771681ce27b94c71a61da27b133427ac", "version_major": 2, "version_minor": 0 }, "text/plain": [ "src/logo/mobvoi.png: 0%| | 0.00/120k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "243ff52bb35242eeb330a2bb2ffe4166", "version_major": 2, "version_minor": 0 }, "text/plain": [ "README.md: 0.00B [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c17c5bd399fd411d8f2ee43f79539cca", "version_major": 2, "version_minor": 0 }, "text/plain": [ "config.json: 0.00B [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2d6ae8fc962b41aeb4ce1fec0d3f0864", "version_major": 2, "version_minor": 0 }, "text/plain": [ "preprocessor_config.json: 0%| | 0.00/212 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f3394d8a215e406f8f50b8770dd354d3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "wav2vec2-large-xlsr-53/pytorch_model.bin: 0%| | 0.00/1.27G [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "==((====))== Unsloth 2025.8.1: Fast Qwen2 patching. Transformers: 4.55.0.\n", " \\\\ /| NVIDIA GeForce RTX 2080 SUPER. Num GPUs = 2. Max memory: 7.785 GB. Platform: Linux.\n", "O^O/ \\_/ \\ Torch: 2.7.1+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.3.1\n", "\\ / Bfloat16 = FALSE. FA [Xformers = 0.0.31.post1. FA2 = False]\n", " \"-____-\" Free license: http://github.com/unslothai/unsloth\n", "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n", "Unsloth: Float16 full finetuning uses more memory since we upcast weights to float32.\n" ] } ], "source": [ "from unsloth import FastModel\n", "import torch\n", "from huggingface_hub import snapshot_download\n", "\n", "max_seq_length = 2048 # Choose any for long context!\n", "\n", "fourbit_models = [\n", " # 4bit dynamic quants for superior accuracy and low memory use\n", " \"unsloth/gemma-3-4b-it-unsloth-bnb-4bit\",\n", " \"unsloth/gemma-3-12b-it-unsloth-bnb-4bit\",\n", " \"unsloth/gemma-3-27b-it-unsloth-bnb-4bit\",\n", " # Qwen3 new models\n", " \"unsloth/Qwen3-4B-unsloth-bnb-4bit\",\n", " \"unsloth/Qwen3-8B-unsloth-bnb-4bit\",\n", " # Other very popular models!\n", " \"unsloth/Llama-3.1-8B\",\n", " \"unsloth/Llama-3.2-3B\",\n", " \"unsloth/Llama-3.3-70B\",\n", " \"unsloth/mistral-7b-instruct-v0.3\",\n", " \"unsloth/Phi-4\",\n", "] # More models at https://huggingface.co/unsloth\n", "\n", "# Download model and code\n", "snapshot_download(\"unsloth/Spark-TTS-0.5B\", local_dir = \"Spark-TTS-0.5B\")\n", "\n", "model, tokenizer = FastModel.from_pretrained(\n", " model_name = f\"Spark-TTS-0.5B/LLM\",\n", " max_seq_length = max_seq_length,\n", " dtype = torch.float32, # Spark seems to only work on float32 for now\n", " full_finetuning = True, # We support full finetuning now!\n", " load_in_4bit = False,\n", " #token = \"hf_...\", # use one if using gated models like meta-llama/Llama-2-7b-hf\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "SXd9bTZd1aaL" }, "source": [ "We now add LoRA adapters so we only need to update 1 to 10% of all parameters!" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "6bZsfBuZDeCL", "outputId": "292447b8-fd80-4b8b-ba3f-4637a1045166" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Unsloth: Full finetuning is enabled, so .get_peft_model has no effect\n" ] } ], "source": [ "#LoRA does not work with float32 only works with bfloat16 !!!\n", "model = FastModel.get_peft_model(\n", " model,\n", " r = 128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n", " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n", " \"gate_proj\", \"up_proj\", \"down_proj\",],\n", " lora_alpha = 128,\n", " lora_dropout = 0, # Supports any, but = 0 is optimized\n", " bias = \"none\", # Supports any, but = \"none\" is optimized\n", " # [NEW] \"unsloth\" uses 30% less VRAM, fits 2x larger batch sizes!\n", " use_gradient_checkpointing = \"unsloth\", # True or \"unsloth\" for very long context\n", " random_state = 3407,\n", " use_rslora = False, # We support rank stabilized LoRA\n", " loftq_config = None, # And LoftQ\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "vITh0KVJ10qX" }, "source": [ "\n", "### Data Prep \n", "\n", "We will use the `Balaji-1904/TTS_KN_DS_V1.1`, which is designed for training TTS models. Ensure that your dataset follows the required format: **text, audio** for single-speaker models or **source, text, audio** for multi-speaker models. You can modify this section to accommodate your own dataset, but maintaining the correct structure is essential for optimal training." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "LjY75GoYUCB8" }, "outputs": [], "source": [ "from datasets import load_dataset\n", "dataset = load_dataset(\"Balaji-1904/TTS_KN_DS_V1.1\", split = \"train\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 173, "referenced_widgets": [ "a3b0c0581f1f4c428baaadd8e9a39b6f", "2315228ff2b141afabe1263471f5364b", "0474debc340943bd85f3daf92aebf7aa", "cff1b0fa2ea24f45aab26685353eefdd", "b7e20be79df246f19b35114a690e44f0", "426eb100a94642f79e6b99777406a265", "a36b5cf197dd4bd9a7f70aa6671b804c", "0de4d0f282404edfbc191dca73f15f35", "e58b5ad2f781475d8af2ddb38009baa6", "33fbacbb2aa146cd90586357eec1dc3e", "930b4d1d5f4b494b830df4d4c398e67c" ] }, "id": "zK94B-Pfioto", "outputId": "3f11cf35-c173-410d-f709-43552323f26f" }, "outputs": [ { "ename": "ModuleNotFoundError", "evalue": "No module named 'torchaudio'", "output_type": "error", "traceback": [ "\u001b[31m---------------------------------------------------------------------------\u001b[39m", "\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)", "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 4\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;66;03m#@title Tokenization Function\u001b[39;00m\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mlocale\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m4\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorchaudio\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mtransforms\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mT\u001b[39;00m\n\u001b[32m 5\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mos\u001b[39;00m\n\u001b[32m 6\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorch\u001b[39;00m\n", "\u001b[31mModuleNotFoundError\u001b[39m: No module named 'torchaudio'" ] } ], "source": [ "#@title Tokenization Function\n", "\n", "import locale\n", "import torchaudio.transforms as T\n", "import os\n", "import torch\n", "import sys\n", "import numpy as np\n", "sys.path.append('Spark-TTS')\n", "from sparktts.models.audio_tokenizer import BiCodecTokenizer\n", "from sparktts.utils.audio import audio_volume_normalize\n", "\n", "audio_tokenizer = BiCodecTokenizer(\"Spark-TTS-0.5B\", \"cuda\")\n", "def extract_wav2vec2_features( wavs: torch.Tensor) -> torch.Tensor:\n", " \"\"\"extract wav2vec2 features\"\"\"\n", "\n", " if wavs.shape[0] != 1:\n", "\n", " raise ValueError(f\"Expected batch size 1, but got shape {wavs.shape}\")\n", " wav_np = wavs.squeeze(0).cpu().numpy()\n", "\n", " processed = audio_tokenizer.processor(\n", " wav_np,\n", " sampling_rate=16000,\n", " return_tensors=\"pt\",\n", " padding=True,\n", " )\n", " input_values = processed.input_values\n", "\n", " input_values = input_values.to(audio_tokenizer.feature_extractor.device)\n", "\n", " model_output = audio_tokenizer.feature_extractor(\n", " input_values,\n", " )\n", "\n", "\n", " if model_output.hidden_states is None:\n", " raise ValueError(\"Wav2Vec2Model did not return hidden states. Ensure config `output_hidden_states=True`.\")\n", "\n", " num_layers = len(model_output.hidden_states)\n", " required_layers = [11, 14, 16]\n", " if any(l >= num_layers for l in required_layers):\n", " raise IndexError(f\"Requested hidden state indices {required_layers} out of range for model with {num_layers} layers.\")\n", "\n", " feats_mix = (\n", " model_output.hidden_states[11] + model_output.hidden_states[14] + model_output.hidden_states[16]\n", " ) / 3\n", "\n", " return feats_mix\n", "def formatting_audio_func(example):\n", " text = f\"{example['source']}: {example['text']}\" if \"source\" in example else example[\"text\"]\n", " audio_array = example[\"audio\"][\"array\"]\n", " sampling_rate = example[\"audio\"][\"sampling_rate\"]\n", "\n", " target_sr = audio_tokenizer.config['sample_rate']\n", "\n", " if sampling_rate != target_sr:\n", " resampler = T.Resample(orig_freq=sampling_rate, new_freq=target_sr)\n", " audio_tensor_temp = torch.from_numpy(audio_array).float()\n", " audio_array = resampler(audio_tensor_temp).numpy()\n", "\n", " if audio_tokenizer.config[\"volume_normalize\"]:\n", " audio_array = audio_volume_normalize(audio_array)\n", "\n", " ref_wav_np = audio_tokenizer.get_ref_clip(audio_array)\n", "\n", " audio_tensor = torch.from_numpy(audio_array).unsqueeze(0).float().to(audio_tokenizer.device)\n", " ref_wav_tensor = torch.from_numpy(ref_wav_np).unsqueeze(0).float().to(audio_tokenizer.device)\n", "\n", "\n", " feat = extract_wav2vec2_features(audio_tensor)\n", "\n", " batch = {\n", "\n", " \"wav\": audio_tensor,\n", " \"ref_wav\": ref_wav_tensor,\n", " \"feat\": feat.to(audio_tokenizer.device),\n", " }\n", "\n", "\n", " semantic_token_ids, global_token_ids = audio_tokenizer.model.tokenize(batch)\n", "\n", " global_tokens = \"\".join(\n", " [f\"<|bicodec_global_{i}|>\" for i in global_token_ids.squeeze().cpu().numpy()] # Squeeze batch dim\n", " )\n", " semantic_tokens = \"\".join(\n", " [f\"<|bicodec_semantic_{i}|>\" for i in semantic_token_ids.squeeze().cpu().numpy()] # Squeeze batch dim\n", " )\n", "\n", " inputs = [\n", " \"<|task_tts|>\",\n", " \"<|start_content|>\",\n", " text,\n", " \"<|end_content|>\",\n", " \"<|start_global_token|>\",\n", " global_tokens,\n", " \"<|end_global_token|>\",\n", " \"<|start_semantic_token|>\",\n", " semantic_tokens,\n", " \"<|end_semantic_token|>\",\n", " \"<|im_end|>\"\n", " ]\n", " inputs = \"\".join(inputs)\n", " return {\"text\": inputs}\n", "\n", "\n", "dataset = dataset.map(formatting_audio_func, remove_columns=[\"audio\"])\n", "print(\"Moving Bicodec model and Wav2Vec2Model to cpu.\")\n", "audio_tokenizer.model.cpu()\n", "audio_tokenizer.feature_extractor.cpu()\n", "torch.cuda.empty_cache()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting torchaudio\n", " Downloading torchaudio-2.8.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (7.2 kB)\n", "Collecting torch==2.8.0 (from torchaudio)\n", " Using cached torch-2.8.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (30 kB)\n", "Requirement already satisfied: filelock in /datadrive/jupyter/devbase/Balaji/TTS_ft/lib/python3.12/site-packages (from torch==2.8.0->torchaudio) (3.18.0)\n", "Requirement already satisfied: typing-extensions>=4.10.0 in /datadrive/jupyter/devbase/Balaji/TTS_ft/lib/python3.12/site-packages (from torch==2.8.0->torchaudio) (4.14.1)\n", "Requirement already satisfied: setuptools in /datadrive/jupyter/devbase/Balaji/TTS_ft/lib/python3.12/site-packages (from torch==2.8.0->torchaudio) (80.9.0)\n", "Requirement already satisfied: sympy>=1.13.3 in /datadrive/jupyter/devbase/Balaji/TTS_ft/lib/python3.12/site-packages (from torch==2.8.0->torchaudio) (1.14.0)\n", "Requirement already satisfied: networkx in /datadrive/jupyter/devbase/Balaji/TTS_ft/lib/python3.12/site-packages (from torch==2.8.0->torchaudio) (3.5)\n", "Requirement already satisfied: jinja2 in /datadrive/jupyter/devbase/Balaji/TTS_ft/lib/python3.12/site-packages (from torch==2.8.0->torchaudio) (3.1.6)\n", "Requirement already satisfied: fsspec in /datadrive/jupyter/devbase/Balaji/TTS_ft/lib/python3.12/site-packages (from torch==2.8.0->torchaudio) (2025.3.0)\n", "Collecting nvidia-cuda-nvrtc-cu12==12.8.93 (from torch==2.8.0->torchaudio)\n", " Using cached nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl.metadata (1.7 kB)\n", "Collecting nvidia-cuda-runtime-cu12==12.8.90 (from torch==2.8.0->torchaudio)\n", " Using cached nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.7 kB)\n", "Collecting nvidia-cuda-cupti-cu12==12.8.90 (from torch==2.8.0->torchaudio)\n", " Using cached nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.7 kB)\n", "Collecting nvidia-cudnn-cu12==9.10.2.21 (from torch==2.8.0->torchaudio)\n", " Using cached nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl.metadata (1.8 kB)\n", "Collecting nvidia-cublas-cu12==12.8.4.1 (from torch==2.8.0->torchaudio)\n", " Using cached nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl.metadata (1.7 kB)\n", "Collecting nvidia-cufft-cu12==11.3.3.83 (from torch==2.8.0->torchaudio)\n", " Using cached nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.7 kB)\n", "Collecting nvidia-curand-cu12==10.3.9.90 (from torch==2.8.0->torchaudio)\n", " Using cached nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl.metadata (1.7 kB)\n", "Collecting nvidia-cusolver-cu12==11.7.3.90 (from torch==2.8.0->torchaudio)\n", " Using cached nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl.metadata (1.8 kB)\n", "Collecting nvidia-cusparse-cu12==12.5.8.93 (from torch==2.8.0->torchaudio)\n", " Using cached nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.8 kB)\n", "Collecting nvidia-cusparselt-cu12==0.7.1 (from torch==2.8.0->torchaudio)\n", " Using cached nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl.metadata (7.0 kB)\n", "Collecting nvidia-nccl-cu12==2.27.3 (from torch==2.8.0->torchaudio)\n", " Using cached nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.0 kB)\n", "Collecting nvidia-nvtx-cu12==12.8.90 (from torch==2.8.0->torchaudio)\n", " Using cached nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.8 kB)\n", "Collecting nvidia-nvjitlink-cu12==12.8.93 (from torch==2.8.0->torchaudio)\n", " Using cached nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl.metadata (1.7 kB)\n", "Collecting nvidia-cufile-cu12==1.13.1.3 (from torch==2.8.0->torchaudio)\n", " Using cached nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.7 kB)\n", "Collecting triton==3.4.0 (from torch==2.8.0->torchaudio)\n", " Using cached triton-3.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.7 kB)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /datadrive/jupyter/devbase/Balaji/TTS_ft/lib/python3.12/site-packages (from sympy>=1.13.3->torch==2.8.0->torchaudio) (1.3.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /datadrive/jupyter/devbase/Balaji/TTS_ft/lib/python3.12/site-packages (from jinja2->torch==2.8.0->torchaudio) (3.0.2)\n", "Downloading torchaudio-2.8.0-cp312-cp312-manylinux_2_28_x86_64.whl (4.0 MB)\n", "\u001b[2K \u001b[38;2;114;156;31mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m4.0/4.0 MB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:01\u001b[0m\n", "\u001b[?25hDownloading torch-2.8.0-cp312-cp312-manylinux_2_28_x86_64.whl (887.9 MB)\n", "\u001b[2K \u001b[38;2;114;156;31mβββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m887.9/887.9 MB\u001b[0m \u001b[31m979.7 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m eta \u001b[36m0:00:01\u001b[0m[36m0:00:19\u001b[0mm\n", "\u001b[?25hDownloading nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl (594.3 MB)\n", "\u001b[2K \u001b[38;2;114;156;31mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m594.3/594.3 MB\u001b[0m \u001b[31m1.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:13\u001b[0m\n", "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (10.2 MB)\n", "\u001b[2K \u001b[38;2;114;156;31mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m10.2/10.2 MB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:01\u001b[0m\n", "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl (88.0 MB)\n", "\u001b[2K \u001b[38;2;114;156;31mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m88.0/88.0 MB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:02\u001b[0m\n", "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (954 kB)\n", "\u001b[2K \u001b[38;2;114;156;31mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m954.8/954.8 kB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m[36m0:00:01\u001b[0m[36m0:00:01\u001b[0m:01\u001b[0m\n", "\u001b[?25hDownloading nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl (706.8 MB)\n", "\u001b[2K \u001b[38;2;114;156;31mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m706.8/706.8 MB\u001b[0m \u001b[31m1.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:15\u001b[0mm\n", "\u001b[?25hDownloading nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (193.1 MB)\n", "\u001b[2K \u001b[38;2;114;156;31mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m193.1/193.1 MB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:05\u001b[0m\n", "\u001b[?25hDownloading nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (1.2 MB)\n", "\u001b[2K \u001b[38;2;114;156;31mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hDownloading nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl (63.6 MB)\n", "\u001b[2K \u001b[38;2;114;156;31mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m63.6/63.6 MB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:02\u001b[0m\n", "\u001b[?25hDownloading nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl (267.5 MB)\n", "\u001b[2K \u001b[38;2;114;156;31mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m267.5/267.5 MB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:06\u001b[0m\n", "\u001b[?25hDownloading nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (288.2 MB)\n", "\u001b[2K \u001b[38;2;114;156;31mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m288.2/288.2 MB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:07\u001b[0m\n", "\u001b[?25hDownloading nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl (287.2 MB)\n", "\u001b[2K \u001b[38;2;114;156;31mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m287.2/287.2 MB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:06\u001b[0m\n", "\u001b[?25hDownloading nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (322.4 MB)\n", "\u001b[2K \u001b[38;2;114;156;31mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m322.4/322.4 MB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:07\u001b[0m\n", "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl (39.3 MB)\n", "\u001b[2K \u001b[38;2;114;156;31mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m39.3/39.3 MB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:01\u001b[0m\n", "\u001b[?25hDownloading nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (89 kB)\n", "\u001b[2K \u001b[38;2;114;156;31mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m90.0/90.0 kB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m:01\u001b[0m\n", "\u001b[?25hDownloading triton-3.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (155.6 MB)\n", "\u001b[2K \u001b[38;2;114;156;31mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m155.6/155.6 MB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:04\u001b[0m\n", "\u001b[?25hInstalling collected packages: nvidia-cusparselt-cu12, triton, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufile-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cufft-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch, torchaudio\n", " Attempting uninstall: nvidia-cusparselt-cu12\n", " Found existing installation: nvidia-cusparselt-cu12 0.6.3\n", " Uninstalling nvidia-cusparselt-cu12-0.6.3:\n", " Successfully uninstalled nvidia-cusparselt-cu12-0.6.3\n", " Attempting uninstall: triton\n", " Found existing installation: triton 3.3.1\n", " Uninstalling triton-3.3.1:\n", " Successfully uninstalled triton-3.3.1\n", " Attempting uninstall: nvidia-nvtx-cu12\n", " Found existing installation: nvidia-nvtx-cu12 12.6.77\n", " Uninstalling nvidia-nvtx-cu12-12.6.77:\n", " Successfully uninstalled nvidia-nvtx-cu12-12.6.77\n", " Attempting uninstall: nvidia-nvjitlink-cu12\n", " Found existing installation: nvidia-nvjitlink-cu12 12.6.85\n", " Uninstalling nvidia-nvjitlink-cu12-12.6.85:\n", " Successfully uninstalled nvidia-nvjitlink-cu12-12.6.85\n", " Attempting uninstall: nvidia-nccl-cu12\n", " Found existing installation: nvidia-nccl-cu12 2.26.2\n", " Uninstalling nvidia-nccl-cu12-2.26.2:\n", " Successfully uninstalled nvidia-nccl-cu12-2.26.2\n", " Attempting uninstall: nvidia-curand-cu12\n", " Found existing installation: nvidia-curand-cu12 10.3.7.77\n", " Uninstalling nvidia-curand-cu12-10.3.7.77:\n", " Successfully uninstalled nvidia-curand-cu12-10.3.7.77\n", " Attempting uninstall: nvidia-cufile-cu12\n", " Found existing installation: nvidia-cufile-cu12 1.11.1.6\n", " Uninstalling nvidia-cufile-cu12-1.11.1.6:\n", " Successfully uninstalled nvidia-cufile-cu12-1.11.1.6\n", " Attempting uninstall: nvidia-cuda-runtime-cu12\n", " Found existing installation: nvidia-cuda-runtime-cu12 12.6.77\n", " Uninstalling nvidia-cuda-runtime-cu12-12.6.77:\n", " Successfully uninstalled nvidia-cuda-runtime-cu12-12.6.77\n", " Attempting uninstall: nvidia-cuda-nvrtc-cu12\n", " Found existing installation: nvidia-cuda-nvrtc-cu12 12.6.77\n", " Uninstalling nvidia-cuda-nvrtc-cu12-12.6.77:\n", " Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.6.77\n", " Attempting uninstall: nvidia-cuda-cupti-cu12\n", " Found existing installation: nvidia-cuda-cupti-cu12 12.6.80\n", " Uninstalling nvidia-cuda-cupti-cu12-12.6.80:\n", " Successfully uninstalled nvidia-cuda-cupti-cu12-12.6.80\n", " Attempting uninstall: nvidia-cublas-cu12\n", " Found existing installation: nvidia-cublas-cu12 12.6.4.1\n", " Uninstalling nvidia-cublas-cu12-12.6.4.1:\n", " Successfully uninstalled nvidia-cublas-cu12-12.6.4.1\n", " Attempting uninstall: nvidia-cusparse-cu12\n", " Found existing installation: nvidia-cusparse-cu12 12.5.4.2\n", " Uninstalling nvidia-cusparse-cu12-12.5.4.2:\n", " Successfully uninstalled nvidia-cusparse-cu12-12.5.4.2\n", " Attempting uninstall: nvidia-cufft-cu12\n", " Found existing installation: nvidia-cufft-cu12 11.3.0.4\n", " Uninstalling nvidia-cufft-cu12-11.3.0.4:\n", " Successfully uninstalled nvidia-cufft-cu12-11.3.0.4\n", " Attempting uninstall: nvidia-cudnn-cu12\n", " Found existing installation: nvidia-cudnn-cu12 9.5.1.17\n", " Uninstalling nvidia-cudnn-cu12-9.5.1.17:\n", " Successfully uninstalled nvidia-cudnn-cu12-9.5.1.17\n", " Attempting uninstall: nvidia-cusolver-cu12\n", " Found existing installation: nvidia-cusolver-cu12 11.7.1.2\n", " Uninstalling nvidia-cusolver-cu12-11.7.1.2:\n", " Successfully uninstalled nvidia-cusolver-cu12-11.7.1.2\n", " Attempting uninstall: torch\n", " Found existing installation: torch 2.7.1\n", " Uninstalling torch-2.7.1:\n", " Successfully uninstalled torch-2.7.1\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "xformers 0.0.31.post1 requires torch==2.7.1, but you have torch 2.8.0 which is incompatible.\n", "torchvision 0.22.1 requires torch==2.7.1, but you have torch 2.8.0 which is incompatible.\u001b[0m\u001b[31m\n", "\u001b[0mSuccessfully installed nvidia-cublas-cu12-12.8.4.1 nvidia-cuda-cupti-cu12-12.8.90 nvidia-cuda-nvrtc-cu12-12.8.93 nvidia-cuda-runtime-cu12-12.8.90 nvidia-cudnn-cu12-9.10.2.21 nvidia-cufft-cu12-11.3.3.83 nvidia-cufile-cu12-1.13.1.3 nvidia-curand-cu12-10.3.9.90 nvidia-cusolver-cu12-11.7.3.90 nvidia-cusparse-cu12-12.5.8.93 nvidia-cusparselt-cu12-0.7.1 nvidia-nccl-cu12-2.27.3 nvidia-nvjitlink-cu12-12.8.93 nvidia-nvtx-cu12-12.8.90 torch-2.8.0 torchaudio-2.8.0 triton-3.4.0\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "%pip install torchaudio" ] }, { "cell_type": "markdown", "metadata": { "id": "idAEIeSQ3xdS" }, "source": [ "\n", "### Train the model\n", "Now let's use Huggingface TRL's `SFTTrainer`! More docs here: [TRL SFT docs](https://huggingface.co/docs/trl/sft_trainer). We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`. We also support TRL's `DPOTrainer`!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "95_Nn-89DhsL" }, "outputs": [], "source": [ "from trl import SFTConfig, SFTTrainer\n", "trainer = SFTTrainer(\n", " model = model,\n", " tokenizer = tokenizer,\n", " train_dataset = dataset,\n", " dataset_text_field = \"text\",\n", " max_seq_length = max_seq_length,\n", " packing = False, # Can make training 5x faster for short sequences.\n", " args = SFTConfig(\n", " per_device_train_batch_size = 2,\n", " gradient_accumulation_steps = 4,\n", " warmup_steps = 5,\n", " num_train_epochs = 5, # Set this for 1 full training run.\n", " #max_steps = 60,\n", " learning_rate = 1e-5,\n", " fp16 = False, # We're doing full float32 s disable mixed precision\n", " bf16 = False, # We're doing full float32 s disable mixed precision\n", " logging_steps = 1,\n", " optim = \"adamw_8bit\",\n", " weight_decay = 0.01,\n", " lr_scheduler_type = \"linear\",\n", " seed = 3407,\n", " output_dir = \"outputs\",\n", " report_to = \"tensorboard\", # Use this for WandB etc\n", " ),\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2ejIt2xSNKKp" }, "outputs": [], "source": [ "# @title Show current memory stats\n", "gpu_stats = torch.cuda.get_device_properties(0)\n", "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n", "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n", "print(f\"{start_gpu_memory} GB of memory reserved.\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yqxqAZ7KJ4oL" }, "outputs": [], "source": [ "trainer_stats = trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "pCqnaKmlO1U9" }, "outputs": [], "source": [ "# @title Show final memory and time stats\n", "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n", "used_percentage = round(used_memory / max_memory * 100, 3)\n", "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n", "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n", "print(\n", " f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\"\n", ")\n", "print(f\"Peak reserved memory = {used_memory} GB.\")\n", "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n", "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n", "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")" ] }, { "cell_type": "markdown", "metadata": { "id": "ekOmTR1hSNcr" }, "source": [ "\n", "### Inference\n", "Let's run the model! You can change the prompts\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "apUdB40Ep6Ki" }, "outputs": [], "source": [ "input_text = \"Hey there my name is Elise,