import torch import torchaudio import gradio as gr from zonos.model import Zonos from zonos.conditioning import make_cond_dict # Global cache to hold the loaded model MODEL = None device = "cuda" def load_model(): """ Loads the Zonos model once and caches it globally. Adjust the model name to the one you want to use. """ global MODEL if MODEL is None: model_name = "Zyphra/Zonos-v0.1-hybrid" print(f"Loading model: {model_name}") MODEL = Zonos.from_pretrained(model_name, device="cuda") MODEL = MODEL.requires_grad_(False).eval() MODEL.bfloat16() # optional, if your GPU supports bfloat16 print("Model loaded successfully!") return MODEL def tts(text, speaker_audio): """ text: str speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy" Returns (sample_rate, waveform) for Gradio audio output. """ model = load_model() if not text: return None # If the user hasn't provided any audio, just return None or a placeholder if speaker_audio is None: return None # Gradio provides audio in the format (sample_rate, numpy_array) sr, wav_np = speaker_audio # Convert to Torch tensor: shape (1, num_samples) wav_tensor = torch.from_numpy(wav_np).unsqueeze(0).float() if wav_tensor.dim() == 2 and wav_tensor.shape[0] > wav_tensor.shape[1]: # If shape is transposed, fix it wav_tensor = wav_tensor.T # Get speaker embedding with torch.no_grad(): spk_embedding = model.make_speaker_embedding(wav_tensor, sr) spk_embedding = spk_embedding.to(device, dtype=torch.bfloat16) # Prepare conditioning dictionary cond_dict = make_cond_dict( text=text, # The text prompt speaker=spk_embedding, # Speaker embedding from reference audio language="en-us", # Hard-coded language or switch to another if needed device=device, ) conditioning = model.prepare_conditioning(cond_dict) # Generate codes with torch.no_grad(): # Optionally set a manual seed for reproducibility # torch.manual_seed(1234) codes = model.generate(conditioning) # Decode the codes into raw audio wav_out = model.autoencoder.decode(codes).cpu().detach().squeeze() sr_out = model.autoencoder.sampling_rate return (sr_out, wav_out.numpy()) def build_demo(): with gr.Blocks() as demo: gr.Markdown("# Simple Zonos TTS Demo (Text + Reference Audio)") with gr.Row(): text_input = gr.Textbox( label="Text Prompt", value="Hello from Zonos!", lines=3 ) ref_audio_input = gr.Audio( label="Reference Audio (Speaker Cloning)", type="numpy" ) generate_button = gr.Button("Generate") # The output will be an audio widget that Gradio will play audio_output = gr.Audio(label="Synthesized Output", type="numpy") # Bind the generate button generate_button.click( fn=tts, inputs=[text_input, ref_audio_input], outputs=audio_output, ) return demo if __name__ == "__main__": demo_app = build_demo() demo_app.launch(server_name="0.0.0.0", server_port=7860, share=True)