ronedgecomb commited on
Commit
1fb7c23
·
verified ·
1 Parent(s): ec33a78
Files changed (4) hide show
  1. app.py +85 -5
  2. pyproject.toml +1 -0
  3. requirements.txt +1 -0
  4. uv.lock +23 -0
app.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
3
  import json
4
  import os
5
  import re
6
- from typing import Dict, List, Optional, Tuple
7
 
8
  import gradio as gr
9
  import numpy as np
@@ -108,6 +108,11 @@ class KittenTTS_1_Onnx:
108
  providers=chosen_providers,
109
  )
110
 
 
 
 
 
 
111
  def _prepare_inputs(
112
  self, text: str, voice: str, speed: float
113
  ) -> Dict[str, np.ndarray]:
@@ -131,6 +136,61 @@ class KittenTTS_1_Onnx:
131
 
132
  return {"input_ids": input_ids, "style": style_vec, "speed": speed_arr}
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  def generate(
135
  self, text: str, voice: str = "expr-voice-5-m", speed: float = 1.0
136
  ) -> np.ndarray:
@@ -174,10 +234,30 @@ class KittenTTS:
174
  repo_id=repo_id, cache_dir=cache_dir, providers=providers
175
  )
176
 
177
- def generate(
178
- self, text: str, voice: str = "expr-voice-5-m", speed: float = 1.0
179
- ) -> np.ndarray:
180
- return self._model.generate(text, voice=voice, speed=speed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
  def generate_to_file(
183
  self,
 
3
  import json
4
  import os
5
  import re
6
+ from typing import Dict, List, Tuple, Optional, Iterator
7
 
8
  import gradio as gr
9
  import numpy as np
 
108
  providers=chosen_providers,
109
  )
110
 
111
+ self.max_seq_len = self._infer_max_seq_len() or int(os.getenv("KITTEN_MAX_SEQ_LEN", "512"))
112
+ # reserve 2 slots for BOS/EOS tokens inserted below
113
+ self._chunk_budget = max(1, self.max_seq_len - 2)
114
+
115
+
116
  def _prepare_inputs(
117
  self, text: str, voice: str, speed: float
118
  ) -> Dict[str, np.ndarray]:
 
136
 
137
  return {"input_ids": input_ids, "style": style_vec, "speed": speed_arr}
138
 
139
+ def _infer_max_seq_len(self) -> Optional[int]:
140
+ """Try to read positional-embedding length from the ONNX initializers.
141
+ Falls back to env var or 512 if unavailable. Optional dependency on 'onnx'.
142
+ """
143
+ try:
144
+ import onnx # optional
145
+ except Exception:
146
+ return None
147
+ try:
148
+ model = onnx.load(self.model_path)
149
+ except Exception:
150
+ return None
151
+
152
+ for tensor in model.graph.initializer:
153
+ name = tensor.name.lower()
154
+ if "position" in name and len(tensor.dims) == 2:
155
+ # dims[0] = max positions, dims[1] = hidden dim
156
+ return int(tensor.dims[0])
157
+ return None
158
+
159
+ def _phonemize_to_clean(self, text: str) -> str:
160
+ """Phonemize once and keep only characters present in the symbol set."""
161
+ phonemes = self._phonemizer.phonemize([text])[0]
162
+ token_str = " ".join(basic_english_tokenize(phonemes))
163
+ # keep only symbols known to the TextCleaner
164
+ return "".join(c for c in token_str if c in self._cleaner._dict)
165
+
166
+ def _run_onnx(self, token_ids: List[int], voice: str, speed: float) -> np.ndarray:
167
+ """One inference call with trimming identical to original behavior."""
168
+ input_ids = np.asarray([token_ids], dtype=np.int64)
169
+ style_vec = self.voices[voice]
170
+ speed_arr = np.asarray([speed], dtype=np.float32)
171
+ outputs = self.session.run(None, {"input_ids": input_ids, "style": style_vec, "speed": speed_arr})
172
+ audio = np.asarray(outputs[0], dtype=np.float32)
173
+ if audio.size > 15000:
174
+ audio = audio[5000:-10000]
175
+ return audio
176
+
177
+ def _chunk_token_ids(self, clean: str) -> Iterator[List[int]]:
178
+ """Yield BOS/segment/EOS token-id sequences within model capacity."""
179
+ n = len(clean)
180
+ i = 0
181
+ while i < n:
182
+ j = min(i + self._chunk_budget, n)
183
+ # prefer to cut at a space when possible, to keep phrasing natural
184
+ cut = clean.rfind(" ", i, j)
185
+ if cut != -1 and cut > i + int(0.6 * self._chunk_budget):
186
+ j = cut + 1 # include the space
187
+ seg = clean[i:j]
188
+ ids = self._cleaner(seg) # segment ids
189
+ ids.insert(0, 0) # BOS
190
+ ids.append(0) # EOS
191
+ yield ids
192
+ i = j
193
+
194
  def generate(
195
  self, text: str, voice: str = "expr-voice-5-m", speed: float = 1.0
196
  ) -> np.ndarray:
 
234
  repo_id=repo_id, cache_dir=cache_dir, providers=providers
235
  )
236
 
237
+ def generate(self, text: str, voice: str = "expr-voice-5-m", speed: float = 1.0) -> np.ndarray:
238
+ """Synthesize speech with automatic chunking at the model's max length."""
239
+ if voice not in self.available_voices:
240
+ raise ValueError(f"Voice '{voice}' not available. Choose from: {self.available_voices}")
241
+
242
+ # Phonemize once, then either run single-shot or chunked
243
+ clean = self._phonemize_to_clean(text)
244
+
245
+ # Fast path: fits in one pass
246
+ if len(clean) + 2 <= self.max_seq_len:
247
+ ids = self._cleaner(clean)
248
+ ids.insert(0, 0) # BOS
249
+ ids.append(0) # EOS
250
+ return self._run_onnx(ids, voice, speed)
251
+
252
+ # Chunked path: concatenate per-chunk audio
253
+ pieces: List[np.ndarray] = []
254
+ for ids in self._chunk_token_ids(clean):
255
+ pieces.append(self._run_onnx(ids, voice, speed))
256
+
257
+ if not pieces:
258
+ return np.array([], dtype=np.float32)
259
+ return pieces[0] if len(pieces) == 1 else np.concatenate(pieces)
260
+
261
 
262
  def generate_to_file(
263
  self,
pyproject.toml CHANGED
@@ -8,6 +8,7 @@ dependencies = [
8
  "gradio>=5.43.1",
9
  "huggingface-hub[hf-xet]>=0.34.4",
10
  "numpy>=2.3.2",
 
11
  "onnxruntime>=1.22.1",
12
  "phonemizer>=3.3.0",
13
  "soundfile>=0.13.1",
 
8
  "gradio>=5.43.1",
9
  "huggingface-hub[hf-xet]>=0.34.4",
10
  "numpy>=2.3.2",
11
+ "onnx>=1.18.0",
12
  "onnxruntime>=1.22.1",
13
  "phonemizer>=3.3.0",
14
  "soundfile>=0.13.1",
requirements.txt CHANGED
@@ -39,6 +39,7 @@ markupsafe==3.0.2
39
  mdurl==0.1.2
40
  mpmath==1.3.0
41
  numpy==2.3.2
 
42
  onnxruntime==1.22.1
43
  orjson==3.11.2
44
  packaging==25.0
 
39
  mdurl==0.1.2
40
  mpmath==1.3.0
41
  numpy==2.3.2
42
+ onnx==1.18.0
43
  onnxruntime==1.22.1
44
  orjson==3.11.2
45
  packaging==25.0
uv.lock CHANGED
@@ -532,6 +532,7 @@ dependencies = [
532
  { name = "gradio" },
533
  { name = "huggingface-hub", extra = ["hf-xet"] },
534
  { name = "numpy" },
 
535
  { name = "onnxruntime" },
536
  { name = "phonemizer" },
537
  { name = "soundfile" },
@@ -542,6 +543,7 @@ requires-dist = [
542
  { name = "gradio", specifier = ">=5.43.1" },
543
  { name = "huggingface-hub", extras = ["hf-xet"], specifier = ">=0.34.4" },
544
  { name = "numpy", specifier = ">=2.3.2" },
 
545
  { name = "onnxruntime", specifier = ">=1.22.1" },
546
  { name = "phonemizer", specifier = ">=3.3.0" },
547
  { name = "soundfile", specifier = ">=0.13.1" },
@@ -666,6 +668,27 @@ wheels = [
666
  { url = "https://files.pythonhosted.org/packages/c1/9e/1652778bce745a67b5fe05adde60ed362d38eb17d919a540e813d30f6874/numpy-2.3.2-cp314-cp314t-win_arm64.whl", hash = "sha256:092aeb3449833ea9c0bf0089d70c29ae480685dd2377ec9cdbbb620257f84631", size = 10544226, upload-time = "2025-07-24T20:56:34.509Z" },
667
  ]
668
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
669
  [[package]]
670
  name = "onnxruntime"
671
  version = "1.22.1"
 
532
  { name = "gradio" },
533
  { name = "huggingface-hub", extra = ["hf-xet"] },
534
  { name = "numpy" },
535
+ { name = "onnx" },
536
  { name = "onnxruntime" },
537
  { name = "phonemizer" },
538
  { name = "soundfile" },
 
543
  { name = "gradio", specifier = ">=5.43.1" },
544
  { name = "huggingface-hub", extras = ["hf-xet"], specifier = ">=0.34.4" },
545
  { name = "numpy", specifier = ">=2.3.2" },
546
+ { name = "onnx", specifier = ">=1.18.0" },
547
  { name = "onnxruntime", specifier = ">=1.22.1" },
548
  { name = "phonemizer", specifier = ">=3.3.0" },
549
  { name = "soundfile", specifier = ">=0.13.1" },
 
668
  { url = "https://files.pythonhosted.org/packages/c1/9e/1652778bce745a67b5fe05adde60ed362d38eb17d919a540e813d30f6874/numpy-2.3.2-cp314-cp314t-win_arm64.whl", hash = "sha256:092aeb3449833ea9c0bf0089d70c29ae480685dd2377ec9cdbbb620257f84631", size = 10544226, upload-time = "2025-07-24T20:56:34.509Z" },
669
  ]
670
 
671
+ [[package]]
672
+ name = "onnx"
673
+ version = "1.18.0"
674
+ source = { registry = "https://pypi.org/simple" }
675
+ dependencies = [
676
+ { name = "numpy" },
677
+ { name = "protobuf" },
678
+ { name = "typing-extensions" },
679
+ ]
680
+ sdist = { url = "https://files.pythonhosted.org/packages/3d/60/e56e8ec44ed34006e6d4a73c92a04d9eea6163cc12440e35045aec069175/onnx-1.18.0.tar.gz", hash = "sha256:3d8dbf9e996629131ba3aa1afd1d8239b660d1f830c6688dd7e03157cccd6b9c", size = 12563009, upload-time = "2025-05-12T22:03:09.626Z" }
681
+ wheels = [
682
+ { url = "https://files.pythonhosted.org/packages/45/da/9fb8824513fae836239276870bfcc433fa2298d34ed282c3a47d3962561b/onnx-1.18.0-cp313-cp313-macosx_12_0_universal2.whl", hash = "sha256:030d9f5f878c5f4c0ff70a4545b90d7812cd6bfe511de2f3e469d3669c8cff95", size = 18285906, upload-time = "2025-05-12T22:02:45.01Z" },
683
+ { url = "https://files.pythonhosted.org/packages/05/e8/762b5fb5ed1a2b8e9a4bc5e668c82723b1b789c23b74e6b5a3356731ae4e/onnx-1.18.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8521544987d713941ee1e591520044d35e702f73dc87e91e6d4b15a064ae813d", size = 17421486, upload-time = "2025-05-12T22:02:48.467Z" },
684
+ { url = "https://files.pythonhosted.org/packages/12/bb/471da68df0364f22296456c7f6becebe0a3da1ba435cdb371099f516da6e/onnx-1.18.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c137eecf6bc618c2f9398bcc381474b55c817237992b169dfe728e169549e8f", size = 17583581, upload-time = "2025-05-12T22:02:51.784Z" },
685
+ { url = "https://files.pythonhosted.org/packages/76/0d/01a95edc2cef6ad916e04e8e1267a9286f15b55c90cce5d3cdeb359d75d6/onnx-1.18.0-cp313-cp313-win32.whl", hash = "sha256:6c093ffc593e07f7e33862824eab9225f86aa189c048dd43ffde207d7041a55f", size = 15734621, upload-time = "2025-05-12T22:02:54.62Z" },
686
+ { url = "https://files.pythonhosted.org/packages/64/95/253451a751be32b6173a648b68f407188009afa45cd6388780c330ff5d5d/onnx-1.18.0-cp313-cp313-win_amd64.whl", hash = "sha256:230b0fb615e5b798dc4a3718999ec1828360bc71274abd14f915135eab0255f1", size = 15850472, upload-time = "2025-05-12T22:02:57.54Z" },
687
+ { url = "https://files.pythonhosted.org/packages/0a/b1/6fd41b026836df480a21687076e0f559bc3ceeac90f2be8c64b4a7a1f332/onnx-1.18.0-cp313-cp313-win_arm64.whl", hash = "sha256:6f91930c1a284135db0f891695a263fc876466bf2afbd2215834ac08f600cfca", size = 15823808, upload-time = "2025-05-12T22:03:00.305Z" },
688
+ { url = "https://files.pythonhosted.org/packages/70/f3/499e53dd41fa7302f914dd18543da01e0786a58b9a9d347497231192001f/onnx-1.18.0-cp313-cp313t-macosx_12_0_universal2.whl", hash = "sha256:2f4d37b0b5c96a873887652d1cbf3f3c70821b8c66302d84b0f0d89dd6e47653", size = 18316526, upload-time = "2025-05-12T22:03:03.691Z" },
689
+ { url = "https://files.pythonhosted.org/packages/84/dd/6abe5d7bd23f5ed3ade8352abf30dff1c7a9e97fc1b0a17b5d7c726e98a9/onnx-1.18.0-cp313-cp313t-win_amd64.whl", hash = "sha256:a69afd0baa372162948b52c13f3aa2730123381edf926d7ef3f68ca7cec6d0d0", size = 15865055, upload-time = "2025-05-12T22:03:06.663Z" },
690
+ ]
691
+
692
  [[package]]
693
  name = "onnxruntime"
694
  version = "1.22.1"