Spaces:
Running
on
Zero
Running
on
Zero
Add v1.0.0 model support with KPipeline implementation
Browse files- README.md +1 -0
- app.py +36 -16
- requirements.txt +5 -1
- tts_factory.py +22 -0
- tts_model_v1.py +168 -0
README.md
CHANGED
@@ -10,6 +10,7 @@ pinned: true
|
|
10 |
short_description: Accelerated Text-To-Speech on Kokoro-82M
|
11 |
models:
|
12 |
- hexgrad/kLegacy
|
|
|
13 |
---
|
14 |
|
15 |
# Kokoro TTS Demo Space
|
|
|
10 |
short_description: Accelerated Text-To-Speech on Kokoro-82M
|
11 |
models:
|
12 |
- hexgrad/kLegacy
|
13 |
+
- hexgrad/Kokoro-82M
|
14 |
---
|
15 |
|
16 |
# Kokoro TTS Demo Space
|
app.py
CHANGED
@@ -9,13 +9,13 @@ from lib import format_audio_output
|
|
9 |
from lib.ui_content import header_html, demo_text_info, styling
|
10 |
from lib.book_utils import get_available_books, get_book_info, get_chapter_text
|
11 |
from lib.text_utils import count_tokens
|
12 |
-
from
|
13 |
|
14 |
# Set HF_HOME for faster restarts with cached models/voices
|
15 |
os.environ["HF_HOME"] = "/data/.huggingface"
|
16 |
|
17 |
-
#
|
18 |
-
model =
|
19 |
|
20 |
# Configure logging
|
21 |
logging.basicConfig(level=logging.DEBUG)
|
@@ -24,21 +24,24 @@ logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|
24 |
logger = logging.getLogger(__name__)
|
25 |
logger.debug("Starting app initialization...")
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
def initialize_model():
|
30 |
"""Initialize model and get voices"""
|
31 |
-
|
|
|
|
|
|
|
32 |
if not model.initialize():
|
33 |
raise gr.Error("Failed to initialize model")
|
34 |
-
|
35 |
-
voices = model.list_voices()
|
36 |
-
if not voices:
|
37 |
-
raise gr.Error("No voices found. Please check the voices directory.")
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
def update_progress(chunk_num, total_chunks, tokens_per_sec, rtf, progress_state, start_time, gpu_timeout, progress):
|
44 |
# Calculate time metrics
|
@@ -382,6 +385,14 @@ with gr.Blocks(title="Kokoro TTS Demo", css=styling) as demo:
|
|
382 |
)
|
383 |
|
384 |
with gr.Group():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
voice_dropdown = gr.Dropdown(
|
386 |
label="Voice(s)",
|
387 |
choices=[], # Start empty, will be populated after initialization
|
@@ -390,6 +401,15 @@ with gr.Blocks(title="Kokoro TTS Demo", css=styling) as demo:
|
|
390 |
multiselect=True
|
391 |
)
|
392 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
393 |
speed_slider = gr.Slider(
|
394 |
label="Speed",
|
395 |
minimum=0.5,
|
@@ -436,9 +456,9 @@ with gr.Blocks(title="Kokoro TTS Demo", css=styling) as demo:
|
|
436 |
with gr.Column():
|
437 |
gr.Markdown(demo_text_info)
|
438 |
|
439 |
-
# Initialize voices on load
|
440 |
demo.load(
|
441 |
-
fn=initialize_model,
|
442 |
outputs=[voice_dropdown]
|
443 |
)
|
444 |
|
|
|
9 |
from lib.ui_content import header_html, demo_text_info, styling
|
10 |
from lib.book_utils import get_available_books, get_book_info, get_chapter_text
|
11 |
from lib.text_utils import count_tokens
|
12 |
+
from tts_factory import TTSFactory
|
13 |
|
14 |
# Set HF_HOME for faster restarts with cached models/voices
|
15 |
os.environ["HF_HOME"] = "/data/.huggingface"
|
16 |
|
17 |
+
# Initialize model variable
|
18 |
+
model = None
|
19 |
|
20 |
# Configure logging
|
21 |
logging.basicConfig(level=logging.DEBUG)
|
|
|
24 |
logger = logging.getLogger(__name__)
|
25 |
logger.debug("Starting app initialization...")
|
26 |
|
27 |
+
def initialize_model(version="v0.19"):
|
|
|
|
|
28 |
"""Initialize model and get voices"""
|
29 |
+
global model
|
30 |
+
try:
|
31 |
+
# Create model instance using factory
|
32 |
+
model = TTSFactory.create_model(version)
|
33 |
if not model.initialize():
|
34 |
raise gr.Error("Failed to initialize model")
|
|
|
|
|
|
|
|
|
35 |
|
36 |
+
voices = model.list_voices()
|
37 |
+
if not voices:
|
38 |
+
raise gr.Error("No voices found. Please check the voices directory.")
|
39 |
+
|
40 |
+
default_voice = 'af_sky' if 'af_sky' in voices else voices[0] if voices else None
|
41 |
+
|
42 |
+
return gr.update(choices=voices, value=default_voice)
|
43 |
+
except Exception as e:
|
44 |
+
raise gr.Error(f"Failed to initialize model: {str(e)}")
|
45 |
|
46 |
def update_progress(chunk_num, total_chunks, tokens_per_sec, rtf, progress_state, start_time, gpu_timeout, progress):
|
47 |
# Calculate time metrics
|
|
|
385 |
)
|
386 |
|
387 |
with gr.Group():
|
388 |
+
version_dropdown = gr.Dropdown(
|
389 |
+
label="Model Version",
|
390 |
+
choices=["v0.19", "v1.0.0"],
|
391 |
+
value="v0.19",
|
392 |
+
allow_custom_value=False,
|
393 |
+
multiselect=False
|
394 |
+
)
|
395 |
+
|
396 |
voice_dropdown = gr.Dropdown(
|
397 |
label="Voice(s)",
|
398 |
choices=[], # Start empty, will be populated after initialization
|
|
|
401 |
multiselect=True
|
402 |
)
|
403 |
|
404 |
+
def on_version_change(version):
|
405 |
+
return initialize_model(version)
|
406 |
+
|
407 |
+
version_dropdown.change(
|
408 |
+
fn=on_version_change,
|
409 |
+
inputs=[version_dropdown],
|
410 |
+
outputs=[voice_dropdown]
|
411 |
+
)
|
412 |
+
|
413 |
speed_slider = gr.Slider(
|
414 |
label="Speed",
|
415 |
minimum=0.5,
|
|
|
456 |
with gr.Column():
|
457 |
gr.Markdown(demo_text_info)
|
458 |
|
459 |
+
# Initialize voices on load with default version
|
460 |
demo.load(
|
461 |
+
fn=lambda: initialize_model("v0.19"),
|
462 |
outputs=[voice_dropdown]
|
463 |
)
|
464 |
|
requirements.txt
CHANGED
@@ -9,4 +9,8 @@ regex==2024.11.6
|
|
9 |
tiktoken==0.8.0
|
10 |
transformers==4.47.1
|
11 |
munch==4.0.0
|
12 |
-
matplotlib==3.4.3
|
|
|
|
|
|
|
|
|
|
9 |
tiktoken==0.8.0
|
10 |
transformers==4.47.1
|
11 |
munch==4.0.0
|
12 |
+
matplotlib==3.4.3
|
13 |
+
|
14 |
+
# v1.0.0 dependencies
|
15 |
+
kokoro>=1.0.0
|
16 |
+
misaki[en]>=0.1.0
|
tts_factory.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tts_model import TTSModel
|
2 |
+
from tts_model_v1 import TTSModelV1
|
3 |
+
|
4 |
+
class TTSFactory:
|
5 |
+
"""Factory class to create appropriate TTS model version"""
|
6 |
+
|
7 |
+
@staticmethod
|
8 |
+
def create_model(version="v0.19"):
|
9 |
+
"""Create TTS model instance for specified version
|
10 |
+
|
11 |
+
Args:
|
12 |
+
version: Model version to use ("v0.19" or "v1.0.0")
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
TTSModel or TTSModelV1 instance
|
16 |
+
"""
|
17 |
+
if version == "v0.19":
|
18 |
+
return TTSModel()
|
19 |
+
elif version == "v1.0.0":
|
20 |
+
return TTSModelV1()
|
21 |
+
else:
|
22 |
+
raise ValueError(f"Unsupported version: {version}")
|
tts_model_v1.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import time
|
5 |
+
from typing import Tuple, List
|
6 |
+
import soundfile as sf
|
7 |
+
from kokoro import KPipeline
|
8 |
+
import spaces
|
9 |
+
|
10 |
+
class TTSModelV1:
|
11 |
+
"""KPipeline-based TTS model for v1.0.0"""
|
12 |
+
|
13 |
+
def __init__(self):
|
14 |
+
self.pipeline = None
|
15 |
+
self.voices_dir = "voices"
|
16 |
+
self.model_repo = "hexgrad/Kokoro-82M"
|
17 |
+
|
18 |
+
def initialize(self) -> bool:
|
19 |
+
"""Initialize KPipeline and verify voices"""
|
20 |
+
try:
|
21 |
+
print("Initializing v1.0.0 model...")
|
22 |
+
|
23 |
+
# Initialize KPipeline with American English
|
24 |
+
self.pipeline = KPipeline(lang_code='a')
|
25 |
+
|
26 |
+
# Verify local voice files are available
|
27 |
+
voices_dir = os.path.join(self.voices_dir, "voices")
|
28 |
+
if not os.path.exists(voices_dir):
|
29 |
+
raise ValueError("Voice files not found")
|
30 |
+
|
31 |
+
# Verify voices were downloaded successfully
|
32 |
+
available_voices = self.list_voices()
|
33 |
+
if not available_voices:
|
34 |
+
print("Warning: No voices found after initialization")
|
35 |
+
else:
|
36 |
+
print(f"Found {len(available_voices)} voices")
|
37 |
+
|
38 |
+
print("Model initialization complete")
|
39 |
+
return True
|
40 |
+
|
41 |
+
except Exception as e:
|
42 |
+
print(f"Error initializing model: {str(e)}")
|
43 |
+
return False
|
44 |
+
|
45 |
+
def list_voices(self) -> List[str]:
|
46 |
+
"""List available voices"""
|
47 |
+
voices = []
|
48 |
+
voices_subdir = os.path.join(self.voices_dir, "voices")
|
49 |
+
if os.path.exists(voices_subdir):
|
50 |
+
for file in os.listdir(voices_subdir):
|
51 |
+
if file.endswith(".pt"):
|
52 |
+
voice_name = file[:-3]
|
53 |
+
voices.append(voice_name)
|
54 |
+
return voices
|
55 |
+
|
56 |
+
@spaces.GPU(duration=None) # Duration will be set by the UI
|
57 |
+
def generate_speech(self, text: str, voice_names: list[str], speed: float = 1.0, gpu_timeout: int = 60, progress_callback=None, progress_state=None, progress=None) -> Tuple[np.ndarray, float]:
|
58 |
+
"""Generate speech from text using KPipeline
|
59 |
+
|
60 |
+
Args:
|
61 |
+
text: Input text to convert to speech
|
62 |
+
voice_names: List of voice names to use (will be mixed if multiple)
|
63 |
+
speed: Speech speed multiplier
|
64 |
+
progress_callback: Optional callback function
|
65 |
+
progress_state: Dictionary tracking generation progress metrics
|
66 |
+
progress: Progress callback from Gradio
|
67 |
+
"""
|
68 |
+
try:
|
69 |
+
start_time = time.time()
|
70 |
+
|
71 |
+
if not text or not voice_names:
|
72 |
+
raise ValueError("Text and voice name are required")
|
73 |
+
|
74 |
+
# Handle voice mixing
|
75 |
+
if isinstance(voice_names, list) and len(voice_names) > 1:
|
76 |
+
t_voices = []
|
77 |
+
for voice in voice_names:
|
78 |
+
try:
|
79 |
+
voice_path = os.path.join(self.voices_dir, "voices", f"{voice}.pt")
|
80 |
+
try:
|
81 |
+
voicepack = torch.load(voice_path, weights_only=True)
|
82 |
+
except Exception as e:
|
83 |
+
print(f"Warning: weights_only load failed, attempting full load: {str(e)}")
|
84 |
+
voicepack = torch.load(voice_path, weights_only=False)
|
85 |
+
t_voices.append(voicepack)
|
86 |
+
except Exception as e:
|
87 |
+
print(f"Warning: Failed to load voice {voice}: {str(e)}")
|
88 |
+
|
89 |
+
# Combine voices by taking mean
|
90 |
+
voicepack = torch.mean(torch.stack(t_voices), dim=0)
|
91 |
+
voice_name = "_".join(voice_names)
|
92 |
+
# Save mixed voice temporarily
|
93 |
+
mixed_voice_path = os.path.join(self.voices_dir, "voices", f"{voice_name}.pt")
|
94 |
+
torch.save(voicepack, mixed_voice_path)
|
95 |
+
else:
|
96 |
+
voice_name = voice_names[0]
|
97 |
+
|
98 |
+
# Generate speech using KPipeline
|
99 |
+
generator = self.pipeline(
|
100 |
+
text,
|
101 |
+
voice=voice_name,
|
102 |
+
speed=speed,
|
103 |
+
split_pattern=r'\n+' # Default chunking pattern
|
104 |
+
)
|
105 |
+
|
106 |
+
# Process chunks and collect metrics
|
107 |
+
audio_chunks = []
|
108 |
+
chunk_times = []
|
109 |
+
chunk_sizes = []
|
110 |
+
total_tokens = 0
|
111 |
+
|
112 |
+
for i, (gs, ps, audio) in enumerate(generator):
|
113 |
+
chunk_start = time.time()
|
114 |
+
|
115 |
+
# Store chunk audio
|
116 |
+
audio_chunks.append(audio)
|
117 |
+
|
118 |
+
# Calculate metrics
|
119 |
+
chunk_time = time.time() - chunk_start
|
120 |
+
chunk_times.append(chunk_time)
|
121 |
+
chunk_sizes.append(len(gs)) # Use grapheme length as chunk size
|
122 |
+
|
123 |
+
# Update progress if callback provided
|
124 |
+
if progress_callback:
|
125 |
+
chunk_duration = len(audio) / 24000
|
126 |
+
rtf = chunk_time / chunk_duration
|
127 |
+
progress_callback(
|
128 |
+
i + 1,
|
129 |
+
-1, # Total chunks unknown with generator
|
130 |
+
len(gs) / chunk_time, # tokens/sec
|
131 |
+
rtf,
|
132 |
+
progress_state,
|
133 |
+
start_time,
|
134 |
+
gpu_timeout,
|
135 |
+
progress
|
136 |
+
)
|
137 |
+
|
138 |
+
print(f"Chunk {i+1} processed in {chunk_time:.2f}s")
|
139 |
+
print(f"Graphemes: {gs}")
|
140 |
+
print(f"Phonemes: {ps}")
|
141 |
+
|
142 |
+
# Concatenate audio chunks
|
143 |
+
audio = np.concatenate(audio_chunks)
|
144 |
+
|
145 |
+
# Cleanup temporary mixed voice if created
|
146 |
+
if len(voice_names) > 1:
|
147 |
+
try:
|
148 |
+
os.remove(mixed_voice_path)
|
149 |
+
except:
|
150 |
+
pass
|
151 |
+
|
152 |
+
# Return audio and metrics
|
153 |
+
return (
|
154 |
+
audio,
|
155 |
+
len(audio) / 24000,
|
156 |
+
{
|
157 |
+
"chunk_times": chunk_times,
|
158 |
+
"chunk_sizes": chunk_sizes,
|
159 |
+
"tokens_per_sec": [float(x) for x in progress_state["tokens_per_sec"]] if progress_state else [],
|
160 |
+
"rtf": [float(x) for x in progress_state["rtf"]] if progress_state else [],
|
161 |
+
"total_tokens": total_tokens,
|
162 |
+
"total_time": time.time() - start_time
|
163 |
+
}
|
164 |
+
)
|
165 |
+
|
166 |
+
except Exception as e:
|
167 |
+
print(f"Error generating speech: {str(e)}")
|
168 |
+
raise
|