Steveeeeeeen HF staff commited on
Commit
1be704d
·
verified ·
1 Parent(s): 1611a5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +328 -95
app.py CHANGED
@@ -1,139 +1,372 @@
1
  import torch
2
  import torchaudio
3
  import gradio as gr
4
- import spaces
5
 
6
  from zonos.model import Zonos
7
  from zonos.conditioning import make_cond_dict, supported_language_codes
8
 
9
- # We'll keep a global dictionary of loaded models to avoid reloading
10
- MODELS_CACHE = {}
11
  device = "cuda"
 
 
12
 
13
- banner_url = "https://huggingface.co/datasets/Steveeeeeeen/random_images/resolve/main/ZonosHeader.png"
14
- BANNER = f'<div style="display: flex; justify-content: space-around;"><img src="{banner_url}" alt="Banner" style="width: 40vw; min-width: 150px; max-width: 300px;"> </div>'
15
 
16
- def load_model(model_name: str):
17
- """
18
- Loads or retrieves a cached Zonos model, sets it to eval and bfloat16.
 
 
 
 
 
 
 
 
 
 
 
 
19
  """
20
- global MODELS_CACHE
21
- if model_name not in MODELS_CACHE:
22
- print(f"Loading model: {model_name}")
23
- model = Zonos.from_pretrained(model_name, device=device)
24
- model = model.requires_grad_(False).eval()
25
- model.bfloat16() # optional if GPU supports bfloat16
26
- MODELS_CACHE[model_name] = model
27
- print(f"Model loaded successfully: {model_name}")
28
- return MODELS_CACHE[model_name]
29
-
30
- @spaces.GPU(duration=90)
31
- def tts(text, speaker_audio, selected_language, model_choice):
32
  """
33
- text: str (Text prompt to synthesize)
34
- speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy"
35
- selected_language: str (language code)
36
- model_choice: str (which Zonos model to use, e.g., "Zyphra/Zonos-v0.1-hybrid")
37
 
38
- Returns (sr_out, wav_out_numpy).
39
- """
40
- model = load_model(model_choice)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- if not text:
43
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- # If the user did not provide a reference audio, skip
46
- if speaker_audio is None:
47
- return None
48
 
49
- # Gradio gives audio in (sample_rate, numpy_array) format
50
- sr, wav_np = speaker_audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- # Convert to Torch tensor
53
- wav_tensor = torch.from_numpy(wav_np).float()
 
54
 
55
- # If stereo (shape [channels, samples]) or multi-channel, downmix to mono
56
- # e.g. shape (2, samples) -> shape (samples,) by averaging
57
- if wav_tensor.ndim == 2 and wav_tensor.shape[0] > 1:
58
- wav_tensor = wav_tensor.mean(dim=0) # shape => (samples,)
 
59
 
60
- # Now add a batch dimension => shape (1, samples)
61
- wav_tensor = wav_tensor.unsqueeze(0)
 
 
 
 
 
 
62
 
63
- # Get speaker embedding
64
- with torch.no_grad():
65
- spk_embedding = model.make_speaker_embedding(wav_tensor, sr)
66
- spk_embedding = spk_embedding.to(device, dtype=torch.bfloat16)
67
 
68
- # Prepare conditioning dictionary
69
  cond_dict = make_cond_dict(
70
  text=text,
71
- speaker=spk_embedding,
72
- language=selected_language,
 
 
 
 
 
 
 
73
  device=device,
 
74
  )
75
- conditioning = model.prepare_conditioning(cond_dict)
 
 
 
76
 
77
- # Generate codes
78
- with torch.no_grad():
79
- codes = model.generate(conditioning)
80
 
81
- # Decode the codes into raw audio
82
- wav_out = model.autoencoder.decode(codes).cpu().detach().squeeze()
83
- sr_out = model.autoencoder.sampling_rate
 
 
 
 
 
 
84
 
85
- return (sr_out, wav_out.numpy())
 
 
 
 
86
 
87
- def build_demo():
88
- with gr.Blocks(theme='davehornik/Tealy') as demo:
89
- gr.HTML(BANNER, elem_id="banner")
90
- gr.Markdown("## Zonos-v0.1 TTS Demo")
91
- gr.Markdown(
92
- """
93
- > **Zero-shot TTS with Voice Cloning**: Input text and a 10–30 second speaker sample to generate high-quality text-to-speech output.
94
 
95
- > **Audio Prefix Inputs**: Enhance speaker matching by adding an audio prefix to the text, enabling behaviors like whispering that are hard to achieve with voice cloning alone.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- > **Multilingual Support**: Supports English, Japanese, Chinese, French, and German.
98
- """
99
- )
100
  with gr.Row():
101
- text_input = gr.Textbox(
102
- label="Text Prompt",
103
- value="Hello from Zonos!",
104
- lines=3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  )
106
- ref_audio_input = gr.Audio(
107
- label="Reference Audio (Speaker Cloning)",
108
- type="numpy"
109
- # Optionally add mono=True if you want Gradio to always downmix automatically:
110
- # mono=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  )
 
 
 
 
 
 
 
 
 
 
112
 
113
- model_dropdown = gr.Dropdown(
114
- label="Model Choice",
115
- choices=["Zyphra/Zonos-v0.1-transformer", "Zyphra/Zonos-v0.1-hybrid"],
116
- value="Zyphra/Zonos-v0.1-hybrid",
117
- interactive=True,
118
- )
119
- language_dropdown = gr.Dropdown(
120
- label="Language Code",
121
- choices=["en-us", "ja", "cmn", "fr-fr", "de"],
122
- value="en-us",
123
- interactive=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  )
125
 
126
- generate_button = gr.Button("Generate")
127
- audio_output = gr.Audio(label="Synthesized Output", type="numpy")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
 
129
  generate_button.click(
130
- fn=tts,
131
- inputs=[text_input, ref_audio_input, language_dropdown, model_dropdown],
132
- outputs=audio_output,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  )
134
 
135
  return demo
136
 
 
137
  if __name__ == "__main__":
138
- demo_app = build_demo()
139
- demo_app.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
 
1
  import torch
2
  import torchaudio
3
  import gradio as gr
4
+ from os import getenv
5
 
6
  from zonos.model import Zonos
7
  from zonos.conditioning import make_cond_dict, supported_language_codes
8
 
 
 
9
  device = "cuda"
10
+ CURRENT_MODEL_TYPE = None
11
+ CURRENT_MODEL = None
12
 
 
 
13
 
14
+ def load_model_if_needed(model_choice: str):
15
+ global CURRENT_MODEL_TYPE, CURRENT_MODEL
16
+ if CURRENT_MODEL_TYPE != model_choice:
17
+ if CURRENT_MODEL is not None:
18
+ del CURRENT_MODEL
19
+ torch.cuda.empty_cache()
20
+ print(f"Loading {model_choice} model...")
21
+ CURRENT_MODEL = Zonos.from_pretrained(model_choice, device=device)
22
+ CURRENT_MODEL.requires_grad_(False).eval()
23
+ CURRENT_MODEL_TYPE = model_choice
24
+ print(f"{model_choice} model loaded successfully!")
25
+ return CURRENT_MODEL
26
+
27
+
28
+ def update_ui(model_choice):
29
  """
30
+ Dynamically show/hide UI elements based on the model's conditioners.
31
+ We do NOT display 'language_id' or 'ctc_loss' even if they exist in the model.
 
 
 
 
 
 
 
 
 
 
32
  """
33
+ model = load_model_if_needed(model_choice)
34
+ cond_names = [c.name for c in model.prefix_conditioner.conditioners]
35
+ print("Conditioners in this model:", cond_names)
 
36
 
37
+ text_update = gr.update(visible=("espeak" in cond_names))
38
+ language_update = gr.update(visible=("espeak" in cond_names))
39
+ speaker_audio_update = gr.update(visible=("speaker" in cond_names))
40
+ prefix_audio_update = gr.update(visible=True)
41
+ emotion1_update = gr.update(visible=("emotion" in cond_names))
42
+ emotion2_update = gr.update(visible=("emotion" in cond_names))
43
+ emotion3_update = gr.update(visible=("emotion" in cond_names))
44
+ emotion4_update = gr.update(visible=("emotion" in cond_names))
45
+ emotion5_update = gr.update(visible=("emotion" in cond_names))
46
+ emotion6_update = gr.update(visible=("emotion" in cond_names))
47
+ emotion7_update = gr.update(visible=("emotion" in cond_names))
48
+ emotion8_update = gr.update(visible=("emotion" in cond_names))
49
+ vq_single_slider_update = gr.update(visible=("vqscore_8" in cond_names))
50
+ fmax_slider_update = gr.update(visible=("fmax" in cond_names))
51
+ pitch_std_slider_update = gr.update(visible=("pitch_std" in cond_names))
52
+ speaking_rate_slider_update = gr.update(visible=("speaking_rate" in cond_names))
53
+ dnsmos_slider_update = gr.update(visible=("dnsmos_ovrl" in cond_names))
54
+ speaker_noised_checkbox_update = gr.update(visible=("speaker_noised" in cond_names))
55
+ unconditional_keys_update = gr.update(
56
+ choices=[name for name in cond_names if name not in ("espeak", "language_id")]
57
+ )
58
 
59
+ return (
60
+ text_update,
61
+ language_update,
62
+ speaker_audio_update,
63
+ prefix_audio_update,
64
+ emotion1_update,
65
+ emotion2_update,
66
+ emotion3_update,
67
+ emotion4_update,
68
+ emotion5_update,
69
+ emotion6_update,
70
+ emotion7_update,
71
+ emotion8_update,
72
+ vq_single_slider_update,
73
+ fmax_slider_update,
74
+ pitch_std_slider_update,
75
+ speaking_rate_slider_update,
76
+ dnsmos_slider_update,
77
+ speaker_noised_checkbox_update,
78
+ unconditional_keys_update,
79
+ )
80
 
 
 
 
81
 
82
+ def generate_audio(
83
+ model_choice,
84
+ text,
85
+ language,
86
+ speaker_audio,
87
+ prefix_audio,
88
+ e1,
89
+ e2,
90
+ e3,
91
+ e4,
92
+ e5,
93
+ e6,
94
+ e7,
95
+ e8,
96
+ vq_single,
97
+ fmax,
98
+ pitch_std,
99
+ speaking_rate,
100
+ dnsmos_ovrl,
101
+ speaker_noised,
102
+ cfg_scale,
103
+ min_p,
104
+ seed,
105
+ randomize_seed,
106
+ unconditional_keys,
107
+ progress=gr.Progress(),
108
+ ):
109
+ """
110
+ Generates audio based on the provided UI parameters.
111
+ We do NOT use language_id or ctc_loss even if the model has them.
112
+ """
113
+ selected_model = load_model_if_needed(model_choice)
114
+
115
+ speaker_noised_bool = bool(speaker_noised)
116
+ fmax = float(fmax)
117
+ pitch_std = float(pitch_std)
118
+ speaking_rate = float(speaking_rate)
119
+ dnsmos_ovrl = float(dnsmos_ovrl)
120
+ cfg_scale = float(cfg_scale)
121
+ min_p = float(min_p)
122
+ seed = int(seed)
123
+ max_new_tokens = 86 * 30
124
 
125
+ if randomize_seed:
126
+ seed = torch.randint(0, 2**32 - 1, (1,)).item()
127
+ torch.manual_seed(seed)
128
 
129
+ speaker_embedding = None
130
+ if speaker_audio is not None and "speaker" not in unconditional_keys:
131
+ wav, sr = torchaudio.load(speaker_audio)
132
+ speaker_embedding = selected_model.make_speaker_embedding(wav, sr)
133
+ speaker_embedding = speaker_embedding.to(device, dtype=torch.bfloat16)
134
 
135
+ audio_prefix_codes = None
136
+ if prefix_audio is not None:
137
+ wav_prefix, sr_prefix = torchaudio.load(prefix_audio)
138
+ wav_prefix = wav_prefix.mean(0, keepdim=True)
139
+ wav_prefix = torchaudio.functional.resample(wav_prefix, sr_prefix, selected_model.autoencoder.sampling_rate)
140
+ wav_prefix = wav_prefix.to(device, dtype=torch.float32)
141
+ with torch.autocast(device, dtype=torch.float32):
142
+ audio_prefix_codes = selected_model.autoencoder.encode(wav_prefix.unsqueeze(0))
143
 
144
+ emotion_tensor = torch.tensor(list(map(float, [e1, e2, e3, e4, e5, e6, e7, e8])), device=device)
145
+
146
+ vq_val = float(vq_single)
147
+ vq_tensor = torch.tensor([vq_val] * 8, device=device).unsqueeze(0)
148
 
 
149
  cond_dict = make_cond_dict(
150
  text=text,
151
+ language=language,
152
+ speaker=speaker_embedding,
153
+ emotion=emotion_tensor,
154
+ vqscore_8=vq_tensor,
155
+ fmax=fmax,
156
+ pitch_std=pitch_std,
157
+ speaking_rate=speaking_rate,
158
+ dnsmos_ovrl=dnsmos_ovrl,
159
+ speaker_noised=speaker_noised_bool,
160
  device=device,
161
+ unconditional_keys=unconditional_keys,
162
  )
163
+ conditioning = selected_model.prepare_conditioning(cond_dict)
164
+
165
+ estimated_generation_duration = 30 * len(text) / 400
166
+ estimated_total_steps = int(estimated_generation_duration * 86)
167
 
168
+ def update_progress(_frame: torch.Tensor, step: int, _total_steps: int) -> bool:
169
+ progress((step, estimated_total_steps))
170
+ return True
171
 
172
+ codes = selected_model.generate(
173
+ prefix_conditioning=conditioning,
174
+ audio_prefix_codes=audio_prefix_codes,
175
+ max_new_tokens=max_new_tokens,
176
+ cfg_scale=cfg_scale,
177
+ batch_size=1,
178
+ sampling_params=dict(min_p=min_p),
179
+ callback=update_progress,
180
+ )
181
 
182
+ wav_out = selected_model.autoencoder.decode(codes).cpu().detach()
183
+ sr_out = selected_model.autoencoder.sampling_rate
184
+ if wav_out.dim() == 2 and wav_out.size(0) > 1:
185
+ wav_out = wav_out[0:1, :]
186
+ return (sr_out, wav_out.squeeze().numpy()), seed
187
 
 
 
 
 
 
 
 
188
 
189
+ def build_interface():
190
+ with gr.Blocks() as demo:
191
+ with gr.Row():
192
+ with gr.Column():
193
+ model_choice = gr.Dropdown(
194
+ choices=["Zyphra/Zonos-v0.1-transformer", "Zyphra/Zonos-v0.1-hybrid"],
195
+ value="Zyphra/Zonos-v0.1-transformer",
196
+ label="Zonos Model Type",
197
+ info="Select the model variant to use.",
198
+ )
199
+ text = gr.Textbox(
200
+ label="Text to Synthesize",
201
+ value="Zonos uses eSpeak for text to phoneme conversion!",
202
+ lines=4,
203
+ max_length=500, # approximately
204
+ )
205
+ language = gr.Dropdown(
206
+ choices=supported_language_codes,
207
+ value="en-us",
208
+ label="Language Code",
209
+ info="Select a language code.",
210
+ )
211
+ prefix_audio = gr.Audio(
212
+ value="assets/silence_100ms.wav",
213
+ label="Optional Prefix Audio (continue from this audio)",
214
+ type="filepath",
215
+ )
216
+ with gr.Column():
217
+ speaker_audio = gr.Audio(
218
+ label="Optional Speaker Audio (for cloning)",
219
+ type="filepath",
220
+ )
221
+ speaker_noised_checkbox = gr.Checkbox(label="Denoise Speaker?", value=False)
222
 
 
 
 
223
  with gr.Row():
224
+ with gr.Column():
225
+ gr.Markdown("## Conditioning Parameters")
226
+ dnsmos_slider = gr.Slider(1.0, 5.0, value=4.0, step=0.1, label="DNSMOS Overall")
227
+ fmax_slider = gr.Slider(0, 24000, value=24000, step=1, label="Fmax (Hz)")
228
+ vq_single_slider = gr.Slider(0.5, 0.8, 0.78, 0.01, label="VQ Score")
229
+ pitch_std_slider = gr.Slider(0.0, 300.0, value=45.0, step=1, label="Pitch Std")
230
+ speaking_rate_slider = gr.Slider(5.0, 30.0, value=15.0, step=0.5, label="Speaking Rate")
231
+
232
+ with gr.Column():
233
+ gr.Markdown("## Generation Parameters")
234
+ cfg_scale_slider = gr.Slider(1.0, 5.0, 2.0, 0.1, label="CFG Scale")
235
+ min_p_slider = gr.Slider(0.0, 1.0, 0.15, 0.01, label="Min P")
236
+ seed_number = gr.Number(label="Seed", value=420, precision=0)
237
+ randomize_seed_toggle = gr.Checkbox(label="Randomize Seed (before generation)", value=True)
238
+
239
+ with gr.Accordion("Advanced Parameters", open=False):
240
+ gr.Markdown(
241
+ "### Unconditional Toggles\n"
242
+ "Checking a box will make the model ignore the corresponding conditioning value and make it unconditional.\n"
243
+ 'Practically this means the given conditioning feature will be unconstrained and "filled in automatically".'
244
  )
245
+ with gr.Row():
246
+ unconditional_keys = gr.CheckboxGroup(
247
+ [
248
+ "speaker",
249
+ "emotion",
250
+ "vqscore_8",
251
+ "fmax",
252
+ "pitch_std",
253
+ "speaking_rate",
254
+ "dnsmos_ovrl",
255
+ "speaker_noised",
256
+ ],
257
+ value=["emotion"],
258
+ label="Unconditional Keys",
259
+ )
260
+
261
+ gr.Markdown(
262
+ "### Emotion Sliders\n"
263
+ "Warning: The way these sliders work is not intuitive and may require some trial and error to get the desired effect.\n"
264
+ "Certain configurations can cause the model to become unstable. Setting emotion to unconditional may help."
265
  )
266
+ with gr.Row():
267
+ emotion1 = gr.Slider(0.0, 1.0, 1.0, 0.05, label="Happiness")
268
+ emotion2 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Sadness")
269
+ emotion3 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Disgust")
270
+ emotion4 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Fear")
271
+ with gr.Row():
272
+ emotion5 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Surprise")
273
+ emotion6 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Anger")
274
+ emotion7 = gr.Slider(0.0, 1.0, 0.1, 0.05, label="Other")
275
+ emotion8 = gr.Slider(0.0, 1.0, 0.2, 0.05, label="Neutral")
276
 
277
+ with gr.Column():
278
+ generate_button = gr.Button("Generate Audio")
279
+ output_audio = gr.Audio(label="Generated Audio", type="numpy", autoplay=True)
280
+
281
+ model_choice.change(
282
+ fn=update_ui,
283
+ inputs=[model_choice],
284
+ outputs=[
285
+ text,
286
+ language,
287
+ speaker_audio,
288
+ prefix_audio,
289
+ emotion1,
290
+ emotion2,
291
+ emotion3,
292
+ emotion4,
293
+ emotion5,
294
+ emotion6,
295
+ emotion7,
296
+ emotion8,
297
+ vq_single_slider,
298
+ fmax_slider,
299
+ pitch_std_slider,
300
+ speaking_rate_slider,
301
+ dnsmos_slider,
302
+ speaker_noised_checkbox,
303
+ unconditional_keys,
304
+ ],
305
  )
306
 
307
+ # On page load, trigger the same UI refresh
308
+ demo.load(
309
+ fn=update_ui,
310
+ inputs=[model_choice],
311
+ outputs=[
312
+ text,
313
+ language,
314
+ speaker_audio,
315
+ prefix_audio,
316
+ emotion1,
317
+ emotion2,
318
+ emotion3,
319
+ emotion4,
320
+ emotion5,
321
+ emotion6,
322
+ emotion7,
323
+ emotion8,
324
+ vq_single_slider,
325
+ fmax_slider,
326
+ pitch_std_slider,
327
+ speaking_rate_slider,
328
+ dnsmos_slider,
329
+ speaker_noised_checkbox,
330
+ unconditional_keys,
331
+ ],
332
+ )
333
 
334
+ # Generate audio on button click
335
  generate_button.click(
336
+ fn=generate_audio,
337
+ inputs=[
338
+ model_choice,
339
+ text,
340
+ language,
341
+ speaker_audio,
342
+ prefix_audio,
343
+ emotion1,
344
+ emotion2,
345
+ emotion3,
346
+ emotion4,
347
+ emotion5,
348
+ emotion6,
349
+ emotion7,
350
+ emotion8,
351
+ vq_single_slider,
352
+ fmax_slider,
353
+ pitch_std_slider,
354
+ speaking_rate_slider,
355
+ dnsmos_slider,
356
+ speaker_noised_checkbox,
357
+ cfg_scale_slider,
358
+ min_p_slider,
359
+ seed_number,
360
+ randomize_seed_toggle,
361
+ unconditional_keys,
362
+ ],
363
+ outputs=[output_audio, seed_number],
364
  )
365
 
366
  return demo
367
 
368
+
369
  if __name__ == "__main__":
370
+ demo = build_interface()
371
+ share = getenv("GRADIO_SHARE", "False").lower() in ("true", "1", "t")
372
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=share)