Steveeeeeeen HF staff hysts HF staff commited on
Commit
0af138e
·
verified ·
1 Parent(s): 1be704d
Dockerfile DELETED
@@ -1,72 +0,0 @@
1
- FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel
2
-
3
- # Install dependencies
4
- RUN pip install uv
5
-
6
- ENV DEBIAN_FRONTEND=noninteractive \
7
- TZ=Europe/Paris
8
-
9
- # Remove any third-party apt sources to avoid issues with expiring keys.
10
- # Install some basic utilities
11
- RUN rm -f /etc/apt/sources.list.d/*.list && \
12
- apt-get update && apt-get install -y --no-install-recommends \
13
- curl \
14
- ca-certificates \
15
- sudo \
16
- git \
17
- wget \
18
- procps \
19
- git-lfs \
20
- zip \
21
- unzip \
22
- htop \
23
- vim \
24
- nano \
25
- bzip2 \
26
- libx11-6 \
27
- build-essential \
28
- libsndfile-dev \
29
- software-properties-common \
30
- espeak \
31
- tmux \
32
- ffmpeg \
33
- && rm -rf /var/lib/apt/lists/*
34
-
35
- RUN add-apt-repository ppa:flexiondotorg/nvtop && \
36
- apt-get upgrade -y && \
37
- apt-get install -y --no-install-recommends nvtop
38
-
39
- RUN curl -sL https://deb.nodesource.com/setup_14.x | bash - && \
40
- apt-get install -y nodejs && \
41
- npm install -g configurable-http-proxy
42
-
43
-
44
- # Set working directory
45
- WORKDIR /app
46
-
47
- # Clone the repository
48
- RUN git clone https://github.com/Zyphra/Zonos.git && cd Zonos
49
-
50
- # Set environment variables for writable cache directories
51
- ENV TRITON_CACHE_DIR=/tmp/.triton
52
- ENV HF_HOME=/tmp/huggingface_cache
53
-
54
- # Ensure cache directories are writable
55
- RUN mkdir -p $TRITON_CACHE_DIR $TRANSFORMERS_CACHE && chmod -R 777 $TRITON_CACHE_DIR $TRANSFORMERS_CACHE
56
-
57
- # Install Python dependencies
58
- WORKDIR /app/Zonos
59
- RUN uv pip install --system -e . && uv pip install --system -e .[compile]
60
- RUN uv pip install --system spaces
61
-
62
- # Expose the Gradio default port
63
- EXPOSE 7860
64
-
65
- # Run the Gradio app from /app
66
- WORKDIR /app
67
- COPY . .
68
-
69
- EXPOSE 7860
70
- ENV GRADIO_SERVER_NAME="0.0.0.0"
71
-
72
- CMD ["python", "app.py"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,3 +1,12 @@
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torchaudio
3
  import gradio as gr
@@ -7,22 +16,10 @@ from zonos.model import Zonos
7
  from zonos.conditioning import make_cond_dict, supported_language_codes
8
 
9
  device = "cuda"
10
- CURRENT_MODEL_TYPE = None
11
- CURRENT_MODEL = None
12
-
13
-
14
- def load_model_if_needed(model_choice: str):
15
- global CURRENT_MODEL_TYPE, CURRENT_MODEL
16
- if CURRENT_MODEL_TYPE != model_choice:
17
- if CURRENT_MODEL is not None:
18
- del CURRENT_MODEL
19
- torch.cuda.empty_cache()
20
- print(f"Loading {model_choice} model...")
21
- CURRENT_MODEL = Zonos.from_pretrained(model_choice, device=device)
22
- CURRENT_MODEL.requires_grad_(False).eval()
23
- CURRENT_MODEL_TYPE = model_choice
24
- print(f"{model_choice} model loaded successfully!")
25
- return CURRENT_MODEL
26
 
27
 
28
  def update_ui(model_choice):
@@ -30,7 +27,7 @@ def update_ui(model_choice):
30
  Dynamically show/hide UI elements based on the model's conditioners.
31
  We do NOT display 'language_id' or 'ctc_loss' even if they exist in the model.
32
  """
33
- model = load_model_if_needed(model_choice)
34
  cond_names = [c.name for c in model.prefix_conditioner.conditioners]
35
  print("Conditioners in this model:", cond_names)
36
 
@@ -79,6 +76,7 @@ def update_ui(model_choice):
79
  )
80
 
81
 
 
82
  def generate_audio(
83
  model_choice,
84
  text,
@@ -110,7 +108,7 @@ def generate_audio(
110
  Generates audio based on the provided UI parameters.
111
  We do NOT use language_id or ctc_loss even if the model has them.
112
  """
113
- selected_model = load_model_if_needed(model_choice)
114
 
115
  speaker_noised_bool = bool(speaker_noised)
116
  fmax = float(fmax)
@@ -191,7 +189,7 @@ def build_interface():
191
  with gr.Row():
192
  with gr.Column():
193
  model_choice = gr.Dropdown(
194
- choices=["Zyphra/Zonos-v0.1-transformer", "Zyphra/Zonos-v0.1-hybrid"],
195
  value="Zyphra/Zonos-v0.1-transformer",
196
  label="Zonos Model Type",
197
  info="Select the model variant to use.",
@@ -369,4 +367,4 @@ def build_interface():
369
  if __name__ == "__main__":
370
  demo = build_interface()
371
  share = getenv("GRADIO_SHARE", "False").lower() in ("true", "1", "t")
372
- demo.launch(server_name="0.0.0.0", server_port=7860, share=share)
 
1
+ import os
2
+ import shlex
3
+ import subprocess
4
+
5
+ subprocess.run(shlex.split("pip install flash-attn --no-build-isolation"), env=os.environ | {"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, check=True)
6
+ subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.4/mamba_ssm-2.2.4+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"), check=True)
7
+ subprocess.run(shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.5.0.post8/causal_conv1d-1.5.0.post8+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"), check=True)
8
+
9
+ import spaces
10
  import torch
11
  import torchaudio
12
  import gradio as gr
 
16
  from zonos.conditioning import make_cond_dict, supported_language_codes
17
 
18
  device = "cuda"
19
+ MODEL_NAMES = ["Zyphra/Zonos-v0.1-transformer", "Zyphra/Zonos-v0.1-hybrid"]
20
+ MODELS = {name: Zonos.from_pretrained(name, device=device) for name in MODEL_NAMES}
21
+ for model in MODELS.values():
22
+ model.requires_grad_(False).eval()
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  def update_ui(model_choice):
 
27
  Dynamically show/hide UI elements based on the model's conditioners.
28
  We do NOT display 'language_id' or 'ctc_loss' even if they exist in the model.
29
  """
30
+ model = MODELS[model_choice]
31
  cond_names = [c.name for c in model.prefix_conditioner.conditioners]
32
  print("Conditioners in this model:", cond_names)
33
 
 
76
  )
77
 
78
 
79
+ @spaces.GPU(duration=120)
80
  def generate_audio(
81
  model_choice,
82
  text,
 
108
  Generates audio based on the provided UI parameters.
109
  We do NOT use language_id or ctc_loss even if the model has them.
110
  """
111
+ selected_model = MODELS[model_choice]
112
 
113
  speaker_noised_bool = bool(speaker_noised)
114
  fmax = float(fmax)
 
189
  with gr.Row():
190
  with gr.Column():
191
  model_choice = gr.Dropdown(
192
+ choices=MODEL_NAMES,
193
  value="Zyphra/Zonos-v0.1-transformer",
194
  label="Zonos Model Type",
195
  info="Select the model variant to use.",
 
367
  if __name__ == "__main__":
368
  demo = build_interface()
369
  share = getenv("GRADIO_SHARE", "False").lower() in ("true", "1", "t")
370
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=share)
assets/silence_100ms.wav ADDED
Binary file (9.43 kB). View file
 
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ espeak-ng
requirements.txt ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile pyproject.toml -o requirements.txt
3
+ aiofiles==23.2.1
4
+ # via gradio
5
+ annotated-types==0.7.0
6
+ # via pydantic
7
+ anyio==4.8.0
8
+ # via
9
+ # gradio
10
+ # httpx
11
+ # starlette
12
+ attrs==25.1.0
13
+ # via
14
+ # clldutils
15
+ # csvw
16
+ # jsonschema
17
+ # phonemizer
18
+ # referencing
19
+ babel==2.17.0
20
+ # via csvw
21
+ certifi==2025.1.31
22
+ # via
23
+ # httpcore
24
+ # httpx
25
+ # requests
26
+ cffi==1.17.1
27
+ # via soundfile
28
+ charset-normalizer==3.4.1
29
+ # via requests
30
+ click==8.1.8
31
+ # via
32
+ # typer
33
+ # uvicorn
34
+ clldutils==3.21.0
35
+ # via segments
36
+ colorama==0.4.6
37
+ # via csvw
38
+ colorlog==6.9.0
39
+ # via clldutils
40
+ csvw==3.5.1
41
+ # via segments
42
+ dlinfo==2.0.0
43
+ # via phonemizer
44
+ exceptiongroup==1.2.2
45
+ # via anyio
46
+ fastapi==0.115.8
47
+ # via gradio
48
+ ffmpy==0.5.0
49
+ # via gradio
50
+ filelock==3.17.0
51
+ # via
52
+ # huggingface-hub
53
+ # torch
54
+ # transformers
55
+ # triton
56
+ fsspec==2025.2.0
57
+ # via
58
+ # gradio-client
59
+ # huggingface-hub
60
+ # torch
61
+ gradio==5.16.0
62
+ # via
63
+ # zonos (pyproject.toml)
64
+ # spaces
65
+ gradio-client==1.7.0
66
+ # via gradio
67
+ h11==0.14.0
68
+ # via
69
+ # httpcore
70
+ # uvicorn
71
+ hf-transfer==0.1.9
72
+ # via zonos (pyproject.toml)
73
+ httpcore==1.0.7
74
+ # via httpx
75
+ httpx==0.28.1
76
+ # via
77
+ # gradio
78
+ # gradio-client
79
+ # safehttpx
80
+ # spaces
81
+ huggingface-hub==0.28.1
82
+ # via
83
+ # zonos (pyproject.toml)
84
+ # gradio
85
+ # gradio-client
86
+ # tokenizers
87
+ # transformers
88
+ idna==3.10
89
+ # via
90
+ # anyio
91
+ # httpx
92
+ # requests
93
+ inflect==7.5.0
94
+ # via zonos (pyproject.toml)
95
+ isodate==0.7.2
96
+ # via
97
+ # csvw
98
+ # rdflib
99
+ jinja2==3.1.5
100
+ # via
101
+ # gradio
102
+ # torch
103
+ joblib==1.4.2
104
+ # via phonemizer
105
+ jsonschema==4.23.0
106
+ # via csvw
107
+ jsonschema-specifications==2024.10.1
108
+ # via jsonschema
109
+ kanjize==1.6.0
110
+ # via zonos (pyproject.toml)
111
+ language-tags==1.2.0
112
+ # via csvw
113
+ lxml==5.3.1
114
+ # via clldutils
115
+ markdown==3.7
116
+ # via clldutils
117
+ markdown-it-py==3.0.0
118
+ # via rich
119
+ markupsafe==2.1.5
120
+ # via
121
+ # clldutils
122
+ # gradio
123
+ # jinja2
124
+ mdurl==0.1.2
125
+ # via markdown-it-py
126
+ more-itertools==10.6.0
127
+ # via inflect
128
+ mpmath==1.3.0
129
+ # via sympy
130
+ networkx==3.4.2
131
+ # via torch
132
+ numpy==2.2.2
133
+ # via
134
+ # zonos (pyproject.toml)
135
+ # gradio
136
+ # pandas
137
+ # soundfile
138
+ # transformers
139
+ nvidia-cublas-cu12==12.1.3.1
140
+ # via
141
+ # nvidia-cudnn-cu12
142
+ # nvidia-cusolver-cu12
143
+ # torch
144
+ nvidia-cuda-cupti-cu12==12.1.105
145
+ # via torch
146
+ nvidia-cuda-nvrtc-cu12==12.1.105
147
+ # via torch
148
+ nvidia-cuda-runtime-cu12==12.1.105
149
+ # via torch
150
+ nvidia-cudnn-cu12==9.1.0.70
151
+ # via torch
152
+ nvidia-cufft-cu12==11.0.2.54
153
+ # via torch
154
+ nvidia-curand-cu12==10.3.2.106
155
+ # via torch
156
+ nvidia-cusolver-cu12==11.4.5.107
157
+ # via torch
158
+ nvidia-cusparse-cu12==12.1.0.106
159
+ # via
160
+ # nvidia-cusolver-cu12
161
+ # torch
162
+ nvidia-nccl-cu12==2.20.5
163
+ # via torch
164
+ nvidia-nvjitlink-cu12==12.8.61
165
+ # via
166
+ # nvidia-cusolver-cu12
167
+ # nvidia-cusparse-cu12
168
+ nvidia-nvtx-cu12==12.1.105
169
+ # via torch
170
+ orjson==3.10.15
171
+ # via gradio
172
+ packaging==24.2
173
+ # via
174
+ # zonos (pyproject.toml)
175
+ # gradio
176
+ # gradio-client
177
+ # huggingface-hub
178
+ # spaces
179
+ # transformers
180
+ pandas==2.2.3
181
+ # via gradio
182
+ phonemizer==3.3.0
183
+ # via zonos (pyproject.toml)
184
+ pillow==11.1.0
185
+ # via gradio
186
+ psutil==5.9.8
187
+ # via spaces
188
+ pycparser==2.22
189
+ # via cffi
190
+ pydantic==2.10.6
191
+ # via
192
+ # fastapi
193
+ # gradio
194
+ # spaces
195
+ pydantic-core==2.27.2
196
+ # via pydantic
197
+ pydub==0.25.1
198
+ # via gradio
199
+ pygments==2.19.1
200
+ # via rich
201
+ pylatexenc==2.10
202
+ # via clldutils
203
+ pyparsing==3.2.1
204
+ # via rdflib
205
+ python-dateutil==2.9.0.post0
206
+ # via
207
+ # clldutils
208
+ # csvw
209
+ # pandas
210
+ python-multipart==0.0.20
211
+ # via gradio
212
+ pytz==2025.1
213
+ # via pandas
214
+ pyyaml==6.0.2
215
+ # via
216
+ # gradio
217
+ # huggingface-hub
218
+ # transformers
219
+ rdflib==7.1.3
220
+ # via csvw
221
+ referencing==0.36.2
222
+ # via
223
+ # jsonschema
224
+ # jsonschema-specifications
225
+ regex==2024.11.6
226
+ # via
227
+ # segments
228
+ # transformers
229
+ requests==2.32.3
230
+ # via
231
+ # csvw
232
+ # huggingface-hub
233
+ # spaces
234
+ # transformers
235
+ rfc3986==1.5.0
236
+ # via csvw
237
+ rich==13.9.4
238
+ # via typer
239
+ rpds-py==0.22.3
240
+ # via
241
+ # jsonschema
242
+ # referencing
243
+ ruff==0.9.6
244
+ # via gradio
245
+ safehttpx==0.1.6
246
+ # via gradio
247
+ safetensors==0.5.2
248
+ # via transformers
249
+ segments==2.2.1
250
+ # via phonemizer
251
+ semantic-version==2.10.0
252
+ # via gradio
253
+ setuptools==75.8.0
254
+ # via zonos (pyproject.toml)
255
+ shellingham==1.5.4
256
+ # via typer
257
+ six==1.17.0
258
+ # via python-dateutil
259
+ sniffio==1.3.1
260
+ # via anyio
261
+ soundfile==0.13.1
262
+ # via zonos (pyproject.toml)
263
+ spaces==0.32.0
264
+ # via zonos (pyproject.toml)
265
+ starlette==0.45.3
266
+ # via
267
+ # fastapi
268
+ # gradio
269
+ sudachidict-full==20250129
270
+ # via zonos (pyproject.toml)
271
+ sudachipy==0.6.10
272
+ # via
273
+ # zonos (pyproject.toml)
274
+ # sudachidict-full
275
+ sympy==1.13.3
276
+ # via torch
277
+ tabulate==0.9.0
278
+ # via clldutils
279
+ tokenizers==0.21.0
280
+ # via transformers
281
+ tomlkit==0.13.2
282
+ # via gradio
283
+ torch==2.4.0
284
+ # via
285
+ # zonos (pyproject.toml)
286
+ # torchaudio
287
+ torchaudio==2.4.0
288
+ # via zonos (pyproject.toml)
289
+ tqdm==4.67.1
290
+ # via
291
+ # huggingface-hub
292
+ # transformers
293
+ transformers==4.48.3
294
+ # via zonos (pyproject.toml)
295
+ triton==3.0.0
296
+ # via torch
297
+ typeguard==4.4.1
298
+ # via inflect
299
+ typer==0.15.1
300
+ # via gradio
301
+ typing-extensions==4.12.2
302
+ # via
303
+ # anyio
304
+ # fastapi
305
+ # gradio
306
+ # gradio-client
307
+ # huggingface-hub
308
+ # phonemizer
309
+ # pydantic
310
+ # pydantic-core
311
+ # referencing
312
+ # rich
313
+ # spaces
314
+ # torch
315
+ # typeguard
316
+ # typer
317
+ # uvicorn
318
+ tzdata==2025.1
319
+ # via pandas
320
+ uritemplate==4.1.1
321
+ # via csvw
322
+ urllib3==2.3.0
323
+ # via requests
324
+ uvicorn==0.34.0
325
+ # via gradio
326
+ websockets==14.2
327
+ # via gradio-client
328
+
zonos/autoencoder.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torchaudio
5
+ from transformers.models.dac import DacModel
6
+
7
+
8
+ class DACAutoencoder:
9
+ def __init__(self):
10
+ super().__init__()
11
+ self.dac = DacModel.from_pretrained("descript/dac_44khz")
12
+ self.dac.eval().requires_grad_(False)
13
+ self.codebook_size = self.dac.config.codebook_size
14
+ self.num_codebooks = self.dac.quantizer.n_codebooks
15
+ self.sampling_rate = self.dac.config.sampling_rate
16
+
17
+ def preprocess(self, wav: torch.Tensor, sr: int) -> torch.Tensor:
18
+ wav = torchaudio.functional.resample(wav, sr, 44_100)
19
+ right_pad = math.ceil(wav.shape[-1] / 512) * 512 - wav.shape[-1]
20
+ return torch.nn.functional.pad(wav, (0, right_pad))
21
+
22
+ def encode(self, wav: torch.Tensor) -> torch.Tensor:
23
+ return self.dac.encode(wav).audio_codes
24
+
25
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
26
+ return self.dac.decode(audio_codes=codes).audio_values.unsqueeze(1)
zonos/backbone.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from mamba_ssm.models.mixer_seq_simple import create_block
4
+ from mamba_ssm.ops.triton.layer_norm import layer_norm_fn
5
+ from mamba_ssm.utils.generation import InferenceParams
6
+
7
+ from zonos.config import BackboneConfig
8
+
9
+
10
+ class ZonosBackbone(nn.Module):
11
+ def __init__(self, config: BackboneConfig):
12
+ super().__init__()
13
+ self.config = config
14
+
15
+ self.layers = nn.ModuleList(
16
+ [
17
+ create_block(
18
+ d_model=config.d_model,
19
+ d_intermediate=config.d_intermediate
20
+ if (i not in config.attn_layer_idx)
21
+ else config.attn_mlp_d_intermediate,
22
+ ssm_cfg=config.ssm_cfg,
23
+ layer_idx=i,
24
+ attn_layer_idx=config.attn_layer_idx,
25
+ attn_cfg=config.attn_cfg,
26
+ norm_epsilon=config.norm_epsilon,
27
+ residual_in_fp32=config.residual_in_fp32,
28
+ fused_add_norm=True,
29
+ rms_norm=config.rms_norm,
30
+ )
31
+ for i in range(config.n_layer)
32
+ ]
33
+ )
34
+
35
+ self.norm_f = nn.LayerNorm(config.d_model, eps=config.norm_epsilon)
36
+
37
+ def forward(self, hidden_states: torch.Tensor, inference_params: InferenceParams | None = None):
38
+ residual = None
39
+ for layer in self.layers:
40
+ hidden_states, residual = layer(hidden_states, residual, inference_params)
41
+
42
+ return layer_norm_fn(
43
+ hidden_states,
44
+ self.norm_f.weight,
45
+ self.norm_f.bias,
46
+ residual,
47
+ eps=self.norm_f.eps,
48
+ residual_in_fp32=self.config.residual_in_fp32,
49
+ is_rms_norm=self.config.rms_norm,
50
+ )
zonos/codebook_pattern.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def apply_delay_pattern(codes: torch.Tensor, mask_token: int):
6
+ codes = F.pad(codes, (0, codes.shape[1]), value=mask_token)
7
+ return torch.stack([codes[:, k].roll(k + 1) for k in range(codes.shape[1])], dim=1)
8
+
9
+
10
+ def revert_delay_pattern(codes: torch.Tensor):
11
+ _, n_q, seq_len = codes.shape
12
+ return torch.stack([codes[:, k, k + 1 : seq_len - n_q + k + 1] for k in range(n_q)], dim=1)
zonos/conditioning.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import cache
2
+ from typing import Any, Literal, Iterable
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from zonos.config import PrefixConditionerConfig
8
+
9
+
10
+ class Conditioner(nn.Module):
11
+ def __init__(
12
+ self,
13
+ output_dim: int,
14
+ name: str,
15
+ cond_dim: int | None = None,
16
+ projection: Literal["none", "linear", "mlp"] = "none",
17
+ uncond_type: Literal["learned", "none"] = "none",
18
+ **kwargs,
19
+ ):
20
+ super().__init__()
21
+ self.name = name
22
+ self.output_dim = output_dim
23
+ self.cond_dim = cond_dim = cond_dim or output_dim
24
+
25
+ if projection == "linear":
26
+ self.project = nn.Linear(cond_dim, output_dim)
27
+ elif projection == "mlp":
28
+ self.project = nn.Sequential(
29
+ nn.Linear(cond_dim, output_dim),
30
+ nn.SiLU(),
31
+ nn.Linear(output_dim, output_dim),
32
+ )
33
+ else:
34
+ self.project = nn.Identity()
35
+
36
+ self.uncond_vector = None
37
+ if uncond_type == "learned":
38
+ self.uncond_vector = nn.Parameter(torch.zeros(output_dim))
39
+
40
+ def apply_cond(self, *inputs: Any) -> torch.Tensor:
41
+ raise NotImplementedError()
42
+
43
+ def forward(self, inputs: tuple[Any, ...] | None) -> torch.Tensor:
44
+ if inputs is None:
45
+ assert self.uncond_vector is not None
46
+ return self.uncond_vector.data.view(1, 1, -1)
47
+
48
+ cond = self.apply_cond(*inputs)
49
+ cond = self.project(cond)
50
+ return cond
51
+
52
+
53
+ # ------- ESPEAK CONTAINMENT ZONE ------------------------------------------------------------------------------------------------------------------------------------------------
54
+ import re
55
+ import unicodedata
56
+
57
+ import inflect
58
+ import torch
59
+ import torch.nn as nn
60
+ from kanjize import number2kanji
61
+ from phonemizer.backend import EspeakBackend
62
+ from sudachipy import Dictionary, SplitMode
63
+
64
+ # --- Number normalization code from https://github.com/daniilrobnikov/vits2/blob/main/text/normalize_numbers.py ---
65
+
66
+ _inflect = inflect.engine()
67
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
68
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
69
+ _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
70
+ _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
71
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
72
+ _number_re = re.compile(r"[0-9]+")
73
+
74
+
75
+ def _remove_commas(m: re.Match) -> str:
76
+ return m.group(1).replace(",", "")
77
+
78
+
79
+ def _expand_decimal_point(m: re.Match) -> str:
80
+ return m.group(1).replace(".", " point ")
81
+
82
+
83
+ def _expand_dollars(m: re.Match) -> str:
84
+ match = m.group(1)
85
+ parts = match.split(".")
86
+ if len(parts) > 2:
87
+ return match + " dollars" # Unexpected format
88
+ dollars = int(parts[0]) if parts[0] else 0
89
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
90
+ if dollars and cents:
91
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
92
+ cent_unit = "cent" if cents == 1 else "cents"
93
+ return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
94
+ elif dollars:
95
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
96
+ return "%s %s" % (dollars, dollar_unit)
97
+ elif cents:
98
+ cent_unit = "cent" if cents == 1 else "cents"
99
+ return "%s %s" % (cents, cent_unit)
100
+ else:
101
+ return "zero dollars"
102
+
103
+
104
+ def _expand_ordinal(m: re.Match) -> str:
105
+ return _inflect.number_to_words(m.group(0))
106
+
107
+
108
+ def _expand_number(m: re.Match) -> str:
109
+ num = int(m.group(0))
110
+ if num > 1000 and num < 3000:
111
+ if num == 2000:
112
+ return "two thousand"
113
+ elif num > 2000 and num < 2010:
114
+ return "two thousand " + _inflect.number_to_words(num % 100)
115
+ elif num % 100 == 0:
116
+ return _inflect.number_to_words(num // 100) + " hundred"
117
+ else:
118
+ return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
119
+ else:
120
+ return _inflect.number_to_words(num, andword="")
121
+
122
+
123
+ def normalize_numbers(text: str) -> str:
124
+ text = re.sub(_comma_number_re, _remove_commas, text)
125
+ text = re.sub(_pounds_re, r"\1 pounds", text)
126
+ text = re.sub(_dollars_re, _expand_dollars, text)
127
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
128
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
129
+ text = re.sub(_number_re, _expand_number, text)
130
+ return text
131
+
132
+
133
+ # --- Number normalization code end ---
134
+
135
+
136
+ PAD_ID, UNK_ID, BOS_ID, EOS_ID = 0, 1, 2, 3
137
+ SPECIAL_TOKEN_IDS = [PAD_ID, UNK_ID, BOS_ID, EOS_ID]
138
+
139
+ _punctuation = ';:,.!?¡¿—…"«»“”() *~-/\\&'
140
+ _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
141
+ _letters_ipa = (
142
+ "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
143
+ )
144
+
145
+ symbols = [*_punctuation, *_letters, *_letters_ipa]
146
+ _symbol_to_id = {s: i for i, s in enumerate(symbols, start=len(SPECIAL_TOKEN_IDS))}
147
+
148
+
149
+ def _get_symbol_id(s: str) -> int:
150
+ return _symbol_to_id.get(s, 1)
151
+
152
+
153
+ def get_symbol_ids(text: str) -> list[int]:
154
+ return list(map(_get_symbol_id, text))
155
+
156
+
157
+ def tokenize_phonemes(phonemes: list[str]) -> tuple[torch.Tensor, list[int]]:
158
+ phoneme_ids = [[BOS_ID, *get_symbol_ids(phonemes), EOS_ID] for phonemes in phonemes]
159
+ lengths = list(map(len, phoneme_ids))
160
+ longest = max(lengths)
161
+ phoneme_ids = [[PAD_ID] * (longest - len(ids)) + ids for ids in phoneme_ids]
162
+ return torch.tensor(phoneme_ids), lengths
163
+
164
+
165
+ def normalize_jp_text(text: str, tokenizer=Dictionary(dict="full").create()) -> str:
166
+ text = unicodedata.normalize("NFKC", text)
167
+ text = re.sub(r"\d+", lambda m: number2kanji(int(m[0])), text)
168
+ final_text = " ".join([x.reading_form() for x in tokenizer.tokenize(text, SplitMode.A)])
169
+ return final_text
170
+
171
+
172
+ def clean(texts: list[str], languages: list[str]) -> list[str]:
173
+ texts_out = []
174
+ for text, language in zip(texts, languages):
175
+ if "ja" in language:
176
+ text = normalize_jp_text(text)
177
+ else:
178
+ text = normalize_numbers(text)
179
+ texts_out.append(text)
180
+ return texts_out
181
+
182
+
183
+ @cache
184
+ def get_backend(language: str) -> "EspeakBackend":
185
+ import logging
186
+
187
+ from phonemizer.backend import EspeakBackend
188
+
189
+ logger = logging.getLogger("phonemizer")
190
+ backend = EspeakBackend(
191
+ language,
192
+ preserve_punctuation=True,
193
+ with_stress=True,
194
+ punctuation_marks=_punctuation,
195
+ logger=logger,
196
+ )
197
+ logger.setLevel(logging.ERROR)
198
+ return backend
199
+
200
+
201
+ def phonemize(texts: list[str], languages: list[str]) -> list[str]:
202
+ texts = clean(texts, languages)
203
+
204
+ batch_phonemes = []
205
+ for text, language in zip(texts, languages):
206
+ backend = get_backend(language)
207
+ phonemes = backend.phonemize([text], strip=True)
208
+ batch_phonemes.append(phonemes[0])
209
+
210
+ return batch_phonemes
211
+
212
+
213
+ class EspeakPhonemeConditioner(Conditioner):
214
+ def __init__(self, output_dim: int, **kwargs):
215
+ super().__init__(output_dim, **kwargs)
216
+ self.phoneme_embedder = nn.Embedding(len(SPECIAL_TOKEN_IDS) + len(symbols), output_dim)
217
+
218
+ def apply_cond(self, texts: list[str], languages: list[str]) -> torch.Tensor:
219
+ """
220
+ Args:
221
+ texts: list of texts to convert to phonemes
222
+ languages: ISO 639-1 -or otherwise eSpeak compatible- language code
223
+ """
224
+ device = self.phoneme_embedder.weight.device
225
+
226
+ phonemes = phonemize(texts, languages)
227
+ phoneme_ids, _ = tokenize_phonemes(phonemes)
228
+ phoneme_embeds = self.phoneme_embedder(phoneme_ids.to(device))
229
+
230
+ return phoneme_embeds
231
+
232
+
233
+ # ------- ESPEAK CONTAINMENT ZONE ------------------------------------------------------------------------------------------------------------------------------------------------
234
+
235
+
236
+ class FourierConditioner(Conditioner):
237
+ def __init__(
238
+ self,
239
+ output_dim: int,
240
+ input_dim: int = 1,
241
+ std: float = 1.0,
242
+ min_val: float = 0.0,
243
+ max_val: float = 1.0,
244
+ **kwargs,
245
+ ):
246
+ assert output_dim % 2 == 0
247
+ super().__init__(output_dim, **kwargs)
248
+ self.register_buffer("weight", torch.randn([output_dim // 2, input_dim]) * std)
249
+ self.input_dim, self.min_val, self.max_val = input_dim, min_val, max_val
250
+
251
+ def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
252
+ assert x.shape[-1] == self.input_dim
253
+ x = (x - self.min_val) / (self.max_val - self.min_val) # [batch_size, seq_len, input_dim]
254
+ f = 2 * torch.pi * x.to(self.weight.dtype) @ self.weight.T # [batch_size, seq_len, output_dim // 2]
255
+ return torch.cat([f.cos(), f.sin()], dim=-1) # [batch_size, seq_len, output_dim]
256
+
257
+
258
+ class IntegerConditioner(Conditioner):
259
+ def __init__(self, output_dim: int, min_val: int = 0, max_val: int = 512, **kwargs):
260
+ super().__init__(output_dim, **kwargs)
261
+ self.min_val = min_val
262
+ self.max_val = max_val
263
+ self.int_embedder = nn.Embedding(max_val - min_val + 1, output_dim)
264
+
265
+ def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
266
+ assert x.shape[-1] == 1
267
+ return self.int_embedder(x.squeeze(-1) - self.min_val) # [batch_size, seq_len, output_dim]
268
+
269
+
270
+ class PassthroughConditioner(Conditioner):
271
+ def __init__(self, output_dim: int, **kwargs):
272
+ super().__init__(output_dim, **kwargs)
273
+
274
+ def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
275
+ assert x.shape[-1] == self.cond_dim
276
+ return x
277
+
278
+
279
+ _cond_cls_map = {
280
+ "PassthroughConditioner": PassthroughConditioner,
281
+ "EspeakPhonemeConditioner": EspeakPhonemeConditioner,
282
+ "FourierConditioner": FourierConditioner,
283
+ "IntegerConditioner": IntegerConditioner,
284
+ }
285
+
286
+
287
+ def build_conditioners(conditioners: list[dict], output_dim: int) -> list[Conditioner]:
288
+ return [_cond_cls_map[config["type"]](output_dim, **config) for config in conditioners]
289
+
290
+
291
+ class PrefixConditioner(Conditioner):
292
+ def __init__(self, config: PrefixConditionerConfig, output_dim: int):
293
+ super().__init__(output_dim, "prefix", projection=config.projection)
294
+ self.conditioners = nn.ModuleList(build_conditioners(config.conditioners, output_dim))
295
+ self.norm = nn.LayerNorm(output_dim)
296
+ self.required_keys = {c.name for c in self.conditioners if c.uncond_vector is None}
297
+
298
+ def forward(self, cond_dict: dict) -> torch.Tensor:
299
+ if not set(cond_dict).issuperset(self.required_keys):
300
+ raise ValueError(f"Missing required keys: {self.required_keys - set(cond_dict)}")
301
+ conds = []
302
+ for conditioner in self.conditioners:
303
+ conds.append(conditioner(cond_dict.get(conditioner.name)))
304
+ max_bsz = max(map(len, conds))
305
+ assert all(c.shape[0] in (max_bsz, 1) for c in conds)
306
+ conds = [c.expand(max_bsz, -1, -1) for c in conds]
307
+ return self.norm(self.project(torch.cat(conds, dim=-2)))
308
+
309
+
310
+ supported_language_codes = [
311
+ 'af', 'am', 'an', 'ar', 'as', 'az', 'ba', 'bg', 'bn', 'bpy', 'bs', 'ca', 'cmn',
312
+ 'cs', 'cy', 'da', 'de', 'el', 'en-029', 'en-gb', 'en-gb-scotland', 'en-gb-x-gbclan',
313
+ 'en-gb-x-gbcwmd', 'en-gb-x-rp', 'en-us', 'eo', 'es', 'es-419', 'et', 'eu', 'fa',
314
+ 'fa-latn', 'fi', 'fr-be', 'fr-ch', 'fr-fr', 'ga', 'gd', 'gn', 'grc', 'gu', 'hak',
315
+ 'hi', 'hr', 'ht', 'hu', 'hy', 'hyw', 'ia', 'id', 'is', 'it', 'ja', 'jbo', 'ka',
316
+ 'kk', 'kl', 'kn', 'ko', 'kok', 'ku', 'ky', 'la', 'lfn', 'lt', 'lv', 'mi', 'mk',
317
+ 'ml', 'mr', 'ms', 'mt', 'my', 'nb', 'nci', 'ne', 'nl', 'om', 'or', 'pa', 'pap',
318
+ 'pl', 'pt', 'pt-br', 'py', 'quc', 'ro', 'ru', 'ru-lv', 'sd', 'shn', 'si', 'sk',
319
+ 'sl', 'sq', 'sr', 'sv', 'sw', 'ta', 'te', 'tn', 'tr', 'tt', 'ur', 'uz', 'vi',
320
+ 'vi-vn-x-central', 'vi-vn-x-south', 'yue'
321
+ ] # fmt: off
322
+
323
+
324
+ def make_cond_dict(
325
+ text: str = "It would be nice to have time for testing, indeed.",
326
+ language: str = "en-us",
327
+ speaker: torch.Tensor | None = None,
328
+ emotion: list[float] = [0.3077, 0.0256, 0.0256, 0.0256, 0.0256, 0.0256, 0.2564, 0.3077],
329
+ fmax: float = 22050.0,
330
+ pitch_std: float = 20.0,
331
+ speaking_rate: float = 15.0,
332
+ vqscore_8: list[float] = [0.78] * 8,
333
+ ctc_loss: float = 0.0,
334
+ dnsmos_ovrl: float = 4.0,
335
+ speaker_noised: bool = False,
336
+ unconditional_keys: Iterable[str] = {"vqscore_8", "dnsmos_ovrl"},
337
+ device: str = "cuda",
338
+ ) -> dict:
339
+ """
340
+ A helper to build the 'cond_dict' that the model expects.
341
+ By default, it will generate a random speaker embedding
342
+ """
343
+ assert language.lower() in supported_language_codes, "Please pick a supported language"
344
+
345
+ language_code_to_id = {lang: i for i, lang in enumerate(supported_language_codes)}
346
+
347
+ cond_dict = {
348
+ "espeak": ([text], [language]),
349
+ "speaker": speaker,
350
+ "emotion": emotion,
351
+ "fmax": fmax,
352
+ "pitch_std": pitch_std,
353
+ "speaking_rate": speaking_rate,
354
+ "language_id": language_code_to_id[language],
355
+ "vqscore_8": vqscore_8,
356
+ "ctc_loss": ctc_loss,
357
+ "dnsmos_ovrl": dnsmos_ovrl,
358
+ "speaker_noised": int(speaker_noised),
359
+ }
360
+
361
+ for k in unconditional_keys:
362
+ cond_dict.pop(k, None)
363
+
364
+ for k, v in cond_dict.items():
365
+ if isinstance(v, (float, int, list)):
366
+ v = torch.tensor(v)
367
+ if isinstance(v, torch.Tensor):
368
+ cond_dict[k] = v.view(1, 1, -1).to(device)
369
+
370
+ if k == "emotion":
371
+ cond_dict[k] /= cond_dict[k].sum(dim=-1)
372
+
373
+ return cond_dict
zonos/config.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Literal
3
+
4
+
5
+ @dataclass
6
+ class BackboneConfig:
7
+ d_model: int = 1024
8
+ d_intermediate: int = 0
9
+ attn_mlp_d_intermediate: int = 0
10
+ n_layer: int = 16
11
+ ssm_cfg: dict = field(default_factory=dict)
12
+ attn_layer_idx: list = field(default_factory=list)
13
+ attn_cfg: dict = field(default_factory=dict)
14
+ rms_norm: bool = False
15
+ residual_in_fp32: bool = False
16
+ norm_epsilon: float = 1e-5
17
+
18
+
19
+ @dataclass
20
+ class PrefixConditionerConfig:
21
+ conditioners: list[dict]
22
+ projection: Literal["none", "linear", "mlp"]
23
+
24
+
25
+ @dataclass
26
+ class ZonosConfig:
27
+ backbone: BackboneConfig
28
+ prefix_conditioner: PrefixConditionerConfig
29
+ eos_token_id: int = 1024
30
+ masked_token_id: int = 1025
31
+
32
+ @classmethod
33
+ def from_dict(cls, d: dict) -> "ZonosConfig":
34
+ d = d.copy()
35
+ backbone_config = BackboneConfig(**d.pop("backbone"))
36
+ prefix_conditioner_config = PrefixConditionerConfig(**d.pop("prefix_conditioner"))
37
+ config = cls(backbone_config, prefix_conditioner_config, **d)
38
+ return config
zonos/model.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Callable
3
+
4
+ import safetensors
5
+ import torch
6
+ import torch.nn as nn
7
+ from huggingface_hub import hf_hub_download
8
+ from mamba_ssm.utils.generation import InferenceParams
9
+ from tqdm import tqdm
10
+
11
+ from zonos.autoencoder import DACAutoencoder
12
+ from zonos.backbone import ZonosBackbone
13
+ from zonos.codebook_pattern import apply_delay_pattern, revert_delay_pattern
14
+ from zonos.conditioning import PrefixConditioner
15
+ from zonos.config import ZonosConfig
16
+ from zonos.sampling import sample_from_logits
17
+ from zonos.speaker_cloning import SpeakerEmbeddingLDA
18
+
19
+
20
+ class Zonos(nn.Module):
21
+ def __init__(self, config: ZonosConfig):
22
+ super().__init__()
23
+ self.config = config
24
+ dim = config.backbone.d_model
25
+ self.eos_token_id = config.eos_token_id
26
+ self.masked_token_id = config.masked_token_id
27
+
28
+ self.autoencoder = DACAutoencoder()
29
+ self.backbone = ZonosBackbone(config.backbone)
30
+ self.prefix_conditioner = PrefixConditioner(config.prefix_conditioner, dim)
31
+ self.spk_clone_model = None
32
+
33
+ # TODO: pad to multiple of at least 8
34
+ self.embeddings = nn.ModuleList([nn.Embedding(1026, dim) for _ in range(self.autoencoder.num_codebooks)])
35
+ self.heads = nn.ModuleList([nn.Linear(dim, 1025, bias=False) for _ in range(self.autoencoder.num_codebooks)])
36
+
37
+ self._cg_graph = None
38
+ self._cg_batch_size = None
39
+ self._cg_input_ids = None
40
+ self._cg_logits = None
41
+ self._cg_inference_params = None
42
+ self._cg_scale = None
43
+
44
+ @classmethod
45
+ def from_pretrained(cls, repo_id: str, revision: str | None = None, device: str = "cuda") -> "Zonos":
46
+ config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
47
+ model_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
48
+ return cls.from_local(config_path, model_path, device)
49
+
50
+ @classmethod
51
+ def from_local(cls, config_path: str, model_path: str, device: str = "cuda") -> "Zonos":
52
+ config = ZonosConfig.from_dict(json.load(open(config_path)))
53
+ model = cls(config).to(device, torch.bfloat16)
54
+ model.autoencoder.dac.to(device)
55
+
56
+ sd = model.state_dict()
57
+ with safetensors.safe_open(model_path, framework="pt") as f:
58
+ for k in f.keys():
59
+ sd[k] = f.get_tensor(k)
60
+ model.load_state_dict(sd)
61
+
62
+ return model
63
+
64
+ def make_speaker_embedding(self, wav: torch.Tensor, sr: int) -> torch.Tensor:
65
+ """Generate a speaker embedding from an audio clip."""
66
+ if self.spk_clone_model is None:
67
+ self.spk_clone_model = SpeakerEmbeddingLDA()
68
+ _, spk_embedding = self.spk_clone_model(wav.to(self.spk_clone_model.device), sr)
69
+ return spk_embedding.unsqueeze(0).bfloat16()
70
+
71
+ def embed_codes(self, codes: torch.Tensor) -> torch.Tensor:
72
+ return sum(emb(codes[:, i]) for i, emb in enumerate(self.embeddings))
73
+
74
+ def apply_heads(self, hidden_states: torch.Tensor) -> torch.Tensor:
75
+ return torch.stack([head(hidden_states) for head in self.heads], dim=1)
76
+
77
+ def _compute_logits(
78
+ self, hidden_states: torch.Tensor, inference_params: InferenceParams, cfg_scale: float
79
+ ) -> torch.Tensor:
80
+ """
81
+ Pass `hidden_states` into `backbone` and `multi_head`, applying
82
+ classifier-free guidance if `cfg_scale != 1.0`.
83
+ """
84
+ last_hidden_states = self.backbone(hidden_states, inference_params)[:, -1, :].unsqueeze(1)
85
+ logits = self.apply_heads(last_hidden_states).squeeze(2).float()
86
+ if cfg_scale != 1.0:
87
+ cond_logits, uncond_logits = logits.chunk(2)
88
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
89
+ return logits
90
+
91
+ def _decode_one_token(
92
+ self,
93
+ input_ids: torch.Tensor,
94
+ inference_params: InferenceParams,
95
+ cfg_scale: float,
96
+ ) -> torch.Tensor:
97
+ """
98
+ Single-step decode. Prepares the hidden states, possibly replicates them
99
+ for CFG, and then delegates to `_compute_logits`.
100
+
101
+ Below we wrap this function with a simple CUDA Graph capturing mechanism,
102
+ doing 3 warmup steps if needed and then capturing or replaying the graph.
103
+ We only recapture if the batch size changes.
104
+ """
105
+ # TODO: support cfg_scale==1
106
+ if cfg_scale == 1.0:
107
+ hidden_states = self.embed_codes(input_ids)
108
+ return self._compute_logits(hidden_states, inference_params, cfg_scale)
109
+
110
+ bsz = input_ids.size(0)
111
+
112
+ need_capture = (self._cg_graph is None) or (self._cg_batch_size != bsz)
113
+
114
+ if need_capture:
115
+ self._cg_graph = None
116
+
117
+ self._cg_batch_size = bsz
118
+ self._cg_inference_params = inference_params
119
+ self._cg_scale = cfg_scale
120
+
121
+ for _ in range(3):
122
+ hidden_states = self.embed_codes(input_ids)
123
+ hidden_states = hidden_states.repeat(2, 1, 1) # because cfg != 1.0
124
+ logits = self._compute_logits(hidden_states, inference_params, cfg_scale)
125
+
126
+ self._cg_input_ids = input_ids.clone()
127
+ self._cg_logits = torch.empty_like(logits)
128
+
129
+ g = torch.cuda.CUDAGraph()
130
+
131
+ def capture_region():
132
+ hidden_states_local = self.embed_codes(self._cg_input_ids)
133
+ hidden_states_local = hidden_states_local.repeat(2, 1, 1)
134
+ self._cg_logits = self._compute_logits(hidden_states_local, self._cg_inference_params, self._cg_scale)
135
+
136
+ with torch.cuda.graph(g):
137
+ capture_region()
138
+
139
+ self._cg_graph = g
140
+
141
+ else:
142
+ self._cg_input_ids.copy_(input_ids)
143
+
144
+ self._cg_graph.replay()
145
+
146
+ return self._cg_logits
147
+
148
+ def _prefill(
149
+ self,
150
+ prefix_hidden_states: torch.Tensor,
151
+ input_ids: torch.Tensor,
152
+ inference_params: InferenceParams,
153
+ cfg_scale: float,
154
+ ) -> torch.Tensor:
155
+ """
156
+ "Prefill" mode: we already have `prefix_hidden_states`, and we want
157
+ to append new embeddings, then compute the logits.
158
+ """
159
+ # Replicate input_ids if CFG is enabled
160
+ if cfg_scale != 1.0:
161
+ input_ids = input_ids.expand(prefix_hidden_states.shape[0], -1, -1)
162
+ hidden_states = torch.cat([prefix_hidden_states, self.embed_codes(input_ids)], dim=1)
163
+ return self._compute_logits(hidden_states, inference_params, cfg_scale)
164
+
165
+ def setup_cache(self, batch_size: int, max_seqlen: int, dtype: torch.dtype = torch.bfloat16) -> InferenceParams:
166
+ key_value_memory_dict = {
167
+ i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)
168
+ for i, layer in enumerate(self.backbone.layers)
169
+ }
170
+ lengths_per_sample = torch.full((batch_size,), 0, dtype=torch.int32, device="cuda")
171
+ return InferenceParams(max_seqlen, batch_size, 0, 0, key_value_memory_dict, lengths_per_sample)
172
+
173
+ def prepare_conditioning(self, cond_dict: dict, uncond_dict: dict | None = None) -> torch.Tensor:
174
+ if uncond_dict is None:
175
+ uncond_dict = {k: cond_dict[k] for k in self.prefix_conditioner.required_keys}
176
+ return torch.cat(
177
+ [
178
+ self.prefix_conditioner(cond_dict),
179
+ self.prefix_conditioner(uncond_dict),
180
+ ]
181
+ )
182
+
183
+ @torch.inference_mode()
184
+ def generate(
185
+ self,
186
+ prefix_conditioning: torch.Tensor, # [bsz, cond_seq_len, d_model]
187
+ audio_prefix_codes: torch.Tensor | None = None, # [bsz, 9, prefix_audio_seq_len]
188
+ max_new_tokens: int = 86 * 30,
189
+ cfg_scale: float = 2.0,
190
+ batch_size: int = 1,
191
+ sampling_params: dict = dict(min_p=0.1),
192
+ progress_bar: bool = True,
193
+ callback: Callable[[torch.Tensor, int, int], bool] | None = None,
194
+ ):
195
+ assert cfg_scale != 1, "TODO: add support for cfg_scale=1"
196
+ prefix_audio_len = 0 if audio_prefix_codes is None else audio_prefix_codes.shape[2]
197
+
198
+ unknown_token = -1
199
+ audio_seq_len = prefix_audio_len + max_new_tokens
200
+ seq_len = prefix_conditioning.shape[1] + audio_seq_len
201
+
202
+ inference_params = self.setup_cache(batch_size=batch_size * 2, max_seqlen=seq_len)
203
+
204
+ codes = torch.full((batch_size, 9, audio_seq_len), unknown_token, device="cuda")
205
+ if audio_prefix_codes is not None:
206
+ codes[..., :prefix_audio_len] = audio_prefix_codes
207
+
208
+ delayed_codes = apply_delay_pattern(codes, self.masked_token_id)
209
+
210
+ delayed_prefix_audio_codes = delayed_codes[..., : prefix_audio_len + 1]
211
+
212
+ logits = self._prefill(prefix_conditioning, delayed_prefix_audio_codes, inference_params, cfg_scale)
213
+ next_token = sample_from_logits(logits, **sampling_params)
214
+
215
+ offset = delayed_prefix_audio_codes.shape[2]
216
+ frame = delayed_codes[..., offset : offset + 1]
217
+ frame.masked_scatter_(frame == unknown_token, next_token)
218
+
219
+ prefix_length = prefix_conditioning.shape[1] + prefix_audio_len + 1
220
+ inference_params.seqlen_offset += prefix_length
221
+ inference_params.lengths_per_sample[:] += prefix_length
222
+
223
+ logit_bias = torch.zeros_like(logits)
224
+ logit_bias[:, 1:, self.eos_token_id] = -torch.inf # only allow codebook 0 to predict EOS
225
+
226
+ stopping = torch.zeros(batch_size, dtype=torch.bool, device="cuda")
227
+ max_steps = delayed_codes.shape[2] - offset
228
+ remaining_steps = torch.full((batch_size,), max_steps, device="cuda")
229
+ progress = tqdm(total=max_steps, desc="Generating", disable=not progress_bar)
230
+
231
+ step = 0
232
+ while torch.max(remaining_steps) > 0:
233
+ offset += 1
234
+ input_ids = delayed_codes[..., offset - 1 : offset]
235
+ logits = self._decode_one_token(input_ids, inference_params, cfg_scale)
236
+
237
+ next_token = sample_from_logits(logits, generated_tokens=delayed_codes[..., :offset], **sampling_params)
238
+ eos_in_cb0 = next_token[:, 0] == self.eos_token_id
239
+
240
+ remaining_steps[eos_in_cb0[:, 0]] = torch.minimum(remaining_steps[eos_in_cb0[:, 0]], torch.tensor(9))
241
+ stopping |= eos_in_cb0[:, 0]
242
+
243
+ eos_codebook_idx = 9 - remaining_steps
244
+ eos_codebook_idx = torch.clamp(eos_codebook_idx, max=9 - 1)
245
+ for i in range(next_token.shape[0]):
246
+ if stopping[i]:
247
+ idx = eos_codebook_idx[i].item()
248
+ next_token[i, :idx] = self.masked_token_id
249
+ next_token[i, idx] = self.eos_token_id
250
+
251
+ frame = delayed_codes[..., offset : offset + 1]
252
+ frame.masked_scatter_(frame == unknown_token, next_token)
253
+ inference_params.seqlen_offset += 1
254
+ inference_params.lengths_per_sample[:] += 1
255
+
256
+ remaining_steps -= 1
257
+
258
+ progress.update()
259
+ step += 1
260
+
261
+ if callback is not None and not callback(frame, step, max_steps):
262
+ break
263
+
264
+ out_codes = revert_delay_pattern(delayed_codes)
265
+ out_codes.masked_fill_(out_codes >= 1024, 0)
266
+ out_codes = out_codes[..., : offset - 9]
267
+
268
+ self._cg_graph = None # reset cuda graph to avoid cache changes
269
+
270
+ return out_codes
zonos/sampling.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
5
+ """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
6
+
7
+ Args:
8
+ input (torch.Tensor): The input tensor containing probabilities.
9
+ num_samples (int): Number of samples to draw.
10
+ replacement (bool): Whether to draw with replacement or not.
11
+ Keywords args:
12
+ generator (torch.Generator): A pseudorandom number generator for sampling.
13
+ Returns:
14
+ torch.Tensor: Last dimension contains num_samples indices
15
+ sampled from the multinomial probability distribution
16
+ located in the last dimension of tensor input.
17
+ """
18
+
19
+ if num_samples == 1:
20
+ q = torch.empty_like(input).exponential_(1, generator=generator)
21
+ return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64)
22
+
23
+ input_ = input.reshape(-1, input.shape[-1])
24
+ output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
25
+ output = output_.reshape(*list(input.shape[:-1]), -1)
26
+ return output
27
+
28
+
29
+ def apply_top_k(
30
+ probs: torch.Tensor,
31
+ k: int,
32
+ ) -> torch.Tensor:
33
+ """Sample next token from top K values along the last dimension of the input probs tensor.
34
+
35
+ Args:
36
+ probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
37
+ k (int): The k in “top-k”.
38
+ Returns:
39
+ torch.Tensor: Sampled tokens.
40
+ """
41
+ v, _ = torch.topk(probs, min(k, probs.size(-1)))
42
+ pivot = v.select(-1, -1).unsqueeze(-1)
43
+ probs = torch.where(probs < pivot, 0.0, probs)
44
+ probs.div_(probs.sum(dim=-1, keepdim=True))
45
+ return probs
46
+
47
+
48
+ def apply_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
49
+ """Sample next token from top P probabilities along the last dimension of the input probs tensor.
50
+
51
+ Args:
52
+ probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
53
+ p (int): The p in “top-p”.
54
+ Returns:
55
+ torch.Tensor: Sampled tokens.
56
+ """
57
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
58
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
59
+ mask = probs_sum - probs_sort > p
60
+ probs_sort *= (~mask).float()
61
+ probs = probs.scatter(-1, probs_idx, probs_sort)
62
+ probs.div_(probs.sum(dim=-1, keepdim=True))
63
+ return probs
64
+
65
+
66
+ def apply_min_p(probs: torch.Tensor, min_p: float) -> torch.Tensor:
67
+ """Sample next token using min-p sampling.
68
+
69
+ Args:
70
+ scores (torch.FloatTensor): Input logits with token candidates on the last dimension.
71
+ min_p (float): Minimum token probability, scaled by the probability of the most likely token.
72
+ Must be between 0 and 1. Typical values are in the 0.01-0.2 range.
73
+ Returns:
74
+ torch.Tensor: Sampled tokens.
75
+ """
76
+ top_probs, _ = probs.max(dim=-1, keepdim=True)
77
+ tokens_to_remove = probs < (min_p * top_probs)
78
+ probs = probs.masked_fill(tokens_to_remove, 0.0)
79
+ probs.div_(probs.sum(dim=-1, keepdim=True))
80
+ return probs
81
+
82
+
83
+ def modify_logit_for_repetition_penalty(
84
+ logits: torch.Tensor,
85
+ generated_tokens: torch.Tensor,
86
+ repetition_penalty: float,
87
+ repetition_penalty_window: int,
88
+ ):
89
+ """See https://arxiv.org/abs/1909.05858
90
+ Apply repetition penalty over a sliding window of the last `repetition_penalty_window` tokens.
91
+ logits: (batch_size, n_codebooks, vocab_size)
92
+ generated_tokens: (batch_size, n_codebooks, seq_len)
93
+ """
94
+ generated_tokens = generated_tokens[..., -repetition_penalty_window:]
95
+ generated_tokens = generated_tokens.clamp_max(logits.shape[-1] - 1).to(torch.int64)
96
+ rp = torch.full_like(logits, repetition_penalty)
97
+ factors = torch.ones_like(logits).scatter_reduce(2, generated_tokens, rp, reduce="prod")
98
+ return torch.where(logits <= 0, logits * factors, logits / factors)
99
+
100
+
101
+ def sample_from_logits(
102
+ logits: torch.Tensor,
103
+ temperature: float = 1.0,
104
+ top_p: float = 0.0,
105
+ top_k: int = 0,
106
+ min_p: float = 0.0,
107
+ generated_tokens: torch.Tensor | None = None,
108
+ repetition_penalty: float = 3.0,
109
+ repetition_penalty_window: float = 2,
110
+ ) -> torch.Tensor:
111
+ """Sample next token from logits using temperature, top-p, top-k, or min-p sampling.
112
+
113
+ Args:
114
+ logits (torch.Tensor): Input logits with token candidates on the last dimension.
115
+ temperature (float): Sampling temperature. Lower temperature results in more deterministic samples.
116
+ top_p (float): The p in “top-p”.
117
+ top_k (int): The k in “top-k”.
118
+ min_p (float): Minimum token probability, scaled by the probability of the most likely token.
119
+ Must be between 0 and 1. Typical values are in the 0.01-0.2 range.
120
+
121
+ Returns:
122
+ torch.Tensor: Sampled tokens.
123
+ """
124
+ if repetition_penalty != 1.0 and generated_tokens is not None:
125
+ logits = modify_logit_for_repetition_penalty(logits, generated_tokens, repetition_penalty, repetition_penalty_window)
126
+
127
+ if temperature > 0:
128
+ probs = torch.softmax(logits / temperature, dim=-1)
129
+
130
+ if top_p > 0:
131
+ probs = apply_top_p(probs, top_p)
132
+ if top_k > 0:
133
+ probs = apply_top_k(probs, top_k)
134
+ if min_p > 0:
135
+ probs = apply_min_p(probs, min_p)
136
+
137
+ next_token = multinomial(probs, num_samples=1)
138
+ else:
139
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
140
+
141
+ return next_token # [batch_size, num_codebooks, 1]
zonos/speaker_cloning.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import cache
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torchaudio
8
+ from huggingface_hub import hf_hub_download
9
+ import os
10
+
11
+
12
+ class logFbankCal(nn.Module):
13
+ def __init__(
14
+ self,
15
+ sample_rate: int = 16_000,
16
+ n_fft: int = 512,
17
+ win_length: float = 0.025,
18
+ hop_length: float = 0.01,
19
+ n_mels: int = 80,
20
+ ):
21
+ super().__init__()
22
+ self.fbankCal = torchaudio.transforms.MelSpectrogram(
23
+ sample_rate=sample_rate,
24
+ n_fft=n_fft,
25
+ win_length=int(win_length * sample_rate),
26
+ hop_length=int(hop_length * sample_rate),
27
+ n_mels=n_mels,
28
+ )
29
+
30
+ def forward(self, x):
31
+ out = self.fbankCal(x)
32
+ out = torch.log(out + 1e-6)
33
+ out = out - out.mean(axis=2).unsqueeze(dim=2)
34
+ return out
35
+
36
+
37
+ class ASP(nn.Module):
38
+ # Attentive statistics pooling
39
+ def __init__(self, in_planes, acoustic_dim):
40
+ super(ASP, self).__init__()
41
+ outmap_size = int(acoustic_dim / 8)
42
+ self.out_dim = in_planes * 8 * outmap_size * 2
43
+
44
+ self.attention = nn.Sequential(
45
+ nn.Conv1d(in_planes * 8 * outmap_size, 128, kernel_size=1),
46
+ nn.ReLU(),
47
+ nn.BatchNorm1d(128),
48
+ nn.Conv1d(128, in_planes * 8 * outmap_size, kernel_size=1),
49
+ nn.Softmax(dim=2),
50
+ )
51
+
52
+ def forward(self, x):
53
+ x = x.reshape(x.size()[0], -1, x.size()[-1])
54
+ w = self.attention(x)
55
+ mu = torch.sum(x * w, dim=2)
56
+ sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5))
57
+ x = torch.cat((mu, sg), 1)
58
+
59
+ x = x.view(x.size()[0], -1)
60
+ return x
61
+
62
+
63
+ class SimAMBasicBlock(nn.Module):
64
+ expansion = 1
65
+
66
+ def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
67
+ super(SimAMBasicBlock, self).__init__()
68
+ self.conv1 = ConvLayer(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
69
+ self.bn1 = NormLayer(planes)
70
+ self.conv2 = ConvLayer(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
71
+ self.bn2 = NormLayer(planes)
72
+ self.relu = nn.ReLU(inplace=True)
73
+ self.sigmoid = nn.Sigmoid()
74
+
75
+ self.downsample = nn.Sequential()
76
+ if stride != 1 or in_planes != self.expansion * planes:
77
+ self.downsample = nn.Sequential(
78
+ ConvLayer(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
79
+ NormLayer(self.expansion * planes),
80
+ )
81
+
82
+ def forward(self, x):
83
+ out = self.relu(self.bn1(self.conv1(x)))
84
+ out = self.bn2(self.conv2(out))
85
+ out = self.SimAM(out)
86
+ out += self.downsample(x)
87
+ out = self.relu(out)
88
+ return out
89
+
90
+ def SimAM(self, X, lambda_p=1e-4):
91
+ n = X.shape[2] * X.shape[3] - 1
92
+ d = (X - X.mean(dim=[2, 3], keepdim=True)).pow(2)
93
+ v = d.sum(dim=[2, 3], keepdim=True) / n
94
+ E_inv = d / (4 * (v + lambda_p)) + 0.5
95
+ return X * self.sigmoid(E_inv)
96
+
97
+
98
+ class BasicBlock(nn.Module):
99
+ expansion = 1
100
+
101
+ def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
102
+ super(BasicBlock, self).__init__()
103
+ self.conv1 = ConvLayer(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
104
+ self.bn1 = NormLayer(planes)
105
+ self.conv2 = ConvLayer(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
106
+ self.bn2 = NormLayer(planes)
107
+ self.relu = nn.ReLU(inplace=True)
108
+
109
+ self.downsample = nn.Sequential()
110
+ if stride != 1 or in_planes != self.expansion * planes:
111
+ self.downsample = nn.Sequential(
112
+ ConvLayer(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
113
+ NormLayer(self.expansion * planes),
114
+ )
115
+
116
+ def forward(self, x):
117
+ out = self.relu(self.bn1(self.conv1(x)))
118
+ out = self.bn2(self.conv2(out))
119
+ out += self.downsample(x)
120
+ out = self.relu(out)
121
+ return out
122
+
123
+
124
+ class Bottleneck(nn.Module):
125
+ expansion = 4
126
+
127
+ def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
128
+ super(Bottleneck, self).__init__()
129
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
130
+ self.bn1 = nn.BatchNorm2d(planes)
131
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
132
+ self.bn2 = nn.BatchNorm2d(planes)
133
+ self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
134
+ self.bn3 = nn.BatchNorm2d(self.expansion * planes)
135
+
136
+ self.shortcut = nn.Sequential()
137
+ if stride != 1 or in_planes != self.expansion * planes:
138
+ self.shortcut = nn.Sequential(
139
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
140
+ nn.BatchNorm2d(self.expansion * planes),
141
+ )
142
+
143
+ def forward(self, x):
144
+ out = F.relu(self.bn1(self.conv1(x)))
145
+ out = F.relu(self.bn2(self.conv2(out)))
146
+ out = self.bn3(self.conv3(out))
147
+ out += self.shortcut(x)
148
+ out = F.relu(out)
149
+ return out
150
+
151
+
152
+ class ResNet(nn.Module):
153
+ def __init__(self, in_planes, block, num_blocks, in_ch=1, feat_dim="2d", **kwargs):
154
+ super(ResNet, self).__init__()
155
+ if feat_dim == "1d":
156
+ self.NormLayer = nn.BatchNorm1d
157
+ self.ConvLayer = nn.Conv1d
158
+ elif feat_dim == "2d":
159
+ self.NormLayer = nn.BatchNorm2d
160
+ self.ConvLayer = nn.Conv2d
161
+ elif feat_dim == "3d":
162
+ self.NormLayer = nn.BatchNorm3d
163
+ self.ConvLayer = nn.Conv3d
164
+ else:
165
+ print("error")
166
+
167
+ self.in_planes = in_planes
168
+
169
+ self.conv1 = self.ConvLayer(in_ch, in_planes, kernel_size=3, stride=1, padding=1, bias=False)
170
+ self.bn1 = self.NormLayer(in_planes)
171
+ self.relu = nn.ReLU(inplace=True)
172
+ self.layer1 = self._make_layer(block, in_planes, num_blocks[0], stride=1, block_id=1)
173
+ self.layer2 = self._make_layer(block, in_planes * 2, num_blocks[1], stride=2, block_id=2)
174
+ self.layer3 = self._make_layer(block, in_planes * 4, num_blocks[2], stride=2, block_id=3)
175
+ self.layer4 = self._make_layer(block, in_planes * 8, num_blocks[3], stride=2, block_id=4)
176
+
177
+ def _make_layer(self, block, planes, num_blocks, stride, block_id=1):
178
+ strides = [stride] + [1] * (num_blocks - 1)
179
+ layers = []
180
+ for stride in strides:
181
+ layers.append(block(self.ConvLayer, self.NormLayer, self.in_planes, planes, stride, block_id))
182
+ self.in_planes = planes * block.expansion
183
+ return nn.Sequential(*layers)
184
+
185
+ def forward(self, x):
186
+ x = self.relu(self.bn1(self.conv1(x)))
187
+ x = self.layer1(x)
188
+ x = self.layer2(x)
189
+ x = self.layer3(x)
190
+ x = self.layer4(x)
191
+ return x
192
+
193
+
194
+ def ResNet293(in_planes: int, **kwargs):
195
+ return ResNet(in_planes, SimAMBasicBlock, [10, 20, 64, 3], **kwargs)
196
+
197
+
198
+ class ResNet293_based(nn.Module):
199
+ def __init__(
200
+ self,
201
+ in_planes: int = 64,
202
+ embd_dim: int = 256,
203
+ acoustic_dim: int = 80,
204
+ featCal=None,
205
+ dropout: float = 0,
206
+ **kwargs,
207
+ ):
208
+ super(ResNet293_based, self).__init__()
209
+ self.featCal = featCal
210
+ self.front = ResNet293(in_planes)
211
+ block_expansion = SimAMBasicBlock.expansion
212
+ self.pooling = ASP(in_planes * block_expansion, acoustic_dim)
213
+ self.bottleneck = nn.Linear(self.pooling.out_dim, embd_dim)
214
+ self.drop = nn.Dropout(dropout) if dropout else None
215
+
216
+ def forward(self, x):
217
+ x = self.featCal(x)
218
+ x = self.front(x.unsqueeze(dim=1))
219
+ x = self.pooling(x)
220
+ if self.drop:
221
+ x = self.drop(x)
222
+ x = self.bottleneck(x)
223
+ return x
224
+
225
+
226
+ class SEModule(nn.Module):
227
+ def __init__(self, channels, bottleneck=128):
228
+ super(SEModule, self).__init__()
229
+ self.se = nn.Sequential(
230
+ nn.AdaptiveAvgPool1d(1),
231
+ nn.Conv1d(channels, bottleneck, kernel_size=1, padding=0),
232
+ nn.ReLU(),
233
+ # nn.BatchNorm1d(bottleneck), # Removed
234
+ nn.Conv1d(bottleneck, channels, kernel_size=1, padding=0),
235
+ nn.Sigmoid(),
236
+ )
237
+
238
+ def forward(self, input):
239
+ x = self.se(input)
240
+ return input * x
241
+
242
+
243
+ class Bottle2neck(nn.Module):
244
+ def __init__(self, inplanes, planes, kernel_size=None, dilation=None, scale=8):
245
+ super(Bottle2neck, self).__init__()
246
+ width = int(math.floor(planes / scale))
247
+ self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1)
248
+ self.bn1 = nn.BatchNorm1d(width * scale)
249
+ self.nums = scale - 1
250
+ convs = []
251
+ bns = []
252
+ num_pad = math.floor(kernel_size / 2) * dilation
253
+ for i in range(self.nums):
254
+ convs.append(nn.Conv1d(width, width, kernel_size=kernel_size, dilation=dilation, padding=num_pad))
255
+ bns.append(nn.BatchNorm1d(width))
256
+ self.convs = nn.ModuleList(convs)
257
+ self.bns = nn.ModuleList(bns)
258
+ self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1)
259
+ self.bn3 = nn.BatchNorm1d(planes)
260
+ self.relu = nn.ReLU()
261
+ self.width = width
262
+ self.se = SEModule(planes)
263
+
264
+ def forward(self, x):
265
+ residual = x
266
+ out = self.conv1(x)
267
+ out = self.relu(out)
268
+ out = self.bn1(out)
269
+
270
+ spx = torch.split(out, self.width, 1)
271
+ for i in range(self.nums):
272
+ if i == 0:
273
+ sp = spx[i]
274
+ else:
275
+ sp = sp + spx[i]
276
+ sp = self.convs[i](sp)
277
+ sp = self.relu(sp)
278
+ sp = self.bns[i](sp)
279
+ if i == 0:
280
+ out = sp
281
+ else:
282
+ out = torch.cat((out, sp), 1)
283
+ out = torch.cat((out, spx[self.nums]), 1)
284
+
285
+ out = self.conv3(out)
286
+ out = self.relu(out)
287
+ out = self.bn3(out)
288
+
289
+ out = self.se(out)
290
+ out += residual
291
+ return out
292
+
293
+
294
+ class ECAPA_TDNN(nn.Module):
295
+ def __init__(self, C, featCal):
296
+ super(ECAPA_TDNN, self).__init__()
297
+ self.featCal = featCal
298
+ self.conv1 = nn.Conv1d(80, C, kernel_size=5, stride=1, padding=2)
299
+ self.relu = nn.ReLU()
300
+ self.bn1 = nn.BatchNorm1d(C)
301
+ self.layer1 = Bottle2neck(C, C, kernel_size=3, dilation=2, scale=8)
302
+ self.layer2 = Bottle2neck(C, C, kernel_size=3, dilation=3, scale=8)
303
+ self.layer3 = Bottle2neck(C, C, kernel_size=3, dilation=4, scale=8)
304
+ # I fixed the shape of the output from MFA layer, that is close to the setting from ECAPA paper.
305
+ self.layer4 = nn.Conv1d(3 * C, 1536, kernel_size=1)
306
+ self.attention = nn.Sequential(
307
+ nn.Conv1d(4608, 256, kernel_size=1),
308
+ nn.ReLU(),
309
+ nn.BatchNorm1d(256),
310
+ nn.Tanh(), # Added
311
+ nn.Conv1d(256, 1536, kernel_size=1),
312
+ nn.Softmax(dim=2),
313
+ )
314
+ self.bn5 = nn.BatchNorm1d(3072)
315
+ self.fc6 = nn.Linear(3072, 192)
316
+ self.bn6 = nn.BatchNorm1d(192)
317
+
318
+ def forward(self, x):
319
+ x = self.featCal(x)
320
+ x = self.conv1(x)
321
+ x = self.relu(x)
322
+ x = self.bn1(x)
323
+
324
+ x1 = self.layer1(x)
325
+ x2 = self.layer2(x + x1)
326
+ x3 = self.layer3(x + x1 + x2)
327
+
328
+ x = self.layer4(torch.cat((x1, x2, x3), dim=1))
329
+ x = self.relu(x)
330
+
331
+ t = x.size()[-1]
332
+
333
+ global_x = torch.cat(
334
+ (
335
+ x,
336
+ torch.mean(x, dim=2, keepdim=True).repeat(1, 1, t),
337
+ torch.sqrt(torch.var(x, dim=2, keepdim=True).clamp(min=1e-4)).repeat(1, 1, t),
338
+ ),
339
+ dim=1,
340
+ )
341
+
342
+ w = self.attention(global_x)
343
+
344
+ mu = torch.sum(x * w, dim=2)
345
+ sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-4))
346
+
347
+ x = torch.cat((mu, sg), 1)
348
+ x = self.bn5(x)
349
+ x = self.fc6(x)
350
+ x = self.bn6(x)
351
+
352
+ return x
353
+
354
+
355
+ class SpeakerEmbedding(nn.Module):
356
+ def __init__(self, ckpt_path: str = "ResNet293_SimAM_ASP_base.pt", device: str = "cuda"):
357
+ super().__init__()
358
+ self.device = device
359
+ with torch.device(device):
360
+ self.model = ResNet293_based()
361
+ self.model.load_state_dict(torch.load(ckpt_path, weights_only=True, mmap=True))
362
+ self.model.featCal = logFbankCal()
363
+
364
+ self.requires_grad_(False).eval()
365
+
366
+ @property
367
+ def dtype(self):
368
+ return next(self.parameters()).dtype
369
+
370
+ @cache
371
+ def _get_resampler(self, orig_sample_rate: int):
372
+ return torchaudio.transforms.Resample(orig_sample_rate, 16_000).to(self.device)
373
+
374
+ def prepare_input(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
375
+ assert wav.ndim < 3
376
+ if wav.ndim == 2:
377
+ wav = wav.mean(0, keepdim=True)
378
+ wav = self._get_resampler(sample_rate)(wav)
379
+ return wav
380
+
381
+ def forward(self, wav: torch.Tensor, sample_rate: int):
382
+ wav = self.prepare_input(wav, sample_rate).to(self.device, self.dtype)
383
+ return self.model(wav).to(wav.device)
384
+
385
+ class SpeakerEmbeddingLDA(nn.Module):
386
+ def __init__(
387
+ self,
388
+ device: str = "cuda",
389
+ ):
390
+ super().__init__()
391
+ spk_model_path = hf_hub_download(repo_id="Zyphra/Zonos-v0.1-speaker-embedding", filename="ResNet293_SimAM_ASP_base.pt")
392
+ lda_spk_model_path = hf_hub_download(repo_id="Zyphra/Zonos-v0.1-speaker-embedding", filename="ResNet293_SimAM_ASP_base_LDA-128.pt")
393
+
394
+ self.device = device
395
+ with torch.device(device):
396
+ self.model = SpeakerEmbedding(spk_model_path, device)
397
+ lda_sd = torch.load(lda_spk_model_path, weights_only=True)
398
+ out_features, in_features = lda_sd["weight"].shape
399
+ self.lda = nn.Linear(in_features, out_features, bias=True, dtype=torch.float32)
400
+ self.lda.load_state_dict(lda_sd)
401
+
402
+ self.requires_grad_(False).eval()
403
+
404
+ def forward(self, wav: torch.Tensor, sample_rate: int):
405
+ emb = self.model(wav, sample_rate).to(torch.float32)
406
+ return emb, self.lda(emb)