Remsky commited on
Commit
27f8803
·
1 Parent(s): 3a912ba

Add v1.0.0 model support with KPipeline implementation

Browse files
Files changed (5) hide show
  1. README.md +1 -0
  2. app.py +36 -16
  3. requirements.txt +5 -1
  4. tts_factory.py +22 -0
  5. tts_model_v1.py +168 -0
README.md CHANGED
@@ -10,6 +10,7 @@ pinned: true
10
  short_description: Accelerated Text-To-Speech on Kokoro-82M
11
  models:
12
  - hexgrad/kLegacy
 
13
  ---
14
 
15
  # Kokoro TTS Demo Space
 
10
  short_description: Accelerated Text-To-Speech on Kokoro-82M
11
  models:
12
  - hexgrad/kLegacy
13
+ - hexgrad/Kokoro-82M
14
  ---
15
 
16
  # Kokoro TTS Demo Space
app.py CHANGED
@@ -9,13 +9,13 @@ from lib import format_audio_output
9
  from lib.ui_content import header_html, demo_text_info, styling
10
  from lib.book_utils import get_available_books, get_book_info, get_chapter_text
11
  from lib.text_utils import count_tokens
12
- from tts_model import TTSModel
13
 
14
  # Set HF_HOME for faster restarts with cached models/voices
15
  os.environ["HF_HOME"] = "/data/.huggingface"
16
 
17
- # Create TTS model instance
18
- model = TTSModel()
19
 
20
  # Configure logging
21
  logging.basicConfig(level=logging.DEBUG)
@@ -24,21 +24,24 @@ logging.getLogger('matplotlib').setLevel(logging.WARNING)
24
  logger = logging.getLogger(__name__)
25
  logger.debug("Starting app initialization...")
26
 
27
- model = TTSModel()
28
-
29
- def initialize_model():
30
  """Initialize model and get voices"""
31
- if model.model is None:
 
 
 
32
  if not model.initialize():
33
  raise gr.Error("Failed to initialize model")
34
-
35
- voices = model.list_voices()
36
- if not voices:
37
- raise gr.Error("No voices found. Please check the voices directory.")
38
 
39
- default_voice = 'af_sky' if 'af_sky' in voices else voices[0] if voices else None
40
-
41
- return gr.update(choices=voices, value=default_voice)
 
 
 
 
 
 
42
 
43
  def update_progress(chunk_num, total_chunks, tokens_per_sec, rtf, progress_state, start_time, gpu_timeout, progress):
44
  # Calculate time metrics
@@ -382,6 +385,14 @@ with gr.Blocks(title="Kokoro TTS Demo", css=styling) as demo:
382
  )
383
 
384
  with gr.Group():
 
 
 
 
 
 
 
 
385
  voice_dropdown = gr.Dropdown(
386
  label="Voice(s)",
387
  choices=[], # Start empty, will be populated after initialization
@@ -390,6 +401,15 @@ with gr.Blocks(title="Kokoro TTS Demo", css=styling) as demo:
390
  multiselect=True
391
  )
392
 
 
 
 
 
 
 
 
 
 
393
  speed_slider = gr.Slider(
394
  label="Speed",
395
  minimum=0.5,
@@ -436,9 +456,9 @@ with gr.Blocks(title="Kokoro TTS Demo", css=styling) as demo:
436
  with gr.Column():
437
  gr.Markdown(demo_text_info)
438
 
439
- # Initialize voices on load
440
  demo.load(
441
- fn=initialize_model,
442
  outputs=[voice_dropdown]
443
  )
444
 
 
9
  from lib.ui_content import header_html, demo_text_info, styling
10
  from lib.book_utils import get_available_books, get_book_info, get_chapter_text
11
  from lib.text_utils import count_tokens
12
+ from tts_factory import TTSFactory
13
 
14
  # Set HF_HOME for faster restarts with cached models/voices
15
  os.environ["HF_HOME"] = "/data/.huggingface"
16
 
17
+ # Initialize model variable
18
+ model = None
19
 
20
  # Configure logging
21
  logging.basicConfig(level=logging.DEBUG)
 
24
  logger = logging.getLogger(__name__)
25
  logger.debug("Starting app initialization...")
26
 
27
+ def initialize_model(version="v0.19"):
 
 
28
  """Initialize model and get voices"""
29
+ global model
30
+ try:
31
+ # Create model instance using factory
32
+ model = TTSFactory.create_model(version)
33
  if not model.initialize():
34
  raise gr.Error("Failed to initialize model")
 
 
 
 
35
 
36
+ voices = model.list_voices()
37
+ if not voices:
38
+ raise gr.Error("No voices found. Please check the voices directory.")
39
+
40
+ default_voice = 'af_sky' if 'af_sky' in voices else voices[0] if voices else None
41
+
42
+ return gr.update(choices=voices, value=default_voice)
43
+ except Exception as e:
44
+ raise gr.Error(f"Failed to initialize model: {str(e)}")
45
 
46
  def update_progress(chunk_num, total_chunks, tokens_per_sec, rtf, progress_state, start_time, gpu_timeout, progress):
47
  # Calculate time metrics
 
385
  )
386
 
387
  with gr.Group():
388
+ version_dropdown = gr.Dropdown(
389
+ label="Model Version",
390
+ choices=["v0.19", "v1.0.0"],
391
+ value="v0.19",
392
+ allow_custom_value=False,
393
+ multiselect=False
394
+ )
395
+
396
  voice_dropdown = gr.Dropdown(
397
  label="Voice(s)",
398
  choices=[], # Start empty, will be populated after initialization
 
401
  multiselect=True
402
  )
403
 
404
+ def on_version_change(version):
405
+ return initialize_model(version)
406
+
407
+ version_dropdown.change(
408
+ fn=on_version_change,
409
+ inputs=[version_dropdown],
410
+ outputs=[voice_dropdown]
411
+ )
412
+
413
  speed_slider = gr.Slider(
414
  label="Speed",
415
  minimum=0.5,
 
456
  with gr.Column():
457
  gr.Markdown(demo_text_info)
458
 
459
+ # Initialize voices on load with default version
460
  demo.load(
461
+ fn=lambda: initialize_model("v0.19"),
462
  outputs=[voice_dropdown]
463
  )
464
 
requirements.txt CHANGED
@@ -9,4 +9,8 @@ regex==2024.11.6
9
  tiktoken==0.8.0
10
  transformers==4.47.1
11
  munch==4.0.0
12
- matplotlib==3.4.3
 
 
 
 
 
9
  tiktoken==0.8.0
10
  transformers==4.47.1
11
  munch==4.0.0
12
+ matplotlib==3.4.3
13
+
14
+ # v1.0.0 dependencies
15
+ kokoro>=1.0.0
16
+ misaki[en]>=0.1.0
tts_factory.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tts_model import TTSModel
2
+ from tts_model_v1 import TTSModelV1
3
+
4
+ class TTSFactory:
5
+ """Factory class to create appropriate TTS model version"""
6
+
7
+ @staticmethod
8
+ def create_model(version="v0.19"):
9
+ """Create TTS model instance for specified version
10
+
11
+ Args:
12
+ version: Model version to use ("v0.19" or "v1.0.0")
13
+
14
+ Returns:
15
+ TTSModel or TTSModelV1 instance
16
+ """
17
+ if version == "v0.19":
18
+ return TTSModel()
19
+ elif version == "v1.0.0":
20
+ return TTSModelV1()
21
+ else:
22
+ raise ValueError(f"Unsupported version: {version}")
tts_model_v1.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import time
5
+ from typing import Tuple, List
6
+ import soundfile as sf
7
+ from kokoro import KPipeline
8
+ import spaces
9
+
10
+ class TTSModelV1:
11
+ """KPipeline-based TTS model for v1.0.0"""
12
+
13
+ def __init__(self):
14
+ self.pipeline = None
15
+ self.voices_dir = "voices"
16
+ self.model_repo = "hexgrad/Kokoro-82M"
17
+
18
+ def initialize(self) -> bool:
19
+ """Initialize KPipeline and verify voices"""
20
+ try:
21
+ print("Initializing v1.0.0 model...")
22
+
23
+ # Initialize KPipeline with American English
24
+ self.pipeline = KPipeline(lang_code='a')
25
+
26
+ # Verify local voice files are available
27
+ voices_dir = os.path.join(self.voices_dir, "voices")
28
+ if not os.path.exists(voices_dir):
29
+ raise ValueError("Voice files not found")
30
+
31
+ # Verify voices were downloaded successfully
32
+ available_voices = self.list_voices()
33
+ if not available_voices:
34
+ print("Warning: No voices found after initialization")
35
+ else:
36
+ print(f"Found {len(available_voices)} voices")
37
+
38
+ print("Model initialization complete")
39
+ return True
40
+
41
+ except Exception as e:
42
+ print(f"Error initializing model: {str(e)}")
43
+ return False
44
+
45
+ def list_voices(self) -> List[str]:
46
+ """List available voices"""
47
+ voices = []
48
+ voices_subdir = os.path.join(self.voices_dir, "voices")
49
+ if os.path.exists(voices_subdir):
50
+ for file in os.listdir(voices_subdir):
51
+ if file.endswith(".pt"):
52
+ voice_name = file[:-3]
53
+ voices.append(voice_name)
54
+ return voices
55
+
56
+ @spaces.GPU(duration=None) # Duration will be set by the UI
57
+ def generate_speech(self, text: str, voice_names: list[str], speed: float = 1.0, gpu_timeout: int = 60, progress_callback=None, progress_state=None, progress=None) -> Tuple[np.ndarray, float]:
58
+ """Generate speech from text using KPipeline
59
+
60
+ Args:
61
+ text: Input text to convert to speech
62
+ voice_names: List of voice names to use (will be mixed if multiple)
63
+ speed: Speech speed multiplier
64
+ progress_callback: Optional callback function
65
+ progress_state: Dictionary tracking generation progress metrics
66
+ progress: Progress callback from Gradio
67
+ """
68
+ try:
69
+ start_time = time.time()
70
+
71
+ if not text or not voice_names:
72
+ raise ValueError("Text and voice name are required")
73
+
74
+ # Handle voice mixing
75
+ if isinstance(voice_names, list) and len(voice_names) > 1:
76
+ t_voices = []
77
+ for voice in voice_names:
78
+ try:
79
+ voice_path = os.path.join(self.voices_dir, "voices", f"{voice}.pt")
80
+ try:
81
+ voicepack = torch.load(voice_path, weights_only=True)
82
+ except Exception as e:
83
+ print(f"Warning: weights_only load failed, attempting full load: {str(e)}")
84
+ voicepack = torch.load(voice_path, weights_only=False)
85
+ t_voices.append(voicepack)
86
+ except Exception as e:
87
+ print(f"Warning: Failed to load voice {voice}: {str(e)}")
88
+
89
+ # Combine voices by taking mean
90
+ voicepack = torch.mean(torch.stack(t_voices), dim=0)
91
+ voice_name = "_".join(voice_names)
92
+ # Save mixed voice temporarily
93
+ mixed_voice_path = os.path.join(self.voices_dir, "voices", f"{voice_name}.pt")
94
+ torch.save(voicepack, mixed_voice_path)
95
+ else:
96
+ voice_name = voice_names[0]
97
+
98
+ # Generate speech using KPipeline
99
+ generator = self.pipeline(
100
+ text,
101
+ voice=voice_name,
102
+ speed=speed,
103
+ split_pattern=r'\n+' # Default chunking pattern
104
+ )
105
+
106
+ # Process chunks and collect metrics
107
+ audio_chunks = []
108
+ chunk_times = []
109
+ chunk_sizes = []
110
+ total_tokens = 0
111
+
112
+ for i, (gs, ps, audio) in enumerate(generator):
113
+ chunk_start = time.time()
114
+
115
+ # Store chunk audio
116
+ audio_chunks.append(audio)
117
+
118
+ # Calculate metrics
119
+ chunk_time = time.time() - chunk_start
120
+ chunk_times.append(chunk_time)
121
+ chunk_sizes.append(len(gs)) # Use grapheme length as chunk size
122
+
123
+ # Update progress if callback provided
124
+ if progress_callback:
125
+ chunk_duration = len(audio) / 24000
126
+ rtf = chunk_time / chunk_duration
127
+ progress_callback(
128
+ i + 1,
129
+ -1, # Total chunks unknown with generator
130
+ len(gs) / chunk_time, # tokens/sec
131
+ rtf,
132
+ progress_state,
133
+ start_time,
134
+ gpu_timeout,
135
+ progress
136
+ )
137
+
138
+ print(f"Chunk {i+1} processed in {chunk_time:.2f}s")
139
+ print(f"Graphemes: {gs}")
140
+ print(f"Phonemes: {ps}")
141
+
142
+ # Concatenate audio chunks
143
+ audio = np.concatenate(audio_chunks)
144
+
145
+ # Cleanup temporary mixed voice if created
146
+ if len(voice_names) > 1:
147
+ try:
148
+ os.remove(mixed_voice_path)
149
+ except:
150
+ pass
151
+
152
+ # Return audio and metrics
153
+ return (
154
+ audio,
155
+ len(audio) / 24000,
156
+ {
157
+ "chunk_times": chunk_times,
158
+ "chunk_sizes": chunk_sizes,
159
+ "tokens_per_sec": [float(x) for x in progress_state["tokens_per_sec"]] if progress_state else [],
160
+ "rtf": [float(x) for x in progress_state["rtf"]] if progress_state else [],
161
+ "total_tokens": total_tokens,
162
+ "total_time": time.time() - start_time
163
+ }
164
+ )
165
+
166
+ except Exception as e:
167
+ print(f"Error generating speech: {str(e)}")
168
+ raise