Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,384 Bytes
748ecaa d743fc1 748ecaa 46f1390 e5d26e9 748ecaa 46f1390 b1f1246 46f1390 e5d26e9 46f1390 748ecaa 46f1390 e5d26e9 748ecaa d743fc1 46f1390 d743fc1 46f1390 d743fc1 46f1390 748ecaa 46f1390 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import torch
import torchaudio
import gradio as gr
from zonos.model import Zonos
from zonos.conditioning import make_cond_dict
# Global cache to hold the loaded model
MODEL = None
device = "cuda"
def load_model():
"""
Loads the Zonos model once and caches it globally.
Adjust the model name to the one you want to use.
"""
global MODEL
if MODEL is None:
model_name = "Zyphra/Zonos-v0.1-hybrid"
print(f"Loading model: {model_name}")
MODEL = Zonos.from_pretrained(model_name, device="cuda")
MODEL = MODEL.requires_grad_(False).eval()
MODEL.bfloat16() # optional, if your GPU supports bfloat16
print("Model loaded successfully!")
return MODEL
def tts(text, speaker_audio):
"""
text: str
speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy"
Returns (sample_rate, waveform) for Gradio audio output.
"""
model = load_model()
if not text:
return None
# If the user hasn't provided any audio, just return None or a placeholder
if speaker_audio is None:
return None
# Gradio provides audio in the format (sample_rate, numpy_array)
sr, wav_np = speaker_audio
# Convert to Torch tensor: shape (1, num_samples)
wav_tensor = torch.from_numpy(wav_np).unsqueeze(0).float()
if wav_tensor.dim() == 2 and wav_tensor.shape[0] > wav_tensor.shape[1]:
# If shape is transposed, fix it
wav_tensor = wav_tensor.T
# Get speaker embedding
with torch.no_grad():
spk_embedding = model.make_speaker_embedding(wav_tensor, sr)
spk_embedding = spk_embedding.to(device, dtype=torch.bfloat16)
# Prepare conditioning dictionary
cond_dict = make_cond_dict(
text=text, # The text prompt
speaker=spk_embedding, # Speaker embedding from reference audio
language="en-us", # Hard-coded language or switch to another if needed
device=device,
)
conditioning = model.prepare_conditioning(cond_dict)
# Generate codes
with torch.no_grad():
# Optionally set a manual seed for reproducibility
# torch.manual_seed(1234)
codes = model.generate(conditioning)
# Decode the codes into raw audio
wav_out = model.autoencoder.decode(codes).cpu().detach().squeeze()
sr_out = model.autoencoder.sampling_rate
return (sr_out, wav_out.numpy())
def build_demo():
with gr.Blocks() as demo:
gr.Markdown("# Simple Zonos TTS Demo (Text + Reference Audio)")
with gr.Row():
text_input = gr.Textbox(
label="Text Prompt",
value="Hello from Zonos!",
lines=3
)
ref_audio_input = gr.Audio(
label="Reference Audio (Speaker Cloning)",
type="numpy"
)
generate_button = gr.Button("Generate")
# The output will be an audio widget that Gradio will play
audio_output = gr.Audio(label="Synthesized Output", type="numpy")
# Bind the generate button
generate_button.click(
fn=tts,
inputs=[text_input, ref_audio_input],
outputs=audio_output,
)
return demo
if __name__ == "__main__":
demo_app = build_demo()
demo_app.launch(server_name="0.0.0.0", server_port=7860, share=True)
|