Remsky commited on
Commit
372ebd3
·
1 Parent(s): 8a4e253

Refactor TTS model initialization and add support for multiple versions

Browse files
Files changed (1) hide show
  1. tts_model_v1.py +52 -27
tts_model_v1.py CHANGED
@@ -12,21 +12,19 @@ class TTSModelV1:
12
 
13
  def __init__(self):
14
  self.pipeline = None
15
- self.voices_dir = "voices"
16
  self.model_repo = "hexgrad/Kokoro-82M"
 
17
 
18
  def initialize(self) -> bool:
19
  """Initialize KPipeline and verify voices"""
20
  try:
21
  print("Initializing v1.0.0 model...")
22
 
23
- # Initialize KPipeline with American English
24
- self.pipeline = None
25
 
26
- # Verify local voice files are available
27
- voices_dir = os.path.join(self.voices_dir, "voices")
28
- if not os.path.exists(voices_dir):
29
- raise ValueError("Voice files not found")
30
 
31
  # Verify voices were downloaded successfully
32
  available_voices = self.list_voices()
@@ -45,9 +43,8 @@ class TTSModelV1:
45
  def list_voices(self) -> List[str]:
46
  """List available voices"""
47
  voices = []
48
- voices_subdir = os.path.join(self.voices_dir, "voices")
49
- if os.path.exists(voices_subdir):
50
- for file in os.listdir(voices_subdir):
51
  if file.endswith(".pt"):
52
  voice_name = file[:-3]
53
  voices.append(voice_name)
@@ -68,7 +65,8 @@ class TTSModelV1:
68
  try:
69
  start_time = time.time()
70
  if self.pipeline is None:
71
- self.pipeline = KPipeline(lang_code='a')
 
72
 
73
  if not text or not voice_names:
74
  raise ValueError("Text and voice name are required")
@@ -78,7 +76,7 @@ class TTSModelV1:
78
  t_voices = []
79
  for voice in voice_names:
80
  try:
81
- voice_path = os.path.join(self.voices_dir, "voices", f"{voice}.pt")
82
  try:
83
  voicepack = torch.load(voice_path, weights_only=True)
84
  except Exception as e:
@@ -92,7 +90,7 @@ class TTSModelV1:
92
  voicepack = torch.mean(torch.stack(t_voices), dim=0)
93
  voice_name = "_".join(voice_names)
94
  # Save mixed voice temporarily
95
- mixed_voice_path = os.path.join(self.voices_dir, "voices", f"{voice_name}.pt")
96
  torch.save(voicepack, mixed_voice_path)
97
  else:
98
  voice_name = voice_names[0]
@@ -105,41 +103,68 @@ class TTSModelV1:
105
  split_pattern=r'\n+' # Default chunking pattern
106
  )
107
 
108
- # Process chunks and collect metrics
109
  audio_chunks = []
110
  chunk_times = []
111
  chunk_sizes = []
112
  total_tokens = 0
113
 
 
 
 
 
 
 
 
 
 
114
  for i, (gs, ps, audio) in enumerate(generator):
115
  chunk_start = time.time()
116
-
117
- # Store chunk audio
118
  audio_chunks.append(audio)
119
 
120
  # Calculate metrics
121
  chunk_time = time.time() - chunk_start
 
 
 
 
 
 
 
 
122
  chunk_times.append(chunk_time)
123
- chunk_sizes.append(len(gs)) # Use grapheme length as chunk size
 
 
 
 
 
124
 
125
- # Update progress if callback provided
126
- if progress_callback:
127
- chunk_duration = len(audio) / 24000
128
- rtf = chunk_time / chunk_duration
 
 
 
 
 
 
 
 
 
 
 
129
  progress_callback(
130
  i + 1,
131
- -1, # Total chunks unknown with generator
132
- len(gs) / chunk_time, # tokens/sec
133
  rtf,
134
  progress_state,
135
  start_time,
136
  gpu_timeout,
137
  progress
138
  )
139
-
140
- print(f"Chunk {i+1} processed in {chunk_time:.2f}s")
141
- print(f"Graphemes: {gs}")
142
- print(f"Phonemes: {ps}")
143
 
144
  # Concatenate audio chunks
145
  audio = np.concatenate(audio_chunks)
 
12
 
13
  def __init__(self):
14
  self.pipeline = None
 
15
  self.model_repo = "hexgrad/Kokoro-82M"
16
+ self.voices_dir = os.path.join(os.path.dirname(__file__), "reference", "reference_other_repo", "voices")
17
 
18
  def initialize(self) -> bool:
19
  """Initialize KPipeline and verify voices"""
20
  try:
21
  print("Initializing v1.0.0 model...")
22
 
23
+ self.pipeline = None # cannot be initialized outside of GPU decorator
 
24
 
25
+ # Verify voices directory exists
26
+ if not os.path.exists(self.voices_dir):
27
+ raise ValueError(f"Voice files not found at {self.voices_dir}")
 
28
 
29
  # Verify voices were downloaded successfully
30
  available_voices = self.list_voices()
 
43
  def list_voices(self) -> List[str]:
44
  """List available voices"""
45
  voices = []
46
+ if os.path.exists(self.voices_dir):
47
+ for file in os.listdir(self.voices_dir):
 
48
  if file.endswith(".pt"):
49
  voice_name = file[:-3]
50
  voices.append(voice_name)
 
65
  try:
66
  start_time = time.time()
67
  if self.pipeline is None:
68
+ lang_code = voice_names[0][0] if voice_names else 'a'
69
+ self.pipeline = KPipeline(lang_code=lang_code)
70
 
71
  if not text or not voice_names:
72
  raise ValueError("Text and voice name are required")
 
76
  t_voices = []
77
  for voice in voice_names:
78
  try:
79
+ voice_path = os.path.join(self.voices_dir, f"{voice}.pt")
80
  try:
81
  voicepack = torch.load(voice_path, weights_only=True)
82
  except Exception as e:
 
90
  voicepack = torch.mean(torch.stack(t_voices), dim=0)
91
  voice_name = "_".join(voice_names)
92
  # Save mixed voice temporarily
93
+ mixed_voice_path = os.path.join(self.voices_dir, f"{voice_name}.pt")
94
  torch.save(voicepack, mixed_voice_path)
95
  else:
96
  voice_name = voice_names[0]
 
103
  split_pattern=r'\n+' # Default chunking pattern
104
  )
105
 
106
+ # Initialize tracking
107
  audio_chunks = []
108
  chunk_times = []
109
  chunk_sizes = []
110
  total_tokens = 0
111
 
112
+ # Get generator from pipeline
113
+ generator = self.pipeline(
114
+ text,
115
+ voice=voice_name,
116
+ speed=speed,
117
+ split_pattern=r'\n+'
118
+ )
119
+
120
+ # Process chunks
121
  for i, (gs, ps, audio) in enumerate(generator):
122
  chunk_start = time.time()
 
 
123
  audio_chunks.append(audio)
124
 
125
  # Calculate metrics
126
  chunk_time = time.time() - chunk_start
127
+ chunk_tokens = len(gs)
128
+ total_tokens += chunk_tokens
129
+
130
+ # Calculate speed metrics
131
+ chunk_duration = len(audio) / 24000
132
+ rtf = chunk_time / chunk_duration
133
+ chunk_tokens_per_sec = chunk_tokens / chunk_time
134
+
135
  chunk_times.append(chunk_time)
136
+ chunk_sizes.append(len(gs))
137
+
138
+ print(f"Chunk {i+1} processed in {chunk_time:.2f}s")
139
+ print(f"Current tokens/sec: {chunk_tokens_per_sec:.2f}")
140
+ print(f"Real-time factor: {rtf:.2f}x")
141
+ print(f"{(1/rtf):.1f}x faster than real-time")
142
 
143
+ # Update progress
144
+ if progress_callback and progress_state:
145
+ # Initialize lists if needed
146
+ if "tokens_per_sec" not in progress_state:
147
+ progress_state["tokens_per_sec"] = []
148
+ if "rtf" not in progress_state:
149
+ progress_state["rtf"] = []
150
+ if "chunk_times" not in progress_state:
151
+ progress_state["chunk_times"] = []
152
+
153
+ # Update progress state
154
+ progress_state["tokens_per_sec"].append(chunk_tokens_per_sec)
155
+ progress_state["rtf"].append(rtf)
156
+ progress_state["chunk_times"].append(chunk_time)
157
+
158
  progress_callback(
159
  i + 1,
160
+ -1, # Let UI handle total chunks
161
+ chunk_tokens_per_sec,
162
  rtf,
163
  progress_state,
164
  start_time,
165
  gpu_timeout,
166
  progress
167
  )
 
 
 
 
168
 
169
  # Concatenate audio chunks
170
  audio = np.concatenate(audio_chunks)