Spaces:
Running
on
Zero
Running
on
Zero
ZeroGPU (#2)
Browse files- ZeroGPU (9aaa3a80a1ed0bc403064fdce8564fb0ef958784)
- Remove Dockerfile (052cd2034dc414d8783cb5686eb722c5cd3744cd)
Co-authored-by: hysts <[email protected]>
- Dockerfile +0 -72
- app.py +18 -20
- assets/silence_100ms.wav +0 -0
- packages.txt +1 -0
- requirements.txt +328 -0
- zonos/autoencoder.py +26 -0
- zonos/backbone.py +50 -0
- zonos/codebook_pattern.py +12 -0
- zonos/conditioning.py +373 -0
- zonos/config.py +38 -0
- zonos/model.py +270 -0
- zonos/sampling.py +141 -0
- zonos/speaker_cloning.py +406 -0
Dockerfile
DELETED
@@ -1,72 +0,0 @@
|
|
1 |
-
FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel
|
2 |
-
|
3 |
-
# Install dependencies
|
4 |
-
RUN pip install uv
|
5 |
-
|
6 |
-
ENV DEBIAN_FRONTEND=noninteractive \
|
7 |
-
TZ=Europe/Paris
|
8 |
-
|
9 |
-
# Remove any third-party apt sources to avoid issues with expiring keys.
|
10 |
-
# Install some basic utilities
|
11 |
-
RUN rm -f /etc/apt/sources.list.d/*.list && \
|
12 |
-
apt-get update && apt-get install -y --no-install-recommends \
|
13 |
-
curl \
|
14 |
-
ca-certificates \
|
15 |
-
sudo \
|
16 |
-
git \
|
17 |
-
wget \
|
18 |
-
procps \
|
19 |
-
git-lfs \
|
20 |
-
zip \
|
21 |
-
unzip \
|
22 |
-
htop \
|
23 |
-
vim \
|
24 |
-
nano \
|
25 |
-
bzip2 \
|
26 |
-
libx11-6 \
|
27 |
-
build-essential \
|
28 |
-
libsndfile-dev \
|
29 |
-
software-properties-common \
|
30 |
-
espeak \
|
31 |
-
tmux \
|
32 |
-
ffmpeg \
|
33 |
-
&& rm -rf /var/lib/apt/lists/*
|
34 |
-
|
35 |
-
RUN add-apt-repository ppa:flexiondotorg/nvtop && \
|
36 |
-
apt-get upgrade -y && \
|
37 |
-
apt-get install -y --no-install-recommends nvtop
|
38 |
-
|
39 |
-
RUN curl -sL https://deb.nodesource.com/setup_14.x | bash - && \
|
40 |
-
apt-get install -y nodejs && \
|
41 |
-
npm install -g configurable-http-proxy
|
42 |
-
|
43 |
-
|
44 |
-
# Set working directory
|
45 |
-
WORKDIR /app
|
46 |
-
|
47 |
-
# Clone the repository
|
48 |
-
RUN git clone https://github.com/Zyphra/Zonos.git && cd Zonos
|
49 |
-
|
50 |
-
# Set environment variables for writable cache directories
|
51 |
-
ENV TRITON_CACHE_DIR=/tmp/.triton
|
52 |
-
ENV HF_HOME=/tmp/huggingface_cache
|
53 |
-
|
54 |
-
# Ensure cache directories are writable
|
55 |
-
RUN mkdir -p $TRITON_CACHE_DIR $TRANSFORMERS_CACHE && chmod -R 777 $TRITON_CACHE_DIR $TRANSFORMERS_CACHE
|
56 |
-
|
57 |
-
# Install Python dependencies
|
58 |
-
WORKDIR /app/Zonos
|
59 |
-
RUN uv pip install --system -e . && uv pip install --system -e .[compile]
|
60 |
-
RUN uv pip install --system spaces
|
61 |
-
|
62 |
-
# Expose the Gradio default port
|
63 |
-
EXPOSE 7860
|
64 |
-
|
65 |
-
# Run the Gradio app from /app
|
66 |
-
WORKDIR /app
|
67 |
-
COPY . .
|
68 |
-
|
69 |
-
EXPOSE 7860
|
70 |
-
ENV GRADIO_SERVER_NAME="0.0.0.0"
|
71 |
-
|
72 |
-
CMD ["python", "app.py"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -1,3 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
import torchaudio
|
3 |
import gradio as gr
|
@@ -7,22 +16,10 @@ from zonos.model import Zonos
|
|
7 |
from zonos.conditioning import make_cond_dict, supported_language_codes
|
8 |
|
9 |
device = "cuda"
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
def load_model_if_needed(model_choice: str):
|
15 |
-
global CURRENT_MODEL_TYPE, CURRENT_MODEL
|
16 |
-
if CURRENT_MODEL_TYPE != model_choice:
|
17 |
-
if CURRENT_MODEL is not None:
|
18 |
-
del CURRENT_MODEL
|
19 |
-
torch.cuda.empty_cache()
|
20 |
-
print(f"Loading {model_choice} model...")
|
21 |
-
CURRENT_MODEL = Zonos.from_pretrained(model_choice, device=device)
|
22 |
-
CURRENT_MODEL.requires_grad_(False).eval()
|
23 |
-
CURRENT_MODEL_TYPE = model_choice
|
24 |
-
print(f"{model_choice} model loaded successfully!")
|
25 |
-
return CURRENT_MODEL
|
26 |
|
27 |
|
28 |
def update_ui(model_choice):
|
@@ -30,7 +27,7 @@ def update_ui(model_choice):
|
|
30 |
Dynamically show/hide UI elements based on the model's conditioners.
|
31 |
We do NOT display 'language_id' or 'ctc_loss' even if they exist in the model.
|
32 |
"""
|
33 |
-
model =
|
34 |
cond_names = [c.name for c in model.prefix_conditioner.conditioners]
|
35 |
print("Conditioners in this model:", cond_names)
|
36 |
|
@@ -79,6 +76,7 @@ def update_ui(model_choice):
|
|
79 |
)
|
80 |
|
81 |
|
|
|
82 |
def generate_audio(
|
83 |
model_choice,
|
84 |
text,
|
@@ -110,7 +108,7 @@ def generate_audio(
|
|
110 |
Generates audio based on the provided UI parameters.
|
111 |
We do NOT use language_id or ctc_loss even if the model has them.
|
112 |
"""
|
113 |
-
selected_model =
|
114 |
|
115 |
speaker_noised_bool = bool(speaker_noised)
|
116 |
fmax = float(fmax)
|
@@ -191,7 +189,7 @@ def build_interface():
|
|
191 |
with gr.Row():
|
192 |
with gr.Column():
|
193 |
model_choice = gr.Dropdown(
|
194 |
-
choices=
|
195 |
value="Zyphra/Zonos-v0.1-transformer",
|
196 |
label="Zonos Model Type",
|
197 |
info="Select the model variant to use.",
|
@@ -369,4 +367,4 @@ def build_interface():
|
|
369 |
if __name__ == "__main__":
|
370 |
demo = build_interface()
|
371 |
share = getenv("GRADIO_SHARE", "False").lower() in ("true", "1", "t")
|
372 |
-
demo.launch(server_name="0.0.0.0", server_port=7860, share=share)
|
|
|
1 |
+
import os
|
2 |
+
import shlex
|
3 |
+
import subprocess
|
4 |
+
|
5 |
+
subprocess.run(shlex.split("pip install flash-attn --no-build-isolation"), env=os.environ | {"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, check=True)
|
6 |
+
subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.4/mamba_ssm-2.2.4+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"), check=True)
|
7 |
+
subprocess.run(shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.5.0.post8/causal_conv1d-1.5.0.post8+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"), check=True)
|
8 |
+
|
9 |
+
import spaces
|
10 |
import torch
|
11 |
import torchaudio
|
12 |
import gradio as gr
|
|
|
16 |
from zonos.conditioning import make_cond_dict, supported_language_codes
|
17 |
|
18 |
device = "cuda"
|
19 |
+
MODEL_NAMES = ["Zyphra/Zonos-v0.1-transformer", "Zyphra/Zonos-v0.1-hybrid"]
|
20 |
+
MODELS = {name: Zonos.from_pretrained(name, device=device) for name in MODEL_NAMES}
|
21 |
+
for model in MODELS.values():
|
22 |
+
model.requires_grad_(False).eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
|
25 |
def update_ui(model_choice):
|
|
|
27 |
Dynamically show/hide UI elements based on the model's conditioners.
|
28 |
We do NOT display 'language_id' or 'ctc_loss' even if they exist in the model.
|
29 |
"""
|
30 |
+
model = MODELS[model_choice]
|
31 |
cond_names = [c.name for c in model.prefix_conditioner.conditioners]
|
32 |
print("Conditioners in this model:", cond_names)
|
33 |
|
|
|
76 |
)
|
77 |
|
78 |
|
79 |
+
@spaces.GPU(duration=120)
|
80 |
def generate_audio(
|
81 |
model_choice,
|
82 |
text,
|
|
|
108 |
Generates audio based on the provided UI parameters.
|
109 |
We do NOT use language_id or ctc_loss even if the model has them.
|
110 |
"""
|
111 |
+
selected_model = MODELS[model_choice]
|
112 |
|
113 |
speaker_noised_bool = bool(speaker_noised)
|
114 |
fmax = float(fmax)
|
|
|
189 |
with gr.Row():
|
190 |
with gr.Column():
|
191 |
model_choice = gr.Dropdown(
|
192 |
+
choices=MODEL_NAMES,
|
193 |
value="Zyphra/Zonos-v0.1-transformer",
|
194 |
label="Zonos Model Type",
|
195 |
info="Select the model variant to use.",
|
|
|
367 |
if __name__ == "__main__":
|
368 |
demo = build_interface()
|
369 |
share = getenv("GRADIO_SHARE", "False").lower() in ("true", "1", "t")
|
370 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=share)
|
assets/silence_100ms.wav
ADDED
Binary file (9.43 kB). View file
|
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
espeak-ng
|
requirements.txt
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file was autogenerated by uv via the following command:
|
2 |
+
# uv pip compile pyproject.toml -o requirements.txt
|
3 |
+
aiofiles==23.2.1
|
4 |
+
# via gradio
|
5 |
+
annotated-types==0.7.0
|
6 |
+
# via pydantic
|
7 |
+
anyio==4.8.0
|
8 |
+
# via
|
9 |
+
# gradio
|
10 |
+
# httpx
|
11 |
+
# starlette
|
12 |
+
attrs==25.1.0
|
13 |
+
# via
|
14 |
+
# clldutils
|
15 |
+
# csvw
|
16 |
+
# jsonschema
|
17 |
+
# phonemizer
|
18 |
+
# referencing
|
19 |
+
babel==2.17.0
|
20 |
+
# via csvw
|
21 |
+
certifi==2025.1.31
|
22 |
+
# via
|
23 |
+
# httpcore
|
24 |
+
# httpx
|
25 |
+
# requests
|
26 |
+
cffi==1.17.1
|
27 |
+
# via soundfile
|
28 |
+
charset-normalizer==3.4.1
|
29 |
+
# via requests
|
30 |
+
click==8.1.8
|
31 |
+
# via
|
32 |
+
# typer
|
33 |
+
# uvicorn
|
34 |
+
clldutils==3.21.0
|
35 |
+
# via segments
|
36 |
+
colorama==0.4.6
|
37 |
+
# via csvw
|
38 |
+
colorlog==6.9.0
|
39 |
+
# via clldutils
|
40 |
+
csvw==3.5.1
|
41 |
+
# via segments
|
42 |
+
dlinfo==2.0.0
|
43 |
+
# via phonemizer
|
44 |
+
exceptiongroup==1.2.2
|
45 |
+
# via anyio
|
46 |
+
fastapi==0.115.8
|
47 |
+
# via gradio
|
48 |
+
ffmpy==0.5.0
|
49 |
+
# via gradio
|
50 |
+
filelock==3.17.0
|
51 |
+
# via
|
52 |
+
# huggingface-hub
|
53 |
+
# torch
|
54 |
+
# transformers
|
55 |
+
# triton
|
56 |
+
fsspec==2025.2.0
|
57 |
+
# via
|
58 |
+
# gradio-client
|
59 |
+
# huggingface-hub
|
60 |
+
# torch
|
61 |
+
gradio==5.16.0
|
62 |
+
# via
|
63 |
+
# zonos (pyproject.toml)
|
64 |
+
# spaces
|
65 |
+
gradio-client==1.7.0
|
66 |
+
# via gradio
|
67 |
+
h11==0.14.0
|
68 |
+
# via
|
69 |
+
# httpcore
|
70 |
+
# uvicorn
|
71 |
+
hf-transfer==0.1.9
|
72 |
+
# via zonos (pyproject.toml)
|
73 |
+
httpcore==1.0.7
|
74 |
+
# via httpx
|
75 |
+
httpx==0.28.1
|
76 |
+
# via
|
77 |
+
# gradio
|
78 |
+
# gradio-client
|
79 |
+
# safehttpx
|
80 |
+
# spaces
|
81 |
+
huggingface-hub==0.28.1
|
82 |
+
# via
|
83 |
+
# zonos (pyproject.toml)
|
84 |
+
# gradio
|
85 |
+
# gradio-client
|
86 |
+
# tokenizers
|
87 |
+
# transformers
|
88 |
+
idna==3.10
|
89 |
+
# via
|
90 |
+
# anyio
|
91 |
+
# httpx
|
92 |
+
# requests
|
93 |
+
inflect==7.5.0
|
94 |
+
# via zonos (pyproject.toml)
|
95 |
+
isodate==0.7.2
|
96 |
+
# via
|
97 |
+
# csvw
|
98 |
+
# rdflib
|
99 |
+
jinja2==3.1.5
|
100 |
+
# via
|
101 |
+
# gradio
|
102 |
+
# torch
|
103 |
+
joblib==1.4.2
|
104 |
+
# via phonemizer
|
105 |
+
jsonschema==4.23.0
|
106 |
+
# via csvw
|
107 |
+
jsonschema-specifications==2024.10.1
|
108 |
+
# via jsonschema
|
109 |
+
kanjize==1.6.0
|
110 |
+
# via zonos (pyproject.toml)
|
111 |
+
language-tags==1.2.0
|
112 |
+
# via csvw
|
113 |
+
lxml==5.3.1
|
114 |
+
# via clldutils
|
115 |
+
markdown==3.7
|
116 |
+
# via clldutils
|
117 |
+
markdown-it-py==3.0.0
|
118 |
+
# via rich
|
119 |
+
markupsafe==2.1.5
|
120 |
+
# via
|
121 |
+
# clldutils
|
122 |
+
# gradio
|
123 |
+
# jinja2
|
124 |
+
mdurl==0.1.2
|
125 |
+
# via markdown-it-py
|
126 |
+
more-itertools==10.6.0
|
127 |
+
# via inflect
|
128 |
+
mpmath==1.3.0
|
129 |
+
# via sympy
|
130 |
+
networkx==3.4.2
|
131 |
+
# via torch
|
132 |
+
numpy==2.2.2
|
133 |
+
# via
|
134 |
+
# zonos (pyproject.toml)
|
135 |
+
# gradio
|
136 |
+
# pandas
|
137 |
+
# soundfile
|
138 |
+
# transformers
|
139 |
+
nvidia-cublas-cu12==12.1.3.1
|
140 |
+
# via
|
141 |
+
# nvidia-cudnn-cu12
|
142 |
+
# nvidia-cusolver-cu12
|
143 |
+
# torch
|
144 |
+
nvidia-cuda-cupti-cu12==12.1.105
|
145 |
+
# via torch
|
146 |
+
nvidia-cuda-nvrtc-cu12==12.1.105
|
147 |
+
# via torch
|
148 |
+
nvidia-cuda-runtime-cu12==12.1.105
|
149 |
+
# via torch
|
150 |
+
nvidia-cudnn-cu12==9.1.0.70
|
151 |
+
# via torch
|
152 |
+
nvidia-cufft-cu12==11.0.2.54
|
153 |
+
# via torch
|
154 |
+
nvidia-curand-cu12==10.3.2.106
|
155 |
+
# via torch
|
156 |
+
nvidia-cusolver-cu12==11.4.5.107
|
157 |
+
# via torch
|
158 |
+
nvidia-cusparse-cu12==12.1.0.106
|
159 |
+
# via
|
160 |
+
# nvidia-cusolver-cu12
|
161 |
+
# torch
|
162 |
+
nvidia-nccl-cu12==2.20.5
|
163 |
+
# via torch
|
164 |
+
nvidia-nvjitlink-cu12==12.8.61
|
165 |
+
# via
|
166 |
+
# nvidia-cusolver-cu12
|
167 |
+
# nvidia-cusparse-cu12
|
168 |
+
nvidia-nvtx-cu12==12.1.105
|
169 |
+
# via torch
|
170 |
+
orjson==3.10.15
|
171 |
+
# via gradio
|
172 |
+
packaging==24.2
|
173 |
+
# via
|
174 |
+
# zonos (pyproject.toml)
|
175 |
+
# gradio
|
176 |
+
# gradio-client
|
177 |
+
# huggingface-hub
|
178 |
+
# spaces
|
179 |
+
# transformers
|
180 |
+
pandas==2.2.3
|
181 |
+
# via gradio
|
182 |
+
phonemizer==3.3.0
|
183 |
+
# via zonos (pyproject.toml)
|
184 |
+
pillow==11.1.0
|
185 |
+
# via gradio
|
186 |
+
psutil==5.9.8
|
187 |
+
# via spaces
|
188 |
+
pycparser==2.22
|
189 |
+
# via cffi
|
190 |
+
pydantic==2.10.6
|
191 |
+
# via
|
192 |
+
# fastapi
|
193 |
+
# gradio
|
194 |
+
# spaces
|
195 |
+
pydantic-core==2.27.2
|
196 |
+
# via pydantic
|
197 |
+
pydub==0.25.1
|
198 |
+
# via gradio
|
199 |
+
pygments==2.19.1
|
200 |
+
# via rich
|
201 |
+
pylatexenc==2.10
|
202 |
+
# via clldutils
|
203 |
+
pyparsing==3.2.1
|
204 |
+
# via rdflib
|
205 |
+
python-dateutil==2.9.0.post0
|
206 |
+
# via
|
207 |
+
# clldutils
|
208 |
+
# csvw
|
209 |
+
# pandas
|
210 |
+
python-multipart==0.0.20
|
211 |
+
# via gradio
|
212 |
+
pytz==2025.1
|
213 |
+
# via pandas
|
214 |
+
pyyaml==6.0.2
|
215 |
+
# via
|
216 |
+
# gradio
|
217 |
+
# huggingface-hub
|
218 |
+
# transformers
|
219 |
+
rdflib==7.1.3
|
220 |
+
# via csvw
|
221 |
+
referencing==0.36.2
|
222 |
+
# via
|
223 |
+
# jsonschema
|
224 |
+
# jsonschema-specifications
|
225 |
+
regex==2024.11.6
|
226 |
+
# via
|
227 |
+
# segments
|
228 |
+
# transformers
|
229 |
+
requests==2.32.3
|
230 |
+
# via
|
231 |
+
# csvw
|
232 |
+
# huggingface-hub
|
233 |
+
# spaces
|
234 |
+
# transformers
|
235 |
+
rfc3986==1.5.0
|
236 |
+
# via csvw
|
237 |
+
rich==13.9.4
|
238 |
+
# via typer
|
239 |
+
rpds-py==0.22.3
|
240 |
+
# via
|
241 |
+
# jsonschema
|
242 |
+
# referencing
|
243 |
+
ruff==0.9.6
|
244 |
+
# via gradio
|
245 |
+
safehttpx==0.1.6
|
246 |
+
# via gradio
|
247 |
+
safetensors==0.5.2
|
248 |
+
# via transformers
|
249 |
+
segments==2.2.1
|
250 |
+
# via phonemizer
|
251 |
+
semantic-version==2.10.0
|
252 |
+
# via gradio
|
253 |
+
setuptools==75.8.0
|
254 |
+
# via zonos (pyproject.toml)
|
255 |
+
shellingham==1.5.4
|
256 |
+
# via typer
|
257 |
+
six==1.17.0
|
258 |
+
# via python-dateutil
|
259 |
+
sniffio==1.3.1
|
260 |
+
# via anyio
|
261 |
+
soundfile==0.13.1
|
262 |
+
# via zonos (pyproject.toml)
|
263 |
+
spaces==0.32.0
|
264 |
+
# via zonos (pyproject.toml)
|
265 |
+
starlette==0.45.3
|
266 |
+
# via
|
267 |
+
# fastapi
|
268 |
+
# gradio
|
269 |
+
sudachidict-full==20250129
|
270 |
+
# via zonos (pyproject.toml)
|
271 |
+
sudachipy==0.6.10
|
272 |
+
# via
|
273 |
+
# zonos (pyproject.toml)
|
274 |
+
# sudachidict-full
|
275 |
+
sympy==1.13.3
|
276 |
+
# via torch
|
277 |
+
tabulate==0.9.0
|
278 |
+
# via clldutils
|
279 |
+
tokenizers==0.21.0
|
280 |
+
# via transformers
|
281 |
+
tomlkit==0.13.2
|
282 |
+
# via gradio
|
283 |
+
torch==2.4.0
|
284 |
+
# via
|
285 |
+
# zonos (pyproject.toml)
|
286 |
+
# torchaudio
|
287 |
+
torchaudio==2.4.0
|
288 |
+
# via zonos (pyproject.toml)
|
289 |
+
tqdm==4.67.1
|
290 |
+
# via
|
291 |
+
# huggingface-hub
|
292 |
+
# transformers
|
293 |
+
transformers==4.48.3
|
294 |
+
# via zonos (pyproject.toml)
|
295 |
+
triton==3.0.0
|
296 |
+
# via torch
|
297 |
+
typeguard==4.4.1
|
298 |
+
# via inflect
|
299 |
+
typer==0.15.1
|
300 |
+
# via gradio
|
301 |
+
typing-extensions==4.12.2
|
302 |
+
# via
|
303 |
+
# anyio
|
304 |
+
# fastapi
|
305 |
+
# gradio
|
306 |
+
# gradio-client
|
307 |
+
# huggingface-hub
|
308 |
+
# phonemizer
|
309 |
+
# pydantic
|
310 |
+
# pydantic-core
|
311 |
+
# referencing
|
312 |
+
# rich
|
313 |
+
# spaces
|
314 |
+
# torch
|
315 |
+
# typeguard
|
316 |
+
# typer
|
317 |
+
# uvicorn
|
318 |
+
tzdata==2025.1
|
319 |
+
# via pandas
|
320 |
+
uritemplate==4.1.1
|
321 |
+
# via csvw
|
322 |
+
urllib3==2.3.0
|
323 |
+
# via requests
|
324 |
+
uvicorn==0.34.0
|
325 |
+
# via gradio
|
326 |
+
websockets==14.2
|
327 |
+
# via gradio-client
|
328 |
+
|
zonos/autoencoder.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torchaudio
|
5 |
+
from transformers.models.dac import DacModel
|
6 |
+
|
7 |
+
|
8 |
+
class DACAutoencoder:
|
9 |
+
def __init__(self):
|
10 |
+
super().__init__()
|
11 |
+
self.dac = DacModel.from_pretrained("descript/dac_44khz")
|
12 |
+
self.dac.eval().requires_grad_(False)
|
13 |
+
self.codebook_size = self.dac.config.codebook_size
|
14 |
+
self.num_codebooks = self.dac.quantizer.n_codebooks
|
15 |
+
self.sampling_rate = self.dac.config.sampling_rate
|
16 |
+
|
17 |
+
def preprocess(self, wav: torch.Tensor, sr: int) -> torch.Tensor:
|
18 |
+
wav = torchaudio.functional.resample(wav, sr, 44_100)
|
19 |
+
right_pad = math.ceil(wav.shape[-1] / 512) * 512 - wav.shape[-1]
|
20 |
+
return torch.nn.functional.pad(wav, (0, right_pad))
|
21 |
+
|
22 |
+
def encode(self, wav: torch.Tensor) -> torch.Tensor:
|
23 |
+
return self.dac.encode(wav).audio_codes
|
24 |
+
|
25 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
26 |
+
return self.dac.decode(audio_codes=codes).audio_values.unsqueeze(1)
|
zonos/backbone.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from mamba_ssm.models.mixer_seq_simple import create_block
|
4 |
+
from mamba_ssm.ops.triton.layer_norm import layer_norm_fn
|
5 |
+
from mamba_ssm.utils.generation import InferenceParams
|
6 |
+
|
7 |
+
from zonos.config import BackboneConfig
|
8 |
+
|
9 |
+
|
10 |
+
class ZonosBackbone(nn.Module):
|
11 |
+
def __init__(self, config: BackboneConfig):
|
12 |
+
super().__init__()
|
13 |
+
self.config = config
|
14 |
+
|
15 |
+
self.layers = nn.ModuleList(
|
16 |
+
[
|
17 |
+
create_block(
|
18 |
+
d_model=config.d_model,
|
19 |
+
d_intermediate=config.d_intermediate
|
20 |
+
if (i not in config.attn_layer_idx)
|
21 |
+
else config.attn_mlp_d_intermediate,
|
22 |
+
ssm_cfg=config.ssm_cfg,
|
23 |
+
layer_idx=i,
|
24 |
+
attn_layer_idx=config.attn_layer_idx,
|
25 |
+
attn_cfg=config.attn_cfg,
|
26 |
+
norm_epsilon=config.norm_epsilon,
|
27 |
+
residual_in_fp32=config.residual_in_fp32,
|
28 |
+
fused_add_norm=True,
|
29 |
+
rms_norm=config.rms_norm,
|
30 |
+
)
|
31 |
+
for i in range(config.n_layer)
|
32 |
+
]
|
33 |
+
)
|
34 |
+
|
35 |
+
self.norm_f = nn.LayerNorm(config.d_model, eps=config.norm_epsilon)
|
36 |
+
|
37 |
+
def forward(self, hidden_states: torch.Tensor, inference_params: InferenceParams | None = None):
|
38 |
+
residual = None
|
39 |
+
for layer in self.layers:
|
40 |
+
hidden_states, residual = layer(hidden_states, residual, inference_params)
|
41 |
+
|
42 |
+
return layer_norm_fn(
|
43 |
+
hidden_states,
|
44 |
+
self.norm_f.weight,
|
45 |
+
self.norm_f.bias,
|
46 |
+
residual,
|
47 |
+
eps=self.norm_f.eps,
|
48 |
+
residual_in_fp32=self.config.residual_in_fp32,
|
49 |
+
is_rms_norm=self.config.rms_norm,
|
50 |
+
)
|
zonos/codebook_pattern.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
def apply_delay_pattern(codes: torch.Tensor, mask_token: int):
|
6 |
+
codes = F.pad(codes, (0, codes.shape[1]), value=mask_token)
|
7 |
+
return torch.stack([codes[:, k].roll(k + 1) for k in range(codes.shape[1])], dim=1)
|
8 |
+
|
9 |
+
|
10 |
+
def revert_delay_pattern(codes: torch.Tensor):
|
11 |
+
_, n_q, seq_len = codes.shape
|
12 |
+
return torch.stack([codes[:, k, k + 1 : seq_len - n_q + k + 1] for k in range(n_q)], dim=1)
|
zonos/conditioning.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import cache
|
2 |
+
from typing import Any, Literal, Iterable
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from zonos.config import PrefixConditionerConfig
|
8 |
+
|
9 |
+
|
10 |
+
class Conditioner(nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
output_dim: int,
|
14 |
+
name: str,
|
15 |
+
cond_dim: int | None = None,
|
16 |
+
projection: Literal["none", "linear", "mlp"] = "none",
|
17 |
+
uncond_type: Literal["learned", "none"] = "none",
|
18 |
+
**kwargs,
|
19 |
+
):
|
20 |
+
super().__init__()
|
21 |
+
self.name = name
|
22 |
+
self.output_dim = output_dim
|
23 |
+
self.cond_dim = cond_dim = cond_dim or output_dim
|
24 |
+
|
25 |
+
if projection == "linear":
|
26 |
+
self.project = nn.Linear(cond_dim, output_dim)
|
27 |
+
elif projection == "mlp":
|
28 |
+
self.project = nn.Sequential(
|
29 |
+
nn.Linear(cond_dim, output_dim),
|
30 |
+
nn.SiLU(),
|
31 |
+
nn.Linear(output_dim, output_dim),
|
32 |
+
)
|
33 |
+
else:
|
34 |
+
self.project = nn.Identity()
|
35 |
+
|
36 |
+
self.uncond_vector = None
|
37 |
+
if uncond_type == "learned":
|
38 |
+
self.uncond_vector = nn.Parameter(torch.zeros(output_dim))
|
39 |
+
|
40 |
+
def apply_cond(self, *inputs: Any) -> torch.Tensor:
|
41 |
+
raise NotImplementedError()
|
42 |
+
|
43 |
+
def forward(self, inputs: tuple[Any, ...] | None) -> torch.Tensor:
|
44 |
+
if inputs is None:
|
45 |
+
assert self.uncond_vector is not None
|
46 |
+
return self.uncond_vector.data.view(1, 1, -1)
|
47 |
+
|
48 |
+
cond = self.apply_cond(*inputs)
|
49 |
+
cond = self.project(cond)
|
50 |
+
return cond
|
51 |
+
|
52 |
+
|
53 |
+
# ------- ESPEAK CONTAINMENT ZONE ------------------------------------------------------------------------------------------------------------------------------------------------
|
54 |
+
import re
|
55 |
+
import unicodedata
|
56 |
+
|
57 |
+
import inflect
|
58 |
+
import torch
|
59 |
+
import torch.nn as nn
|
60 |
+
from kanjize import number2kanji
|
61 |
+
from phonemizer.backend import EspeakBackend
|
62 |
+
from sudachipy import Dictionary, SplitMode
|
63 |
+
|
64 |
+
# --- Number normalization code from https://github.com/daniilrobnikov/vits2/blob/main/text/normalize_numbers.py ---
|
65 |
+
|
66 |
+
_inflect = inflect.engine()
|
67 |
+
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
68 |
+
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
|
69 |
+
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
|
70 |
+
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
|
71 |
+
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
72 |
+
_number_re = re.compile(r"[0-9]+")
|
73 |
+
|
74 |
+
|
75 |
+
def _remove_commas(m: re.Match) -> str:
|
76 |
+
return m.group(1).replace(",", "")
|
77 |
+
|
78 |
+
|
79 |
+
def _expand_decimal_point(m: re.Match) -> str:
|
80 |
+
return m.group(1).replace(".", " point ")
|
81 |
+
|
82 |
+
|
83 |
+
def _expand_dollars(m: re.Match) -> str:
|
84 |
+
match = m.group(1)
|
85 |
+
parts = match.split(".")
|
86 |
+
if len(parts) > 2:
|
87 |
+
return match + " dollars" # Unexpected format
|
88 |
+
dollars = int(parts[0]) if parts[0] else 0
|
89 |
+
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
90 |
+
if dollars and cents:
|
91 |
+
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
92 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
93 |
+
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
|
94 |
+
elif dollars:
|
95 |
+
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
96 |
+
return "%s %s" % (dollars, dollar_unit)
|
97 |
+
elif cents:
|
98 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
99 |
+
return "%s %s" % (cents, cent_unit)
|
100 |
+
else:
|
101 |
+
return "zero dollars"
|
102 |
+
|
103 |
+
|
104 |
+
def _expand_ordinal(m: re.Match) -> str:
|
105 |
+
return _inflect.number_to_words(m.group(0))
|
106 |
+
|
107 |
+
|
108 |
+
def _expand_number(m: re.Match) -> str:
|
109 |
+
num = int(m.group(0))
|
110 |
+
if num > 1000 and num < 3000:
|
111 |
+
if num == 2000:
|
112 |
+
return "two thousand"
|
113 |
+
elif num > 2000 and num < 2010:
|
114 |
+
return "two thousand " + _inflect.number_to_words(num % 100)
|
115 |
+
elif num % 100 == 0:
|
116 |
+
return _inflect.number_to_words(num // 100) + " hundred"
|
117 |
+
else:
|
118 |
+
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
|
119 |
+
else:
|
120 |
+
return _inflect.number_to_words(num, andword="")
|
121 |
+
|
122 |
+
|
123 |
+
def normalize_numbers(text: str) -> str:
|
124 |
+
text = re.sub(_comma_number_re, _remove_commas, text)
|
125 |
+
text = re.sub(_pounds_re, r"\1 pounds", text)
|
126 |
+
text = re.sub(_dollars_re, _expand_dollars, text)
|
127 |
+
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
128 |
+
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
129 |
+
text = re.sub(_number_re, _expand_number, text)
|
130 |
+
return text
|
131 |
+
|
132 |
+
|
133 |
+
# --- Number normalization code end ---
|
134 |
+
|
135 |
+
|
136 |
+
PAD_ID, UNK_ID, BOS_ID, EOS_ID = 0, 1, 2, 3
|
137 |
+
SPECIAL_TOKEN_IDS = [PAD_ID, UNK_ID, BOS_ID, EOS_ID]
|
138 |
+
|
139 |
+
_punctuation = ';:,.!?¡¿—…"«»“”() *~-/\\&'
|
140 |
+
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
141 |
+
_letters_ipa = (
|
142 |
+
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
143 |
+
)
|
144 |
+
|
145 |
+
symbols = [*_punctuation, *_letters, *_letters_ipa]
|
146 |
+
_symbol_to_id = {s: i for i, s in enumerate(symbols, start=len(SPECIAL_TOKEN_IDS))}
|
147 |
+
|
148 |
+
|
149 |
+
def _get_symbol_id(s: str) -> int:
|
150 |
+
return _symbol_to_id.get(s, 1)
|
151 |
+
|
152 |
+
|
153 |
+
def get_symbol_ids(text: str) -> list[int]:
|
154 |
+
return list(map(_get_symbol_id, text))
|
155 |
+
|
156 |
+
|
157 |
+
def tokenize_phonemes(phonemes: list[str]) -> tuple[torch.Tensor, list[int]]:
|
158 |
+
phoneme_ids = [[BOS_ID, *get_symbol_ids(phonemes), EOS_ID] for phonemes in phonemes]
|
159 |
+
lengths = list(map(len, phoneme_ids))
|
160 |
+
longest = max(lengths)
|
161 |
+
phoneme_ids = [[PAD_ID] * (longest - len(ids)) + ids for ids in phoneme_ids]
|
162 |
+
return torch.tensor(phoneme_ids), lengths
|
163 |
+
|
164 |
+
|
165 |
+
def normalize_jp_text(text: str, tokenizer=Dictionary(dict="full").create()) -> str:
|
166 |
+
text = unicodedata.normalize("NFKC", text)
|
167 |
+
text = re.sub(r"\d+", lambda m: number2kanji(int(m[0])), text)
|
168 |
+
final_text = " ".join([x.reading_form() for x in tokenizer.tokenize(text, SplitMode.A)])
|
169 |
+
return final_text
|
170 |
+
|
171 |
+
|
172 |
+
def clean(texts: list[str], languages: list[str]) -> list[str]:
|
173 |
+
texts_out = []
|
174 |
+
for text, language in zip(texts, languages):
|
175 |
+
if "ja" in language:
|
176 |
+
text = normalize_jp_text(text)
|
177 |
+
else:
|
178 |
+
text = normalize_numbers(text)
|
179 |
+
texts_out.append(text)
|
180 |
+
return texts_out
|
181 |
+
|
182 |
+
|
183 |
+
@cache
|
184 |
+
def get_backend(language: str) -> "EspeakBackend":
|
185 |
+
import logging
|
186 |
+
|
187 |
+
from phonemizer.backend import EspeakBackend
|
188 |
+
|
189 |
+
logger = logging.getLogger("phonemizer")
|
190 |
+
backend = EspeakBackend(
|
191 |
+
language,
|
192 |
+
preserve_punctuation=True,
|
193 |
+
with_stress=True,
|
194 |
+
punctuation_marks=_punctuation,
|
195 |
+
logger=logger,
|
196 |
+
)
|
197 |
+
logger.setLevel(logging.ERROR)
|
198 |
+
return backend
|
199 |
+
|
200 |
+
|
201 |
+
def phonemize(texts: list[str], languages: list[str]) -> list[str]:
|
202 |
+
texts = clean(texts, languages)
|
203 |
+
|
204 |
+
batch_phonemes = []
|
205 |
+
for text, language in zip(texts, languages):
|
206 |
+
backend = get_backend(language)
|
207 |
+
phonemes = backend.phonemize([text], strip=True)
|
208 |
+
batch_phonemes.append(phonemes[0])
|
209 |
+
|
210 |
+
return batch_phonemes
|
211 |
+
|
212 |
+
|
213 |
+
class EspeakPhonemeConditioner(Conditioner):
|
214 |
+
def __init__(self, output_dim: int, **kwargs):
|
215 |
+
super().__init__(output_dim, **kwargs)
|
216 |
+
self.phoneme_embedder = nn.Embedding(len(SPECIAL_TOKEN_IDS) + len(symbols), output_dim)
|
217 |
+
|
218 |
+
def apply_cond(self, texts: list[str], languages: list[str]) -> torch.Tensor:
|
219 |
+
"""
|
220 |
+
Args:
|
221 |
+
texts: list of texts to convert to phonemes
|
222 |
+
languages: ISO 639-1 -or otherwise eSpeak compatible- language code
|
223 |
+
"""
|
224 |
+
device = self.phoneme_embedder.weight.device
|
225 |
+
|
226 |
+
phonemes = phonemize(texts, languages)
|
227 |
+
phoneme_ids, _ = tokenize_phonemes(phonemes)
|
228 |
+
phoneme_embeds = self.phoneme_embedder(phoneme_ids.to(device))
|
229 |
+
|
230 |
+
return phoneme_embeds
|
231 |
+
|
232 |
+
|
233 |
+
# ------- ESPEAK CONTAINMENT ZONE ------------------------------------------------------------------------------------------------------------------------------------------------
|
234 |
+
|
235 |
+
|
236 |
+
class FourierConditioner(Conditioner):
|
237 |
+
def __init__(
|
238 |
+
self,
|
239 |
+
output_dim: int,
|
240 |
+
input_dim: int = 1,
|
241 |
+
std: float = 1.0,
|
242 |
+
min_val: float = 0.0,
|
243 |
+
max_val: float = 1.0,
|
244 |
+
**kwargs,
|
245 |
+
):
|
246 |
+
assert output_dim % 2 == 0
|
247 |
+
super().__init__(output_dim, **kwargs)
|
248 |
+
self.register_buffer("weight", torch.randn([output_dim // 2, input_dim]) * std)
|
249 |
+
self.input_dim, self.min_val, self.max_val = input_dim, min_val, max_val
|
250 |
+
|
251 |
+
def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
|
252 |
+
assert x.shape[-1] == self.input_dim
|
253 |
+
x = (x - self.min_val) / (self.max_val - self.min_val) # [batch_size, seq_len, input_dim]
|
254 |
+
f = 2 * torch.pi * x.to(self.weight.dtype) @ self.weight.T # [batch_size, seq_len, output_dim // 2]
|
255 |
+
return torch.cat([f.cos(), f.sin()], dim=-1) # [batch_size, seq_len, output_dim]
|
256 |
+
|
257 |
+
|
258 |
+
class IntegerConditioner(Conditioner):
|
259 |
+
def __init__(self, output_dim: int, min_val: int = 0, max_val: int = 512, **kwargs):
|
260 |
+
super().__init__(output_dim, **kwargs)
|
261 |
+
self.min_val = min_val
|
262 |
+
self.max_val = max_val
|
263 |
+
self.int_embedder = nn.Embedding(max_val - min_val + 1, output_dim)
|
264 |
+
|
265 |
+
def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
|
266 |
+
assert x.shape[-1] == 1
|
267 |
+
return self.int_embedder(x.squeeze(-1) - self.min_val) # [batch_size, seq_len, output_dim]
|
268 |
+
|
269 |
+
|
270 |
+
class PassthroughConditioner(Conditioner):
|
271 |
+
def __init__(self, output_dim: int, **kwargs):
|
272 |
+
super().__init__(output_dim, **kwargs)
|
273 |
+
|
274 |
+
def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
|
275 |
+
assert x.shape[-1] == self.cond_dim
|
276 |
+
return x
|
277 |
+
|
278 |
+
|
279 |
+
_cond_cls_map = {
|
280 |
+
"PassthroughConditioner": PassthroughConditioner,
|
281 |
+
"EspeakPhonemeConditioner": EspeakPhonemeConditioner,
|
282 |
+
"FourierConditioner": FourierConditioner,
|
283 |
+
"IntegerConditioner": IntegerConditioner,
|
284 |
+
}
|
285 |
+
|
286 |
+
|
287 |
+
def build_conditioners(conditioners: list[dict], output_dim: int) -> list[Conditioner]:
|
288 |
+
return [_cond_cls_map[config["type"]](output_dim, **config) for config in conditioners]
|
289 |
+
|
290 |
+
|
291 |
+
class PrefixConditioner(Conditioner):
|
292 |
+
def __init__(self, config: PrefixConditionerConfig, output_dim: int):
|
293 |
+
super().__init__(output_dim, "prefix", projection=config.projection)
|
294 |
+
self.conditioners = nn.ModuleList(build_conditioners(config.conditioners, output_dim))
|
295 |
+
self.norm = nn.LayerNorm(output_dim)
|
296 |
+
self.required_keys = {c.name for c in self.conditioners if c.uncond_vector is None}
|
297 |
+
|
298 |
+
def forward(self, cond_dict: dict) -> torch.Tensor:
|
299 |
+
if not set(cond_dict).issuperset(self.required_keys):
|
300 |
+
raise ValueError(f"Missing required keys: {self.required_keys - set(cond_dict)}")
|
301 |
+
conds = []
|
302 |
+
for conditioner in self.conditioners:
|
303 |
+
conds.append(conditioner(cond_dict.get(conditioner.name)))
|
304 |
+
max_bsz = max(map(len, conds))
|
305 |
+
assert all(c.shape[0] in (max_bsz, 1) for c in conds)
|
306 |
+
conds = [c.expand(max_bsz, -1, -1) for c in conds]
|
307 |
+
return self.norm(self.project(torch.cat(conds, dim=-2)))
|
308 |
+
|
309 |
+
|
310 |
+
supported_language_codes = [
|
311 |
+
'af', 'am', 'an', 'ar', 'as', 'az', 'ba', 'bg', 'bn', 'bpy', 'bs', 'ca', 'cmn',
|
312 |
+
'cs', 'cy', 'da', 'de', 'el', 'en-029', 'en-gb', 'en-gb-scotland', 'en-gb-x-gbclan',
|
313 |
+
'en-gb-x-gbcwmd', 'en-gb-x-rp', 'en-us', 'eo', 'es', 'es-419', 'et', 'eu', 'fa',
|
314 |
+
'fa-latn', 'fi', 'fr-be', 'fr-ch', 'fr-fr', 'ga', 'gd', 'gn', 'grc', 'gu', 'hak',
|
315 |
+
'hi', 'hr', 'ht', 'hu', 'hy', 'hyw', 'ia', 'id', 'is', 'it', 'ja', 'jbo', 'ka',
|
316 |
+
'kk', 'kl', 'kn', 'ko', 'kok', 'ku', 'ky', 'la', 'lfn', 'lt', 'lv', 'mi', 'mk',
|
317 |
+
'ml', 'mr', 'ms', 'mt', 'my', 'nb', 'nci', 'ne', 'nl', 'om', 'or', 'pa', 'pap',
|
318 |
+
'pl', 'pt', 'pt-br', 'py', 'quc', 'ro', 'ru', 'ru-lv', 'sd', 'shn', 'si', 'sk',
|
319 |
+
'sl', 'sq', 'sr', 'sv', 'sw', 'ta', 'te', 'tn', 'tr', 'tt', 'ur', 'uz', 'vi',
|
320 |
+
'vi-vn-x-central', 'vi-vn-x-south', 'yue'
|
321 |
+
] # fmt: off
|
322 |
+
|
323 |
+
|
324 |
+
def make_cond_dict(
|
325 |
+
text: str = "It would be nice to have time for testing, indeed.",
|
326 |
+
language: str = "en-us",
|
327 |
+
speaker: torch.Tensor | None = None,
|
328 |
+
emotion: list[float] = [0.3077, 0.0256, 0.0256, 0.0256, 0.0256, 0.0256, 0.2564, 0.3077],
|
329 |
+
fmax: float = 22050.0,
|
330 |
+
pitch_std: float = 20.0,
|
331 |
+
speaking_rate: float = 15.0,
|
332 |
+
vqscore_8: list[float] = [0.78] * 8,
|
333 |
+
ctc_loss: float = 0.0,
|
334 |
+
dnsmos_ovrl: float = 4.0,
|
335 |
+
speaker_noised: bool = False,
|
336 |
+
unconditional_keys: Iterable[str] = {"vqscore_8", "dnsmos_ovrl"},
|
337 |
+
device: str = "cuda",
|
338 |
+
) -> dict:
|
339 |
+
"""
|
340 |
+
A helper to build the 'cond_dict' that the model expects.
|
341 |
+
By default, it will generate a random speaker embedding
|
342 |
+
"""
|
343 |
+
assert language.lower() in supported_language_codes, "Please pick a supported language"
|
344 |
+
|
345 |
+
language_code_to_id = {lang: i for i, lang in enumerate(supported_language_codes)}
|
346 |
+
|
347 |
+
cond_dict = {
|
348 |
+
"espeak": ([text], [language]),
|
349 |
+
"speaker": speaker,
|
350 |
+
"emotion": emotion,
|
351 |
+
"fmax": fmax,
|
352 |
+
"pitch_std": pitch_std,
|
353 |
+
"speaking_rate": speaking_rate,
|
354 |
+
"language_id": language_code_to_id[language],
|
355 |
+
"vqscore_8": vqscore_8,
|
356 |
+
"ctc_loss": ctc_loss,
|
357 |
+
"dnsmos_ovrl": dnsmos_ovrl,
|
358 |
+
"speaker_noised": int(speaker_noised),
|
359 |
+
}
|
360 |
+
|
361 |
+
for k in unconditional_keys:
|
362 |
+
cond_dict.pop(k, None)
|
363 |
+
|
364 |
+
for k, v in cond_dict.items():
|
365 |
+
if isinstance(v, (float, int, list)):
|
366 |
+
v = torch.tensor(v)
|
367 |
+
if isinstance(v, torch.Tensor):
|
368 |
+
cond_dict[k] = v.view(1, 1, -1).to(device)
|
369 |
+
|
370 |
+
if k == "emotion":
|
371 |
+
cond_dict[k] /= cond_dict[k].sum(dim=-1)
|
372 |
+
|
373 |
+
return cond_dict
|
zonos/config.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Literal
|
3 |
+
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class BackboneConfig:
|
7 |
+
d_model: int = 1024
|
8 |
+
d_intermediate: int = 0
|
9 |
+
attn_mlp_d_intermediate: int = 0
|
10 |
+
n_layer: int = 16
|
11 |
+
ssm_cfg: dict = field(default_factory=dict)
|
12 |
+
attn_layer_idx: list = field(default_factory=list)
|
13 |
+
attn_cfg: dict = field(default_factory=dict)
|
14 |
+
rms_norm: bool = False
|
15 |
+
residual_in_fp32: bool = False
|
16 |
+
norm_epsilon: float = 1e-5
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class PrefixConditionerConfig:
|
21 |
+
conditioners: list[dict]
|
22 |
+
projection: Literal["none", "linear", "mlp"]
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class ZonosConfig:
|
27 |
+
backbone: BackboneConfig
|
28 |
+
prefix_conditioner: PrefixConditionerConfig
|
29 |
+
eos_token_id: int = 1024
|
30 |
+
masked_token_id: int = 1025
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def from_dict(cls, d: dict) -> "ZonosConfig":
|
34 |
+
d = d.copy()
|
35 |
+
backbone_config = BackboneConfig(**d.pop("backbone"))
|
36 |
+
prefix_conditioner_config = PrefixConditionerConfig(**d.pop("prefix_conditioner"))
|
37 |
+
config = cls(backbone_config, prefix_conditioner_config, **d)
|
38 |
+
return config
|
zonos/model.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import Callable
|
3 |
+
|
4 |
+
import safetensors
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from huggingface_hub import hf_hub_download
|
8 |
+
from mamba_ssm.utils.generation import InferenceParams
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from zonos.autoencoder import DACAutoencoder
|
12 |
+
from zonos.backbone import ZonosBackbone
|
13 |
+
from zonos.codebook_pattern import apply_delay_pattern, revert_delay_pattern
|
14 |
+
from zonos.conditioning import PrefixConditioner
|
15 |
+
from zonos.config import ZonosConfig
|
16 |
+
from zonos.sampling import sample_from_logits
|
17 |
+
from zonos.speaker_cloning import SpeakerEmbeddingLDA
|
18 |
+
|
19 |
+
|
20 |
+
class Zonos(nn.Module):
|
21 |
+
def __init__(self, config: ZonosConfig):
|
22 |
+
super().__init__()
|
23 |
+
self.config = config
|
24 |
+
dim = config.backbone.d_model
|
25 |
+
self.eos_token_id = config.eos_token_id
|
26 |
+
self.masked_token_id = config.masked_token_id
|
27 |
+
|
28 |
+
self.autoencoder = DACAutoencoder()
|
29 |
+
self.backbone = ZonosBackbone(config.backbone)
|
30 |
+
self.prefix_conditioner = PrefixConditioner(config.prefix_conditioner, dim)
|
31 |
+
self.spk_clone_model = None
|
32 |
+
|
33 |
+
# TODO: pad to multiple of at least 8
|
34 |
+
self.embeddings = nn.ModuleList([nn.Embedding(1026, dim) for _ in range(self.autoencoder.num_codebooks)])
|
35 |
+
self.heads = nn.ModuleList([nn.Linear(dim, 1025, bias=False) for _ in range(self.autoencoder.num_codebooks)])
|
36 |
+
|
37 |
+
self._cg_graph = None
|
38 |
+
self._cg_batch_size = None
|
39 |
+
self._cg_input_ids = None
|
40 |
+
self._cg_logits = None
|
41 |
+
self._cg_inference_params = None
|
42 |
+
self._cg_scale = None
|
43 |
+
|
44 |
+
@classmethod
|
45 |
+
def from_pretrained(cls, repo_id: str, revision: str | None = None, device: str = "cuda") -> "Zonos":
|
46 |
+
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
|
47 |
+
model_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
|
48 |
+
return cls.from_local(config_path, model_path, device)
|
49 |
+
|
50 |
+
@classmethod
|
51 |
+
def from_local(cls, config_path: str, model_path: str, device: str = "cuda") -> "Zonos":
|
52 |
+
config = ZonosConfig.from_dict(json.load(open(config_path)))
|
53 |
+
model = cls(config).to(device, torch.bfloat16)
|
54 |
+
model.autoencoder.dac.to(device)
|
55 |
+
|
56 |
+
sd = model.state_dict()
|
57 |
+
with safetensors.safe_open(model_path, framework="pt") as f:
|
58 |
+
for k in f.keys():
|
59 |
+
sd[k] = f.get_tensor(k)
|
60 |
+
model.load_state_dict(sd)
|
61 |
+
|
62 |
+
return model
|
63 |
+
|
64 |
+
def make_speaker_embedding(self, wav: torch.Tensor, sr: int) -> torch.Tensor:
|
65 |
+
"""Generate a speaker embedding from an audio clip."""
|
66 |
+
if self.spk_clone_model is None:
|
67 |
+
self.spk_clone_model = SpeakerEmbeddingLDA()
|
68 |
+
_, spk_embedding = self.spk_clone_model(wav.to(self.spk_clone_model.device), sr)
|
69 |
+
return spk_embedding.unsqueeze(0).bfloat16()
|
70 |
+
|
71 |
+
def embed_codes(self, codes: torch.Tensor) -> torch.Tensor:
|
72 |
+
return sum(emb(codes[:, i]) for i, emb in enumerate(self.embeddings))
|
73 |
+
|
74 |
+
def apply_heads(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
75 |
+
return torch.stack([head(hidden_states) for head in self.heads], dim=1)
|
76 |
+
|
77 |
+
def _compute_logits(
|
78 |
+
self, hidden_states: torch.Tensor, inference_params: InferenceParams, cfg_scale: float
|
79 |
+
) -> torch.Tensor:
|
80 |
+
"""
|
81 |
+
Pass `hidden_states` into `backbone` and `multi_head`, applying
|
82 |
+
classifier-free guidance if `cfg_scale != 1.0`.
|
83 |
+
"""
|
84 |
+
last_hidden_states = self.backbone(hidden_states, inference_params)[:, -1, :].unsqueeze(1)
|
85 |
+
logits = self.apply_heads(last_hidden_states).squeeze(2).float()
|
86 |
+
if cfg_scale != 1.0:
|
87 |
+
cond_logits, uncond_logits = logits.chunk(2)
|
88 |
+
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
|
89 |
+
return logits
|
90 |
+
|
91 |
+
def _decode_one_token(
|
92 |
+
self,
|
93 |
+
input_ids: torch.Tensor,
|
94 |
+
inference_params: InferenceParams,
|
95 |
+
cfg_scale: float,
|
96 |
+
) -> torch.Tensor:
|
97 |
+
"""
|
98 |
+
Single-step decode. Prepares the hidden states, possibly replicates them
|
99 |
+
for CFG, and then delegates to `_compute_logits`.
|
100 |
+
|
101 |
+
Below we wrap this function with a simple CUDA Graph capturing mechanism,
|
102 |
+
doing 3 warmup steps if needed and then capturing or replaying the graph.
|
103 |
+
We only recapture if the batch size changes.
|
104 |
+
"""
|
105 |
+
# TODO: support cfg_scale==1
|
106 |
+
if cfg_scale == 1.0:
|
107 |
+
hidden_states = self.embed_codes(input_ids)
|
108 |
+
return self._compute_logits(hidden_states, inference_params, cfg_scale)
|
109 |
+
|
110 |
+
bsz = input_ids.size(0)
|
111 |
+
|
112 |
+
need_capture = (self._cg_graph is None) or (self._cg_batch_size != bsz)
|
113 |
+
|
114 |
+
if need_capture:
|
115 |
+
self._cg_graph = None
|
116 |
+
|
117 |
+
self._cg_batch_size = bsz
|
118 |
+
self._cg_inference_params = inference_params
|
119 |
+
self._cg_scale = cfg_scale
|
120 |
+
|
121 |
+
for _ in range(3):
|
122 |
+
hidden_states = self.embed_codes(input_ids)
|
123 |
+
hidden_states = hidden_states.repeat(2, 1, 1) # because cfg != 1.0
|
124 |
+
logits = self._compute_logits(hidden_states, inference_params, cfg_scale)
|
125 |
+
|
126 |
+
self._cg_input_ids = input_ids.clone()
|
127 |
+
self._cg_logits = torch.empty_like(logits)
|
128 |
+
|
129 |
+
g = torch.cuda.CUDAGraph()
|
130 |
+
|
131 |
+
def capture_region():
|
132 |
+
hidden_states_local = self.embed_codes(self._cg_input_ids)
|
133 |
+
hidden_states_local = hidden_states_local.repeat(2, 1, 1)
|
134 |
+
self._cg_logits = self._compute_logits(hidden_states_local, self._cg_inference_params, self._cg_scale)
|
135 |
+
|
136 |
+
with torch.cuda.graph(g):
|
137 |
+
capture_region()
|
138 |
+
|
139 |
+
self._cg_graph = g
|
140 |
+
|
141 |
+
else:
|
142 |
+
self._cg_input_ids.copy_(input_ids)
|
143 |
+
|
144 |
+
self._cg_graph.replay()
|
145 |
+
|
146 |
+
return self._cg_logits
|
147 |
+
|
148 |
+
def _prefill(
|
149 |
+
self,
|
150 |
+
prefix_hidden_states: torch.Tensor,
|
151 |
+
input_ids: torch.Tensor,
|
152 |
+
inference_params: InferenceParams,
|
153 |
+
cfg_scale: float,
|
154 |
+
) -> torch.Tensor:
|
155 |
+
"""
|
156 |
+
"Prefill" mode: we already have `prefix_hidden_states`, and we want
|
157 |
+
to append new embeddings, then compute the logits.
|
158 |
+
"""
|
159 |
+
# Replicate input_ids if CFG is enabled
|
160 |
+
if cfg_scale != 1.0:
|
161 |
+
input_ids = input_ids.expand(prefix_hidden_states.shape[0], -1, -1)
|
162 |
+
hidden_states = torch.cat([prefix_hidden_states, self.embed_codes(input_ids)], dim=1)
|
163 |
+
return self._compute_logits(hidden_states, inference_params, cfg_scale)
|
164 |
+
|
165 |
+
def setup_cache(self, batch_size: int, max_seqlen: int, dtype: torch.dtype = torch.bfloat16) -> InferenceParams:
|
166 |
+
key_value_memory_dict = {
|
167 |
+
i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)
|
168 |
+
for i, layer in enumerate(self.backbone.layers)
|
169 |
+
}
|
170 |
+
lengths_per_sample = torch.full((batch_size,), 0, dtype=torch.int32, device="cuda")
|
171 |
+
return InferenceParams(max_seqlen, batch_size, 0, 0, key_value_memory_dict, lengths_per_sample)
|
172 |
+
|
173 |
+
def prepare_conditioning(self, cond_dict: dict, uncond_dict: dict | None = None) -> torch.Tensor:
|
174 |
+
if uncond_dict is None:
|
175 |
+
uncond_dict = {k: cond_dict[k] for k in self.prefix_conditioner.required_keys}
|
176 |
+
return torch.cat(
|
177 |
+
[
|
178 |
+
self.prefix_conditioner(cond_dict),
|
179 |
+
self.prefix_conditioner(uncond_dict),
|
180 |
+
]
|
181 |
+
)
|
182 |
+
|
183 |
+
@torch.inference_mode()
|
184 |
+
def generate(
|
185 |
+
self,
|
186 |
+
prefix_conditioning: torch.Tensor, # [bsz, cond_seq_len, d_model]
|
187 |
+
audio_prefix_codes: torch.Tensor | None = None, # [bsz, 9, prefix_audio_seq_len]
|
188 |
+
max_new_tokens: int = 86 * 30,
|
189 |
+
cfg_scale: float = 2.0,
|
190 |
+
batch_size: int = 1,
|
191 |
+
sampling_params: dict = dict(min_p=0.1),
|
192 |
+
progress_bar: bool = True,
|
193 |
+
callback: Callable[[torch.Tensor, int, int], bool] | None = None,
|
194 |
+
):
|
195 |
+
assert cfg_scale != 1, "TODO: add support for cfg_scale=1"
|
196 |
+
prefix_audio_len = 0 if audio_prefix_codes is None else audio_prefix_codes.shape[2]
|
197 |
+
|
198 |
+
unknown_token = -1
|
199 |
+
audio_seq_len = prefix_audio_len + max_new_tokens
|
200 |
+
seq_len = prefix_conditioning.shape[1] + audio_seq_len
|
201 |
+
|
202 |
+
inference_params = self.setup_cache(batch_size=batch_size * 2, max_seqlen=seq_len)
|
203 |
+
|
204 |
+
codes = torch.full((batch_size, 9, audio_seq_len), unknown_token, device="cuda")
|
205 |
+
if audio_prefix_codes is not None:
|
206 |
+
codes[..., :prefix_audio_len] = audio_prefix_codes
|
207 |
+
|
208 |
+
delayed_codes = apply_delay_pattern(codes, self.masked_token_id)
|
209 |
+
|
210 |
+
delayed_prefix_audio_codes = delayed_codes[..., : prefix_audio_len + 1]
|
211 |
+
|
212 |
+
logits = self._prefill(prefix_conditioning, delayed_prefix_audio_codes, inference_params, cfg_scale)
|
213 |
+
next_token = sample_from_logits(logits, **sampling_params)
|
214 |
+
|
215 |
+
offset = delayed_prefix_audio_codes.shape[2]
|
216 |
+
frame = delayed_codes[..., offset : offset + 1]
|
217 |
+
frame.masked_scatter_(frame == unknown_token, next_token)
|
218 |
+
|
219 |
+
prefix_length = prefix_conditioning.shape[1] + prefix_audio_len + 1
|
220 |
+
inference_params.seqlen_offset += prefix_length
|
221 |
+
inference_params.lengths_per_sample[:] += prefix_length
|
222 |
+
|
223 |
+
logit_bias = torch.zeros_like(logits)
|
224 |
+
logit_bias[:, 1:, self.eos_token_id] = -torch.inf # only allow codebook 0 to predict EOS
|
225 |
+
|
226 |
+
stopping = torch.zeros(batch_size, dtype=torch.bool, device="cuda")
|
227 |
+
max_steps = delayed_codes.shape[2] - offset
|
228 |
+
remaining_steps = torch.full((batch_size,), max_steps, device="cuda")
|
229 |
+
progress = tqdm(total=max_steps, desc="Generating", disable=not progress_bar)
|
230 |
+
|
231 |
+
step = 0
|
232 |
+
while torch.max(remaining_steps) > 0:
|
233 |
+
offset += 1
|
234 |
+
input_ids = delayed_codes[..., offset - 1 : offset]
|
235 |
+
logits = self._decode_one_token(input_ids, inference_params, cfg_scale)
|
236 |
+
|
237 |
+
next_token = sample_from_logits(logits, generated_tokens=delayed_codes[..., :offset], **sampling_params)
|
238 |
+
eos_in_cb0 = next_token[:, 0] == self.eos_token_id
|
239 |
+
|
240 |
+
remaining_steps[eos_in_cb0[:, 0]] = torch.minimum(remaining_steps[eos_in_cb0[:, 0]], torch.tensor(9))
|
241 |
+
stopping |= eos_in_cb0[:, 0]
|
242 |
+
|
243 |
+
eos_codebook_idx = 9 - remaining_steps
|
244 |
+
eos_codebook_idx = torch.clamp(eos_codebook_idx, max=9 - 1)
|
245 |
+
for i in range(next_token.shape[0]):
|
246 |
+
if stopping[i]:
|
247 |
+
idx = eos_codebook_idx[i].item()
|
248 |
+
next_token[i, :idx] = self.masked_token_id
|
249 |
+
next_token[i, idx] = self.eos_token_id
|
250 |
+
|
251 |
+
frame = delayed_codes[..., offset : offset + 1]
|
252 |
+
frame.masked_scatter_(frame == unknown_token, next_token)
|
253 |
+
inference_params.seqlen_offset += 1
|
254 |
+
inference_params.lengths_per_sample[:] += 1
|
255 |
+
|
256 |
+
remaining_steps -= 1
|
257 |
+
|
258 |
+
progress.update()
|
259 |
+
step += 1
|
260 |
+
|
261 |
+
if callback is not None and not callback(frame, step, max_steps):
|
262 |
+
break
|
263 |
+
|
264 |
+
out_codes = revert_delay_pattern(delayed_codes)
|
265 |
+
out_codes.masked_fill_(out_codes >= 1024, 0)
|
266 |
+
out_codes = out_codes[..., : offset - 9]
|
267 |
+
|
268 |
+
self._cg_graph = None # reset cuda graph to avoid cache changes
|
269 |
+
|
270 |
+
return out_codes
|
zonos/sampling.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
|
5 |
+
"""torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
|
6 |
+
|
7 |
+
Args:
|
8 |
+
input (torch.Tensor): The input tensor containing probabilities.
|
9 |
+
num_samples (int): Number of samples to draw.
|
10 |
+
replacement (bool): Whether to draw with replacement or not.
|
11 |
+
Keywords args:
|
12 |
+
generator (torch.Generator): A pseudorandom number generator for sampling.
|
13 |
+
Returns:
|
14 |
+
torch.Tensor: Last dimension contains num_samples indices
|
15 |
+
sampled from the multinomial probability distribution
|
16 |
+
located in the last dimension of tensor input.
|
17 |
+
"""
|
18 |
+
|
19 |
+
if num_samples == 1:
|
20 |
+
q = torch.empty_like(input).exponential_(1, generator=generator)
|
21 |
+
return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64)
|
22 |
+
|
23 |
+
input_ = input.reshape(-1, input.shape[-1])
|
24 |
+
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
|
25 |
+
output = output_.reshape(*list(input.shape[:-1]), -1)
|
26 |
+
return output
|
27 |
+
|
28 |
+
|
29 |
+
def apply_top_k(
|
30 |
+
probs: torch.Tensor,
|
31 |
+
k: int,
|
32 |
+
) -> torch.Tensor:
|
33 |
+
"""Sample next token from top K values along the last dimension of the input probs tensor.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
37 |
+
k (int): The k in “top-k”.
|
38 |
+
Returns:
|
39 |
+
torch.Tensor: Sampled tokens.
|
40 |
+
"""
|
41 |
+
v, _ = torch.topk(probs, min(k, probs.size(-1)))
|
42 |
+
pivot = v.select(-1, -1).unsqueeze(-1)
|
43 |
+
probs = torch.where(probs < pivot, 0.0, probs)
|
44 |
+
probs.div_(probs.sum(dim=-1, keepdim=True))
|
45 |
+
return probs
|
46 |
+
|
47 |
+
|
48 |
+
def apply_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
49 |
+
"""Sample next token from top P probabilities along the last dimension of the input probs tensor.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
53 |
+
p (int): The p in “top-p”.
|
54 |
+
Returns:
|
55 |
+
torch.Tensor: Sampled tokens.
|
56 |
+
"""
|
57 |
+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
58 |
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
59 |
+
mask = probs_sum - probs_sort > p
|
60 |
+
probs_sort *= (~mask).float()
|
61 |
+
probs = probs.scatter(-1, probs_idx, probs_sort)
|
62 |
+
probs.div_(probs.sum(dim=-1, keepdim=True))
|
63 |
+
return probs
|
64 |
+
|
65 |
+
|
66 |
+
def apply_min_p(probs: torch.Tensor, min_p: float) -> torch.Tensor:
|
67 |
+
"""Sample next token using min-p sampling.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
scores (torch.FloatTensor): Input logits with token candidates on the last dimension.
|
71 |
+
min_p (float): Minimum token probability, scaled by the probability of the most likely token.
|
72 |
+
Must be between 0 and 1. Typical values are in the 0.01-0.2 range.
|
73 |
+
Returns:
|
74 |
+
torch.Tensor: Sampled tokens.
|
75 |
+
"""
|
76 |
+
top_probs, _ = probs.max(dim=-1, keepdim=True)
|
77 |
+
tokens_to_remove = probs < (min_p * top_probs)
|
78 |
+
probs = probs.masked_fill(tokens_to_remove, 0.0)
|
79 |
+
probs.div_(probs.sum(dim=-1, keepdim=True))
|
80 |
+
return probs
|
81 |
+
|
82 |
+
|
83 |
+
def modify_logit_for_repetition_penalty(
|
84 |
+
logits: torch.Tensor,
|
85 |
+
generated_tokens: torch.Tensor,
|
86 |
+
repetition_penalty: float,
|
87 |
+
repetition_penalty_window: int,
|
88 |
+
):
|
89 |
+
"""See https://arxiv.org/abs/1909.05858
|
90 |
+
Apply repetition penalty over a sliding window of the last `repetition_penalty_window` tokens.
|
91 |
+
logits: (batch_size, n_codebooks, vocab_size)
|
92 |
+
generated_tokens: (batch_size, n_codebooks, seq_len)
|
93 |
+
"""
|
94 |
+
generated_tokens = generated_tokens[..., -repetition_penalty_window:]
|
95 |
+
generated_tokens = generated_tokens.clamp_max(logits.shape[-1] - 1).to(torch.int64)
|
96 |
+
rp = torch.full_like(logits, repetition_penalty)
|
97 |
+
factors = torch.ones_like(logits).scatter_reduce(2, generated_tokens, rp, reduce="prod")
|
98 |
+
return torch.where(logits <= 0, logits * factors, logits / factors)
|
99 |
+
|
100 |
+
|
101 |
+
def sample_from_logits(
|
102 |
+
logits: torch.Tensor,
|
103 |
+
temperature: float = 1.0,
|
104 |
+
top_p: float = 0.0,
|
105 |
+
top_k: int = 0,
|
106 |
+
min_p: float = 0.0,
|
107 |
+
generated_tokens: torch.Tensor | None = None,
|
108 |
+
repetition_penalty: float = 3.0,
|
109 |
+
repetition_penalty_window: float = 2,
|
110 |
+
) -> torch.Tensor:
|
111 |
+
"""Sample next token from logits using temperature, top-p, top-k, or min-p sampling.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
logits (torch.Tensor): Input logits with token candidates on the last dimension.
|
115 |
+
temperature (float): Sampling temperature. Lower temperature results in more deterministic samples.
|
116 |
+
top_p (float): The p in “top-p”.
|
117 |
+
top_k (int): The k in “top-k”.
|
118 |
+
min_p (float): Minimum token probability, scaled by the probability of the most likely token.
|
119 |
+
Must be between 0 and 1. Typical values are in the 0.01-0.2 range.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
torch.Tensor: Sampled tokens.
|
123 |
+
"""
|
124 |
+
if repetition_penalty != 1.0 and generated_tokens is not None:
|
125 |
+
logits = modify_logit_for_repetition_penalty(logits, generated_tokens, repetition_penalty, repetition_penalty_window)
|
126 |
+
|
127 |
+
if temperature > 0:
|
128 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
129 |
+
|
130 |
+
if top_p > 0:
|
131 |
+
probs = apply_top_p(probs, top_p)
|
132 |
+
if top_k > 0:
|
133 |
+
probs = apply_top_k(probs, top_k)
|
134 |
+
if min_p > 0:
|
135 |
+
probs = apply_min_p(probs, min_p)
|
136 |
+
|
137 |
+
next_token = multinomial(probs, num_samples=1)
|
138 |
+
else:
|
139 |
+
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
140 |
+
|
141 |
+
return next_token # [batch_size, num_codebooks, 1]
|
zonos/speaker_cloning.py
ADDED
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from functools import cache
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torchaudio
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
+
import os
|
10 |
+
|
11 |
+
|
12 |
+
class logFbankCal(nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
sample_rate: int = 16_000,
|
16 |
+
n_fft: int = 512,
|
17 |
+
win_length: float = 0.025,
|
18 |
+
hop_length: float = 0.01,
|
19 |
+
n_mels: int = 80,
|
20 |
+
):
|
21 |
+
super().__init__()
|
22 |
+
self.fbankCal = torchaudio.transforms.MelSpectrogram(
|
23 |
+
sample_rate=sample_rate,
|
24 |
+
n_fft=n_fft,
|
25 |
+
win_length=int(win_length * sample_rate),
|
26 |
+
hop_length=int(hop_length * sample_rate),
|
27 |
+
n_mels=n_mels,
|
28 |
+
)
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
out = self.fbankCal(x)
|
32 |
+
out = torch.log(out + 1e-6)
|
33 |
+
out = out - out.mean(axis=2).unsqueeze(dim=2)
|
34 |
+
return out
|
35 |
+
|
36 |
+
|
37 |
+
class ASP(nn.Module):
|
38 |
+
# Attentive statistics pooling
|
39 |
+
def __init__(self, in_planes, acoustic_dim):
|
40 |
+
super(ASP, self).__init__()
|
41 |
+
outmap_size = int(acoustic_dim / 8)
|
42 |
+
self.out_dim = in_planes * 8 * outmap_size * 2
|
43 |
+
|
44 |
+
self.attention = nn.Sequential(
|
45 |
+
nn.Conv1d(in_planes * 8 * outmap_size, 128, kernel_size=1),
|
46 |
+
nn.ReLU(),
|
47 |
+
nn.BatchNorm1d(128),
|
48 |
+
nn.Conv1d(128, in_planes * 8 * outmap_size, kernel_size=1),
|
49 |
+
nn.Softmax(dim=2),
|
50 |
+
)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
x = x.reshape(x.size()[0], -1, x.size()[-1])
|
54 |
+
w = self.attention(x)
|
55 |
+
mu = torch.sum(x * w, dim=2)
|
56 |
+
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5))
|
57 |
+
x = torch.cat((mu, sg), 1)
|
58 |
+
|
59 |
+
x = x.view(x.size()[0], -1)
|
60 |
+
return x
|
61 |
+
|
62 |
+
|
63 |
+
class SimAMBasicBlock(nn.Module):
|
64 |
+
expansion = 1
|
65 |
+
|
66 |
+
def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
|
67 |
+
super(SimAMBasicBlock, self).__init__()
|
68 |
+
self.conv1 = ConvLayer(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
69 |
+
self.bn1 = NormLayer(planes)
|
70 |
+
self.conv2 = ConvLayer(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
71 |
+
self.bn2 = NormLayer(planes)
|
72 |
+
self.relu = nn.ReLU(inplace=True)
|
73 |
+
self.sigmoid = nn.Sigmoid()
|
74 |
+
|
75 |
+
self.downsample = nn.Sequential()
|
76 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
77 |
+
self.downsample = nn.Sequential(
|
78 |
+
ConvLayer(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
79 |
+
NormLayer(self.expansion * planes),
|
80 |
+
)
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
84 |
+
out = self.bn2(self.conv2(out))
|
85 |
+
out = self.SimAM(out)
|
86 |
+
out += self.downsample(x)
|
87 |
+
out = self.relu(out)
|
88 |
+
return out
|
89 |
+
|
90 |
+
def SimAM(self, X, lambda_p=1e-4):
|
91 |
+
n = X.shape[2] * X.shape[3] - 1
|
92 |
+
d = (X - X.mean(dim=[2, 3], keepdim=True)).pow(2)
|
93 |
+
v = d.sum(dim=[2, 3], keepdim=True) / n
|
94 |
+
E_inv = d / (4 * (v + lambda_p)) + 0.5
|
95 |
+
return X * self.sigmoid(E_inv)
|
96 |
+
|
97 |
+
|
98 |
+
class BasicBlock(nn.Module):
|
99 |
+
expansion = 1
|
100 |
+
|
101 |
+
def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
|
102 |
+
super(BasicBlock, self).__init__()
|
103 |
+
self.conv1 = ConvLayer(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
104 |
+
self.bn1 = NormLayer(planes)
|
105 |
+
self.conv2 = ConvLayer(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
106 |
+
self.bn2 = NormLayer(planes)
|
107 |
+
self.relu = nn.ReLU(inplace=True)
|
108 |
+
|
109 |
+
self.downsample = nn.Sequential()
|
110 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
111 |
+
self.downsample = nn.Sequential(
|
112 |
+
ConvLayer(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
113 |
+
NormLayer(self.expansion * planes),
|
114 |
+
)
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
118 |
+
out = self.bn2(self.conv2(out))
|
119 |
+
out += self.downsample(x)
|
120 |
+
out = self.relu(out)
|
121 |
+
return out
|
122 |
+
|
123 |
+
|
124 |
+
class Bottleneck(nn.Module):
|
125 |
+
expansion = 4
|
126 |
+
|
127 |
+
def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
|
128 |
+
super(Bottleneck, self).__init__()
|
129 |
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
130 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
131 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
132 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
133 |
+
self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
|
134 |
+
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
|
135 |
+
|
136 |
+
self.shortcut = nn.Sequential()
|
137 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
138 |
+
self.shortcut = nn.Sequential(
|
139 |
+
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
140 |
+
nn.BatchNorm2d(self.expansion * planes),
|
141 |
+
)
|
142 |
+
|
143 |
+
def forward(self, x):
|
144 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
145 |
+
out = F.relu(self.bn2(self.conv2(out)))
|
146 |
+
out = self.bn3(self.conv3(out))
|
147 |
+
out += self.shortcut(x)
|
148 |
+
out = F.relu(out)
|
149 |
+
return out
|
150 |
+
|
151 |
+
|
152 |
+
class ResNet(nn.Module):
|
153 |
+
def __init__(self, in_planes, block, num_blocks, in_ch=1, feat_dim="2d", **kwargs):
|
154 |
+
super(ResNet, self).__init__()
|
155 |
+
if feat_dim == "1d":
|
156 |
+
self.NormLayer = nn.BatchNorm1d
|
157 |
+
self.ConvLayer = nn.Conv1d
|
158 |
+
elif feat_dim == "2d":
|
159 |
+
self.NormLayer = nn.BatchNorm2d
|
160 |
+
self.ConvLayer = nn.Conv2d
|
161 |
+
elif feat_dim == "3d":
|
162 |
+
self.NormLayer = nn.BatchNorm3d
|
163 |
+
self.ConvLayer = nn.Conv3d
|
164 |
+
else:
|
165 |
+
print("error")
|
166 |
+
|
167 |
+
self.in_planes = in_planes
|
168 |
+
|
169 |
+
self.conv1 = self.ConvLayer(in_ch, in_planes, kernel_size=3, stride=1, padding=1, bias=False)
|
170 |
+
self.bn1 = self.NormLayer(in_planes)
|
171 |
+
self.relu = nn.ReLU(inplace=True)
|
172 |
+
self.layer1 = self._make_layer(block, in_planes, num_blocks[0], stride=1, block_id=1)
|
173 |
+
self.layer2 = self._make_layer(block, in_planes * 2, num_blocks[1], stride=2, block_id=2)
|
174 |
+
self.layer3 = self._make_layer(block, in_planes * 4, num_blocks[2], stride=2, block_id=3)
|
175 |
+
self.layer4 = self._make_layer(block, in_planes * 8, num_blocks[3], stride=2, block_id=4)
|
176 |
+
|
177 |
+
def _make_layer(self, block, planes, num_blocks, stride, block_id=1):
|
178 |
+
strides = [stride] + [1] * (num_blocks - 1)
|
179 |
+
layers = []
|
180 |
+
for stride in strides:
|
181 |
+
layers.append(block(self.ConvLayer, self.NormLayer, self.in_planes, planes, stride, block_id))
|
182 |
+
self.in_planes = planes * block.expansion
|
183 |
+
return nn.Sequential(*layers)
|
184 |
+
|
185 |
+
def forward(self, x):
|
186 |
+
x = self.relu(self.bn1(self.conv1(x)))
|
187 |
+
x = self.layer1(x)
|
188 |
+
x = self.layer2(x)
|
189 |
+
x = self.layer3(x)
|
190 |
+
x = self.layer4(x)
|
191 |
+
return x
|
192 |
+
|
193 |
+
|
194 |
+
def ResNet293(in_planes: int, **kwargs):
|
195 |
+
return ResNet(in_planes, SimAMBasicBlock, [10, 20, 64, 3], **kwargs)
|
196 |
+
|
197 |
+
|
198 |
+
class ResNet293_based(nn.Module):
|
199 |
+
def __init__(
|
200 |
+
self,
|
201 |
+
in_planes: int = 64,
|
202 |
+
embd_dim: int = 256,
|
203 |
+
acoustic_dim: int = 80,
|
204 |
+
featCal=None,
|
205 |
+
dropout: float = 0,
|
206 |
+
**kwargs,
|
207 |
+
):
|
208 |
+
super(ResNet293_based, self).__init__()
|
209 |
+
self.featCal = featCal
|
210 |
+
self.front = ResNet293(in_planes)
|
211 |
+
block_expansion = SimAMBasicBlock.expansion
|
212 |
+
self.pooling = ASP(in_planes * block_expansion, acoustic_dim)
|
213 |
+
self.bottleneck = nn.Linear(self.pooling.out_dim, embd_dim)
|
214 |
+
self.drop = nn.Dropout(dropout) if dropout else None
|
215 |
+
|
216 |
+
def forward(self, x):
|
217 |
+
x = self.featCal(x)
|
218 |
+
x = self.front(x.unsqueeze(dim=1))
|
219 |
+
x = self.pooling(x)
|
220 |
+
if self.drop:
|
221 |
+
x = self.drop(x)
|
222 |
+
x = self.bottleneck(x)
|
223 |
+
return x
|
224 |
+
|
225 |
+
|
226 |
+
class SEModule(nn.Module):
|
227 |
+
def __init__(self, channels, bottleneck=128):
|
228 |
+
super(SEModule, self).__init__()
|
229 |
+
self.se = nn.Sequential(
|
230 |
+
nn.AdaptiveAvgPool1d(1),
|
231 |
+
nn.Conv1d(channels, bottleneck, kernel_size=1, padding=0),
|
232 |
+
nn.ReLU(),
|
233 |
+
# nn.BatchNorm1d(bottleneck), # Removed
|
234 |
+
nn.Conv1d(bottleneck, channels, kernel_size=1, padding=0),
|
235 |
+
nn.Sigmoid(),
|
236 |
+
)
|
237 |
+
|
238 |
+
def forward(self, input):
|
239 |
+
x = self.se(input)
|
240 |
+
return input * x
|
241 |
+
|
242 |
+
|
243 |
+
class Bottle2neck(nn.Module):
|
244 |
+
def __init__(self, inplanes, planes, kernel_size=None, dilation=None, scale=8):
|
245 |
+
super(Bottle2neck, self).__init__()
|
246 |
+
width = int(math.floor(planes / scale))
|
247 |
+
self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1)
|
248 |
+
self.bn1 = nn.BatchNorm1d(width * scale)
|
249 |
+
self.nums = scale - 1
|
250 |
+
convs = []
|
251 |
+
bns = []
|
252 |
+
num_pad = math.floor(kernel_size / 2) * dilation
|
253 |
+
for i in range(self.nums):
|
254 |
+
convs.append(nn.Conv1d(width, width, kernel_size=kernel_size, dilation=dilation, padding=num_pad))
|
255 |
+
bns.append(nn.BatchNorm1d(width))
|
256 |
+
self.convs = nn.ModuleList(convs)
|
257 |
+
self.bns = nn.ModuleList(bns)
|
258 |
+
self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1)
|
259 |
+
self.bn3 = nn.BatchNorm1d(planes)
|
260 |
+
self.relu = nn.ReLU()
|
261 |
+
self.width = width
|
262 |
+
self.se = SEModule(planes)
|
263 |
+
|
264 |
+
def forward(self, x):
|
265 |
+
residual = x
|
266 |
+
out = self.conv1(x)
|
267 |
+
out = self.relu(out)
|
268 |
+
out = self.bn1(out)
|
269 |
+
|
270 |
+
spx = torch.split(out, self.width, 1)
|
271 |
+
for i in range(self.nums):
|
272 |
+
if i == 0:
|
273 |
+
sp = spx[i]
|
274 |
+
else:
|
275 |
+
sp = sp + spx[i]
|
276 |
+
sp = self.convs[i](sp)
|
277 |
+
sp = self.relu(sp)
|
278 |
+
sp = self.bns[i](sp)
|
279 |
+
if i == 0:
|
280 |
+
out = sp
|
281 |
+
else:
|
282 |
+
out = torch.cat((out, sp), 1)
|
283 |
+
out = torch.cat((out, spx[self.nums]), 1)
|
284 |
+
|
285 |
+
out = self.conv3(out)
|
286 |
+
out = self.relu(out)
|
287 |
+
out = self.bn3(out)
|
288 |
+
|
289 |
+
out = self.se(out)
|
290 |
+
out += residual
|
291 |
+
return out
|
292 |
+
|
293 |
+
|
294 |
+
class ECAPA_TDNN(nn.Module):
|
295 |
+
def __init__(self, C, featCal):
|
296 |
+
super(ECAPA_TDNN, self).__init__()
|
297 |
+
self.featCal = featCal
|
298 |
+
self.conv1 = nn.Conv1d(80, C, kernel_size=5, stride=1, padding=2)
|
299 |
+
self.relu = nn.ReLU()
|
300 |
+
self.bn1 = nn.BatchNorm1d(C)
|
301 |
+
self.layer1 = Bottle2neck(C, C, kernel_size=3, dilation=2, scale=8)
|
302 |
+
self.layer2 = Bottle2neck(C, C, kernel_size=3, dilation=3, scale=8)
|
303 |
+
self.layer3 = Bottle2neck(C, C, kernel_size=3, dilation=4, scale=8)
|
304 |
+
# I fixed the shape of the output from MFA layer, that is close to the setting from ECAPA paper.
|
305 |
+
self.layer4 = nn.Conv1d(3 * C, 1536, kernel_size=1)
|
306 |
+
self.attention = nn.Sequential(
|
307 |
+
nn.Conv1d(4608, 256, kernel_size=1),
|
308 |
+
nn.ReLU(),
|
309 |
+
nn.BatchNorm1d(256),
|
310 |
+
nn.Tanh(), # Added
|
311 |
+
nn.Conv1d(256, 1536, kernel_size=1),
|
312 |
+
nn.Softmax(dim=2),
|
313 |
+
)
|
314 |
+
self.bn5 = nn.BatchNorm1d(3072)
|
315 |
+
self.fc6 = nn.Linear(3072, 192)
|
316 |
+
self.bn6 = nn.BatchNorm1d(192)
|
317 |
+
|
318 |
+
def forward(self, x):
|
319 |
+
x = self.featCal(x)
|
320 |
+
x = self.conv1(x)
|
321 |
+
x = self.relu(x)
|
322 |
+
x = self.bn1(x)
|
323 |
+
|
324 |
+
x1 = self.layer1(x)
|
325 |
+
x2 = self.layer2(x + x1)
|
326 |
+
x3 = self.layer3(x + x1 + x2)
|
327 |
+
|
328 |
+
x = self.layer4(torch.cat((x1, x2, x3), dim=1))
|
329 |
+
x = self.relu(x)
|
330 |
+
|
331 |
+
t = x.size()[-1]
|
332 |
+
|
333 |
+
global_x = torch.cat(
|
334 |
+
(
|
335 |
+
x,
|
336 |
+
torch.mean(x, dim=2, keepdim=True).repeat(1, 1, t),
|
337 |
+
torch.sqrt(torch.var(x, dim=2, keepdim=True).clamp(min=1e-4)).repeat(1, 1, t),
|
338 |
+
),
|
339 |
+
dim=1,
|
340 |
+
)
|
341 |
+
|
342 |
+
w = self.attention(global_x)
|
343 |
+
|
344 |
+
mu = torch.sum(x * w, dim=2)
|
345 |
+
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-4))
|
346 |
+
|
347 |
+
x = torch.cat((mu, sg), 1)
|
348 |
+
x = self.bn5(x)
|
349 |
+
x = self.fc6(x)
|
350 |
+
x = self.bn6(x)
|
351 |
+
|
352 |
+
return x
|
353 |
+
|
354 |
+
|
355 |
+
class SpeakerEmbedding(nn.Module):
|
356 |
+
def __init__(self, ckpt_path: str = "ResNet293_SimAM_ASP_base.pt", device: str = "cuda"):
|
357 |
+
super().__init__()
|
358 |
+
self.device = device
|
359 |
+
with torch.device(device):
|
360 |
+
self.model = ResNet293_based()
|
361 |
+
self.model.load_state_dict(torch.load(ckpt_path, weights_only=True, mmap=True))
|
362 |
+
self.model.featCal = logFbankCal()
|
363 |
+
|
364 |
+
self.requires_grad_(False).eval()
|
365 |
+
|
366 |
+
@property
|
367 |
+
def dtype(self):
|
368 |
+
return next(self.parameters()).dtype
|
369 |
+
|
370 |
+
@cache
|
371 |
+
def _get_resampler(self, orig_sample_rate: int):
|
372 |
+
return torchaudio.transforms.Resample(orig_sample_rate, 16_000).to(self.device)
|
373 |
+
|
374 |
+
def prepare_input(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
375 |
+
assert wav.ndim < 3
|
376 |
+
if wav.ndim == 2:
|
377 |
+
wav = wav.mean(0, keepdim=True)
|
378 |
+
wav = self._get_resampler(sample_rate)(wav)
|
379 |
+
return wav
|
380 |
+
|
381 |
+
def forward(self, wav: torch.Tensor, sample_rate: int):
|
382 |
+
wav = self.prepare_input(wav, sample_rate).to(self.device, self.dtype)
|
383 |
+
return self.model(wav).to(wav.device)
|
384 |
+
|
385 |
+
class SpeakerEmbeddingLDA(nn.Module):
|
386 |
+
def __init__(
|
387 |
+
self,
|
388 |
+
device: str = "cuda",
|
389 |
+
):
|
390 |
+
super().__init__()
|
391 |
+
spk_model_path = hf_hub_download(repo_id="Zyphra/Zonos-v0.1-speaker-embedding", filename="ResNet293_SimAM_ASP_base.pt")
|
392 |
+
lda_spk_model_path = hf_hub_download(repo_id="Zyphra/Zonos-v0.1-speaker-embedding", filename="ResNet293_SimAM_ASP_base_LDA-128.pt")
|
393 |
+
|
394 |
+
self.device = device
|
395 |
+
with torch.device(device):
|
396 |
+
self.model = SpeakerEmbedding(spk_model_path, device)
|
397 |
+
lda_sd = torch.load(lda_spk_model_path, weights_only=True)
|
398 |
+
out_features, in_features = lda_sd["weight"].shape
|
399 |
+
self.lda = nn.Linear(in_features, out_features, bias=True, dtype=torch.float32)
|
400 |
+
self.lda.load_state_dict(lda_sd)
|
401 |
+
|
402 |
+
self.requires_grad_(False).eval()
|
403 |
+
|
404 |
+
def forward(self, wav: torch.Tensor, sample_rate: int):
|
405 |
+
emb = self.model(wav, sample_rate).to(torch.float32)
|
406 |
+
return emb, self.lda(emb)
|