coqui2 / app.py
Adoetz's picture
Update app.py
838abce verified
from TTS.api import TTS
import numpy as np
import torch
import os
import gradio as gr
from scipy.io.wavfile import write as write_wav
# Check if GPU is available
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# Global variable to store the TTS model
global_tts = None
current_model_name = None
# Function to list available TTS models
def list_available_models():
tts = TTS()
model_manager = tts.list_models()
return model_manager.list_models()
# Function to check if a model is multilingual
def is_multilingual(model_name):
return "multilingual" in model_name.lower() or "xtts" in model_name.lower()
# Function to fetch available speakers from the model
def get_available_speakers(tts):
try:
# Check if the model has a speaker manager
if hasattr(tts.synthesizer, 'speaker_manager') and tts.synthesizer.speaker_manager:
return tts.synthesizer.speaker_manager.speaker_names
else:
print("Warning: No speaker manager found in the model. Using voice cloning only.")
return None # No pre-defined speakers
except Exception as e:
print(f"Error fetching speakers: {e}")
return None # Fallback to voice cloning
# Function to list .wav files in the /clone/ folder
def list_wav_files():
clone_folder = "clone"
if not os.path.exists(clone_folder):
print(f"Error: Folder '{clone_folder}' not found.")
return []
wav_files = [f for f in os.listdir(clone_folder) if f.endswith(".wav")]
if not wav_files:
print(f"No .wav files found in '{clone_folder}'.")
return []
return wav_files
# Function to initialize or update the TTS model
def initialize_or_update_tts(model_name):
global global_tts, current_model_name
if global_tts is None or model_name != current_model_name:
print(f"Loading model: {model_name}")
try:
# Load the model with explicit phonemizer configuration
global_tts = TTS(model_name=model_name, progress_bar=True)
# Force the use of gruut phonemizer
if hasattr(global_tts.synthesizer, 'phonemizer'):
global_tts.synthesizer.phonemizer = "gruut"
print("Using gruut phonemizer.")
except Exception as e:
print(f"Error loading model: {e}")
return None
global_tts.to(device)
current_model_name = model_name
return global_tts
# Function to generate TTS audio
def generate_tts_audio(text, model_name, voice_choice, speaker_name=None, wav_file_choice=None, uploaded_file=None, recorded_audio=None):
global global_tts
try:
# Initialize or update the TTS model
tts = initialize_or_update_tts(model_name)
if tts is None:
return "Error: Failed to load the TTS model.", None
# Determine the reference audio file
if voice_choice == "existing_speaker":
if not speaker_name:
return "Error: Speaker name is required for existing speaker.", None
reference_audio = None
elif voice_choice == "voice_cloning":
if recorded_audio:
# Use the recorded audio for voice cloning
reference_audio = recorded_audio
elif uploaded_file:
# Use the uploaded file for voice cloning
reference_audio = uploaded_file
elif wav_file_choice:
# Use a file from the clone folder
wav_files = list_wav_files()
if not wav_files:
return "Error: No .wav files found for voice cloning.", None
try:
wav_file_index = int(wav_file_choice.split(":")[0].strip())
if wav_file_index < 0 or wav_file_index >= len(wav_files):
return "Error: Invalid .wav file index.", None
reference_audio = os.path.join("clone", wav_files[wav_file_index])
except (ValueError, IndexError, AttributeError):
return "Error: Invalid .wav file choice.", None
else:
return "Error: No reference audio provided for voice cloning.", None
else:
return "Error: Invalid voice choice.", None
# Generate TTS audio
if reference_audio:
# Use reference voice (voice cloning)
if is_multilingual(model_name):
audio = tts.tts(
text=text,
speaker_wav=reference_audio,
language="en"
)
else:
audio = tts.tts(
text=text,
speaker_wav=reference_audio
)
else:
# Use existing speaker
if is_multilingual(model_name):
audio = tts.tts(
text=text,
speaker=speaker_name,
language="en"
)
else:
audio = tts.tts(
text=text,
speaker=speaker_name
)
# Convert audio to a NumPy array
audio_np = np.array(audio, dtype=np.float32)
# Save the audio as a .wav file
output_file = "output.wav"
write_wav(output_file, tts.synthesizer.output_sample_rate, audio_np)
return "Audio generated successfully!", (tts.synthesizer.output_sample_rate, audio_np)
except Exception as e:
return f"Error generating audio: {e}", None
# Gradio interface
def create_gradio_interface():
available_models = list_available_models()
wav_files = list_wav_files()
wav_file_choices = [f"{i}: {file}" for i, file in enumerate(wav_files)]
with gr.Blocks() as demo:
gr.Markdown("# TTS Streaming System")
with gr.Row():
text_input = gr.Textbox(label="Enter text to generate speech", lines=3)
with gr.Row():
model_name = gr.Dropdown(choices=available_models, label="Select TTS Model", value=available_models[0] if available_models else None)
with gr.Row():
voice_choice = gr.Radio(
choices=["existing_speaker", "voice_cloning"],
label="Select voice type",
value="existing_speaker"
)
with gr.Row():
speaker_name = gr.Dropdown(
label="Select a speaker",
visible=True
)
wav_file_choice = gr.Dropdown(
choices=wav_file_choices,
label="Select a .wav file for cloning",
visible=False
)
uploaded_file = gr.Audio(
label="Upload your own .wav file for cloning",
type="filepath",
visible=False
)
recorded_audio = gr.Microphone(
label="Record your voice for cloning",
type="filepath",
visible=False
)
with gr.Row():
submit_button = gr.Button("Generate Speech")
with gr.Row():
output_text = gr.Textbox(label="Output", interactive=False)
output_audio = gr.Audio(label="Generated Audio", type="numpy", visible=True)
def update_components(choice, model_name):
tts = initialize_or_update_tts(model_name)
available_speakers = get_available_speakers(tts)
if choice == "existing_speaker":
return (
gr.update(visible=True, choices=available_speakers if available_speakers else []), # speaker_name
gr.update(visible=False), # wav_file_choice
gr.update(visible=False), # uploaded_file
gr.update(visible=False) # recorded_audio
)
elif choice == "voice_cloning":
return (
gr.update(visible=False), # speaker_name
gr.update(visible=bool(wav_files)), # wav_file_choice
gr.update(visible=True), # uploaded_file
gr.update(visible=True) # recorded_audio
)
else:
return (
gr.update(visible=False), # speaker_name
gr.update(visible=False), # wav_file_choice
gr.update(visible=False), # uploaded_file
gr.update(visible=False) # recorded_audio
)
voice_choice.change(update_components, inputs=[voice_choice, model_name], outputs=[speaker_name, wav_file_choice, uploaded_file, recorded_audio])
model_name.change(update_components, inputs=[voice_choice, model_name], outputs=[speaker_name, wav_file_choice, uploaded_file, recorded_audio])
# Enable concurrency for the submit button
submit_button.click(
generate_tts_audio,
inputs=[text_input, model_name, voice_choice, speaker_name, wav_file_choice, uploaded_file, recorded_audio],
outputs=[output_text, output_audio],
concurrency_limit=10 # Adjust this value based on your system's capabilities
)
return demo
# Launch Gradio interface
if __name__ == "__main__":
demo = create_gradio_interface()
demo.launch(share=True) # Set share=True to create a public link