Steveeeeeeen HF staff commited on
Commit
748ecaa
·
verified ·
1 Parent(s): 89d56a3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +406 -0
app.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import gradio as gr
4
+
5
+ from zonos.model import Zonos
6
+ from zonos.conditioning import make_cond_dict, supported_language_codes
7
+
8
+ device = "cuda"
9
+ CURRENT_MODEL_TYPE = None
10
+ CURRENT_MODEL = None
11
+
12
+
13
+ def load_model_if_needed(model_choice: str):
14
+ global CURRENT_MODEL_TYPE, CURRENT_MODEL
15
+ if CURRENT_MODEL_TYPE != model_choice:
16
+ if CURRENT_MODEL is not None:
17
+ del CURRENT_MODEL
18
+ torch.cuda.empty_cache()
19
+ print(f"Loading {model_choice} model...")
20
+ if model_choice == "Transformer":
21
+ CURRENT_MODEL = Zonos.from_pretrained("Zyphra/Zonos-v0.1-transformer", device=device)
22
+ else:
23
+ CURRENT_MODEL = Zonos.from_pretrained("Zyphra/Zonos-v0.1-hybrid", device=device)
24
+ CURRENT_MODEL.to(device)
25
+ CURRENT_MODEL.bfloat16()
26
+ CURRENT_MODEL.eval()
27
+ CURRENT_MODEL_TYPE = model_choice
28
+ print(f"{model_choice} model loaded successfully!")
29
+ else:
30
+ print(f"{model_choice} model is already loaded.")
31
+ return CURRENT_MODEL
32
+
33
+
34
+ def update_ui(model_choice):
35
+ """
36
+ Dynamically show/hide UI elements based on the model's conditioners.
37
+ We do NOT display 'language_id' or 'ctc_loss' even if they exist in the model.
38
+ """
39
+ model = load_model_if_needed(model_choice)
40
+ cond_names = [c.name for c in model.prefix_conditioner.conditioners]
41
+ print("Conditioners in this model:", cond_names)
42
+
43
+ text_update = gr.update(visible=("espeak" in cond_names))
44
+ language_update = gr.update(visible=("espeak" in cond_names))
45
+ speaker_audio_update = gr.update(visible=("speaker" in cond_names))
46
+ prefix_audio_update = gr.update(visible=True)
47
+ skip_speaker_update = gr.update(visible=("speaker" in cond_names))
48
+ skip_emotion_update = gr.update(visible=("emotion" in cond_names))
49
+ emotion1_update = gr.update(visible=("emotion" in cond_names))
50
+ emotion2_update = gr.update(visible=("emotion" in cond_names))
51
+ emotion3_update = gr.update(visible=("emotion" in cond_names))
52
+ emotion4_update = gr.update(visible=("emotion" in cond_names))
53
+ emotion5_update = gr.update(visible=("emotion" in cond_names))
54
+ emotion6_update = gr.update(visible=("emotion" in cond_names))
55
+ emotion7_update = gr.update(visible=("emotion" in cond_names))
56
+ emotion8_update = gr.update(visible=("emotion" in cond_names))
57
+ skip_vqscore_8_update = gr.update(visible=("vqscore_8" in cond_names))
58
+ vq_single_slider_update = gr.update(visible=("vqscore_8" in cond_names))
59
+ fmax_slider_update = gr.update(visible=("fmax" in cond_names))
60
+ skip_fmax_update = gr.update(visible=("fmax" in cond_names))
61
+ pitch_std_slider_update = gr.update(visible=("pitch_std" in cond_names))
62
+ skip_pitch_std_update = gr.update(visible=("pitch_std" in cond_names))
63
+ speaking_rate_slider_update = gr.update(visible=("speaking_rate" in cond_names))
64
+ skip_speaking_rate_update = gr.update(visible=("speaking_rate" in cond_names))
65
+ dnsmos_slider_update = gr.update(visible=("dnsmos_ovrl" in cond_names))
66
+ skip_dnsmos_ovrl_update = gr.update(visible=("dnsmos_ovrl" in cond_names))
67
+ speaker_noised_checkbox_update = gr.update(visible=("speaker_noised" in cond_names))
68
+ skip_speaker_noised_update = gr.update(visible=("speaker_noised" in cond_names))
69
+
70
+ return (
71
+ text_update, # 1
72
+ language_update, # 2
73
+ speaker_audio_update, # 3
74
+ prefix_audio_update, # 4
75
+ skip_speaker_update, # 5
76
+ skip_emotion_update, # 6
77
+ emotion1_update, # 7
78
+ emotion2_update, # 8
79
+ emotion3_update, # 9
80
+ emotion4_update, # 10
81
+ emotion5_update, # 11
82
+ emotion6_update, # 12
83
+ emotion7_update, # 13
84
+ emotion8_update, # 14
85
+ skip_vqscore_8_update, # 15
86
+ vq_single_slider_update, # 16
87
+ fmax_slider_update, # 17
88
+ skip_fmax_update, # 18
89
+ pitch_std_slider_update, # 19
90
+ skip_pitch_std_update, # 20
91
+ speaking_rate_slider_update, # 21
92
+ skip_speaking_rate_update, # 22
93
+ dnsmos_slider_update, # 23
94
+ skip_dnsmos_ovrl_update, # 24
95
+ speaker_noised_checkbox_update, # 25
96
+ skip_speaker_noised_update, # 26
97
+ )
98
+
99
+
100
+ def generate_audio(
101
+ model_choice,
102
+ text,
103
+ language,
104
+ speaker_audio,
105
+ prefix_audio,
106
+ skip_speaker,
107
+ skip_emotion,
108
+ e1,
109
+ e2,
110
+ e3,
111
+ e4,
112
+ e5,
113
+ e6,
114
+ e7,
115
+ e8,
116
+ skip_vqscore_8,
117
+ vq_single,
118
+ fmax,
119
+ skip_fmax,
120
+ pitch_std,
121
+ skip_pitch_std,
122
+ speaking_rate,
123
+ skip_speaking_rate,
124
+ dnsmos_ovrl,
125
+ skip_dnsmos_ovrl,
126
+ speaker_noised,
127
+ skip_speaker_noised,
128
+ cfg_scale,
129
+ min_p,
130
+ seed,
131
+ ):
132
+ """
133
+ Generates audio based on the provided UI parameters.
134
+ We do NOT use language_id or ctc_loss even if the model has them.
135
+ """
136
+ selected_model = load_model_if_needed(model_choice)
137
+
138
+ uncond_keys = []
139
+ if skip_speaker:
140
+ uncond_keys.append("speaker")
141
+ if skip_emotion:
142
+ uncond_keys.append("emotion")
143
+ if skip_vqscore_8:
144
+ uncond_keys.append("vqscore_8")
145
+ if skip_fmax:
146
+ uncond_keys.append("fmax")
147
+ if skip_pitch_std:
148
+ uncond_keys.append("pitch_std")
149
+ if skip_speaking_rate:
150
+ uncond_keys.append("speaking_rate")
151
+ if skip_dnsmos_ovrl:
152
+ uncond_keys.append("dnsmos_ovrl")
153
+ if skip_speaker_noised:
154
+ uncond_keys.append("speaker_noised")
155
+
156
+ speaker_noised_bool = bool(speaker_noised)
157
+ fmax = float(fmax)
158
+ pitch_std = float(pitch_std)
159
+ speaking_rate = float(speaking_rate)
160
+ dnsmos_ovrl = float(dnsmos_ovrl)
161
+ cfg_scale = float(cfg_scale)
162
+ min_p = float(min_p)
163
+ seed = int(seed)
164
+ max_new_tokens = 86 * 30
165
+
166
+ torch.manual_seed(seed)
167
+
168
+ speaker_embedding = None
169
+ if speaker_audio is not None and not skip_speaker:
170
+ wav, sr = torchaudio.load(speaker_audio)
171
+ speaker_embedding = selected_model.make_speaker_embedding(wav, sr)
172
+ speaker_embedding = speaker_embedding.to(device, dtype=torch.bfloat16)
173
+
174
+ audio_prefix_codes = None
175
+ if prefix_audio is not None:
176
+ wav_prefix, sr_prefix = torchaudio.load(prefix_audio)
177
+ wav_prefix = wav_prefix.mean(0, keepdim=True)
178
+ wav_prefix = torchaudio.functional.resample(wav_prefix, sr_prefix, selected_model.autoencoder.sampling_rate)
179
+ wav_prefix = wav_prefix.to(device, dtype=torch.float32)
180
+ with torch.autocast(device, dtype=torch.float32):
181
+ audio_prefix_codes = selected_model.autoencoder.encode(wav_prefix.unsqueeze(0))
182
+
183
+ emotion_tensor = torch.tensor(
184
+ [[float(e1), float(e2), float(e3), float(e4), float(e5), float(e6), float(e7), float(e8)]], device=device
185
+ )
186
+
187
+ vq_val = float(vq_single)
188
+ vq_tensor = torch.tensor([vq_val] * 8, device=device).unsqueeze(0)
189
+
190
+ cond_dict = make_cond_dict(
191
+ text=text,
192
+ language=language,
193
+ speaker=speaker_embedding,
194
+ emotion=emotion_tensor,
195
+ vqscore_8=vq_tensor,
196
+ fmax=fmax,
197
+ pitch_std=pitch_std,
198
+ speaking_rate=speaking_rate,
199
+ dnsmos_ovrl=dnsmos_ovrl,
200
+ speaker_noised=speaker_noised_bool,
201
+ device=device,
202
+ unconditional_keys=uncond_keys,
203
+ )
204
+ conditioning = selected_model.prepare_conditioning(cond_dict)
205
+
206
+ codes = selected_model.generate(
207
+ prefix_conditioning=conditioning,
208
+ audio_prefix_codes=audio_prefix_codes,
209
+ max_new_tokens=max_new_tokens,
210
+ cfg_scale=cfg_scale,
211
+ batch_size=1,
212
+ sampling_params=dict(min_p=min_p),
213
+ )
214
+
215
+ wav_out = selected_model.autoencoder.decode(codes).cpu().detach()
216
+ sr_out = selected_model.autoencoder.sampling_rate
217
+ if wav_out.dim() == 2 and wav_out.size(0) > 1:
218
+ wav_out = wav_out[0:1, :]
219
+ return sr_out, wav_out.squeeze().numpy()
220
+
221
+
222
+ def build_interface():
223
+ with gr.Blocks() as demo:
224
+ with gr.Row():
225
+ with gr.Column():
226
+ model_choice = gr.Dropdown(
227
+ choices=["Hybrid", "Transformer"],
228
+ value="Transformer",
229
+ label="Zonos Model Type",
230
+ info="Select the model variant to use.",
231
+ )
232
+ text = gr.Textbox(
233
+ label="Text to Synthesize", value="Zonos uses eSpeak for text to phoneme conversion!", lines=4
234
+ )
235
+ language = gr.Dropdown(
236
+ choices=supported_language_codes,
237
+ value="en-us",
238
+ label="Language Code",
239
+ info="Select a language code.",
240
+ )
241
+ prefix_audio = gr.Audio(
242
+ value="assets/silence_100ms.wav",
243
+ label="Optional Prefix Audio (continue from this audio)",
244
+ type="filepath",
245
+ )
246
+ with gr.Column():
247
+ speaker_audio = gr.Audio(
248
+ label="Optional Speaker Audio (for cloning)",
249
+ type="filepath",
250
+ )
251
+ speaker_noised_checkbox = gr.Checkbox(label="Denoise Speaker?", value=False)
252
+
253
+ with gr.Column():
254
+ gr.Markdown("## Conditioning Parameters")
255
+
256
+ with gr.Row():
257
+ dnsmos_slider = gr.Slider(1.0, 5.0, value=4.0, step=0.1, label="DNSMOS Overall")
258
+ fmax_slider = gr.Slider(0, 24000, value=22050, step=1, label="Fmax (Hz)")
259
+ vq_single_slider = gr.Slider(0.5, 0.8, 0.78, 0.01, label="VQ Score")
260
+ pitch_std_slider = gr.Slider(0.0, 400.0, value=20.0, step=1, label="Pitch Std")
261
+ speaking_rate_slider = gr.Slider(0.0, 40.0, value=15.0, step=1, label="Speaking Rate")
262
+
263
+ gr.Markdown("### Emotion Sliders")
264
+ with gr.Row():
265
+ emotion1 = gr.Slider(0.0, 1.0, 0.6, 0.05, label="Happiness")
266
+ emotion2 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Sadness")
267
+ emotion3 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Disgust")
268
+ emotion4 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Fear")
269
+ with gr.Row():
270
+ emotion5 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Surprise")
271
+ emotion6 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Anger")
272
+ emotion7 = gr.Slider(0.0, 1.0, 0.5, 0.05, label="Other")
273
+ emotion8 = gr.Slider(0.0, 1.0, 0.6, 0.05, label="Neutral")
274
+
275
+ gr.Markdown("### Unconditional Toggles")
276
+ with gr.Row():
277
+ skip_speaker = gr.Checkbox(label="Skip Speaker", value=False)
278
+ skip_emotion = gr.Checkbox(label="Skip Emotion", value=False)
279
+ skip_vqscore_8 = gr.Checkbox(label="Skip VQ Score", value=True)
280
+ skip_fmax = gr.Checkbox(label="Skip Fmax", value=False)
281
+ skip_pitch_std = gr.Checkbox(label="Skip Pitch Std", value=False)
282
+ skip_speaking_rate = gr.Checkbox(label="Skip Speaking Rate", value=False)
283
+ skip_dnsmos_ovrl = gr.Checkbox(label="Skip DNSMOS", value=True)
284
+ skip_speaker_noised = gr.Checkbox(label="Skip Noised Speaker", value=False)
285
+
286
+ with gr.Column():
287
+ gr.Markdown("## Generation Parameters")
288
+ with gr.Row():
289
+ cfg_scale_slider = gr.Slider(1.0, 5.0, 2.0, 0.1, label="CFG Scale")
290
+ min_p_slider = gr.Slider(0.0, 1.0, 0.1, 0.01, label="Min P")
291
+ seed_number = gr.Number(label="Seed", value=420, precision=0)
292
+
293
+ generate_button = gr.Button("Generate Audio")
294
+ output_audio = gr.Audio(label="Generated Audio", type="numpy")
295
+
296
+ model_choice.change(
297
+ fn=update_ui,
298
+ inputs=[model_choice],
299
+ outputs=[
300
+ text, # 1
301
+ language, # 2
302
+ speaker_audio, # 3
303
+ prefix_audio, # 4
304
+ skip_speaker, # 5
305
+ skip_emotion, # 6
306
+ emotion1, # 7
307
+ emotion2, # 8
308
+ emotion3, # 9
309
+ emotion4, # 10
310
+ emotion5, # 11
311
+ emotion6, # 12
312
+ emotion7, # 13
313
+ emotion8, # 14
314
+ skip_vqscore_8, # 15
315
+ vq_single_slider, # 16
316
+ fmax_slider, # 17
317
+ skip_fmax, # 18
318
+ pitch_std_slider, # 19
319
+ skip_pitch_std, # 20
320
+ speaking_rate_slider, # 21
321
+ skip_speaking_rate, # 22
322
+ dnsmos_slider, # 23
323
+ skip_dnsmos_ovrl, # 24
324
+ speaker_noised_checkbox, # 25
325
+ skip_speaker_noised, # 26
326
+ ],
327
+ )
328
+
329
+ # On page load, trigger the same UI refresh
330
+ demo.load(
331
+ fn=update_ui,
332
+ inputs=[model_choice],
333
+ outputs=[
334
+ text,
335
+ language,
336
+ speaker_audio,
337
+ prefix_audio,
338
+ skip_speaker,
339
+ skip_emotion,
340
+ emotion1,
341
+ emotion2,
342
+ emotion3,
343
+ emotion4,
344
+ emotion5,
345
+ emotion6,
346
+ emotion7,
347
+ emotion8,
348
+ skip_vqscore_8,
349
+ vq_single_slider,
350
+ fmax_slider,
351
+ skip_fmax,
352
+ pitch_std_slider,
353
+ skip_pitch_std,
354
+ speaking_rate_slider,
355
+ skip_speaking_rate,
356
+ dnsmos_slider,
357
+ skip_dnsmos_ovrl,
358
+ speaker_noised_checkbox,
359
+ skip_speaker_noised,
360
+ ],
361
+ )
362
+
363
+ # Generate audio on button click
364
+ generate_button.click(
365
+ fn=generate_audio,
366
+ inputs=[
367
+ model_choice,
368
+ text,
369
+ language,
370
+ speaker_audio,
371
+ prefix_audio,
372
+ skip_speaker,
373
+ skip_emotion,
374
+ emotion1,
375
+ emotion2,
376
+ emotion3,
377
+ emotion4,
378
+ emotion5,
379
+ emotion6,
380
+ emotion7,
381
+ emotion8,
382
+ skip_vqscore_8,
383
+ vq_single_slider,
384
+ fmax_slider,
385
+ skip_fmax,
386
+ pitch_std_slider,
387
+ skip_pitch_std,
388
+ speaking_rate_slider,
389
+ skip_speaking_rate,
390
+ dnsmos_slider,
391
+ skip_dnsmos_ovrl,
392
+ speaker_noised_checkbox,
393
+ skip_speaker_noised,
394
+ cfg_scale_slider,
395
+ min_p_slider,
396
+ seed_number,
397
+ ],
398
+ outputs=[output_audio],
399
+ )
400
+
401
+ return demo
402
+
403
+
404
+ if __name__ == "__main__":
405
+ demo = build_interface()
406
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)