Remsky commited on
Commit
80c0dbf
·
1 Parent(s): bb43905

Refactor TTSModelV1 to load voice mappings from JSON and simplify voice selection

Browse files
Files changed (2) hide show
  1. tts_model_v1.py +10 -60
  2. voices/v1_voices.json +32 -0
tts_model_v1.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import torch
3
  import numpy as np
4
  import time
@@ -6,54 +7,32 @@ from typing import Tuple, List
6
  import soundfile as sf
7
  from kokoro import KPipeline
8
  import spaces
9
- from lib.file_utils import download_voice_files, ensure_dir
10
 
11
  class TTSModelV1:
12
  """KPipeline-based TTS model for v1.0.0"""
13
 
14
  def __init__(self):
15
  self.pipeline = None
16
- self.model_repo = "hexgrad/Kokoro-82M"
17
- # Use v1 voices from Kokoro-82M repo
18
- self.voices_dir = os.path.join(os.path.dirname(__file__), "voices")
 
19
 
20
  def initialize(self) -> bool:
21
- """Initialize KPipeline and verify voices"""
22
  try:
23
  print("Initializing v1.0.0 model...")
24
-
25
  self.pipeline = None # cannot be initialized outside of GPU decorator
26
-
27
- # Download v1 voices if needed
28
- ensure_dir(self.voices_dir)
29
- if not os.path.exists(os.path.join(self.voices_dir, "voices")):
30
- print("Downloading v1 voices...")
31
- download_voice_files(self.model_repo, "voices", self.voices_dir)
32
-
33
- # Verify voices were downloaded successfully
34
- available_voices = self.list_voices()
35
- if not available_voices:
36
- print("Warning: No voices found after initialization")
37
- else:
38
- print(f"Found {len(available_voices)} voices")
39
-
40
  print("Model initialization complete")
41
  return True
42
-
43
  except Exception as e:
44
  print(f"Error initializing model: {str(e)}")
45
  return False
46
 
47
  def list_voices(self) -> List[str]:
48
  """List available voices"""
49
- voices = []
50
- voices_dir = os.path.join(self.voices_dir, "voices")
51
- if os.path.exists(voices_dir):
52
- for file in os.listdir(voices_dir):
53
- if file.endswith(".pt"):
54
- voice_name = file[:-3]
55
- voices.append(voice_name)
56
- return voices
57
 
58
  @spaces.GPU(duration=None) # Duration will be set by the UI
59
  def generate_speech(self, text: str, voice_names: list[str], speed: float = 1.0, gpu_timeout: int = 60, progress_callback=None, progress_state=None, progress=None) -> Tuple[np.ndarray, float]:
@@ -76,35 +55,12 @@ class TTSModelV1:
76
  if not text or not voice_names:
77
  raise ValueError("Text and voice name are required")
78
 
79
- # Handle voice mixing
80
  if isinstance(voice_names, list) and len(voice_names) > 1:
81
- t_voices = []
82
- for voice in voice_names:
83
- try:
84
- voice_path = os.path.join(self.voices_dir, "voices", f"{voice}.pt")
85
- try:
86
- voicepack = torch.load(voice_path, weights_only=True)
87
- except Exception as e:
88
- print(f"Warning: weights_only load failed, attempting full load: {str(e)}")
89
- voicepack = torch.load(voice_path, weights_only=False)
90
- t_voices.append(voicepack)
91
- except Exception as e:
92
- print(f"Warning: Failed to load voice {voice}: {str(e)}")
93
-
94
- # Combine voices by taking mean
95
- voicepack = torch.mean(torch.stack(t_voices), dim=0)
96
  voice_name = "_".join(voice_names)
97
- # Save mixed voice temporarily
98
- mixed_voice_path = os.path.join(self.voices_dir, "voices", f"{voice_name}.pt")
99
- torch.save(voicepack, mixed_voice_path)
100
  else:
101
  voice_name = voice_names[0]
102
- voice_path = os.path.join(self.voices_dir, "voices", f"{voice_name}.pt")
103
- try:
104
- voicepack = torch.load(voice_path, weights_only=True)
105
- except Exception as e:
106
- print(f"Warning: weights_only load failed, attempting full load: {str(e)}")
107
- voicepack = torch.load(voice_path, weights_only=False)
108
 
109
  # Initialize tracking
110
  audio_chunks = []
@@ -172,12 +128,6 @@ class TTSModelV1:
172
  # Concatenate audio chunks
173
  audio = np.concatenate(audio_chunks)
174
 
175
- # Cleanup temporary mixed voice if created
176
- if len(voice_names) > 1:
177
- try:
178
- os.remove(mixed_voice_path)
179
- except:
180
- pass
181
 
182
  # Return audio and metrics
183
  return (
 
1
  import os
2
+ import json
3
  import torch
4
  import numpy as np
5
  import time
 
7
  import soundfile as sf
8
  from kokoro import KPipeline
9
  import spaces
 
10
 
11
  class TTSModelV1:
12
  """KPipeline-based TTS model for v1.0.0"""
13
 
14
  def __init__(self):
15
  self.pipeline = None
16
+ # Load v1 voice mappings
17
+ voice_map_path = os.path.join(os.path.dirname(__file__), "voices", "v1_voices.json")
18
+ with open(voice_map_path) as f:
19
+ self.voice_map = json.load(f)
20
 
21
  def initialize(self) -> bool:
22
+ """Initialize KPipeline"""
23
  try:
24
  print("Initializing v1.0.0 model...")
 
25
  self.pipeline = None # cannot be initialized outside of GPU decorator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  print("Model initialization complete")
27
  return True
 
28
  except Exception as e:
29
  print(f"Error initializing model: {str(e)}")
30
  return False
31
 
32
  def list_voices(self) -> List[str]:
33
  """List available voices"""
34
+ # Return all voices from voice map
35
+ return self.voice_map["american"] + self.voice_map["british"]
 
 
 
 
 
 
36
 
37
  @spaces.GPU(duration=None) # Duration will be set by the UI
38
  def generate_speech(self, text: str, voice_names: list[str], speed: float = 1.0, gpu_timeout: int = 60, progress_callback=None, progress_state=None, progress=None) -> Tuple[np.ndarray, float]:
 
55
  if not text or not voice_names:
56
  raise ValueError("Text and voice name are required")
57
 
58
+ # Handle voice selection
59
  if isinstance(voice_names, list) and len(voice_names) > 1:
60
+ # For multiple voices, join them with underscore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  voice_name = "_".join(voice_names)
 
 
 
62
  else:
63
  voice_name = voice_names[0]
 
 
 
 
 
 
64
 
65
  # Initialize tracking
66
  audio_chunks = []
 
128
  # Concatenate audio chunks
129
  audio = np.concatenate(audio_chunks)
130
 
 
 
 
 
 
 
131
 
132
  # Return audio and metrics
133
  return (
voices/v1_voices.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "american": [
3
+ "af_alloy",
4
+ "af_aoede",
5
+ "af_bella",
6
+ "af_jessica",
7
+ "af_kore",
8
+ "af_nicole",
9
+ "af_nova",
10
+ "af_river",
11
+ "af_sarah",
12
+ "af_sky",
13
+ "am_adam",
14
+ "am_echo",
15
+ "am_eric",
16
+ "am_fenrir",
17
+ "am_liam",
18
+ "am_michael",
19
+ "am_onyx",
20
+ "am_puck"
21
+ ],
22
+ "british": [
23
+ "bf_alice",
24
+ "bf_emma",
25
+ "bf_isabella",
26
+ "bf_lily",
27
+ "bm_daniel",
28
+ "bm_fable",
29
+ "bm_george",
30
+ "bm_lewis"
31
+ ]
32
+ }