Adoetz commited on
Commit
af17061
·
verified ·
1 Parent(s): 58c996f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +237 -0
app.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from TTS.api import TTS
2
+ import numpy as np
3
+ import torch
4
+ import os
5
+ import gradio as gr
6
+ from scipy.io.wavfile import write as write_wav
7
+
8
+ # Check if GPU is available
9
+ if torch.cuda.is_available():
10
+ device = "cuda"
11
+ else:
12
+ device = "cpu"
13
+
14
+ # Global variable to store the TTS model
15
+ global_tts = None
16
+ current_model_name = None
17
+
18
+ # Function to list available TTS models
19
+ def list_available_models():
20
+ tts = TTS()
21
+ model_manager = tts.list_models()
22
+ return model_manager.list_models()
23
+
24
+ # Function to check if a model is multilingual
25
+ def is_multilingual(model_name):
26
+ return "multilingual" in model_name.lower() or "xtts" in model_name.lower()
27
+
28
+ # Function to fetch available speakers from the model
29
+ def get_available_speakers(tts):
30
+ try:
31
+ # Check if the model has a speaker manager
32
+ if hasattr(tts.synthesizer, 'speaker_manager') and tts.synthesizer.speaker_manager:
33
+ return tts.synthesizer.speaker_manager.speaker_names
34
+ else:
35
+ print("Warning: No speaker manager found in the model. Using voice cloning only.")
36
+ return None # No pre-defined speakers
37
+ except Exception as e:
38
+ print(f"Error fetching speakers: {e}")
39
+ return None # Fallback to voice cloning
40
+
41
+ # Function to list .wav files in the /clone/ folder
42
+ def list_wav_files():
43
+ clone_folder = "clone"
44
+ if not os.path.exists(clone_folder):
45
+ print(f"Error: Folder '{clone_folder}' not found.")
46
+ return []
47
+
48
+ wav_files = [f for f in os.listdir(clone_folder) if f.endswith(".wav")]
49
+ if not wav_files:
50
+ print(f"No .wav files found in '{clone_folder}'.")
51
+ return []
52
+
53
+ return wav_files
54
+
55
+ # Function to initialize or update the TTS model
56
+ def initialize_or_update_tts(model_name):
57
+ global global_tts, current_model_name
58
+ if global_tts is None or model_name != current_model_name:
59
+ print(f"Loading model: {model_name}")
60
+ try:
61
+ # Try loading the model with espeak
62
+ global_tts = TTS(model_name=model_name, progress_bar=True)
63
+ except Exception as e:
64
+ print(f"Error loading model with espeak: {e}")
65
+ print("Falling back to gruut phonemizer...")
66
+ # Load the model with gruut phonemizer
67
+ global_tts = TTS(model_name=model_name, progress_bar=True)
68
+ if hasattr(global_tts.synthesizer, 'phonemizer'):
69
+ global_tts.synthesizer.phonemizer = "gruut"
70
+
71
+ global_tts.to(device)
72
+ current_model_name = model_name
73
+ return global_tts
74
+
75
+ # Function to generate TTS audio
76
+ def generate_tts_audio(text, model_name, voice_choice, speaker_name=None, wav_file_choice=None, uploaded_file=None, recorded_audio=None):
77
+ global global_tts
78
+ try:
79
+ # Initialize or update the TTS model
80
+ tts = initialize_or_update_tts(model_name)
81
+
82
+ # Determine the reference audio file
83
+ if voice_choice == "existing_speaker":
84
+ if not speaker_name:
85
+ return "Error: Speaker name is required for existing speaker.", None
86
+ reference_audio = None
87
+ elif voice_choice == "voice_cloning":
88
+ if recorded_audio:
89
+ # Use the recorded audio for voice cloning
90
+ reference_audio = recorded_audio
91
+ elif uploaded_file:
92
+ # Use the uploaded file for voice cloning
93
+ reference_audio = uploaded_file
94
+ elif wav_file_choice:
95
+ # Use a file from the clone folder
96
+ wav_files = list_wav_files()
97
+ if not wav_files:
98
+ return "Error: No .wav files found for voice cloning.", None
99
+
100
+ try:
101
+ wav_file_index = int(wav_file_choice.split(":")[0].strip())
102
+ if wav_file_index < 0 or wav_file_index >= len(wav_files):
103
+ return "Error: Invalid .wav file index.", None
104
+ reference_audio = os.path.join("clone", wav_files[wav_file_index])
105
+ except (ValueError, IndexError, AttributeError):
106
+ return "Error: Invalid .wav file choice.", None
107
+ else:
108
+ return "Error: No reference audio provided for voice cloning.", None
109
+ else:
110
+ return "Error: Invalid voice choice.", None
111
+
112
+ # Generate TTS audio
113
+ if reference_audio:
114
+ # Use reference voice (voice cloning)
115
+ if is_multilingual(model_name):
116
+ audio = tts.tts(
117
+ text=text,
118
+ speaker_wav=reference_audio,
119
+ language="en"
120
+ )
121
+ else:
122
+ audio = tts.tts(
123
+ text=text,
124
+ speaker_wav=reference_audio
125
+ )
126
+ else:
127
+ # Use existing speaker
128
+ if is_multilingual(model_name):
129
+ audio = tts.tts(
130
+ text=text,
131
+ speaker=speaker_name,
132
+ language="en"
133
+ )
134
+ else:
135
+ audio = tts.tts(
136
+ text=text,
137
+ speaker=speaker_name
138
+ )
139
+
140
+ # Convert audio to a NumPy array
141
+ audio_np = np.array(audio, dtype=np.float32)
142
+
143
+ # Save the audio as a .wav file
144
+ output_file = "output.wav"
145
+ write_wav(output_file, tts.synthesizer.output_sample_rate, audio_np)
146
+
147
+ return "Audio generated successfully!", (tts.synthesizer.output_sample_rate, audio_np)
148
+ except Exception as e:
149
+ return f"Error generating audio: {e}", None
150
+
151
+ # Gradio interface
152
+ def create_gradio_interface():
153
+ available_models = list_available_models()
154
+ wav_files = list_wav_files()
155
+ wav_file_choices = [f"{i}: {file}" for i, file in enumerate(wav_files)]
156
+
157
+ with gr.Blocks() as demo:
158
+ gr.Markdown("# TTS Streaming System")
159
+ with gr.Row():
160
+ text_input = gr.Textbox(label="Enter text to generate speech", lines=3)
161
+ with gr.Row():
162
+ model_name = gr.Dropdown(choices=available_models, label="Select TTS Model", value=available_models[0] if available_models else None)
163
+ with gr.Row():
164
+ voice_choice = gr.Radio(
165
+ choices=["existing_speaker", "voice_cloning"],
166
+ label="Select voice type",
167
+ value="existing_speaker"
168
+ )
169
+ with gr.Row():
170
+ speaker_name = gr.Dropdown(
171
+ label="Select a speaker",
172
+ visible=True
173
+ )
174
+ wav_file_choice = gr.Dropdown(
175
+ choices=wav_file_choices,
176
+ label="Select a .wav file for cloning",
177
+ visible=False
178
+ )
179
+ uploaded_file = gr.Audio(
180
+ label="Upload your own .wav file for cloning",
181
+ type="filepath",
182
+ visible=False
183
+ )
184
+ recorded_audio = gr.Microphone(
185
+ label="Record your voice for cloning",
186
+ type="filepath",
187
+ visible=False
188
+ )
189
+ with gr.Row():
190
+ submit_button = gr.Button("Generate Speech")
191
+ with gr.Row():
192
+ output_text = gr.Textbox(label="Output", interactive=False)
193
+ output_audio = gr.Audio(label="Generated Audio", type="numpy", visible=True)
194
+
195
+ def update_components(choice, model_name):
196
+ tts = initialize_or_update_tts(model_name)
197
+ available_speakers = get_available_speakers(tts)
198
+
199
+ if choice == "existing_speaker":
200
+ return (
201
+ gr.update(visible=True, choices=available_speakers if available_speakers else []), # speaker_name
202
+ gr.update(visible=False), # wav_file_choice
203
+ gr.update(visible=False), # uploaded_file
204
+ gr.update(visible=False) # recorded_audio
205
+ )
206
+ elif choice == "voice_cloning":
207
+ return (
208
+ gr.update(visible=False), # speaker_name
209
+ gr.update(visible=bool(wav_files)), # wav_file_choice
210
+ gr.update(visible=True), # uploaded_file
211
+ gr.update(visible=True) # recorded_audio
212
+ )
213
+ else:
214
+ return (
215
+ gr.update(visible=False), # speaker_name
216
+ gr.update(visible=False), # wav_file_choice
217
+ gr.update(visible=False), # uploaded_file
218
+ gr.update(visible=False) # recorded_audio
219
+ )
220
+
221
+ voice_choice.change(update_components, inputs=[voice_choice, model_name], outputs=[speaker_name, wav_file_choice, uploaded_file, recorded_audio])
222
+ model_name.change(update_components, inputs=[voice_choice, model_name], outputs=[speaker_name, wav_file_choice, uploaded_file, recorded_audio])
223
+
224
+ # Enable concurrency for the submit button
225
+ submit_button.click(
226
+ generate_tts_audio,
227
+ inputs=[text_input, model_name, voice_choice, speaker_name, wav_file_choice, uploaded_file, recorded_audio],
228
+ outputs=[output_text, output_audio],
229
+ concurrency_limit=10 # Adjust this value based on your system's capabilities
230
+ )
231
+
232
+ return demo
233
+
234
+ # Launch Gradio interface
235
+ if __name__ == "__main__":
236
+ demo = create_gradio_interface()
237
+ demo.launch(share=True) # Set share=True to create a public link