Upload folder using huggingface_hub
Browse files- .gitattributes +9 -0
- .ipynb_checkpoints/Spark_TTS_FT-checkpoint.ipynb +1211 -0
- .ipynb_checkpoints/config-checkpoint.yaml +7 -0
- BiCodec/config.yaml +60 -0
- BiCodec/model.safetensors +3 -0
- LLM/.gitattributes +36 -0
- LLM/added_tokens.json +0 -0
- LLM/chat_template.jinja +54 -0
- LLM/config.json +56 -0
- LLM/generation_config.json +8 -0
- LLM/merges.txt +0 -0
- LLM/model.safetensors +3 -0
- LLM/special_tokens_map.json +31 -0
- LLM/tokenizer.json +3 -0
- LLM/tokenizer_config.json +0 -0
- LLM/vocab.json +0 -0
- README.md +208 -0
- Spark_TTS_FT.ipynb +1732 -0
- config.yaml +7 -0
- src/figures/gradio_TTS.png +0 -0
- src/figures/gradio_control.png +0 -0
- src/figures/infer_control.png +3 -0
- src/figures/infer_voice_cloning.png +3 -0
- src/logo/HKUST.jpg +3 -0
- src/logo/NPU.jpg +3 -0
- src/logo/NTU.jpg +0 -0
- src/logo/SJU.jpg +3 -0
- src/logo/SparkAudio.jpg +0 -0
- src/logo/SparkAudio2.jpg +0 -0
- src/logo/SparkTTS.jpg +0 -0
- src/logo/SparkTTS.png +3 -0
- src/logo/mobvoi.jpg +3 -0
- src/logo/mobvoi.png +3 -0
- wav2vec2-large-xlsr-53/README.md +29 -0
- wav2vec2-large-xlsr-53/config.json +83 -0
- wav2vec2-large-xlsr-53/preprocessor_config.json +9 -0
- wav2vec2-large-xlsr-53/pytorch_model.bin +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
LLM/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
37 |
+
src/figures/infer_control.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
src/figures/infer_voice_cloning.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
src/logo/HKUST.jpg filter=lfs diff=lfs merge=lfs -text
|
40 |
+
src/logo/NPU.jpg filter=lfs diff=lfs merge=lfs -text
|
41 |
+
src/logo/SJU.jpg filter=lfs diff=lfs merge=lfs -text
|
42 |
+
src/logo/SparkTTS.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
src/logo/mobvoi.jpg filter=lfs diff=lfs merge=lfs -text
|
44 |
+
src/logo/mobvoi.png filter=lfs diff=lfs merge=lfs -text
|
.ipynb_checkpoints/Spark_TTS_FT-checkpoint.ipynb
ADDED
@@ -0,0 +1,1211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "Qpw04rkbynx0"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"To run this, press \"*Runtime*\" and press \"*Run all*\" on a **free** Tesla T4 Google Colab instance!\n",
|
10 |
+
"<div class=\"align-center\">\n",
|
11 |
+
"<a href=\"https://unsloth.ai/\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png\" width=\"115\"></a>\n",
|
12 |
+
"<a href=\"https://discord.gg/unsloth\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/Discord button.png\" width=\"145\"></a>\n",
|
13 |
+
"<a href=\"https://docs.unsloth.ai/\"><img src=\"https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true\" width=\"125\"></a></a> Join Discord if you need help + ⭐ <i>Star us on <a href=\"https://github.com/unslothai/unsloth\">Github</a> </i> ⭐\n",
|
14 |
+
"</div>\n",
|
15 |
+
"\n",
|
16 |
+
"To install Unsloth on your own computer, follow the installation instructions on our Github page [here](https://docs.unsloth.ai/get-started/installing-+-updating).\n",
|
17 |
+
"\n",
|
18 |
+
"You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)\n"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "markdown",
|
23 |
+
"metadata": {
|
24 |
+
"id": "5fs-yYEaynx1"
|
25 |
+
},
|
26 |
+
"source": [
|
27 |
+
"### News"
|
28 |
+
]
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"cell_type": "markdown",
|
32 |
+
"metadata": {
|
33 |
+
"id": "pyJK0UZaynx2"
|
34 |
+
},
|
35 |
+
"source": [
|
36 |
+
"Unsloth now supports Text-to-Speech (TTS) models. Read our [guide here](https://docs.unsloth.ai/basics/text-to-speech-tts-fine-tuning).\n",
|
37 |
+
"\n",
|
38 |
+
"Read our **[Gemma 3N Guide](https://docs.unsloth.ai/basics/gemma-3n-how-to-run-and-fine-tune)** and check out our new **[Dynamic 2.0](https://docs.unsloth.ai/basics/unsloth-dynamic-2.0-ggufs)** quants which outperforms other quantization methods!\n",
|
39 |
+
"\n",
|
40 |
+
"Visit our docs for all our [model uploads](https://docs.unsloth.ai/get-started/all-our-models) and [notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks).\n"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "markdown",
|
45 |
+
"metadata": {
|
46 |
+
"id": "SDUHv0mwynx3"
|
47 |
+
},
|
48 |
+
"source": [
|
49 |
+
"### Installation"
|
50 |
+
]
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"cell_type": "code",
|
54 |
+
"execution_count": 1,
|
55 |
+
"metadata": {
|
56 |
+
"id": "MY4G3EIbynx3"
|
57 |
+
},
|
58 |
+
"outputs": [],
|
59 |
+
"source": [
|
60 |
+
"%%capture\n",
|
61 |
+
"import os\n",
|
62 |
+
"if \"COLAB_\" not in \"\".join(os.environ.keys()):\n",
|
63 |
+
" %pip install unsloth\n",
|
64 |
+
"else:\n",
|
65 |
+
" # Do this only in Colab notebooks! Otherwise use pip install unsloth\n",
|
66 |
+
" %pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n",
|
67 |
+
" %pip install sentencepiece protobuf \"datasets>=3.4.1,<4.0.0\" \"huggingface_hub>=0.34.0\" hf_transfer\n",
|
68 |
+
" %pip install --no-deps unsloth\n",
|
69 |
+
"%git clone https://github.com/SparkAudio/Spark-TTS\n",
|
70 |
+
"%pip install omegaconf einx"
|
71 |
+
]
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"cell_type": "markdown",
|
75 |
+
"metadata": {
|
76 |
+
"id": "AkWYsztAs9Ky"
|
77 |
+
},
|
78 |
+
"source": [
|
79 |
+
"### Unsloth\n",
|
80 |
+
"\n",
|
81 |
+
"`FastModel` supports loading nearly any model now! This includes Vision and Text models!\n",
|
82 |
+
"\n",
|
83 |
+
"Thank you to [Etherl](https://huggingface.co/Etherll) for creating this notebook!"
|
84 |
+
]
|
85 |
+
},
|
86 |
+
{
|
87 |
+
"cell_type": "code",
|
88 |
+
"execution_count": 2,
|
89 |
+
"metadata": {
|
90 |
+
"colab": {
|
91 |
+
"base_uri": "https://localhost:8080/"
|
92 |
+
},
|
93 |
+
"execution": {
|
94 |
+
"iopub.execute_input": "2025-03-22T00:48:54.511089Z",
|
95 |
+
"iopub.status.busy": "2025-03-22T00:48:54.510770Z",
|
96 |
+
"iopub.status.idle": "2025-03-22T00:51:37.363415Z",
|
97 |
+
"shell.execute_reply": "2025-03-22T00:51:37.362696Z",
|
98 |
+
"shell.execute_reply.started": "2025-03-22T00:48:54.511053Z"
|
99 |
+
},
|
100 |
+
"id": "QmUBVEnvCDJv",
|
101 |
+
"outputId": "42083a68-d3cc-48c9-d852-b60796377434"
|
102 |
+
},
|
103 |
+
"outputs": [
|
104 |
+
{
|
105 |
+
"name": "stdout",
|
106 |
+
"output_type": "stream",
|
107 |
+
"text": [
|
108 |
+
"🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n",
|
109 |
+
"🦥 Unsloth Zoo will now patch everything to make training faster!\n",
|
110 |
+
"==((====))== Unsloth 2025.8.1: Fast Qwen2 patching. Transformers: 4.54.1.\n",
|
111 |
+
" \\\\ /| Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.\n",
|
112 |
+
"O^O/ \\_/ \\ Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0\n",
|
113 |
+
"\\ / Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]\n",
|
114 |
+
" \"-____-\" Free license: http://github.com/unslothai/unsloth\n",
|
115 |
+
"Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n",
|
116 |
+
"Unsloth: Float16 full finetuning uses more memory since we upcast weights to float32.\n"
|
117 |
+
]
|
118 |
+
}
|
119 |
+
],
|
120 |
+
"source": [
|
121 |
+
"from unsloth import FastModel\n",
|
122 |
+
"import torch\n",
|
123 |
+
"from huggingface_hub import snapshot_download\n",
|
124 |
+
"\n",
|
125 |
+
"max_seq_length = 2048 # Choose any for long context!\n",
|
126 |
+
"\n",
|
127 |
+
"fourbit_models = [\n",
|
128 |
+
" # 4bit dynamic quants for superior accuracy and low memory use\n",
|
129 |
+
" \"unsloth/gemma-3-4b-it-unsloth-bnb-4bit\",\n",
|
130 |
+
" \"unsloth/gemma-3-12b-it-unsloth-bnb-4bit\",\n",
|
131 |
+
" \"unsloth/gemma-3-27b-it-unsloth-bnb-4bit\",\n",
|
132 |
+
" # Qwen3 new models\n",
|
133 |
+
" \"unsloth/Qwen3-4B-unsloth-bnb-4bit\",\n",
|
134 |
+
" \"unsloth/Qwen3-8B-unsloth-bnb-4bit\",\n",
|
135 |
+
" # Other very popular models!\n",
|
136 |
+
" \"unsloth/Llama-3.1-8B\",\n",
|
137 |
+
" \"unsloth/Llama-3.2-3B\",\n",
|
138 |
+
" \"unsloth/Llama-3.3-70B\",\n",
|
139 |
+
" \"unsloth/mistral-7b-instruct-v0.3\",\n",
|
140 |
+
" \"unsloth/Phi-4\",\n",
|
141 |
+
"] # More models at https://huggingface.co/unsloth\n",
|
142 |
+
"\n",
|
143 |
+
"# Download model and code\n",
|
144 |
+
"snapshot_download(\"unsloth/Spark-TTS-0.5B\", local_dir = \"Spark-TTS-0.5B\")\n",
|
145 |
+
"\n",
|
146 |
+
"model, tokenizer = FastModel.from_pretrained(\n",
|
147 |
+
" model_name = f\"Spark-TTS-0.5B/LLM\",\n",
|
148 |
+
" max_seq_length = max_seq_length,\n",
|
149 |
+
" dtype = torch.float32, # Spark seems to only work on float32 for now\n",
|
150 |
+
" full_finetuning = True, # We support full finetuning now!\n",
|
151 |
+
" load_in_4bit = False,\n",
|
152 |
+
" #token = \"hf_...\", # use one if using gated models like meta-llama/Llama-2-7b-hf\n",
|
153 |
+
")"
|
154 |
+
]
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"cell_type": "markdown",
|
158 |
+
"metadata": {
|
159 |
+
"id": "SXd9bTZd1aaL"
|
160 |
+
},
|
161 |
+
"source": [
|
162 |
+
"We now add LoRA adapters so we only need to update 1 to 10% of all parameters!"
|
163 |
+
]
|
164 |
+
},
|
165 |
+
{
|
166 |
+
"cell_type": "code",
|
167 |
+
"execution_count": 3,
|
168 |
+
"metadata": {
|
169 |
+
"colab": {
|
170 |
+
"base_uri": "https://localhost:8080/"
|
171 |
+
},
|
172 |
+
"execution": {
|
173 |
+
"iopub.execute_input": "2025-03-22T00:51:37.365079Z",
|
174 |
+
"iopub.status.busy": "2025-03-22T00:51:37.364731Z",
|
175 |
+
"iopub.status.idle": "2025-03-22T00:51:44.221612Z",
|
176 |
+
"shell.execute_reply": "2025-03-22T00:51:44.220949Z",
|
177 |
+
"shell.execute_reply.started": "2025-03-22T00:51:37.365045Z"
|
178 |
+
},
|
179 |
+
"id": "6bZsfBuZDeCL",
|
180 |
+
"outputId": "292447b8-fd80-4b8b-ba3f-4637a1045166"
|
181 |
+
},
|
182 |
+
"outputs": [
|
183 |
+
{
|
184 |
+
"name": "stdout",
|
185 |
+
"output_type": "stream",
|
186 |
+
"text": [
|
187 |
+
"Unsloth: Full finetuning is enabled, so .get_peft_model has no effect\n"
|
188 |
+
]
|
189 |
+
}
|
190 |
+
],
|
191 |
+
"source": [
|
192 |
+
"#LoRA does not work with float32 only works with bfloat16 !!!\n",
|
193 |
+
"model = FastModel.get_peft_model(\n",
|
194 |
+
" model,\n",
|
195 |
+
" r = 128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
|
196 |
+
" target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
|
197 |
+
" \"gate_proj\", \"up_proj\", \"down_proj\",],\n",
|
198 |
+
" lora_alpha = 128,\n",
|
199 |
+
" lora_dropout = 0, # Supports any, but = 0 is optimized\n",
|
200 |
+
" bias = \"none\", # Supports any, but = \"none\" is optimized\n",
|
201 |
+
" # [NEW] \"unsloth\" uses 30% less VRAM, fits 2x larger batch sizes!\n",
|
202 |
+
" use_gradient_checkpointing = \"unsloth\", # True or \"unsloth\" for very long context\n",
|
203 |
+
" random_state = 3407,\n",
|
204 |
+
" use_rslora = False, # We support rank stabilized LoRA\n",
|
205 |
+
" loftq_config = None, # And LoftQ\n",
|
206 |
+
")"
|
207 |
+
]
|
208 |
+
},
|
209 |
+
{
|
210 |
+
"cell_type": "markdown",
|
211 |
+
"metadata": {
|
212 |
+
"id": "vITh0KVJ10qX"
|
213 |
+
},
|
214 |
+
"source": [
|
215 |
+
"<a name=\"Data\"></a>\n",
|
216 |
+
"### Data Prep \n",
|
217 |
+
"\n",
|
218 |
+
"We will use the `MrDragonFox/Elise`, which is designed for training TTS models. Ensure that your dataset follows the required format: **text, audio** for single-speaker models or **source, text, audio** for multi-speaker models. You can modify this section to accommodate your own dataset, but maintaining the correct structure is essential for optimal training."
|
219 |
+
]
|
220 |
+
},
|
221 |
+
{
|
222 |
+
"cell_type": "code",
|
223 |
+
"execution_count": 4,
|
224 |
+
"metadata": {
|
225 |
+
"execution": {
|
226 |
+
"iopub.execute_input": "2025-03-22T00:51:44.222880Z",
|
227 |
+
"iopub.status.busy": "2025-03-22T00:51:44.222617Z",
|
228 |
+
"iopub.status.idle": "2025-03-22T00:52:16.516878Z",
|
229 |
+
"shell.execute_reply": "2025-03-22T00:52:16.516033Z",
|
230 |
+
"shell.execute_reply.started": "2025-03-22T00:51:44.222848Z"
|
231 |
+
},
|
232 |
+
"id": "LjY75GoYUCB8"
|
233 |
+
},
|
234 |
+
"outputs": [],
|
235 |
+
"source": [
|
236 |
+
"from datasets import load_dataset\n",
|
237 |
+
"dataset = load_dataset(\"Balaji-1904/TTS_KN_DS_V1.1\", split = \"train\")"
|
238 |
+
]
|
239 |
+
},
|
240 |
+
{
|
241 |
+
"cell_type": "code",
|
242 |
+
"execution_count": null,
|
243 |
+
"metadata": {
|
244 |
+
"colab": {
|
245 |
+
"base_uri": "https://localhost:8080/",
|
246 |
+
"height": 173,
|
247 |
+
"referenced_widgets": [
|
248 |
+
"a3b0c0581f1f4c428baaadd8e9a39b6f",
|
249 |
+
"2315228ff2b141afabe1263471f5364b",
|
250 |
+
"0474debc340943bd85f3daf92aebf7aa",
|
251 |
+
"cff1b0fa2ea24f45aab26685353eefdd",
|
252 |
+
"b7e20be79df246f19b35114a690e44f0",
|
253 |
+
"426eb100a94642f79e6b99777406a265",
|
254 |
+
"a36b5cf197dd4bd9a7f70aa6671b804c",
|
255 |
+
"0de4d0f282404edfbc191dca73f15f35",
|
256 |
+
"e58b5ad2f781475d8af2ddb38009baa6",
|
257 |
+
"33fbacbb2aa146cd90586357eec1dc3e",
|
258 |
+
"930b4d1d5f4b494b830df4d4c398e67c"
|
259 |
+
]
|
260 |
+
},
|
261 |
+
"execution": {
|
262 |
+
"iopub.execute_input": "2025-03-22T00:52:16.518175Z",
|
263 |
+
"iopub.status.busy": "2025-03-22T00:52:16.517841Z",
|
264 |
+
"iopub.status.idle": "2025-03-22T00:52:35.039329Z",
|
265 |
+
"shell.execute_reply": "2025-03-22T00:52:35.038356Z",
|
266 |
+
"shell.execute_reply.started": "2025-03-22T00:52:16.518146Z"
|
267 |
+
},
|
268 |
+
"id": "zK94B-Pfioto",
|
269 |
+
"outputId": "3f11cf35-c173-410d-f709-43552323f26f"
|
270 |
+
},
|
271 |
+
"outputs": [
|
272 |
+
{
|
273 |
+
"name": "stderr",
|
274 |
+
"output_type": "stream",
|
275 |
+
"text": [
|
276 |
+
"/usr/local/lib/python3.11/dist-packages/torch/nn/utils/weight_norm.py:143: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.\n",
|
277 |
+
" WeightNorm.apply(module, name, dim)\n"
|
278 |
+
]
|
279 |
+
},
|
280 |
+
{
|
281 |
+
"name": "stdout",
|
282 |
+
"output_type": "stream",
|
283 |
+
"text": [
|
284 |
+
"Missing tensor: mel_transformer.spectrogram.window\n",
|
285 |
+
"Missing tensor: mel_transformer.mel_scale.fb\n"
|
286 |
+
]
|
287 |
+
},
|
288 |
+
{
|
289 |
+
"name": "stderr",
|
290 |
+
"output_type": "stream",
|
291 |
+
"text": [
|
292 |
+
"Parameter 'function'=<function formatting_audio_func at 0x7bd438943100> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.\n",
|
293 |
+
"WARNING:datasets.fingerprint:Parameter 'function'=<function formatting_audio_func at 0x7bd438943100> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.\n"
|
294 |
+
]
|
295 |
+
},
|
296 |
+
{
|
297 |
+
"data": {
|
298 |
+
"application/vnd.jupyter.widget-view+json": {
|
299 |
+
"model_id": "a3b0c0581f1f4c428baaadd8e9a39b6f",
|
300 |
+
"version_major": 2,
|
301 |
+
"version_minor": 0
|
302 |
+
},
|
303 |
+
"text/plain": [
|
304 |
+
"Map: 0%| | 0/401 [00:00<?, ? examples/s]"
|
305 |
+
]
|
306 |
+
},
|
307 |
+
"metadata": {},
|
308 |
+
"output_type": "display_data"
|
309 |
+
}
|
310 |
+
],
|
311 |
+
"source": [
|
312 |
+
"#@title Tokenization Function\n",
|
313 |
+
"\n",
|
314 |
+
"import locale\n",
|
315 |
+
"import torchaudio.transforms as T\n",
|
316 |
+
"import os\n",
|
317 |
+
"import torch\n",
|
318 |
+
"import sys\n",
|
319 |
+
"import numpy as np\n",
|
320 |
+
"sys.path.append('Spark-TTS')\n",
|
321 |
+
"from sparktts.models.audio_tokenizer import BiCodecTokenizer\n",
|
322 |
+
"from sparktts.utils.audio import audio_volume_normalize\n",
|
323 |
+
"\n",
|
324 |
+
"audio_tokenizer = BiCodecTokenizer(\"Spark-TTS-0.5B\", \"cuda\")\n",
|
325 |
+
"def extract_wav2vec2_features( wavs: torch.Tensor) -> torch.Tensor:\n",
|
326 |
+
" \"\"\"extract wav2vec2 features\"\"\"\n",
|
327 |
+
"\n",
|
328 |
+
" if wavs.shape[0] != 1:\n",
|
329 |
+
"\n",
|
330 |
+
" raise ValueError(f\"Expected batch size 1, but got shape {wavs.shape}\")\n",
|
331 |
+
" wav_np = wavs.squeeze(0).cpu().numpy()\n",
|
332 |
+
"\n",
|
333 |
+
" processed = audio_tokenizer.processor(\n",
|
334 |
+
" wav_np,\n",
|
335 |
+
" sampling_rate=16000,\n",
|
336 |
+
" return_tensors=\"pt\",\n",
|
337 |
+
" padding=True,\n",
|
338 |
+
" )\n",
|
339 |
+
" input_values = processed.input_values\n",
|
340 |
+
"\n",
|
341 |
+
" input_values = input_values.to(audio_tokenizer.feature_extractor.device)\n",
|
342 |
+
"\n",
|
343 |
+
" model_output = audio_tokenizer.feature_extractor(\n",
|
344 |
+
" input_values,\n",
|
345 |
+
" )\n",
|
346 |
+
"\n",
|
347 |
+
"\n",
|
348 |
+
" if model_output.hidden_states is None:\n",
|
349 |
+
" raise ValueError(\"Wav2Vec2Model did not return hidden states. Ensure config `output_hidden_states=True`.\")\n",
|
350 |
+
"\n",
|
351 |
+
" num_layers = len(model_output.hidden_states)\n",
|
352 |
+
" required_layers = [11, 14, 16]\n",
|
353 |
+
" if any(l >= num_layers for l in required_layers):\n",
|
354 |
+
" raise IndexError(f\"Requested hidden state indices {required_layers} out of range for model with {num_layers} layers.\")\n",
|
355 |
+
"\n",
|
356 |
+
" feats_mix = (\n",
|
357 |
+
" model_output.hidden_states[11] + model_output.hidden_states[14] + model_output.hidden_states[16]\n",
|
358 |
+
" ) / 3\n",
|
359 |
+
"\n",
|
360 |
+
" return feats_mix\n",
|
361 |
+
"def formatting_audio_func(example):\n",
|
362 |
+
" text = f\"{example['source']}: {example['text']}\" if \"source\" in example else example[\"text\"]\n",
|
363 |
+
" audio_array = example[\"audio\"][\"array\"]\n",
|
364 |
+
" sampling_rate = example[\"audio\"][\"sampling_rate\"]\n",
|
365 |
+
"\n",
|
366 |
+
" target_sr = audio_tokenizer.config['sample_rate']\n",
|
367 |
+
"\n",
|
368 |
+
" if sampling_rate != target_sr:\n",
|
369 |
+
" resampler = T.Resample(orig_freq=sampling_rate, new_freq=target_sr)\n",
|
370 |
+
" audio_tensor_temp = torch.from_numpy(audio_array).float()\n",
|
371 |
+
" audio_array = resampler(audio_tensor_temp).numpy()\n",
|
372 |
+
"\n",
|
373 |
+
" if audio_tokenizer.config[\"volume_normalize\"]:\n",
|
374 |
+
" audio_array = audio_volume_normalize(audio_array)\n",
|
375 |
+
"\n",
|
376 |
+
" ref_wav_np = audio_tokenizer.get_ref_clip(audio_array)\n",
|
377 |
+
"\n",
|
378 |
+
" audio_tensor = torch.from_numpy(audio_array).unsqueeze(0).float().to(audio_tokenizer.device)\n",
|
379 |
+
" ref_wav_tensor = torch.from_numpy(ref_wav_np).unsqueeze(0).float().to(audio_tokenizer.device)\n",
|
380 |
+
"\n",
|
381 |
+
"\n",
|
382 |
+
" feat = extract_wav2vec2_features(audio_tensor)\n",
|
383 |
+
"\n",
|
384 |
+
" batch = {\n",
|
385 |
+
"\n",
|
386 |
+
" \"wav\": audio_tensor,\n",
|
387 |
+
" \"ref_wav\": ref_wav_tensor,\n",
|
388 |
+
" \"feat\": feat.to(audio_tokenizer.device),\n",
|
389 |
+
" }\n",
|
390 |
+
"\n",
|
391 |
+
"\n",
|
392 |
+
" semantic_token_ids, global_token_ids = audio_tokenizer.model.tokenize(batch)\n",
|
393 |
+
"\n",
|
394 |
+
" global_tokens = \"\".join(\n",
|
395 |
+
" [f\"<|bicodec_global_{i}|>\" for i in global_token_ids.squeeze().cpu().numpy()] # Squeeze batch dim\n",
|
396 |
+
" )\n",
|
397 |
+
" semantic_tokens = \"\".join(\n",
|
398 |
+
" [f\"<|bicodec_semantic_{i}|>\" for i in semantic_token_ids.squeeze().cpu().numpy()] # Squeeze batch dim\n",
|
399 |
+
" )\n",
|
400 |
+
"\n",
|
401 |
+
" inputs = [\n",
|
402 |
+
" \"<|task_tts|>\",\n",
|
403 |
+
" \"<|start_content|>\",\n",
|
404 |
+
" text,\n",
|
405 |
+
" \"<|end_content|>\",\n",
|
406 |
+
" \"<|start_global_token|>\",\n",
|
407 |
+
" global_tokens,\n",
|
408 |
+
" \"<|end_global_token|>\",\n",
|
409 |
+
" \"<|start_semantic_token|>\",\n",
|
410 |
+
" semantic_tokens,\n",
|
411 |
+
" \"<|end_semantic_token|>\",\n",
|
412 |
+
" \"<|im_end|>\"\n",
|
413 |
+
" ]\n",
|
414 |
+
" inputs = \"\".join(inputs)\n",
|
415 |
+
" return {\"text\": inputs}\n",
|
416 |
+
"\n",
|
417 |
+
"\n",
|
418 |
+
"dataset = dataset.map(formatting_audio_func, remove_columns=[\"audio\"])\n",
|
419 |
+
"print(\"Moving Bicodec model and Wav2Vec2Model to cpu.\")\n",
|
420 |
+
"audio_tokenizer.model.cpu()\n",
|
421 |
+
"audio_tokenizer.feature_extractor.cpu()\n",
|
422 |
+
"torch.cuda.empty_cache()"
|
423 |
+
]
|
424 |
+
},
|
425 |
+
{
|
426 |
+
"cell_type": "markdown",
|
427 |
+
"metadata": {
|
428 |
+
"id": "idAEIeSQ3xdS"
|
429 |
+
},
|
430 |
+
"source": [
|
431 |
+
"<a name=\"Train\"></a>\n",
|
432 |
+
"### Train the model\n",
|
433 |
+
"Now let's use Huggingface TRL's `SFTTrainer`! More docs here: [TRL SFT docs](https://huggingface.co/docs/trl/sft_trainer). We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`. We also support TRL's `DPOTrainer`!"
|
434 |
+
]
|
435 |
+
},
|
436 |
+
{
|
437 |
+
"cell_type": "code",
|
438 |
+
"execution_count": null,
|
439 |
+
"metadata": {
|
440 |
+
"execution": {
|
441 |
+
"iopub.execute_input": "2025-03-22T00:34:09.688959Z",
|
442 |
+
"iopub.status.busy": "2025-03-22T00:34:09.688649Z",
|
443 |
+
"iopub.status.idle": "2025-03-22T00:34:09.729661Z",
|
444 |
+
"shell.execute_reply": "2025-03-22T00:34:09.729001Z",
|
445 |
+
"shell.execute_reply.started": "2025-03-22T00:34:09.688939Z"
|
446 |
+
},
|
447 |
+
"id": "95_Nn-89DhsL"
|
448 |
+
},
|
449 |
+
"outputs": [],
|
450 |
+
"source": [
|
451 |
+
"from trl import SFTConfig, SFTTrainer\n",
|
452 |
+
"trainer = SFTTrainer(\n",
|
453 |
+
" model = model,\n",
|
454 |
+
" tokenizer = tokenizer,\n",
|
455 |
+
" train_dataset = dataset,\n",
|
456 |
+
" dataset_text_field = \"text\",\n",
|
457 |
+
" max_seq_length = max_seq_length,\n",
|
458 |
+
" packing = False, # Can make training 5x faster for short sequences.\n",
|
459 |
+
" args = SFTConfig(\n",
|
460 |
+
" per_device_train_batch_size = 2,\n",
|
461 |
+
" gradient_accumulation_steps = 4,\n",
|
462 |
+
" warmup_steps = 5,\n",
|
463 |
+
" num_train_epochs = 5, # Set this for 1 full training run.\n",
|
464 |
+
" #max_steps = 60,\n",
|
465 |
+
" learning_rate = 2e-4,\n",
|
466 |
+
" fp16 = False, # We're doing full float32 s disable mixed precision\n",
|
467 |
+
" bf16 = False, # We're doing full float32 s disable mixed precision\n",
|
468 |
+
" logging_steps = 1,\n",
|
469 |
+
" optim = \"adamw_8bit\",\n",
|
470 |
+
" weight_decay = 0.01,\n",
|
471 |
+
" lr_scheduler_type = \"linear\",\n",
|
472 |
+
" seed = 3407,\n",
|
473 |
+
" output_dir = \"outputs\",\n",
|
474 |
+
" report_to = \"tensorboard\", # Use this for WandB etc\n",
|
475 |
+
" ),\n",
|
476 |
+
")"
|
477 |
+
]
|
478 |
+
},
|
479 |
+
{
|
480 |
+
"cell_type": "code",
|
481 |
+
"execution_count": null,
|
482 |
+
"metadata": {
|
483 |
+
"id": "2ejIt2xSNKKp"
|
484 |
+
},
|
485 |
+
"outputs": [],
|
486 |
+
"source": [
|
487 |
+
"# @title Show current memory stats\n",
|
488 |
+
"gpu_stats = torch.cuda.get_device_properties(0)\n",
|
489 |
+
"start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
|
490 |
+
"max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n",
|
491 |
+
"print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n",
|
492 |
+
"print(f\"{start_gpu_memory} GB of memory reserved.\")"
|
493 |
+
]
|
494 |
+
},
|
495 |
+
{
|
496 |
+
"cell_type": "code",
|
497 |
+
"execution_count": null,
|
498 |
+
"metadata": {
|
499 |
+
"execution": {
|
500 |
+
"iopub.execute_input": "2025-03-22T00:34:12.049152Z",
|
501 |
+
"iopub.status.busy": "2025-03-22T00:34:12.048862Z",
|
502 |
+
"iopub.status.idle": "2025-03-22T00:34:14.404349Z",
|
503 |
+
"shell.execute_reply": "2025-03-22T00:34:14.403239Z",
|
504 |
+
"shell.execute_reply.started": "2025-03-22T00:34:12.049130Z"
|
505 |
+
},
|
506 |
+
"id": "yqxqAZ7KJ4oL"
|
507 |
+
},
|
508 |
+
"outputs": [],
|
509 |
+
"source": [
|
510 |
+
"trainer_stats = trainer.train()"
|
511 |
+
]
|
512 |
+
},
|
513 |
+
{
|
514 |
+
"cell_type": "code",
|
515 |
+
"execution_count": null,
|
516 |
+
"metadata": {
|
517 |
+
"cellView": "form",
|
518 |
+
"id": "pCqnaKmlO1U9"
|
519 |
+
},
|
520 |
+
"outputs": [],
|
521 |
+
"source": [
|
522 |
+
"# @title Show final memory and time stats\n",
|
523 |
+
"used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
|
524 |
+
"used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n",
|
525 |
+
"used_percentage = round(used_memory / max_memory * 100, 3)\n",
|
526 |
+
"lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n",
|
527 |
+
"print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n",
|
528 |
+
"print(\n",
|
529 |
+
" f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\"\n",
|
530 |
+
")\n",
|
531 |
+
"print(f\"Peak reserved memory = {used_memory} GB.\")\n",
|
532 |
+
"print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n",
|
533 |
+
"print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n",
|
534 |
+
"print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")"
|
535 |
+
]
|
536 |
+
},
|
537 |
+
{
|
538 |
+
"cell_type": "markdown",
|
539 |
+
"metadata": {
|
540 |
+
"id": "ekOmTR1hSNcr"
|
541 |
+
},
|
542 |
+
"source": [
|
543 |
+
"<a name=\"Inference\"></a>\n",
|
544 |
+
"### Inference\n",
|
545 |
+
"Let's run the model! You can change the prompts\n"
|
546 |
+
]
|
547 |
+
},
|
548 |
+
{
|
549 |
+
"cell_type": "code",
|
550 |
+
"execution_count": null,
|
551 |
+
"metadata": {
|
552 |
+
"id": "apUdB40Ep6Ki"
|
553 |
+
},
|
554 |
+
"outputs": [],
|
555 |
+
"source": [
|
556 |
+
"input_text = \"Hey there my name is Elise, <giggles> and I'm a speech generation model that can sound like a person.\"\n",
|
557 |
+
"\n",
|
558 |
+
"chosen_voice = None # None for single-speaker"
|
559 |
+
]
|
560 |
+
},
|
561 |
+
{
|
562 |
+
"cell_type": "code",
|
563 |
+
"execution_count": null,
|
564 |
+
"metadata": {
|
565 |
+
"cellView": "form",
|
566 |
+
"execution": {
|
567 |
+
"iopub.execute_input": "2025-03-22T00:52:35.040842Z",
|
568 |
+
"iopub.status.busy": "2025-03-22T00:52:35.040125Z",
|
569 |
+
"iopub.status.idle": "2025-03-22T00:52:35.050560Z",
|
570 |
+
"shell.execute_reply": "2025-03-22T00:52:35.049663Z",
|
571 |
+
"shell.execute_reply.started": "2025-03-22T00:52:35.040818Z"
|
572 |
+
},
|
573 |
+
"id": "krYI8PrRJ6MX"
|
574 |
+
},
|
575 |
+
"outputs": [],
|
576 |
+
"source": [
|
577 |
+
"#@title Run Inference\n",
|
578 |
+
"\n",
|
579 |
+
"import torch\n",
|
580 |
+
"import re\n",
|
581 |
+
"import numpy as np\n",
|
582 |
+
"from typing import Dict, Any\n",
|
583 |
+
"import torchaudio.transforms as T\n",
|
584 |
+
"\n",
|
585 |
+
"FastModel.for_inference(model) # Enable native 2x faster inference\n",
|
586 |
+
"\n",
|
587 |
+
"@torch.inference_mode()\n",
|
588 |
+
"def generate_speech_from_text(\n",
|
589 |
+
" text: str,\n",
|
590 |
+
" temperature: float = 0.8, # Generation temperature\n",
|
591 |
+
" top_k: int = 50, # Generation top_k\n",
|
592 |
+
" top_p: float = 1, # Generation top_p\n",
|
593 |
+
" max_new_audio_tokens: int = 2048, # Max tokens for audio part\n",
|
594 |
+
" device: torch.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
595 |
+
") -> np.ndarray:\n",
|
596 |
+
" \"\"\"\n",
|
597 |
+
" Generates speech audio from text using default voice control parameters.\n",
|
598 |
+
"\n",
|
599 |
+
" Args:\n",
|
600 |
+
" text (str): The text input to be converted to speech.\n",
|
601 |
+
" temperature (float): Sampling temperature for generation.\n",
|
602 |
+
" top_k (int): Top-k sampling parameter.\n",
|
603 |
+
" top_p (float): Top-p (nucleus) sampling parameter.\n",
|
604 |
+
" max_new_audio_tokens (int): Max number of new tokens to generate (limits audio length).\n",
|
605 |
+
" device (torch.device): Device to run inference on.\n",
|
606 |
+
"\n",
|
607 |
+
" Returns:\n",
|
608 |
+
" np.ndarray: Generated waveform as a NumPy array.\n",
|
609 |
+
" \"\"\"\n",
|
610 |
+
"\n",
|
611 |
+
" torch.compiler.reset()\n",
|
612 |
+
"\n",
|
613 |
+
" prompt = \"\".join([\n",
|
614 |
+
" \"<|task_tts|>\",\n",
|
615 |
+
" \"<|start_content|>\",\n",
|
616 |
+
" text,\n",
|
617 |
+
" \"<|end_content|>\",\n",
|
618 |
+
" \"<|start_global_token|>\"\n",
|
619 |
+
" ])\n",
|
620 |
+
"\n",
|
621 |
+
" model_inputs = tokenizer([prompt], return_tensors=\"pt\").to(device)\n",
|
622 |
+
"\n",
|
623 |
+
" print(\"Generating token sequence...\")\n",
|
624 |
+
" generated_ids = model.generate(\n",
|
625 |
+
" **model_inputs,\n",
|
626 |
+
" max_new_tokens=max_new_audio_tokens, # Limit generation length\n",
|
627 |
+
" do_sample=True,\n",
|
628 |
+
" temperature=temperature,\n",
|
629 |
+
" top_k=top_k,\n",
|
630 |
+
" top_p=top_p,\n",
|
631 |
+
" eos_token_id=tokenizer.eos_token_id, # Stop token\n",
|
632 |
+
" pad_token_id=tokenizer.pad_token_id # Use models pad token id\n",
|
633 |
+
" )\n",
|
634 |
+
" print(\"Token sequence generated.\")\n",
|
635 |
+
"\n",
|
636 |
+
"\n",
|
637 |
+
" generated_ids_trimmed = generated_ids[:, model_inputs.input_ids.shape[1]:]\n",
|
638 |
+
"\n",
|
639 |
+
"\n",
|
640 |
+
" predicts_text = tokenizer.batch_decode(generated_ids_trimmed, skip_special_tokens=False)[0]\n",
|
641 |
+
" # print(f\"\\nGenerated Text (for parsing):\\n{predicts_text}\\n\") # Debugging\n",
|
642 |
+
"\n",
|
643 |
+
" # Extract semantic token IDs using regex\n",
|
644 |
+
" semantic_matches = re.findall(r\"<\\|bicodec_semantic_(\\d+)\\|>\", predicts_text)\n",
|
645 |
+
" if not semantic_matches:\n",
|
646 |
+
" print(\"Warning: No semantic tokens found in the generated output.\")\n",
|
647 |
+
" # Handle appropriately - perhaps return silence or raise error\n",
|
648 |
+
" return np.array([], dtype=np.float32)\n",
|
649 |
+
"\n",
|
650 |
+
" pred_semantic_ids = torch.tensor([int(token) for token in semantic_matches]).long().unsqueeze(0) # Add batch dim\n",
|
651 |
+
"\n",
|
652 |
+
" # Extract global token IDs using regex (assuming controllable mode also generates these)\n",
|
653 |
+
" global_matches = re.findall(r\"<\\|bicodec_global_(\\d+)\\|>\", predicts_text)\n",
|
654 |
+
" if not global_matches:\n",
|
655 |
+
" print(\"Warning: No global tokens found in the generated output (controllable mode). Might use defaults or fail.\")\n",
|
656 |
+
" pred_global_ids = torch.zeros((1, 1), dtype=torch.long)\n",
|
657 |
+
" else:\n",
|
658 |
+
" pred_global_ids = torch.tensor([int(token) for token in global_matches]).long().unsqueeze(0) # Add batch dim\n",
|
659 |
+
"\n",
|
660 |
+
" pred_global_ids = pred_global_ids.unsqueeze(0) # Shape becomes (1, 1, N_global)\n",
|
661 |
+
"\n",
|
662 |
+
" print(f\"Found {pred_semantic_ids.shape[1]} semantic tokens.\")\n",
|
663 |
+
" print(f\"Found {pred_global_ids.shape[2]} global tokens.\")\n",
|
664 |
+
"\n",
|
665 |
+
"\n",
|
666 |
+
" # 5. Detokenize using BiCodecTokenizer\n",
|
667 |
+
" print(\"Detokenizing audio tokens...\")\n",
|
668 |
+
" # Ensure audio_tokenizer and its internal model are on the correct device\n",
|
669 |
+
" audio_tokenizer.device = device\n",
|
670 |
+
" audio_tokenizer.model.to(device)\n",
|
671 |
+
" # Squeeze the extra dimension from global tokens as seen in SparkTTS example\n",
|
672 |
+
" wav_np = audio_tokenizer.detokenize(\n",
|
673 |
+
" pred_global_ids.to(device).squeeze(0), # Shape (1, N_global)\n",
|
674 |
+
" pred_semantic_ids.to(device) # Shape (1, N_semantic)\n",
|
675 |
+
" )\n",
|
676 |
+
" print(\"Detokenization complete.\")\n",
|
677 |
+
"\n",
|
678 |
+
" return wav_np\n",
|
679 |
+
"\n",
|
680 |
+
"if __name__ == \"__main__\":\n",
|
681 |
+
" print(f\"Generating speech for: '{input_text}'\")\n",
|
682 |
+
" text = f\"{chosen_voice}: \" + input_text if chosen_voice else input_text\n",
|
683 |
+
" generated_waveform = generate_speech_from_text(input_text)\n",
|
684 |
+
"\n",
|
685 |
+
" if generated_waveform.size > 0:\n",
|
686 |
+
" import soundfile as sf\n",
|
687 |
+
" output_filename = \"generated_speech_controllable.wav\"\n",
|
688 |
+
" sample_rate = audio_tokenizer.config.get(\"sample_rate\", 16000)\n",
|
689 |
+
" sf.write(output_filename, generated_waveform, sample_rate)\n",
|
690 |
+
" print(f\"Audio saved to {output_filename}\")\n",
|
691 |
+
"\n",
|
692 |
+
" # Optional: Play in notebook\n",
|
693 |
+
" from IPython.display import Audio, display\n",
|
694 |
+
" display(Audio(generated_waveform, rate=sample_rate))\n",
|
695 |
+
" else:\n",
|
696 |
+
" print(\"Audio generation failed (no tokens found?).\")"
|
697 |
+
]
|
698 |
+
},
|
699 |
+
{
|
700 |
+
"cell_type": "markdown",
|
701 |
+
"metadata": {
|
702 |
+
"id": "uMuVrWbjAzhc"
|
703 |
+
},
|
704 |
+
"source": [
|
705 |
+
"<a name=\"Save\"></a>\n",
|
706 |
+
"### Saving, loading finetuned models\n",
|
707 |
+
"To save the final model as LoRA adapters, either use Huggingface's `push_to_hub` for an online save or `save_pretrained` for a local save.\n",
|
708 |
+
"\n",
|
709 |
+
"**[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!"
|
710 |
+
]
|
711 |
+
},
|
712 |
+
{
|
713 |
+
"cell_type": "code",
|
714 |
+
"execution_count": null,
|
715 |
+
"metadata": {
|
716 |
+
"id": "upcOlWe7A1vc"
|
717 |
+
},
|
718 |
+
"outputs": [],
|
719 |
+
"source": [
|
720 |
+
"model.save_pretrained(\"lora_model\") # Local saving\n",
|
721 |
+
"tokenizer.save_pretrained(\"lora_model\")\n",
|
722 |
+
"# model.push_to_hub(\"your_name/lora_model\", token = \"...\") # Online saving\n",
|
723 |
+
"# tokenizer.push_to_hub(\"your_name/lora_model\", token = \"...\") # Online saving"
|
724 |
+
]
|
725 |
+
},
|
726 |
+
{
|
727 |
+
"cell_type": "markdown",
|
728 |
+
"metadata": {
|
729 |
+
"id": "f422JgM9sdVT"
|
730 |
+
},
|
731 |
+
"source": [
|
732 |
+
"\n",
|
733 |
+
"### Saving to float16\n",
|
734 |
+
"\n",
|
735 |
+
"We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens."
|
736 |
+
]
|
737 |
+
},
|
738 |
+
{
|
739 |
+
"cell_type": "code",
|
740 |
+
"execution_count": null,
|
741 |
+
"metadata": {
|
742 |
+
"colab": {
|
743 |
+
"base_uri": "https://localhost:8080/"
|
744 |
+
},
|
745 |
+
"id": "iHjt_SMYsd3P",
|
746 |
+
"outputId": "bd8cccb7-6b95-45bf-80da-de120988447e"
|
747 |
+
},
|
748 |
+
"outputs": [
|
749 |
+
{
|
750 |
+
"name": "stderr",
|
751 |
+
"output_type": "stream",
|
752 |
+
"text": [
|
753 |
+
"Unsloth: You have 1 CPUs. Using `safe_serialization` is 10x slower.\n",
|
754 |
+
"We shall switch to Pytorch saving, which might take 3 minutes and not 30 minutes.\n",
|
755 |
+
"To force `safe_serialization`, set it to `None` instead.\n",
|
756 |
+
"Unsloth: Kaggle/Colab has limited disk space. We need to delete the downloaded\n",
|
757 |
+
"model which will save 4-16GB of disk space, allowing you to save on Kaggle/Colab.\n",
|
758 |
+
"Unsloth: Will remove a cached repo with size 15.1G\n"
|
759 |
+
]
|
760 |
+
},
|
761 |
+
{
|
762 |
+
"name": "stdout",
|
763 |
+
"output_type": "stream",
|
764 |
+
"text": [
|
765 |
+
"Unsloth: Merging 4bit and LoRA weights to 16bit...\n",
|
766 |
+
"Unsloth: Will use up to 3.99 out of 12.67 RAM for saving.\n",
|
767 |
+
"Unsloth: Saving model... This might take 5 minutes ...\n"
|
768 |
+
]
|
769 |
+
},
|
770 |
+
{
|
771 |
+
"name": "stderr",
|
772 |
+
"output_type": "stream",
|
773 |
+
"text": [
|
774 |
+
"100%|██████████| 28/28 [00:01<00:00, 27.83it/s]\n"
|
775 |
+
]
|
776 |
+
},
|
777 |
+
{
|
778 |
+
"name": "stdout",
|
779 |
+
"output_type": "stream",
|
780 |
+
"text": [
|
781 |
+
"Unsloth: Saving tokenizer... Done.\n",
|
782 |
+
"Unsloth: Saving model/pytorch_model-00001-of-00002.bin...\n",
|
783 |
+
"Unsloth: Saving model/pytorch_model-00002-of-00002.bin...\n",
|
784 |
+
"Done.\n"
|
785 |
+
]
|
786 |
+
}
|
787 |
+
],
|
788 |
+
"source": [
|
789 |
+
"# Merge to 16bit\n",
|
790 |
+
"if False: model.save_pretrained_merged(\"model\", tokenizer, save_method = \"merged_16bit\",)\n",
|
791 |
+
"if False: model.push_to_hub_merged(\"hf/model\", tokenizer, save_method = \"merged_16bit\", token = \"\")\n",
|
792 |
+
"\n",
|
793 |
+
"# Merge to 4bit\n",
|
794 |
+
"if False: model.save_pretrained_merged(\"model\", tokenizer, save_method = \"merged_4bit\",)\n",
|
795 |
+
"if False: model.push_to_hub_merged(\"hf/model\", tokenizer, save_method = \"merged_4bit\", token = \"\")\n",
|
796 |
+
"\n",
|
797 |
+
"# Just LoRA adapters\n",
|
798 |
+
"if False:\n",
|
799 |
+
" model.save_pretrained(\"model\")\n",
|
800 |
+
" tokenizer.save_pretrained(\"model\")\n",
|
801 |
+
"if False:\n",
|
802 |
+
" model.push_to_hub(\"hf/model\", token = \"\")\n",
|
803 |
+
" tokenizer.push_to_hub(\"hf/model\", token = \"\")\n"
|
804 |
+
]
|
805 |
+
},
|
806 |
+
{
|
807 |
+
"cell_type": "markdown",
|
808 |
+
"metadata": {
|
809 |
+
"id": "egOSE7Cgynx7"
|
810 |
+
},
|
811 |
+
"source": [
|
812 |
+
"And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!\n",
|
813 |
+
"\n",
|
814 |
+
"Some other links:\n",
|
815 |
+
"1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)\n",
|
816 |
+
"2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)\n",
|
817 |
+
"3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)\n",
|
818 |
+
"6. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://docs.unsloth.ai/get-started/unsloth-notebooks)!\n",
|
819 |
+
"\n",
|
820 |
+
"<div class=\"align-center\">\n",
|
821 |
+
" <a href=\"https://unsloth.ai\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png\" width=\"115\"></a>\n",
|
822 |
+
" <a href=\"https://discord.gg/unsloth\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/Discord.png\" width=\"145\"></a>\n",
|
823 |
+
" <a href=\"https://docs.unsloth.ai/\"><img src=\"https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true\" width=\"125\"></a>\n",
|
824 |
+
"\n",
|
825 |
+
" Join Discord if you need help + ⭐️ <i>Star us on <a href=\"https://github.com/unslothai/unsloth\">Github</a> </i> ⭐️\n",
|
826 |
+
"</div>\n"
|
827 |
+
]
|
828 |
+
}
|
829 |
+
],
|
830 |
+
"metadata": {
|
831 |
+
"accelerator": "GPU",
|
832 |
+
"colab": {
|
833 |
+
"gpuType": "T4",
|
834 |
+
"provenance": []
|
835 |
+
},
|
836 |
+
"kaggle": {
|
837 |
+
"accelerator": "nvidiaTeslaT4",
|
838 |
+
"dataSources": [],
|
839 |
+
"dockerImageVersionId": 30919,
|
840 |
+
"isGpuEnabled": true,
|
841 |
+
"isInternetEnabled": true,
|
842 |
+
"language": "python",
|
843 |
+
"sourceType": "notebook"
|
844 |
+
},
|
845 |
+
"kernelspec": {
|
846 |
+
"display_name": "TTS_ft",
|
847 |
+
"language": "python",
|
848 |
+
"name": "tts_ft"
|
849 |
+
},
|
850 |
+
"language_info": {
|
851 |
+
"codemirror_mode": {
|
852 |
+
"name": "ipython",
|
853 |
+
"version": 3
|
854 |
+
},
|
855 |
+
"file_extension": ".py",
|
856 |
+
"mimetype": "text/x-python",
|
857 |
+
"name": "python",
|
858 |
+
"nbconvert_exporter": "python",
|
859 |
+
"pygments_lexer": "ipython3",
|
860 |
+
"version": "3.12.3"
|
861 |
+
},
|
862 |
+
"widgets": {
|
863 |
+
"application/vnd.jupyter.widget-state+json": {
|
864 |
+
"0474debc340943bd85f3daf92aebf7aa": {
|
865 |
+
"model_module": "@jupyter-widgets/controls",
|
866 |
+
"model_module_version": "1.5.0",
|
867 |
+
"model_name": "FloatProgressModel",
|
868 |
+
"state": {
|
869 |
+
"_dom_classes": [],
|
870 |
+
"_model_module": "@jupyter-widgets/controls",
|
871 |
+
"_model_module_version": "1.5.0",
|
872 |
+
"_model_name": "FloatProgressModel",
|
873 |
+
"_view_count": null,
|
874 |
+
"_view_module": "@jupyter-widgets/controls",
|
875 |
+
"_view_module_version": "1.5.0",
|
876 |
+
"_view_name": "ProgressView",
|
877 |
+
"bar_style": "",
|
878 |
+
"description": "",
|
879 |
+
"description_tooltip": null,
|
880 |
+
"layout": "IPY_MODEL_0de4d0f282404edfbc191dca73f15f35",
|
881 |
+
"max": 401,
|
882 |
+
"min": 0,
|
883 |
+
"orientation": "horizontal",
|
884 |
+
"style": "IPY_MODEL_e58b5ad2f781475d8af2ddb38009baa6",
|
885 |
+
"value": 354
|
886 |
+
}
|
887 |
+
},
|
888 |
+
"0de4d0f282404edfbc191dca73f15f35": {
|
889 |
+
"model_module": "@jupyter-widgets/base",
|
890 |
+
"model_module_version": "1.2.0",
|
891 |
+
"model_name": "LayoutModel",
|
892 |
+
"state": {
|
893 |
+
"_model_module": "@jupyter-widgets/base",
|
894 |
+
"_model_module_version": "1.2.0",
|
895 |
+
"_model_name": "LayoutModel",
|
896 |
+
"_view_count": null,
|
897 |
+
"_view_module": "@jupyter-widgets/base",
|
898 |
+
"_view_module_version": "1.2.0",
|
899 |
+
"_view_name": "LayoutView",
|
900 |
+
"align_content": null,
|
901 |
+
"align_items": null,
|
902 |
+
"align_self": null,
|
903 |
+
"border": null,
|
904 |
+
"bottom": null,
|
905 |
+
"display": null,
|
906 |
+
"flex": null,
|
907 |
+
"flex_flow": null,
|
908 |
+
"grid_area": null,
|
909 |
+
"grid_auto_columns": null,
|
910 |
+
"grid_auto_flow": null,
|
911 |
+
"grid_auto_rows": null,
|
912 |
+
"grid_column": null,
|
913 |
+
"grid_gap": null,
|
914 |
+
"grid_row": null,
|
915 |
+
"grid_template_areas": null,
|
916 |
+
"grid_template_columns": null,
|
917 |
+
"grid_template_rows": null,
|
918 |
+
"height": null,
|
919 |
+
"justify_content": null,
|
920 |
+
"justify_items": null,
|
921 |
+
"left": null,
|
922 |
+
"margin": null,
|
923 |
+
"max_height": null,
|
924 |
+
"max_width": null,
|
925 |
+
"min_height": null,
|
926 |
+
"min_width": null,
|
927 |
+
"object_fit": null,
|
928 |
+
"object_position": null,
|
929 |
+
"order": null,
|
930 |
+
"overflow": null,
|
931 |
+
"overflow_x": null,
|
932 |
+
"overflow_y": null,
|
933 |
+
"padding": null,
|
934 |
+
"right": null,
|
935 |
+
"top": null,
|
936 |
+
"visibility": null,
|
937 |
+
"width": null
|
938 |
+
}
|
939 |
+
},
|
940 |
+
"2315228ff2b141afabe1263471f5364b": {
|
941 |
+
"model_module": "@jupyter-widgets/controls",
|
942 |
+
"model_module_version": "1.5.0",
|
943 |
+
"model_name": "HTMLModel",
|
944 |
+
"state": {
|
945 |
+
"_dom_classes": [],
|
946 |
+
"_model_module": "@jupyter-widgets/controls",
|
947 |
+
"_model_module_version": "1.5.0",
|
948 |
+
"_model_name": "HTMLModel",
|
949 |
+
"_view_count": null,
|
950 |
+
"_view_module": "@jupyter-widgets/controls",
|
951 |
+
"_view_module_version": "1.5.0",
|
952 |
+
"_view_name": "HTMLView",
|
953 |
+
"description": "",
|
954 |
+
"description_tooltip": null,
|
955 |
+
"layout": "IPY_MODEL_426eb100a94642f79e6b99777406a265",
|
956 |
+
"placeholder": "",
|
957 |
+
"style": "IPY_MODEL_a36b5cf197dd4bd9a7f70aa6671b804c",
|
958 |
+
"value": "Map: 88%"
|
959 |
+
}
|
960 |
+
},
|
961 |
+
"33fbacbb2aa146cd90586357eec1dc3e": {
|
962 |
+
"model_module": "@jupyter-widgets/base",
|
963 |
+
"model_module_version": "1.2.0",
|
964 |
+
"model_name": "LayoutModel",
|
965 |
+
"state": {
|
966 |
+
"_model_module": "@jupyter-widgets/base",
|
967 |
+
"_model_module_version": "1.2.0",
|
968 |
+
"_model_name": "LayoutModel",
|
969 |
+
"_view_count": null,
|
970 |
+
"_view_module": "@jupyter-widgets/base",
|
971 |
+
"_view_module_version": "1.2.0",
|
972 |
+
"_view_name": "LayoutView",
|
973 |
+
"align_content": null,
|
974 |
+
"align_items": null,
|
975 |
+
"align_self": null,
|
976 |
+
"border": null,
|
977 |
+
"bottom": null,
|
978 |
+
"display": null,
|
979 |
+
"flex": null,
|
980 |
+
"flex_flow": null,
|
981 |
+
"grid_area": null,
|
982 |
+
"grid_auto_columns": null,
|
983 |
+
"grid_auto_flow": null,
|
984 |
+
"grid_auto_rows": null,
|
985 |
+
"grid_column": null,
|
986 |
+
"grid_gap": null,
|
987 |
+
"grid_row": null,
|
988 |
+
"grid_template_areas": null,
|
989 |
+
"grid_template_columns": null,
|
990 |
+
"grid_template_rows": null,
|
991 |
+
"height": null,
|
992 |
+
"justify_content": null,
|
993 |
+
"justify_items": null,
|
994 |
+
"left": null,
|
995 |
+
"margin": null,
|
996 |
+
"max_height": null,
|
997 |
+
"max_width": null,
|
998 |
+
"min_height": null,
|
999 |
+
"min_width": null,
|
1000 |
+
"object_fit": null,
|
1001 |
+
"object_position": null,
|
1002 |
+
"order": null,
|
1003 |
+
"overflow": null,
|
1004 |
+
"overflow_x": null,
|
1005 |
+
"overflow_y": null,
|
1006 |
+
"padding": null,
|
1007 |
+
"right": null,
|
1008 |
+
"top": null,
|
1009 |
+
"visibility": null,
|
1010 |
+
"width": null
|
1011 |
+
}
|
1012 |
+
},
|
1013 |
+
"426eb100a94642f79e6b99777406a265": {
|
1014 |
+
"model_module": "@jupyter-widgets/base",
|
1015 |
+
"model_module_version": "1.2.0",
|
1016 |
+
"model_name": "LayoutModel",
|
1017 |
+
"state": {
|
1018 |
+
"_model_module": "@jupyter-widgets/base",
|
1019 |
+
"_model_module_version": "1.2.0",
|
1020 |
+
"_model_name": "LayoutModel",
|
1021 |
+
"_view_count": null,
|
1022 |
+
"_view_module": "@jupyter-widgets/base",
|
1023 |
+
"_view_module_version": "1.2.0",
|
1024 |
+
"_view_name": "LayoutView",
|
1025 |
+
"align_content": null,
|
1026 |
+
"align_items": null,
|
1027 |
+
"align_self": null,
|
1028 |
+
"border": null,
|
1029 |
+
"bottom": null,
|
1030 |
+
"display": null,
|
1031 |
+
"flex": null,
|
1032 |
+
"flex_flow": null,
|
1033 |
+
"grid_area": null,
|
1034 |
+
"grid_auto_columns": null,
|
1035 |
+
"grid_auto_flow": null,
|
1036 |
+
"grid_auto_rows": null,
|
1037 |
+
"grid_column": null,
|
1038 |
+
"grid_gap": null,
|
1039 |
+
"grid_row": null,
|
1040 |
+
"grid_template_areas": null,
|
1041 |
+
"grid_template_columns": null,
|
1042 |
+
"grid_template_rows": null,
|
1043 |
+
"height": null,
|
1044 |
+
"justify_content": null,
|
1045 |
+
"justify_items": null,
|
1046 |
+
"left": null,
|
1047 |
+
"margin": null,
|
1048 |
+
"max_height": null,
|
1049 |
+
"max_width": null,
|
1050 |
+
"min_height": null,
|
1051 |
+
"min_width": null,
|
1052 |
+
"object_fit": null,
|
1053 |
+
"object_position": null,
|
1054 |
+
"order": null,
|
1055 |
+
"overflow": null,
|
1056 |
+
"overflow_x": null,
|
1057 |
+
"overflow_y": null,
|
1058 |
+
"padding": null,
|
1059 |
+
"right": null,
|
1060 |
+
"top": null,
|
1061 |
+
"visibility": null,
|
1062 |
+
"width": null
|
1063 |
+
}
|
1064 |
+
},
|
1065 |
+
"930b4d1d5f4b494b830df4d4c398e67c": {
|
1066 |
+
"model_module": "@jupyter-widgets/controls",
|
1067 |
+
"model_module_version": "1.5.0",
|
1068 |
+
"model_name": "DescriptionStyleModel",
|
1069 |
+
"state": {
|
1070 |
+
"_model_module": "@jupyter-widgets/controls",
|
1071 |
+
"_model_module_version": "1.5.0",
|
1072 |
+
"_model_name": "DescriptionStyleModel",
|
1073 |
+
"_view_count": null,
|
1074 |
+
"_view_module": "@jupyter-widgets/base",
|
1075 |
+
"_view_module_version": "1.2.0",
|
1076 |
+
"_view_name": "StyleView",
|
1077 |
+
"description_width": ""
|
1078 |
+
}
|
1079 |
+
},
|
1080 |
+
"a36b5cf197dd4bd9a7f70aa6671b804c": {
|
1081 |
+
"model_module": "@jupyter-widgets/controls",
|
1082 |
+
"model_module_version": "1.5.0",
|
1083 |
+
"model_name": "DescriptionStyleModel",
|
1084 |
+
"state": {
|
1085 |
+
"_model_module": "@jupyter-widgets/controls",
|
1086 |
+
"_model_module_version": "1.5.0",
|
1087 |
+
"_model_name": "DescriptionStyleModel",
|
1088 |
+
"_view_count": null,
|
1089 |
+
"_view_module": "@jupyter-widgets/base",
|
1090 |
+
"_view_module_version": "1.2.0",
|
1091 |
+
"_view_name": "StyleView",
|
1092 |
+
"description_width": ""
|
1093 |
+
}
|
1094 |
+
},
|
1095 |
+
"a3b0c0581f1f4c428baaadd8e9a39b6f": {
|
1096 |
+
"model_module": "@jupyter-widgets/controls",
|
1097 |
+
"model_module_version": "1.5.0",
|
1098 |
+
"model_name": "HBoxModel",
|
1099 |
+
"state": {
|
1100 |
+
"_dom_classes": [],
|
1101 |
+
"_model_module": "@jupyter-widgets/controls",
|
1102 |
+
"_model_module_version": "1.5.0",
|
1103 |
+
"_model_name": "HBoxModel",
|
1104 |
+
"_view_count": null,
|
1105 |
+
"_view_module": "@jupyter-widgets/controls",
|
1106 |
+
"_view_module_version": "1.5.0",
|
1107 |
+
"_view_name": "HBoxView",
|
1108 |
+
"box_style": "",
|
1109 |
+
"children": [
|
1110 |
+
"IPY_MODEL_2315228ff2b141afabe1263471f5364b",
|
1111 |
+
"IPY_MODEL_0474debc340943bd85f3daf92aebf7aa",
|
1112 |
+
"IPY_MODEL_cff1b0fa2ea24f45aab26685353eefdd"
|
1113 |
+
],
|
1114 |
+
"layout": "IPY_MODEL_b7e20be79df246f19b35114a690e44f0"
|
1115 |
+
}
|
1116 |
+
},
|
1117 |
+
"b7e20be79df246f19b35114a690e44f0": {
|
1118 |
+
"model_module": "@jupyter-widgets/base",
|
1119 |
+
"model_module_version": "1.2.0",
|
1120 |
+
"model_name": "LayoutModel",
|
1121 |
+
"state": {
|
1122 |
+
"_model_module": "@jupyter-widgets/base",
|
1123 |
+
"_model_module_version": "1.2.0",
|
1124 |
+
"_model_name": "LayoutModel",
|
1125 |
+
"_view_count": null,
|
1126 |
+
"_view_module": "@jupyter-widgets/base",
|
1127 |
+
"_view_module_version": "1.2.0",
|
1128 |
+
"_view_name": "LayoutView",
|
1129 |
+
"align_content": null,
|
1130 |
+
"align_items": null,
|
1131 |
+
"align_self": null,
|
1132 |
+
"border": null,
|
1133 |
+
"bottom": null,
|
1134 |
+
"display": null,
|
1135 |
+
"flex": null,
|
1136 |
+
"flex_flow": null,
|
1137 |
+
"grid_area": null,
|
1138 |
+
"grid_auto_columns": null,
|
1139 |
+
"grid_auto_flow": null,
|
1140 |
+
"grid_auto_rows": null,
|
1141 |
+
"grid_column": null,
|
1142 |
+
"grid_gap": null,
|
1143 |
+
"grid_row": null,
|
1144 |
+
"grid_template_areas": null,
|
1145 |
+
"grid_template_columns": null,
|
1146 |
+
"grid_template_rows": null,
|
1147 |
+
"height": null,
|
1148 |
+
"justify_content": null,
|
1149 |
+
"justify_items": null,
|
1150 |
+
"left": null,
|
1151 |
+
"margin": null,
|
1152 |
+
"max_height": null,
|
1153 |
+
"max_width": null,
|
1154 |
+
"min_height": null,
|
1155 |
+
"min_width": null,
|
1156 |
+
"object_fit": null,
|
1157 |
+
"object_position": null,
|
1158 |
+
"order": null,
|
1159 |
+
"overflow": null,
|
1160 |
+
"overflow_x": null,
|
1161 |
+
"overflow_y": null,
|
1162 |
+
"padding": null,
|
1163 |
+
"right": null,
|
1164 |
+
"top": null,
|
1165 |
+
"visibility": null,
|
1166 |
+
"width": null
|
1167 |
+
}
|
1168 |
+
},
|
1169 |
+
"cff1b0fa2ea24f45aab26685353eefdd": {
|
1170 |
+
"model_module": "@jupyter-widgets/controls",
|
1171 |
+
"model_module_version": "1.5.0",
|
1172 |
+
"model_name": "HTMLModel",
|
1173 |
+
"state": {
|
1174 |
+
"_dom_classes": [],
|
1175 |
+
"_model_module": "@jupyter-widgets/controls",
|
1176 |
+
"_model_module_version": "1.5.0",
|
1177 |
+
"_model_name": "HTMLModel",
|
1178 |
+
"_view_count": null,
|
1179 |
+
"_view_module": "@jupyter-widgets/controls",
|
1180 |
+
"_view_module_version": "1.5.0",
|
1181 |
+
"_view_name": "HTMLView",
|
1182 |
+
"description": "",
|
1183 |
+
"description_tooltip": null,
|
1184 |
+
"layout": "IPY_MODEL_33fbacbb2aa146cd90586357eec1dc3e",
|
1185 |
+
"placeholder": "",
|
1186 |
+
"style": "IPY_MODEL_930b4d1d5f4b494b830df4d4c398e67c",
|
1187 |
+
"value": " 354/401 [03:01<00:22, 2.11 examples/s]"
|
1188 |
+
}
|
1189 |
+
},
|
1190 |
+
"e58b5ad2f781475d8af2ddb38009baa6": {
|
1191 |
+
"model_module": "@jupyter-widgets/controls",
|
1192 |
+
"model_module_version": "1.5.0",
|
1193 |
+
"model_name": "ProgressStyleModel",
|
1194 |
+
"state": {
|
1195 |
+
"_model_module": "@jupyter-widgets/controls",
|
1196 |
+
"_model_module_version": "1.5.0",
|
1197 |
+
"_model_name": "ProgressStyleModel",
|
1198 |
+
"_view_count": null,
|
1199 |
+
"_view_module": "@jupyter-widgets/base",
|
1200 |
+
"_view_module_version": "1.2.0",
|
1201 |
+
"_view_name": "StyleView",
|
1202 |
+
"bar_color": null,
|
1203 |
+
"description_width": ""
|
1204 |
+
}
|
1205 |
+
}
|
1206 |
+
}
|
1207 |
+
}
|
1208 |
+
},
|
1209 |
+
"nbformat": 4,
|
1210 |
+
"nbformat_minor": 4
|
1211 |
+
}
|
.ipynb_checkpoints/config-checkpoint.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
highpass_cutoff_freq: 40
|
2 |
+
sample_rate: 16000
|
3 |
+
segment_duration: 2.4 # (s)
|
4 |
+
max_val_duration: 12 # (s)
|
5 |
+
latent_hop_length: 320
|
6 |
+
ref_segment_duration: 6
|
7 |
+
volume_normalize: true
|
BiCodec/config.yaml
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
audio_tokenizer:
|
2 |
+
mel_params:
|
3 |
+
sample_rate: 16000
|
4 |
+
n_fft: 1024
|
5 |
+
win_length: 640
|
6 |
+
hop_length: 320
|
7 |
+
mel_fmin: 10
|
8 |
+
mel_fmax: null
|
9 |
+
num_mels: 128
|
10 |
+
|
11 |
+
encoder:
|
12 |
+
input_channels: 1024
|
13 |
+
vocos_dim: 384
|
14 |
+
vocos_intermediate_dim: 2048
|
15 |
+
vocos_num_layers: 12
|
16 |
+
out_channels: 1024
|
17 |
+
sample_ratios: [1,1]
|
18 |
+
|
19 |
+
decoder:
|
20 |
+
input_channel: 1024
|
21 |
+
channels: 1536
|
22 |
+
rates: [8, 5, 4, 2]
|
23 |
+
kernel_sizes: [16,11,8,4]
|
24 |
+
|
25 |
+
quantizer:
|
26 |
+
input_dim: 1024
|
27 |
+
codebook_size: 8192
|
28 |
+
codebook_dim: 8
|
29 |
+
commitment: 0.25
|
30 |
+
codebook_loss_weight: 2.0
|
31 |
+
use_l2_normlize: True
|
32 |
+
threshold_ema_dead_code: 0.2
|
33 |
+
|
34 |
+
speaker_encoder:
|
35 |
+
input_dim: 128
|
36 |
+
out_dim: 1024
|
37 |
+
latent_dim: 128
|
38 |
+
token_num: 32
|
39 |
+
fsq_levels: [4, 4, 4, 4, 4, 4]
|
40 |
+
fsq_num_quantizers: 1
|
41 |
+
|
42 |
+
prenet:
|
43 |
+
input_channels: 1024
|
44 |
+
vocos_dim: 384
|
45 |
+
vocos_intermediate_dim: 2048
|
46 |
+
vocos_num_layers: 12
|
47 |
+
out_channels: 1024
|
48 |
+
condition_dim: 1024
|
49 |
+
sample_ratios: [1,1]
|
50 |
+
use_tanh_at_final: False
|
51 |
+
|
52 |
+
postnet:
|
53 |
+
input_channels: 1024
|
54 |
+
vocos_dim: 384
|
55 |
+
vocos_intermediate_dim: 2048
|
56 |
+
vocos_num_layers: 6
|
57 |
+
out_channels: 1024
|
58 |
+
use_tanh_at_final: False
|
59 |
+
|
60 |
+
|
BiCodec/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e9940cd48d4446e4340ced82d234bf5618350dd9f5db900ebe47a4fdb03867ec
|
3 |
+
size 625518756
|
LLM/.gitattributes
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
LLM/added_tokens.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
LLM/chat_template.jinja
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{%- if tools %}
|
2 |
+
{{- '<|im_start|>system\n' }}
|
3 |
+
{%- if messages[0]['role'] == 'system' %}
|
4 |
+
{{- messages[0]['content'] }}
|
5 |
+
{%- else %}
|
6 |
+
{{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}
|
7 |
+
{%- endif %}
|
8 |
+
{{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
9 |
+
{%- for tool in tools %}
|
10 |
+
{{- "\n" }}
|
11 |
+
{{- tool | tojson }}
|
12 |
+
{%- endfor %}
|
13 |
+
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
14 |
+
{%- else %}
|
15 |
+
{%- if messages[0]['role'] == 'system' %}
|
16 |
+
{{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
|
17 |
+
{%- else %}
|
18 |
+
{{- '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n' }}
|
19 |
+
{%- endif %}
|
20 |
+
{%- endif %}
|
21 |
+
{%- for message in messages %}
|
22 |
+
{%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
|
23 |
+
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
|
24 |
+
{%- elif message.role == "assistant" %}
|
25 |
+
{{- '<|im_start|>' + message.role }}
|
26 |
+
{%- if message.content %}
|
27 |
+
{{- '\n' + message.content }}
|
28 |
+
{%- endif %}
|
29 |
+
{%- for tool_call in message.tool_calls %}
|
30 |
+
{%- if tool_call.function is defined %}
|
31 |
+
{%- set tool_call = tool_call.function %}
|
32 |
+
{%- endif %}
|
33 |
+
{{- '\n<tool_call>\n{"name": "' }}
|
34 |
+
{{- tool_call.name }}
|
35 |
+
{{- '", "arguments": ' }}
|
36 |
+
{{- tool_call.arguments | tojson }}
|
37 |
+
{{- '}\n</tool_call>' }}
|
38 |
+
{%- endfor %}
|
39 |
+
{{- '<|im_end|>\n' }}
|
40 |
+
{%- elif message.role == "tool" %}
|
41 |
+
{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
|
42 |
+
{{- '<|im_start|>user' }}
|
43 |
+
{%- endif %}
|
44 |
+
{{- '\n<tool_response>\n' }}
|
45 |
+
{{- message.content }}
|
46 |
+
{{- '\n</tool_response>' }}
|
47 |
+
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
48 |
+
{{- '<|im_end|>\n' }}
|
49 |
+
{%- endif %}
|
50 |
+
{%- endif %}
|
51 |
+
{%- endfor %}
|
52 |
+
{%- if add_generation_prompt %}
|
53 |
+
{{- '<|im_start|>assistant\n' }}
|
54 |
+
{%- endif %}
|
LLM/config.json
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"Qwen2ForCausalLM"
|
4 |
+
],
|
5 |
+
"attention_dropout": 0.0,
|
6 |
+
"bos_token_id": 151643,
|
7 |
+
"eos_token_id": 151645,
|
8 |
+
"hidden_act": "silu",
|
9 |
+
"hidden_size": 896,
|
10 |
+
"initializer_range": 0.02,
|
11 |
+
"intermediate_size": 4864,
|
12 |
+
"layer_types": [
|
13 |
+
"full_attention",
|
14 |
+
"full_attention",
|
15 |
+
"full_attention",
|
16 |
+
"full_attention",
|
17 |
+
"full_attention",
|
18 |
+
"full_attention",
|
19 |
+
"full_attention",
|
20 |
+
"full_attention",
|
21 |
+
"full_attention",
|
22 |
+
"full_attention",
|
23 |
+
"full_attention",
|
24 |
+
"full_attention",
|
25 |
+
"full_attention",
|
26 |
+
"full_attention",
|
27 |
+
"full_attention",
|
28 |
+
"full_attention",
|
29 |
+
"full_attention",
|
30 |
+
"full_attention",
|
31 |
+
"full_attention",
|
32 |
+
"full_attention",
|
33 |
+
"full_attention",
|
34 |
+
"full_attention",
|
35 |
+
"full_attention",
|
36 |
+
"full_attention"
|
37 |
+
],
|
38 |
+
"max_position_embeddings": 32768,
|
39 |
+
"max_window_layers": 21,
|
40 |
+
"model_type": "qwen2",
|
41 |
+
"num_attention_heads": 14,
|
42 |
+
"num_hidden_layers": 24,
|
43 |
+
"num_key_value_heads": 2,
|
44 |
+
"pad_token_id": 151643,
|
45 |
+
"rms_norm_eps": 1e-06,
|
46 |
+
"rope_scaling": null,
|
47 |
+
"rope_theta": 1000000.0,
|
48 |
+
"sliding_window": null,
|
49 |
+
"tie_word_embeddings": true,
|
50 |
+
"torch_dtype": "float32",
|
51 |
+
"transformers_version": "4.54.1",
|
52 |
+
"unsloth_version": "2025.8.1",
|
53 |
+
"use_cache": true,
|
54 |
+
"use_sliding_window": false,
|
55 |
+
"vocab_size": 166000
|
56 |
+
}
|
LLM/generation_config.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 151643,
|
4 |
+
"eos_token_id": 151645,
|
5 |
+
"max_length": 32768,
|
6 |
+
"pad_token_id": 151643,
|
7 |
+
"transformers_version": "4.54.1"
|
8 |
+
}
|
LLM/merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
LLM/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:de8e649c4c889e92eca6d18afbb7ea7be71ac874797c29e954f1ff89bfd4e237
|
3 |
+
size 2026568872
|
LLM/special_tokens_map.json
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"additional_special_tokens": [
|
3 |
+
"<|im_start|>",
|
4 |
+
"<|im_end|>",
|
5 |
+
"<|object_ref_start|>",
|
6 |
+
"<|object_ref_end|>",
|
7 |
+
"<|box_start|>",
|
8 |
+
"<|box_end|>",
|
9 |
+
"<|quad_start|>",
|
10 |
+
"<|quad_end|>",
|
11 |
+
"<|vision_start|>",
|
12 |
+
"<|vision_end|>",
|
13 |
+
"<|vision_pad|>",
|
14 |
+
"<|image_pad|>",
|
15 |
+
"<|video_pad|>"
|
16 |
+
],
|
17 |
+
"eos_token": {
|
18 |
+
"content": "<|im_end|>",
|
19 |
+
"lstrip": false,
|
20 |
+
"normalized": false,
|
21 |
+
"rstrip": false,
|
22 |
+
"single_word": false
|
23 |
+
},
|
24 |
+
"pad_token": {
|
25 |
+
"content": "<|endoftext|>",
|
26 |
+
"lstrip": false,
|
27 |
+
"normalized": false,
|
28 |
+
"rstrip": false,
|
29 |
+
"single_word": false
|
30 |
+
}
|
31 |
+
}
|
LLM/tokenizer.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9c8b057d6ca205a429cc3428b9fc815f0d6ee1d53106dd5e5b129ef9db2ff057
|
3 |
+
size 14129172
|
LLM/tokenizer_config.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
LLM/vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
README.md
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: cc-by-nc-sa-4.0
|
3 |
+
language:
|
4 |
+
- en
|
5 |
+
- zh
|
6 |
+
tags:
|
7 |
+
- text-to-speech
|
8 |
+
library_tag: spark-tts
|
9 |
+
base_model:
|
10 |
+
- SparkAudio/Spark-TTS-0.5B
|
11 |
+
---
|
12 |
+
<div>
|
13 |
+
<p style="margin-bottom: 0; margin-top: 0;">
|
14 |
+
<strong>See <a href="https://huggingface.co/collections/unsloth/text-to-speech-tts-models-68007ab12522e96be1e02155">our collection</a> for all our TTS model uploads.</strong>
|
15 |
+
</p>
|
16 |
+
<p style="margin-bottom: 0;">
|
17 |
+
<em>Learn to fine-tune TTS models - <a href="https://docs.unsloth.ai/basics/text-to-speech-tts-fine-tuning">Read our Guide</a>.</em>
|
18 |
+
</p>
|
19 |
+
<p style="margin-top: 0;margin-bottom: 0;">
|
20 |
+
<em><a href="https://docs.unsloth.ai/basics/unsloth-dynamic-v2.0-gguf">Unsloth Dynamic 2.0</a> achieves superior accuracy & outperforms other leading quants.</em>
|
21 |
+
</p>
|
22 |
+
<div style="display: flex; gap: 5px; align-items: center; ">
|
23 |
+
<a href="https://github.com/unslothai/unsloth/">
|
24 |
+
<img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="133">
|
25 |
+
</a>
|
26 |
+
<a href="https://discord.gg/unsloth">
|
27 |
+
<img src="https://github.com/unslothai/unsloth/raw/main/images/Discord%20button.png" width="173">
|
28 |
+
</a>
|
29 |
+
<a href="https://docs.unsloth.ai/basics/text-to-speech-tts-fine-tuning">
|
30 |
+
<img src="https://raw.githubusercontent.com/unslothai/unsloth/refs/heads/main/images/documentation%20green%20button.png" width="143">
|
31 |
+
</a>
|
32 |
+
</div>
|
33 |
+
<h1 style="margin-top: 0rem;">✨ Run & Fine-tune TTS models with Unsloth!</h1>
|
34 |
+
</div>
|
35 |
+
|
36 |
+
- Fine-tune TTS models for free using our Google [Colab notebooks here](https://docs.unsloth.ai/get-started/unsloth-notebooks#text-to-speech-tts-notebooks)!
|
37 |
+
- Read our Blog about TTS support: [unsloth.ai/blog/tts](https://docs.unsloth.ai/basics/text-to-speech-tts-fine-tuning)
|
38 |
+
|
39 |
+
| Unsloth supports | Free Notebooks | Performance | Memory use |
|
40 |
+
|-----------------|--------------------------------------------------------------------------------------------------------------------------|-------------|----------|
|
41 |
+
| **Spark-TTS** | [▶️ Start on Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Spark_TTS_(0_5B).ipynb) | 1.5x faster | 58% less |
|
42 |
+
| **Whisper Large V3** | [▶️ Start on Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Whisper.ipynb) | 1.5x faster | 50% less |
|
43 |
+
| **Qwen3 (14B)** | [▶️ Start on Colab](https://docs.unsloth.ai/get-started/unsloth-notebooks) | 2x faster | 70% less |
|
44 |
+
| **Llama 3.2 Vision (11B)** | [▶️ Start on Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb) | 1.8x faster | 50% less |
|
45 |
+
|
46 |
+
<div align="center">
|
47 |
+
<h1>
|
48 |
+
Spark-TTS
|
49 |
+
</h1>
|
50 |
+
<p>
|
51 |
+
Official model for <br>
|
52 |
+
<b><em>Spark-TTS: An Efficient LLM-Based Text-to-Speech Model with Single-Stream Decoupled Speech Tokens</em></b>
|
53 |
+
</p>
|
54 |
+
<p>
|
55 |
+
<img src="src/logo/SparkTTS.jpg" alt="Spark-TTS Logo" style="width: 200px; height: 200px;">
|
56 |
+
</p>
|
57 |
+
</div>
|
58 |
+
|
59 |
+
|
60 |
+
## Spark-TTS 🔥
|
61 |
+
|
62 |
+
### 👉🏻 [Spark-TTS Demos](https://sparkaudio.github.io/spark-tts/) 👈🏻
|
63 |
+
|
64 |
+
### 👉🏻 [Github Repo](https://github.com/SparkAudio/Spark-TTS) 👈🏻
|
65 |
+
|
66 |
+
### 👉🏻 [Paper](https://arxiv.org/pdf/2503.01710) 👈🏻
|
67 |
+
|
68 |
+
### Overview
|
69 |
+
|
70 |
+
Spark-TTS is an advanced text-to-speech system that uses the power of large language models (LLM) for highly accurate and natural-sounding voice synthesis. It is designed to be efficient, flexible, and powerful for both research and production use.
|
71 |
+
|
72 |
+
### Key Features
|
73 |
+
|
74 |
+
- **Simplicity and Efficiency**: Built entirely on Qwen2.5, Spark-TTS eliminates the need for additional generation models like flow matching. Instead of relying on separate models to generate acoustic features, it directly reconstructs audio from the code predicted by the LLM. This approach streamlines the process, improving efficiency and reducing complexity.
|
75 |
+
- **High-Quality Voice Cloning**: Supports zero-shot voice cloning, which means it can replicate a speaker's voice even without specific training data for that voice. This is ideal for cross-lingual and code-switching scenarios, allowing for seamless transitions between languages and voices without requiring separate training for each one.
|
76 |
+
- **Bilingual Support**: Supports both Chinese and English, and is capable of zero-shot voice cloning for cross-lingual and code-switching scenarios, enabling the model to synthesize speech in multiple languages with high naturalness and accuracy.
|
77 |
+
- **Controllable Speech Generation**: Supports creating virtual speakers by adjusting parameters such as gender, pitch, and speaking rate.
|
78 |
+
|
79 |
+
---
|
80 |
+
|
81 |
+
<table align="center">
|
82 |
+
<tr>
|
83 |
+
<td align="center"><b>Inference Overview of Voice Cloning</b><br><img src="src/figures/infer_voice_cloning.png" width="80%" /></td>
|
84 |
+
</tr>
|
85 |
+
<tr>
|
86 |
+
<td align="center"><b>Inference Overview of Controlled Generation</b><br><img src="src/figures/infer_control.png" width="80%" /></td>
|
87 |
+
</tr>
|
88 |
+
</table>
|
89 |
+
|
90 |
+
|
91 |
+
## Install
|
92 |
+
**Clone and Install**
|
93 |
+
|
94 |
+
- Clone the repo
|
95 |
+
``` sh
|
96 |
+
git clone https://github.com/SparkAudio/Spark-TTS.git
|
97 |
+
cd Spark-TTS
|
98 |
+
```
|
99 |
+
|
100 |
+
- Install Conda: please see https://docs.conda.io/en/latest/miniconda.html
|
101 |
+
- Create Conda env:
|
102 |
+
|
103 |
+
``` sh
|
104 |
+
conda create -n sparktts -y python=3.12
|
105 |
+
conda activate sparktts
|
106 |
+
pip install -r requirements.txt
|
107 |
+
# If you are in mainland China, you can set the mirror as follows:
|
108 |
+
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
109 |
+
```
|
110 |
+
|
111 |
+
**Model Download**
|
112 |
+
|
113 |
+
Download via python:
|
114 |
+
```python
|
115 |
+
from huggingface_hub import snapshot_download
|
116 |
+
|
117 |
+
snapshot_download("SparkAudio/Spark-TTS-0.5B", local_dir="pretrained_models/Spark-TTS-0.5B")
|
118 |
+
```
|
119 |
+
|
120 |
+
Download via git clone:
|
121 |
+
```sh
|
122 |
+
mkdir -p pretrained_models
|
123 |
+
|
124 |
+
# Make sure you have git-lfs installed (https://git-lfs.com)
|
125 |
+
git lfs install
|
126 |
+
|
127 |
+
git clone https://huggingface.co/SparkAudio/Spark-TTS-0.5B pretrained_models/Spark-TTS-0.5B
|
128 |
+
```
|
129 |
+
|
130 |
+
**Basic Usage**
|
131 |
+
|
132 |
+
You can simply run the demo with the following commands:
|
133 |
+
``` sh
|
134 |
+
cd example
|
135 |
+
bash infer.sh
|
136 |
+
```
|
137 |
+
|
138 |
+
Alternatively, you can directly execute the following command in the command line to perform inference:
|
139 |
+
|
140 |
+
``` sh
|
141 |
+
python -m cli.inference \
|
142 |
+
--text "text to synthesis." \
|
143 |
+
--device 0 \
|
144 |
+
--save_dir "path/to/save/audio" \
|
145 |
+
--model_dir pretrained_models/Spark-TTS-0.5B \
|
146 |
+
--prompt_text "transcript of the prompt audio" \
|
147 |
+
--prompt_speech_path "path/to/prompt_audio"
|
148 |
+
```
|
149 |
+
|
150 |
+
**UI Usage**
|
151 |
+
|
152 |
+
You can start the UI interface by running `python webui.py`, which allows you to perform Voice Cloning and Voice Creation. Voice Cloning supports uploading reference audio or directly recording the audio.
|
153 |
+
|
154 |
+
|
155 |
+
| **Voice Cloning** | **Voice Creation** |
|
156 |
+
|:-------------------:|:-------------------:|
|
157 |
+
|  |  |
|
158 |
+
|
159 |
+
|
160 |
+
## To-Do List
|
161 |
+
|
162 |
+
- [x] Release the Spark-TTS paper.
|
163 |
+
- [ ] Release the training code.
|
164 |
+
- [ ] Release the training dataset, VoxBox.
|
165 |
+
|
166 |
+
## Citation
|
167 |
+
|
168 |
+
```
|
169 |
+
@misc{wang2025sparktts,
|
170 |
+
title={Spark-TTS: An Efficient LLM-Based Text-to-Speech Model with Single-Stream Decoupled Speech Tokens},
|
171 |
+
author={Xinsheng Wang and Mingqi Jiang and Ziyang Ma and Ziyu Zhang and Songxiang Liu and Linqin Li and Zheng Liang and Qixi Zheng and Rui Wang and Xiaoqin Feng and Weizhen Bian and Zhen Ye and Sitong Cheng and Ruibin Yuan and Zhixian Zhao and Xinfa Zhu and Jiahao Pan and Liumeng Xue and Pengcheng Zhu and Yunlin Chen and Zhifei Li and Xie Chen and Lei Xie and Yike Guo and Wei Xue},
|
172 |
+
year={2025},
|
173 |
+
eprint={2503.01710},
|
174 |
+
archivePrefix={arXiv},
|
175 |
+
primaryClass={cs.SD},
|
176 |
+
url={https://arxiv.org/abs/2503.01710},
|
177 |
+
}
|
178 |
+
```
|
179 |
+
|
180 |
+
|
181 |
+
## ⚠ License Update
|
182 |
+
|
183 |
+
The model's license has been updated from Apache 2.0 to CC BY-NC-SA due to the licensing terms of some training data.
|
184 |
+
|
185 |
+
Key Changes:
|
186 |
+
|
187 |
+
- The model can only be used for non-commercial purposes.
|
188 |
+
|
189 |
+
- Any modifications or derivatives must also be released under CC BY-NC-SA 4.0.
|
190 |
+
|
191 |
+
- Proper attribution is required when using or modifying the model.
|
192 |
+
|
193 |
+
Please ensure compliance with the new license terms.
|
194 |
+
|
195 |
+
|
196 |
+
## ⚠️ Usage Disclaimer
|
197 |
+
|
198 |
+
This project provides a zero-shot voice cloning TTS model intended for academic research, educational purposes, and legitimate applications, such as personalized speech synthesis, assistive technologies, and linguistic research.
|
199 |
+
|
200 |
+
Please note:
|
201 |
+
|
202 |
+
- Do not use this model for unauthorized voice cloning, impersonation, fraud, scams, deepfakes, or any illegal activities.
|
203 |
+
|
204 |
+
- Ensure compliance with local laws and regulations when using this model and uphold ethical standards.
|
205 |
+
|
206 |
+
- The developers assume no liability for any misuse of this model.
|
207 |
+
|
208 |
+
We advocate for the responsible development and use of AI and encourage the community to uphold safety and ethical principles in AI research and applications. If you have any concerns regarding ethics or misuse, please contact us.
|
Spark_TTS_FT.ipynb
ADDED
@@ -0,0 +1,1732 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "Qpw04rkbynx0"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"To run this, press \"*Runtime*\" and press \"*Run all*\" on a **free** Tesla T4 Google Colab instance!\n",
|
10 |
+
"<div class=\"align-center\">\n",
|
11 |
+
"<a href=\"https://unsloth.ai/\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png\" width=\"115\"></a>\n",
|
12 |
+
"<a href=\"https://discord.gg/unsloth\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/Discord button.png\" width=\"145\"></a>\n",
|
13 |
+
"<a href=\"https://docs.unsloth.ai/\"><img src=\"https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true\" width=\"125\"></a></a> Join Discord if you need help + ⭐ <i>Star us on <a href=\"https://github.com/unslothai/unsloth\">Github</a> </i> ⭐\n",
|
14 |
+
"</div>\n",
|
15 |
+
"\n",
|
16 |
+
"To install Unsloth on your own computer, follow the installation instructions on our Github page [here](https://docs.unsloth.ai/get-started/installing-+-updating).\n",
|
17 |
+
"\n",
|
18 |
+
"You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)\n"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "markdown",
|
23 |
+
"metadata": {
|
24 |
+
"id": "5fs-yYEaynx1"
|
25 |
+
},
|
26 |
+
"source": [
|
27 |
+
"### News"
|
28 |
+
]
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"cell_type": "markdown",
|
32 |
+
"metadata": {
|
33 |
+
"id": "pyJK0UZaynx2"
|
34 |
+
},
|
35 |
+
"source": [
|
36 |
+
"Unsloth now supports Text-to-Speech (TTS) models. Read our [guide here](https://docs.unsloth.ai/basics/text-to-speech-tts-fine-tuning).\n",
|
37 |
+
"\n",
|
38 |
+
"Read our **[Gemma 3N Guide](https://docs.unsloth.ai/basics/gemma-3n-how-to-run-and-fine-tune)** and check out our new **[Dynamic 2.0](https://docs.unsloth.ai/basics/unsloth-dynamic-2.0-ggufs)** quants which outperforms other quantization methods!\n",
|
39 |
+
"\n",
|
40 |
+
"Visit our docs for all our [model uploads](https://docs.unsloth.ai/get-started/all-our-models) and [notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks).\n"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "markdown",
|
45 |
+
"metadata": {
|
46 |
+
"id": "SDUHv0mwynx3"
|
47 |
+
},
|
48 |
+
"source": [
|
49 |
+
"### Installation"
|
50 |
+
]
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"cell_type": "code",
|
54 |
+
"execution_count": 1,
|
55 |
+
"metadata": {
|
56 |
+
"id": "MY4G3EIbynx3"
|
57 |
+
},
|
58 |
+
"outputs": [],
|
59 |
+
"source": [
|
60 |
+
"%%capture\n",
|
61 |
+
"import os\n",
|
62 |
+
"if \"COLAB_\" not in \"\".join(os.environ.keys()):\n",
|
63 |
+
" %pip install unsloth\n",
|
64 |
+
"else:\n",
|
65 |
+
" # Do this only in Colab notebooks! Otherwise use pip install unsloth\n",
|
66 |
+
" %pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n",
|
67 |
+
" %pip install sentencepiece protobuf \"datasets>=3.4.1,<4.0.0\" \"huggingface_hub>=0.34.0\" hf_transfer\n",
|
68 |
+
" %pip install --no-deps unsloth\n",
|
69 |
+
"%git clone https://github.com/SparkAudio/Spark-TTS\n",
|
70 |
+
"%pip install omegaconf einx"
|
71 |
+
]
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"cell_type": "code",
|
75 |
+
"execution_count": 2,
|
76 |
+
"metadata": {
|
77 |
+
"colab": {
|
78 |
+
"base_uri": "https://localhost:8080/"
|
79 |
+
},
|
80 |
+
"id": "QmUBVEnvCDJv",
|
81 |
+
"outputId": "42083a68-d3cc-48c9-d852-b60796377434"
|
82 |
+
},
|
83 |
+
"outputs": [
|
84 |
+
{
|
85 |
+
"name": "stdout",
|
86 |
+
"output_type": "stream",
|
87 |
+
"text": [
|
88 |
+
"🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n",
|
89 |
+
"🦥 Unsloth Zoo will now patch everything to make training faster!\n"
|
90 |
+
]
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"data": {
|
94 |
+
"application/vnd.jupyter.widget-view+json": {
|
95 |
+
"model_id": "9ad0d25a6f8549d1ac79addbe171b758",
|
96 |
+
"version_major": 2,
|
97 |
+
"version_minor": 0
|
98 |
+
},
|
99 |
+
"text/plain": [
|
100 |
+
".gitattributes: 0.00B [00:00, ?B/s]"
|
101 |
+
]
|
102 |
+
},
|
103 |
+
"metadata": {},
|
104 |
+
"output_type": "display_data"
|
105 |
+
},
|
106 |
+
{
|
107 |
+
"data": {
|
108 |
+
"application/vnd.jupyter.widget-view+json": {
|
109 |
+
"model_id": "7e83dd9464b64a6d963c349d1660a28c",
|
110 |
+
"version_major": 2,
|
111 |
+
"version_minor": 0
|
112 |
+
},
|
113 |
+
"text/plain": [
|
114 |
+
"config.yaml: 0.00B [00:00, ?B/s]"
|
115 |
+
]
|
116 |
+
},
|
117 |
+
"metadata": {},
|
118 |
+
"output_type": "display_data"
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"data": {
|
122 |
+
"application/vnd.jupyter.widget-view+json": {
|
123 |
+
"model_id": "332e86b12a4c45a89a95f1f265ca0f12",
|
124 |
+
"version_major": 2,
|
125 |
+
"version_minor": 0
|
126 |
+
},
|
127 |
+
"text/plain": [
|
128 |
+
"BiCodec/model.safetensors: 0%| | 0.00/626M [00:00<?, ?B/s]"
|
129 |
+
]
|
130 |
+
},
|
131 |
+
"metadata": {},
|
132 |
+
"output_type": "display_data"
|
133 |
+
},
|
134 |
+
{
|
135 |
+
"data": {
|
136 |
+
"application/vnd.jupyter.widget-view+json": {
|
137 |
+
"model_id": "c1a54d8c9dc8472e8f0f37603ccd3904",
|
138 |
+
"version_major": 2,
|
139 |
+
"version_minor": 0
|
140 |
+
},
|
141 |
+
"text/plain": [
|
142 |
+
"added_tokens.json: 0.00B [00:00, ?B/s]"
|
143 |
+
]
|
144 |
+
},
|
145 |
+
"metadata": {},
|
146 |
+
"output_type": "display_data"
|
147 |
+
},
|
148 |
+
{
|
149 |
+
"data": {
|
150 |
+
"application/vnd.jupyter.widget-view+json": {
|
151 |
+
"model_id": "8402d2f2ef204022b0727f2b09437bad",
|
152 |
+
"version_major": 2,
|
153 |
+
"version_minor": 0
|
154 |
+
},
|
155 |
+
"text/plain": [
|
156 |
+
"config.json: 0%| | 0.00/658 [00:00<?, ?B/s]"
|
157 |
+
]
|
158 |
+
},
|
159 |
+
"metadata": {},
|
160 |
+
"output_type": "display_data"
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"data": {
|
164 |
+
"application/vnd.jupyter.widget-view+json": {
|
165 |
+
"model_id": "43f438eabd1843cc8c5977f0ef6226ec",
|
166 |
+
"version_major": 2,
|
167 |
+
"version_minor": 0
|
168 |
+
},
|
169 |
+
"text/plain": [
|
170 |
+
"merges.txt: 0.00B [00:00, ?B/s]"
|
171 |
+
]
|
172 |
+
},
|
173 |
+
"metadata": {},
|
174 |
+
"output_type": "display_data"
|
175 |
+
},
|
176 |
+
{
|
177 |
+
"data": {
|
178 |
+
"application/vnd.jupyter.widget-view+json": {
|
179 |
+
"model_id": "87dce305eba54c1797547c06a2ab7cf6",
|
180 |
+
"version_major": 2,
|
181 |
+
"version_minor": 0
|
182 |
+
},
|
183 |
+
"text/plain": [
|
184 |
+
"LLM/model.safetensors: 0%| | 0.00/2.03G [00:00<?, ?B/s]"
|
185 |
+
]
|
186 |
+
},
|
187 |
+
"metadata": {},
|
188 |
+
"output_type": "display_data"
|
189 |
+
},
|
190 |
+
{
|
191 |
+
"data": {
|
192 |
+
"application/vnd.jupyter.widget-view+json": {
|
193 |
+
"model_id": "3ea6e51894454a5c82bb4cfe1fd0a47f",
|
194 |
+
"version_major": 2,
|
195 |
+
"version_minor": 0
|
196 |
+
},
|
197 |
+
"text/plain": [
|
198 |
+
"special_tokens_map.json: 0%| | 0.00/613 [00:00<?, ?B/s]"
|
199 |
+
]
|
200 |
+
},
|
201 |
+
"metadata": {},
|
202 |
+
"output_type": "display_data"
|
203 |
+
},
|
204 |
+
{
|
205 |
+
"data": {
|
206 |
+
"application/vnd.jupyter.widget-view+json": {
|
207 |
+
"model_id": "94e7da1bdc7549e0ba4dcd0b73d38667",
|
208 |
+
"version_major": 2,
|
209 |
+
"version_minor": 0
|
210 |
+
},
|
211 |
+
"text/plain": [
|
212 |
+
"LLM/tokenizer.json: 0%| | 0.00/14.1M [00:00<?, ?B/s]"
|
213 |
+
]
|
214 |
+
},
|
215 |
+
"metadata": {},
|
216 |
+
"output_type": "display_data"
|
217 |
+
},
|
218 |
+
{
|
219 |
+
"data": {
|
220 |
+
"application/vnd.jupyter.widget-view+json": {
|
221 |
+
"model_id": "1aa226f63eac4ee48537df6b26d921c1",
|
222 |
+
"version_major": 2,
|
223 |
+
"version_minor": 0
|
224 |
+
},
|
225 |
+
"text/plain": [
|
226 |
+
"tokenizer_config.json: 0.00B [00:00, ?B/s]"
|
227 |
+
]
|
228 |
+
},
|
229 |
+
"metadata": {},
|
230 |
+
"output_type": "display_data"
|
231 |
+
},
|
232 |
+
{
|
233 |
+
"data": {
|
234 |
+
"application/vnd.jupyter.widget-view+json": {
|
235 |
+
"model_id": "420eaeeb7bee4c21964c17968c266ac1",
|
236 |
+
"version_major": 2,
|
237 |
+
"version_minor": 0
|
238 |
+
},
|
239 |
+
"text/plain": [
|
240 |
+
"vocab.json: 0.00B [00:00, ?B/s]"
|
241 |
+
]
|
242 |
+
},
|
243 |
+
"metadata": {},
|
244 |
+
"output_type": "display_data"
|
245 |
+
},
|
246 |
+
{
|
247 |
+
"data": {
|
248 |
+
"application/vnd.jupyter.widget-view+json": {
|
249 |
+
"model_id": "bdcb3d5d6a8e4e969afa77631e7c3104",
|
250 |
+
"version_major": 2,
|
251 |
+
"version_minor": 0
|
252 |
+
},
|
253 |
+
"text/plain": [
|
254 |
+
"README.md: 0.00B [00:00, ?B/s]"
|
255 |
+
]
|
256 |
+
},
|
257 |
+
"metadata": {},
|
258 |
+
"output_type": "display_data"
|
259 |
+
},
|
260 |
+
{
|
261 |
+
"data": {
|
262 |
+
"application/vnd.jupyter.widget-view+json": {
|
263 |
+
"model_id": "1cd60c7dbe61410ca5bc61310367635a",
|
264 |
+
"version_major": 2,
|
265 |
+
"version_minor": 0
|
266 |
+
},
|
267 |
+
"text/plain": [
|
268 |
+
"config.yaml: 0%| | 0.00/169 [00:00<?, ?B/s]"
|
269 |
+
]
|
270 |
+
},
|
271 |
+
"metadata": {},
|
272 |
+
"output_type": "display_data"
|
273 |
+
},
|
274 |
+
{
|
275 |
+
"data": {
|
276 |
+
"application/vnd.jupyter.widget-view+json": {
|
277 |
+
"model_id": "0ea819afc66b437ca8b0dc7337f5ce5f",
|
278 |
+
"version_major": 2,
|
279 |
+
"version_minor": 0
|
280 |
+
},
|
281 |
+
"text/plain": [
|
282 |
+
"gradio_TTS.png: 0%| | 0.00/81.8k [00:00<?, ?B/s]"
|
283 |
+
]
|
284 |
+
},
|
285 |
+
"metadata": {},
|
286 |
+
"output_type": "display_data"
|
287 |
+
},
|
288 |
+
{
|
289 |
+
"data": {
|
290 |
+
"application/vnd.jupyter.widget-view+json": {
|
291 |
+
"model_id": "00f074bbbc5b44d59c590cc217187aa5",
|
292 |
+
"version_major": 2,
|
293 |
+
"version_minor": 0
|
294 |
+
},
|
295 |
+
"text/plain": [
|
296 |
+
"gradio_control.png: 0%| | 0.00/62.2k [00:00<?, ?B/s]"
|
297 |
+
]
|
298 |
+
},
|
299 |
+
"metadata": {},
|
300 |
+
"output_type": "display_data"
|
301 |
+
},
|
302 |
+
{
|
303 |
+
"data": {
|
304 |
+
"application/vnd.jupyter.widget-view+json": {
|
305 |
+
"model_id": "d050a4b7cf2b4f78af51986b9c2eee45",
|
306 |
+
"version_major": 2,
|
307 |
+
"version_minor": 0
|
308 |
+
},
|
309 |
+
"text/plain": [
|
310 |
+
"src/figures/infer_control.png: 0%| | 0.00/127k [00:00<?, ?B/s]"
|
311 |
+
]
|
312 |
+
},
|
313 |
+
"metadata": {},
|
314 |
+
"output_type": "display_data"
|
315 |
+
},
|
316 |
+
{
|
317 |
+
"data": {
|
318 |
+
"application/vnd.jupyter.widget-view+json": {
|
319 |
+
"model_id": "6ed5ce435b89443f9cca00ed1b97311e",
|
320 |
+
"version_major": 2,
|
321 |
+
"version_minor": 0
|
322 |
+
},
|
323 |
+
"text/plain": [
|
324 |
+
"src/figures/infer_voice_cloning.png: 0%| | 0.00/119k [00:00<?, ?B/s]"
|
325 |
+
]
|
326 |
+
},
|
327 |
+
"metadata": {},
|
328 |
+
"output_type": "display_data"
|
329 |
+
},
|
330 |
+
{
|
331 |
+
"data": {
|
332 |
+
"application/vnd.jupyter.widget-view+json": {
|
333 |
+
"model_id": "0a7db4ff0d204ed4839471cbd8ebefef",
|
334 |
+
"version_major": 2,
|
335 |
+
"version_minor": 0
|
336 |
+
},
|
337 |
+
"text/plain": [
|
338 |
+
"src/logo/HKUST.jpg: 0%| | 0.00/102k [00:00<?, ?B/s]"
|
339 |
+
]
|
340 |
+
},
|
341 |
+
"metadata": {},
|
342 |
+
"output_type": "display_data"
|
343 |
+
},
|
344 |
+
{
|
345 |
+
"data": {
|
346 |
+
"application/vnd.jupyter.widget-view+json": {
|
347 |
+
"model_id": "d7b682f3d5d142c68ec6bea0be196792",
|
348 |
+
"version_major": 2,
|
349 |
+
"version_minor": 0
|
350 |
+
},
|
351 |
+
"text/plain": [
|
352 |
+
"src/logo/NPU.jpg: 0%| | 0.00/152k [00:00<?, ?B/s]"
|
353 |
+
]
|
354 |
+
},
|
355 |
+
"metadata": {},
|
356 |
+
"output_type": "display_data"
|
357 |
+
},
|
358 |
+
{
|
359 |
+
"data": {
|
360 |
+
"application/vnd.jupyter.widget-view+json": {
|
361 |
+
"model_id": "bd49989b32d3492894bf08b084059ba6",
|
362 |
+
"version_major": 2,
|
363 |
+
"version_minor": 0
|
364 |
+
},
|
365 |
+
"text/plain": [
|
366 |
+
"NTU.jpg: 0%| | 0.00/77.6k [00:00<?, ?B/s]"
|
367 |
+
]
|
368 |
+
},
|
369 |
+
"metadata": {},
|
370 |
+
"output_type": "display_data"
|
371 |
+
},
|
372 |
+
{
|
373 |
+
"data": {
|
374 |
+
"application/vnd.jupyter.widget-view+json": {
|
375 |
+
"model_id": "b4576071c87448ef8ba94df410964d6c",
|
376 |
+
"version_major": 2,
|
377 |
+
"version_minor": 0
|
378 |
+
},
|
379 |
+
"text/plain": [
|
380 |
+
"src/logo/SJU.jpg: 0%| | 0.00/364k [00:00<?, ?B/s]"
|
381 |
+
]
|
382 |
+
},
|
383 |
+
"metadata": {},
|
384 |
+
"output_type": "display_data"
|
385 |
+
},
|
386 |
+
{
|
387 |
+
"data": {
|
388 |
+
"application/vnd.jupyter.widget-view+json": {
|
389 |
+
"model_id": "3dbdd98fca6741d2874849b2b26662db",
|
390 |
+
"version_major": 2,
|
391 |
+
"version_minor": 0
|
392 |
+
},
|
393 |
+
"text/plain": [
|
394 |
+
"SparkAudio.jpg: 0%| | 0.00/89.0k [00:00<?, ?B/s]"
|
395 |
+
]
|
396 |
+
},
|
397 |
+
"metadata": {},
|
398 |
+
"output_type": "display_data"
|
399 |
+
},
|
400 |
+
{
|
401 |
+
"data": {
|
402 |
+
"application/vnd.jupyter.widget-view+json": {
|
403 |
+
"model_id": "ce753e6904ff4dd4ae5c5824ac554d76",
|
404 |
+
"version_major": 2,
|
405 |
+
"version_minor": 0
|
406 |
+
},
|
407 |
+
"text/plain": [
|
408 |
+
"SparkAudio2.jpg: 0%| | 0.00/40.7k [00:00<?, ?B/s]"
|
409 |
+
]
|
410 |
+
},
|
411 |
+
"metadata": {},
|
412 |
+
"output_type": "display_data"
|
413 |
+
},
|
414 |
+
{
|
415 |
+
"data": {
|
416 |
+
"application/vnd.jupyter.widget-view+json": {
|
417 |
+
"model_id": "90c48554b64b46f388ee14df2c401a02",
|
418 |
+
"version_major": 2,
|
419 |
+
"version_minor": 0
|
420 |
+
},
|
421 |
+
"text/plain": [
|
422 |
+
"SparkTTS.jpg: 0%| | 0.00/52.5k [00:00<?, ?B/s]"
|
423 |
+
]
|
424 |
+
},
|
425 |
+
"metadata": {},
|
426 |
+
"output_type": "display_data"
|
427 |
+
},
|
428 |
+
{
|
429 |
+
"data": {
|
430 |
+
"application/vnd.jupyter.widget-view+json": {
|
431 |
+
"model_id": "059f5fe90c324bd7b0aef23095af1c21",
|
432 |
+
"version_major": 2,
|
433 |
+
"version_minor": 0
|
434 |
+
},
|
435 |
+
"text/plain": [
|
436 |
+
"src/logo/SparkTTS.png: 0%| | 0.00/102k [00:00<?, ?B/s]"
|
437 |
+
]
|
438 |
+
},
|
439 |
+
"metadata": {},
|
440 |
+
"output_type": "display_data"
|
441 |
+
},
|
442 |
+
{
|
443 |
+
"data": {
|
444 |
+
"application/vnd.jupyter.widget-view+json": {
|
445 |
+
"model_id": "ccf1938072024151ab5c50492866e253",
|
446 |
+
"version_major": 2,
|
447 |
+
"version_minor": 0
|
448 |
+
},
|
449 |
+
"text/plain": [
|
450 |
+
"src/logo/mobvoi.jpg: 0%| | 0.00/431k [00:00<?, ?B/s]"
|
451 |
+
]
|
452 |
+
},
|
453 |
+
"metadata": {},
|
454 |
+
"output_type": "display_data"
|
455 |
+
},
|
456 |
+
{
|
457 |
+
"data": {
|
458 |
+
"application/vnd.jupyter.widget-view+json": {
|
459 |
+
"model_id": "771681ce27b94c71a61da27b133427ac",
|
460 |
+
"version_major": 2,
|
461 |
+
"version_minor": 0
|
462 |
+
},
|
463 |
+
"text/plain": [
|
464 |
+
"src/logo/mobvoi.png: 0%| | 0.00/120k [00:00<?, ?B/s]"
|
465 |
+
]
|
466 |
+
},
|
467 |
+
"metadata": {},
|
468 |
+
"output_type": "display_data"
|
469 |
+
},
|
470 |
+
{
|
471 |
+
"data": {
|
472 |
+
"application/vnd.jupyter.widget-view+json": {
|
473 |
+
"model_id": "243ff52bb35242eeb330a2bb2ffe4166",
|
474 |
+
"version_major": 2,
|
475 |
+
"version_minor": 0
|
476 |
+
},
|
477 |
+
"text/plain": [
|
478 |
+
"README.md: 0.00B [00:00, ?B/s]"
|
479 |
+
]
|
480 |
+
},
|
481 |
+
"metadata": {},
|
482 |
+
"output_type": "display_data"
|
483 |
+
},
|
484 |
+
{
|
485 |
+
"data": {
|
486 |
+
"application/vnd.jupyter.widget-view+json": {
|
487 |
+
"model_id": "c17c5bd399fd411d8f2ee43f79539cca",
|
488 |
+
"version_major": 2,
|
489 |
+
"version_minor": 0
|
490 |
+
},
|
491 |
+
"text/plain": [
|
492 |
+
"config.json: 0.00B [00:00, ?B/s]"
|
493 |
+
]
|
494 |
+
},
|
495 |
+
"metadata": {},
|
496 |
+
"output_type": "display_data"
|
497 |
+
},
|
498 |
+
{
|
499 |
+
"data": {
|
500 |
+
"application/vnd.jupyter.widget-view+json": {
|
501 |
+
"model_id": "2d6ae8fc962b41aeb4ce1fec0d3f0864",
|
502 |
+
"version_major": 2,
|
503 |
+
"version_minor": 0
|
504 |
+
},
|
505 |
+
"text/plain": [
|
506 |
+
"preprocessor_config.json: 0%| | 0.00/212 [00:00<?, ?B/s]"
|
507 |
+
]
|
508 |
+
},
|
509 |
+
"metadata": {},
|
510 |
+
"output_type": "display_data"
|
511 |
+
},
|
512 |
+
{
|
513 |
+
"data": {
|
514 |
+
"application/vnd.jupyter.widget-view+json": {
|
515 |
+
"model_id": "f3394d8a215e406f8f50b8770dd354d3",
|
516 |
+
"version_major": 2,
|
517 |
+
"version_minor": 0
|
518 |
+
},
|
519 |
+
"text/plain": [
|
520 |
+
"wav2vec2-large-xlsr-53/pytorch_model.bin: 0%| | 0.00/1.27G [00:00<?, ?B/s]"
|
521 |
+
]
|
522 |
+
},
|
523 |
+
"metadata": {},
|
524 |
+
"output_type": "display_data"
|
525 |
+
},
|
526 |
+
{
|
527 |
+
"name": "stdout",
|
528 |
+
"output_type": "stream",
|
529 |
+
"text": [
|
530 |
+
"==((====))== Unsloth 2025.8.1: Fast Qwen2 patching. Transformers: 4.55.0.\n",
|
531 |
+
" \\\\ /| NVIDIA GeForce RTX 2080 SUPER. Num GPUs = 2. Max memory: 7.785 GB. Platform: Linux.\n",
|
532 |
+
"O^O/ \\_/ \\ Torch: 2.7.1+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.3.1\n",
|
533 |
+
"\\ / Bfloat16 = FALSE. FA [Xformers = 0.0.31.post1. FA2 = False]\n",
|
534 |
+
" \"-____-\" Free license: http://github.com/unslothai/unsloth\n",
|
535 |
+
"Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n",
|
536 |
+
"Unsloth: Float16 full finetuning uses more memory since we upcast weights to float32.\n"
|
537 |
+
]
|
538 |
+
}
|
539 |
+
],
|
540 |
+
"source": [
|
541 |
+
"from unsloth import FastModel\n",
|
542 |
+
"import torch\n",
|
543 |
+
"from huggingface_hub import snapshot_download\n",
|
544 |
+
"\n",
|
545 |
+
"max_seq_length = 2048 # Choose any for long context!\n",
|
546 |
+
"\n",
|
547 |
+
"fourbit_models = [\n",
|
548 |
+
" # 4bit dynamic quants for superior accuracy and low memory use\n",
|
549 |
+
" \"unsloth/gemma-3-4b-it-unsloth-bnb-4bit\",\n",
|
550 |
+
" \"unsloth/gemma-3-12b-it-unsloth-bnb-4bit\",\n",
|
551 |
+
" \"unsloth/gemma-3-27b-it-unsloth-bnb-4bit\",\n",
|
552 |
+
" # Qwen3 new models\n",
|
553 |
+
" \"unsloth/Qwen3-4B-unsloth-bnb-4bit\",\n",
|
554 |
+
" \"unsloth/Qwen3-8B-unsloth-bnb-4bit\",\n",
|
555 |
+
" # Other very popular models!\n",
|
556 |
+
" \"unsloth/Llama-3.1-8B\",\n",
|
557 |
+
" \"unsloth/Llama-3.2-3B\",\n",
|
558 |
+
" \"unsloth/Llama-3.3-70B\",\n",
|
559 |
+
" \"unsloth/mistral-7b-instruct-v0.3\",\n",
|
560 |
+
" \"unsloth/Phi-4\",\n",
|
561 |
+
"] # More models at https://huggingface.co/unsloth\n",
|
562 |
+
"\n",
|
563 |
+
"# Download model and code\n",
|
564 |
+
"snapshot_download(\"unsloth/Spark-TTS-0.5B\", local_dir = \"Spark-TTS-0.5B\")\n",
|
565 |
+
"\n",
|
566 |
+
"model, tokenizer = FastModel.from_pretrained(\n",
|
567 |
+
" model_name = f\"Spark-TTS-0.5B/LLM\",\n",
|
568 |
+
" max_seq_length = max_seq_length,\n",
|
569 |
+
" dtype = torch.float32, # Spark seems to only work on float32 for now\n",
|
570 |
+
" full_finetuning = True, # We support full finetuning now!\n",
|
571 |
+
" load_in_4bit = False,\n",
|
572 |
+
" #token = \"hf_...\", # use one if using gated models like meta-llama/Llama-2-7b-hf\n",
|
573 |
+
")"
|
574 |
+
]
|
575 |
+
},
|
576 |
+
{
|
577 |
+
"cell_type": "markdown",
|
578 |
+
"metadata": {
|
579 |
+
"id": "SXd9bTZd1aaL"
|
580 |
+
},
|
581 |
+
"source": [
|
582 |
+
"We now add LoRA adapters so we only need to update 1 to 10% of all parameters!"
|
583 |
+
]
|
584 |
+
},
|
585 |
+
{
|
586 |
+
"cell_type": "code",
|
587 |
+
"execution_count": 3,
|
588 |
+
"metadata": {
|
589 |
+
"colab": {
|
590 |
+
"base_uri": "https://localhost:8080/"
|
591 |
+
},
|
592 |
+
"id": "6bZsfBuZDeCL",
|
593 |
+
"outputId": "292447b8-fd80-4b8b-ba3f-4637a1045166"
|
594 |
+
},
|
595 |
+
"outputs": [
|
596 |
+
{
|
597 |
+
"name": "stdout",
|
598 |
+
"output_type": "stream",
|
599 |
+
"text": [
|
600 |
+
"Unsloth: Full finetuning is enabled, so .get_peft_model has no effect\n"
|
601 |
+
]
|
602 |
+
}
|
603 |
+
],
|
604 |
+
"source": [
|
605 |
+
"#LoRA does not work with float32 only works with bfloat16 !!!\n",
|
606 |
+
"model = FastModel.get_peft_model(\n",
|
607 |
+
" model,\n",
|
608 |
+
" r = 128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
|
609 |
+
" target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
|
610 |
+
" \"gate_proj\", \"up_proj\", \"down_proj\",],\n",
|
611 |
+
" lora_alpha = 128,\n",
|
612 |
+
" lora_dropout = 0, # Supports any, but = 0 is optimized\n",
|
613 |
+
" bias = \"none\", # Supports any, but = \"none\" is optimized\n",
|
614 |
+
" # [NEW] \"unsloth\" uses 30% less VRAM, fits 2x larger batch sizes!\n",
|
615 |
+
" use_gradient_checkpointing = \"unsloth\", # True or \"unsloth\" for very long context\n",
|
616 |
+
" random_state = 3407,\n",
|
617 |
+
" use_rslora = False, # We support rank stabilized LoRA\n",
|
618 |
+
" loftq_config = None, # And LoftQ\n",
|
619 |
+
")"
|
620 |
+
]
|
621 |
+
},
|
622 |
+
{
|
623 |
+
"cell_type": "markdown",
|
624 |
+
"metadata": {
|
625 |
+
"id": "vITh0KVJ10qX"
|
626 |
+
},
|
627 |
+
"source": [
|
628 |
+
"<a name=\"Data\"></a>\n",
|
629 |
+
"### Data Prep \n",
|
630 |
+
"\n",
|
631 |
+
"We will use the `Balaji-1904/TTS_KN_DS_V1.1`, which is designed for training TTS models. Ensure that your dataset follows the required format: **text, audio** for single-speaker models or **source, text, audio** for multi-speaker models. You can modify this section to accommodate your own dataset, but maintaining the correct structure is essential for optimal training."
|
632 |
+
]
|
633 |
+
},
|
634 |
+
{
|
635 |
+
"cell_type": "code",
|
636 |
+
"execution_count": 4,
|
637 |
+
"metadata": {
|
638 |
+
"id": "LjY75GoYUCB8"
|
639 |
+
},
|
640 |
+
"outputs": [],
|
641 |
+
"source": [
|
642 |
+
"from datasets import load_dataset\n",
|
643 |
+
"dataset = load_dataset(\"Balaji-1904/TTS_KN_DS_V1.1\", split = \"train\")"
|
644 |
+
]
|
645 |
+
},
|
646 |
+
{
|
647 |
+
"cell_type": "code",
|
648 |
+
"execution_count": 5,
|
649 |
+
"metadata": {
|
650 |
+
"colab": {
|
651 |
+
"base_uri": "https://localhost:8080/",
|
652 |
+
"height": 173,
|
653 |
+
"referenced_widgets": [
|
654 |
+
"a3b0c0581f1f4c428baaadd8e9a39b6f",
|
655 |
+
"2315228ff2b141afabe1263471f5364b",
|
656 |
+
"0474debc340943bd85f3daf92aebf7aa",
|
657 |
+
"cff1b0fa2ea24f45aab26685353eefdd",
|
658 |
+
"b7e20be79df246f19b35114a690e44f0",
|
659 |
+
"426eb100a94642f79e6b99777406a265",
|
660 |
+
"a36b5cf197dd4bd9a7f70aa6671b804c",
|
661 |
+
"0de4d0f282404edfbc191dca73f15f35",
|
662 |
+
"e58b5ad2f781475d8af2ddb38009baa6",
|
663 |
+
"33fbacbb2aa146cd90586357eec1dc3e",
|
664 |
+
"930b4d1d5f4b494b830df4d4c398e67c"
|
665 |
+
]
|
666 |
+
},
|
667 |
+
"id": "zK94B-Pfioto",
|
668 |
+
"outputId": "3f11cf35-c173-410d-f709-43552323f26f"
|
669 |
+
},
|
670 |
+
"outputs": [
|
671 |
+
{
|
672 |
+
"ename": "ModuleNotFoundError",
|
673 |
+
"evalue": "No module named 'torchaudio'",
|
674 |
+
"output_type": "error",
|
675 |
+
"traceback": [
|
676 |
+
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
677 |
+
"\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)",
|
678 |
+
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 4\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;66;03m#@title Tokenization Function\u001b[39;00m\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mlocale\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m4\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorchaudio\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mtransforms\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mT\u001b[39;00m\n\u001b[32m 5\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mos\u001b[39;00m\n\u001b[32m 6\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorch\u001b[39;00m\n",
|
679 |
+
"\u001b[31mModuleNotFoundError\u001b[39m: No module named 'torchaudio'"
|
680 |
+
]
|
681 |
+
}
|
682 |
+
],
|
683 |
+
"source": [
|
684 |
+
"#@title Tokenization Function\n",
|
685 |
+
"\n",
|
686 |
+
"import locale\n",
|
687 |
+
"import torchaudio.transforms as T\n",
|
688 |
+
"import os\n",
|
689 |
+
"import torch\n",
|
690 |
+
"import sys\n",
|
691 |
+
"import numpy as np\n",
|
692 |
+
"sys.path.append('Spark-TTS')\n",
|
693 |
+
"from sparktts.models.audio_tokenizer import BiCodecTokenizer\n",
|
694 |
+
"from sparktts.utils.audio import audio_volume_normalize\n",
|
695 |
+
"\n",
|
696 |
+
"audio_tokenizer = BiCodecTokenizer(\"Spark-TTS-0.5B\", \"cuda\")\n",
|
697 |
+
"def extract_wav2vec2_features( wavs: torch.Tensor) -> torch.Tensor:\n",
|
698 |
+
" \"\"\"extract wav2vec2 features\"\"\"\n",
|
699 |
+
"\n",
|
700 |
+
" if wavs.shape[0] != 1:\n",
|
701 |
+
"\n",
|
702 |
+
" raise ValueError(f\"Expected batch size 1, but got shape {wavs.shape}\")\n",
|
703 |
+
" wav_np = wavs.squeeze(0).cpu().numpy()\n",
|
704 |
+
"\n",
|
705 |
+
" processed = audio_tokenizer.processor(\n",
|
706 |
+
" wav_np,\n",
|
707 |
+
" sampling_rate=16000,\n",
|
708 |
+
" return_tensors=\"pt\",\n",
|
709 |
+
" padding=True,\n",
|
710 |
+
" )\n",
|
711 |
+
" input_values = processed.input_values\n",
|
712 |
+
"\n",
|
713 |
+
" input_values = input_values.to(audio_tokenizer.feature_extractor.device)\n",
|
714 |
+
"\n",
|
715 |
+
" model_output = audio_tokenizer.feature_extractor(\n",
|
716 |
+
" input_values,\n",
|
717 |
+
" )\n",
|
718 |
+
"\n",
|
719 |
+
"\n",
|
720 |
+
" if model_output.hidden_states is None:\n",
|
721 |
+
" raise ValueError(\"Wav2Vec2Model did not return hidden states. Ensure config `output_hidden_states=True`.\")\n",
|
722 |
+
"\n",
|
723 |
+
" num_layers = len(model_output.hidden_states)\n",
|
724 |
+
" required_layers = [11, 14, 16]\n",
|
725 |
+
" if any(l >= num_layers for l in required_layers):\n",
|
726 |
+
" raise IndexError(f\"Requested hidden state indices {required_layers} out of range for model with {num_layers} layers.\")\n",
|
727 |
+
"\n",
|
728 |
+
" feats_mix = (\n",
|
729 |
+
" model_output.hidden_states[11] + model_output.hidden_states[14] + model_output.hidden_states[16]\n",
|
730 |
+
" ) / 3\n",
|
731 |
+
"\n",
|
732 |
+
" return feats_mix\n",
|
733 |
+
"def formatting_audio_func(example):\n",
|
734 |
+
" text = f\"{example['source']}: {example['text']}\" if \"source\" in example else example[\"text\"]\n",
|
735 |
+
" audio_array = example[\"audio\"][\"array\"]\n",
|
736 |
+
" sampling_rate = example[\"audio\"][\"sampling_rate\"]\n",
|
737 |
+
"\n",
|
738 |
+
" target_sr = audio_tokenizer.config['sample_rate']\n",
|
739 |
+
"\n",
|
740 |
+
" if sampling_rate != target_sr:\n",
|
741 |
+
" resampler = T.Resample(orig_freq=sampling_rate, new_freq=target_sr)\n",
|
742 |
+
" audio_tensor_temp = torch.from_numpy(audio_array).float()\n",
|
743 |
+
" audio_array = resampler(audio_tensor_temp).numpy()\n",
|
744 |
+
"\n",
|
745 |
+
" if audio_tokenizer.config[\"volume_normalize\"]:\n",
|
746 |
+
" audio_array = audio_volume_normalize(audio_array)\n",
|
747 |
+
"\n",
|
748 |
+
" ref_wav_np = audio_tokenizer.get_ref_clip(audio_array)\n",
|
749 |
+
"\n",
|
750 |
+
" audio_tensor = torch.from_numpy(audio_array).unsqueeze(0).float().to(audio_tokenizer.device)\n",
|
751 |
+
" ref_wav_tensor = torch.from_numpy(ref_wav_np).unsqueeze(0).float().to(audio_tokenizer.device)\n",
|
752 |
+
"\n",
|
753 |
+
"\n",
|
754 |
+
" feat = extract_wav2vec2_features(audio_tensor)\n",
|
755 |
+
"\n",
|
756 |
+
" batch = {\n",
|
757 |
+
"\n",
|
758 |
+
" \"wav\": audio_tensor,\n",
|
759 |
+
" \"ref_wav\": ref_wav_tensor,\n",
|
760 |
+
" \"feat\": feat.to(audio_tokenizer.device),\n",
|
761 |
+
" }\n",
|
762 |
+
"\n",
|
763 |
+
"\n",
|
764 |
+
" semantic_token_ids, global_token_ids = audio_tokenizer.model.tokenize(batch)\n",
|
765 |
+
"\n",
|
766 |
+
" global_tokens = \"\".join(\n",
|
767 |
+
" [f\"<|bicodec_global_{i}|>\" for i in global_token_ids.squeeze().cpu().numpy()] # Squeeze batch dim\n",
|
768 |
+
" )\n",
|
769 |
+
" semantic_tokens = \"\".join(\n",
|
770 |
+
" [f\"<|bicodec_semantic_{i}|>\" for i in semantic_token_ids.squeeze().cpu().numpy()] # Squeeze batch dim\n",
|
771 |
+
" )\n",
|
772 |
+
"\n",
|
773 |
+
" inputs = [\n",
|
774 |
+
" \"<|task_tts|>\",\n",
|
775 |
+
" \"<|start_content|>\",\n",
|
776 |
+
" text,\n",
|
777 |
+
" \"<|end_content|>\",\n",
|
778 |
+
" \"<|start_global_token|>\",\n",
|
779 |
+
" global_tokens,\n",
|
780 |
+
" \"<|end_global_token|>\",\n",
|
781 |
+
" \"<|start_semantic_token|>\",\n",
|
782 |
+
" semantic_tokens,\n",
|
783 |
+
" \"<|end_semantic_token|>\",\n",
|
784 |
+
" \"<|im_end|>\"\n",
|
785 |
+
" ]\n",
|
786 |
+
" inputs = \"\".join(inputs)\n",
|
787 |
+
" return {\"text\": inputs}\n",
|
788 |
+
"\n",
|
789 |
+
"\n",
|
790 |
+
"dataset = dataset.map(formatting_audio_func, remove_columns=[\"audio\"])\n",
|
791 |
+
"print(\"Moving Bicodec model and Wav2Vec2Model to cpu.\")\n",
|
792 |
+
"audio_tokenizer.model.cpu()\n",
|
793 |
+
"audio_tokenizer.feature_extractor.cpu()\n",
|
794 |
+
"torch.cuda.empty_cache()"
|
795 |
+
]
|
796 |
+
},
|
797 |
+
{
|
798 |
+
"cell_type": "code",
|
799 |
+
"execution_count": 6,
|
800 |
+
"metadata": {},
|
801 |
+
"outputs": [
|
802 |
+
{
|
803 |
+
"name": "stdout",
|
804 |
+
"output_type": "stream",
|
805 |
+
"text": [
|
806 |
+
"Collecting torchaudio\n",
|
807 |
+
" Downloading torchaudio-2.8.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (7.2 kB)\n",
|
808 |
+
"Collecting torch==2.8.0 (from torchaudio)\n",
|
809 |
+
" Using cached torch-2.8.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (30 kB)\n",
|
810 |
+
"Requirement already satisfied: filelock in /datadrive/jupyter/devbase/Balaji/TTS_ft/lib/python3.12/site-packages (from torch==2.8.0->torchaudio) (3.18.0)\n",
|
811 |
+
"Requirement already satisfied: typing-extensions>=4.10.0 in /datadrive/jupyter/devbase/Balaji/TTS_ft/lib/python3.12/site-packages (from torch==2.8.0->torchaudio) (4.14.1)\n",
|
812 |
+
"Requirement already satisfied: setuptools in /datadrive/jupyter/devbase/Balaji/TTS_ft/lib/python3.12/site-packages (from torch==2.8.0->torchaudio) (80.9.0)\n",
|
813 |
+
"Requirement already satisfied: sympy>=1.13.3 in /datadrive/jupyter/devbase/Balaji/TTS_ft/lib/python3.12/site-packages (from torch==2.8.0->torchaudio) (1.14.0)\n",
|
814 |
+
"Requirement already satisfied: networkx in /datadrive/jupyter/devbase/Balaji/TTS_ft/lib/python3.12/site-packages (from torch==2.8.0->torchaudio) (3.5)\n",
|
815 |
+
"Requirement already satisfied: jinja2 in /datadrive/jupyter/devbase/Balaji/TTS_ft/lib/python3.12/site-packages (from torch==2.8.0->torchaudio) (3.1.6)\n",
|
816 |
+
"Requirement already satisfied: fsspec in /datadrive/jupyter/devbase/Balaji/TTS_ft/lib/python3.12/site-packages (from torch==2.8.0->torchaudio) (2025.3.0)\n",
|
817 |
+
"Collecting nvidia-cuda-nvrtc-cu12==12.8.93 (from torch==2.8.0->torchaudio)\n",
|
818 |
+
" Using cached nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl.metadata (1.7 kB)\n",
|
819 |
+
"Collecting nvidia-cuda-runtime-cu12==12.8.90 (from torch==2.8.0->torchaudio)\n",
|
820 |
+
" Using cached nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.7 kB)\n",
|
821 |
+
"Collecting nvidia-cuda-cupti-cu12==12.8.90 (from torch==2.8.0->torchaudio)\n",
|
822 |
+
" Using cached nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.7 kB)\n",
|
823 |
+
"Collecting nvidia-cudnn-cu12==9.10.2.21 (from torch==2.8.0->torchaudio)\n",
|
824 |
+
" Using cached nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl.metadata (1.8 kB)\n",
|
825 |
+
"Collecting nvidia-cublas-cu12==12.8.4.1 (from torch==2.8.0->torchaudio)\n",
|
826 |
+
" Using cached nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl.metadata (1.7 kB)\n",
|
827 |
+
"Collecting nvidia-cufft-cu12==11.3.3.83 (from torch==2.8.0->torchaudio)\n",
|
828 |
+
" Using cached nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.7 kB)\n",
|
829 |
+
"Collecting nvidia-curand-cu12==10.3.9.90 (from torch==2.8.0->torchaudio)\n",
|
830 |
+
" Using cached nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl.metadata (1.7 kB)\n",
|
831 |
+
"Collecting nvidia-cusolver-cu12==11.7.3.90 (from torch==2.8.0->torchaudio)\n",
|
832 |
+
" Using cached nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl.metadata (1.8 kB)\n",
|
833 |
+
"Collecting nvidia-cusparse-cu12==12.5.8.93 (from torch==2.8.0->torchaudio)\n",
|
834 |
+
" Using cached nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.8 kB)\n",
|
835 |
+
"Collecting nvidia-cusparselt-cu12==0.7.1 (from torch==2.8.0->torchaudio)\n",
|
836 |
+
" Using cached nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl.metadata (7.0 kB)\n",
|
837 |
+
"Collecting nvidia-nccl-cu12==2.27.3 (from torch==2.8.0->torchaudio)\n",
|
838 |
+
" Using cached nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.0 kB)\n",
|
839 |
+
"Collecting nvidia-nvtx-cu12==12.8.90 (from torch==2.8.0->torchaudio)\n",
|
840 |
+
" Using cached nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.8 kB)\n",
|
841 |
+
"Collecting nvidia-nvjitlink-cu12==12.8.93 (from torch==2.8.0->torchaudio)\n",
|
842 |
+
" Using cached nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl.metadata (1.7 kB)\n",
|
843 |
+
"Collecting nvidia-cufile-cu12==1.13.1.3 (from torch==2.8.0->torchaudio)\n",
|
844 |
+
" Using cached nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.7 kB)\n",
|
845 |
+
"Collecting triton==3.4.0 (from torch==2.8.0->torchaudio)\n",
|
846 |
+
" Using cached triton-3.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.7 kB)\n",
|
847 |
+
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /datadrive/jupyter/devbase/Balaji/TTS_ft/lib/python3.12/site-packages (from sympy>=1.13.3->torch==2.8.0->torchaudio) (1.3.0)\n",
|
848 |
+
"Requirement already satisfied: MarkupSafe>=2.0 in /datadrive/jupyter/devbase/Balaji/TTS_ft/lib/python3.12/site-packages (from jinja2->torch==2.8.0->torchaudio) (3.0.2)\n",
|
849 |
+
"Downloading torchaudio-2.8.0-cp312-cp312-manylinux_2_28_x86_64.whl (4.0 MB)\n",
|
850 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.0/4.0 MB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:01\u001b[0m\n",
|
851 |
+
"\u001b[?25hDownloading torch-2.8.0-cp312-cp312-manylinux_2_28_x86_64.whl (887.9 MB)\n",
|
852 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m887.9/887.9 MB\u001b[0m \u001b[31m979.7 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m eta \u001b[36m0:00:01\u001b[0m[36m0:00:19\u001b[0mm\n",
|
853 |
+
"\u001b[?25hDownloading nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl (594.3 MB)\n",
|
854 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m594.3/594.3 MB\u001b[0m \u001b[31m1.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:13\u001b[0m\n",
|
855 |
+
"\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (10.2 MB)\n",
|
856 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.2/10.2 MB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:01\u001b[0m\n",
|
857 |
+
"\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl (88.0 MB)\n",
|
858 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m88.0/88.0 MB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:02\u001b[0m\n",
|
859 |
+
"\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (954 kB)\n",
|
860 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m954.8/954.8 kB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m[36m0:00:01\u001b[0m[36m0:00:01\u001b[0m:01\u001b[0m\n",
|
861 |
+
"\u001b[?25hDownloading nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl (706.8 MB)\n",
|
862 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m706.8/706.8 MB\u001b[0m \u001b[31m1.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:15\u001b[0mm\n",
|
863 |
+
"\u001b[?25hDownloading nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (193.1 MB)\n",
|
864 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m193.1/193.1 MB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:05\u001b[0m\n",
|
865 |
+
"\u001b[?25hDownloading nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (1.2 MB)\n",
|
866 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\n",
|
867 |
+
"\u001b[?25hDownloading nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl (63.6 MB)\n",
|
868 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m63.6/63.6 MB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:02\u001b[0m\n",
|
869 |
+
"\u001b[?25hDownloading nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl (267.5 MB)\n",
|
870 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m267.5/267.5 MB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:06\u001b[0m\n",
|
871 |
+
"\u001b[?25hDownloading nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (288.2 MB)\n",
|
872 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m288.2/288.2 MB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:07\u001b[0m\n",
|
873 |
+
"\u001b[?25hDownloading nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl (287.2 MB)\n",
|
874 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m287.2/287.2 MB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:06\u001b[0m\n",
|
875 |
+
"\u001b[?25hDownloading nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (322.4 MB)\n",
|
876 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m322.4/322.4 MB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:07\u001b[0m\n",
|
877 |
+
"\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl (39.3 MB)\n",
|
878 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m39.3/39.3 MB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:01\u001b[0m\n",
|
879 |
+
"\u001b[?25hDownloading nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (89 kB)\n",
|
880 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m90.0/90.0 kB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m:01\u001b[0m\n",
|
881 |
+
"\u001b[?25hDownloading triton-3.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (155.6 MB)\n",
|
882 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m155.6/155.6 MB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:04\u001b[0m\n",
|
883 |
+
"\u001b[?25hInstalling collected packages: nvidia-cusparselt-cu12, triton, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufile-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cufft-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch, torchaudio\n",
|
884 |
+
" Attempting uninstall: nvidia-cusparselt-cu12\n",
|
885 |
+
" Found existing installation: nvidia-cusparselt-cu12 0.6.3\n",
|
886 |
+
" Uninstalling nvidia-cusparselt-cu12-0.6.3:\n",
|
887 |
+
" Successfully uninstalled nvidia-cusparselt-cu12-0.6.3\n",
|
888 |
+
" Attempting uninstall: triton\n",
|
889 |
+
" Found existing installation: triton 3.3.1\n",
|
890 |
+
" Uninstalling triton-3.3.1:\n",
|
891 |
+
" Successfully uninstalled triton-3.3.1\n",
|
892 |
+
" Attempting uninstall: nvidia-nvtx-cu12\n",
|
893 |
+
" Found existing installation: nvidia-nvtx-cu12 12.6.77\n",
|
894 |
+
" Uninstalling nvidia-nvtx-cu12-12.6.77:\n",
|
895 |
+
" Successfully uninstalled nvidia-nvtx-cu12-12.6.77\n",
|
896 |
+
" Attempting uninstall: nvidia-nvjitlink-cu12\n",
|
897 |
+
" Found existing installation: nvidia-nvjitlink-cu12 12.6.85\n",
|
898 |
+
" Uninstalling nvidia-nvjitlink-cu12-12.6.85:\n",
|
899 |
+
" Successfully uninstalled nvidia-nvjitlink-cu12-12.6.85\n",
|
900 |
+
" Attempting uninstall: nvidia-nccl-cu12\n",
|
901 |
+
" Found existing installation: nvidia-nccl-cu12 2.26.2\n",
|
902 |
+
" Uninstalling nvidia-nccl-cu12-2.26.2:\n",
|
903 |
+
" Successfully uninstalled nvidia-nccl-cu12-2.26.2\n",
|
904 |
+
" Attempting uninstall: nvidia-curand-cu12\n",
|
905 |
+
" Found existing installation: nvidia-curand-cu12 10.3.7.77\n",
|
906 |
+
" Uninstalling nvidia-curand-cu12-10.3.7.77:\n",
|
907 |
+
" Successfully uninstalled nvidia-curand-cu12-10.3.7.77\n",
|
908 |
+
" Attempting uninstall: nvidia-cufile-cu12\n",
|
909 |
+
" Found existing installation: nvidia-cufile-cu12 1.11.1.6\n",
|
910 |
+
" Uninstalling nvidia-cufile-cu12-1.11.1.6:\n",
|
911 |
+
" Successfully uninstalled nvidia-cufile-cu12-1.11.1.6\n",
|
912 |
+
" Attempting uninstall: nvidia-cuda-runtime-cu12\n",
|
913 |
+
" Found existing installation: nvidia-cuda-runtime-cu12 12.6.77\n",
|
914 |
+
" Uninstalling nvidia-cuda-runtime-cu12-12.6.77:\n",
|
915 |
+
" Successfully uninstalled nvidia-cuda-runtime-cu12-12.6.77\n",
|
916 |
+
" Attempting uninstall: nvidia-cuda-nvrtc-cu12\n",
|
917 |
+
" Found existing installation: nvidia-cuda-nvrtc-cu12 12.6.77\n",
|
918 |
+
" Uninstalling nvidia-cuda-nvrtc-cu12-12.6.77:\n",
|
919 |
+
" Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.6.77\n",
|
920 |
+
" Attempting uninstall: nvidia-cuda-cupti-cu12\n",
|
921 |
+
" Found existing installation: nvidia-cuda-cupti-cu12 12.6.80\n",
|
922 |
+
" Uninstalling nvidia-cuda-cupti-cu12-12.6.80:\n",
|
923 |
+
" Successfully uninstalled nvidia-cuda-cupti-cu12-12.6.80\n",
|
924 |
+
" Attempting uninstall: nvidia-cublas-cu12\n",
|
925 |
+
" Found existing installation: nvidia-cublas-cu12 12.6.4.1\n",
|
926 |
+
" Uninstalling nvidia-cublas-cu12-12.6.4.1:\n",
|
927 |
+
" Successfully uninstalled nvidia-cublas-cu12-12.6.4.1\n",
|
928 |
+
" Attempting uninstall: nvidia-cusparse-cu12\n",
|
929 |
+
" Found existing installation: nvidia-cusparse-cu12 12.5.4.2\n",
|
930 |
+
" Uninstalling nvidia-cusparse-cu12-12.5.4.2:\n",
|
931 |
+
" Successfully uninstalled nvidia-cusparse-cu12-12.5.4.2\n",
|
932 |
+
" Attempting uninstall: nvidia-cufft-cu12\n",
|
933 |
+
" Found existing installation: nvidia-cufft-cu12 11.3.0.4\n",
|
934 |
+
" Uninstalling nvidia-cufft-cu12-11.3.0.4:\n",
|
935 |
+
" Successfully uninstalled nvidia-cufft-cu12-11.3.0.4\n",
|
936 |
+
" Attempting uninstall: nvidia-cudnn-cu12\n",
|
937 |
+
" Found existing installation: nvidia-cudnn-cu12 9.5.1.17\n",
|
938 |
+
" Uninstalling nvidia-cudnn-cu12-9.5.1.17:\n",
|
939 |
+
" Successfully uninstalled nvidia-cudnn-cu12-9.5.1.17\n",
|
940 |
+
" Attempting uninstall: nvidia-cusolver-cu12\n",
|
941 |
+
" Found existing installation: nvidia-cusolver-cu12 11.7.1.2\n",
|
942 |
+
" Uninstalling nvidia-cusolver-cu12-11.7.1.2:\n",
|
943 |
+
" Successfully uninstalled nvidia-cusolver-cu12-11.7.1.2\n",
|
944 |
+
" Attempting uninstall: torch\n",
|
945 |
+
" Found existing installation: torch 2.7.1\n",
|
946 |
+
" Uninstalling torch-2.7.1:\n",
|
947 |
+
" Successfully uninstalled torch-2.7.1\n",
|
948 |
+
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
|
949 |
+
"xformers 0.0.31.post1 requires torch==2.7.1, but you have torch 2.8.0 which is incompatible.\n",
|
950 |
+
"torchvision 0.22.1 requires torch==2.7.1, but you have torch 2.8.0 which is incompatible.\u001b[0m\u001b[31m\n",
|
951 |
+
"\u001b[0mSuccessfully installed nvidia-cublas-cu12-12.8.4.1 nvidia-cuda-cupti-cu12-12.8.90 nvidia-cuda-nvrtc-cu12-12.8.93 nvidia-cuda-runtime-cu12-12.8.90 nvidia-cudnn-cu12-9.10.2.21 nvidia-cufft-cu12-11.3.3.83 nvidia-cufile-cu12-1.13.1.3 nvidia-curand-cu12-10.3.9.90 nvidia-cusolver-cu12-11.7.3.90 nvidia-cusparse-cu12-12.5.8.93 nvidia-cusparselt-cu12-0.7.1 nvidia-nccl-cu12-2.27.3 nvidia-nvjitlink-cu12-12.8.93 nvidia-nvtx-cu12-12.8.90 torch-2.8.0 torchaudio-2.8.0 triton-3.4.0\n",
|
952 |
+
"Note: you may need to restart the kernel to use updated packages.\n"
|
953 |
+
]
|
954 |
+
}
|
955 |
+
],
|
956 |
+
"source": [
|
957 |
+
"%pip install torchaudio"
|
958 |
+
]
|
959 |
+
},
|
960 |
+
{
|
961 |
+
"cell_type": "markdown",
|
962 |
+
"metadata": {
|
963 |
+
"id": "idAEIeSQ3xdS"
|
964 |
+
},
|
965 |
+
"source": [
|
966 |
+
"<a name=\"Train\"></a>\n",
|
967 |
+
"### Train the model\n",
|
968 |
+
"Now let's use Huggingface TRL's `SFTTrainer`! More docs here: [TRL SFT docs](https://huggingface.co/docs/trl/sft_trainer). We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`. We also support TRL's `DPOTrainer`!"
|
969 |
+
]
|
970 |
+
},
|
971 |
+
{
|
972 |
+
"cell_type": "code",
|
973 |
+
"execution_count": null,
|
974 |
+
"metadata": {
|
975 |
+
"id": "95_Nn-89DhsL"
|
976 |
+
},
|
977 |
+
"outputs": [],
|
978 |
+
"source": [
|
979 |
+
"from trl import SFTConfig, SFTTrainer\n",
|
980 |
+
"trainer = SFTTrainer(\n",
|
981 |
+
" model = model,\n",
|
982 |
+
" tokenizer = tokenizer,\n",
|
983 |
+
" train_dataset = dataset,\n",
|
984 |
+
" dataset_text_field = \"text\",\n",
|
985 |
+
" max_seq_length = max_seq_length,\n",
|
986 |
+
" packing = False, # Can make training 5x faster for short sequences.\n",
|
987 |
+
" args = SFTConfig(\n",
|
988 |
+
" per_device_train_batch_size = 2,\n",
|
989 |
+
" gradient_accumulation_steps = 4,\n",
|
990 |
+
" warmup_steps = 5,\n",
|
991 |
+
" num_train_epochs = 5, # Set this for 1 full training run.\n",
|
992 |
+
" #max_steps = 60,\n",
|
993 |
+
" learning_rate = 1e-5,\n",
|
994 |
+
" fp16 = False, # We're doing full float32 s disable mixed precision\n",
|
995 |
+
" bf16 = False, # We're doing full float32 s disable mixed precision\n",
|
996 |
+
" logging_steps = 1,\n",
|
997 |
+
" optim = \"adamw_8bit\",\n",
|
998 |
+
" weight_decay = 0.01,\n",
|
999 |
+
" lr_scheduler_type = \"linear\",\n",
|
1000 |
+
" seed = 3407,\n",
|
1001 |
+
" output_dir = \"outputs\",\n",
|
1002 |
+
" report_to = \"tensorboard\", # Use this for WandB etc\n",
|
1003 |
+
" ),\n",
|
1004 |
+
")"
|
1005 |
+
]
|
1006 |
+
},
|
1007 |
+
{
|
1008 |
+
"cell_type": "code",
|
1009 |
+
"execution_count": null,
|
1010 |
+
"metadata": {
|
1011 |
+
"id": "2ejIt2xSNKKp"
|
1012 |
+
},
|
1013 |
+
"outputs": [],
|
1014 |
+
"source": [
|
1015 |
+
"# @title Show current memory stats\n",
|
1016 |
+
"gpu_stats = torch.cuda.get_device_properties(0)\n",
|
1017 |
+
"start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
|
1018 |
+
"max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n",
|
1019 |
+
"print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n",
|
1020 |
+
"print(f\"{start_gpu_memory} GB of memory reserved.\")"
|
1021 |
+
]
|
1022 |
+
},
|
1023 |
+
{
|
1024 |
+
"cell_type": "code",
|
1025 |
+
"execution_count": null,
|
1026 |
+
"metadata": {
|
1027 |
+
"id": "yqxqAZ7KJ4oL"
|
1028 |
+
},
|
1029 |
+
"outputs": [],
|
1030 |
+
"source": [
|
1031 |
+
"trainer_stats = trainer.train()"
|
1032 |
+
]
|
1033 |
+
},
|
1034 |
+
{
|
1035 |
+
"cell_type": "code",
|
1036 |
+
"execution_count": null,
|
1037 |
+
"metadata": {
|
1038 |
+
"cellView": "form",
|
1039 |
+
"id": "pCqnaKmlO1U9"
|
1040 |
+
},
|
1041 |
+
"outputs": [],
|
1042 |
+
"source": [
|
1043 |
+
"# @title Show final memory and time stats\n",
|
1044 |
+
"used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
|
1045 |
+
"used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n",
|
1046 |
+
"used_percentage = round(used_memory / max_memory * 100, 3)\n",
|
1047 |
+
"lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n",
|
1048 |
+
"print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n",
|
1049 |
+
"print(\n",
|
1050 |
+
" f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\"\n",
|
1051 |
+
")\n",
|
1052 |
+
"print(f\"Peak reserved memory = {used_memory} GB.\")\n",
|
1053 |
+
"print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n",
|
1054 |
+
"print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n",
|
1055 |
+
"print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")"
|
1056 |
+
]
|
1057 |
+
},
|
1058 |
+
{
|
1059 |
+
"cell_type": "markdown",
|
1060 |
+
"metadata": {
|
1061 |
+
"id": "ekOmTR1hSNcr"
|
1062 |
+
},
|
1063 |
+
"source": [
|
1064 |
+
"<a name=\"Inference\"></a>\n",
|
1065 |
+
"### Inference\n",
|
1066 |
+
"Let's run the model! You can change the prompts\n"
|
1067 |
+
]
|
1068 |
+
},
|
1069 |
+
{
|
1070 |
+
"cell_type": "code",
|
1071 |
+
"execution_count": null,
|
1072 |
+
"metadata": {
|
1073 |
+
"id": "apUdB40Ep6Ki"
|
1074 |
+
},
|
1075 |
+
"outputs": [],
|
1076 |
+
"source": [
|
1077 |
+
"input_text = \"Hey there my name is Elise, <giggles> and I'm a speech generation model that can sound like a person.\"\n",
|
1078 |
+
"\n",
|
1079 |
+
"chosen_voice = None # None for single-speaker"
|
1080 |
+
]
|
1081 |
+
},
|
1082 |
+
{
|
1083 |
+
"cell_type": "code",
|
1084 |
+
"execution_count": null,
|
1085 |
+
"metadata": {
|
1086 |
+
"cellView": "form",
|
1087 |
+
"execution": {
|
1088 |
+
"iopub.execute_input": "2025-03-22T00:52:35.040842Z",
|
1089 |
+
"iopub.status.busy": "2025-03-22T00:52:35.040125Z",
|
1090 |
+
"iopub.status.idle": "2025-03-22T00:52:35.050560Z",
|
1091 |
+
"shell.execute_reply": "2025-03-22T00:52:35.049663Z",
|
1092 |
+
"shell.execute_reply.started": "2025-03-22T00:52:35.040818Z"
|
1093 |
+
},
|
1094 |
+
"id": "krYI8PrRJ6MX"
|
1095 |
+
},
|
1096 |
+
"outputs": [],
|
1097 |
+
"source": [
|
1098 |
+
"#@title Run Inference\n",
|
1099 |
+
"\n",
|
1100 |
+
"import torch\n",
|
1101 |
+
"import re\n",
|
1102 |
+
"import numpy as np\n",
|
1103 |
+
"from typing import Dict, Any\n",
|
1104 |
+
"import torchaudio.transforms as T\n",
|
1105 |
+
"\n",
|
1106 |
+
"FastModel.for_inference(model) # Enable native 2x faster inference\n",
|
1107 |
+
"\n",
|
1108 |
+
"@torch.inference_mode()\n",
|
1109 |
+
"def generate_speech_from_text(\n",
|
1110 |
+
" text: str,\n",
|
1111 |
+
" temperature: float = 0.8, # Generation temperature\n",
|
1112 |
+
" top_k: int = 50, # Generation top_k\n",
|
1113 |
+
" top_p: float = 1, # Generation top_p\n",
|
1114 |
+
" max_new_audio_tokens: int = 2048, # Max tokens for audio part\n",
|
1115 |
+
" device: torch.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
1116 |
+
") -> np.ndarray:\n",
|
1117 |
+
" \"\"\"\n",
|
1118 |
+
" Generates speech audio from text using default voice control parameters.\n",
|
1119 |
+
"\n",
|
1120 |
+
" Args:\n",
|
1121 |
+
" text (str): The text input to be converted to speech.\n",
|
1122 |
+
" temperature (float): Sampling temperature for generation.\n",
|
1123 |
+
" top_k (int): Top-k sampling parameter.\n",
|
1124 |
+
" top_p (float): Top-p (nucleus) sampling parameter.\n",
|
1125 |
+
" max_new_audio_tokens (int): Max number of new tokens to generate (limits audio length).\n",
|
1126 |
+
" device (torch.device): Device to run inference on.\n",
|
1127 |
+
"\n",
|
1128 |
+
" Returns:\n",
|
1129 |
+
" np.ndarray: Generated waveform as a NumPy array.\n",
|
1130 |
+
" \"\"\"\n",
|
1131 |
+
"\n",
|
1132 |
+
" torch.compiler.reset()\n",
|
1133 |
+
"\n",
|
1134 |
+
" prompt = \"\".join([\n",
|
1135 |
+
" \"<|task_tts|>\",\n",
|
1136 |
+
" \"<|start_content|>\",\n",
|
1137 |
+
" text,\n",
|
1138 |
+
" \"<|end_content|>\",\n",
|
1139 |
+
" \"<|start_global_token|>\"\n",
|
1140 |
+
" ])\n",
|
1141 |
+
"\n",
|
1142 |
+
" model_inputs = tokenizer([prompt], return_tensors=\"pt\").to(device)\n",
|
1143 |
+
"\n",
|
1144 |
+
" print(\"Generating token sequence...\")\n",
|
1145 |
+
" generated_ids = model.generate(\n",
|
1146 |
+
" **model_inputs,\n",
|
1147 |
+
" max_new_tokens=max_new_audio_tokens, # Limit generation length\n",
|
1148 |
+
" do_sample=True,\n",
|
1149 |
+
" temperature=temperature,\n",
|
1150 |
+
" top_k=top_k,\n",
|
1151 |
+
" top_p=top_p,\n",
|
1152 |
+
" eos_token_id=tokenizer.eos_token_id, # Stop token\n",
|
1153 |
+
" pad_token_id=tokenizer.pad_token_id # Use models pad token id\n",
|
1154 |
+
" )\n",
|
1155 |
+
" print(\"Token sequence generated.\")\n",
|
1156 |
+
"\n",
|
1157 |
+
"\n",
|
1158 |
+
" generated_ids_trimmed = generated_ids[:, model_inputs.input_ids.shape[1]:]\n",
|
1159 |
+
"\n",
|
1160 |
+
"\n",
|
1161 |
+
" predicts_text = tokenizer.batch_decode(generated_ids_trimmed, skip_special_tokens=False)[0]\n",
|
1162 |
+
" # print(f\"\\nGenerated Text (for parsing):\\n{predicts_text}\\n\") # Debugging\n",
|
1163 |
+
"\n",
|
1164 |
+
" # Extract semantic token IDs using regex\n",
|
1165 |
+
" semantic_matches = re.findall(r\"<\\|bicodec_semantic_(\\d+)\\|>\", predicts_text)\n",
|
1166 |
+
" if not semantic_matches:\n",
|
1167 |
+
" print(\"Warning: No semantic tokens found in the generated output.\")\n",
|
1168 |
+
" # Handle appropriately - perhaps return silence or raise error\n",
|
1169 |
+
" return np.array([], dtype=np.float32)\n",
|
1170 |
+
"\n",
|
1171 |
+
" pred_semantic_ids = torch.tensor([int(token) for token in semantic_matches]).long().unsqueeze(0) # Add batch dim\n",
|
1172 |
+
"\n",
|
1173 |
+
" # Extract global token IDs using regex (assuming controllable mode also generates these)\n",
|
1174 |
+
" global_matches = re.findall(r\"<\\|bicodec_global_(\\d+)\\|>\", predicts_text)\n",
|
1175 |
+
" if not global_matches:\n",
|
1176 |
+
" print(\"Warning: No global tokens found in the generated output (controllable mode). Might use defaults or fail.\")\n",
|
1177 |
+
" pred_global_ids = torch.zeros((1, 1), dtype=torch.long)\n",
|
1178 |
+
" else:\n",
|
1179 |
+
" pred_global_ids = torch.tensor([int(token) for token in global_matches]).long().unsqueeze(0) # Add batch dim\n",
|
1180 |
+
"\n",
|
1181 |
+
" pred_global_ids = pred_global_ids.unsqueeze(0) # Shape becomes (1, 1, N_global)\n",
|
1182 |
+
"\n",
|
1183 |
+
" print(f\"Found {pred_semantic_ids.shape[1]} semantic tokens.\")\n",
|
1184 |
+
" print(f\"Found {pred_global_ids.shape[2]} global tokens.\")\n",
|
1185 |
+
"\n",
|
1186 |
+
"\n",
|
1187 |
+
" # 5. Detokenize using BiCodecTokenizer\n",
|
1188 |
+
" print(\"Detokenizing audio tokens...\")\n",
|
1189 |
+
" # Ensure audio_tokenizer and its internal model are on the correct device\n",
|
1190 |
+
" audio_tokenizer.device = device\n",
|
1191 |
+
" audio_tokenizer.model.to(device)\n",
|
1192 |
+
" # Squeeze the extra dimension from global tokens as seen in SparkTTS example\n",
|
1193 |
+
" wav_np = audio_tokenizer.detokenize(\n",
|
1194 |
+
" pred_global_ids.to(device).squeeze(0), # Shape (1, N_global)\n",
|
1195 |
+
" pred_semantic_ids.to(device) # Shape (1, N_semantic)\n",
|
1196 |
+
" )\n",
|
1197 |
+
" print(\"Detokenization complete.\")\n",
|
1198 |
+
"\n",
|
1199 |
+
" return wav_np\n",
|
1200 |
+
"\n",
|
1201 |
+
"if __name__ == \"__main__\":\n",
|
1202 |
+
" print(f\"Generating speech for: '{input_text}'\")\n",
|
1203 |
+
" text = f\"{chosen_voice}: \" + input_text if chosen_voice else input_text\n",
|
1204 |
+
" generated_waveform = generate_speech_from_text(input_text)\n",
|
1205 |
+
"\n",
|
1206 |
+
" if generated_waveform.size > 0:\n",
|
1207 |
+
" import soundfile as sf\n",
|
1208 |
+
" output_filename = \"generated_speech_controllable.wav\"\n",
|
1209 |
+
" sample_rate = audio_tokenizer.config.get(\"sample_rate\", 16000)\n",
|
1210 |
+
" sf.write(output_filename, generated_waveform, sample_rate)\n",
|
1211 |
+
" print(f\"Audio saved to {output_filename}\")\n",
|
1212 |
+
"\n",
|
1213 |
+
" # Optional: Play in notebook\n",
|
1214 |
+
" from IPython.display import Audio, display\n",
|
1215 |
+
" display(Audio(generated_waveform, rate=sample_rate))\n",
|
1216 |
+
" else:\n",
|
1217 |
+
" print(\"Audio generation failed (no tokens found?).\")"
|
1218 |
+
]
|
1219 |
+
},
|
1220 |
+
{
|
1221 |
+
"cell_type": "markdown",
|
1222 |
+
"metadata": {
|
1223 |
+
"id": "uMuVrWbjAzhc"
|
1224 |
+
},
|
1225 |
+
"source": [
|
1226 |
+
"<a name=\"Save\"></a>\n",
|
1227 |
+
"### Saving, loading finetuned models\n",
|
1228 |
+
"To save the final model as LoRA adapters, either use Huggingface's `push_to_hub` for an online save or `save_pretrained` for a local save.\n",
|
1229 |
+
"\n",
|
1230 |
+
"**[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!"
|
1231 |
+
]
|
1232 |
+
},
|
1233 |
+
{
|
1234 |
+
"cell_type": "code",
|
1235 |
+
"execution_count": null,
|
1236 |
+
"metadata": {
|
1237 |
+
"id": "upcOlWe7A1vc"
|
1238 |
+
},
|
1239 |
+
"outputs": [],
|
1240 |
+
"source": [
|
1241 |
+
"model.save_pretrained(\"lora_model\") # Local saving\n",
|
1242 |
+
"tokenizer.save_pretrained(\"lora_model\")\n",
|
1243 |
+
"# model.push_to_hub(\"your_name/lora_model\", token = \"...\") # Online saving\n",
|
1244 |
+
"# tokenizer.push_to_hub(\"your_name/lora_model\", token = \"...\") # Online saving"
|
1245 |
+
]
|
1246 |
+
},
|
1247 |
+
{
|
1248 |
+
"cell_type": "markdown",
|
1249 |
+
"metadata": {
|
1250 |
+
"id": "f422JgM9sdVT"
|
1251 |
+
},
|
1252 |
+
"source": [
|
1253 |
+
"\n",
|
1254 |
+
"### Saving to float16\n",
|
1255 |
+
"\n",
|
1256 |
+
"We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens."
|
1257 |
+
]
|
1258 |
+
},
|
1259 |
+
{
|
1260 |
+
"cell_type": "code",
|
1261 |
+
"execution_count": null,
|
1262 |
+
"metadata": {
|
1263 |
+
"colab": {
|
1264 |
+
"base_uri": "https://localhost:8080/"
|
1265 |
+
},
|
1266 |
+
"id": "iHjt_SMYsd3P",
|
1267 |
+
"outputId": "bd8cccb7-6b95-45bf-80da-de120988447e"
|
1268 |
+
},
|
1269 |
+
"outputs": [
|
1270 |
+
{
|
1271 |
+
"name": "stderr",
|
1272 |
+
"output_type": "stream",
|
1273 |
+
"text": [
|
1274 |
+
"Unsloth: You have 1 CPUs. Using `safe_serialization` is 10x slower.\n",
|
1275 |
+
"We shall switch to Pytorch saving, which might take 3 minutes and not 30 minutes.\n",
|
1276 |
+
"To force `safe_serialization`, set it to `None` instead.\n",
|
1277 |
+
"Unsloth: Kaggle/Colab has limited disk space. We need to delete the downloaded\n",
|
1278 |
+
"model which will save 4-16GB of disk space, allowing you to save on Kaggle/Colab.\n",
|
1279 |
+
"Unsloth: Will remove a cached repo with size 15.1G\n"
|
1280 |
+
]
|
1281 |
+
},
|
1282 |
+
{
|
1283 |
+
"name": "stdout",
|
1284 |
+
"output_type": "stream",
|
1285 |
+
"text": [
|
1286 |
+
"Unsloth: Merging 4bit and LoRA weights to 16bit...\n",
|
1287 |
+
"Unsloth: Will use up to 3.99 out of 12.67 RAM for saving.\n",
|
1288 |
+
"Unsloth: Saving model... This might take 5 minutes ...\n"
|
1289 |
+
]
|
1290 |
+
},
|
1291 |
+
{
|
1292 |
+
"name": "stderr",
|
1293 |
+
"output_type": "stream",
|
1294 |
+
"text": [
|
1295 |
+
"100%|██████████| 28/28 [00:01<00:00, 27.83it/s]\n"
|
1296 |
+
]
|
1297 |
+
},
|
1298 |
+
{
|
1299 |
+
"name": "stdout",
|
1300 |
+
"output_type": "stream",
|
1301 |
+
"text": [
|
1302 |
+
"Unsloth: Saving tokenizer... Done.\n",
|
1303 |
+
"Unsloth: Saving model/pytorch_model-00001-of-00002.bin...\n",
|
1304 |
+
"Unsloth: Saving model/pytorch_model-00002-of-00002.bin...\n",
|
1305 |
+
"Done.\n"
|
1306 |
+
]
|
1307 |
+
}
|
1308 |
+
],
|
1309 |
+
"source": [
|
1310 |
+
"# Merge to 16bit\n",
|
1311 |
+
"if False: model.save_pretrained_merged(\"model\", tokenizer, save_method = \"merged_16bit\",)\n",
|
1312 |
+
"if False: model.push_to_hub_merged(\"hf/model\", tokenizer, save_method = \"merged_16bit\", token = \"\")\n",
|
1313 |
+
"\n",
|
1314 |
+
"# Merge to 4bit\n",
|
1315 |
+
"if False: model.save_pretrained_merged(\"model\", tokenizer, save_method = \"merged_4bit\",)\n",
|
1316 |
+
"if False: model.push_to_hub_merged(\"hf/model\", tokenizer, save_method = \"merged_4bit\", token = \"\")\n",
|
1317 |
+
"\n",
|
1318 |
+
"# Just LoRA adapters\n",
|
1319 |
+
"if False:\n",
|
1320 |
+
" model.save_pretrained(\"model\")\n",
|
1321 |
+
" tokenizer.save_pretrained(\"model\")\n",
|
1322 |
+
"if False:\n",
|
1323 |
+
" model.push_to_hub(\"hf/model\", token = \"\")\n",
|
1324 |
+
" tokenizer.push_to_hub(\"hf/model\", token = \"\")\n"
|
1325 |
+
]
|
1326 |
+
},
|
1327 |
+
{
|
1328 |
+
"cell_type": "markdown",
|
1329 |
+
"metadata": {
|
1330 |
+
"id": "egOSE7Cgynx7"
|
1331 |
+
},
|
1332 |
+
"source": [
|
1333 |
+
"And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!\n",
|
1334 |
+
"\n",
|
1335 |
+
"Some other links:\n",
|
1336 |
+
"1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)\n",
|
1337 |
+
"2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)\n",
|
1338 |
+
"3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)\n",
|
1339 |
+
"6. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://docs.unsloth.ai/get-started/unsloth-notebooks)!\n",
|
1340 |
+
"\n",
|
1341 |
+
"<div class=\"align-center\">\n",
|
1342 |
+
" <a href=\"https://unsloth.ai\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png\" width=\"115\"></a>\n",
|
1343 |
+
" <a href=\"https://discord.gg/unsloth\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/Discord.png\" width=\"145\"></a>\n",
|
1344 |
+
" <a href=\"https://docs.unsloth.ai/\"><img src=\"https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true\" width=\"125\"></a>\n",
|
1345 |
+
"\n",
|
1346 |
+
" Join Discord if you need help + ⭐️ <i>Star us on <a href=\"https://github.com/unslothai/unsloth\">Github</a> </i> ⭐️\n",
|
1347 |
+
"</div>\n"
|
1348 |
+
]
|
1349 |
+
}
|
1350 |
+
],
|
1351 |
+
"metadata": {
|
1352 |
+
"accelerator": "GPU",
|
1353 |
+
"colab": {
|
1354 |
+
"gpuType": "T4",
|
1355 |
+
"provenance": []
|
1356 |
+
},
|
1357 |
+
"kaggle": {
|
1358 |
+
"accelerator": "nvidiaTeslaT4",
|
1359 |
+
"dataSources": [],
|
1360 |
+
"dockerImageVersionId": 30919,
|
1361 |
+
"isGpuEnabled": true,
|
1362 |
+
"isInternetEnabled": true,
|
1363 |
+
"language": "python",
|
1364 |
+
"sourceType": "notebook"
|
1365 |
+
},
|
1366 |
+
"kernelspec": {
|
1367 |
+
"display_name": "TTS_ft",
|
1368 |
+
"language": "python",
|
1369 |
+
"name": "tts_ft"
|
1370 |
+
},
|
1371 |
+
"language_info": {
|
1372 |
+
"codemirror_mode": {
|
1373 |
+
"name": "ipython",
|
1374 |
+
"version": 3
|
1375 |
+
},
|
1376 |
+
"file_extension": ".py",
|
1377 |
+
"mimetype": "text/x-python",
|
1378 |
+
"name": "python",
|
1379 |
+
"nbconvert_exporter": "python",
|
1380 |
+
"pygments_lexer": "ipython3",
|
1381 |
+
"version": "3.12.3"
|
1382 |
+
},
|
1383 |
+
"widgets": {
|
1384 |
+
"application/vnd.jupyter.widget-state+json": {
|
1385 |
+
"0474debc340943bd85f3daf92aebf7aa": {
|
1386 |
+
"model_module": "@jupyter-widgets/controls",
|
1387 |
+
"model_module_version": "1.5.0",
|
1388 |
+
"model_name": "FloatProgressModel",
|
1389 |
+
"state": {
|
1390 |
+
"_dom_classes": [],
|
1391 |
+
"_model_module": "@jupyter-widgets/controls",
|
1392 |
+
"_model_module_version": "1.5.0",
|
1393 |
+
"_model_name": "FloatProgressModel",
|
1394 |
+
"_view_count": null,
|
1395 |
+
"_view_module": "@jupyter-widgets/controls",
|
1396 |
+
"_view_module_version": "1.5.0",
|
1397 |
+
"_view_name": "ProgressView",
|
1398 |
+
"bar_style": "",
|
1399 |
+
"description": "",
|
1400 |
+
"description_tooltip": null,
|
1401 |
+
"layout": "IPY_MODEL_0de4d0f282404edfbc191dca73f15f35",
|
1402 |
+
"max": 401,
|
1403 |
+
"min": 0,
|
1404 |
+
"orientation": "horizontal",
|
1405 |
+
"style": "IPY_MODEL_e58b5ad2f781475d8af2ddb38009baa6",
|
1406 |
+
"value": 354
|
1407 |
+
}
|
1408 |
+
},
|
1409 |
+
"0de4d0f282404edfbc191dca73f15f35": {
|
1410 |
+
"model_module": "@jupyter-widgets/base",
|
1411 |
+
"model_module_version": "1.2.0",
|
1412 |
+
"model_name": "LayoutModel",
|
1413 |
+
"state": {
|
1414 |
+
"_model_module": "@jupyter-widgets/base",
|
1415 |
+
"_model_module_version": "1.2.0",
|
1416 |
+
"_model_name": "LayoutModel",
|
1417 |
+
"_view_count": null,
|
1418 |
+
"_view_module": "@jupyter-widgets/base",
|
1419 |
+
"_view_module_version": "1.2.0",
|
1420 |
+
"_view_name": "LayoutView",
|
1421 |
+
"align_content": null,
|
1422 |
+
"align_items": null,
|
1423 |
+
"align_self": null,
|
1424 |
+
"border": null,
|
1425 |
+
"bottom": null,
|
1426 |
+
"display": null,
|
1427 |
+
"flex": null,
|
1428 |
+
"flex_flow": null,
|
1429 |
+
"grid_area": null,
|
1430 |
+
"grid_auto_columns": null,
|
1431 |
+
"grid_auto_flow": null,
|
1432 |
+
"grid_auto_rows": null,
|
1433 |
+
"grid_column": null,
|
1434 |
+
"grid_gap": null,
|
1435 |
+
"grid_row": null,
|
1436 |
+
"grid_template_areas": null,
|
1437 |
+
"grid_template_columns": null,
|
1438 |
+
"grid_template_rows": null,
|
1439 |
+
"height": null,
|
1440 |
+
"justify_content": null,
|
1441 |
+
"justify_items": null,
|
1442 |
+
"left": null,
|
1443 |
+
"margin": null,
|
1444 |
+
"max_height": null,
|
1445 |
+
"max_width": null,
|
1446 |
+
"min_height": null,
|
1447 |
+
"min_width": null,
|
1448 |
+
"object_fit": null,
|
1449 |
+
"object_position": null,
|
1450 |
+
"order": null,
|
1451 |
+
"overflow": null,
|
1452 |
+
"overflow_x": null,
|
1453 |
+
"overflow_y": null,
|
1454 |
+
"padding": null,
|
1455 |
+
"right": null,
|
1456 |
+
"top": null,
|
1457 |
+
"visibility": null,
|
1458 |
+
"width": null
|
1459 |
+
}
|
1460 |
+
},
|
1461 |
+
"2315228ff2b141afabe1263471f5364b": {
|
1462 |
+
"model_module": "@jupyter-widgets/controls",
|
1463 |
+
"model_module_version": "1.5.0",
|
1464 |
+
"model_name": "HTMLModel",
|
1465 |
+
"state": {
|
1466 |
+
"_dom_classes": [],
|
1467 |
+
"_model_module": "@jupyter-widgets/controls",
|
1468 |
+
"_model_module_version": "1.5.0",
|
1469 |
+
"_model_name": "HTMLModel",
|
1470 |
+
"_view_count": null,
|
1471 |
+
"_view_module": "@jupyter-widgets/controls",
|
1472 |
+
"_view_module_version": "1.5.0",
|
1473 |
+
"_view_name": "HTMLView",
|
1474 |
+
"description": "",
|
1475 |
+
"description_tooltip": null,
|
1476 |
+
"layout": "IPY_MODEL_426eb100a94642f79e6b99777406a265",
|
1477 |
+
"placeholder": "",
|
1478 |
+
"style": "IPY_MODEL_a36b5cf197dd4bd9a7f70aa6671b804c",
|
1479 |
+
"value": "Map: 88%"
|
1480 |
+
}
|
1481 |
+
},
|
1482 |
+
"33fbacbb2aa146cd90586357eec1dc3e": {
|
1483 |
+
"model_module": "@jupyter-widgets/base",
|
1484 |
+
"model_module_version": "1.2.0",
|
1485 |
+
"model_name": "LayoutModel",
|
1486 |
+
"state": {
|
1487 |
+
"_model_module": "@jupyter-widgets/base",
|
1488 |
+
"_model_module_version": "1.2.0",
|
1489 |
+
"_model_name": "LayoutModel",
|
1490 |
+
"_view_count": null,
|
1491 |
+
"_view_module": "@jupyter-widgets/base",
|
1492 |
+
"_view_module_version": "1.2.0",
|
1493 |
+
"_view_name": "LayoutView",
|
1494 |
+
"align_content": null,
|
1495 |
+
"align_items": null,
|
1496 |
+
"align_self": null,
|
1497 |
+
"border": null,
|
1498 |
+
"bottom": null,
|
1499 |
+
"display": null,
|
1500 |
+
"flex": null,
|
1501 |
+
"flex_flow": null,
|
1502 |
+
"grid_area": null,
|
1503 |
+
"grid_auto_columns": null,
|
1504 |
+
"grid_auto_flow": null,
|
1505 |
+
"grid_auto_rows": null,
|
1506 |
+
"grid_column": null,
|
1507 |
+
"grid_gap": null,
|
1508 |
+
"grid_row": null,
|
1509 |
+
"grid_template_areas": null,
|
1510 |
+
"grid_template_columns": null,
|
1511 |
+
"grid_template_rows": null,
|
1512 |
+
"height": null,
|
1513 |
+
"justify_content": null,
|
1514 |
+
"justify_items": null,
|
1515 |
+
"left": null,
|
1516 |
+
"margin": null,
|
1517 |
+
"max_height": null,
|
1518 |
+
"max_width": null,
|
1519 |
+
"min_height": null,
|
1520 |
+
"min_width": null,
|
1521 |
+
"object_fit": null,
|
1522 |
+
"object_position": null,
|
1523 |
+
"order": null,
|
1524 |
+
"overflow": null,
|
1525 |
+
"overflow_x": null,
|
1526 |
+
"overflow_y": null,
|
1527 |
+
"padding": null,
|
1528 |
+
"right": null,
|
1529 |
+
"top": null,
|
1530 |
+
"visibility": null,
|
1531 |
+
"width": null
|
1532 |
+
}
|
1533 |
+
},
|
1534 |
+
"426eb100a94642f79e6b99777406a265": {
|
1535 |
+
"model_module": "@jupyter-widgets/base",
|
1536 |
+
"model_module_version": "1.2.0",
|
1537 |
+
"model_name": "LayoutModel",
|
1538 |
+
"state": {
|
1539 |
+
"_model_module": "@jupyter-widgets/base",
|
1540 |
+
"_model_module_version": "1.2.0",
|
1541 |
+
"_model_name": "LayoutModel",
|
1542 |
+
"_view_count": null,
|
1543 |
+
"_view_module": "@jupyter-widgets/base",
|
1544 |
+
"_view_module_version": "1.2.0",
|
1545 |
+
"_view_name": "LayoutView",
|
1546 |
+
"align_content": null,
|
1547 |
+
"align_items": null,
|
1548 |
+
"align_self": null,
|
1549 |
+
"border": null,
|
1550 |
+
"bottom": null,
|
1551 |
+
"display": null,
|
1552 |
+
"flex": null,
|
1553 |
+
"flex_flow": null,
|
1554 |
+
"grid_area": null,
|
1555 |
+
"grid_auto_columns": null,
|
1556 |
+
"grid_auto_flow": null,
|
1557 |
+
"grid_auto_rows": null,
|
1558 |
+
"grid_column": null,
|
1559 |
+
"grid_gap": null,
|
1560 |
+
"grid_row": null,
|
1561 |
+
"grid_template_areas": null,
|
1562 |
+
"grid_template_columns": null,
|
1563 |
+
"grid_template_rows": null,
|
1564 |
+
"height": null,
|
1565 |
+
"justify_content": null,
|
1566 |
+
"justify_items": null,
|
1567 |
+
"left": null,
|
1568 |
+
"margin": null,
|
1569 |
+
"max_height": null,
|
1570 |
+
"max_width": null,
|
1571 |
+
"min_height": null,
|
1572 |
+
"min_width": null,
|
1573 |
+
"object_fit": null,
|
1574 |
+
"object_position": null,
|
1575 |
+
"order": null,
|
1576 |
+
"overflow": null,
|
1577 |
+
"overflow_x": null,
|
1578 |
+
"overflow_y": null,
|
1579 |
+
"padding": null,
|
1580 |
+
"right": null,
|
1581 |
+
"top": null,
|
1582 |
+
"visibility": null,
|
1583 |
+
"width": null
|
1584 |
+
}
|
1585 |
+
},
|
1586 |
+
"930b4d1d5f4b494b830df4d4c398e67c": {
|
1587 |
+
"model_module": "@jupyter-widgets/controls",
|
1588 |
+
"model_module_version": "1.5.0",
|
1589 |
+
"model_name": "DescriptionStyleModel",
|
1590 |
+
"state": {
|
1591 |
+
"_model_module": "@jupyter-widgets/controls",
|
1592 |
+
"_model_module_version": "1.5.0",
|
1593 |
+
"_model_name": "DescriptionStyleModel",
|
1594 |
+
"_view_count": null,
|
1595 |
+
"_view_module": "@jupyter-widgets/base",
|
1596 |
+
"_view_module_version": "1.2.0",
|
1597 |
+
"_view_name": "StyleView",
|
1598 |
+
"description_width": ""
|
1599 |
+
}
|
1600 |
+
},
|
1601 |
+
"a36b5cf197dd4bd9a7f70aa6671b804c": {
|
1602 |
+
"model_module": "@jupyter-widgets/controls",
|
1603 |
+
"model_module_version": "1.5.0",
|
1604 |
+
"model_name": "DescriptionStyleModel",
|
1605 |
+
"state": {
|
1606 |
+
"_model_module": "@jupyter-widgets/controls",
|
1607 |
+
"_model_module_version": "1.5.0",
|
1608 |
+
"_model_name": "DescriptionStyleModel",
|
1609 |
+
"_view_count": null,
|
1610 |
+
"_view_module": "@jupyter-widgets/base",
|
1611 |
+
"_view_module_version": "1.2.0",
|
1612 |
+
"_view_name": "StyleView",
|
1613 |
+
"description_width": ""
|
1614 |
+
}
|
1615 |
+
},
|
1616 |
+
"a3b0c0581f1f4c428baaadd8e9a39b6f": {
|
1617 |
+
"model_module": "@jupyter-widgets/controls",
|
1618 |
+
"model_module_version": "1.5.0",
|
1619 |
+
"model_name": "HBoxModel",
|
1620 |
+
"state": {
|
1621 |
+
"_dom_classes": [],
|
1622 |
+
"_model_module": "@jupyter-widgets/controls",
|
1623 |
+
"_model_module_version": "1.5.0",
|
1624 |
+
"_model_name": "HBoxModel",
|
1625 |
+
"_view_count": null,
|
1626 |
+
"_view_module": "@jupyter-widgets/controls",
|
1627 |
+
"_view_module_version": "1.5.0",
|
1628 |
+
"_view_name": "HBoxView",
|
1629 |
+
"box_style": "",
|
1630 |
+
"children": [
|
1631 |
+
"IPY_MODEL_2315228ff2b141afabe1263471f5364b",
|
1632 |
+
"IPY_MODEL_0474debc340943bd85f3daf92aebf7aa",
|
1633 |
+
"IPY_MODEL_cff1b0fa2ea24f45aab26685353eefdd"
|
1634 |
+
],
|
1635 |
+
"layout": "IPY_MODEL_b7e20be79df246f19b35114a690e44f0"
|
1636 |
+
}
|
1637 |
+
},
|
1638 |
+
"b7e20be79df246f19b35114a690e44f0": {
|
1639 |
+
"model_module": "@jupyter-widgets/base",
|
1640 |
+
"model_module_version": "1.2.0",
|
1641 |
+
"model_name": "LayoutModel",
|
1642 |
+
"state": {
|
1643 |
+
"_model_module": "@jupyter-widgets/base",
|
1644 |
+
"_model_module_version": "1.2.0",
|
1645 |
+
"_model_name": "LayoutModel",
|
1646 |
+
"_view_count": null,
|
1647 |
+
"_view_module": "@jupyter-widgets/base",
|
1648 |
+
"_view_module_version": "1.2.0",
|
1649 |
+
"_view_name": "LayoutView",
|
1650 |
+
"align_content": null,
|
1651 |
+
"align_items": null,
|
1652 |
+
"align_self": null,
|
1653 |
+
"border": null,
|
1654 |
+
"bottom": null,
|
1655 |
+
"display": null,
|
1656 |
+
"flex": null,
|
1657 |
+
"flex_flow": null,
|
1658 |
+
"grid_area": null,
|
1659 |
+
"grid_auto_columns": null,
|
1660 |
+
"grid_auto_flow": null,
|
1661 |
+
"grid_auto_rows": null,
|
1662 |
+
"grid_column": null,
|
1663 |
+
"grid_gap": null,
|
1664 |
+
"grid_row": null,
|
1665 |
+
"grid_template_areas": null,
|
1666 |
+
"grid_template_columns": null,
|
1667 |
+
"grid_template_rows": null,
|
1668 |
+
"height": null,
|
1669 |
+
"justify_content": null,
|
1670 |
+
"justify_items": null,
|
1671 |
+
"left": null,
|
1672 |
+
"margin": null,
|
1673 |
+
"max_height": null,
|
1674 |
+
"max_width": null,
|
1675 |
+
"min_height": null,
|
1676 |
+
"min_width": null,
|
1677 |
+
"object_fit": null,
|
1678 |
+
"object_position": null,
|
1679 |
+
"order": null,
|
1680 |
+
"overflow": null,
|
1681 |
+
"overflow_x": null,
|
1682 |
+
"overflow_y": null,
|
1683 |
+
"padding": null,
|
1684 |
+
"right": null,
|
1685 |
+
"top": null,
|
1686 |
+
"visibility": null,
|
1687 |
+
"width": null
|
1688 |
+
}
|
1689 |
+
},
|
1690 |
+
"cff1b0fa2ea24f45aab26685353eefdd": {
|
1691 |
+
"model_module": "@jupyter-widgets/controls",
|
1692 |
+
"model_module_version": "1.5.0",
|
1693 |
+
"model_name": "HTMLModel",
|
1694 |
+
"state": {
|
1695 |
+
"_dom_classes": [],
|
1696 |
+
"_model_module": "@jupyter-widgets/controls",
|
1697 |
+
"_model_module_version": "1.5.0",
|
1698 |
+
"_model_name": "HTMLModel",
|
1699 |
+
"_view_count": null,
|
1700 |
+
"_view_module": "@jupyter-widgets/controls",
|
1701 |
+
"_view_module_version": "1.5.0",
|
1702 |
+
"_view_name": "HTMLView",
|
1703 |
+
"description": "",
|
1704 |
+
"description_tooltip": null,
|
1705 |
+
"layout": "IPY_MODEL_33fbacbb2aa146cd90586357eec1dc3e",
|
1706 |
+
"placeholder": "",
|
1707 |
+
"style": "IPY_MODEL_930b4d1d5f4b494b830df4d4c398e67c",
|
1708 |
+
"value": " 354/401 [03:01<00:22, 2.11 examples/s]"
|
1709 |
+
}
|
1710 |
+
},
|
1711 |
+
"e58b5ad2f781475d8af2ddb38009baa6": {
|
1712 |
+
"model_module": "@jupyter-widgets/controls",
|
1713 |
+
"model_module_version": "1.5.0",
|
1714 |
+
"model_name": "ProgressStyleModel",
|
1715 |
+
"state": {
|
1716 |
+
"_model_module": "@jupyter-widgets/controls",
|
1717 |
+
"_model_module_version": "1.5.0",
|
1718 |
+
"_model_name": "ProgressStyleModel",
|
1719 |
+
"_view_count": null,
|
1720 |
+
"_view_module": "@jupyter-widgets/base",
|
1721 |
+
"_view_module_version": "1.2.0",
|
1722 |
+
"_view_name": "StyleView",
|
1723 |
+
"bar_color": null,
|
1724 |
+
"description_width": ""
|
1725 |
+
}
|
1726 |
+
}
|
1727 |
+
}
|
1728 |
+
}
|
1729 |
+
},
|
1730 |
+
"nbformat": 4,
|
1731 |
+
"nbformat_minor": 4
|
1732 |
+
}
|
config.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
highpass_cutoff_freq: 40
|
2 |
+
sample_rate: 16000
|
3 |
+
segment_duration: 2.4 # (s)
|
4 |
+
max_val_duration: 12 # (s)
|
5 |
+
latent_hop_length: 320
|
6 |
+
ref_segment_duration: 6
|
7 |
+
volume_normalize: true
|
src/figures/gradio_TTS.png
ADDED
![]() |
src/figures/gradio_control.png
ADDED
![]() |
src/figures/infer_control.png
ADDED
![]() |
Git LFS Details
|
src/figures/infer_voice_cloning.png
ADDED
![]() |
Git LFS Details
|
src/logo/HKUST.jpg
ADDED
![]() |
Git LFS Details
|
src/logo/NPU.jpg
ADDED
![]() |
Git LFS Details
|
src/logo/NTU.jpg
ADDED
![]() |
src/logo/SJU.jpg
ADDED
![]() |
Git LFS Details
|
src/logo/SparkAudio.jpg
ADDED
![]() |
src/logo/SparkAudio2.jpg
ADDED
![]() |
src/logo/SparkTTS.jpg
ADDED
![]() |
src/logo/SparkTTS.png
ADDED
![]() |
Git LFS Details
|
src/logo/mobvoi.jpg
ADDED
![]() |
Git LFS Details
|
src/logo/mobvoi.png
ADDED
![]() |
Git LFS Details
|
wav2vec2-large-xlsr-53/README.md
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language: multilingual
|
3 |
+
datasets:
|
4 |
+
- common_voice
|
5 |
+
tags:
|
6 |
+
- speech
|
7 |
+
license: apache-2.0
|
8 |
+
---
|
9 |
+
|
10 |
+
# Wav2Vec2-XLSR-53
|
11 |
+
|
12 |
+
[Facebook's XLSR-Wav2Vec2](https://ai.facebook.com/blog/wav2vec-20-learning-the-structure-of-speech-from-raw-audio/)
|
13 |
+
|
14 |
+
The base model pretrained on 16kHz sampled speech audio. When using the model make sure that your speech input is also sampled at 16Khz. Note that this model should be fine-tuned on a downstream task, like Automatic Speech Recognition. Check out [this blog](https://huggingface.co/blog/fine-tune-wav2vec2-english) for more information.
|
15 |
+
|
16 |
+
[Paper](https://arxiv.org/abs/2006.13979)
|
17 |
+
|
18 |
+
Authors: Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli
|
19 |
+
|
20 |
+
**Abstract**
|
21 |
+
This paper presents XLSR which learns cross-lingual speech representations by pretraining a single model from the raw waveform of speech in multiple languages. We build on wav2vec 2.0 which is trained by solving a contrastive task over masked latent speech representations and jointly learns a quantization of the latents shared across languages. The resulting model is fine-tuned on labeled data and experiments show that cross-lingual pretraining significantly outperforms monolingual pretraining. On the CommonVoice benchmark, XLSR shows a relative phoneme error rate reduction of 72% compared to the best known results. On BABEL, our approach improves word error rate by 16% relative compared to a comparable system. Our approach enables a single multilingual speech recognition model which is competitive to strong individual models. Analysis shows that the latent discrete speech representations are shared across languages with increased sharing for related languages. We hope to catalyze research in low-resource speech understanding by releasing XLSR-53, a large model pretrained in 53 languages.
|
22 |
+
|
23 |
+
The original model can be found under https://github.com/pytorch/fairseq/tree/master/examples/wav2vec#wav2vec-20.
|
24 |
+
|
25 |
+
# Usage
|
26 |
+
|
27 |
+
See [this notebook](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/Fine_Tune_XLSR_Wav2Vec2_on_Turkish_ASR_with_%F0%9F%A4%97_Transformers.ipynb) for more information on how to fine-tune the model.
|
28 |
+
|
29 |
+

|
wav2vec2-large-xlsr-53/config.json
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"activation_dropout": 0.0,
|
3 |
+
"apply_spec_augment": true,
|
4 |
+
"architectures": [
|
5 |
+
"Wav2Vec2ForPreTraining"
|
6 |
+
],
|
7 |
+
"attention_dropout": 0.1,
|
8 |
+
"bos_token_id": 1,
|
9 |
+
"codevector_dim": 768,
|
10 |
+
"contrastive_logits_temperature": 0.1,
|
11 |
+
"conv_bias": true,
|
12 |
+
"conv_dim": [
|
13 |
+
512,
|
14 |
+
512,
|
15 |
+
512,
|
16 |
+
512,
|
17 |
+
512,
|
18 |
+
512,
|
19 |
+
512
|
20 |
+
],
|
21 |
+
"conv_kernel": [
|
22 |
+
10,
|
23 |
+
3,
|
24 |
+
3,
|
25 |
+
3,
|
26 |
+
3,
|
27 |
+
2,
|
28 |
+
2
|
29 |
+
],
|
30 |
+
"conv_stride": [
|
31 |
+
5,
|
32 |
+
2,
|
33 |
+
2,
|
34 |
+
2,
|
35 |
+
2,
|
36 |
+
2,
|
37 |
+
2
|
38 |
+
],
|
39 |
+
"ctc_loss_reduction": "sum",
|
40 |
+
"ctc_zero_infinity": false,
|
41 |
+
"diversity_loss_weight": 0.1,
|
42 |
+
"do_stable_layer_norm": true,
|
43 |
+
"eos_token_id": 2,
|
44 |
+
"feat_extract_activation": "gelu",
|
45 |
+
"feat_extract_dropout": 0.0,
|
46 |
+
"feat_extract_norm": "layer",
|
47 |
+
"feat_proj_dropout": 0.1,
|
48 |
+
"feat_quantizer_dropout": 0.0,
|
49 |
+
"final_dropout": 0.0,
|
50 |
+
"gradient_checkpointing": false,
|
51 |
+
"hidden_act": "gelu",
|
52 |
+
"hidden_dropout": 0.1,
|
53 |
+
"hidden_size": 1024,
|
54 |
+
"initializer_range": 0.02,
|
55 |
+
"intermediate_size": 4096,
|
56 |
+
"layer_norm_eps": 1e-05,
|
57 |
+
"layerdrop": 0.1,
|
58 |
+
"mask_channel_length": 10,
|
59 |
+
"mask_channel_min_space": 1,
|
60 |
+
"mask_channel_other": 0.0,
|
61 |
+
"mask_channel_prob": 0.0,
|
62 |
+
"mask_channel_selection": "static",
|
63 |
+
"mask_feature_length": 10,
|
64 |
+
"mask_feature_prob": 0.0,
|
65 |
+
"mask_time_length": 10,
|
66 |
+
"mask_time_min_space": 1,
|
67 |
+
"mask_time_other": 0.0,
|
68 |
+
"mask_time_prob": 0.075,
|
69 |
+
"mask_time_selection": "static",
|
70 |
+
"model_type": "wav2vec2",
|
71 |
+
"num_attention_heads": 16,
|
72 |
+
"num_codevector_groups": 2,
|
73 |
+
"num_codevectors_per_group": 320,
|
74 |
+
"num_conv_pos_embedding_groups": 16,
|
75 |
+
"num_conv_pos_embeddings": 128,
|
76 |
+
"num_feat_extract_layers": 7,
|
77 |
+
"num_hidden_layers": 24,
|
78 |
+
"num_negatives": 100,
|
79 |
+
"pad_token_id": 0,
|
80 |
+
"proj_codevector_dim": 768,
|
81 |
+
"transformers_version": "4.7.0.dev0",
|
82 |
+
"vocab_size": 32
|
83 |
+
}
|
wav2vec2-large-xlsr-53/preprocessor_config.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"do_normalize": true,
|
3 |
+
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
|
4 |
+
"feature_size": 1,
|
5 |
+
"padding_side": "right",
|
6 |
+
"padding_value": 0,
|
7 |
+
"return_attention_mask": true,
|
8 |
+
"sampling_rate": 16000
|
9 |
+
}
|
wav2vec2-large-xlsr-53/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:314340227371a608f71adcd5f0de5933824fe77e55822aa4b24dba9c1c364dcb
|
3 |
+
size 1269737156
|