Zonos / app.py
Steveeeeeeen's picture
Steveeeeeeen HF staff
Update app.py
e5d26e9 verified
raw
history blame
3.38 kB
import torch
import torchaudio
import gradio as gr
from zonos.model import Zonos
from zonos.conditioning import make_cond_dict
# Global cache to hold the loaded model
MODEL = None
device = "cuda"
def load_model():
"""
Loads the Zonos model once and caches it globally.
Adjust the model name to the one you want to use.
"""
global MODEL
if MODEL is None:
model_name = "Zyphra/Zonos-v0.1-hybrid"
print(f"Loading model: {model_name}")
MODEL = Zonos.from_pretrained(model_name, device="cuda")
MODEL = MODEL.requires_grad_(False).eval()
MODEL.bfloat16() # optional, if your GPU supports bfloat16
print("Model loaded successfully!")
return MODEL
def tts(text, speaker_audio):
"""
text: str
speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy"
Returns (sample_rate, waveform) for Gradio audio output.
"""
model = load_model()
if not text:
return None
# If the user hasn't provided any audio, just return None or a placeholder
if speaker_audio is None:
return None
# Gradio provides audio in the format (sample_rate, numpy_array)
sr, wav_np = speaker_audio
# Convert to Torch tensor: shape (1, num_samples)
wav_tensor = torch.from_numpy(wav_np).unsqueeze(0).float()
if wav_tensor.dim() == 2 and wav_tensor.shape[0] > wav_tensor.shape[1]:
# If shape is transposed, fix it
wav_tensor = wav_tensor.T
# Get speaker embedding
with torch.no_grad():
spk_embedding = model.make_speaker_embedding(wav_tensor, sr)
spk_embedding = spk_embedding.to(device, dtype=torch.bfloat16)
# Prepare conditioning dictionary
cond_dict = make_cond_dict(
text=text, # The text prompt
speaker=spk_embedding, # Speaker embedding from reference audio
language="en-us", # Hard-coded language or switch to another if needed
device=device,
)
conditioning = model.prepare_conditioning(cond_dict)
# Generate codes
with torch.no_grad():
# Optionally set a manual seed for reproducibility
# torch.manual_seed(1234)
codes = model.generate(conditioning)
# Decode the codes into raw audio
wav_out = model.autoencoder.decode(codes).cpu().detach().squeeze()
sr_out = model.autoencoder.sampling_rate
return (sr_out, wav_out.numpy())
def build_demo():
with gr.Blocks() as demo:
gr.Markdown("# Simple Zonos TTS Demo (Text + Reference Audio)")
with gr.Row():
text_input = gr.Textbox(
label="Text Prompt",
value="Hello from Zonos!",
lines=3
)
ref_audio_input = gr.Audio(
label="Reference Audio (Speaker Cloning)",
type="numpy"
)
generate_button = gr.Button("Generate")
# The output will be an audio widget that Gradio will play
audio_output = gr.Audio(label="Synthesized Output", type="numpy")
# Bind the generate button
generate_button.click(
fn=tts,
inputs=[text_input, ref_audio_input],
outputs=audio_output,
)
return demo
if __name__ == "__main__":
demo_app = build_demo()
demo_app.launch(server_name="0.0.0.0", server_port=7860, share=True)