yueyulin commited on
Commit
66d387e
·
verified ·
1 Parent(s): c75845b

Upload folder using huggingface_hub

Browse files
Files changed (24) hide show
  1. rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/BiCodecDetokenize.onnx +3 -0
  2. rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/BiCodecTokenize.onnx +3 -0
  3. rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/__pycache__/properties_util.cpython-311.pyc +0 -0
  4. rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/__pycache__/ref_audio_utilities.cpython-311.pyc +0 -0
  5. rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/config.json +55 -0
  6. rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/generation_config.json +6 -0
  7. rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/hf_rwkv_tokenizer.py +280 -0
  8. rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/model.safetensors +3 -0
  9. rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/model_converted.pth +3 -0
  10. rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/model_padded.pth +3 -0
  11. rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/modeling_rwkvspeech.py +6 -0
  12. rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/properties_util.py +221 -0
  13. rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/ref_audio_utilities.py +306 -0
  14. rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/rwkv_vocab_v20230424.txt +0 -0
  15. rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/spark_llm.py +202 -0
  16. rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/special_tokens_map.json +24 -0
  17. rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/texts_utilities.py +0 -0
  18. rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/tokenizer_config.json +836 -0
  19. rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/translation_data.py +55 -0
  20. rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/tts_cli.py +992 -0
  21. rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/utilities.py +209 -0
  22. rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/vocab.txt +0 -0
  23. rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/wav2vec2-large-xlsr-53.onnx +3 -0
  24. rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/webrwkv.safetensors +3 -0
rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/BiCodecDetokenize.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:055f86df2809ca8b9210154e8ddc85aa7458909d4b30aa7f996e3fe053a71e3d
3
+ size 385412236
rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/BiCodecTokenize.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7080b9790ee020977105d78754628c2b5e03841c0bbfc0294072ec40278222ce
3
+ size 146225395
rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/__pycache__/properties_util.cpython-311.pyc ADDED
Binary file (5.93 kB). View file
 
rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/__pycache__/ref_audio_utilities.cpython-311.pyc ADDED
Binary file (13.3 kB). View file
 
rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/config.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "a_low_rank_dim": 64,
3
+ "architectures": [
4
+ "RWKV7ForSpeech"
5
+ ],
6
+ "attn": null,
7
+ "attn_mode": "chunk",
8
+ "audio_global_vocab_size": 4096,
9
+ "auto_map": {
10
+ "AutoConfig": "spark_llm.RWKV7SpeechConfig",
11
+ "AutoModel": "modeling_rwkvspeech.RWKV7Model",
12
+ "AutoModelForCausalLM": "modeling_rwkvspeech.RWKV7ForSpeech"
13
+ },
14
+ "bos_token_id": 0,
15
+ "decay_low_rank_dim": 64,
16
+ "eos_token_id": 0,
17
+ "fuse_cross_entropy": true,
18
+ "fuse_norm": false,
19
+ "gate_low_rank_dim": 128,
20
+ "head_dim": 64,
21
+ "hidden_act": "sqrelu",
22
+ "hidden_ratio": 4.0,
23
+ "hidden_size": 768,
24
+ "initializer_range": 0.006,
25
+ "intermediate_size": 3072,
26
+ "max_position_embeddings": 2048,
27
+ "model_type": "rwkv7",
28
+ "norm_bias": true,
29
+ "norm_eps": 1e-05,
30
+ "norm_first": true,
31
+ "num_heads": 32,
32
+ "num_hidden_layers": 12,
33
+ "text_vocab_size": 65631,
34
+ "tie_word_embeddings": false,
35
+ "torch_dtype": "float32",
36
+ "transformers_version": "4.52.4",
37
+ "use_cache": true,
38
+ "use_l2warp": true,
39
+ "v_low_rank_dim": 32,
40
+ "value_dim": [
41
+ 768,
42
+ 768,
43
+ 768,
44
+ 768,
45
+ 768,
46
+ 768,
47
+ 768,
48
+ 768,
49
+ 768,
50
+ 768,
51
+ 768,
52
+ 768
53
+ ],
54
+ "vocab_size": 8193
55
+ }
rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": 0,
5
+ "transformers_version": "4.52.4"
6
+ }
rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/hf_rwkv_tokenizer.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for RWKV."""
16
+
17
+ import os
18
+ import re
19
+ from typing import TYPE_CHECKING, List, Optional, Tuple
20
+
21
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
22
+ from transformers.utils import logging
23
+
24
+
25
+ if TYPE_CHECKING:
26
+ pass
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ VOCAB_FILES_NAMES = {
32
+ "vocab_file": "rwkv_vocab_v20230424.txt",
33
+ }
34
+
35
+ class TRIE:
36
+ __slots__ = tuple("ch,to,values,front".split(","))
37
+ to: list
38
+ values: set
39
+
40
+ def __init__(self, front=None, ch=None):
41
+ self.ch = ch
42
+ self.to = [None for ch in range(256)]
43
+ self.values = set()
44
+ self.front = front
45
+
46
+ def __repr__(self):
47
+ fr = self
48
+ ret = []
49
+ while fr != None:
50
+ if fr.ch != None:
51
+ ret.append(fr.ch)
52
+ fr = fr.front
53
+ return "<TRIE %s %s>" % (ret[::-1], self.values)
54
+
55
+ def add(self, key: bytes, idx: int = 0, val=None):
56
+ if idx == len(key):
57
+ if val is None:
58
+ val = key
59
+ self.values.add(val)
60
+ return self
61
+ ch = key[idx]
62
+ if self.to[ch] is None:
63
+ self.to[ch] = TRIE(front=self, ch=ch)
64
+ return self.to[ch].add(key, idx=idx + 1, val=val)
65
+
66
+ def find_longest(self, key: bytes, idx: int = 0):
67
+ u: TRIE = self
68
+ ch: int = key[idx]
69
+
70
+ while u.to[ch] is not None:
71
+ u = u.to[ch]
72
+ idx += 1
73
+ if u.values:
74
+ ret = idx, u, u.values
75
+ if idx == len(key):
76
+ break
77
+ ch = key[idx]
78
+ return ret
79
+
80
+
81
+ class RWKV_TOKENIZER:
82
+ def __init__(self, file_name):
83
+ self.idx2token = {}
84
+ sorted = [] # must be already sorted
85
+ with open(file_name, "r", encoding="utf-8") as f:
86
+ lines = f.readlines()
87
+ for l in lines:
88
+ idx = int(l[: l.index(" ")])
89
+ x = eval(l[l.index(" ") : l.rindex(" ")])
90
+ x = x.encode("utf-8") if isinstance(x, str) else x
91
+ assert isinstance(x, bytes)
92
+
93
+ assert len(x) == int(l[l.rindex(" ") :])
94
+ sorted += [x]
95
+ self.idx2token[idx] = x
96
+
97
+ self.token2idx = {}
98
+ for k, v in self.idx2token.items():
99
+ self.token2idx[v] = int(k)
100
+
101
+ self.root = TRIE()
102
+ for t, i in self.token2idx.items():
103
+ _ = self.root.add(t, val=(t, i))
104
+
105
+ def encodeBytes(self, src: bytes):
106
+ idx: int = 0
107
+ tokens = []
108
+ while idx < len(src):
109
+ _idx: int = idx
110
+ idx, _, values = self.root.find_longest(src, idx)
111
+ assert idx != _idx
112
+ _, token = next(iter(values))
113
+ tokens.append(token)
114
+ return tokens
115
+
116
+ def decodeBytes(self, tokens):
117
+ return b"".join(map(lambda i: self.idx2token[i], tokens))
118
+
119
+ def encode(self, src):
120
+ if isinstance(src, str):
121
+ return [self.encodeBytes(src.encode("utf-8"))]
122
+ elif isinstance(src, list):
123
+ return [self.encodeBytes(s.encode("utf-8")) for s in src]
124
+
125
+ def decode(self, tokens):
126
+ return [self.decodeBytes(batch).decode("utf-8") for batch in tokens]
127
+ # try:
128
+ # return self.decodeBytes(tokens).decode('utf-8')
129
+ # except:
130
+ # return '\ufffd' # bad utf-8
131
+
132
+ def printTokens(self, tokens):
133
+ for i in tokens:
134
+ s = self.idx2token[i]
135
+ try:
136
+ s = s.decode("utf-8")
137
+ except:
138
+ pass
139
+ print(f"{repr(s)}{i}", end=" ")
140
+ print()
141
+
142
+
143
+ class RwkvTokenizer(PreTrainedTokenizer):
144
+ vocab_files_names = VOCAB_FILES_NAMES
145
+ model_input_names = ["input_ids", "attention_mask"]
146
+
147
+ def __init__(
148
+ self, vocab_file, bos_token="<|rwkv_tokenizer_end_of_text|>", eos_token="<|rwkv_tokenizer_end_of_text|>", unk_token="<|rwkv_tokenizer_end_of_text|>", **kwargs
149
+ ):
150
+ if not os.path.isfile(vocab_file):
151
+ raise ValueError(
152
+ f"Can't find a vocabulary file at path '{vocab_file}'."
153
+ )
154
+
155
+ with open(vocab_file, "r", encoding="utf-8") as reader:
156
+ tokens = reader.readlines()
157
+
158
+ if "add_bos_token" in kwargs:
159
+ self.add_bos_token = kwargs["add_bos_token"]
160
+ else:
161
+ self.add_bos_token = False
162
+ self.trie_tokenizer = RWKV_TOKENIZER(vocab_file)
163
+ vocab = self.trie_tokenizer.token2idx
164
+ self.encoder = vocab
165
+ self.decoder = {v: k for k, v in vocab.items()}
166
+ self._added_tokens_decoder = {0: AddedToken(str(bos_token))}
167
+ super().__init__(
168
+ bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs
169
+ )
170
+
171
+ @property
172
+ def vocab_size(self):
173
+ return len(self.encoder)
174
+
175
+ def get_vocab(self):
176
+ vocab = self.encoder
177
+ vocab.update(self.added_tokens_encoder)
178
+ vocab = dict(sorted(vocab.items(), key=lambda item: item[1]))
179
+ return vocab
180
+
181
+ def _tokenize(self, text, split_special_tokens=False):
182
+ # return self.wordpiece_tokenizer.tokenize(text.encode("utf-8"))
183
+ return self.trie_tokenizer.encode(text)[0]
184
+
185
+ def _convert_token_to_id(self, token):
186
+ return token
187
+
188
+ def _convert_id_to_token(self, index):
189
+ """Converts an index (integer) in a token (byte) using the vocab."""
190
+ token = self.decoder.get(index, self.unk_token)
191
+ if isinstance(token, (bytes)):
192
+ token = token.decode("utf-8", errors="replace")
193
+ return token
194
+
195
+ def convert_tokens_to_string(self, tokens):
196
+ """Converts a sequence of tokens (bytes) in a single string. Additional tokens are encoded to bytes"""
197
+ out_string = b"".join(
198
+ [k.encode(errors="replace") if isinstance(k, str) else k for k in tokens]
199
+ ).decode("utf-8")
200
+ return out_string
201
+
202
+ def save_vocabulary(
203
+ self, save_directory: str, filename_prefix: Optional[str] = None
204
+ ) -> Tuple[str]:
205
+ index = 0
206
+ if os.path.isdir(save_directory):
207
+ vocab_file = os.path.join(
208
+ save_directory,
209
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.txt",
210
+ )
211
+ else:
212
+ vocab_file = (
213
+ filename_prefix + "-" if filename_prefix else ""
214
+ ) + save_directory
215
+ with open(vocab_file, "w", encoding="utf-8") as writer:
216
+ for token, token_index in sorted(
217
+ self.encoder.items(), key=lambda kv: kv[1]
218
+ ):
219
+ if index != token_index:
220
+ logger.warning(
221
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
222
+ " Please check that the vocabulary is not corrupted!"
223
+ )
224
+ index = token_index
225
+ writer.write(str(token) + "\n")
226
+ index += 1
227
+ return (vocab_file,)
228
+
229
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
230
+ if self.add_bos_token:
231
+ bos_token_ids = [self.bos_token_id]
232
+ else:
233
+ bos_token_ids = []
234
+
235
+ output = bos_token_ids + token_ids_0
236
+
237
+ if token_ids_1 is None:
238
+ return output
239
+
240
+ return output + bos_token_ids + token_ids_1
241
+
242
+ def get_special_tokens_mask(
243
+ self,
244
+ token_ids_0: List[int],
245
+ token_ids_1: Optional[List[int]] = None,
246
+ already_has_special_tokens: bool = False,
247
+ ) -> List[int]:
248
+ """
249
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
250
+ special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
251
+
252
+ Args:
253
+ token_ids_0 (`List[int]`):
254
+ List of IDs.
255
+ token_ids_1 (`List[int]`, *optional*):
256
+ Optional second list of IDs for sequence pairs.
257
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
258
+ Whether or not the token list is already formatted with special tokens for the model.
259
+
260
+ Returns:
261
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
262
+ """
263
+ if already_has_special_tokens:
264
+ return super().get_special_tokens_mask(
265
+ token_ids_0=token_ids_0,
266
+ token_ids_1=token_ids_1,
267
+ already_has_special_tokens=True,
268
+ )
269
+
270
+ if not self.add_bos_token:
271
+ return super().get_special_tokens_mask(
272
+ token_ids_0=token_ids_0,
273
+ token_ids_1=token_ids_1,
274
+ already_has_special_tokens=False,
275
+ )
276
+
277
+ if token_ids_1 is None:
278
+ return [1] + ([0] * len(token_ids_0))
279
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
280
+
rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:634a0c3b1b67cf897f451b142701f49b1d862875b804d587e098341f1bf0bb57
3
+ size 626075280
rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/model_converted.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b869b06e4f486c4698b0f4d71d81a958db0b27b78fa8c0c73f687baac6a87b54
3
+ size 626155657
rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/model_padded.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5b9292bc80b0db946657f60cc66193679bc6d87e91725e543a0a544a21a276c
3
+ size 840365002
rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/modeling_rwkvspeech.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from model.llm.spark_llm import RWKV7SpeechConfig,RWKV7ForSpeech
2
+ from rwkvfla.models.rwkv7 import RWKV7Model
3
+
4
+ RWKV7ForCausalLM = RWKV7ForSpeech
5
+ RWKV7Model = RWKV7Model
6
+ RWKV7Config = RWKV7SpeechConfig
rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/properties_util.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SPEED_MAP = {
2
+ "very_slow": "SPCT_1",
3
+ "slow": "SPCT_2",
4
+ "medium": "SPCT_3",
5
+ "fast": "SPCT_4",
6
+ "very_fast": "SPCT_5",
7
+ }
8
+
9
+ PITCH_MAP = {
10
+ "low_pitch": "SPCT_6",
11
+ "medium_pitch": "SPCT_7",
12
+ "high_pitch": "SPCT_8",
13
+ "very_high_pitch": "SPCT_9",
14
+ }
15
+
16
+ AGE_MAP = {
17
+ "child": "SPCT_13",
18
+ "teenager": "SPCT_14",
19
+ "youth-adult": "SPCT_15",
20
+ "middle-aged": "SPCT_16",
21
+ "elderly": "SPCT_17",
22
+ }
23
+
24
+
25
+ EMOTION_MAP = {
26
+ "UNKNOWN": "SPCT_21",
27
+ "NEUTRAL": "SPCT_22",
28
+ "ANGRY": "SPCT_23",
29
+ "HAPPY": "SPCT_24",
30
+ "SAD": "SPCT_25",
31
+ "FEARFUL": "SPCT_26",
32
+ "DISGUSTED": "SPCT_27",
33
+ "SURPRISED": "SPCT_28",
34
+ "SARCASTIC": "SPCT_29",
35
+ "EXCITED": "SPCT_30",
36
+ "SLEEPY": "SPCT_31",
37
+ "CONFUSED": "SPCT_32",
38
+ "EMPHASIS": "SPCT_33",
39
+ "LAUGHING": "SPCT_34",
40
+ "SINGING": "SPCT_35",
41
+ "WORRIED": "SPCT_36",
42
+ "WHISPER": "SPCT_37",
43
+ "ANXIOUS": "SPCT_38",
44
+ "NO-AGREEMENT": "SPCT_39",
45
+ "APOLOGETIC": "SPCT_40",
46
+ "CONCERNED": "SPCT_41",
47
+ "ENUNCIATED": "SPCT_42",
48
+ "ASSERTIVE": "SPCT_43",
49
+ "ENCOURAGING": "SPCT_44",
50
+ "CONTEMPT": "SPCT_45",
51
+ }
52
+
53
+ # 注意:这里有两个GENDER_MAP定义,第二个会覆盖第一个
54
+ # 第一个定义包含了"unknown",第二个只包含"female"和"male"
55
+ # 建议使用第二个定义,因为它更简洁且符合实际使用场景
56
+ GENDER_MAP = {
57
+ "female": "SPCT_46",
58
+ "male": "SPCT_47"
59
+ }
60
+
61
+ def convert_standard_properties_to_tokens(age: str, gender: str, emotion: str, pitch: str, speed: str) -> list:
62
+ age_token = AGE_MAP[age.lower()]
63
+ gender_token = GENDER_MAP[gender.lower()]
64
+ emotion_token = EMOTION_MAP[emotion.upper()]
65
+ pitch_token = PITCH_MAP[pitch.lower()]
66
+ speed_token = SPEED_MAP[speed.lower()]
67
+ return "SPCT_0"+age_token+gender_token+emotion_token+pitch_token+speed_token
68
+
69
+ def convert_properties_to_tokens(age: str, gender: str, emotion: str, pitch: float, speed: float) -> list:
70
+ age_token = AGE_MAP[age.lower()]
71
+ gender_token = GENDER_MAP[gender.lower()]
72
+ emotion_token = EMOTION_MAP[emotion.upper()]
73
+ pitch_token = PITCH_MAP[classify_pitch(pitch, gender.lower(), age.lower())]
74
+ speed_token = SPEED_MAP[classify_speed(speed)]
75
+ return "SPCT_0"+age_token+gender_token+emotion_token+pitch_token+speed_token
76
+
77
+ def classify_speed(speed: float) -> str:
78
+ if speed <= 3.5:
79
+ return "very_slow"
80
+ elif 3.5 < speed < 4.0:
81
+ return "slow"
82
+ elif 4.0 < speed <= 4.5:
83
+ return "medium"
84
+ elif 4.5 < speed <= 5.0:
85
+ return "fast"
86
+ else: # speed >= 5.0
87
+ return "very_fast"
88
+ def classify_pitch(pitch: float, gender: str, age: str) -> str:
89
+ """
90
+ 根据性别和年龄重新划分pitch区间
91
+ 基于统计结果:
92
+ - female: 平均212.08, 中位数208.76, 25%分位数187.40, 75%分位数232.08
93
+ - male: 平均136.22, 中位数129.65, 25%分位数113.76, 75%分位数151.42
94
+ """
95
+ gender = gender.lower()
96
+ age = age.lower()
97
+
98
+ # 女性分类
99
+ if gender == "female":
100
+ if age == "child":
101
+ # Child: 平均280.12, 中位数279.34, 范围216.91-324.25
102
+ if pitch < 250:
103
+ return "low_pitch"
104
+ elif pitch < 290:
105
+ return "medium_pitch"
106
+ else:
107
+ return "high_pitch"
108
+ elif age == "teenager":
109
+ # Teenager: 平均240.61, 中位数238.43, 25%分位数207.54, 75%分位数270.12
110
+ if pitch < 208:
111
+ return "low_pitch"
112
+ elif pitch < 238:
113
+ return "medium_pitch"
114
+ elif pitch < 270:
115
+ return "high_pitch"
116
+ else:
117
+ return "very_high_pitch"
118
+ elif age == "youth-adult":
119
+ # Youth-Adult: 平均213.26, 中位数210.99, 25%分位数190.81, 75%分位数232.24
120
+ if pitch < 191:
121
+ return "low_pitch"
122
+ elif pitch < 211:
123
+ return "medium_pitch"
124
+ elif pitch < 232:
125
+ return "high_pitch"
126
+ else:
127
+ return "very_high_pitch"
128
+ elif age == "middle-aged":
129
+ # Middle-aged: 平均197.68, 中位数195.01, 25%分位数176.34, 75%分位数215.22
130
+ if pitch < 176:
131
+ return "low_pitch"
132
+ elif pitch < 195:
133
+ return "medium_pitch"
134
+ elif pitch < 215:
135
+ return "high_pitch"
136
+ else:
137
+ return "very_high_pitch"
138
+ elif age == "elderly":
139
+ # Elderly: 平均194.91, 中位数189.90, 25%分位数170.42, 75%分位数213.41
140
+ if pitch < 170:
141
+ return "low_pitch"
142
+ elif pitch < 190:
143
+ return "medium_pitch"
144
+ elif pitch < 213:
145
+ return "high_pitch"
146
+ else:
147
+ return "very_high_pitch"
148
+ else:
149
+ # 默认女性分类
150
+ if pitch < 187:
151
+ return "low_pitch"
152
+ elif pitch < 209:
153
+ return "medium_pitch"
154
+ elif pitch < 232:
155
+ return "high_pitch"
156
+ else:
157
+ return "very_high_pitch"
158
+
159
+ # 男性分类
160
+ elif gender == "male":
161
+ if age == "teenager":
162
+ # Teenager: 平均150.93, 中位数142.50, 25%分位数121.47, 75%分位数165.55
163
+ if pitch < 121:
164
+ return "low_pitch"
165
+ elif pitch < 143:
166
+ return "medium_pitch"
167
+ elif pitch < 166:
168
+ return "high_pitch"
169
+ else:
170
+ return "very_high_pitch"
171
+ elif age == "youth-adult":
172
+ # Youth-Adult: 平均137.17, 中位数130.92, 25%分位数114.70, 75%分位数153.18
173
+ if pitch < 115:
174
+ return "low_pitch"
175
+ elif pitch < 131:
176
+ return "medium_pitch"
177
+ elif pitch < 153:
178
+ return "high_pitch"
179
+ else:
180
+ return "very_high_pitch"
181
+ elif age == "middle-aged":
182
+ # Middle-aged: 平均132.33, 中位数125.30, 25%分位数110.31, 75%分位数146.55
183
+ if pitch < 110:
184
+ return "low_pitch"
185
+ elif pitch < 125:
186
+ return "medium_pitch"
187
+ elif pitch < 147:
188
+ return "high_pitch"
189
+ else:
190
+ return "very_high_pitch"
191
+ elif age == "elderly":
192
+ # Elderly: 平均132.62, 中位数128.42, 25%分位数114.69, 75%分位数141.57
193
+ if pitch < 115:
194
+ return "low_pitch"
195
+ elif pitch < 128:
196
+ return "medium_pitch"
197
+ elif pitch < 142:
198
+ return "high_pitch"
199
+ else:
200
+ return "very_high_pitch"
201
+ else:
202
+ # 默认男性分类
203
+ if pitch < 114:
204
+ return "low_pitch"
205
+ elif pitch < 130:
206
+ return "medium_pitch"
207
+ elif pitch < 151:
208
+ return "high_pitch"
209
+ else:
210
+ return "very_high_pitch"
211
+
212
+ # 未知性别,使用通用分类
213
+ else:
214
+ if pitch < 130:
215
+ return "low_pitch"
216
+ elif pitch < 180:
217
+ return "medium_pitch"
218
+ elif pitch < 220:
219
+ return "high_pitch"
220
+ else:
221
+ return "very_high_pitch"
rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/ref_audio_utilities.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime as ort
2
+ import numpy as np
3
+ import librosa
4
+ import soundfile as sf
5
+ import soxr
6
+ from pathlib import Path
7
+ from typing import Tuple, Union, Optional
8
+ import soundfile as sf
9
+
10
+
11
+ class RefAudioUtilities:
12
+ """音频处理工具类,使用ONNX模型生成tokens"""
13
+
14
+ def __init__(self, onnx_model_path: str, wav2vec2_path,
15
+ ref_segment_duration: float = 6.0, latent_hop_length: int = 320):
16
+ """
17
+ 初始化ONNX模型
18
+
19
+ Args:
20
+ onnx_model_path: ONNX模型文件路径
21
+ wav2vec2_path: wav2vec2 ONNX模型文件路径,如果为None则不加载wav2vec2模型
22
+ ref_segment_duration: 参考音频时长(秒)
23
+ latent_hop_length: 潜在特征跳长度
24
+ """
25
+ self.ort_session = ort.InferenceSession(onnx_model_path,
26
+ providers=['CUDAExecutionProvider','CPUExecutionProvider'])
27
+ print(f"🖥️ONNX Session actual providers: {self.ort_session.get_providers()}")
28
+ self.sample_rate = 16000
29
+ self.ref_segment_duration = ref_segment_duration
30
+ self.latent_hop_length = latent_hop_length
31
+
32
+ # 获取模型输入输出信息
33
+ self.input_names = [input_info.name for input_info in self.ort_session.get_inputs()]
34
+ self.output_names = [output_info.name for output_info in self.ort_session.get_outputs()]
35
+
36
+ print(f"模型输入: {self.input_names}")
37
+ print(f"模型输出: {self.output_names}")
38
+
39
+ # 初始化wav2vec2模型
40
+ self.wav2vec2_session = ort.InferenceSession(wav2vec2_path,
41
+ providers=['CUDAExecutionProvider','CPUExecutionProvider'])
42
+ print(f"🖥️Wav2Vec2 Session actual providers: {self.wav2vec2_session.get_providers()}")
43
+ def load_audio(self, audio_path: Union[str, Path], target_sr: int = 16000,
44
+ volume_normalize: bool = False) -> np.ndarray:
45
+ """
46
+ 加载音频文件,与BiCodecTokenizer保持一致
47
+
48
+ Args:
49
+ audio_path: 音频文件路径
50
+ target_sr: 目标采样率
51
+ volume_normalize: 是否进行音量归一化
52
+
53
+ Returns:
54
+ 音频数据数组
55
+ """
56
+ if isinstance(audio_path, str):
57
+ audio_path = Path(audio_path)
58
+
59
+ # 使用soundfile加载音频,与BiCodecTokenizer保持一致
60
+ audio, sr = sf.read(audio_path)
61
+ if len(audio.shape) > 1:
62
+ audio = audio[:, 0] # 如果是立体声,取第一个通道
63
+
64
+ # 重采样到目标采样率
65
+ if sr != target_sr:
66
+ audio = soxr.resample(audio, sr, target_sr, quality="VHQ")
67
+ sr = target_sr
68
+
69
+ # 音量归一化
70
+ if volume_normalize:
71
+ audio = self._audio_volume_normalize(audio)
72
+
73
+ return audio
74
+
75
+ def _audio_volume_normalize(self, audio: np.ndarray, coeff: float = 0.2) -> np.ndarray:
76
+ """音频音量归一化"""
77
+ # Sort the absolute values of the audio signal
78
+ temp = np.sort(np.abs(audio))
79
+
80
+ # If the maximum value is less than 0.1, scale the array to have a maximum of 0.1
81
+ if temp[-1] < 0.1:
82
+ scaling_factor = max(
83
+ temp[-1], 1e-3
84
+ ) # Prevent division by zero with a small constant
85
+ audio = audio / scaling_factor * 0.1
86
+
87
+ # Filter out values less than 0.01 from temp
88
+ temp = temp[temp > 0.01]
89
+ L = temp.shape[0] # Length of the filtered array
90
+
91
+ # If there are fewer than or equal to 10 significant values, return the audio without further processing
92
+ if L <= 10:
93
+ return audio
94
+
95
+ # Compute the average of the top 10% to 1% of values in temp
96
+ volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)])
97
+
98
+ # Normalize the audio to the target coefficient level, clamping the scale factor between 0.1 and 10
99
+ audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10)
100
+
101
+ # Ensure the maximum absolute value in the audio does not exceed 1
102
+ max_value = np.max(np.abs(audio))
103
+ if max_value > 1:
104
+ audio = audio / max_value
105
+
106
+ return audio
107
+
108
+ def extract_mel_spectrogram(self, wav: np.ndarray, n_mels: int = 128,
109
+ n_fft: int = 1024, hop_length: int = 320,
110
+ win_length: int = 640) -> np.ndarray:
111
+ """
112
+ 提取梅尔频谱图
113
+
114
+ Args:
115
+ wav: 音频数据
116
+ n_mels: 梅尔滤波器组数量
117
+ n_fft: FFT窗口大小
118
+ hop_length: 帧移
119
+ win_length: 窗口长度
120
+
121
+ Returns:
122
+ 梅尔频谱图
123
+ """
124
+ mel_spec = librosa.feature.melspectrogram(
125
+ y=wav,
126
+ sr=self.sample_rate,
127
+ n_mels=n_mels,
128
+ n_fft=n_fft,
129
+ hop_length=hop_length,
130
+ win_length=win_length,
131
+ power=1,
132
+ norm="slaney",
133
+ fmin=10,
134
+ )
135
+
136
+ return mel_spec
137
+
138
+ def extract_wav2vec2_features(self, wav: np.ndarray) -> np.ndarray:
139
+ """
140
+ 使用ONNX wav2vec2模型提取特征,模拟BiCodecTokenizer的行为
141
+
142
+ Args:
143
+ wav: 音频数据
144
+
145
+ Returns:
146
+ 特征向量
147
+ """
148
+ # 检查wav2vec2模型是否已加载
149
+ if self.wav2vec2_session is None:
150
+ raise RuntimeError("wav2vec2模型未加载,请在初始化时提供wav2vec2_path参数")
151
+
152
+ # 添加batch维度
153
+ input_data = wav[np.newaxis, :].astype(np.float32) # [1, sequence_length]
154
+
155
+ # 运行wav2vec2推理
156
+ # 注意:这个ONNX模型已经包含了特征提取器的预处理和多个隐藏层的组合
157
+ inputs = {'input': input_data}
158
+ outputs = self.wav2vec2_session.run(None, inputs)
159
+
160
+ # 输出形状应该是 [1, time_steps, 1024]
161
+ # 这个输出已经是通过选择隐藏层11, 14, 16并计算平均值得到的
162
+ print(f'outputs: {outputs}')
163
+ print(f'outputs: {outputs[0].shape}')
164
+ features = outputs[0][0] # 移除batch维度,得到 [time_steps, 1024]
165
+
166
+ return features.astype(np.float32)
167
+
168
+
169
+
170
+ def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
171
+ """
172
+ 获取参考音频片段,与BiCodecTokenizer保持一致
173
+
174
+ Args:
175
+ wav: 原始音频数据
176
+
177
+ Returns:
178
+ 参考音频片段
179
+ """
180
+ # 使用与BiCodecTokenizer相同的计算方式
181
+ ref_segment_length = (
182
+ int(self.sample_rate * self.ref_segment_duration)
183
+ // self.latent_hop_length
184
+ * self.latent_hop_length
185
+ )
186
+ wav_length = len(wav)
187
+
188
+ if ref_segment_length > wav_length:
189
+ # 如果音频不足指定长度,重复音频直到达到要求
190
+ repeat_times = ref_segment_length // wav_length + 1
191
+ wav = np.tile(wav, repeat_times)
192
+
193
+ # 截取指定长度
194
+ return wav[:ref_segment_length]
195
+
196
+ def process_audio(self, audio_path: Union[str, Path], volume_normalize: bool = False) -> Tuple[np.ndarray, np.ndarray]:
197
+ """
198
+ 处理音频文件,返回原始音频和参考音频,与BiCodecTokenizer保持一致
199
+
200
+ Args:
201
+ audio_path: 音频文件路径
202
+ volume_normalize: 是否进行音量归一化
203
+
204
+ Returns:
205
+ (原始音频, 参考音频)
206
+ """
207
+ wav = self.load_audio(audio_path, volume_normalize=volume_normalize)
208
+ ref_wav = self.get_ref_clip(wav)
209
+
210
+ return wav, ref_wav
211
+
212
+ def tokenize(self, audio_path: Union[str, Path]) -> Tuple[np.ndarray, np.ndarray]:
213
+ """
214
+ 使用ONNX模型生成tokens
215
+
216
+ Args:
217
+ audio_path: 音频文件路径
218
+
219
+ Returns:
220
+ (global_tokens, semantic_tokens)
221
+ """
222
+ # 处理音频
223
+ wav, ref_wav = self.process_audio(audio_path)
224
+
225
+ # 提取特征
226
+ feat = self.extract_wav2vec2_features(wav)
227
+ ref_mel = self.extract_mel_spectrogram(ref_wav)
228
+
229
+
230
+ # 添加batch维度
231
+ ref_mel_input = ref_mel[np.newaxis, :, :].astype(np.float32) # [1, 128, 301]
232
+ feat_input = feat[np.newaxis, :, :].astype(np.float32) # [1, feat_len, 1024]
233
+
234
+ # 运行ONNX模型
235
+ inputs = {
236
+ 'ref_wav_mel': ref_mel_input,
237
+ 'feat': feat_input
238
+ }
239
+
240
+ outputs = self.ort_session.run(self.output_names, inputs)
241
+
242
+ # 解析输出
243
+ semantic_tokens = outputs[0] # 第一个输出
244
+ global_tokens = outputs[1] # 第二个输出
245
+
246
+ return global_tokens, semantic_tokens
247
+
248
+ def tokenize_batch(self, audio_paths: list) -> Tuple[list, list]:
249
+ """
250
+ 批量处理音频文件
251
+
252
+ Args:
253
+ audio_paths: 音频文件路径列表
254
+
255
+ Returns:
256
+ (global_tokens_list, semantic_tokens_list)
257
+ """
258
+ global_tokens_list = []
259
+ semantic_tokens_list = []
260
+
261
+ for audio_path in audio_paths:
262
+ global_tokens, semantic_tokens = self.tokenize(audio_path)
263
+ global_tokens_list.append(global_tokens)
264
+ semantic_tokens_list.append(semantic_tokens)
265
+
266
+ return global_tokens_list, semantic_tokens_list
267
+
268
+
269
+ # 测试函数
270
+ def test_ref_audio_utilities():
271
+ """测试RefAudioUtilities类"""
272
+ # 初始化工具类
273
+ onnx_model_path = '/Volumes/bigdata/models/RWKVTTS_WebRWKV/BiCodecTokenize.onnx'
274
+ wav2vec2_path = "/Volumes/bigdata/models/RWKVTTS_WebRWKV/wav2vec2-large-xlsr-53.onnx"
275
+ # 使用与BiCodecTokenizer相同的���数
276
+ utilities = RefAudioUtilities(
277
+ onnx_model_path,
278
+ wav2vec2_path,
279
+ ref_segment_duration=6.0, # 6秒参考音频
280
+ latent_hop_length=320 # 潜在特征跳长度
281
+ )
282
+
283
+ # 测试音频文件(使用项目中的示例音频)
284
+ test_audio_path = "demos/刘德华/dehua_zh.wav"
285
+
286
+ if Path(test_audio_path).exists():
287
+ print(f"测试音频文件: {test_audio_path}")
288
+
289
+ try:
290
+ # 生成tokens
291
+ global_tokens, semantic_tokens = utilities.tokenize(test_audio_path)
292
+
293
+ print(f"Global tokens shape: {global_tokens.shape}")
294
+ print(f"Semantic tokens shape: {semantic_tokens.shape}")
295
+ print(f"Global tokens: {global_tokens.flatten().tolist()}")
296
+ print(f"Semantic tokens : {semantic_tokens.flatten().tolist()}")
297
+
298
+ except Exception as e:
299
+ print(f"处理音频时出错: {e}")
300
+ else:
301
+ print(f"测试音频文件不存在: {test_audio_path}")
302
+ print("请确保测试音频文件存在")
303
+
304
+
305
+ if __name__ == "__main__":
306
+ test_ref_audio_utilities()
rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/rwkv_vocab_v20230424.txt ADDED
The diff for this file is too large to render. See raw diff
 
rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/spark_llm.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Optional, Union, Tuple, Dict, Unpack
4
+ from transformers.modeling_utils import PreTrainedModel
5
+ from transformers.modeling_outputs import CausalLMOutputWithPast
6
+ from transformers.utils.deprecation import deprecate_kwarg
7
+ from rwkvfla.models.rwkv7.modeling_rwkv7 import RWKV7Model, RWKV7PreTrainedModel, Cache,RWKV7ForCausalLM
8
+ from rwkvfla.models.rwkv7.modeling_rwkv7 import FusedLinearCrossEntropyLoss, FusedCrossEntropyLoss
9
+ from transformers.generation.utils import GenerationMixin
10
+
11
+ from rwkvfla.models.rwkv7.configuration_rwkv7 import RWKV7Config
12
+
13
+ class RWKV7SpeechConfig(RWKV7Config):
14
+ def __init__(self, **kwargs):
15
+ super().__init__(**kwargs)
16
+ self.text_vocab_size = kwargs.get("text_vocab_size", kwargs.get("text_vocab_size"))
17
+ self.audio_global_vocab_size = kwargs.get("audio_global_vocab_size", kwargs.get("audio_global_vocab_size"))
18
+
19
+
20
+ class RWKV7ForSpeech(RWKV7ForCausalLM):
21
+ config_class = RWKV7SpeechConfig
22
+ def __init__(self, config: RWKV7SpeechConfig):
23
+ super().__init__(config)
24
+ self.model = RWKV7Model(config)
25
+ self.vocab_size = config.vocab_size
26
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)#Spark 0.5B vocab size is 8192 + 1 for eos resulting in 8193
27
+ self.criterion = None
28
+ self.text_embedder = nn.Embedding(config.text_vocab_size, config.hidden_size)
29
+ self.global_embedder = nn.Embedding(config.audio_global_vocab_size, config.hidden_size)#Spark 0.5B global token size is 4096
30
+ #TTS Tag includes GLOBAL=0, SEMANTIC=1,START_TTS=2
31
+ self.tts_tag_embedder = nn.Embedding(3, config.hidden_size)
32
+ # Initialize weights and apply final processing
33
+ self.post_init()
34
+ self.dropout = torch.nn.Dropout(0.02)
35
+
36
+ def get_input_embeddings(self):
37
+ return self.model.embeddings
38
+
39
+ def set_input_embeddings(self, value):
40
+ self.model.embeddings = value
41
+
42
+ def get_output_embeddings(self):
43
+ return self.lm_head
44
+
45
+ def set_output_embeddings(self, new_embeddings):
46
+ self.lm_head = new_embeddings
47
+
48
+ def set_decoder(self, decoder):
49
+ self.model = decoder
50
+
51
+ def get_decoder(self):
52
+ return self.model
53
+
54
+ def generate(self, *args, **kwargs):
55
+ try:
56
+ return super().generate(*args, **kwargs)
57
+ except AttributeError as exception:
58
+ if 'past_key_values' in str(exception):
59
+ raise AttributeError(
60
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
61
+ f"which is not supported for {self.__class__.__name__}. "
62
+ f"Try another generation strategy instead. "
63
+ f"For the available generation strategies, check this doc: "
64
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
65
+ )
66
+ else:
67
+ raise exception
68
+
69
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
70
+ def prepare_inputs_for_generation(
71
+ self,
72
+ input_ids: torch.LongTensor = None,
73
+ past_key_values: Optional[Cache] = None,
74
+ attention_mask: Optional[torch.Tensor] = None,
75
+ inputs_embeds: Optional[torch.Tensor] = None,
76
+ use_cache: bool = True,
77
+ logits_to_keep: Optional[int] = None,
78
+ **kwargs
79
+ ):
80
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
81
+ if past_key_values is not None and len(past_key_values) > 0:
82
+ input_ids = input_ids[:, -1:]
83
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
84
+ if inputs_embeds is not None and len(past_key_values) == 0:
85
+ model_inputs = {'inputs_embeds': inputs_embeds}
86
+ else:
87
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
88
+ # recompiles graphs as the stride of the inputs is a guard.
89
+ # Ref: https://github.com/huggingface/transformers/pull/29114
90
+ # TODO: use `next_tokens` directly instead.
91
+ model_inputs = {'input_ids': input_ids.contiguous()}
92
+
93
+ if logits_to_keep is not None:
94
+ model_inputs['logits_to_keep'] = logits_to_keep
95
+
96
+ model_inputs.update({
97
+ 'past_key_values': past_key_values,
98
+ 'use_cache': use_cache,
99
+ 'attention_mask': attention_mask,
100
+ 'logits_to_keep': logits_to_keep,
101
+ })
102
+ return model_inputs
103
+
104
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
105
+ def forward(
106
+ self,
107
+ input_ids: torch.LongTensor = None,
108
+ attention_mask: Optional[torch.Tensor] = None,
109
+ inputs_embeds: Optional[torch.Tensor] = None,
110
+ past_key_values: Optional[Cache] = None,
111
+ labels: Optional[torch.LongTensor] = None,
112
+ use_cache: Optional[bool] = None,
113
+ output_attentions: Optional[bool] = None,
114
+ output_hidden_states: Optional[bool] = None,
115
+ return_dict: Optional[bool] = None,
116
+ logits_to_keep: Optional[int] = 0,
117
+ **kwargs: Unpack[Dict]
118
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
119
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
120
+ output_hidden_states = (
121
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
122
+ )
123
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
124
+ if self.training and inputs_embeds is not None:
125
+ inputs_embeds = self.dropout(inputs_embeds)
126
+ outputs = self.model(
127
+ input_ids=input_ids,
128
+ attention_mask=attention_mask,
129
+ inputs_embeds=inputs_embeds,
130
+ past_key_values=past_key_values,
131
+ use_cache=use_cache,
132
+ output_attentions=output_attentions,
133
+ output_hidden_states=output_hidden_states,
134
+ return_dict=return_dict,
135
+ **kwargs
136
+ )
137
+
138
+ hidden_states = outputs[0]
139
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
140
+
141
+ loss, logits = None, None
142
+ if not fuse_linear_and_cross_entropy or labels is None:
143
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
144
+ if labels is not None:
145
+ if getattr(self, 'criterion', None) is None:
146
+ if fuse_linear_and_cross_entropy:
147
+ criterion = FusedLinearCrossEntropyLoss()
148
+ elif self.config.fuse_cross_entropy:
149
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
150
+ else:
151
+ criterion = nn.CrossEntropyLoss()
152
+ else:
153
+ criterion = self.criterion
154
+ # Enable model parallelism
155
+ labels = labels.to(hidden_states.device)
156
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
157
+ if fuse_linear_and_cross_entropy:
158
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
159
+ else:
160
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
161
+
162
+ if not return_dict:
163
+ output = (logits,) + outputs[1:]
164
+ return (loss,) + output if loss is not None else output
165
+
166
+ return CausalLMOutputWithPast(
167
+ loss=loss,
168
+ logits=logits,
169
+ past_key_values=outputs.past_key_values,
170
+ hidden_states=outputs.hidden_states,
171
+ attentions=outputs.attentions,
172
+ )
173
+
174
+ def copy_state_dict(self, state_dict: dict):
175
+ """从源 state dict 复制参数到当前模型,排除 embeddings 和 lm_head
176
+ The state dict is from original RWKV7 language model
177
+ Args:
178
+ state_dict: 源 state dict
179
+ """
180
+ # 获取当前模型的 state dict
181
+ target_dict = self.state_dict()
182
+
183
+ # 创建新的 state dict 用于存储要复制的参数
184
+ new_state_dict = {}
185
+
186
+ # 遍历源 state dict 的键
187
+ for key in state_dict.keys():
188
+ # 跳过 embeddings 和 lm_head 相关的参数
189
+ if key == 'model.embeddings.weight':
190
+ new_state_dict['text_embedder.weight'] = state_dict[key]
191
+ continue
192
+ if 'embeddings' in key or 'lm_head' in key:
193
+ continue
194
+ # 如果键在当前模型中存在,则复制参数
195
+ if key in target_dict:
196
+ new_state_dict[key] = state_dict[key]
197
+
198
+ # 加载新的 state dict 到当前模型
199
+ info = self.load_state_dict(new_state_dict, strict=False)
200
+ print(info)
201
+ return self
202
+
rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|rwkv_tokenizer_end_of_text|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": "\n\n",
10
+ "pad_token": {
11
+ "content": "<|rwkv_tokenizer_end_of_text|>",
12
+ "lstrip": false,
13
+ "normalized": false,
14
+ "rstrip": false,
15
+ "single_word": false
16
+ },
17
+ "unk_token": {
18
+ "content": "<|rwkv_tokenizer_end_of_text|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/texts_utilities.py ADDED
File without changes
rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/tokenizer_config.json ADDED
@@ -0,0 +1,836 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<|rwkv_tokenizer_end_of_text|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "65530": {
13
+ "content": "\n\n",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "65531": {
21
+ "content": "SPCT_0",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": false
27
+ },
28
+ "65532": {
29
+ "content": "SPCT_1",
30
+ "lstrip": false,
31
+ "normalized": true,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": false
35
+ },
36
+ "65533": {
37
+ "content": "SPCT_2",
38
+ "lstrip": false,
39
+ "normalized": true,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": false
43
+ },
44
+ "65534": {
45
+ "content": "SPCT_3",
46
+ "lstrip": false,
47
+ "normalized": true,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": false
51
+ },
52
+ "65535": {
53
+ "content": "SPCT_4",
54
+ "lstrip": false,
55
+ "normalized": true,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": false
59
+ },
60
+ "65536": {
61
+ "content": "SPCT_5",
62
+ "lstrip": false,
63
+ "normalized": true,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": false
67
+ },
68
+ "65537": {
69
+ "content": "SPCT_6",
70
+ "lstrip": false,
71
+ "normalized": true,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": false
75
+ },
76
+ "65538": {
77
+ "content": "SPCT_7",
78
+ "lstrip": false,
79
+ "normalized": true,
80
+ "rstrip": false,
81
+ "single_word": false,
82
+ "special": false
83
+ },
84
+ "65539": {
85
+ "content": "SPCT_8",
86
+ "lstrip": false,
87
+ "normalized": true,
88
+ "rstrip": false,
89
+ "single_word": false,
90
+ "special": false
91
+ },
92
+ "65540": {
93
+ "content": "SPCT_9",
94
+ "lstrip": false,
95
+ "normalized": true,
96
+ "rstrip": false,
97
+ "single_word": false,
98
+ "special": false
99
+ },
100
+ "65541": {
101
+ "content": "SPCT_10",
102
+ "lstrip": false,
103
+ "normalized": true,
104
+ "rstrip": false,
105
+ "single_word": false,
106
+ "special": false
107
+ },
108
+ "65542": {
109
+ "content": "SPCT_11",
110
+ "lstrip": false,
111
+ "normalized": true,
112
+ "rstrip": false,
113
+ "single_word": false,
114
+ "special": false
115
+ },
116
+ "65543": {
117
+ "content": "SPCT_12",
118
+ "lstrip": false,
119
+ "normalized": true,
120
+ "rstrip": false,
121
+ "single_word": false,
122
+ "special": false
123
+ },
124
+ "65544": {
125
+ "content": "SPCT_13",
126
+ "lstrip": false,
127
+ "normalized": true,
128
+ "rstrip": false,
129
+ "single_word": false,
130
+ "special": false
131
+ },
132
+ "65545": {
133
+ "content": "SPCT_14",
134
+ "lstrip": false,
135
+ "normalized": true,
136
+ "rstrip": false,
137
+ "single_word": false,
138
+ "special": false
139
+ },
140
+ "65546": {
141
+ "content": "SPCT_15",
142
+ "lstrip": false,
143
+ "normalized": true,
144
+ "rstrip": false,
145
+ "single_word": false,
146
+ "special": false
147
+ },
148
+ "65547": {
149
+ "content": "SPCT_16",
150
+ "lstrip": false,
151
+ "normalized": true,
152
+ "rstrip": false,
153
+ "single_word": false,
154
+ "special": false
155
+ },
156
+ "65548": {
157
+ "content": "SPCT_17",
158
+ "lstrip": false,
159
+ "normalized": true,
160
+ "rstrip": false,
161
+ "single_word": false,
162
+ "special": false
163
+ },
164
+ "65549": {
165
+ "content": "SPCT_18",
166
+ "lstrip": false,
167
+ "normalized": true,
168
+ "rstrip": false,
169
+ "single_word": false,
170
+ "special": false
171
+ },
172
+ "65550": {
173
+ "content": "SPCT_19",
174
+ "lstrip": false,
175
+ "normalized": true,
176
+ "rstrip": false,
177
+ "single_word": false,
178
+ "special": false
179
+ },
180
+ "65551": {
181
+ "content": "SPCT_20",
182
+ "lstrip": false,
183
+ "normalized": true,
184
+ "rstrip": false,
185
+ "single_word": false,
186
+ "special": false
187
+ },
188
+ "65552": {
189
+ "content": "SPCT_21",
190
+ "lstrip": false,
191
+ "normalized": true,
192
+ "rstrip": false,
193
+ "single_word": false,
194
+ "special": false
195
+ },
196
+ "65553": {
197
+ "content": "SPCT_22",
198
+ "lstrip": false,
199
+ "normalized": true,
200
+ "rstrip": false,
201
+ "single_word": false,
202
+ "special": false
203
+ },
204
+ "65554": {
205
+ "content": "SPCT_23",
206
+ "lstrip": false,
207
+ "normalized": true,
208
+ "rstrip": false,
209
+ "single_word": false,
210
+ "special": false
211
+ },
212
+ "65555": {
213
+ "content": "SPCT_24",
214
+ "lstrip": false,
215
+ "normalized": true,
216
+ "rstrip": false,
217
+ "single_word": false,
218
+ "special": false
219
+ },
220
+ "65556": {
221
+ "content": "SPCT_25",
222
+ "lstrip": false,
223
+ "normalized": true,
224
+ "rstrip": false,
225
+ "single_word": false,
226
+ "special": false
227
+ },
228
+ "65557": {
229
+ "content": "SPCT_26",
230
+ "lstrip": false,
231
+ "normalized": true,
232
+ "rstrip": false,
233
+ "single_word": false,
234
+ "special": false
235
+ },
236
+ "65558": {
237
+ "content": "SPCT_27",
238
+ "lstrip": false,
239
+ "normalized": true,
240
+ "rstrip": false,
241
+ "single_word": false,
242
+ "special": false
243
+ },
244
+ "65559": {
245
+ "content": "SPCT_28",
246
+ "lstrip": false,
247
+ "normalized": true,
248
+ "rstrip": false,
249
+ "single_word": false,
250
+ "special": false
251
+ },
252
+ "65560": {
253
+ "content": "SPCT_29",
254
+ "lstrip": false,
255
+ "normalized": true,
256
+ "rstrip": false,
257
+ "single_word": false,
258
+ "special": false
259
+ },
260
+ "65561": {
261
+ "content": "SPCT_30",
262
+ "lstrip": false,
263
+ "normalized": true,
264
+ "rstrip": false,
265
+ "single_word": false,
266
+ "special": false
267
+ },
268
+ "65562": {
269
+ "content": "SPCT_31",
270
+ "lstrip": false,
271
+ "normalized": true,
272
+ "rstrip": false,
273
+ "single_word": false,
274
+ "special": false
275
+ },
276
+ "65563": {
277
+ "content": "SPCT_32",
278
+ "lstrip": false,
279
+ "normalized": true,
280
+ "rstrip": false,
281
+ "single_word": false,
282
+ "special": false
283
+ },
284
+ "65564": {
285
+ "content": "SPCT_33",
286
+ "lstrip": false,
287
+ "normalized": true,
288
+ "rstrip": false,
289
+ "single_word": false,
290
+ "special": false
291
+ },
292
+ "65565": {
293
+ "content": "SPCT_34",
294
+ "lstrip": false,
295
+ "normalized": true,
296
+ "rstrip": false,
297
+ "single_word": false,
298
+ "special": false
299
+ },
300
+ "65566": {
301
+ "content": "SPCT_35",
302
+ "lstrip": false,
303
+ "normalized": true,
304
+ "rstrip": false,
305
+ "single_word": false,
306
+ "special": false
307
+ },
308
+ "65567": {
309
+ "content": "SPCT_36",
310
+ "lstrip": false,
311
+ "normalized": true,
312
+ "rstrip": false,
313
+ "single_word": false,
314
+ "special": false
315
+ },
316
+ "65568": {
317
+ "content": "SPCT_37",
318
+ "lstrip": false,
319
+ "normalized": true,
320
+ "rstrip": false,
321
+ "single_word": false,
322
+ "special": false
323
+ },
324
+ "65569": {
325
+ "content": "SPCT_38",
326
+ "lstrip": false,
327
+ "normalized": true,
328
+ "rstrip": false,
329
+ "single_word": false,
330
+ "special": false
331
+ },
332
+ "65570": {
333
+ "content": "SPCT_39",
334
+ "lstrip": false,
335
+ "normalized": true,
336
+ "rstrip": false,
337
+ "single_word": false,
338
+ "special": false
339
+ },
340
+ "65571": {
341
+ "content": "SPCT_40",
342
+ "lstrip": false,
343
+ "normalized": true,
344
+ "rstrip": false,
345
+ "single_word": false,
346
+ "special": false
347
+ },
348
+ "65572": {
349
+ "content": "SPCT_41",
350
+ "lstrip": false,
351
+ "normalized": true,
352
+ "rstrip": false,
353
+ "single_word": false,
354
+ "special": false
355
+ },
356
+ "65573": {
357
+ "content": "SPCT_42",
358
+ "lstrip": false,
359
+ "normalized": true,
360
+ "rstrip": false,
361
+ "single_word": false,
362
+ "special": false
363
+ },
364
+ "65574": {
365
+ "content": "SPCT_43",
366
+ "lstrip": false,
367
+ "normalized": true,
368
+ "rstrip": false,
369
+ "single_word": false,
370
+ "special": false
371
+ },
372
+ "65575": {
373
+ "content": "SPCT_44",
374
+ "lstrip": false,
375
+ "normalized": true,
376
+ "rstrip": false,
377
+ "single_word": false,
378
+ "special": false
379
+ },
380
+ "65576": {
381
+ "content": "SPCT_45",
382
+ "lstrip": false,
383
+ "normalized": true,
384
+ "rstrip": false,
385
+ "single_word": false,
386
+ "special": false
387
+ },
388
+ "65577": {
389
+ "content": "SPCT_46",
390
+ "lstrip": false,
391
+ "normalized": true,
392
+ "rstrip": false,
393
+ "single_word": false,
394
+ "special": false
395
+ },
396
+ "65578": {
397
+ "content": "SPCT_47",
398
+ "lstrip": false,
399
+ "normalized": true,
400
+ "rstrip": false,
401
+ "single_word": false,
402
+ "special": false
403
+ },
404
+ "65579": {
405
+ "content": "SPCT_48",
406
+ "lstrip": false,
407
+ "normalized": true,
408
+ "rstrip": false,
409
+ "single_word": false,
410
+ "special": false
411
+ },
412
+ "65580": {
413
+ "content": "SPCT_49",
414
+ "lstrip": false,
415
+ "normalized": true,
416
+ "rstrip": false,
417
+ "single_word": false,
418
+ "special": false
419
+ },
420
+ "65581": {
421
+ "content": "SPCT_50",
422
+ "lstrip": false,
423
+ "normalized": true,
424
+ "rstrip": false,
425
+ "single_word": false,
426
+ "special": false
427
+ },
428
+ "65582": {
429
+ "content": "SPCT_51",
430
+ "lstrip": false,
431
+ "normalized": true,
432
+ "rstrip": false,
433
+ "single_word": false,
434
+ "special": false
435
+ },
436
+ "65583": {
437
+ "content": "SPCT_52",
438
+ "lstrip": false,
439
+ "normalized": true,
440
+ "rstrip": false,
441
+ "single_word": false,
442
+ "special": false
443
+ },
444
+ "65584": {
445
+ "content": "SPCT_53",
446
+ "lstrip": false,
447
+ "normalized": true,
448
+ "rstrip": false,
449
+ "single_word": false,
450
+ "special": false
451
+ },
452
+ "65585": {
453
+ "content": "SPCT_54",
454
+ "lstrip": false,
455
+ "normalized": true,
456
+ "rstrip": false,
457
+ "single_word": false,
458
+ "special": false
459
+ },
460
+ "65586": {
461
+ "content": "SPCT_55",
462
+ "lstrip": false,
463
+ "normalized": true,
464
+ "rstrip": false,
465
+ "single_word": false,
466
+ "special": false
467
+ },
468
+ "65587": {
469
+ "content": "SPCT_56",
470
+ "lstrip": false,
471
+ "normalized": true,
472
+ "rstrip": false,
473
+ "single_word": false,
474
+ "special": false
475
+ },
476
+ "65588": {
477
+ "content": "SPCT_57",
478
+ "lstrip": false,
479
+ "normalized": true,
480
+ "rstrip": false,
481
+ "single_word": false,
482
+ "special": false
483
+ },
484
+ "65589": {
485
+ "content": "SPCT_58",
486
+ "lstrip": false,
487
+ "normalized": true,
488
+ "rstrip": false,
489
+ "single_word": false,
490
+ "special": false
491
+ },
492
+ "65590": {
493
+ "content": "SPCT_59",
494
+ "lstrip": false,
495
+ "normalized": true,
496
+ "rstrip": false,
497
+ "single_word": false,
498
+ "special": false
499
+ },
500
+ "65591": {
501
+ "content": "SPCT_60",
502
+ "lstrip": false,
503
+ "normalized": true,
504
+ "rstrip": false,
505
+ "single_word": false,
506
+ "special": false
507
+ },
508
+ "65592": {
509
+ "content": "SPCT_61",
510
+ "lstrip": false,
511
+ "normalized": true,
512
+ "rstrip": false,
513
+ "single_word": false,
514
+ "special": false
515
+ },
516
+ "65593": {
517
+ "content": "SPCT_62",
518
+ "lstrip": false,
519
+ "normalized": true,
520
+ "rstrip": false,
521
+ "single_word": false,
522
+ "special": false
523
+ },
524
+ "65594": {
525
+ "content": "SPCT_63",
526
+ "lstrip": false,
527
+ "normalized": true,
528
+ "rstrip": false,
529
+ "single_word": false,
530
+ "special": false
531
+ },
532
+ "65595": {
533
+ "content": "SPCT_64",
534
+ "lstrip": false,
535
+ "normalized": true,
536
+ "rstrip": false,
537
+ "single_word": false,
538
+ "special": false
539
+ },
540
+ "65596": {
541
+ "content": "SPCT_65",
542
+ "lstrip": false,
543
+ "normalized": true,
544
+ "rstrip": false,
545
+ "single_word": false,
546
+ "special": false
547
+ },
548
+ "65597": {
549
+ "content": "SPCT_66",
550
+ "lstrip": false,
551
+ "normalized": true,
552
+ "rstrip": false,
553
+ "single_word": false,
554
+ "special": false
555
+ },
556
+ "65598": {
557
+ "content": "SPCT_67",
558
+ "lstrip": false,
559
+ "normalized": true,
560
+ "rstrip": false,
561
+ "single_word": false,
562
+ "special": false
563
+ },
564
+ "65599": {
565
+ "content": "SPCT_68",
566
+ "lstrip": false,
567
+ "normalized": true,
568
+ "rstrip": false,
569
+ "single_word": false,
570
+ "special": false
571
+ },
572
+ "65600": {
573
+ "content": "SPCT_69",
574
+ "lstrip": false,
575
+ "normalized": true,
576
+ "rstrip": false,
577
+ "single_word": false,
578
+ "special": false
579
+ },
580
+ "65601": {
581
+ "content": "SPCT_70",
582
+ "lstrip": false,
583
+ "normalized": true,
584
+ "rstrip": false,
585
+ "single_word": false,
586
+ "special": false
587
+ },
588
+ "65602": {
589
+ "content": "SPCT_71",
590
+ "lstrip": false,
591
+ "normalized": true,
592
+ "rstrip": false,
593
+ "single_word": false,
594
+ "special": false
595
+ },
596
+ "65603": {
597
+ "content": "SPCT_72",
598
+ "lstrip": false,
599
+ "normalized": true,
600
+ "rstrip": false,
601
+ "single_word": false,
602
+ "special": false
603
+ },
604
+ "65604": {
605
+ "content": "SPCT_73",
606
+ "lstrip": false,
607
+ "normalized": true,
608
+ "rstrip": false,
609
+ "single_word": false,
610
+ "special": false
611
+ },
612
+ "65605": {
613
+ "content": "SPCT_74",
614
+ "lstrip": false,
615
+ "normalized": true,
616
+ "rstrip": false,
617
+ "single_word": false,
618
+ "special": false
619
+ },
620
+ "65606": {
621
+ "content": "SPCT_75",
622
+ "lstrip": false,
623
+ "normalized": true,
624
+ "rstrip": false,
625
+ "single_word": false,
626
+ "special": false
627
+ },
628
+ "65607": {
629
+ "content": "SPCT_76",
630
+ "lstrip": false,
631
+ "normalized": true,
632
+ "rstrip": false,
633
+ "single_word": false,
634
+ "special": false
635
+ },
636
+ "65608": {
637
+ "content": "SPCT_77",
638
+ "lstrip": false,
639
+ "normalized": true,
640
+ "rstrip": false,
641
+ "single_word": false,
642
+ "special": false
643
+ },
644
+ "65609": {
645
+ "content": "SPCT_78",
646
+ "lstrip": false,
647
+ "normalized": true,
648
+ "rstrip": false,
649
+ "single_word": false,
650
+ "special": false
651
+ },
652
+ "65610": {
653
+ "content": "SPCT_79",
654
+ "lstrip": false,
655
+ "normalized": true,
656
+ "rstrip": false,
657
+ "single_word": false,
658
+ "special": false
659
+ },
660
+ "65611": {
661
+ "content": "SPCT_80",
662
+ "lstrip": false,
663
+ "normalized": true,
664
+ "rstrip": false,
665
+ "single_word": false,
666
+ "special": false
667
+ },
668
+ "65612": {
669
+ "content": "SPCT_81",
670
+ "lstrip": false,
671
+ "normalized": true,
672
+ "rstrip": false,
673
+ "single_word": false,
674
+ "special": false
675
+ },
676
+ "65613": {
677
+ "content": "SPCT_82",
678
+ "lstrip": false,
679
+ "normalized": true,
680
+ "rstrip": false,
681
+ "single_word": false,
682
+ "special": false
683
+ },
684
+ "65614": {
685
+ "content": "SPCT_83",
686
+ "lstrip": false,
687
+ "normalized": true,
688
+ "rstrip": false,
689
+ "single_word": false,
690
+ "special": false
691
+ },
692
+ "65615": {
693
+ "content": "SPCT_84",
694
+ "lstrip": false,
695
+ "normalized": true,
696
+ "rstrip": false,
697
+ "single_word": false,
698
+ "special": false
699
+ },
700
+ "65616": {
701
+ "content": "SPCT_85",
702
+ "lstrip": false,
703
+ "normalized": true,
704
+ "rstrip": false,
705
+ "single_word": false,
706
+ "special": false
707
+ },
708
+ "65617": {
709
+ "content": "SPCT_86",
710
+ "lstrip": false,
711
+ "normalized": true,
712
+ "rstrip": false,
713
+ "single_word": false,
714
+ "special": false
715
+ },
716
+ "65618": {
717
+ "content": "SPCT_87",
718
+ "lstrip": false,
719
+ "normalized": true,
720
+ "rstrip": false,
721
+ "single_word": false,
722
+ "special": false
723
+ },
724
+ "65619": {
725
+ "content": "SPCT_88",
726
+ "lstrip": false,
727
+ "normalized": true,
728
+ "rstrip": false,
729
+ "single_word": false,
730
+ "special": false
731
+ },
732
+ "65620": {
733
+ "content": "SPCT_89",
734
+ "lstrip": false,
735
+ "normalized": true,
736
+ "rstrip": false,
737
+ "single_word": false,
738
+ "special": false
739
+ },
740
+ "65621": {
741
+ "content": "SPCT_90",
742
+ "lstrip": false,
743
+ "normalized": true,
744
+ "rstrip": false,
745
+ "single_word": false,
746
+ "special": false
747
+ },
748
+ "65622": {
749
+ "content": "SPCT_91",
750
+ "lstrip": false,
751
+ "normalized": true,
752
+ "rstrip": false,
753
+ "single_word": false,
754
+ "special": false
755
+ },
756
+ "65623": {
757
+ "content": "SPCT_92",
758
+ "lstrip": false,
759
+ "normalized": true,
760
+ "rstrip": false,
761
+ "single_word": false,
762
+ "special": false
763
+ },
764
+ "65624": {
765
+ "content": "SPCT_93",
766
+ "lstrip": false,
767
+ "normalized": true,
768
+ "rstrip": false,
769
+ "single_word": false,
770
+ "special": false
771
+ },
772
+ "65625": {
773
+ "content": "SPCT_94",
774
+ "lstrip": false,
775
+ "normalized": true,
776
+ "rstrip": false,
777
+ "single_word": false,
778
+ "special": false
779
+ },
780
+ "65626": {
781
+ "content": "SPCT_95",
782
+ "lstrip": false,
783
+ "normalized": true,
784
+ "rstrip": false,
785
+ "single_word": false,
786
+ "special": false
787
+ },
788
+ "65627": {
789
+ "content": "SPCT_96",
790
+ "lstrip": false,
791
+ "normalized": true,
792
+ "rstrip": false,
793
+ "single_word": false,
794
+ "special": false
795
+ },
796
+ "65628": {
797
+ "content": "SPCT_97",
798
+ "lstrip": false,
799
+ "normalized": true,
800
+ "rstrip": false,
801
+ "single_word": false,
802
+ "special": false
803
+ },
804
+ "65629": {
805
+ "content": "SPCT_98",
806
+ "lstrip": false,
807
+ "normalized": true,
808
+ "rstrip": false,
809
+ "single_word": false,
810
+ "special": false
811
+ },
812
+ "65630": {
813
+ "content": "SPCT_99",
814
+ "lstrip": false,
815
+ "normalized": true,
816
+ "rstrip": false,
817
+ "single_word": false,
818
+ "special": false
819
+ }
820
+ },
821
+ "auto_map": {
822
+ "AutoTokenizer": [
823
+ "hf_rwkv_tokenizer.RwkvTokenizer",
824
+ null
825
+ ]
826
+ },
827
+ "bos_token": "<|rwkv_tokenizer_end_of_text|>",
828
+ "clean_up_tokenization_spaces": false,
829
+ "eos_token": "\n\n",
830
+ "extra_special_tokens": {},
831
+ "model_max_length": 1000000000000000019884624838656,
832
+ "pad_token": "<|rwkv_tokenizer_end_of_text|>",
833
+ "tokenizer_class": "RwkvTokenizer",
834
+ "unk_token": "<|rwkv_tokenizer_end_of_text|>",
835
+ "use_fast": false
836
+ }
rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/translation_data.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tts_cli import TTSGenerator
2
+ import webrwkv_py
3
+ import time
4
+ from transformers import AutoTokenizer
5
+
6
+ model_path = "/home/yueyulin/models/rwkvtts-respark-webrwkv/"
7
+ decoder_path = f'{model_path}/BiCodecDetokenize.onnx'
8
+ device_idx = 0
9
+
10
+ webrwkv_model_path = f'{model_path}/webrwkv.safetensors'
11
+ print(f"🔍 尝试加载模型文件: {webrwkv_model_path} time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))}")
12
+ model = webrwkv_py.Model(webrwkv_model_path, 'fp32', device_idx)
13
+ print(f"✅ 模型加载成功: {webrwkv_model_path} time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))}")
14
+
15
+
16
+ runtime = model.create_thread_runtime()
17
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
18
+ print(f"✅ tokenizer 加载成功: {model_path} time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))}")
19
+ generator = TTSGenerator(runtime, tokenizer, decoder_path, device_idx, model_path)
20
+ print(f"✅ generator 创建成功: {model_path} time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))}")
21
+
22
+
23
+ chinese_text = "一开始,很多人把这次危机比作一九八二年或一九七三年所发生的情况,这样得类比是令人宽心的,因为这两段时期意味着典型的周期性衰退。"
24
+ english_text = "At the start of the crisis, many people likened it to 1982 or 1973, which was reassuring, because both dates refer to classical cyclical downturns."
25
+
26
+ global_tokens, semantic_tokens, global_time, global_speed, semantic_time, semantic_speed = generator._generate_tokens(chinese_text,'middle-aged','male','happy','medium_pitch','medium')
27
+ print(f"✅ 生成完成: {chinese_text} time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))}")
28
+ print(f"🎯 global_tokens: {global_tokens}")
29
+ print(f"🎯 semantic_tokens: {semantic_tokens}")
30
+ print(f"🎯 global_time: {global_time}")
31
+ print(f"🎯 global_speed: {global_speed}")
32
+ print(f"🎯 semantic_time: {semantic_time}")
33
+ print(f"🎯 semantic_speed: {semantic_speed}")
34
+
35
+ wav_data, audio_duration, decode_time, decode_speed = generator._decode_audio(global_tokens, semantic_tokens)
36
+ print(f"✅ 解码完成: {audio_duration:.2f}s,耗时 {decode_time:.2f}s,速度 {decode_speed:.1f} tokens/s")
37
+ generator._save_audio(wav_data, "chinese_text.wav", 16000)
38
+ generator.reset_runtime()
39
+ global_tokens, semantic_tokens, prefill_time, prefill_speed, semantic_time, semantic_speed = generator._generate_tokens_with_global_tokens(english_text, global_tokens)
40
+ print(f"✅ 生成完成: {english_text} time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))}")
41
+ print(f"🎯 global_tokens: {global_tokens}")
42
+ print(f"🎯 semantic_tokens: {semantic_tokens}")
43
+ print(f"🎯 prefill_time: {prefill_time}")
44
+ print(f"🎯 prefill_speed: {prefill_speed}")
45
+ print(f"🎯 semantic_time: {semantic_time}")
46
+ print(f"🎯 semantic_speed: {semantic_speed}")
47
+ wav_data, audio_duration, decode_time, decode_speed = generator._decode_audio(global_tokens, semantic_tokens)
48
+ print(f"✅ 解码完成: {audio_duration:.2f}s,耗时 {decode_time:.2f}s,速度 {decode_speed:.1f} tokens/s")
49
+ generator._save_audio(wav_data, "english_text.wav", 16000)
50
+
51
+
52
+
53
+
54
+
55
+
rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/tts_cli.py ADDED
@@ -0,0 +1,992 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ RWKV TTS 交互式音频生成工具
5
+ 使用 webrwkv_py 和 ONNX Runtime 进行音频生成
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import re
11
+ import time
12
+ import warnings
13
+ import logging
14
+ from pathlib import Path
15
+ from typing import Dict, Any, Tuple, List
16
+
17
+ import numpy as np
18
+ import soundfile as sf
19
+ import click
20
+
21
+ # 配置日志
22
+ def setup_logging():
23
+ """设置日志配置"""
24
+ # 从环境变量获取日志级别,默认为WARNING
25
+ log_level_str = os.environ.get('LOG_LEVEL', 'WARNING').upper()
26
+ log_level = getattr(logging, log_level_str, logging.WARNING)
27
+
28
+ # 配置日志格式
29
+ logging.basicConfig(
30
+ level=log_level,
31
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
32
+ datefmt='%Y-%m-%d %H:%M:%S'
33
+ )
34
+
35
+ return logging.getLogger(__name__)
36
+
37
+ # 创建logger实例
38
+ logger = setup_logging()
39
+
40
+ # 抑制警告
41
+ warnings.filterwarnings("ignore", category=UserWarning, module="numpy")
42
+ warnings.filterwarnings("ignore", category=UserWarning, module="onnxruntime")
43
+ warnings.filterwarnings("ignore", category=UserWarning, module="torch")
44
+ warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
45
+ np.seterr(all='ignore')
46
+
47
+ # 检查并导入必要的库
48
+ try:
49
+ import webrwkv_py
50
+ HAS_WEBRWKV = True
51
+ except ImportError:
52
+ HAS_WEBRWKV = False
53
+ logger.error("❌ 错误: 需要安装 'webrwkv_py' 库")
54
+ logger.error("请运行: pip install webrwkv_py")
55
+ sys.exit(1)
56
+
57
+ try:
58
+ import onnxruntime as ort
59
+ HAS_ONNX = True
60
+ except ImportError:
61
+ HAS_ONNX = False
62
+ logger.error("❌ 错误: 需要安装 'onnxruntime' 库")
63
+ logger.error("请运行: pip install onnxruntime")
64
+ sys.exit(1)
65
+
66
+ try:
67
+ from transformers import AutoTokenizer
68
+ HAS_TRANSFORMERS = True
69
+ except ImportError:
70
+ HAS_TRANSFORMERS = False
71
+ logger.error("❌ 错误: 需要安装 'transformers' 库")
72
+ logger.error("请运行: pip install transformers")
73
+ sys.exit(1)
74
+
75
+ try:
76
+ import questionary
77
+ HAS_QUESTIONARY = True
78
+ except ImportError:
79
+ HAS_QUESTIONARY = False
80
+ logger.warning("⚠️ 警告: 无法导入 questionary 库来使用交互式界面")
81
+ logger.warning("请运行: pip install questionary")
82
+ sys.exit(1)
83
+
84
+ # 导入属性工具
85
+ try:
86
+ from properties_util import (
87
+ SPEED_MAP, PITCH_MAP, AGE_MAP, GENDER_MAP, EMOTION_MAP
88
+ )
89
+ # 从映射中提取选项
90
+ age_choices = list(AGE_MAP.keys())
91
+ gender_choices = list(GENDER_MAP.keys())
92
+ emotion_choices = list(EMOTION_MAP.keys())
93
+ pitch_choices = list(PITCH_MAP.keys())
94
+ speed_choices = list(SPEED_MAP.keys())
95
+ except ImportError:
96
+ logger.warning("⚠️ 警告: 无法导入 properties_util,使用默认选项")
97
+ # 默认选项
98
+ age_choices = ['child', 'teenager', 'youth-adult', 'middle-aged', 'elderly']
99
+ gender_choices = ['female', 'male'] # 与properties_util.py保持一致
100
+ emotion_choices = ['NEUTRAL', 'HAPPY', 'SAD', 'ANGRY', 'FEARFUL', 'DISGUSTED', 'SURPRISED']
101
+ pitch_choices = ['low_pitch', 'medium_pitch', 'high_pitch', 'very_high_pitch']
102
+ speed_choices = ['very_slow', 'slow', 'medium', 'fast', 'very_fast']
103
+
104
+ def detect_token_lang(token: str) -> str:
105
+ """基于字符集合的简单词级语言检测。返回 'en' 或 'zh'。"""
106
+ if not token:
107
+ return 'en'
108
+ has_zh = re.search(r"[\u4e00-\u9fff]", token) is not None
109
+ has_en = re.search(r"[A-Za-z]", token) is not None
110
+ if has_zh and not has_en:
111
+ return 'zh'
112
+ if has_en and not has_zh:
113
+ return 'en'
114
+ if has_zh and has_en:
115
+ return 'zh'
116
+ return 'en'
117
+
118
+ def sample_logits(logits, temperature=1.0, top_p=0.85, top_k=0):
119
+ """从logits中采样token"""
120
+ if temperature == 0:
121
+ temperature = 1.0
122
+ top_p = 0
123
+
124
+ if isinstance(logits, list):
125
+ logits = np.array(logits)
126
+
127
+ try:
128
+ from scipy import special
129
+ probs = special.softmax(logits, axis=-1)
130
+ except ImportError:
131
+ # 如果没有scipy,使用numpy的简单实现
132
+ exp_logits = np.exp(logits - np.max(logits))
133
+ probs = exp_logits / np.sum(exp_logits)
134
+
135
+ top_k = int(top_k)
136
+
137
+ sorted_ids = np.argsort(probs)
138
+ sorted_probs = probs[sorted_ids][::-1]
139
+ cumulative_probs = np.cumsum(sorted_probs)
140
+
141
+ cutoff_mask = cumulative_probs >= top_p
142
+ if np.any(cutoff_mask):
143
+ cutoff_idx = np.argmax(cutoff_mask)
144
+ cutoff = float(sorted_probs[cutoff_idx])
145
+ probs[probs < cutoff] = 0
146
+
147
+ if top_k < len(probs) and top_k > 0:
148
+ probs[sorted_ids[:-top_k]] = 0
149
+
150
+ if temperature != 1.0:
151
+ probs = probs ** (1.0 / temperature)
152
+
153
+ probs = probs / np.sum(probs)
154
+ out = np.random.choice(a=len(probs), size=1, p=probs)
155
+ return int(out[0])
156
+
157
+ def get_unique_filename(output_dir, text, extension=".wav"):
158
+ """生成唯一的文件名,避免重名"""
159
+ output_dir = Path(output_dir)
160
+ output_dir.mkdir(parents=True, exist_ok=True)
161
+
162
+ prefix = text[:3] if len(text) >= 3 else text
163
+ prefix = re.sub(r'[\W\s]', '', prefix).strip()
164
+
165
+ base_name = prefix
166
+ index = 0
167
+
168
+ while True:
169
+ if index == 0:
170
+ filename = base_name + extension
171
+ else:
172
+ filename = f"{base_name}_{index}{extension}"
173
+
174
+ filepath = output_dir / filename
175
+ if not filepath.exists():
176
+ return str(filepath)
177
+ index += 1
178
+
179
+ class TTSGenerator:
180
+ """TTS生成器类,负责音频生成和统计"""
181
+
182
+ def __init__(self, runtime, tokenizer, decoder_path, device, model_path):
183
+ self.runtime = runtime
184
+ self.tokenizer = tokenizer
185
+ self.decoder_path = decoder_path
186
+ self.device = device
187
+ self.model_path = model_path
188
+
189
+ # 初始化 RefAudioUtilities 实例
190
+ logger.info('🎿 开始加载音频编码器模型')
191
+ try:
192
+ audio_tokenizer_path = os.path.join(model_path, 'BiCodecTokenize.onnx')
193
+ wav2vec2_path = os.path.join(model_path, 'wav2vec2-large-xlsr-53.onnx')
194
+ from ref_audio_utilities import RefAudioUtilities
195
+ self.ref_audio_utilities = RefAudioUtilities(audio_tokenizer_path, wav2vec2_path)
196
+ logger.info('✅ 音频编码器模型加载成功')
197
+ except Exception as e:
198
+ logger.error(f'❌ 音频编码器模型加载失败: {e}')
199
+ self.ref_audio_utilities = None
200
+
201
+ # 缓存ONNX session
202
+ logger.info('🎿 开始加载ONNX模型')
203
+ try:
204
+ self.ort_session = ort.InferenceSession(decoder_path,
205
+ providers=['CUDAExecutionProvider','CPUExecutionProvider'])
206
+ logger.info(f"🖥️ONNX Session for generate wavform actual providers: {self.ort_session.get_providers()}")
207
+ logger.info('✅ ONNX模型加载成功')
208
+ except Exception as e:
209
+ logger.error(f'❌ ONNX模型加载失败: {e}')
210
+ raise
211
+
212
+ # 生成统计信息
213
+ self.generation_stats = {
214
+ 'total_generations': 0,
215
+ 'total_tokens': 0,
216
+ 'total_time': 0.0,
217
+ 'last_generation': {
218
+ 'text': '',
219
+ 'params': {},
220
+ 'total_time': 0.0,
221
+ 'total_tokens': 0,
222
+ 'audio_duration': 0.0,
223
+ 'rtf': 0.0,
224
+ 'global_speed': 0.0,
225
+ 'semantic_speed': 0.0,
226
+ 'decode_speed': 0.0,
227
+ 'timestamp': '',
228
+ 'output_path': ''
229
+ }
230
+ }
231
+
232
+ def reset_runtime(self):
233
+ """重置runtime状态"""
234
+ try:
235
+ self.runtime.reset()
236
+ logger.info("🔄 Runtime状态已重置")
237
+ except Exception as e:
238
+ logger.warning(f"⚠️ Runtime重置失败: {e}")
239
+
240
+ def generate_audio(self, params: Dict[str, Any]) -> Tuple[np.ndarray, Dict[str, Any]]:
241
+ """生成音频"""
242
+ start_time = time.time()
243
+
244
+ # 重置runtime状态
245
+ self.reset_runtime()
246
+
247
+ # 获取参数
248
+ text = params['text']
249
+
250
+ # 检查是否为 zero shot 模式
251
+ if params.get('zero_shot', False):
252
+ # Zero shot 模式
253
+ ref_audio_path = params['ref_audio_path']
254
+ prompt_text = params.get('prompt_text', "希望你以后能够做的,比我还好呦!")
255
+
256
+ logger.info(f"🎯 开始生成音频 (Zero Shot 模式): {text}")
257
+ logger.info(f"📊 参数: 参考音频={ref_audio_path}, 提示文本={prompt_text}")
258
+
259
+ # 检测语言
260
+ lang = detect_token_lang(text)
261
+ logger.info(f"🌍 检测到语言: {lang}")
262
+
263
+ # 使用 zero shot 方法生成 tokens
264
+ global_tokens, semantic_tokens, semantic_time, semantic_speed = self._generate_tokens_zeroshot(text, ref_audio_path, prompt_text)
265
+ else:
266
+ # 传统模式
267
+ age = params['age']
268
+ gender = params['gender']
269
+ emotion = params['emotion']
270
+ pitch = params['pitch']
271
+ speed = params['speed']
272
+
273
+ logger.info(f"🎯 开始生成音频: {text}")
274
+ logger.info(f"📊 参数: 年龄={age}, 性别={gender}, 情感={emotion}, 音高={pitch}, 速度={speed}")
275
+
276
+ # 检测语言
277
+ lang = detect_token_lang(text)
278
+ logger.info(f"🌍 检测到语言: {lang}")
279
+
280
+ # 生成global tokens和semantic tokens
281
+ global_tokens, semantic_tokens, global_time, global_speed, semantic_time, semantic_speed = self._generate_tokens(text, age, gender, emotion, pitch, speed)
282
+
283
+ # 解码音频
284
+ logger.info("🎵 解码音频...")
285
+
286
+ # 使用抽象化的音频解码函数
287
+ wav_data, audio_duration, decode_time, decode_speed = self._decode_audio(global_tokens, semantic_tokens)
288
+
289
+ # 计算总耗时和RTF
290
+ total_time = time.time() - start_time
291
+ total_tokens = len(global_tokens) + len(semantic_tokens)
292
+ rtf = total_time / audio_duration if audio_duration > 0 else 0
293
+
294
+ logger.info(f"📊 总耗时: {total_time:.2f}s,RTF: {rtf:.2f}")
295
+
296
+ # 更新统计信息
297
+ self.generation_stats['total_generations'] += 1
298
+ self.generation_stats['total_tokens'] += total_tokens
299
+ self.generation_stats['total_time'] += total_time
300
+
301
+ self.generation_stats['last_generation'] = {
302
+ 'text': text,
303
+ 'params': params,
304
+ 'total_time': total_time,
305
+ 'total_tokens': total_tokens,
306
+ 'audio_duration': audio_duration,
307
+ 'rtf': rtf,
308
+ 'semantic_speed': semantic_speed,
309
+ 'decode_speed': decode_speed,
310
+ 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
311
+ 'output_path': ''
312
+ }
313
+
314
+ return wav_data, self.generation_stats['last_generation']
315
+
316
+ def _generate_tokens(self, text: str, age: str, gender: str, emotion: str, pitch: str, speed: str) -> Tuple[List[int], List[int], float, float, float, float]:
317
+ """
318
+ 生成global tokens和semantic tokens
319
+
320
+ Args:
321
+ text: 原始文本内容
322
+ age: 年龄参数
323
+ gender: 性别参数
324
+ emotion: 情感参数
325
+ pitch: 音高参数
326
+ speed: 速度参数
327
+
328
+ Returns:
329
+ Tuple: (global_tokens, semantic_tokens, global_time, global_speed, semantic_time, semantic_speed)
330
+ """
331
+ # 编码文本
332
+ logger.info("🔤 编码文本...")
333
+ tokens = self.tokenizer.encode(text)
334
+ logger.info(f"✅ 文本编码完成,共 {len(tokens)} 个token")
335
+
336
+ # 生成全局token
337
+ logger.info("🌐 生成全局token...")
338
+ global_start = time.time()
339
+
340
+ # 准备输入tokens
341
+ TTS_TAG_0 = 8193
342
+ TTS_TAG_1 = 8194
343
+ TTS_TAG_2 = 8195
344
+
345
+ # 构建属性tokens - 使用properties_util.py
346
+ from properties_util import convert_standard_properties_to_tokens
347
+ properties_text = convert_standard_properties_to_tokens(age, gender, emotion, pitch, speed)
348
+ logger.info(f'🔤 属性文本: {properties_text}')
349
+ properties_tokens = self.tokenizer.encode(properties_text, add_special_tokens=False)
350
+ properties_tokens = [i + 8196 + 4096 for i in properties_tokens]
351
+
352
+ # 构建文本tokens
353
+ text_tokens = [i + 8196 + 4096 for i in tokens]
354
+
355
+ # 组合所有tokens
356
+ all_idx = properties_tokens + [TTS_TAG_2] + text_tokens + [TTS_TAG_0]
357
+ logger.info(f'🔢 属性token: {properties_tokens}')
358
+ logger.info(f'🔢 文本token: {text_tokens}')
359
+ logger.info(f'🎯 组合后的tokens: {all_idx}')
360
+
361
+ # Prefill阶段
362
+ logger.info("💎 开始Prefill阶段...")
363
+ session = self.runtime.create_inference_session([all_idx],token_chunk_size=512)
364
+ step_count = 0
365
+ start = time.time()
366
+ while not session.is_complete():
367
+ step_count += 1
368
+ output = session.step()
369
+ if not output.batches[0].is_empty():
370
+ logits = output.batches[0].data
371
+ break
372
+
373
+ prefill_time = time.time() - start
374
+ logger.info(f"✅ Prefill完成,耗时 {step_count} 步")
375
+ logger.info(f"✅ Prefill完成,logits长度: {len(logits)}")
376
+ logger.info(f"✅ Prefill完成,耗时 {prefill_time:.2f}s {len(all_idx)/prefill_time:.1f} tokens/s")
377
+
378
+ # 生成全局token - 按照tts_gui_simple.py的逻辑
379
+ logger.info("🌍 开始生成全局token...")
380
+ global_tokens_size = 32
381
+ global_tokens = []
382
+
383
+ for i in range(global_tokens_size):
384
+ # 从logits中采样token
385
+ sampled_id = sample_logits(logits[0:4096], temperature=1.0, top_p=0.95, top_k=20)
386
+ global_tokens.append(sampled_id)
387
+ # 预测下一个token
388
+ sampled_id += 8196
389
+ logits = self.runtime.predict_next(sampled_id)
390
+
391
+ global_time = time.time() - global_start
392
+ global_speed = global_tokens_size / global_time if global_time > 0 else 0
393
+ logger.info(f"✅ 全局token生成完成,共 {len(global_tokens)} 个token,耗时 {global_time:.2f}s,速度 {global_speed:.1f} tokens/s")
394
+ logger.info(f'🎯 生成的全局token: {global_tokens}')
395
+
396
+ # 生成语义token
397
+ logger.info("🧠 生成语义token...")
398
+ semantic_start = time.time()
399
+
400
+ # 按照tts_gui_simple.py的逻辑生成语义token
401
+ x = self.runtime.predict_next(TTS_TAG_1)
402
+ semantic_tokens = []
403
+
404
+ for i in range(2048): # 最大生成2048个token
405
+ sampled_id = sample_logits(x[0:8193], temperature=1.0, top_p=0.95, top_k=80)
406
+ if sampled_id == 8192: # 遇到结束标记
407
+ logger.info(f"🛑 语义token生成结束,遇到结束标记,共生成 {len(semantic_tokens)} 个token")
408
+ break
409
+ semantic_tokens.append(sampled_id)
410
+ x = self.runtime.predict_next(sampled_id)
411
+
412
+ semantic_time = time.time() - semantic_start
413
+ semantic_speed = len(semantic_tokens) / semantic_time if semantic_time > 0 else 0
414
+ logger.info(f"✅ 语义token生成完成,共 {len(semantic_tokens)} 个token,耗时 {semantic_time:.2f}s,速度 {semantic_speed:.1f} tokens/s")
415
+
416
+ return global_tokens, semantic_tokens, global_time, global_speed, semantic_time, semantic_speed
417
+
418
+ def _generate_tokens_with_global_tokens(self, text: str, global_tokens: List[int]) -> Tuple[List[int], List[int], float, float, float, float]:
419
+ """
420
+ 使用 global tokens 生成语义token
421
+ """
422
+ # 编码文本
423
+ logger.info("🔤 编码文本...")
424
+ text_tokens = self.tokenizer.encode(text, add_special_tokens=False)
425
+ text_tokens = [i + 8196 + 4096 for i in text_tokens]
426
+ logger.info(f"✅ 文本编码完成,共 {len(text_tokens)} 个token")
427
+ global_tokens = [int(i) + 8196 for i in global_tokens]
428
+ logger.info(f'🎯 参考音频 global_tokens: {global_tokens}')
429
+ start = time.time()
430
+
431
+ # 准备输入tokens
432
+ TTS_TAG_0 = 8193
433
+ TTS_TAG_1 = 8194
434
+ TTS_TAG_2 = 8195
435
+
436
+ # 组合所有tokens
437
+ all_idx = [TTS_TAG_2] + text_tokens + [TTS_TAG_0] + global_tokens + [TTS_TAG_1]
438
+ logger.info(f'🎯 组合后的tokens: {all_idx}')
439
+
440
+ # Prefill阶段
441
+ logger.info("💎 开始Prefill阶段...")
442
+ session = self.runtime.create_inference_session([all_idx],token_chunk_size=512)
443
+ step_count = 0
444
+ while not session.is_complete():
445
+ step_count += 1
446
+ output = session.step()
447
+ if not output.batches[0].is_empty():
448
+ logits = output.batches[0].data[0]
449
+ break
450
+ logger.info(f"✅ Prefill完成,耗时 {step_count} 步")
451
+ logger.info(f"✅ Prefill完成,速度 {step_count/output.time:.1f} tokens/s")
452
+ logger.info(f"✅ Prefill完成,logits长度: {len(logits)}")
453
+ prefill_time = time.time() - start
454
+ prefill_speed = len(all_idx) / prefill_time if prefill_time > 0 else 0
455
+ logger.info(f"✅ Prefill完成,耗时 {prefill_time:.2f}s,速度 {prefill_speed:.1f} tokens/s")
456
+
457
+ # 生成语义token
458
+ logger.info("🧠 生成语义token...")
459
+ semantic_start = time.time()
460
+
461
+ # 从当前logits开始生成语义token
462
+ x = logits
463
+ semantic_tokens = []
464
+
465
+ for i in range(2048): # 最大生成2048个token
466
+ sampled_id = sample_logits(x[0:8193], temperature=1.0, top_p=0.95, top_k=80)
467
+ if sampled_id == 8192: # 遇到结束标记
468
+ logger.info(f"🛑 语义token生成结束,遇到结束标记,共生成 {len(semantic_tokens)} 个token")
469
+ break
470
+ semantic_tokens.append(sampled_id)
471
+ x = self.runtime.predict_next(sampled_id)
472
+
473
+ semantic_time = time.time() - semantic_start
474
+ semantic_speed = len(semantic_tokens) / semantic_time if semantic_time > 0 else 0
475
+ logger.info(f"✅ 语义token生成完成,共 {len(semantic_tokens)} 个token,耗时 {semantic_time:.2f}s,速度 {semantic_speed:.1f} tokens/s")
476
+
477
+ return global_tokens, semantic_tokens, prefill_time, prefill_speed, semantic_time, semantic_speed
478
+
479
+ def _generate_tokens_zeroshot(self, text: str, ref_audio_path: str, prompt_text: str = "希望你以后能够做的,比我还好呦!") -> Tuple[List[int], List[int], float, float, float, float]:
480
+ """
481
+ 使用 zero shot 方式生成global tokens和semantic tokens
482
+
483
+ Args:
484
+ text: 原始文本内容
485
+ ref_audio_path: 参考音频路径
486
+ prompt_text: 提示文本,默认为"希望你以后能够做的,比我还好呦!"
487
+
488
+ Returns:
489
+ Tuple: (global_tokens, semantic_tokens, global_time, global_speed, semantic_time, semantic_speed)
490
+ """
491
+ if self.ref_audio_utilities is None:
492
+ raise RuntimeError("RefAudioUtilities 未初始化,无法使用 zero shot 模式")
493
+
494
+ # 编码文本
495
+ logger.info("🔤 编码文本...")
496
+ text_tokens = self.tokenizer.encode(prompt_text + text, add_special_tokens=False)
497
+ text_tokens = [i + 8196 + 4096 for i in text_tokens]
498
+ logger.info(f"✅ 文本编码完成,共 {len(text_tokens)} 个token")
499
+
500
+ # 从参考音频获取 global tokens 和 semantic tokens
501
+ logger.info("🎵 处理参考音频...")
502
+ global_tokens, prompt_semantic_tokens = self.ref_audio_utilities.tokenize(ref_audio_path)
503
+ logger.info(f"✅ 参考音频处理完成")
504
+
505
+ # 直接使用flatten()展平数组并转换为Python一维数组
506
+ global_tokens = [int(i) + 8196 for i in global_tokens.flatten()]
507
+ prompt_semantic_tokens = [int(i) for i in prompt_semantic_tokens.flatten()]
508
+
509
+ logger.info(f'🎯 参考音频 global_tokens: {global_tokens}')
510
+ logger.info(f'🎯 参考音频 semantic_tokens: {prompt_semantic_tokens}')
511
+
512
+ # 生成全局token
513
+ logger.info("🌐 生成全局token...")
514
+ global_start = time.time()
515
+
516
+ # 准备输入tokens
517
+ TTS_TAG_0 = 8193
518
+ TTS_TAG_1 = 8194
519
+ TTS_TAG_2 = 8195
520
+
521
+ # 组合所有tokens
522
+ all_idx = [TTS_TAG_2] + text_tokens + [TTS_TAG_0] + global_tokens + [TTS_TAG_1] + prompt_semantic_tokens
523
+ logger.info(f'🎯 组合后的tokens: {all_idx}')
524
+
525
+ # Prefill阶段
526
+ logger.info("💎 开始Prefill阶段...")
527
+ session = self.runtime.create_inference_session([all_idx],token_chunk_size=512)
528
+ step_count = 0
529
+ start = time.time()
530
+ while not session.is_complete():
531
+ step_count += 1
532
+ output = session.step()
533
+ if not output.batches[0].is_empty():
534
+ logits = output.batches[0].data
535
+ break
536
+ prefill_time = time.time() - start
537
+ logger.info(f"✅ Prefill完成,logits长度: {len(logits)}")
538
+ logger.info(f"✅ Prefill完成,耗时 {step_count} 步")
539
+ logger.info(f"✅ Prefill完成,耗时 {prefill_time:.2f}s {len(all_idx)/prefill_time:.1f} tokens/s")
540
+
541
+
542
+ # 生成语义token
543
+ logger.info("🧠 生成语义token...")
544
+ semantic_start = time.time()
545
+
546
+ # 从当前logits开始生成语义token
547
+ x = logits
548
+ semantic_tokens = []
549
+
550
+ for i in range(2048): # 最大生成2048个token
551
+ sampled_id = sample_logits(x[0:8193], temperature=1.0, top_p=0.95, top_k=80)
552
+ if sampled_id == 8192: # 遇到结束标记
553
+ logger.info(f"🛑 语义token生成结束,遇到结束标记,共生成 {len(semantic_tokens)} 个token")
554
+ break
555
+ semantic_tokens.append(sampled_id)
556
+ x = self.runtime.predict_next(sampled_id)
557
+
558
+ semantic_time = time.time() - semantic_start
559
+ semantic_speed = len(semantic_tokens) / semantic_time if semantic_time > 0 else 0
560
+ logger.info(f"✅ 语义token生成完成,共 {len(semantic_tokens)} 个token,耗时 {semantic_time:.2f}s,速度 {semantic_speed:.1f} tokens/s")
561
+
562
+ global_tokens = [i - 8196 for i in global_tokens]
563
+ return global_tokens, semantic_tokens, semantic_time, semantic_speed
564
+
565
+ def _decode_audio(self, global_tokens: List[int], semantic_tokens: List[int]) -> Tuple[np.ndarray, float, float, float]:
566
+ """
567
+ 解码音频的核心函数
568
+
569
+ Args:
570
+ global_tokens: 全局tokens列表
571
+ semantic_tokens: 语义tokens列表
572
+
573
+ Returns:
574
+ Tuple: (wav_data, audio_duration, decode_time, decode_speed)
575
+ """
576
+ # 开始计时
577
+ decode_start = time.time()
578
+
579
+ # 准备输入数据
580
+ logger.info("🔧 准备解码器输入数据...")
581
+ global_tokens_array = np.array(global_tokens, dtype=np.int64).reshape(1, 1, -1)
582
+ semantic_tokens_array = np.array(semantic_tokens, dtype=np.int64).reshape(1, -1)
583
+ logger.info(f'🎯 生成的全局token: {global_tokens}')
584
+ logger.info(f'🎯 生成的语义token: {semantic_tokens}')
585
+ logger.info(f'📊 解码器输入形状: global_tokens={global_tokens_array.shape}, semantic_tokens={semantic_tokens_array.shape}')
586
+
587
+ # 使用ONNX解码器生成音频
588
+ logger.info("🎵 开始ONNX解码器推理...")
589
+ outputs = self.ort_session.run(None, {
590
+ "global_tokens": global_tokens_array,
591
+ "semantic_tokens": semantic_tokens_array
592
+ })
593
+ wav_data = outputs[0].reshape(-1)
594
+ decode_time = time.time() - decode_start
595
+
596
+ # 计算音频时长和解码速度
597
+ audio_duration = len(wav_data) / 16000 # 采样率16kHz
598
+ decode_speed = len(semantic_tokens) / decode_time if decode_time > 0 else 0
599
+
600
+ logger.info(f"✅ 音频解码完成,时长 {audio_duration:.2f}s,耗时 {decode_time:.2f}s,速度 {decode_speed:.1f} tokens/s")
601
+
602
+ return wav_data, audio_duration, decode_time, decode_speed
603
+
604
+ def _save_audio(self, wav_data: np.ndarray, output_path: str, sample_rate: int = 16000) -> bool:
605
+ """
606
+ 保存音频文件
607
+
608
+ Args:
609
+ wav_data: 音频数据
610
+ output_path: 输出文件路径
611
+ sample_rate: 采样率,默认16kHz
612
+
613
+ Returns:
614
+ bool: 保存是否成功
615
+ """
616
+ try:
617
+ sf.write(output_path, wav_data, sample_rate)
618
+ logger.info(f"💾 音频保存成功: {output_path}")
619
+ return True
620
+ except Exception as e:
621
+ logger.error(f"❌ 音频保存失败: {e}")
622
+ return False
623
+
624
+ def display_stats(stats: Dict[str, Any]):
625
+ """显示生成统计信息"""
626
+ logger.info("\n" + "="*60)
627
+ logger.info("📊 生成统计信息")
628
+ logger.info("="*60)
629
+
630
+ if stats['text']:
631
+ logger.info(f"🎯 生成参数: {stats['params']}")
632
+ logger.info(f"📝 文本: {stats['text']}")
633
+ logger.info(f"⏱️ 总耗时: {stats['total_time']:.2f}s")
634
+ logger.info(f"🎵 音频时长: {stats['audio_duration']:.2f}s")
635
+ logger.info(f"📈 RTF: {stats['rtf']:.2f}")
636
+ logger.info(f"🔢 总token数: {stats['total_tokens']}")
637
+ logger.info(f"🧠 语义token速度: {stats['semantic_speed']:.1f} tokens/s")
638
+ logger.info(f"🎵 解码速度: {stats['decode_speed']:.1f} tokens/s")
639
+ logger.info(f"🕐 时间: {stats['timestamp']}")
640
+ if stats['output_path']:
641
+ logger.info(f"💾 保存路径: {stats['output_path']}")
642
+ else:
643
+ logger.info("暂无生成记录")
644
+
645
+ logger.info("="*60)
646
+
647
+ def interactive_parameter_selection(generator: TTSGenerator):
648
+ """交互式参数选择界面"""
649
+ logger.info("\n🎮 进入交互式配置界面")
650
+ logger.info("💡 使用方向键选择,回车确认,Ctrl+C退出")
651
+
652
+ while True:
653
+ try:
654
+ logger.info("\n" + "="*60)
655
+ logger.info("🎵 RWKV TTS 参数配置")
656
+ logger.info("="*60)
657
+
658
+ # 选择生成模式
659
+ generation_mode = questionary.select(
660
+ "🎯 请选择生成模式:",
661
+ choices=[
662
+ "传统模式 (使用属性参数)",
663
+ "Zero Shot 模式 (使用参考音频)"
664
+ ],
665
+ default="传统模式 (使用属性参数)"
666
+ ).ask()
667
+
668
+ if generation_mode is None: # 用户按Ctrl+C
669
+ break
670
+
671
+ is_zero_shot = generation_mode == "Zero Shot 模式 (使用参考音频)"
672
+
673
+ # 文本输入
674
+ text = questionary.text(
675
+ "📝 请输入要转换的文本:",
676
+ default=generator.generation_stats['last_generation'].get('text', '你好,世界!')
677
+ ).ask()
678
+
679
+ if text is None: # 用户按Ctrl+C
680
+ break
681
+
682
+ # 输出目录
683
+ output_dir = questionary.text(
684
+ "📁 请输入输出目录:",
685
+ default="./generated_audio"
686
+ ).ask()
687
+
688
+ if output_dir is None:
689
+ break
690
+
691
+ if is_zero_shot:
692
+ # Zero Shot 模式参数
693
+ ref_audio_path = questionary.text(
694
+ "🎵 请输入参考音频路径:",
695
+ default="zero_shot_prompt.wav"
696
+ ).ask()
697
+
698
+ if ref_audio_path is None:
699
+ break
700
+
701
+ prompt_text = questionary.text(
702
+ "💬 请输入提示文本 (可选,回车使用默认值):",
703
+ default="希望你以后能够做的,能比我还好呦!"
704
+ ).ask()
705
+
706
+ if prompt_text is None:
707
+ break
708
+
709
+
710
+
711
+ # 确认生成
712
+ confirm = questionary.confirm(
713
+ f"🚀 确认生成音频 (Zero Shot 模式)?\n"
714
+ f"文本: {text}\n"
715
+ f"参考音频: {ref_audio_path}\n"
716
+ f"提示文本: {prompt_text}\n"
717
+ f"输出目录: {output_dir}",
718
+ default=True
719
+ ).ask()
720
+
721
+ if confirm:
722
+ # 准备参数
723
+ params = {
724
+ 'text': text,
725
+ 'zero_shot': True,
726
+ 'ref_audio_path': ref_audio_path,
727
+ 'prompt_text': prompt_text,
728
+ 'output_dir': output_dir
729
+ }
730
+
731
+ # 生成音频
732
+ try:
733
+ wav_data, stats = generator.generate_audio(params)
734
+
735
+ # 生成唯一文件名
736
+ output_path = get_unique_filename(output_dir, text)
737
+
738
+ # 保存音频
739
+ if generator._save_audio(wav_data, output_path, 16000):
740
+ stats['output_path'] = output_path
741
+ else:
742
+ logger.warning("⚠️ 音频保存失败,但生成统计已更新")
743
+
744
+ logger.info(f"✅ 音频生成成功,保存至: {output_path}")
745
+ stats['生成参数'] = f'参考音频={ref_audio_path}, 提示文本={prompt_text}'
746
+ # 显示统计信息
747
+ display_stats(stats)
748
+
749
+ except Exception as e:
750
+ logger.error(f"❌ 生成失败: {e}")
751
+ import traceback
752
+ traceback.print_exc()
753
+ else:
754
+ # 传统模式参数
755
+ # 年龄选择
756
+ age = questionary.select(
757
+ "👶 请选择年龄:",
758
+ choices=age_choices,
759
+ default=age_choices[3] # middle-aged
760
+ ).ask()
761
+
762
+ if age is None:
763
+ break
764
+
765
+ # 性别选择
766
+ gender = questionary.select(
767
+ "👤 请选择性别:",
768
+ choices=gender_choices,
769
+ default=gender_choices[0] # female (第一个选项)
770
+ ).ask()
771
+
772
+ if gender is None:
773
+ break
774
+
775
+ # 情感选择
776
+ emotion = questionary.select(
777
+ "😊 请选择情感:",
778
+ choices=emotion_choices,
779
+ default=emotion_choices[1] # NEUTRAL
780
+ ).ask()
781
+
782
+ if emotion is None:
783
+ break
784
+
785
+ # 音高选择
786
+ pitch = questionary.select(
787
+ "🎵 请选择音高:",
788
+ choices=pitch_choices,
789
+ default=pitch_choices[1] # medium_pitch
790
+ ).ask()
791
+
792
+ if pitch is None:
793
+ break
794
+
795
+ # 速度选择
796
+ speed = questionary.select(
797
+ "⚡ 请选择速度:",
798
+ choices=speed_choices,
799
+ default=speed_choices[2] # medium
800
+ ).ask()
801
+
802
+ if speed is None:
803
+ break
804
+
805
+
806
+ # 确认生成
807
+ confirm = questionary.confirm(
808
+ f"🚀 确认生成音频?\n"
809
+ f"文本: {text}\n"
810
+ f"参数: 年龄={age}, 性别={gender}, 情感={emotion}, 音高={pitch}, 速度={speed}\n"
811
+ f"输出目录: {output_dir}",
812
+ default=True
813
+ ).ask()
814
+
815
+ if confirm:
816
+ # 准备参数
817
+ params = {
818
+ 'text': text,
819
+ 'zero_shot': False,
820
+ 'age': age,
821
+ 'gender': gender,
822
+ 'emotion': emotion,
823
+ 'pitch': pitch,
824
+ 'speed': speed,
825
+ 'output_dir': output_dir
826
+ }
827
+
828
+ # 生成音频
829
+ try:
830
+ wav_data, stats = generator.generate_audio(params)
831
+
832
+ # 生成唯一文件名
833
+ output_path = get_unique_filename(output_dir, text)
834
+
835
+ # 保存音频
836
+ if generator._save_audio(wav_data, output_path, 16000):
837
+ stats['output_path'] = output_path
838
+ else:
839
+ logger.warning("⚠️ 音频保存失败,但生成统计已更新")
840
+
841
+ logger.info(f"✅ 音频生成成功,保存至: {output_path}")
842
+ stats['生成参数'] = f'年龄={age}, 性别={gender}, 情感={emotion}, 音高={pitch}, 速度={speed}'
843
+ # 显示统计信息
844
+ display_stats(stats)
845
+
846
+ except Exception as e:
847
+ logger.error(f"❌ 生成失败: {e}")
848
+ import traceback
849
+ traceback.print_exc()
850
+
851
+ # 询问是否继续
852
+ continue_generation = questionary.confirm(
853
+ "🔄 是否继续生成音频?",
854
+ default=True
855
+ ).ask()
856
+
857
+ if not continue_generation:
858
+ break
859
+
860
+ except KeyboardInterrupt:
861
+ logger.info("\n👋 用户中断,退出程序")
862
+ break
863
+ except Exception as e:
864
+ logger.error(f"❌ 发生错误: {e}")
865
+ import traceback
866
+ traceback.print_exc()
867
+ break
868
+
869
+ logger.info("👋 感谢使用 RWKV TTS!")
870
+
871
+ @click.command()
872
+ @click.option('--model_path', required=True, help='RWKV模型路径')
873
+ def main(model_path):
874
+ """RWKV TTS 主程序"""
875
+ logger.info("🚀 欢迎使用 RWKV TTS 交互式音频生成工具!")
876
+
877
+ # 检查模型文件
878
+ if not os.path.exists(model_path):
879
+ logger.error(f"❌ 错误: 模型路径不存在: {model_path}")
880
+ return
881
+
882
+ # 自动构建解码器路径
883
+ decoder_path = os.path.join(model_path, "BiCodecDetokenize.onnx")
884
+ logger.info(f"🔍 自动设置解码器路径: {decoder_path}")
885
+
886
+ # 检查模型目录中的文件
887
+ logger.info(f"🔍 检查模型目录: {model_path}")
888
+ try:
889
+ model_files = os.listdir(model_path)
890
+ logger.info(f"📁 模型目录中的文件:")
891
+ for file in model_files:
892
+ file_path = os.path.join(model_path, file)
893
+ if os.path.isfile(file_path):
894
+ size = os.path.getsize(file_path)
895
+ logger.info(f" 📄 {file} ({size:,} bytes)")
896
+ else:
897
+ logger.info(f" 📁 {file}/")
898
+ except Exception as e:
899
+ logger.warning(f"⚠️ 无法列出模型目录内容: {e}")
900
+
901
+ if not os.path.exists(decoder_path):
902
+ logger.error(f"❌ 错误: 解码器路径不存在: {decoder_path}")
903
+ return
904
+
905
+ # 选择设备
906
+ logger.info("\n💎 选择设备 💎")
907
+ try:
908
+ devices = webrwkv_py.get_available_adapters_py()
909
+ except Exception as e:
910
+ logger.error(f"❌ 无法获取可用设备列表: {e}")
911
+ return
912
+
913
+ for i, device in enumerate(devices):
914
+ logger.info(f"{i}: {device}")
915
+
916
+ device_choice = input("请选择设备: ")
917
+ try:
918
+ device_idx = int(device_choice)
919
+ if device_idx < 0 or device_idx >= len(devices):
920
+ logger.error("❌ 无效的设备选择")
921
+ return
922
+ device = devices[device_idx]
923
+ logger.info(f"✅ 选择设备: {device}")
924
+ except ValueError:
925
+ logger.error("❌ 无效的设备选择")
926
+ return
927
+
928
+ # 加载模型
929
+ logger.info("\n💎 加载模型 💎")
930
+ try:
931
+ # 尝试多种可能的模型文件名
932
+ possible_model_files = [
933
+ 'webrwkv.safetensors',
934
+ ]
935
+
936
+ webrwkv_model_path = None
937
+ for model_file in possible_model_files:
938
+ test_path = os.path.join(model_path, model_file)
939
+ if os.path.exists(test_path):
940
+ webrwkv_model_path = test_path
941
+ logger.info(f"✅ 找到模型文件: {model_file}")
942
+ break
943
+
944
+ if webrwkv_model_path is None:
945
+ logger.error(f"❌ 未找到模型文件")
946
+ logger.info(f"💡 请检查模型目录 {model_path} 中是否包含以下文件之一:")
947
+ for model_file in possible_model_files:
948
+ logger.info(f" - {model_file}")
949
+ return
950
+
951
+ logger.info(f"🔍 尝试加载模型文件: {webrwkv_model_path}")
952
+
953
+ # 尝试新的API
954
+ model = webrwkv_py.Model(webrwkv_model_path, 'fp32', device_idx)
955
+ logger.info(f"✅ 模型加载成功: {webrwkv_model_path}")
956
+ except Exception as e:
957
+ logger.error(f"❌ 模型加载失败: {e}")
958
+ logger.info(f"💡 请检查:")
959
+ logger.info(f" 1. 模型文件路径是否正确: {webrwkv_model_path}")
960
+ logger.info(f" 2. 模型文件是否完整")
961
+ logger.info(f" 3. 设备索引是否正确: {device_idx}")
962
+ logger.info(f" 4. 模型文件格式是否支持")
963
+ return
964
+
965
+ # 创建runtime
966
+ logger.info("\n💎 创建 runtime 💎")
967
+ try:
968
+ runtime = model.create_thread_runtime()
969
+ logger.info("✅ runtime 创建成功")
970
+ except Exception as e:
971
+ logger.error(f"❌ runtime 创建失败: {e}")
972
+ return
973
+
974
+ # 加载tokenizer
975
+ logger.info("\n💎 加载 tokenizer 💎")
976
+ try:
977
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
978
+ logger.info(f"✅ tokenizer 加载成功: {model_path}")
979
+ except Exception as e:
980
+ logger.error(f"❌ tokenizer 加载失败: {e}")
981
+ logger.info(f"💡 请检查模型目录 {model_path} 中是否包含正确的tokenizer文件")
982
+ return
983
+
984
+ # 创建TTS生成器
985
+ generator = TTSGenerator(runtime, tokenizer, decoder_path, device, model_path)
986
+
987
+ # 启动交互式界面
988
+ logger.info("\n🎯 启动交互式配置界面...")
989
+ interactive_parameter_selection(generator)
990
+
991
+ if __name__ == "__main__":
992
+ main()
rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/utilities.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import numpy as np
5
+ from transformers import AutoTokenizer
6
+ from properties_util import convert_standard_properties_to_tokens
7
+
8
+ def print_properties_info(age: str, gender: str, emotion: str, pitch: float, speed: float):
9
+ """
10
+ 打印属性信息的辅助函数
11
+
12
+ Args:
13
+ age: 年龄
14
+ gender: 性别
15
+ emotion: 情感
16
+ pitch: 音调
17
+ speed: 速度
18
+ """
19
+ print(f'age: {age}, gender: {gender}, emotion: {emotion}, pitch: {pitch}, speed: {speed}')
20
+
21
+ @torch.inference_mode()
22
+ def extract_embeddings_for_global_tokens(model, tokenizer, text, age: str, gender: str, emotion: str, pitch: float, speed: float,global_tokens: list = None):
23
+ """
24
+ 提取生成全局tokens所需的embedding
25
+
26
+ Args:
27
+ model: 模型实例
28
+ tokenizer: 分词器
29
+ text: 输入文本
30
+ age: 年龄
31
+ gender: 性别
32
+ emotion: 情感
33
+ pitch: 音调
34
+ speed: 速度
35
+ global_tokens: 全局tokens
36
+ Returns:
37
+ torch.Tensor: 拼接后的完整embedding
38
+ """
39
+ device = (next(model.parameters()).device)
40
+ properties_tokens = convert_standard_properties_to_tokens(age, gender, emotion, pitch, speed)
41
+ text_tokens = tokenizer.encode(text, add_special_tokens=False)
42
+ properties_tokens = tokenizer.encode(properties_tokens, add_special_tokens=False)
43
+ text_tokens_tensor = torch.tensor(text_tokens, dtype=torch.long, device=device)
44
+ properties_tokens_tensor = torch.tensor(properties_tokens, dtype=torch.long, device=device)
45
+ text_embs = model.text_embedder(text_tokens_tensor)
46
+ properties_embs = model.text_embedder(properties_tokens_tensor)
47
+ tag_0_emb = model.tts_tag_embedder(torch.tensor([0], dtype=torch.long, device=device))
48
+ tag_1_emb = model.tts_tag_embedder(torch.tensor([1], dtype=torch.long, device=device))
49
+ tag_2_emb = model.tts_tag_embedder(torch.tensor([2], dtype=torch.long, device=device))
50
+ full_embs_for_sample = torch.cat([
51
+ properties_embs,
52
+ tag_2_emb, text_embs, tag_0_emb,
53
+ ], dim=0)
54
+ if global_tokens is not None:
55
+ global_tokens_tensor = torch.tensor(global_tokens, dtype=torch.long, device=device)
56
+ global_embs = model.global_embedder(global_tokens_tensor)
57
+ full_embs_for_sample = torch.cat([
58
+ full_embs_for_sample,
59
+ global_embs,
60
+ tag_1_emb
61
+ ], dim=0)
62
+ return full_embs_for_sample
63
+
64
+ def get_tokenizer(model_dir):
65
+ tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
66
+ special_tokens = {
67
+ 'pad_token': '<|rwkv_tokenizer_end_of_text|>',
68
+ 'additional_special_tokens': [
69
+ '<|endofprompt|>',
70
+ '[breath]', '<strong>', '</strong>', '[noise]',
71
+ '[laughter]', '[cough]', '[clucking]', '[accent]',
72
+ '[quick_breath]',
73
+ "<laughter>", "</laughter>",
74
+ "[hissing]", "[sigh]", "[vocalized-noise]",
75
+ "[lipsmack]", "[mn]"
76
+ ]
77
+ }
78
+ tokenizer.add_special_tokens(special_tokens)
79
+ return tokenizer
80
+
81
+ def get_respark_tts_tokenizer(model_dir):
82
+ tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
83
+ original_vocab_size = tokenizer.vocab_size
84
+ added_tokens_file = os.path.join(os.path.dirname(__file__),'spark_tts_added_tokens.json')
85
+ with open(added_tokens_file, 'r') as f:
86
+ added_tokens = json.load(f)
87
+ tokenizer.add_special_tokens(added_tokens)
88
+ return tokenizer,original_vocab_size
89
+ @torch.inference_mode()
90
+ def generate_global_tokens(model, tokenizer, text, age: str, gender: str, emotion: str, pitch: float, speed: float,
91
+ num_global_tokens: int = 4096):
92
+ full_embs_for_sample = extract_embeddings_for_global_tokens(model, tokenizer, text, age, gender, emotion, pitch, speed)
93
+ device = full_embs_for_sample.device
94
+ vocab_size = model.config.vocab_size
95
+ eos_token_id = vocab_size - 1
96
+ suppress_tokens = [id for id in range(num_global_tokens,vocab_size)]
97
+ gen_args = {
98
+ "inputs_embeds":full_embs_for_sample.unsqueeze(0),
99
+ "attention_mask":torch.ones((1, full_embs_for_sample.shape[1]),dtype=torch.long,device=device),
100
+ "max_new_tokens":32,
101
+ "min_new_tokens":32,
102
+ "do_sample":True,
103
+ "top_k":50,
104
+ "top_p":0.95,
105
+ "temperature":1.0,
106
+ "eos_token_id":eos_token_id,
107
+ "pad_token_id":tokenizer.pad_token_id,
108
+ "use_cache":True,
109
+ "suppress_tokens":suppress_tokens,
110
+ "return_dict_in_generate":True,
111
+ }
112
+ generated_outputs = model.generate(**gen_args)
113
+ return generated_outputs
114
+ @torch.inference_mode()
115
+ def generate_input_embeddings(model,tokenizer,text,global_tokens):
116
+ device = (next(model.parameters()).device)
117
+ text_tokens = tokenizer.encode(text, add_special_tokens=False)
118
+ text_tokens_tensor = torch.tensor(text_tokens, dtype=torch.long, device=device)
119
+ text_embs = model.text_embedder(text_tokens_tensor)
120
+ global_tokens_tensor = torch.tensor(global_tokens, dtype=torch.long, device=device)
121
+ global_embs = model.global_embedder(global_tokens_tensor)
122
+ tag_0_emb = model.tts_tag_embedder(torch.tensor([0], dtype=torch.long, device=device))
123
+ tag_1_emb = model.tts_tag_embedder(torch.tensor([1], dtype=torch.long, device=device))
124
+ tag_2_emb = model.tts_tag_embedder(torch.tensor([2], dtype=torch.long, device=device))
125
+ input_embs = torch.cat([tag_2_emb,text_embs,tag_0_emb,global_embs,tag_1_emb],dim=0)
126
+ return input_embs
127
+
128
+ def generate_embeddings(model, tokenizer, text, bicodec, prompt_text=None, prompt_audio=None):
129
+ """
130
+ 为 Spark LLM 生成预测所需的输入嵌入
131
+
132
+ Args:
133
+ model: Spark LLM 模型
134
+ tokenizer: 文本分词器
135
+ text: 要生成语音的文本
136
+ bicodec: BiCodecTokenizer 实例
137
+ prompt_text: 提示文本(可选)
138
+ prompt_audio: 提示音频数组(可选)
139
+
140
+ Returns:
141
+ dict: 包含 input_embs 的字典,用于模型预测
142
+ """
143
+ device = next(model.parameters()).device
144
+
145
+ # 1. 处理提示音频,提取 global_tokens 和 semantic_tokens
146
+ if prompt_audio is not None:
147
+ # 确保音频数据是 float32 类型
148
+ audio_data = np.array(prompt_audio, dtype=np.float32)
149
+ target_sample_rate = bicodec.config['sample_rate']
150
+
151
+ # 检查是否需要重采样
152
+ # 注意:这里假设 prompt_audio 已经是从 soundfile 加载的,采样率信息在外部处理
153
+ # BiCodecTokenizer 期望 16kHz 采样率的音频
154
+ print(f"BiCodecTokenizer 期望的采样率: {target_sample_rate}Hz")
155
+ print(f"音频数据形状: {audio_data.shape}")
156
+
157
+ # 使用 BiCodec 提取 tokens (返回顺序: global_tokens, semantic_tokens)
158
+ global_tokens, semantic_tokens = bicodec.tokenize(audio_data)
159
+ global_tokens = global_tokens.squeeze(0).squeeze(0).detach().cpu().tolist()
160
+ semantic_tokens = semantic_tokens.squeeze(0).squeeze(0).detach().cpu().tolist()
161
+ else:
162
+ global_tokens = []
163
+ semantic_tokens = []
164
+
165
+ # 2. 处理文本
166
+ if prompt_text is not None:
167
+ # 连接提示文本和目标文本
168
+ full_text = prompt_text + text
169
+ # 初始的 semantic tokens 等于 prompt_audio 提取的 semantic tokens
170
+ initial_semantic_tokens = semantic_tokens.copy()
171
+ else:
172
+ full_text = text
173
+ initial_semantic_tokens = []
174
+
175
+ # 3. 获取文本 tokens
176
+ text_tokens = tokenizer.encode(full_text, add_special_tokens=False)
177
+
178
+ # 4. 转换为张量
179
+ text_tokens_tensor = torch.tensor(text_tokens, dtype=torch.long, device=device)
180
+ global_tokens_tensor = torch.tensor(global_tokens, dtype=torch.long, device=device)
181
+ semantic_tokens_tensor = torch.tensor(initial_semantic_tokens, dtype=torch.long, device=device)
182
+
183
+ # 5. 获取嵌入
184
+ text_embs = model.text_embedder(text_tokens_tensor)
185
+ global_embs = model.global_embedder(global_tokens_tensor)
186
+ semantic_embs = model.model.embeddings(semantic_tokens_tensor)
187
+
188
+ # 6. 获取特殊标记嵌入
189
+ tag_0_emb = model.tts_tag_embedder(torch.tensor([0], dtype=torch.long, device=device))
190
+ tag_1_emb = model.tts_tag_embedder(torch.tensor([1], dtype=torch.long, device=device))
191
+ tag_2_emb = model.tts_tag_embedder(torch.tensor([2], dtype=torch.long, device=device))
192
+
193
+ # 7. 连接嵌入
194
+ input_embs = torch.cat([
195
+ tag_2_emb,
196
+ text_embs,
197
+ tag_0_emb,
198
+ global_embs,
199
+ tag_1_emb,
200
+ semantic_embs
201
+ ], dim=0)
202
+
203
+ # 8. 添加批次维度
204
+ input_embs = input_embs.unsqueeze(0) # [1, seq_len, hidden_size]
205
+
206
+ return {
207
+ "input_embs": input_embs,
208
+ "global_tokens": global_tokens_tensor,
209
+ }
rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/wav2vec2-large-xlsr-53.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0947d5aed2023e06b07a0180549e64a48977863b20f1156cbf33fd97ab6e3ad6
3
+ size 858969041
rwkv7-0.1B-g1-respark-voice-tunable-ipa-epoch1/webrwkv.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:986776d478378952d42932269de147aee2e77332ab9ea5b1bc16c657eb5c424c
3
+ size 420157752