Steveeeeeeen HF staff commited on
Commit
d743fc1
·
verified ·
1 Parent(s): 0739217

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -392
app.py CHANGED
@@ -3,404 +3,72 @@ import torchaudio
3
  import gradio as gr
4
 
5
  from zonos.model import Zonos
6
- from zonos.conditioning import make_cond_dict, supported_language_codes
7
 
8
- device = "cuda"
9
- CURRENT_MODEL_TYPE = None
10
- CURRENT_MODEL = None
11
 
12
-
13
- def load_model_if_needed(model_choice: str):
14
- global CURRENT_MODEL_TYPE, CURRENT_MODEL
15
- if CURRENT_MODEL_TYPE != model_choice:
16
- if CURRENT_MODEL is not None:
17
- del CURRENT_MODEL
18
- torch.cuda.empty_cache()
19
- print(f"Loading {model_choice} model...")
20
- if model_choice == "Transformer":
21
- CURRENT_MODEL = Zonos.from_pretrained("Zyphra/Zonos-v0.1-transformer", device=device)
22
- else:
23
- CURRENT_MODEL = Zonos.from_pretrained("Zyphra/Zonos-v0.1-hybrid", device=device)
24
- CURRENT_MODEL.to(device)
25
- CURRENT_MODEL.bfloat16()
26
- CURRENT_MODEL.eval()
27
- CURRENT_MODEL_TYPE = model_choice
28
- print(f"{model_choice} model loaded successfully!")
29
- else:
30
- print(f"{model_choice} model is already loaded.")
31
- return CURRENT_MODEL
32
-
33
-
34
- def update_ui(model_choice):
35
- """
36
- Dynamically show/hide UI elements based on the model's conditioners.
37
- We do NOT display 'language_id' or 'ctc_loss' even if they exist in the model.
38
- """
39
- model = load_model_if_needed(model_choice)
40
- cond_names = [c.name for c in model.prefix_conditioner.conditioners]
41
- print("Conditioners in this model:", cond_names)
42
-
43
- text_update = gr.update(visible=("espeak" in cond_names))
44
- language_update = gr.update(visible=("espeak" in cond_names))
45
- speaker_audio_update = gr.update(visible=("speaker" in cond_names))
46
- prefix_audio_update = gr.update(visible=True)
47
- skip_speaker_update = gr.update(visible=("speaker" in cond_names))
48
- skip_emotion_update = gr.update(visible=("emotion" in cond_names))
49
- emotion1_update = gr.update(visible=("emotion" in cond_names))
50
- emotion2_update = gr.update(visible=("emotion" in cond_names))
51
- emotion3_update = gr.update(visible=("emotion" in cond_names))
52
- emotion4_update = gr.update(visible=("emotion" in cond_names))
53
- emotion5_update = gr.update(visible=("emotion" in cond_names))
54
- emotion6_update = gr.update(visible=("emotion" in cond_names))
55
- emotion7_update = gr.update(visible=("emotion" in cond_names))
56
- emotion8_update = gr.update(visible=("emotion" in cond_names))
57
- skip_vqscore_8_update = gr.update(visible=("vqscore_8" in cond_names))
58
- vq_single_slider_update = gr.update(visible=("vqscore_8" in cond_names))
59
- fmax_slider_update = gr.update(visible=("fmax" in cond_names))
60
- skip_fmax_update = gr.update(visible=("fmax" in cond_names))
61
- pitch_std_slider_update = gr.update(visible=("pitch_std" in cond_names))
62
- skip_pitch_std_update = gr.update(visible=("pitch_std" in cond_names))
63
- speaking_rate_slider_update = gr.update(visible=("speaking_rate" in cond_names))
64
- skip_speaking_rate_update = gr.update(visible=("speaking_rate" in cond_names))
65
- dnsmos_slider_update = gr.update(visible=("dnsmos_ovrl" in cond_names))
66
- skip_dnsmos_ovrl_update = gr.update(visible=("dnsmos_ovrl" in cond_names))
67
- speaker_noised_checkbox_update = gr.update(visible=("speaker_noised" in cond_names))
68
- skip_speaker_noised_update = gr.update(visible=("speaker_noised" in cond_names))
69
-
70
- return (
71
- text_update, # 1
72
- language_update, # 2
73
- speaker_audio_update, # 3
74
- prefix_audio_update, # 4
75
- skip_speaker_update, # 5
76
- skip_emotion_update, # 6
77
- emotion1_update, # 7
78
- emotion2_update, # 8
79
- emotion3_update, # 9
80
- emotion4_update, # 10
81
- emotion5_update, # 11
82
- emotion6_update, # 12
83
- emotion7_update, # 13
84
- emotion8_update, # 14
85
- skip_vqscore_8_update, # 15
86
- vq_single_slider_update, # 16
87
- fmax_slider_update, # 17
88
- skip_fmax_update, # 18
89
- pitch_std_slider_update, # 19
90
- skip_pitch_std_update, # 20
91
- speaking_rate_slider_update, # 21
92
- skip_speaking_rate_update, # 22
93
- dnsmos_slider_update, # 23
94
- skip_dnsmos_ovrl_update, # 24
95
- speaker_noised_checkbox_update, # 25
96
- skip_speaker_noised_update, # 26
97
- )
98
-
99
-
100
- def generate_audio(
101
- model_choice,
102
- text,
103
- language,
104
- speaker_audio,
105
- prefix_audio,
106
- skip_speaker,
107
- skip_emotion,
108
- e1,
109
- e2,
110
- e3,
111
- e4,
112
- e5,
113
- e6,
114
- e7,
115
- e8,
116
- skip_vqscore_8,
117
- vq_single,
118
- fmax,
119
- skip_fmax,
120
- pitch_std,
121
- skip_pitch_std,
122
- speaking_rate,
123
- skip_speaking_rate,
124
- dnsmos_ovrl,
125
- skip_dnsmos_ovrl,
126
- speaker_noised,
127
- skip_speaker_noised,
128
- cfg_scale,
129
- min_p,
130
- seed,
131
- ):
132
  """
133
- Generates audio based on the provided UI parameters.
134
- We do NOT use language_id or ctc_loss even if the model has them.
135
  """
136
- selected_model = load_model_if_needed(model_choice)
137
-
138
- uncond_keys = []
139
- if skip_speaker:
140
- uncond_keys.append("speaker")
141
- if skip_emotion:
142
- uncond_keys.append("emotion")
143
- if skip_vqscore_8:
144
- uncond_keys.append("vqscore_8")
145
- if skip_fmax:
146
- uncond_keys.append("fmax")
147
- if skip_pitch_std:
148
- uncond_keys.append("pitch_std")
149
- if skip_speaking_rate:
150
- uncond_keys.append("speaking_rate")
151
- if skip_dnsmos_ovrl:
152
- uncond_keys.append("dnsmos_ovrl")
153
- if skip_speaker_noised:
154
- uncond_keys.append("speaker_noised")
155
-
156
- speaker_noised_bool = bool(speaker_noised)
157
- fmax = float(fmax)
158
- pitch_std = float(pitch_std)
159
- speaking_rate = float(speaking_rate)
160
- dnsmos_ovrl = float(dnsmos_ovrl)
161
- cfg_scale = float(cfg_scale)
162
- min_p = float(min_p)
163
- seed = int(seed)
164
- max_new_tokens = 86 * 30
165
-
166
- torch.manual_seed(seed)
167
-
168
- speaker_embedding = None
169
- if speaker_audio is not None and not skip_speaker:
170
- wav, sr = torchaudio.load(speaker_audio)
171
- speaker_embedding = selected_model.make_speaker_embedding(wav, sr)
172
- speaker_embedding = speaker_embedding.to(device, dtype=torch.bfloat16)
173
-
174
- audio_prefix_codes = None
175
- if prefix_audio is not None:
176
- wav_prefix, sr_prefix = torchaudio.load(prefix_audio)
177
- wav_prefix = wav_prefix.mean(0, keepdim=True)
178
- wav_prefix = torchaudio.functional.resample(wav_prefix, sr_prefix, selected_model.autoencoder.sampling_rate)
179
- wav_prefix = wav_prefix.to(device, dtype=torch.float32)
180
- with torch.autocast(device, dtype=torch.float32):
181
- audio_prefix_codes = selected_model.autoencoder.encode(wav_prefix.unsqueeze(0))
182
-
183
- emotion_tensor = torch.tensor(
184
- [[float(e1), float(e2), float(e3), float(e4), float(e5), float(e6), float(e7), float(e8)]], device=device
185
- )
186
-
187
- vq_val = float(vq_single)
188
- vq_tensor = torch.tensor([vq_val] * 8, device=device).unsqueeze(0)
189
-
190
  cond_dict = make_cond_dict(
191
  text=text,
192
- language=language,
193
- speaker=speaker_embedding,
194
- emotion=emotion_tensor,
195
- vqscore_8=vq_tensor,
196
- fmax=fmax,
197
- pitch_std=pitch_std,
198
- speaking_rate=speaking_rate,
199
- dnsmos_ovrl=dnsmos_ovrl,
200
- speaker_noised=speaker_noised_bool,
201
- device=device,
202
- unconditional_keys=uncond_keys,
203
- )
204
- conditioning = selected_model.prepare_conditioning(cond_dict)
205
-
206
- codes = selected_model.generate(
207
- prefix_conditioning=conditioning,
208
- audio_prefix_codes=audio_prefix_codes,
209
- max_new_tokens=max_new_tokens,
210
- cfg_scale=cfg_scale,
211
- batch_size=1,
212
- sampling_params=dict(min_p=min_p),
213
  )
214
-
215
- wav_out = selected_model.autoencoder.decode(codes).cpu().detach()
216
- sr_out = selected_model.autoencoder.sampling_rate
217
- if wav_out.dim() == 2 and wav_out.size(0) > 1:
218
- wav_out = wav_out[0:1, :]
219
- return sr_out, wav_out.squeeze().numpy()
220
-
221
-
222
- def build_interface():
223
- with gr.Blocks() as demo:
224
- with gr.Row():
225
- with gr.Column():
226
- model_choice = gr.Dropdown(
227
- choices=["Hybrid", "Transformer"],
228
- value="Transformer",
229
- label="Zonos Model Type",
230
- info="Select the model variant to use.",
231
- )
232
- text = gr.Textbox(
233
- label="Text to Synthesize", value="Zonos uses eSpeak for text to phoneme conversion!", lines=4
234
- )
235
- language = gr.Dropdown(
236
- choices=supported_language_codes,
237
- value="en-us",
238
- label="Language Code",
239
- info="Select a language code.",
240
- )
241
- prefix_audio = gr.Audio(
242
- value="assets/silence_100ms.wav",
243
- label="Optional Prefix Audio (continue from this audio)",
244
- type="filepath",
245
- )
246
- with gr.Column():
247
- speaker_audio = gr.Audio(
248
- label="Optional Speaker Audio (for cloning)",
249
- type="filepath",
250
- )
251
- speaker_noised_checkbox = gr.Checkbox(label="Denoise Speaker?", value=False)
252
-
253
- with gr.Column():
254
- gr.Markdown("## Conditioning Parameters")
255
-
256
- with gr.Row():
257
- dnsmos_slider = gr.Slider(1.0, 5.0, value=4.0, step=0.1, label="DNSMOS Overall")
258
- fmax_slider = gr.Slider(0, 24000, value=22050, step=1, label="Fmax (Hz)")
259
- vq_single_slider = gr.Slider(0.5, 0.8, 0.78, 0.01, label="VQ Score")
260
- pitch_std_slider = gr.Slider(0.0, 400.0, value=20.0, step=1, label="Pitch Std")
261
- speaking_rate_slider = gr.Slider(0.0, 40.0, value=15.0, step=1, label="Speaking Rate")
262
-
263
- gr.Markdown("### Emotion Sliders")
264
- with gr.Row():
265
- emotion1 = gr.Slider(0.0, 1.0, 0.6, 0.05, label="Happiness")
266
- emotion2 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Sadness")
267
- emotion3 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Disgust")
268
- emotion4 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Fear")
269
- with gr.Row():
270
- emotion5 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Surprise")
271
- emotion6 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Anger")
272
- emotion7 = gr.Slider(0.0, 1.0, 0.5, 0.05, label="Other")
273
- emotion8 = gr.Slider(0.0, 1.0, 0.6, 0.05, label="Neutral")
274
-
275
- gr.Markdown("### Unconditional Toggles")
276
- with gr.Row():
277
- skip_speaker = gr.Checkbox(label="Skip Speaker", value=False)
278
- skip_emotion = gr.Checkbox(label="Skip Emotion", value=False)
279
- skip_vqscore_8 = gr.Checkbox(label="Skip VQ Score", value=True)
280
- skip_fmax = gr.Checkbox(label="Skip Fmax", value=False)
281
- skip_pitch_std = gr.Checkbox(label="Skip Pitch Std", value=False)
282
- skip_speaking_rate = gr.Checkbox(label="Skip Speaking Rate", value=False)
283
- skip_dnsmos_ovrl = gr.Checkbox(label="Skip DNSMOS", value=True)
284
- skip_speaker_noised = gr.Checkbox(label="Skip Noised Speaker", value=False)
285
-
286
- with gr.Column():
287
- gr.Markdown("## Generation Parameters")
288
- with gr.Row():
289
- cfg_scale_slider = gr.Slider(1.0, 5.0, 2.0, 0.1, label="CFG Scale")
290
- min_p_slider = gr.Slider(0.0, 1.0, 0.1, 0.01, label="Min P")
291
- seed_number = gr.Number(label="Seed", value=420, precision=0)
292
-
293
- generate_button = gr.Button("Generate Audio")
294
- output_audio = gr.Audio(label="Generated Audio", type="numpy")
295
-
296
- model_choice.change(
297
- fn=update_ui,
298
- inputs=[model_choice],
299
- outputs=[
300
- text, # 1
301
- language, # 2
302
- speaker_audio, # 3
303
- prefix_audio, # 4
304
- skip_speaker, # 5
305
- skip_emotion, # 6
306
- emotion1, # 7
307
- emotion2, # 8
308
- emotion3, # 9
309
- emotion4, # 10
310
- emotion5, # 11
311
- emotion6, # 12
312
- emotion7, # 13
313
- emotion8, # 14
314
- skip_vqscore_8, # 15
315
- vq_single_slider, # 16
316
- fmax_slider, # 17
317
- skip_fmax, # 18
318
- pitch_std_slider, # 19
319
- skip_pitch_std, # 20
320
- speaking_rate_slider, # 21
321
- skip_speaking_rate, # 22
322
- dnsmos_slider, # 23
323
- skip_dnsmos_ovrl, # 24
324
- speaker_noised_checkbox, # 25
325
- skip_speaker_noised, # 26
326
- ],
327
- )
328
-
329
- # On page load, trigger the same UI refresh
330
- demo.load(
331
- fn=update_ui,
332
- inputs=[model_choice],
333
- outputs=[
334
- text,
335
- language,
336
- speaker_audio,
337
- prefix_audio,
338
- skip_speaker,
339
- skip_emotion,
340
- emotion1,
341
- emotion2,
342
- emotion3,
343
- emotion4,
344
- emotion5,
345
- emotion6,
346
- emotion7,
347
- emotion8,
348
- skip_vqscore_8,
349
- vq_single_slider,
350
- fmax_slider,
351
- skip_fmax,
352
- pitch_std_slider,
353
- skip_pitch_std,
354
- speaking_rate_slider,
355
- skip_speaking_rate,
356
- dnsmos_slider,
357
- skip_dnsmos_ovrl,
358
- speaker_noised_checkbox,
359
- skip_speaker_noised,
360
- ],
361
- )
362
-
363
- # Generate audio on button click
364
- generate_button.click(
365
- fn=generate_audio,
366
- inputs=[
367
- model_choice,
368
- text,
369
- language,
370
- speaker_audio,
371
- prefix_audio,
372
- skip_speaker,
373
- skip_emotion,
374
- emotion1,
375
- emotion2,
376
- emotion3,
377
- emotion4,
378
- emotion5,
379
- emotion6,
380
- emotion7,
381
- emotion8,
382
- skip_vqscore_8,
383
- vq_single_slider,
384
- fmax_slider,
385
- skip_fmax,
386
- pitch_std_slider,
387
- skip_pitch_std,
388
- speaking_rate_slider,
389
- skip_speaking_rate,
390
- dnsmos_slider,
391
- skip_dnsmos_ovrl,
392
- speaker_noised_checkbox,
393
- skip_speaker_noised,
394
- cfg_scale_slider,
395
- min_p_slider,
396
- seed_number,
397
- ],
398
- outputs=[output_audio],
399
- )
400
-
401
- return demo
402
-
403
 
404
  if __name__ == "__main__":
405
- demo = build_interface()
406
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
3
  import gradio as gr
4
 
5
  from zonos.model import Zonos
6
+ from zonos.conditioning import make_cond_dict
7
 
8
+ # Load the hybrid model
9
+ model = Zonos.from_pretrained("Zyphra/Zonos-v0.1-hybrid", device="cuda")
10
+ model.bfloat16() # Switch model weights to bfloat16 precision (optional, but recommended for GPU)
11
 
12
+ # Main inference function for Gradio
13
+ def tts(text, reference_audio):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
+ text: str
16
+ reference_audio: (numpy.ndarray, int) -> (data, sample_rate)
17
  """
18
+ if reference_audio is None:
19
+ return "No reference audio provided."
20
+
21
+ # reference_audio[0] is a NumPy float32 array of shape (num_samples, 1) or (num_samples,)
22
+ # reference_audio[1] is the sample rate
23
+ wav_np, sr = reference_audio
24
+
25
+ # Convert NumPy audio to Torch tensor
26
+ wav_torch = torch.from_numpy(wav_np).float().unsqueeze(0) # shape: (1, num_samples)
27
+ if wav_torch.dim() == 2 and wav_torch.shape[0] > wav_torch.shape[1]:
28
+ # If the shape is (samples, 1), reorder to (1, samples)
29
+ wav_torch = wav_torch.T
30
+
31
+ # Create speaker embedding
32
+ spk_embedding = model.embed_spk_audio(wav_torch, sr)
33
+
34
+ # Prepare conditioning
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  cond_dict = make_cond_dict(
36
  text=text,
37
+ speaker=spk_embedding.to(torch.bfloat16),
38
+ language="en-us",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  )
40
+ conditioning = model.prepare_conditioning(cond_dict)
41
+
42
+ # Generate codes
43
+ with torch.no_grad():
44
+ torch.manual_seed(421) # Seeding for reproducible results
45
+ codes = model.generate(conditioning)
46
+
47
+ # Decode the codes into waveform
48
+ wavs = model.autoencoder.decode(codes).cpu()
49
+ out_audio = wavs[0].numpy() # shape: (num_samples,)
50
+
51
+ # Return as (sample_rate, audio_ndarray) for Gradio's "audio" output
52
+ return (model.autoencoder.sampling_rate, out_audio)
53
+
54
+
55
+ # Define the Gradio interface
56
+ # - text input for the prompt
57
+ # - audio input for the speaker reference
58
+ # - audio output with the generated speech
59
+ demo = gr.Interface(
60
+ fn=tts,
61
+ inputs=[
62
+ gr.Textbox(label="Text to Synthesize"),
63
+ gr.Audio(source="upload", type="numpy", label="Reference Audio (for speaker embedding)"),
64
+ ],
65
+ outputs=gr.Audio(label="Generated Audio"),
66
+ title="Zonos TTS Demo (Hybrid)",
67
+ description=(
68
+ "Provide a reference audio snippet for speaker embedding, "
69
+ "enter text, and generate speech with Zonos TTS."
70
+ ),
71
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  if __name__ == "__main__":
74
+ demo.launch(debug=True)