Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from TTS.api import TTS
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
import gradio as gr
|
6 |
+
from scipy.io.wavfile import write as write_wav
|
7 |
+
|
8 |
+
# Check if GPU is available
|
9 |
+
if torch.cuda.is_available():
|
10 |
+
device = "cuda"
|
11 |
+
else:
|
12 |
+
device = "cpu"
|
13 |
+
|
14 |
+
# Global variable to store the TTS model
|
15 |
+
global_tts = None
|
16 |
+
current_model_name = None
|
17 |
+
|
18 |
+
# Function to list available TTS models
|
19 |
+
def list_available_models():
|
20 |
+
tts = TTS()
|
21 |
+
model_manager = tts.list_models()
|
22 |
+
return model_manager.list_models()
|
23 |
+
|
24 |
+
# Function to check if a model is multilingual
|
25 |
+
def is_multilingual(model_name):
|
26 |
+
return "multilingual" in model_name.lower() or "xtts" in model_name.lower()
|
27 |
+
|
28 |
+
# Function to fetch available speakers from the model
|
29 |
+
def get_available_speakers(tts):
|
30 |
+
try:
|
31 |
+
# Check if the model has a speaker manager
|
32 |
+
if hasattr(tts.synthesizer, 'speaker_manager') and tts.synthesizer.speaker_manager:
|
33 |
+
return tts.synthesizer.speaker_manager.speaker_names
|
34 |
+
else:
|
35 |
+
print("Warning: No speaker manager found in the model. Using voice cloning only.")
|
36 |
+
return None # No pre-defined speakers
|
37 |
+
except Exception as e:
|
38 |
+
print(f"Error fetching speakers: {e}")
|
39 |
+
return None # Fallback to voice cloning
|
40 |
+
|
41 |
+
# Function to list .wav files in the /clone/ folder
|
42 |
+
def list_wav_files():
|
43 |
+
clone_folder = "clone"
|
44 |
+
if not os.path.exists(clone_folder):
|
45 |
+
print(f"Error: Folder '{clone_folder}' not found.")
|
46 |
+
return []
|
47 |
+
|
48 |
+
wav_files = [f for f in os.listdir(clone_folder) if f.endswith(".wav")]
|
49 |
+
if not wav_files:
|
50 |
+
print(f"No .wav files found in '{clone_folder}'.")
|
51 |
+
return []
|
52 |
+
|
53 |
+
return wav_files
|
54 |
+
|
55 |
+
# Function to initialize or update the TTS model
|
56 |
+
def initialize_or_update_tts(model_name):
|
57 |
+
global global_tts, current_model_name
|
58 |
+
if global_tts is None or model_name != current_model_name:
|
59 |
+
print(f"Loading model: {model_name}")
|
60 |
+
try:
|
61 |
+
# Try loading the model with espeak
|
62 |
+
global_tts = TTS(model_name=model_name, progress_bar=True)
|
63 |
+
except Exception as e:
|
64 |
+
print(f"Error loading model with espeak: {e}")
|
65 |
+
print("Falling back to gruut phonemizer...")
|
66 |
+
# Load the model with gruut phonemizer
|
67 |
+
global_tts = TTS(model_name=model_name, progress_bar=True)
|
68 |
+
if hasattr(global_tts.synthesizer, 'phonemizer'):
|
69 |
+
global_tts.synthesizer.phonemizer = "gruut"
|
70 |
+
|
71 |
+
global_tts.to(device)
|
72 |
+
current_model_name = model_name
|
73 |
+
return global_tts
|
74 |
+
|
75 |
+
# Function to generate TTS audio
|
76 |
+
def generate_tts_audio(text, model_name, voice_choice, speaker_name=None, wav_file_choice=None, uploaded_file=None, recorded_audio=None):
|
77 |
+
global global_tts
|
78 |
+
try:
|
79 |
+
# Initialize or update the TTS model
|
80 |
+
tts = initialize_or_update_tts(model_name)
|
81 |
+
|
82 |
+
# Determine the reference audio file
|
83 |
+
if voice_choice == "existing_speaker":
|
84 |
+
if not speaker_name:
|
85 |
+
return "Error: Speaker name is required for existing speaker.", None
|
86 |
+
reference_audio = None
|
87 |
+
elif voice_choice == "voice_cloning":
|
88 |
+
if recorded_audio:
|
89 |
+
# Use the recorded audio for voice cloning
|
90 |
+
reference_audio = recorded_audio
|
91 |
+
elif uploaded_file:
|
92 |
+
# Use the uploaded file for voice cloning
|
93 |
+
reference_audio = uploaded_file
|
94 |
+
elif wav_file_choice:
|
95 |
+
# Use a file from the clone folder
|
96 |
+
wav_files = list_wav_files()
|
97 |
+
if not wav_files:
|
98 |
+
return "Error: No .wav files found for voice cloning.", None
|
99 |
+
|
100 |
+
try:
|
101 |
+
wav_file_index = int(wav_file_choice.split(":")[0].strip())
|
102 |
+
if wav_file_index < 0 or wav_file_index >= len(wav_files):
|
103 |
+
return "Error: Invalid .wav file index.", None
|
104 |
+
reference_audio = os.path.join("clone", wav_files[wav_file_index])
|
105 |
+
except (ValueError, IndexError, AttributeError):
|
106 |
+
return "Error: Invalid .wav file choice.", None
|
107 |
+
else:
|
108 |
+
return "Error: No reference audio provided for voice cloning.", None
|
109 |
+
else:
|
110 |
+
return "Error: Invalid voice choice.", None
|
111 |
+
|
112 |
+
# Generate TTS audio
|
113 |
+
if reference_audio:
|
114 |
+
# Use reference voice (voice cloning)
|
115 |
+
if is_multilingual(model_name):
|
116 |
+
audio = tts.tts(
|
117 |
+
text=text,
|
118 |
+
speaker_wav=reference_audio,
|
119 |
+
language="en"
|
120 |
+
)
|
121 |
+
else:
|
122 |
+
audio = tts.tts(
|
123 |
+
text=text,
|
124 |
+
speaker_wav=reference_audio
|
125 |
+
)
|
126 |
+
else:
|
127 |
+
# Use existing speaker
|
128 |
+
if is_multilingual(model_name):
|
129 |
+
audio = tts.tts(
|
130 |
+
text=text,
|
131 |
+
speaker=speaker_name,
|
132 |
+
language="en"
|
133 |
+
)
|
134 |
+
else:
|
135 |
+
audio = tts.tts(
|
136 |
+
text=text,
|
137 |
+
speaker=speaker_name
|
138 |
+
)
|
139 |
+
|
140 |
+
# Convert audio to a NumPy array
|
141 |
+
audio_np = np.array(audio, dtype=np.float32)
|
142 |
+
|
143 |
+
# Save the audio as a .wav file
|
144 |
+
output_file = "output.wav"
|
145 |
+
write_wav(output_file, tts.synthesizer.output_sample_rate, audio_np)
|
146 |
+
|
147 |
+
return "Audio generated successfully!", (tts.synthesizer.output_sample_rate, audio_np)
|
148 |
+
except Exception as e:
|
149 |
+
return f"Error generating audio: {e}", None
|
150 |
+
|
151 |
+
# Gradio interface
|
152 |
+
def create_gradio_interface():
|
153 |
+
available_models = list_available_models()
|
154 |
+
wav_files = list_wav_files()
|
155 |
+
wav_file_choices = [f"{i}: {file}" for i, file in enumerate(wav_files)]
|
156 |
+
|
157 |
+
with gr.Blocks() as demo:
|
158 |
+
gr.Markdown("# TTS Streaming System")
|
159 |
+
with gr.Row():
|
160 |
+
text_input = gr.Textbox(label="Enter text to generate speech", lines=3)
|
161 |
+
with gr.Row():
|
162 |
+
model_name = gr.Dropdown(choices=available_models, label="Select TTS Model", value=available_models[0] if available_models else None)
|
163 |
+
with gr.Row():
|
164 |
+
voice_choice = gr.Radio(
|
165 |
+
choices=["existing_speaker", "voice_cloning"],
|
166 |
+
label="Select voice type",
|
167 |
+
value="existing_speaker"
|
168 |
+
)
|
169 |
+
with gr.Row():
|
170 |
+
speaker_name = gr.Dropdown(
|
171 |
+
label="Select a speaker",
|
172 |
+
visible=True
|
173 |
+
)
|
174 |
+
wav_file_choice = gr.Dropdown(
|
175 |
+
choices=wav_file_choices,
|
176 |
+
label="Select a .wav file for cloning",
|
177 |
+
visible=False
|
178 |
+
)
|
179 |
+
uploaded_file = gr.Audio(
|
180 |
+
label="Upload your own .wav file for cloning",
|
181 |
+
type="filepath",
|
182 |
+
visible=False
|
183 |
+
)
|
184 |
+
recorded_audio = gr.Microphone(
|
185 |
+
label="Record your voice for cloning",
|
186 |
+
type="filepath",
|
187 |
+
visible=False
|
188 |
+
)
|
189 |
+
with gr.Row():
|
190 |
+
submit_button = gr.Button("Generate Speech")
|
191 |
+
with gr.Row():
|
192 |
+
output_text = gr.Textbox(label="Output", interactive=False)
|
193 |
+
output_audio = gr.Audio(label="Generated Audio", type="numpy", visible=True)
|
194 |
+
|
195 |
+
def update_components(choice, model_name):
|
196 |
+
tts = initialize_or_update_tts(model_name)
|
197 |
+
available_speakers = get_available_speakers(tts)
|
198 |
+
|
199 |
+
if choice == "existing_speaker":
|
200 |
+
return (
|
201 |
+
gr.update(visible=True, choices=available_speakers if available_speakers else []), # speaker_name
|
202 |
+
gr.update(visible=False), # wav_file_choice
|
203 |
+
gr.update(visible=False), # uploaded_file
|
204 |
+
gr.update(visible=False) # recorded_audio
|
205 |
+
)
|
206 |
+
elif choice == "voice_cloning":
|
207 |
+
return (
|
208 |
+
gr.update(visible=False), # speaker_name
|
209 |
+
gr.update(visible=bool(wav_files)), # wav_file_choice
|
210 |
+
gr.update(visible=True), # uploaded_file
|
211 |
+
gr.update(visible=True) # recorded_audio
|
212 |
+
)
|
213 |
+
else:
|
214 |
+
return (
|
215 |
+
gr.update(visible=False), # speaker_name
|
216 |
+
gr.update(visible=False), # wav_file_choice
|
217 |
+
gr.update(visible=False), # uploaded_file
|
218 |
+
gr.update(visible=False) # recorded_audio
|
219 |
+
)
|
220 |
+
|
221 |
+
voice_choice.change(update_components, inputs=[voice_choice, model_name], outputs=[speaker_name, wav_file_choice, uploaded_file, recorded_audio])
|
222 |
+
model_name.change(update_components, inputs=[voice_choice, model_name], outputs=[speaker_name, wav_file_choice, uploaded_file, recorded_audio])
|
223 |
+
|
224 |
+
# Enable concurrency for the submit button
|
225 |
+
submit_button.click(
|
226 |
+
generate_tts_audio,
|
227 |
+
inputs=[text_input, model_name, voice_choice, speaker_name, wav_file_choice, uploaded_file, recorded_audio],
|
228 |
+
outputs=[output_text, output_audio],
|
229 |
+
concurrency_limit=10 # Adjust this value based on your system's capabilities
|
230 |
+
)
|
231 |
+
|
232 |
+
return demo
|
233 |
+
|
234 |
+
# Launch Gradio interface
|
235 |
+
if __name__ == "__main__":
|
236 |
+
demo = create_gradio_interface()
|
237 |
+
demo.launch(share=True) # Set share=True to create a public link
|