|
import random |
|
import numpy as np |
|
import torch |
|
from chatterbox.src.chatterbox.tts import ChatterboxTTS |
|
import gradio as gr |
|
import spaces |
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"π Running on device: {DEVICE}") |
|
|
|
|
|
MODEL = None |
|
|
|
def get_or_load_model(): |
|
"""Loads the ChatterboxTTS model if it hasn't been loaded already, |
|
and ensures it's on the correct device.""" |
|
global MODEL |
|
if MODEL is None: |
|
print("Model not loaded, initializing...") |
|
try: |
|
MODEL = ChatterboxTTS.from_pretrained(DEVICE) |
|
if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE: |
|
MODEL.to(DEVICE) |
|
print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}") |
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
raise |
|
return MODEL |
|
|
|
|
|
try: |
|
get_or_load_model() |
|
except Exception as e: |
|
print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}") |
|
|
|
def set_seed(seed: int): |
|
"""Sets the random seed for reproducibility across torch, numpy, and random.""" |
|
torch.manual_seed(seed) |
|
if DEVICE == "cuda": |
|
torch.cuda.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
|
|
@spaces.GPU |
|
def generate_tts_audio( |
|
text_input: str, |
|
exaggeration_input: float, |
|
temperature_input: float, |
|
seed_num_input: int, |
|
cfgw_input: float |
|
) -> tuple[int, np.ndarray]: |
|
""" |
|
Generates TTS audio using the ChatterboxTTS model. |
|
Args: |
|
text_input: The text to synthesize (max 300 characters). |
|
exaggeration_input: Exaggeration parameter for the model. |
|
temperature_input: Temperature parameter for the model. |
|
seed_num_input: Random seed (0 for random). |
|
cfgw_input: CFG/Pace weight. |
|
Returns: |
|
A tuple containing the sample rate (int) and the audio waveform (numpy.ndarray). |
|
""" |
|
current_model = get_or_load_model() |
|
|
|
if current_model is None: |
|
raise RuntimeError("TTS model is not loaded.") |
|
|
|
if seed_num_input != 0: |
|
set_seed(int(seed_num_input)) |
|
|
|
print(f"Generating audio for text: '{text_input[:50]}...'") |
|
wav = current_model.generate( |
|
text_input[:300], |
|
exaggeration=exaggeration_input, |
|
temperature=temperature_input, |
|
cfg_weight=cfgw_input, |
|
) |
|
print("Audio generation complete.") |
|
return (current_model.sr, wav.squeeze(0).numpy()) |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
# Ryan TTS Demo |
|
Fine tuned Chatterbox TTS model |
|
""" |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
text = gr.Textbox( |
|
value="Now let's make my favourite. Three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible.", |
|
label="Text to synthesize (max chars 300)", |
|
max_lines=5 |
|
) |
|
exaggeration = gr.Slider( |
|
0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5 |
|
) |
|
cfg_weight = gr.Slider( |
|
0.2, 1, step=.05, label="CFG/Pace", value=0.5 |
|
) |
|
|
|
with gr.Accordion("More options", open=False): |
|
seed_num = gr.Number(value=0, label="Random seed (0 for random)") |
|
temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8) |
|
|
|
run_btn = gr.Button("Generate", variant="primary") |
|
|
|
with gr.Column(): |
|
audio_output = gr.Audio(label="Output Audio") |
|
|
|
run_btn.click( |
|
fn=generate_tts_audio, |
|
inputs=[ |
|
text, |
|
exaggeration, |
|
temp, |
|
seed_num, |
|
cfg_weight, |
|
], |
|
outputs=[audio_output], |
|
) |
|
|
|
demo.launch() |