Steveeeeeeen HF staff commited on
Commit
b1f1246
·
verified ·
1 Parent(s): c28b0ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -31
app.py CHANGED
@@ -5,29 +5,21 @@ import gradio as gr
5
  from zonos.model import Zonos
6
  from zonos.conditioning import make_cond_dict
7
 
8
- # Load the hybrid model
9
  model = Zonos.from_pretrained("Zyphra/Zonos-v0.1-hybrid", device="cuda")
10
- model.bfloat16() # Switch model weights to bfloat16 precision (optional, but recommended for GPU)
11
 
12
- # Main inference function for Gradio
13
  def tts(text, reference_audio):
14
- """
15
- text: str
16
- reference_audio: (numpy.ndarray, int) -> (data, sample_rate)
17
- """
18
  if reference_audio is None:
19
- return "No reference audio provided."
20
 
21
- # reference_audio[0] is a NumPy float32 array of shape (num_samples, 1) or (num_samples,)
22
- # reference_audio[1] is the sample rate
23
- wav_np, sr = reference_audio
24
 
25
- # Convert NumPy audio to Torch tensor
26
- wav_torch = torch.from_numpy(wav_np).float().unsqueeze(0) # shape: (1, num_samples)
27
  if wav_torch.dim() == 2 and wav_torch.shape[0] > wav_torch.shape[1]:
28
- # If the shape is (samples, 1), reorder to (1, samples)
29
  wav_torch = wav_torch.T
30
-
31
  # Create speaker embedding
32
  spk_embedding = model.embed_spk_audio(wav_torch, sr)
33
 
@@ -39,35 +31,26 @@ def tts(text, reference_audio):
39
  )
40
  conditioning = model.prepare_conditioning(cond_dict)
41
 
42
- # Generate codes
43
  with torch.no_grad():
44
- torch.manual_seed(421) # Seeding for reproducible results
45
  codes = model.generate(conditioning)
46
 
47
- # Decode the codes into waveform
48
  wavs = model.autoencoder.decode(codes).cpu()
49
- out_audio = wavs[0].numpy() # shape: (num_samples,)
50
-
51
- # Return as (sample_rate, audio_ndarray) for Gradio's "audio" output
52
  return (model.autoencoder.sampling_rate, out_audio)
53
 
54
-
55
- # Define the Gradio interface
56
- # - text input for the prompt
57
- # - audio input for the speaker reference
58
- # - audio output with the generated speech
59
  demo = gr.Interface(
60
  fn=tts,
61
  inputs=[
62
  gr.Textbox(label="Text to Synthesize"),
63
- gr.Audio(label="Reference Audio (for speaker embedding)"),
64
  ],
65
  outputs=gr.Audio(label="Generated Audio"),
66
  title="Zonos TTS Demo (Hybrid)",
67
- description=(
68
- "Provide a reference audio snippet for speaker embedding, "
69
- "enter text, and generate speech with Zonos TTS."
70
- ),
71
  )
72
 
73
  if __name__ == "__main__":
 
5
  from zonos.model import Zonos
6
  from zonos.conditioning import make_cond_dict
7
 
 
8
  model = Zonos.from_pretrained("Zyphra/Zonos-v0.1-hybrid", device="cuda")
9
+ model.bfloat16()
10
 
 
11
  def tts(text, reference_audio):
 
 
 
 
12
  if reference_audio is None:
13
+ return None
14
 
15
+ # Gradio returns (sample_rate, audio_data) for type="numpy"
16
+ sr, wav_np = reference_audio
 
17
 
18
+ # Convert NumPy audio data to Torch tensor
19
+ wav_torch = torch.from_numpy(wav_np).float().unsqueeze(0)
20
  if wav_torch.dim() == 2 and wav_torch.shape[0] > wav_torch.shape[1]:
 
21
  wav_torch = wav_torch.T
22
+
23
  # Create speaker embedding
24
  spk_embedding = model.embed_spk_audio(wav_torch, sr)
25
 
 
31
  )
32
  conditioning = model.prepare_conditioning(cond_dict)
33
 
34
+ # Generate codes & decode
35
  with torch.no_grad():
36
+ torch.manual_seed(421)
37
  codes = model.generate(conditioning)
38
 
 
39
  wavs = model.autoencoder.decode(codes).cpu()
40
+ out_audio = wavs[0].numpy()
41
+
42
+ # Return a tuple of (sample_rate, audio_data) for playback
43
  return (model.autoencoder.sampling_rate, out_audio)
44
 
 
 
 
 
 
45
  demo = gr.Interface(
46
  fn=tts,
47
  inputs=[
48
  gr.Textbox(label="Text to Synthesize"),
49
+ gr.Audio(type="numpy", label="Reference Audio (Speaker)"),
50
  ],
51
  outputs=gr.Audio(label="Generated Audio"),
52
  title="Zonos TTS Demo (Hybrid)",
53
+ description="Upload a reference audio for speaker embedding, enter text, and generate speech!"
 
 
 
54
  )
55
 
56
  if __name__ == "__main__":