fcyai commited on
Commit
01a69aa
·
1 Parent(s): 623d773
ChatTTS/ChatTTS/core.py CHANGED
@@ -1,10 +1,9 @@
1
-
2
  import os
3
  import json
4
  import logging
5
- from functools import partial
6
- from typing import Literal
7
  import tempfile
 
 
8
 
9
  import torch
10
  from omegaconf import OmegaConf
@@ -15,19 +14,19 @@ from .model.dvae import DVAE
15
  from .model.gpt import GPT_warpper
16
  from .utils.gpu_utils import select_device
17
  from .utils.infer_utils import count_invalid_characters, detect_language, apply_character_map, apply_half2full_map, HomophonesReplacer
18
- from .utils.io_utils import get_latest_modified_file
19
  from .infer.api import refine_text, infer_code
20
  from .utils.download import check_all_assets, download_all_assets
21
-
22
- logging.basicConfig(level = logging.INFO)
23
 
24
 
25
  class Chat:
26
- def __init__(self, ):
27
  self.pretrain_models = {}
28
  self.normalizer = {}
29
  self.homophones_replacer = None
30
- self.logger = logging.getLogger(__name__)
 
31
 
32
  def check_model(self, level = logging.INFO, use_decoder = False):
33
  not_finish = False
@@ -45,7 +44,7 @@ class Chat:
45
 
46
  if not not_finish:
47
  self.logger.log(level, f'All initialized.')
48
-
49
  return not not_finish
50
 
51
  def load_models(
@@ -61,8 +60,8 @@ class Chat:
61
  with tempfile.TemporaryDirectory() as tmp:
62
  download_all_assets(tmpdir=tmp)
63
  if not check_all_assets(update=False):
64
- logging.error("counld not satisfy all assets needed.")
65
- exit(1)
66
  elif source == 'huggingface':
67
  hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface"))
68
  try:
@@ -78,7 +77,7 @@ class Chat:
78
  self.logger.log(logging.INFO, f'Load from local: {custom_path}')
79
  download_path = custom_path
80
 
81
- self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()}, **kwargs)
82
 
83
  def _load(
84
  self,
@@ -91,17 +90,18 @@ class Chat:
91
  decoder_config_path: str = None,
92
  decoder_ckpt_path: str = None,
93
  tokenizer_path: str = None,
94
- device: str = None,
95
  compile: bool = True,
96
  ):
97
- if not device:
98
  device = select_device(4096)
99
  self.logger.log(logging.INFO, f'use {device}')
100
-
 
101
  if vocos_config_path:
102
  vocos = Vocos.from_hparams(vocos_config_path).to(
103
  # vocos on mps will crash, use cpu fallback
104
- "cpu" if torch.backends.mps.is_available() else device
105
  ).eval()
106
  assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
107
  vocos.load_state_dict(torch.load(vocos_ckpt_path))
@@ -118,14 +118,14 @@ class Chat:
118
 
119
  if gpt_config_path:
120
  cfg = OmegaConf.load(gpt_config_path)
121
- gpt = GPT_warpper(**cfg).to(device).eval()
122
  assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
123
  gpt.load_state_dict(torch.load(gpt_ckpt_path))
124
  if compile and 'cuda' in str(device):
125
  try:
126
  gpt.gpt.forward = torch.compile(gpt.gpt.forward, backend='inductor', dynamic=True)
127
  except RuntimeError as e:
128
- logging.warning(f'Compile failed,{e}. fallback to normal mode.')
129
  self.pretrain_models['gpt'] = gpt
130
  spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), 'spk_stat.pt')
131
  assert os.path.exists(spk_stat_path), f'Missing spk_stat.pt: {spk_stat_path}'
@@ -146,7 +146,7 @@ class Chat:
146
  self.pretrain_models['tokenizer'] = tokenizer
147
  self.logger.log(logging.INFO, 'tokenizer loaded.')
148
 
149
- self.check_model()
150
 
151
  def _infer(
152
  self,
@@ -179,14 +179,16 @@ class Chat:
179
  self.logger.log(logging.WARNING, f'Invalid characters found! : {invalid_characters}')
180
  text[i] = apply_character_map(t)
181
  if do_homophone_replacement and self.init_homophones_replacer():
182
- text[i] = self.homophones_replacer.replace(t)
183
- if t != text[i]:
184
- self.logger.log(logging.INFO, f'Homophones replace: {t} -> {text[i]}')
 
185
 
186
  if not skip_refine_text:
187
  text_tokens = refine_text(
188
  self.pretrain_models,
189
  text,
 
190
  **params_refine_text,
191
  )['ids']
192
  text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
@@ -197,16 +199,28 @@ class Chat:
197
 
198
  text = [params_infer_code.get('prompt', '') + i for i in text]
199
  params_infer_code.pop('prompt', '')
200
- result_gen = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder, stream=stream)
 
 
 
 
 
 
 
201
  if use_decoder:
202
  field = 'hiddens'
203
  docoder_name = 'decoder'
204
  else:
205
  field = 'ids'
206
  docoder_name = 'dvae'
207
- vocos_decode = lambda spec: [self.pretrain_models['vocos'].decode(
208
- i.cpu() if torch.backends.mps.is_available() else i
209
- ).cpu().numpy() for i in spec]
 
 
 
 
 
210
  if stream:
211
 
212
  length = 0
@@ -220,13 +234,20 @@ class Chat:
220
  if not len(chunk_data):
221
  continue
222
  self.logger.debug(f'new hidden {len(chunk_data)=}')
223
- mel_spec = [self.pretrain_models[docoder_name](i[None].permute(0,2,1)) for i in [chunk_data]]
 
 
224
  wav = vocos_decode(mel_spec)
 
225
  self.logger.debug(f'yield wav chunk {len(wav[0])=} {len(wav[0][0])=}')
226
  yield wav
227
  return
228
- mel_spec = [self.pretrain_models[docoder_name](i[None].permute(0,2,1)) for i in next(result_gen)[field]]
229
- yield vocos_decode(mel_spec)
 
 
 
 
230
 
231
  def infer(
232
  self,
 
 
1
  import os
2
  import json
3
  import logging
 
 
4
  import tempfile
5
+ from functools import partial
6
+ from typing import Literal, Optional
7
 
8
  import torch
9
  from omegaconf import OmegaConf
 
14
  from .model.gpt import GPT_warpper
15
  from .utils.gpu_utils import select_device
16
  from .utils.infer_utils import count_invalid_characters, detect_language, apply_character_map, apply_half2full_map, HomophonesReplacer
17
+ from .utils.io import get_latest_modified_file, del_all
18
  from .infer.api import refine_text, infer_code
19
  from .utils.download import check_all_assets, download_all_assets
20
+ from .utils.log import set_utils_logger
 
21
 
22
 
23
  class Chat:
24
+ def __init__(self, logger=logging.getLogger(__name__)):
25
  self.pretrain_models = {}
26
  self.normalizer = {}
27
  self.homophones_replacer = None
28
+ self.logger = logger
29
+ set_utils_logger(logger)
30
 
31
  def check_model(self, level = logging.INFO, use_decoder = False):
32
  not_finish = False
 
44
 
45
  if not not_finish:
46
  self.logger.log(level, f'All initialized.')
47
+
48
  return not not_finish
49
 
50
  def load_models(
 
60
  with tempfile.TemporaryDirectory() as tmp:
61
  download_all_assets(tmpdir=tmp)
62
  if not check_all_assets(update=False):
63
+ self.logger.error("counld not satisfy all assets needed.")
64
+ return False
65
  elif source == 'huggingface':
66
  hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface"))
67
  try:
 
77
  self.logger.log(logging.INFO, f'Load from local: {custom_path}')
78
  download_path = custom_path
79
 
80
+ return self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()}, **kwargs)
81
 
82
  def _load(
83
  self,
 
90
  decoder_config_path: str = None,
91
  decoder_ckpt_path: str = None,
92
  tokenizer_path: str = None,
93
+ device: Optional[torch.device] = None,
94
  compile: bool = True,
95
  ):
96
+ if device is None:
97
  device = select_device(4096)
98
  self.logger.log(logging.INFO, f'use {device}')
99
+ self.device = device
100
+
101
  if vocos_config_path:
102
  vocos = Vocos.from_hparams(vocos_config_path).to(
103
  # vocos on mps will crash, use cpu fallback
104
+ "cpu" if "mps" in str(device) else device
105
  ).eval()
106
  assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
107
  vocos.load_state_dict(torch.load(vocos_ckpt_path))
 
118
 
119
  if gpt_config_path:
120
  cfg = OmegaConf.load(gpt_config_path)
121
+ gpt = GPT_warpper(**cfg, device=device, logger=self.logger).eval()
122
  assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
123
  gpt.load_state_dict(torch.load(gpt_ckpt_path))
124
  if compile and 'cuda' in str(device):
125
  try:
126
  gpt.gpt.forward = torch.compile(gpt.gpt.forward, backend='inductor', dynamic=True)
127
  except RuntimeError as e:
128
+ self.logger.warning(f'Compile failed,{e}. fallback to normal mode.')
129
  self.pretrain_models['gpt'] = gpt
130
  spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), 'spk_stat.pt')
131
  assert os.path.exists(spk_stat_path), f'Missing spk_stat.pt: {spk_stat_path}'
 
146
  self.pretrain_models['tokenizer'] = tokenizer
147
  self.logger.log(logging.INFO, 'tokenizer loaded.')
148
 
149
+ return self.check_model()
150
 
151
  def _infer(
152
  self,
 
179
  self.logger.log(logging.WARNING, f'Invalid characters found! : {invalid_characters}')
180
  text[i] = apply_character_map(t)
181
  if do_homophone_replacement and self.init_homophones_replacer():
182
+ text[i], replaced_words = self.homophones_replacer.replace(text[i])
183
+ if replaced_words:
184
+ repl_res = ', '.join([f'{_[0]}->{_[1]}' for _ in replaced_words])
185
+ self.logger.log(logging.INFO, f'Homophones replace: {repl_res}')
186
 
187
  if not skip_refine_text:
188
  text_tokens = refine_text(
189
  self.pretrain_models,
190
  text,
191
+ device=self.device,
192
  **params_refine_text,
193
  )['ids']
194
  text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
 
199
 
200
  text = [params_infer_code.get('prompt', '') + i for i in text]
201
  params_infer_code.pop('prompt', '')
202
+ result_gen = infer_code(
203
+ self.pretrain_models,
204
+ text,
205
+ device=self.device,
206
+ **params_infer_code,
207
+ return_hidden=use_decoder,
208
+ stream=stream,
209
+ )
210
  if use_decoder:
211
  field = 'hiddens'
212
  docoder_name = 'decoder'
213
  else:
214
  field = 'ids'
215
  docoder_name = 'dvae'
216
+ if "mps" in str(self.device):
217
+ vocos_decode = lambda spec: [self.pretrain_models['vocos'].decode(
218
+ i.cpu()
219
+ ).cpu().numpy() for i in spec]
220
+ else:
221
+ vocos_decode = lambda spec: [self.pretrain_models['vocos'].decode(
222
+ i
223
+ ).cpu().numpy() for i in spec]
224
  if stream:
225
 
226
  length = 0
 
234
  if not len(chunk_data):
235
  continue
236
  self.logger.debug(f'new hidden {len(chunk_data)=}')
237
+ mel_spec = [self.pretrain_models[docoder_name](i[None].permute(0,2,1).to(self.device)) for i in [chunk_data]]
238
+ del_all(result)
239
+ del chunk_data
240
  wav = vocos_decode(mel_spec)
241
+ del_all(mel_spec)
242
  self.logger.debug(f'yield wav chunk {len(wav[0])=} {len(wav[0][0])=}')
243
  yield wav
244
  return
245
+ result = next(result_gen)
246
+ mel_spec = [self.pretrain_models[docoder_name](i[None].permute(0,2,1).to(self.device)) for i in result[field]]
247
+ del_all(result)
248
+ wav = vocos_decode(mel_spec)
249
+ del_all(mel_spec)
250
+ yield wav
251
 
252
  def infer(
253
  self,
ChatTTS/ChatTTS/infer/api.py CHANGED
@@ -2,7 +2,10 @@
2
  import torch
3
  import torch.nn.functional as F
4
  from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
 
5
  from ..utils.infer_utils import CustomRepetitionPenaltyLogitsProcessorRepeat
 
 
6
 
7
  def infer_code(
8
  models,
@@ -14,39 +17,42 @@ def infer_code(
14
  repetition_penalty = 1.05,
15
  max_new_token = 2048,
16
  stream=False,
 
17
  **kwargs
18
  ):
19
-
20
- device = next(models['gpt'].parameters()).device
21
-
22
  if not isinstance(text, list):
23
  text = [text]
24
 
25
  if not isinstance(temperature, list):
26
- temperature = [temperature] * models['gpt'].num_vq
27
 
28
  if spk_emb is not None:
29
  text = [f'[Stts][spk_emb]{i}[Ptts]' for i in text]
30
  else:
31
  text = [f'[Stts][empty_spk]{i}[Ptts]' for i in text]
32
 
33
- text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
34
- input_ids = text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq)
35
- text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
36
-
37
- inputs = {
38
- 'input_ids': input_ids,
39
- 'text_mask': text_mask,
40
- 'attention_mask': text_token['attention_mask'],
41
- }
 
 
42
 
43
- emb = models['gpt'].get_emb(**inputs)
44
  if spk_emb is not None:
45
- emb[inputs['input_ids'][..., 0] == models['tokenizer'].convert_tokens_to_ids('[spk_emb]')] = \
46
- F.normalize(spk_emb.to(device).to(emb.dtype)[None].expand(len(text), -1), p=2.0, dim=1, eps=1e-12)
47
-
48
- num_code = models['gpt'].emb_code[0].num_embeddings - 1
49
-
 
50
  LogitsWarpers = []
51
  if top_P is not None:
52
  LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
@@ -58,10 +64,10 @@ def infer_code(
58
  LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(\
59
  repetition_penalty, num_code, 16))
60
 
61
- result = models['gpt'].generate(
62
- emb, inputs['input_ids'],
63
  temperature = torch.tensor(temperature, device=device),
64
- attention_mask = inputs['attention_mask'],
65
  LogitsWarpers = LogitsWarpers,
66
  LogitsProcessors = LogitsProcessors,
67
  eos_token = num_code,
@@ -71,6 +77,11 @@ def infer_code(
71
  **kwargs
72
  )
73
 
 
 
 
 
 
74
  return result
75
 
76
 
@@ -83,11 +94,12 @@ def refine_text(
83
  repetition_penalty = 1.0,
84
  max_new_token = 384,
85
  prompt = '',
 
86
  **kwargs
87
  ):
88
-
89
- device = next(models['gpt'].parameters()).device
90
-
91
  if not isinstance(text, list):
92
  text = [text]
93
 
@@ -97,11 +109,7 @@ def refine_text(
97
  text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
98
  text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
99
 
100
- inputs = {
101
- 'input_ids': text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq),
102
- 'text_mask': text_mask,
103
- 'attention_mask': text_token['attention_mask'],
104
- }
105
 
106
  LogitsWarpers = []
107
  if top_P is not None:
@@ -112,11 +120,17 @@ def refine_text(
112
  LogitsProcessors = []
113
  if repetition_penalty is not None and repetition_penalty != 1:
114
  LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, len(models['tokenizer']), 16))
115
-
116
- result = models['gpt'].generate(
117
- models['gpt'].get_emb(**inputs), inputs['input_ids'],
 
 
 
 
 
 
118
  temperature = torch.tensor([temperature,], device=device),
119
- attention_mask = inputs['attention_mask'],
120
  LogitsWarpers = LogitsWarpers,
121
  LogitsProcessors = LogitsProcessors,
122
  eos_token = torch.tensor(models['tokenizer'].convert_tokens_to_ids('[Ebreak]'), device=device)[None],
@@ -125,4 +139,10 @@ def refine_text(
125
  stream = False,
126
  **kwargs
127
  )
 
 
 
 
 
 
128
  return next(result)
 
2
  import torch
3
  import torch.nn.functional as F
4
  from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
5
+
6
  from ..utils.infer_utils import CustomRepetitionPenaltyLogitsProcessorRepeat
7
+ from ..utils.io import del_all
8
+ from ..model.gpt import GPT_warpper
9
 
10
  def infer_code(
11
  models,
 
17
  repetition_penalty = 1.05,
18
  max_new_token = 2048,
19
  stream=False,
20
+ device="cpu",
21
  **kwargs
22
  ):
23
+
24
+ gpt: GPT_warpper = models['gpt']
25
+
26
  if not isinstance(text, list):
27
  text = [text]
28
 
29
  if not isinstance(temperature, list):
30
+ temperature = [temperature] * gpt.num_vq
31
 
32
  if spk_emb is not None:
33
  text = [f'[Stts][spk_emb]{i}[Ptts]' for i in text]
34
  else:
35
  text = [f'[Stts][empty_spk]{i}[Ptts]' for i in text]
36
 
37
+ text_token_tmp = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True)
38
+ text_token = text_token_tmp.to(device)
39
+ del text_token_tmp
40
+ input_ids = text_token['input_ids'][...,None].expand(-1, -1, gpt.num_vq).to(gpt.device_gpt)
41
+ text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=gpt.device_gpt)
42
+
43
+ emb = gpt.get_emb(
44
+ input_ids=input_ids,
45
+ text_mask=text_mask,
46
+ )
47
+ del text_mask
48
 
 
49
  if spk_emb is not None:
50
+ n = F.normalize(spk_emb.to(emb.dtype)[None].expand(len(text), -1), p=2.0, dim=1, eps=1e-12).to(gpt.device_gpt)
51
+ emb[input_ids[..., 0] == models['tokenizer'].convert_tokens_to_ids('[spk_emb]')] = n
52
+ del n
53
+
54
+ num_code = int(gpt.emb_code[0].num_embeddings - 1)
55
+
56
  LogitsWarpers = []
57
  if top_P is not None:
58
  LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
 
64
  LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(\
65
  repetition_penalty, num_code, 16))
66
 
67
+ result = gpt.generate(
68
+ emb, input_ids,
69
  temperature = torch.tensor(temperature, device=device),
70
+ attention_mask = text_token['attention_mask'],
71
  LogitsWarpers = LogitsWarpers,
72
  LogitsProcessors = LogitsProcessors,
73
  eos_token = num_code,
 
77
  **kwargs
78
  )
79
 
80
+ del_all(text_token)
81
+ del emb, text_token, input_ids
82
+ del_all(LogitsWarpers)
83
+ del_all(LogitsProcessors)
84
+
85
  return result
86
 
87
 
 
94
  repetition_penalty = 1.0,
95
  max_new_token = 384,
96
  prompt = '',
97
+ device="cpu",
98
  **kwargs
99
  ):
100
+
101
+ gpt: GPT_warpper = models['gpt']
102
+
103
  if not isinstance(text, list):
104
  text = [text]
105
 
 
109
  text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
110
  text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
111
 
112
+ input_ids = text_token['input_ids'][...,None].expand(-1, -1, gpt.num_vq)
 
 
 
 
113
 
114
  LogitsWarpers = []
115
  if top_P is not None:
 
120
  LogitsProcessors = []
121
  if repetition_penalty is not None and repetition_penalty != 1:
122
  LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, len(models['tokenizer']), 16))
123
+
124
+ emb = gpt.get_emb(
125
+ input_ids=input_ids,
126
+ text_mask=text_mask,
127
+ )
128
+ del text_mask
129
+
130
+ result = gpt.generate(
131
+ emb, input_ids,
132
  temperature = torch.tensor([temperature,], device=device),
133
+ attention_mask = text_token['attention_mask'],
134
  LogitsWarpers = LogitsWarpers,
135
  LogitsProcessors = LogitsProcessors,
136
  eos_token = torch.tensor(models['tokenizer'].convert_tokens_to_ids('[Ebreak]'), device=device)[None],
 
139
  stream = False,
140
  **kwargs
141
  )
142
+
143
+ del_all(text_token)
144
+ del emb, text_token, input_ids
145
+ del_all(LogitsWarpers)
146
+ del_all(LogitsProcessors)
147
+
148
  return next(result)
ChatTTS/ChatTTS/model/dvae.py CHANGED
@@ -1,5 +1,4 @@
1
  import math
2
- from einops import rearrange
3
  from vector_quantize_pytorch import GroupedResidualFSQ
4
 
5
  import torch
@@ -66,23 +65,32 @@ class GFSQ(nn.Module):
66
  self.G = G
67
  self.R = R
68
 
69
- def _embed(self, x):
70
  if self.transpose:
71
  x = x.transpose(1,2)
 
72
  x = rearrange(
73
  x, "b t (g r) -> g b t r", g = self.G, r = self.R,
74
- )
 
 
75
  feat = self.quantizer.get_output_from_indices(x)
76
  return feat.transpose(1,2) if self.transpose else feat
77
-
78
  def forward(self, x,):
79
  if self.transpose:
80
  x = x.transpose(1,2)
81
  feat, ind = self.quantizer(x)
 
82
  ind = rearrange(
83
  ind, "g b t r ->b t (g r)",
84
- )
85
- embed_onehot = F.one_hot(ind.long(), self.n_ind).to(x.dtype)
 
 
 
 
 
86
  e_mean = torch.mean(embed_onehot, dim=[0,1])
87
  e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1)
88
  perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))
 
1
  import math
 
2
  from vector_quantize_pytorch import GroupedResidualFSQ
3
 
4
  import torch
 
65
  self.G = G
66
  self.R = R
67
 
68
+ def _embed(self, x: torch.Tensor):
69
  if self.transpose:
70
  x = x.transpose(1,2)
71
+ """
72
  x = rearrange(
73
  x, "b t (g r) -> g b t r", g = self.G, r = self.R,
74
+ )
75
+ """
76
+ x.view(-1, self.G, self.R).permute(2, 0, 1, 3)
77
  feat = self.quantizer.get_output_from_indices(x)
78
  return feat.transpose(1,2) if self.transpose else feat
79
+
80
  def forward(self, x,):
81
  if self.transpose:
82
  x = x.transpose(1,2)
83
  feat, ind = self.quantizer(x)
84
+ """
85
  ind = rearrange(
86
  ind, "g b t r ->b t (g r)",
87
+ )
88
+ """
89
+ ind = ind.permute(1, 2, 0, 3).contiguous()
90
+ ind = ind.view(ind.size(0), ind.size(1), -1)
91
+ embed_onehot_tmp = F.one_hot(ind.long(), self.n_ind)
92
+ embed_onehot = embed_onehot_tmp.to(x.dtype)
93
+ del embed_onehot_tmp
94
  e_mean = torch.mean(embed_onehot, dim=[0,1])
95
  e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1)
96
  perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))
ChatTTS/ChatTTS/model/gpt.py CHANGED
@@ -2,8 +2,10 @@ import os
2
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
3
 
4
  import logging
 
 
 
5
  from tqdm import tqdm
6
- from einops import rearrange
7
  from transformers.cache_utils import Cache
8
 
9
  import torch
@@ -12,8 +14,10 @@ import torch.nn.functional as F
12
  import torch.nn.utils.parametrize as P
13
  from torch.nn.utils.parametrizations import weight_norm
14
  from transformers import LlamaModel, LlamaConfig
15
-
16
-
 
 
17
  class LlamaMLP(nn.Module):
18
  def __init__(self, hidden_size, intermediate_size):
19
  super().__init__()
@@ -36,40 +40,67 @@ class GPT_warpper(nn.Module):
36
  num_audio_tokens,
37
  num_text_tokens,
38
  num_vq=4,
 
 
39
  ):
40
  super().__init__()
41
 
42
- self.logger = logging.getLogger(__name__)
43
- self.gpt = self.build_model(gpt_config)
 
 
 
 
44
  self.model_dim = self.gpt.config.hidden_size
 
 
 
 
 
 
45
 
46
- self.num_vq = num_vq
47
- self.emb_code = nn.ModuleList([nn.Embedding(num_audio_tokens, self.model_dim) for i in range(self.num_vq)])
48
- self.emb_text = nn.Embedding(num_text_tokens, self.model_dim)
49
- self.head_text = weight_norm(nn.Linear(self.model_dim, num_text_tokens, bias=False), name='weight')
50
- self.head_code = nn.ModuleList([weight_norm(nn.Linear(self.model_dim, num_audio_tokens, bias=False), name='weight') for i in range(self.num_vq)])
 
 
 
 
 
 
 
 
 
51
 
52
- def build_model(self, config):
53
 
54
  configuration = LlamaConfig(**config)
55
  model = LlamaModel(configuration)
56
  del model.embed_tokens
57
 
58
- return model
59
-
60
- def get_emb(self, input_ids, text_mask, **kwargs):
61
 
62
- emb_text = self.emb_text(input_ids[text_mask][:, 0])
63
-
64
- emb_code = [self.emb_code[i](input_ids[~text_mask][:, i]) for i in range(self.num_vq)]
 
 
 
 
 
 
65
  emb_code = torch.stack(emb_code, 2).sum(2)
66
-
67
  emb = torch.zeros((input_ids.shape[:-1])+(emb_text.shape[-1],), device=emb_text.device, dtype=emb_text.dtype)
68
  emb[text_mask] = emb_text
69
  emb[~text_mask] = emb_code.to(emb.dtype)
70
-
 
 
71
  return emb
72
-
73
  def prepare_inputs_for_generation(
74
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
75
  ):
@@ -157,7 +188,7 @@ class GPT_warpper(nn.Module):
157
  emb,
158
  inputs_ids,
159
  temperature,
160
- eos_token,
161
  attention_mask = None,
162
  max_new_token = 2048,
163
  min_new_token = 0,
@@ -177,8 +208,8 @@ class GPT_warpper(nn.Module):
177
  start_idx, end_idx = inputs_ids.shape[1], torch.zeros(inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long)
178
  finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool()
179
 
180
- temperature = temperature[None].expand(inputs_ids.shape[0], -1)
181
- temperature = rearrange(temperature, "b n -> (b n) 1")
182
 
183
  attention_mask_cache = torch.ones((inputs_ids.shape[0], inputs_ids.shape[1]+max_new_token,), dtype=torch.bool, device=inputs_ids.device)
184
  if attention_mask is not None:
@@ -189,7 +220,6 @@ class GPT_warpper(nn.Module):
189
  past_key_values = None
190
 
191
  for i in range(max_new_token):
192
- pbar.update(1)
193
  model_input = self.prepare_inputs_for_generation(
194
  inputs_ids,
195
  past_key_values,
@@ -200,17 +230,26 @@ class GPT_warpper(nn.Module):
200
  if i == 0:
201
  model_input['inputs_embeds'] = emb
202
  else:
 
203
  if infer_text:
204
- model_input['inputs_embeds'] = self.emb_text(model_input['input_ids'][:,:,0])
205
  else:
206
- code_emb = [self.emb_code[i](model_input['input_ids'][:,:,i]) for i in range(self.num_vq)]
207
  model_input['inputs_embeds'] = torch.stack(code_emb, 3).sum(3)
208
-
209
- model_input['input_ids'] = None
210
- outputs = self.gpt.forward(**model_input, output_attentions=return_attn)
211
- del model_input
 
 
 
 
 
 
 
 
212
  attentions.append(outputs.attentions)
213
- hidden_states = outputs[0] # 🐻
214
  past_key_values = outputs.past_key_values
215
  del outputs
216
  if return_hidden:
@@ -225,8 +264,14 @@ class GPT_warpper(nn.Module):
225
  logits = logits[:, -1].float()
226
 
227
  if not infer_text:
228
- logits = rearrange(logits, "b c n -> (b n) c")
229
- logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c")
 
 
 
 
 
 
230
  else:
231
  logits_token = inputs_ids[:, start_idx:, 0]
232
 
@@ -247,10 +292,11 @@ class GPT_warpper(nn.Module):
247
 
248
  del logits
249
 
250
- idx_next = torch.multinomial(scores, num_samples=1)
251
 
252
  if not infer_text:
253
- idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
 
254
  finish_or = (idx_next == eos_token).any(1)
255
  finish |= finish_or
256
  del finish_or
@@ -278,9 +324,11 @@ class GPT_warpper(nn.Module):
278
  'attentions': attentions,
279
  'hiddens':y_hiddens,
280
  }
 
281
  if finish.all():
282
  pbar.update(max_new_token-i-1)
283
  break
 
284
 
285
  inputs_ids = [inputs_ids[idx, start_idx: start_idx+i] for idx, i in enumerate(end_idx.int())]
286
  inputs_ids = [i[:, 0] for i in inputs_ids] if infer_text else inputs_ids
 
2
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
3
 
4
  import logging
5
+ from typing import Union
6
+
7
+
8
  from tqdm import tqdm
 
9
  from transformers.cache_utils import Cache
10
 
11
  import torch
 
14
  import torch.nn.utils.parametrize as P
15
  from torch.nn.utils.parametrizations import weight_norm
16
  from transformers import LlamaModel, LlamaConfig
17
+
18
+ from ..utils.io import del_all
19
+
20
+
21
  class LlamaMLP(nn.Module):
22
  def __init__(self, hidden_size, intermediate_size):
23
  super().__init__()
 
40
  num_audio_tokens,
41
  num_text_tokens,
42
  num_vq=4,
43
+ device="cpu",
44
+ logger=logging.getLogger(__name__)
45
  ):
46
  super().__init__()
47
 
48
+ self.logger = logger
49
+ self.device = device
50
+ self.device_gpt = device if "mps" not in str(device) else "cpu"
51
+ self.num_vq = num_vq
52
+
53
+ self.gpt = self.build_model(gpt_config, self.device_gpt)
54
  self.model_dim = self.gpt.config.hidden_size
55
+ self.emb_code = nn.ModuleList(
56
+ [nn.Embedding(
57
+ num_audio_tokens, self.model_dim, device=self.device_gpt,
58
+ ) for _ in range(num_vq)],
59
+ )
60
+ self.emb_text = nn.Embedding(num_text_tokens, self.model_dim, device=self.device_gpt)
61
 
62
+ self.head_text = weight_norm(
63
+ nn.Linear(
64
+ self.model_dim, num_text_tokens, bias=False, device=device,
65
+ ),
66
+ name='weight',
67
+ )
68
+ self.head_code = nn.ModuleList(
69
+ [weight_norm(
70
+ nn.Linear(
71
+ self.model_dim, num_audio_tokens, bias=False, device=device,
72
+ ),
73
+ name='weight',
74
+ ) for _ in range(self.num_vq)],
75
+ )
76
 
77
+ def build_model(self, config, device):
78
 
79
  configuration = LlamaConfig(**config)
80
  model = LlamaModel(configuration)
81
  del model.embed_tokens
82
 
83
+ return model.to(device)
 
 
84
 
85
+ def get_emb(self, input_ids, text_mask):
86
+
87
+ emb_text = self.emb_text(input_ids[text_mask][:, 0].to(self.device_gpt))
88
+
89
+ text_mask_inv = ~text_mask
90
+ masked_input_ids = input_ids[text_mask_inv].to(self.device_gpt)
91
+ del text_mask_inv
92
+
93
+ emb_code = [self.emb_code[i](masked_input_ids[:, i]) for i in range(self.num_vq)]
94
  emb_code = torch.stack(emb_code, 2).sum(2)
95
+
96
  emb = torch.zeros((input_ids.shape[:-1])+(emb_text.shape[-1],), device=emb_text.device, dtype=emb_text.dtype)
97
  emb[text_mask] = emb_text
98
  emb[~text_mask] = emb_code.to(emb.dtype)
99
+
100
+ del emb_text, emb_code
101
+
102
  return emb
103
+
104
  def prepare_inputs_for_generation(
105
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
106
  ):
 
188
  emb,
189
  inputs_ids,
190
  temperature,
191
+ eos_token: Union[int, torch.Tensor],
192
  attention_mask = None,
193
  max_new_token = 2048,
194
  min_new_token = 0,
 
208
  start_idx, end_idx = inputs_ids.shape[1], torch.zeros(inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long)
209
  finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool()
210
 
211
+ temperature = temperature.unsqueeze_(0).expand(inputs_ids.shape[0], -1).contiguous().view(-1, 1)
212
+ # temperature = rearrange(temperature, "b n -> (b n) 1")
213
 
214
  attention_mask_cache = torch.ones((inputs_ids.shape[0], inputs_ids.shape[1]+max_new_token,), dtype=torch.bool, device=inputs_ids.device)
215
  if attention_mask is not None:
 
220
  past_key_values = None
221
 
222
  for i in range(max_new_token):
 
223
  model_input = self.prepare_inputs_for_generation(
224
  inputs_ids,
225
  past_key_values,
 
230
  if i == 0:
231
  model_input['inputs_embeds'] = emb
232
  else:
233
+ inputs_ids_emb = model_input['input_ids'].to(self.device_gpt)
234
  if infer_text:
235
+ model_input['inputs_embeds'] = self.emb_text(inputs_ids_emb[:,:,0])
236
  else:
237
+ code_emb = [self.emb_code[i](inputs_ids_emb[:,:,i]) for i in range(self.num_vq)]
238
  model_input['inputs_embeds'] = torch.stack(code_emb, 3).sum(3)
239
+ del inputs_ids_emb, model_input['input_ids']
240
+
241
+ outputs = self.gpt.forward(
242
+ attention_mask=model_input["attention_mask"].to(self.device_gpt),
243
+ position_ids=model_input["position_ids"].to(self.device_gpt),
244
+ past_key_values=model_input["past_key_values"],
245
+ inputs_embeds=model_input['inputs_embeds'].to(self.device_gpt),
246
+ use_cache=model_input['use_cache'],
247
+ output_attentions=return_attn,
248
+ cache_position=model_input['cache_position'].to(self.device_gpt),
249
+ )
250
+ del_all(model_input)
251
  attentions.append(outputs.attentions)
252
+ hidden_states = outputs[0].to(self.device) # 🐻
253
  past_key_values = outputs.past_key_values
254
  del outputs
255
  if return_hidden:
 
264
  logits = logits[:, -1].float()
265
 
266
  if not infer_text:
267
+ # logits = rearrange(logits, "b c n -> (b n) c")
268
+ logits = logits.permute(0, 2, 1)
269
+ logits = logits.reshape(-1, logits.size(2))
270
+ # logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c")
271
+ inputs_ids_sliced = inputs_ids[:, start_idx:].permute(0, 2, 1)
272
+ logits_token = inputs_ids_sliced.reshape(
273
+ inputs_ids_sliced.size(0)*inputs_ids_sliced.size(1), -1,
274
+ )
275
  else:
276
  logits_token = inputs_ids[:, start_idx:, 0]
277
 
 
292
 
293
  del logits
294
 
295
+ idx_next = torch.multinomial(scores, num_samples=1).to(finish.device)
296
 
297
  if not infer_text:
298
+ # idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
299
+ idx_next = idx_next.view(-1, self.num_vq)
300
  finish_or = (idx_next == eos_token).any(1)
301
  finish |= finish_or
302
  del finish_or
 
324
  'attentions': attentions,
325
  'hiddens':y_hiddens,
326
  }
327
+
328
  if finish.all():
329
  pbar.update(max_new_token-i-1)
330
  break
331
+ pbar.update(1)
332
 
333
  inputs_ids = [inputs_ids[idx, start_idx: start_idx+i] for idx, i in enumerate(end_idx.int())]
334
  inputs_ids = [i[:, 0] for i in inputs_ids] if infer_text else inputs_ids
ChatTTS/ChatTTS/utils/download.py CHANGED
@@ -3,10 +3,8 @@ from pathlib import Path
3
  import hashlib
4
  import requests
5
  from io import BytesIO
6
- import logging
7
-
8
- logger = logging.getLogger(__name__)
9
 
 
10
 
11
  def sha256(f) -> str:
12
  sha256_hash = hashlib.sha256()
 
3
  import hashlib
4
  import requests
5
  from io import BytesIO
 
 
 
6
 
7
+ from .log import logger
8
 
9
  def sha256(f) -> str:
10
  sha256_hash = hashlib.sha256()
ChatTTS/ChatTTS/utils/gpu_utils.py CHANGED
@@ -1,9 +1,9 @@
1
 
2
  import torch
3
- import logging
 
4
 
5
  def select_device(min_memory=2048):
6
- logger = logging.getLogger(__name__)
7
  if torch.cuda.is_available():
8
  available_gpus = []
9
  for i in range(torch.cuda.device_count()):
 
1
 
2
  import torch
3
+
4
+ from .log import logger
5
 
6
  def select_device(min_memory=2048):
 
7
  if torch.cuda.is_available():
8
  available_gpus = []
9
  for i in range(torch.cuda.device_count()):
ChatTTS/ChatTTS/utils/infer_utils.py CHANGED
@@ -2,7 +2,6 @@
2
  import re
3
  import torch
4
  import torch.nn.functional as F
5
- import os
6
  import json
7
 
8
 
@@ -76,12 +75,15 @@ class HomophonesReplacer:
76
 
77
  def replace(self, text):
78
  result = []
 
79
  for char in text:
80
  if char in self.homophones_map:
81
- result.append(self.homophones_map[char])
 
 
82
  else:
83
  result.append(char)
84
- return ''.join(result)
85
 
86
  def count_invalid_characters(s):
87
 
 
2
  import re
3
  import torch
4
  import torch.nn.functional as F
 
5
  import json
6
 
7
 
 
75
 
76
  def replace(self, text):
77
  result = []
78
+ replaced_words = []
79
  for char in text:
80
  if char in self.homophones_map:
81
+ repl_char = self.homophones_map[char]
82
+ result.append(repl_char)
83
+ replaced_words.append((char, repl_char))
84
  else:
85
  result.append(char)
86
+ return ''.join(result), replaced_words
87
 
88
  def count_invalid_characters(s):
89
 
ChatTTS/ChatTTS/utils/io.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import logging
4
+ from typing import Union
5
+
6
+ from .log import logger
7
+
8
+ def get_latest_modified_file(directory):
9
+
10
+ files = [os.path.join(directory, f) for f in os.listdir(directory)]
11
+ if not files:
12
+ logger.log(logging.WARNING, f'No files found in the directory: {directory}')
13
+ return None
14
+ latest_file = max(files, key=os.path.getmtime)
15
+
16
+ return latest_file
17
+
18
+ def del_all(d: Union[dict, list]):
19
+ if isinstance(d, dict):
20
+ lst = list(d.keys())
21
+ for k in lst:
22
+ x = d.pop(k)
23
+ if isinstance(x, dict) or isinstance(x, list):
24
+ del_all(x)
25
+ del x
26
+ return
27
+ elif isinstance(d, list):
28
+ while len(d):
29
+ x = d.pop()
30
+ if isinstance(x, dict) or isinstance(x, list):
31
+ del_all(x)
32
+ del x
33
+ return
ChatTTS/ChatTTS/utils/log.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ logger = logging.getLogger(Path(__file__).parent.name)
5
+
6
+ def set_utils_logger(l: logging.Logger):
7
+ global logger
8
+ logger = l
ChatTTS/docs/cn/README.md CHANGED
@@ -3,7 +3,7 @@
3
  <a href="https://trendshift.io/repositories/10489" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10489" alt="2noise%2FChatTTS | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
4
 
5
  # ChatTTS
6
- 一款用于日常对话的生成式语音模型。
7
 
8
  [![Licence](https://img.shields.io/badge/LICENSE-CC%20BY--NC%204.0-green.svg?style=for-the-badge)](https://github.com/2noise/ChatTTS/blob/main/LICENSE)
9
 
@@ -14,6 +14,9 @@
14
 
15
  </div>
16
 
 
 
 
17
  ## 简介
18
 
19
  ChatTTS 是一款专门为对话场景(例如 LLM 助手)设计的文本转语音模型。
@@ -26,7 +29,7 @@ ChatTTS 是一款专门为对话场景(例如 LLM 助手)设计的文本转
26
 
27
  ### 亮点
28
 
29
- > 你可以参考 **[Bilibili](https://www.bilibili.com/video/BV1zn4y1o7iV)** 上的这个视频了解详细的介绍.
30
 
31
  1. **对话式 TTS**: ChatTTS 针对对话式任务进行了优化,能够实现自然且富有表现力的合成语音。它支持多个说话者,便于生成互动式对话。
32
  2. **精细的控制**: 该模型可以预测和控制精细的韵律特征,包括笑声、停顿和插入语。
@@ -34,8 +37,8 @@ ChatTTS 是一款专门为对话场景(例如 LLM 助手)设计的文本转
34
 
35
  ### 数据集和模型
36
 
37
- - 主要模型使用 100,000+ 小时的中文和英文音频数据进行训练。
38
- - **[HuggingFace](https://huggingface.co/2Noise/ChatTTS)** 上的开源版本是一个在 40,000 小时数据上进行无监督微调的预训练模型。。
39
 
40
  ### 路线图
41
 
@@ -50,7 +53,7 @@ ChatTTS 是一款专门为对话场景(例如 LLM 助手)设计的文本转
50
  > [!Important]
51
  > 此仓库仅供学术用途。
52
 
53
- 本项目旨在用于教育和研究目的,不应用于任何商业或法律目的。作者不保证信息的准确性、完整性或可靠性。此仓库中使用的信息和数据仅供学术和研究目的。数据来自公开来源,作者不声称对数据拥有任何所有权或版权。
54
 
55
  ChatTTS 是一款强大的文本转语音系统。但是,负责任和道德地使用这项技术非常重要。为了限制 ChatTTS 的使用,我们在 40,000 小时模型的训练过程中添加了少量高频噪声,并使用 MP3 格式尽可能压缩音频质量,以防止恶意行为者将其用于犯罪目的。同时,我们内部训练了一个检测模型,并计划在未来开源它。
56
 
@@ -60,7 +63,7 @@ ChatTTS 是一款强大的文本转语音系统。但是,负责任和道德地
60
 
61
  #### 合作洽谈
62
 
63
- 如就模型和路线图进行合作洽谈,请发送邮件至 **[email protected]**。
64
 
65
  #### 线上讨论
66
 
@@ -131,7 +134,7 @@ wavs = chat.infer(texts, )
131
  torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)
132
  ```
133
 
134
- ### 高级用法
135
 
136
  ```python
137
  ###################################
@@ -219,10 +222,14 @@ torchaudio.save("output3.wav", torch.from_numpy(audio_array_en[0]), 24000)
219
 
220
  - [Awesome-ChatTTS](https://github.com/libukai/Awesome-ChatTTS) 一个 ChatTTS 的资源汇总列表。
221
 
222
- ## 感谢所有贡献者的付出
223
 
224
  [![contributors](https://contrib.rocks/image?repo=2noise/ChatTTS)](https://github.com/2noise/ChatTTS/graphs/contributors)
225
 
226
- ## Star 趋势
 
 
227
 
228
- [![Star History Chart](https://api.star-history.com/svg?repos=2noise/ChatTTS&type=Date)](https://star-history.com/#2noise/ChatTTS&Date)
 
 
 
3
  <a href="https://trendshift.io/repositories/10489" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10489" alt="2noise%2FChatTTS | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
4
 
5
  # ChatTTS
6
+ 一款适用于日常对话的生成式语音模型。
7
 
8
  [![Licence](https://img.shields.io/badge/LICENSE-CC%20BY--NC%204.0-green.svg?style=for-the-badge)](https://github.com/2noise/ChatTTS/blob/main/LICENSE)
9
 
 
14
 
15
  </div>
16
 
17
+ > [!NOTE]
18
+ > 注意此版本可能不是最新版,所有内容请以英文版为准。
19
+
20
  ## 简介
21
 
22
  ChatTTS 是一款专门为对话场景(例如 LLM 助手)设计的文本转语音模型。
 
29
 
30
  ### 亮点
31
 
32
+ > 你可以参考 **[Bilibili](https://www.bilibili.com/video/BV1zn4y1o7iV)** 上的这个视频,了解本项目的详细情况。
33
 
34
  1. **对话式 TTS**: ChatTTS 针对对话式任务进行了优化,能够实现自然且富有表现力的合成语音。它支持多个说话者,便于生成互动式对话。
35
  2. **精细的控制**: 该模型可以预测和控制精细的韵律特征,包括笑声、停顿和插入语。
 
37
 
38
  ### 数据集和模型
39
 
40
+ - 主模型使用了 100,000+ 小时的中文和英文音频数据进行训练。
41
+ - **[HuggingFace](https://huggingface.co/2Noise/ChatTTS)** 上的开源版本是一个在 40,000 小时数据上进行无监督微调的预训练模型。
42
 
43
  ### 路线图
44
 
 
53
  > [!Important]
54
  > 此仓库仅供学术用途。
55
 
56
+ 本项目旨在用于教育和研究目的,不适用于任何商业或法律目的。作者不保证信息的准确性、完整性和可靠性。此仓库中使用的信息和数据仅供学术和研究目的。数据来自公开来源,作者不声称对数据拥有任何所有权或版权。
57
 
58
  ChatTTS 是一款强大的文本转语音系统。但是,负责任和道德地使用这项技术非常重要。为了限制 ChatTTS 的使用,我们在 40,000 小时模型的训练过程中添加了少量高频噪声,并使用 MP3 格式尽可能压缩音频质量,以防止恶意行为者将其用于犯罪目的。同时,我们内部训练了一个检测模型,并计划在未来开源它。
59
 
 
63
 
64
  #### 合作洽谈
65
 
66
+ 如需就模型和路线图进行合作洽谈,请发送邮件至 **[email protected]**。
67
 
68
  #### 线上讨论
69
 
 
134
  torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)
135
  ```
136
 
137
+ ### 进阶用法
138
 
139
  ```python
140
  ###################################
 
222
 
223
  - [Awesome-ChatTTS](https://github.com/libukai/Awesome-ChatTTS) 一个 ChatTTS 的资源汇总列表。
224
 
225
+ ## 贡献者列表
226
 
227
  [![contributors](https://contrib.rocks/image?repo=2noise/ChatTTS)](https://github.com/2noise/ChatTTS/graphs/contributors)
228
 
229
+ ## 项目浏览量
230
+
231
+ <div align="center">
232
 
233
+ ![counter](https://counter.seku.su/cmoe?name=chattts&theme=mbs)
234
+
235
+ </div>
ChatTTS/examples/cmd/run.py CHANGED
@@ -13,6 +13,10 @@ import wave
13
  import ChatTTS
14
  from IPython.display import Audio
15
 
 
 
 
 
16
  def save_wav_file(wav, index):
17
  wav_filename = f"output_audio_{index}.wav"
18
  # Convert numpy array to bytes and write to WAV file
@@ -22,23 +26,26 @@ def save_wav_file(wav, index):
22
  wf.setsampwidth(2) # Sample width in bytes
23
  wf.setframerate(24000) # Sample rate in Hz
24
  wf.writeframes(wav_bytes)
25
- print(f"Audio saved to {wav_filename}")
26
 
27
  def main():
28
  # Retrieve text from command line argument
29
  text_input = sys.argv[1] if len(sys.argv) > 1 else "<YOUR TEXT HERE>"
30
- print("Received text input:", text_input)
31
 
32
- chat = ChatTTS.Chat()
33
- print("Initializing ChatTTS...")
34
- chat.load_models()
35
- print("Models loaded successfully.")
 
 
 
36
 
37
  texts = [text_input]
38
- print("Text prepared for inference:", texts)
39
 
40
  wavs = chat.infer(texts, use_decoder=True)
41
- print("Inference completed. Audio generation successful.")
42
  # Save each generated wav file to a local file
43
  for index, wav in enumerate(wavs):
44
  save_wav_file(wav, index)
@@ -46,6 +53,6 @@ def main():
46
  return Audio(wavs[0], rate=24_000, autoplay=True)
47
 
48
  if __name__ == "__main__":
49
- print("Starting the TTS application...")
50
  main()
51
- print("TTS application finished.")
 
13
  import ChatTTS
14
  from IPython.display import Audio
15
 
16
+ from tools.logger import get_logger
17
+
18
+ logger = get_logger("Command")
19
+
20
  def save_wav_file(wav, index):
21
  wav_filename = f"output_audio_{index}.wav"
22
  # Convert numpy array to bytes and write to WAV file
 
26
  wf.setsampwidth(2) # Sample width in bytes
27
  wf.setframerate(24000) # Sample rate in Hz
28
  wf.writeframes(wav_bytes)
29
+ logger.info(f"Audio saved to {wav_filename}")
30
 
31
  def main():
32
  # Retrieve text from command line argument
33
  text_input = sys.argv[1] if len(sys.argv) > 1 else "<YOUR TEXT HERE>"
34
+ logger.info("Received text input: %s", text_input)
35
 
36
+ chat = ChatTTS.Chat(get_logger("ChatTTS"))
37
+ logger.info("Initializing ChatTTS...")
38
+ if chat.load_models():
39
+ logger.info("Models loaded successfully.")
40
+ else:
41
+ logger.error("Models load failed.")
42
+ sys.exit(1)
43
 
44
  texts = [text_input]
45
+ logger.info("Text prepared for inference: %s", texts)
46
 
47
  wavs = chat.infer(texts, use_decoder=True)
48
+ logger.info("Inference completed. Audio generation successful.")
49
  # Save each generated wav file to a local file
50
  for index, wav in enumerate(wavs):
51
  save_wav_file(wav, index)
 
53
  return Audio(wavs[0], rate=24_000, autoplay=True)
54
 
55
  if __name__ == "__main__":
56
+ logger.info("Starting the TTS application...")
57
  main()
58
+ logger.info("TTS application finished.")
ChatTTS/examples/ipynb/colab.ipynb CHANGED
@@ -1,24 +1,38 @@
1
  {
2
  "cells": [
 
 
 
 
 
 
 
 
 
3
  {
4
  "cell_type": "code",
5
  "execution_count": null,
6
  "metadata": {
7
- "colab": {
8
- "base_uri": "https://localhost:8080/"
9
- },
10
- "id": "hegwDOfffwzw",
11
- "outputId": "1e221210-152b-4f5b-f009-9b9ffec2fa9f"
12
  },
13
  "outputs": [],
14
  "source": [
 
15
  "!rm -rf /content/ChatTTS\n",
16
  "!git clone https://github.com/2noise/ChatTTS.git\n",
17
  "!pip install -r /content/ChatTTS/requirements.txt\n",
18
- "!pip install nemo_text_processing WeTextProcessing\n",
19
  "!ldconfig /usr/lib64-nvidia"
20
  ]
21
  },
 
 
 
 
 
 
 
 
 
22
  {
23
  "cell_type": "code",
24
  "execution_count": null,
@@ -28,7 +42,7 @@
28
  "outputs": [],
29
  "source": [
30
  "from dotenv import load_dotenv\n",
31
- "load_dotenv(\"sha256.env\")\n",
32
  "\n",
33
  "import torch\n",
34
  "torch._dynamo.config.cache_size_limit = 64\n",
@@ -52,35 +66,69 @@
52
  "cell_type": "code",
53
  "execution_count": null,
54
  "metadata": {
55
- "colab": {
56
- "base_uri": "https://localhost:8080/",
57
- "height": 49,
58
- "referenced_widgets": [
59
- "c365a95346ec4b09a1e6467bf313baf7",
60
- "d79fd51849fd463cb08b83fdb8e5ca0c",
61
- "d247683a0a61441b971dfb39062e1fbf",
62
- "1da23fc236034f32adcaf6bb2e0e7d80",
63
- "4b2126d97c514795ab2a90f7357a203c",
64
- "9775ce64008b417fac3edd55b9e999d9",
65
- "96c9bb2eff4043b2a5dbd1e3e65375e5",
66
- "20aa0031b7bb45bf82443b48b3694166",
67
- "67252ea545d64392a1bd6ac40852e65f",
68
- "2f920c00bcac4787a0078ee035e97b43",
69
- "ba592297ff5347aebae298770a29fb8c"
70
- ]
71
- },
72
- "id": "e0QSkngRbSrg",
73
- "outputId": "138ac28b-6a33-4c31-8fe3-8481bb213d02"
74
  },
75
  "outputs": [],
76
  "source": [
77
- "chat = ChatTTS.Chat()\n",
78
- "\n",
79
- "# Use force_redownload=True if the weights updated.\n",
80
- "chat.load_models(force_redownload=True)\n",
81
- "\n",
82
- "# If you download the weights manually, set source='locals'.\n",
83
- "# chat.load_models(source='local', local_path='YOUR LOCAL PATH')"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  ]
85
  },
86
  {
@@ -105,11 +153,7 @@
105
  "cell_type": "code",
106
  "execution_count": null,
107
  "metadata": {
108
- "colab": {
109
- "base_uri": "https://localhost:8080/"
110
- },
111
- "id": "Su9FmUYAbSrh",
112
- "outputId": "7c2aa0c1-1f99-4da1-b2e5-bbcb93465d89"
113
  },
114
  "outputs": [],
115
  "source": [
@@ -123,12 +167,7 @@
123
  "cell_type": "code",
124
  "execution_count": null,
125
  "metadata": {
126
- "colab": {
127
- "base_uri": "https://localhost:8080/",
128
- "height": 76
129
- },
130
- "id": "YQRwB8lpbSri",
131
- "outputId": "62ca9282-2755-44a5-ffca-c05c5e35ce76"
132
  },
133
  "outputs": [],
134
  "source": [
@@ -139,12 +178,7 @@
139
  "cell_type": "code",
140
  "execution_count": null,
141
  "metadata": {
142
- "colab": {
143
- "base_uri": "https://localhost:8080/",
144
- "height": 76
145
- },
146
- "id": "LuFG6m7AbSri",
147
- "outputId": "d8e0e3a2-d9fe-44db-e1f4-e2596289270e"
148
  },
149
  "outputs": [],
150
  "source": [
@@ -164,11 +198,7 @@
164
  "cell_type": "code",
165
  "execution_count": null,
166
  "metadata": {
167
- "colab": {
168
- "base_uri": "https://localhost:8080/"
169
- },
170
- "id": "kma0HBEBbSrj",
171
- "outputId": "b80b9d2f-8248-41ee-f1d7-eb3bf331ee69"
172
  },
173
  "outputs": [],
174
  "source": [
@@ -183,12 +213,7 @@
183
  "cell_type": "code",
184
  "execution_count": null,
185
  "metadata": {
186
- "colab": {
187
- "base_uri": "https://localhost:8080/",
188
- "height": 76
189
- },
190
- "id": "Nl_mT9KpbSrj",
191
- "outputId": "1bfcc06a-5246-4d25-fc19-3d125362fa59"
192
  },
193
  "outputs": [],
194
  "source": [
@@ -208,11 +233,7 @@
208
  "cell_type": "code",
209
  "execution_count": null,
210
  "metadata": {
211
- "colab": {
212
- "base_uri": "https://localhost:8080/"
213
- },
214
- "id": "Qh7dcWrAbSrk",
215
- "outputId": "3b936323-170a-496b-c4c2-6caa97a8d514"
216
  },
217
  "outputs": [],
218
  "source": [
@@ -227,12 +248,7 @@
227
  "cell_type": "code",
228
  "execution_count": null,
229
  "metadata": {
230
- "colab": {
231
- "base_uri": "https://localhost:8080/",
232
- "height": 76
233
- },
234
- "id": "0ljWDWzabSrk",
235
- "outputId": "8ade2469-c226-44ae-c3a7-ff034e2bffbf"
236
  },
237
  "outputs": [],
238
  "source": [
@@ -252,11 +268,7 @@
252
  "cell_type": "code",
253
  "execution_count": null,
254
  "metadata": {
255
- "colab": {
256
- "base_uri": "https://localhost:8080/"
257
- },
258
- "id": "3hAAc0lJbSrl",
259
- "outputId": "8dc45586-fb2a-4e81-ee53-0ce6df2fc43a"
260
  },
261
  "outputs": [],
262
  "source": [
@@ -269,11 +281,7 @@
269
  "cell_type": "code",
270
  "execution_count": null,
271
  "metadata": {
272
- "colab": {
273
- "base_uri": "https://localhost:8080/"
274
- },
275
- "id": "0GVJxhd3BKQX",
276
- "outputId": "f1484519-7130-450a-b7d8-09de5fe2ffd1"
277
  },
278
  "outputs": [],
279
  "source": [
@@ -284,12 +292,7 @@
284
  "cell_type": "code",
285
  "execution_count": null,
286
  "metadata": {
287
- "colab": {
288
- "base_uri": "https://localhost:8080/",
289
- "height": 76
290
- },
291
- "id": "ngyMht74BicY",
292
- "outputId": "8c7447ad-9ac7-4264-9f53-057d47d43931"
293
  },
294
  "outputs": [],
295
  "source": [
@@ -300,11 +303,7 @@
300
  "cell_type": "code",
301
  "execution_count": null,
302
  "metadata": {
303
- "colab": {
304
- "base_uri": "https://localhost:8080/"
305
- },
306
- "id": "R2WjuVrWbSrl",
307
- "outputId": "0d644cb9-4d65-4147-bd99-d5451439be02"
308
  },
309
  "outputs": [],
310
  "source": [
@@ -316,12 +315,7 @@
316
  "cell_type": "code",
317
  "execution_count": null,
318
  "metadata": {
319
- "colab": {
320
- "base_uri": "https://localhost:8080/",
321
- "height": 76
322
- },
323
- "id": "71Y4pBdl-_Yd",
324
- "outputId": "d44fdf1a-c9e8-42ff-ab96-8712986418fa"
325
  },
326
  "outputs": [],
327
  "source": [
@@ -406,352 +400,6 @@
406
  "nbconvert_exporter": "python",
407
  "pygments_lexer": "ipython3",
408
  "version": "3.10.8"
409
- },
410
- "widgets": {
411
- "application/vnd.jupyter.widget-state+json": {
412
- "1da23fc236034f32adcaf6bb2e0e7d80": {
413
- "model_module": "@jupyter-widgets/controls",
414
- "model_module_version": "1.5.0",
415
- "model_name": "HTMLModel",
416
- "state": {
417
- "_dom_classes": [],
418
- "_model_module": "@jupyter-widgets/controls",
419
- "_model_module_version": "1.5.0",
420
- "_model_name": "HTMLModel",
421
- "_view_count": null,
422
- "_view_module": "@jupyter-widgets/controls",
423
- "_view_module_version": "1.5.0",
424
- "_view_name": "HTMLView",
425
- "description": "",
426
- "description_tooltip": null,
427
- "layout": "IPY_MODEL_2f920c00bcac4787a0078ee035e97b43",
428
- "placeholder": "​",
429
- "style": "IPY_MODEL_ba592297ff5347aebae298770a29fb8c",
430
- "value": " 11/11 [00:00&lt;00:00, 762.51it/s]"
431
- }
432
- },
433
- "20aa0031b7bb45bf82443b48b3694166": {
434
- "model_module": "@jupyter-widgets/base",
435
- "model_module_version": "1.2.0",
436
- "model_name": "LayoutModel",
437
- "state": {
438
- "_model_module": "@jupyter-widgets/base",
439
- "_model_module_version": "1.2.0",
440
- "_model_name": "LayoutModel",
441
- "_view_count": null,
442
- "_view_module": "@jupyter-widgets/base",
443
- "_view_module_version": "1.2.0",
444
- "_view_name": "LayoutView",
445
- "align_content": null,
446
- "align_items": null,
447
- "align_self": null,
448
- "border": null,
449
- "bottom": null,
450
- "display": null,
451
- "flex": null,
452
- "flex_flow": null,
453
- "grid_area": null,
454
- "grid_auto_columns": null,
455
- "grid_auto_flow": null,
456
- "grid_auto_rows": null,
457
- "grid_column": null,
458
- "grid_gap": null,
459
- "grid_row": null,
460
- "grid_template_areas": null,
461
- "grid_template_columns": null,
462
- "grid_template_rows": null,
463
- "height": null,
464
- "justify_content": null,
465
- "justify_items": null,
466
- "left": null,
467
- "margin": null,
468
- "max_height": null,
469
- "max_width": null,
470
- "min_height": null,
471
- "min_width": null,
472
- "object_fit": null,
473
- "object_position": null,
474
- "order": null,
475
- "overflow": null,
476
- "overflow_x": null,
477
- "overflow_y": null,
478
- "padding": null,
479
- "right": null,
480
- "top": null,
481
- "visibility": null,
482
- "width": null
483
- }
484
- },
485
- "2f920c00bcac4787a0078ee035e97b43": {
486
- "model_module": "@jupyter-widgets/base",
487
- "model_module_version": "1.2.0",
488
- "model_name": "LayoutModel",
489
- "state": {
490
- "_model_module": "@jupyter-widgets/base",
491
- "_model_module_version": "1.2.0",
492
- "_model_name": "LayoutModel",
493
- "_view_count": null,
494
- "_view_module": "@jupyter-widgets/base",
495
- "_view_module_version": "1.2.0",
496
- "_view_name": "LayoutView",
497
- "align_content": null,
498
- "align_items": null,
499
- "align_self": null,
500
- "border": null,
501
- "bottom": null,
502
- "display": null,
503
- "flex": null,
504
- "flex_flow": null,
505
- "grid_area": null,
506
- "grid_auto_columns": null,
507
- "grid_auto_flow": null,
508
- "grid_auto_rows": null,
509
- "grid_column": null,
510
- "grid_gap": null,
511
- "grid_row": null,
512
- "grid_template_areas": null,
513
- "grid_template_columns": null,
514
- "grid_template_rows": null,
515
- "height": null,
516
- "justify_content": null,
517
- "justify_items": null,
518
- "left": null,
519
- "margin": null,
520
- "max_height": null,
521
- "max_width": null,
522
- "min_height": null,
523
- "min_width": null,
524
- "object_fit": null,
525
- "object_position": null,
526
- "order": null,
527
- "overflow": null,
528
- "overflow_x": null,
529
- "overflow_y": null,
530
- "padding": null,
531
- "right": null,
532
- "top": null,
533
- "visibility": null,
534
- "width": null
535
- }
536
- },
537
- "4b2126d97c514795ab2a90f7357a203c": {
538
- "model_module": "@jupyter-widgets/base",
539
- "model_module_version": "1.2.0",
540
- "model_name": "LayoutModel",
541
- "state": {
542
- "_model_module": "@jupyter-widgets/base",
543
- "_model_module_version": "1.2.0",
544
- "_model_name": "LayoutModel",
545
- "_view_count": null,
546
- "_view_module": "@jupyter-widgets/base",
547
- "_view_module_version": "1.2.0",
548
- "_view_name": "LayoutView",
549
- "align_content": null,
550
- "align_items": null,
551
- "align_self": null,
552
- "border": null,
553
- "bottom": null,
554
- "display": null,
555
- "flex": null,
556
- "flex_flow": null,
557
- "grid_area": null,
558
- "grid_auto_columns": null,
559
- "grid_auto_flow": null,
560
- "grid_auto_rows": null,
561
- "grid_column": null,
562
- "grid_gap": null,
563
- "grid_row": null,
564
- "grid_template_areas": null,
565
- "grid_template_columns": null,
566
- "grid_template_rows": null,
567
- "height": null,
568
- "justify_content": null,
569
- "justify_items": null,
570
- "left": null,
571
- "margin": null,
572
- "max_height": null,
573
- "max_width": null,
574
- "min_height": null,
575
- "min_width": null,
576
- "object_fit": null,
577
- "object_position": null,
578
- "order": null,
579
- "overflow": null,
580
- "overflow_x": null,
581
- "overflow_y": null,
582
- "padding": null,
583
- "right": null,
584
- "top": null,
585
- "visibility": null,
586
- "width": null
587
- }
588
- },
589
- "67252ea545d64392a1bd6ac40852e65f": {
590
- "model_module": "@jupyter-widgets/controls",
591
- "model_module_version": "1.5.0",
592
- "model_name": "ProgressStyleModel",
593
- "state": {
594
- "_model_module": "@jupyter-widgets/controls",
595
- "_model_module_version": "1.5.0",
596
- "_model_name": "ProgressStyleModel",
597
- "_view_count": null,
598
- "_view_module": "@jupyter-widgets/base",
599
- "_view_module_version": "1.2.0",
600
- "_view_name": "StyleView",
601
- "bar_color": null,
602
- "description_width": ""
603
- }
604
- },
605
- "96c9bb2eff4043b2a5dbd1e3e65375e5": {
606
- "model_module": "@jupyter-widgets/controls",
607
- "model_module_version": "1.5.0",
608
- "model_name": "DescriptionStyleModel",
609
- "state": {
610
- "_model_module": "@jupyter-widgets/controls",
611
- "_model_module_version": "1.5.0",
612
- "_model_name": "DescriptionStyleModel",
613
- "_view_count": null,
614
- "_view_module": "@jupyter-widgets/base",
615
- "_view_module_version": "1.2.0",
616
- "_view_name": "StyleView",
617
- "description_width": ""
618
- }
619
- },
620
- "9775ce64008b417fac3edd55b9e999d9": {
621
- "model_module": "@jupyter-widgets/base",
622
- "model_module_version": "1.2.0",
623
- "model_name": "LayoutModel",
624
- "state": {
625
- "_model_module": "@jupyter-widgets/base",
626
- "_model_module_version": "1.2.0",
627
- "_model_name": "LayoutModel",
628
- "_view_count": null,
629
- "_view_module": "@jupyter-widgets/base",
630
- "_view_module_version": "1.2.0",
631
- "_view_name": "LayoutView",
632
- "align_content": null,
633
- "align_items": null,
634
- "align_self": null,
635
- "border": null,
636
- "bottom": null,
637
- "display": null,
638
- "flex": null,
639
- "flex_flow": null,
640
- "grid_area": null,
641
- "grid_auto_columns": null,
642
- "grid_auto_flow": null,
643
- "grid_auto_rows": null,
644
- "grid_column": null,
645
- "grid_gap": null,
646
- "grid_row": null,
647
- "grid_template_areas": null,
648
- "grid_template_columns": null,
649
- "grid_template_rows": null,
650
- "height": null,
651
- "justify_content": null,
652
- "justify_items": null,
653
- "left": null,
654
- "margin": null,
655
- "max_height": null,
656
- "max_width": null,
657
- "min_height": null,
658
- "min_width": null,
659
- "object_fit": null,
660
- "object_position": null,
661
- "order": null,
662
- "overflow": null,
663
- "overflow_x": null,
664
- "overflow_y": null,
665
- "padding": null,
666
- "right": null,
667
- "top": null,
668
- "visibility": null,
669
- "width": null
670
- }
671
- },
672
- "ba592297ff5347aebae298770a29fb8c": {
673
- "model_module": "@jupyter-widgets/controls",
674
- "model_module_version": "1.5.0",
675
- "model_name": "DescriptionStyleModel",
676
- "state": {
677
- "_model_module": "@jupyter-widgets/controls",
678
- "_model_module_version": "1.5.0",
679
- "_model_name": "DescriptionStyleModel",
680
- "_view_count": null,
681
- "_view_module": "@jupyter-widgets/base",
682
- "_view_module_version": "1.2.0",
683
- "_view_name": "StyleView",
684
- "description_width": ""
685
- }
686
- },
687
- "c365a95346ec4b09a1e6467bf313baf7": {
688
- "model_module": "@jupyter-widgets/controls",
689
- "model_module_version": "1.5.0",
690
- "model_name": "HBoxModel",
691
- "state": {
692
- "_dom_classes": [],
693
- "_model_module": "@jupyter-widgets/controls",
694
- "_model_module_version": "1.5.0",
695
- "_model_name": "HBoxModel",
696
- "_view_count": null,
697
- "_view_module": "@jupyter-widgets/controls",
698
- "_view_module_version": "1.5.0",
699
- "_view_name": "HBoxView",
700
- "box_style": "",
701
- "children": [
702
- "IPY_MODEL_d79fd51849fd463cb08b83fdb8e5ca0c",
703
- "IPY_MODEL_d247683a0a61441b971dfb39062e1fbf",
704
- "IPY_MODEL_1da23fc236034f32adcaf6bb2e0e7d80"
705
- ],
706
- "layout": "IPY_MODEL_4b2126d97c514795ab2a90f7357a203c"
707
- }
708
- },
709
- "d247683a0a61441b971dfb39062e1fbf": {
710
- "model_module": "@jupyter-widgets/controls",
711
- "model_module_version": "1.5.0",
712
- "model_name": "FloatProgressModel",
713
- "state": {
714
- "_dom_classes": [],
715
- "_model_module": "@jupyter-widgets/controls",
716
- "_model_module_version": "1.5.0",
717
- "_model_name": "FloatProgressModel",
718
- "_view_count": null,
719
- "_view_module": "@jupyter-widgets/controls",
720
- "_view_module_version": "1.5.0",
721
- "_view_name": "ProgressView",
722
- "bar_style": "success",
723
- "description": "",
724
- "description_tooltip": null,
725
- "layout": "IPY_MODEL_20aa0031b7bb45bf82443b48b3694166",
726
- "max": 11,
727
- "min": 0,
728
- "orientation": "horizontal",
729
- "style": "IPY_MODEL_67252ea545d64392a1bd6ac40852e65f",
730
- "value": 11
731
- }
732
- },
733
- "d79fd51849fd463cb08b83fdb8e5ca0c": {
734
- "model_module": "@jupyter-widgets/controls",
735
- "model_module_version": "1.5.0",
736
- "model_name": "HTMLModel",
737
- "state": {
738
- "_dom_classes": [],
739
- "_model_module": "@jupyter-widgets/controls",
740
- "_model_module_version": "1.5.0",
741
- "_model_name": "HTMLModel",
742
- "_view_count": null,
743
- "_view_module": "@jupyter-widgets/controls",
744
- "_view_module_version": "1.5.0",
745
- "_view_name": "HTMLView",
746
- "description": "",
747
- "description_tooltip": null,
748
- "layout": "IPY_MODEL_9775ce64008b417fac3edd55b9e999d9",
749
- "placeholder": "​",
750
- "style": "IPY_MODEL_96c9bb2eff4043b2a5dbd1e3e65375e5",
751
- "value": "Fetching 11 files: 100%"
752
- }
753
- }
754
- }
755
  }
756
  },
757
  "nbformat": 4,
 
1
  {
2
  "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "xYJFXKP9xhQM"
7
+ },
8
+ "source": [
9
+ "## Clone Repo"
10
+ ]
11
+ },
12
  {
13
  "cell_type": "code",
14
  "execution_count": null,
15
  "metadata": {
16
+ "id": "hegwDOfffwzw"
 
 
 
 
17
  },
18
  "outputs": [],
19
  "source": [
20
+ "!cd /content\n",
21
  "!rm -rf /content/ChatTTS\n",
22
  "!git clone https://github.com/2noise/ChatTTS.git\n",
23
  "!pip install -r /content/ChatTTS/requirements.txt\n",
 
24
  "!ldconfig /usr/lib64-nvidia"
25
  ]
26
  },
27
+ {
28
+ "cell_type": "markdown",
29
+ "metadata": {
30
+ "id": "zdzEFoknxqTH"
31
+ },
32
+ "source": [
33
+ "## Import Libs"
34
+ ]
35
+ },
36
  {
37
  "cell_type": "code",
38
  "execution_count": null,
 
42
  "outputs": [],
43
  "source": [
44
  "from dotenv import load_dotenv\n",
45
+ "load_dotenv(\"ChatTTS/sha256.env\")\n",
46
  "\n",
47
  "import torch\n",
48
  "torch._dynamo.config.cache_size_limit = 64\n",
 
66
  "cell_type": "code",
67
  "execution_count": null,
68
  "metadata": {
69
+ "id": "e0QSkngRbSrg"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  },
71
  "outputs": [],
72
  "source": [
73
+ "chat = ChatTTS.Chat()"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "markdown",
78
+ "metadata": {},
79
+ "source": [
80
+ "### Here are three choices for loading models:"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "markdown",
85
+ "metadata": {},
86
+ "source": [
87
+ "#### 1. Load models from Hugging Face:"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "# use force_redownload=True if the weights have been updated.\n",
97
+ "chat.load_models(source='huggingface', force_redownload=True)"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "markdown",
102
+ "metadata": {},
103
+ "source": [
104
+ "#### 2. Load models from local directories 'asset' and 'config':"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": null,
110
+ "metadata": {},
111
+ "outputs": [],
112
+ "source": [
113
+ "chat.load_models()\n",
114
+ "# chat.load_models(source='local') same as above"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "markdown",
119
+ "metadata": {},
120
+ "source": [
121
+ "#### 3. Load models from a custom path:"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": null,
127
+ "metadata": {},
128
+ "outputs": [],
129
+ "source": [
130
+ "# write the model path into custom_path\n",
131
+ "chat.load_models(source='custom', custom_path='YOUR CUSTOM PATH')"
132
  ]
133
  },
134
  {
 
153
  "cell_type": "code",
154
  "execution_count": null,
155
  "metadata": {
156
+ "id": "Su9FmUYAbSrh"
 
 
 
 
157
  },
158
  "outputs": [],
159
  "source": [
 
167
  "cell_type": "code",
168
  "execution_count": null,
169
  "metadata": {
170
+ "id": "YQRwB8lpbSri"
 
 
 
 
 
171
  },
172
  "outputs": [],
173
  "source": [
 
178
  "cell_type": "code",
179
  "execution_count": null,
180
  "metadata": {
181
+ "id": "LuFG6m7AbSri"
 
 
 
 
 
182
  },
183
  "outputs": [],
184
  "source": [
 
198
  "cell_type": "code",
199
  "execution_count": null,
200
  "metadata": {
201
+ "id": "kma0HBEBbSrj"
 
 
 
 
202
  },
203
  "outputs": [],
204
  "source": [
 
213
  "cell_type": "code",
214
  "execution_count": null,
215
  "metadata": {
216
+ "id": "Nl_mT9KpbSrj"
 
 
 
 
 
217
  },
218
  "outputs": [],
219
  "source": [
 
233
  "cell_type": "code",
234
  "execution_count": null,
235
  "metadata": {
236
+ "id": "Qh7dcWrAbSrk"
 
 
 
 
237
  },
238
  "outputs": [],
239
  "source": [
 
248
  "cell_type": "code",
249
  "execution_count": null,
250
  "metadata": {
251
+ "id": "0ljWDWzabSrk"
 
 
 
 
 
252
  },
253
  "outputs": [],
254
  "source": [
 
268
  "cell_type": "code",
269
  "execution_count": null,
270
  "metadata": {
271
+ "id": "3hAAc0lJbSrl"
 
 
 
 
272
  },
273
  "outputs": [],
274
  "source": [
 
281
  "cell_type": "code",
282
  "execution_count": null,
283
  "metadata": {
284
+ "id": "0GVJxhd3BKQX"
 
 
 
 
285
  },
286
  "outputs": [],
287
  "source": [
 
292
  "cell_type": "code",
293
  "execution_count": null,
294
  "metadata": {
295
+ "id": "ngyMht74BicY"
 
 
 
 
 
296
  },
297
  "outputs": [],
298
  "source": [
 
303
  "cell_type": "code",
304
  "execution_count": null,
305
  "metadata": {
306
+ "id": "R2WjuVrWbSrl"
 
 
 
 
307
  },
308
  "outputs": [],
309
  "source": [
 
315
  "cell_type": "code",
316
  "execution_count": null,
317
  "metadata": {
318
+ "id": "71Y4pBdl-_Yd"
 
 
 
 
 
319
  },
320
  "outputs": [],
321
  "source": [
 
400
  "nbconvert_exporter": "python",
401
  "pygments_lexer": "ipython3",
402
  "version": "3.10.8"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  }
404
  },
405
  "nbformat": 4,
ChatTTS/examples/ipynb/example.ipynb CHANGED
@@ -13,8 +13,19 @@
13
  "metadata": {},
14
  "outputs": [],
15
  "source": [
 
 
 
 
 
 
 
 
 
 
 
16
  "from dotenv import load_dotenv\n",
17
- "load_dotenv(\"sha256.env\")\n",
18
  "\n",
19
  "import torch\n",
20
  "torch._dynamo.config.cache_size_limit = 64\n",
@@ -38,14 +49,67 @@
38
  "metadata": {},
39
  "outputs": [],
40
  "source": [
41
- "chat = ChatTTS.Chat()\n",
42
- "chat.load_models()\n",
43
  "\n",
44
- "# Use force_redownload=True if the weights updated.\n",
45
- "# chat.load_models(force_redownload=True)\n",
46
- "\n",
47
- "# If you download the weights manually, set source='locals'.\n",
48
- "# chat.load_models(source='local', local_path='YOUR LOCAL PATH')"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  ]
50
  },
51
  {
@@ -70,7 +134,7 @@
70
  "source": [
71
  "texts = [\"So we found being competitive and collaborative was a huge way of staying motivated towards our goals, so one person to call when you fall off, one person who gets you back on then one person to actually do the activity with.\",]*3 \\\n",
72
  " + [\"我觉得像我们这些写程序的人,他,我觉得多多少少可能会对开源有一种情怀在吧我觉得开源是一个很好的形式。现在其实最先进的技术掌握在一些公司的手里的话,就他们并不会轻易的开放给所有的人用。\"]*3 \n",
73
- " \n",
74
  "wavs = chat.infer(texts)"
75
  ]
76
  },
@@ -239,7 +303,7 @@
239
  "name": "python",
240
  "nbconvert_exporter": "python",
241
  "pygments_lexer": "ipython3",
242
- "version": "3.10.8"
243
  }
244
  },
245
  "nbformat": 4,
 
13
  "metadata": {},
14
  "outputs": [],
15
  "source": [
16
+ "import os, sys\n",
17
+ "\n",
18
+ "if sys.platform == \"darwin\":\n",
19
+ " os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"\n",
20
+ "\n",
21
+ "if not \"root_dir\" in globals():\n",
22
+ " now_dir = os.getcwd() # skip examples/ipynb\n",
23
+ " root_dir = os.path.join(now_dir, \"../../\")\n",
24
+ " sys.path.append(root_dir)\n",
25
+ " print(\"init root dir to\", root_dir)\n",
26
+ "\n",
27
  "from dotenv import load_dotenv\n",
28
+ "load_dotenv(os.path.join(root_dir, \"sha256.env\"))\n",
29
  "\n",
30
  "import torch\n",
31
  "torch._dynamo.config.cache_size_limit = 64\n",
 
49
  "metadata": {},
50
  "outputs": [],
51
  "source": [
52
+ "os.chdir(root_dir)\n",
 
53
  "\n",
54
+ "chat = ChatTTS.Chat()"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "markdown",
59
+ "metadata": {},
60
+ "source": [
61
+ "### Here are three choices for loading models:"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "markdown",
66
+ "metadata": {},
67
+ "source": [
68
+ "#### 1. Load models from Hugging Face:"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": null,
74
+ "metadata": {},
75
+ "outputs": [],
76
+ "source": [
77
+ "# use force_redownload=True if the weights have been updated.\n",
78
+ "chat.load_models(source='huggingface', force_redownload=True)"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "markdown",
83
+ "metadata": {},
84
+ "source": [
85
+ "#### 2. Load models from local directories 'asset' and 'config':"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": null,
91
+ "metadata": {},
92
+ "outputs": [],
93
+ "source": [
94
+ "chat.load_models()\n",
95
+ "# chat.load_models(source='local') same as above"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "markdown",
100
+ "metadata": {},
101
+ "source": [
102
+ "#### 3. Load models from a custom path:"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": null,
108
+ "metadata": {},
109
+ "outputs": [],
110
+ "source": [
111
+ "# write the model path into custom_path\n",
112
+ "chat.load_models(source='custom', custom_path='YOUR CUSTOM PATH')"
113
  ]
114
  },
115
  {
 
134
  "source": [
135
  "texts = [\"So we found being competitive and collaborative was a huge way of staying motivated towards our goals, so one person to call when you fall off, one person who gets you back on then one person to actually do the activity with.\",]*3 \\\n",
136
  " + [\"我觉得像我们这些写程序的人,他,我觉得多多少少可能会对开源有一种情怀在吧我觉得开源是一个很好的形式。现在其实最先进的技术掌握在一些公司的手里的话,就他们并不会轻易的开放给所有的人用。\"]*3 \n",
137
+ "\n",
138
  "wavs = chat.infer(texts)"
139
  ]
140
  },
 
303
  "name": "python",
304
  "nbconvert_exporter": "python",
305
  "pygments_lexer": "ipython3",
306
+ "version": "3.9.6"
307
  }
308
  },
309
  "nbformat": 4,
ChatTTS/examples/web/funcs.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import torch
4
+ import gradio as gr
5
+ import numpy as np
6
+
7
+ from tools.logger import get_logger
8
+ logger = get_logger(" WebUI ")
9
+
10
+ import ChatTTS
11
+ chat = ChatTTS.Chat(get_logger("ChatTTS"))
12
+
13
+ # 音色选项:用于预置合适的音色
14
+ voices = {
15
+ "默认": {"seed": 2},
16
+ "音色1": {"seed": 1111},
17
+ "音色2": {"seed": 2222},
18
+ "音色3": {"seed": 3333},
19
+ "音色4": {"seed": 4444},
20
+ "音色5": {"seed": 5555},
21
+ "音色6": {"seed": 6666},
22
+ "音色7": {"seed": 7777},
23
+ "音色8": {"seed": 8888},
24
+ "音色9": {"seed": 9999},
25
+ "音色10": {"seed": 11111},
26
+ }
27
+
28
+ def generate_seed():
29
+ return gr.update(value=random.randint(1, 100000000))
30
+
31
+ # 返回选择音色对应的seed
32
+ def on_voice_change(vocie_selection):
33
+ return voices.get(vocie_selection)['seed']
34
+
35
+ def refine_text(text, audio_seed_input, text_seed_input, refine_text_flag):
36
+ if not refine_text_flag:
37
+ return text
38
+
39
+ global chat
40
+
41
+ torch.manual_seed(audio_seed_input)
42
+ params_refine_text = {'prompt': '[oral_2][laugh_0][break_6]'}
43
+
44
+ torch.manual_seed(text_seed_input)
45
+
46
+ text = chat.infer(text,
47
+ skip_refine_text=False,
48
+ refine_text_only=True,
49
+ params_refine_text=params_refine_text,
50
+ )
51
+ return text[0] if isinstance(text, list) else text
52
+
53
+ def generate_audio(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, stream):
54
+ if not text: return None
55
+
56
+ global chat
57
+
58
+ torch.manual_seed(audio_seed_input)
59
+ rand_spk = chat.sample_random_speaker()
60
+ params_infer_code = {
61
+ 'spk_emb': rand_spk,
62
+ 'temperature': temperature,
63
+ 'top_P': top_P,
64
+ 'top_K': top_K,
65
+ }
66
+ torch.manual_seed(text_seed_input)
67
+
68
+ wav = chat.infer(
69
+ text,
70
+ skip_refine_text=True,
71
+ params_infer_code=params_infer_code,
72
+ stream=stream,
73
+ )
74
+
75
+ if stream:
76
+ for gen in wav:
77
+ wavs = [np.array([[]])]
78
+ wavs[0] = np.hstack([wavs[0], np.array(gen[0])])
79
+ audio = wavs[0][0]
80
+
81
+ # normalize
82
+ am = np.abs(audio).max() * 32768
83
+ if am > 32768:
84
+ am = 32768 * 32768 / am
85
+ np.multiply(audio, am, audio)
86
+ audio = audio.astype(np.int16)
87
+
88
+ yield 24000, audio
89
+ return
90
+
91
+ audio_data = np.array(wav[0]).flatten()
92
+ # normalize
93
+ am = np.abs(audio_data).max() * 32768
94
+ if am > 32768:
95
+ am = 32768 * 32768 / am
96
+ np.multiply(audio_data, am, audio_data)
97
+ audio_data = audio_data.astype(np.int16)
98
+ sample_rate = 24000
99
+
100
+ yield sample_rate, audio_data
ChatTTS/examples/web/webui.py CHANGED
@@ -6,106 +6,44 @@ if sys.platform == "darwin":
6
  now_dir = os.getcwd()
7
  sys.path.append(now_dir)
8
 
9
- import random
10
  import argparse
11
 
12
- import torch
13
  import gradio as gr
14
- import numpy as np
15
 
16
  from dotenv import load_dotenv
17
  load_dotenv("sha256.env")
18
 
19
- import ChatTTS
20
-
21
- # 音色选项:用于预置合适的音色
22
- voices = {
23
- "默认": {"seed": 2},
24
- "音色1": {"seed": 1111},
25
- "音色2": {"seed": 2222},
26
- "音色3": {"seed": 3333},
27
- "音色4": {"seed": 4444},
28
- "音色5": {"seed": 5555},
29
- "音色6": {"seed": 6666},
30
- "音色7": {"seed": 7777},
31
- "音色8": {"seed": 8888},
32
- "音色9": {"seed": 9999},
33
- "音色10": {"seed": 11111},
34
- }
35
-
36
- def generate_seed():
37
- new_seed = random.randint(1, 100000000)
38
- return {
39
- "__type__": "update",
40
- "value": new_seed
41
- }
42
-
43
- # 返回选择音色对应的seed
44
- def on_voice_change(vocie_selection):
45
- return voices.get(vocie_selection)['seed']
46
-
47
- def generate_audio(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag):
48
-
49
- torch.manual_seed(audio_seed_input)
50
- rand_spk = chat.sample_random_speaker()
51
- params_infer_code = {
52
- 'spk_emb': rand_spk,
53
- 'temperature': temperature,
54
- 'top_P': top_P,
55
- 'top_K': top_K,
56
- }
57
- params_refine_text = {'prompt': '[oral_2][laugh_0][break_6]'}
58
-
59
- torch.manual_seed(text_seed_input)
60
-
61
- if refine_text_flag:
62
- text = chat.infer(text,
63
- skip_refine_text=False,
64
- refine_text_only=True,
65
- params_refine_text=params_refine_text,
66
- params_infer_code=params_infer_code
67
- )
68
-
69
- wav = chat.infer(text,
70
- skip_refine_text=True,
71
- params_refine_text=params_refine_text,
72
- params_infer_code=params_infer_code
73
- )
74
-
75
- audio_data = np.array(wav[0]).flatten()
76
- sample_rate = 24000
77
- text_data = text[0] if isinstance(text, list) else text
78
-
79
- return [(sample_rate, audio_data), text_data]
80
-
81
 
82
  def main():
83
 
84
  with gr.Blocks() as demo:
85
- gr.Markdown("# ChatTTS Webui")
86
- gr.Markdown("ChatTTS Model: [2noise/ChatTTS](https://github.com/2noise/ChatTTS)")
 
87
 
88
- default_text = "四川美食确实以辣闻名,但也有不辣的选择。[uv_break]比如甜水面、赖汤圆、蛋烘糕、叶儿粑等,这些小吃口味温和,甜而不腻,也很受欢迎。[laugh]"
89
  text_input = gr.Textbox(label="Input Text", lines=4, placeholder="Please Input Text...", value=default_text)
90
 
91
  with gr.Row():
92
  refine_text_checkbox = gr.Checkbox(label="Refine text", value=True)
93
- temperature_slider = gr.Slider(minimum=0.00001, maximum=1.0, step=0.00001, value=0.3, label="Audio temperature")
94
- top_p_slider = gr.Slider(minimum=0.1, maximum=0.9, step=0.05, value=0.7, label="top_P")
95
- top_k_slider = gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_K")
96
 
97
  with gr.Row():
98
- voice_options = {}
99
  voice_selection = gr.Dropdown(label="音色", choices=voices.keys(), value='默认')
100
  audio_seed_input = gr.Number(value=2, label="Audio Seed")
101
  generate_audio_seed = gr.Button("\U0001F3B2")
102
  text_seed_input = gr.Number(value=42, label="Text Seed")
103
  generate_text_seed = gr.Button("\U0001F3B2")
104
 
105
- generate_button = gr.Button("Generate")
 
 
 
106
 
107
  text_output = gr.Textbox(label="Output Text", interactive=False)
108
- audio_output = gr.Audio(label="Output Audio")
109
 
110
  # 使用Gradio的回调功能来更新数值输入框
111
  voice_selection.change(fn=on_voice_change, inputs=voice_selection, outputs=audio_seed_input)
@@ -117,10 +55,25 @@ def main():
117
  generate_text_seed.click(generate_seed,
118
  inputs=[],
119
  outputs=text_seed_input)
120
-
121
- generate_button.click(generate_audio,
122
- inputs=[text_input, temperature_slider, top_p_slider, top_k_slider, audio_seed_input, text_seed_input, refine_text_checkbox],
123
- outputs=[audio_output, text_output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  gr.Examples(
126
  examples=[
@@ -138,15 +91,22 @@ def main():
138
  parser.add_argument('--custom_path', type=str, default=None, help='the custom model path')
139
  args = parser.parse_args()
140
 
141
- print("loading ChatTTS model...")
 
142
  global chat
143
- chat = ChatTTS.Chat()
144
 
145
  if args.custom_path == None:
146
- chat.load_models()
147
  else:
148
- print('local model path:', args.custom_path)
149
- chat.load_models('custom', custom_path=args.custom_path)
 
 
 
 
 
 
 
150
 
151
  demo.launch(server_name=args.server_name, server_port=args.server_port, root_path=args.root_path, inbrowser=True)
152
 
 
6
  now_dir = os.getcwd()
7
  sys.path.append(now_dir)
8
 
 
9
  import argparse
10
 
 
11
  import gradio as gr
 
12
 
13
  from dotenv import load_dotenv
14
  load_dotenv("sha256.env")
15
 
16
+ from examples.web.funcs import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def main():
19
 
20
  with gr.Blocks() as demo:
21
+ gr.Markdown("# ChatTTS WebUI")
22
+ gr.Markdown("- **GitHub Repo**: https://github.com/2noise/ChatTTS")
23
+ gr.Markdown("- **HuggingFace Repo**: https://huggingface.co/2Noise/ChatTTS")
24
 
25
+ default_text = "四川美食确实以辣闻名,但也有不辣的选择。比如甜水面、赖汤圆、蛋烘糕、叶儿粑等,这些小吃口味温和,甜而不腻,也很受欢迎。"
26
  text_input = gr.Textbox(label="Input Text", lines=4, placeholder="Please Input Text...", value=default_text)
27
 
28
  with gr.Row():
29
  refine_text_checkbox = gr.Checkbox(label="Refine text", value=True)
30
+ temperature_slider = gr.Slider(minimum=0.00001, maximum=1.0, step=0.00001, value=0.3, label="Audio temperature", interactive=True)
31
+ top_p_slider = gr.Slider(minimum=0.1, maximum=0.9, step=0.05, value=0.7, label="top_P", interactive=True)
32
+ top_k_slider = gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_K", interactive=True)
33
 
34
  with gr.Row():
 
35
  voice_selection = gr.Dropdown(label="音色", choices=voices.keys(), value='默认')
36
  audio_seed_input = gr.Number(value=2, label="Audio Seed")
37
  generate_audio_seed = gr.Button("\U0001F3B2")
38
  text_seed_input = gr.Number(value=42, label="Text Seed")
39
  generate_text_seed = gr.Button("\U0001F3B2")
40
 
41
+ with gr.Row():
42
+ auto_play_checkbox = gr.Checkbox(label="Auto Play", value=False, scale=1)
43
+ stream_mode_checkbox = gr.Checkbox(label="Stream Mode", value=False, scale=1)
44
+ generate_button = gr.Button("Generate", scale=2)
45
 
46
  text_output = gr.Textbox(label="Output Text", interactive=False)
 
47
 
48
  # 使用Gradio的回调功能来更新数值输入框
49
  voice_selection.change(fn=on_voice_change, inputs=voice_selection, outputs=audio_seed_input)
 
55
  generate_text_seed.click(generate_seed,
56
  inputs=[],
57
  outputs=text_seed_input)
58
+
59
+ generate_button.click(fn=lambda: "", outputs=text_output)
60
+ generate_button.click(refine_text,
61
+ inputs=[text_input, audio_seed_input, text_seed_input, refine_text_checkbox],
62
+ outputs=text_output)
63
+
64
+ @gr.render(inputs=[auto_play_checkbox, stream_mode_checkbox])
65
+ def make_audio(autoplay, stream):
66
+ audio_output = gr.Audio(
67
+ label="Output Audio",
68
+ value=None,
69
+ autoplay=autoplay,
70
+ streaming=stream,
71
+ interactive=False,
72
+ show_label=True,
73
+ )
74
+ text_output.change(generate_audio,
75
+ inputs=[text_output, temperature_slider, top_p_slider, top_k_slider, audio_seed_input, text_seed_input, stream_mode_checkbox],
76
+ outputs=audio_output)
77
 
78
  gr.Examples(
79
  examples=[
 
91
  parser.add_argument('--custom_path', type=str, default=None, help='the custom model path')
92
  args = parser.parse_args()
93
 
94
+ logger.info("loading ChatTTS model...")
95
+
96
  global chat
 
97
 
98
  if args.custom_path == None:
99
+ ret = chat.load_models()
100
  else:
101
+ logger.info('local model path: %s', args.custom_path)
102
+ ret = chat.load_models('custom', custom_path=args.custom_path)
103
+
104
+ if ret:
105
+ logger.info("Models loaded successfully.")
106
+ else:
107
+ logger.error("Models load failed.")
108
+ sys.exit(1)
109
+
110
 
111
  demo.launch(server_name=args.server_name, server_port=args.server_port, root_path=args.root_path, inbrowser=True)
112
 
ChatTTS/requirements.txt CHANGED
@@ -1,14 +1,13 @@
1
  numpy<2.0.0
2
- omegaconf~=2.3.0
3
- torch~=2.1.0
4
  tqdm
5
- einops
6
  vector_quantize_pytorch
7
- transformers~=4.41.1
8
  vocos
9
  IPython
10
  gradio
11
  python-dotenv
12
- pynini==2.1.5
13
- WeTextProcessing
14
- nemo_text_processing
 
1
  numpy<2.0.0
2
+ omegaconf>=2.3.0
3
+ torch>=2.1.0
4
  tqdm
 
5
  vector_quantize_pytorch
6
+ transformers>=4.41.1
7
  vocos
8
  IPython
9
  gradio
10
  python-dotenv
11
+ pynini==2.1.5; sys_platform == 'linux'
12
+ WeTextProcessing; sys_platform == 'linux'
13
+ nemo_text_processing; sys_platform == 'linux'
ChatTTS/setup.py CHANGED
@@ -6,7 +6,6 @@ setup(name='chattts',
6
  install_requires=['omegaconf>=2.3.0',
7
  'torch>=2.1.0',
8
  'tqdm',
9
- 'einops',
10
  'vector_quantize_pytorch',
11
  'transformers>=4.41.1',
12
  'vocos',
 
6
  install_requires=['omegaconf>=2.3.0',
7
  'torch>=2.1.0',
8
  'tqdm',
 
9
  'vector_quantize_pytorch',
10
  'transformers>=4.41.1',
11
  'vocos',
ChatTTS/tools/logger/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .log import get_logger
ChatTTS/tools/logger/log.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+ import logging
3
+ from datetime import datetime, timezone
4
+
5
+ # from https://github.com/FloatTech/ZeroBot-Plugin/blob/c70766a989698452e60e5e48fb2f802a2444330d/console/console_windows.go#L89-L96
6
+ colorCodePanic = "\x1b[1;31m"
7
+ colorCodeFatal = "\x1b[1;31m"
8
+ colorCodeError = "\x1b[31m"
9
+ colorCodeWarn = "\x1b[33m"
10
+ colorCodeInfo = "\x1b[37m"
11
+ colorCodeDebug = "\x1b[32m"
12
+ colorCodeTrace = "\x1b[36m"
13
+ colorReset = "\x1b[0m"
14
+
15
+ log_level_color_code = {
16
+ logging.DEBUG: colorCodeDebug,
17
+ logging.INFO: colorCodeInfo,
18
+ logging.WARN: colorCodeWarn,
19
+ logging.ERROR: colorCodeError,
20
+ logging.FATAL: colorCodeFatal,
21
+ }
22
+
23
+ log_level_msg_str = {
24
+ logging.DEBUG: "DEBU",
25
+ logging.INFO: "INFO",
26
+ logging.WARN: "WARN",
27
+ logging.ERROR: "ERRO",
28
+ logging.FATAL: "FATL",
29
+ }
30
+
31
+ class Formatter(logging.Formatter):
32
+ def __init__(self, color=platform.system().lower() != "windows"):
33
+ # https://stackoverflow.com/questions/2720319/python-figure-out-local-timezone
34
+ self.tz = datetime.now(timezone.utc).astimezone().tzinfo
35
+ self.color = color
36
+
37
+ def format(self, record: logging.LogRecord):
38
+ logstr = "[" + datetime.now(self.tz).strftime('%z %Y%m%d %H:%M:%S') + "] ["
39
+ if self.color:
40
+ logstr += log_level_color_code.get(record.levelno, colorCodeInfo)
41
+ logstr += log_level_msg_str.get(record.levelno, record.levelname)
42
+ if self.color:
43
+ logstr += colorReset
44
+ logstr += f"] {str(record.name)} | {str(record.msg)}"
45
+ return logstr
46
+
47
+ def get_logger(name: str, lv = logging.INFO):
48
+ logger = logging.getLogger(name)
49
+ syslog = logging.StreamHandler()
50
+ syslog.setFormatter(Formatter())
51
+ logger.setLevel(lv)
52
+ logger.addHandler(syslog)
53
+ return logger
abc ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit f4c8329f0d231b272b676e5e171fb9655b345f2e
chattts_webui_mix.ipynb CHANGED
@@ -4,7 +4,9 @@
4
  "metadata": {
5
  "colab": {
6
  "provenance": [],
7
- "gpuType": "T4"
 
 
8
  },
9
  "kernelspec": {
10
  "name": "python3",
@@ -18,8 +20,42 @@
18
  "cells": [
19
  {
20
  "cell_type": "markdown",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  "source": [
22
- "> 🌟 如果你觉得 ChatTTS 和 ChatTTS_colab 项目对你有帮助,请访问以下链接给它们点个星星吧!🌟\n",
 
 
 
23
  "\n",
24
  "- [ChatTTS 项目](https://github.com/2noise/ChatTTS)\n",
25
  "\n",
@@ -27,18 +63,19 @@
27
  "\n",
28
  "感谢你的支持!\n",
29
  "\n",
30
- "# 运行方法\n",
31
  "\n",
32
  "- 点击菜单栏的--代码执行程序--全部运行即可\n",
33
  "- 执行后在下方的日志中找到类似\n",
34
  "\n",
35
  " Running on public URL: https://**************.gradio.live <-这个就是可以访问的公网地址\n",
36
  "\n",
37
- "安装包的时候提示要重启 请点**\"否\"**"
38
- ],
39
- "metadata": {
40
- "id": "Xo3k5XsTzWK6"
41
- }
 
42
  },
43
  {
44
  "cell_type": "code",
@@ -47,19 +84,58 @@
47
  "%cd ChatTTS_colab\n",
48
  "!git clone -q https://github.com/2noise/ChatTTS\n",
49
  "%cd ChatTTS\n",
50
- "!git checkout -q e6412b1\n",
51
  "%cd ..\n",
52
  "!mv ChatTTS abc\n",
53
- "!mv abc/* /content/ChatTTS_colab/\n",
54
- "!pip install -q omegaconf vocos vector_quantize_pytorch gradio cn2an pypinyin openai jieba WeTextProcessing python-dotenv\n",
55
  "# 启动 Gradio 有公网地址\n",
56
  "!python webui_mix.py --share\n"
57
  ],
58
  "metadata": {
59
- "id": "hNDl-5muR77-"
 
 
 
 
60
  },
61
  "execution_count": null,
62
- "outputs": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  }
64
  ]
65
  }
 
4
  "metadata": {
5
  "colab": {
6
  "provenance": [],
7
+ "gpuType": "T4",
8
+ "authorship_tag": "ABX9TyPWzXw++IDXf5gvuBHiHqmz",
9
+ "include_colab_link": true
10
  },
11
  "kernelspec": {
12
  "name": "python3",
 
20
  "cells": [
21
  {
22
  "cell_type": "markdown",
23
+ "metadata": {
24
+ "id": "view-in-github",
25
+ "colab_type": "text"
26
+ },
27
+ "source": [
28
+ "<a href=\"https://colab.research.google.com/github/6drf21e/ChatTTS_colab/blob/main/chattts_webui_mix.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": 10,
34
+ "metadata": {
35
+ "colab": {
36
+ "base_uri": "https://localhost:8080/",
37
+ "height": 260
38
+ },
39
+ "id": "-VNe1BeDO1n0",
40
+ "outputId": "f3ed0cc9-b8dd-4f2a-9cdd-3106e41f485d"
41
+ },
42
+ "outputs": [
43
+ {
44
+ "output_type": "display_data",
45
+ "data": {
46
+ "text/plain": [
47
+ "<IPython.core.display.Markdown object>"
48
+ ],
49
+ "text/markdown": "\n### 🌟 如果你觉得 ChatTTS 和 ChatTTS_colab 项目对你有帮助,请访问以下链接给它们点个星星吧!🌟\n\n- [ChatTTS 项目](https://github.com/2noise/ChatTTS)\n\n- [ChatTTS_colab 项目](https://github.com/6drf21e/ChatTTS_colab)\n\n感谢你的支持!\n\n### 运行方法 ###\n\n- 点击菜单栏的--代码执行程序--全部运行即可\n- 执行后在下方的日志中找到类似\n\n Running on public URL: https://**********.gradio.live <-这个就是可以访问的公网地址\n\n安装包的时候提示要重启 请点**\"否\"**\n\n\n"
50
+ },
51
+ "metadata": {}
52
+ }
53
+ ],
54
  "source": [
55
+ "from IPython.display import display, Markdown\n",
56
+ "\n",
57
+ "message = \"\"\"\n",
58
+ "### 🌟 如果你觉得 ChatTTS 和 ChatTTS_colab 项目对你有帮助,请访问以下链接给它们点个星星吧!🌟\n",
59
  "\n",
60
  "- [ChatTTS 项目](https://github.com/2noise/ChatTTS)\n",
61
  "\n",
 
63
  "\n",
64
  "感谢你的支持!\n",
65
  "\n",
66
+ "### 运行方法 ###\n",
67
  "\n",
68
  "- 点击菜单栏的--代码执行程序--全部运行即可\n",
69
  "- 执行后在下方的日志中找到类似\n",
70
  "\n",
71
  " Running on public URL: https://**************.gradio.live <-这个就是可以访问的公网地址\n",
72
  "\n",
73
+ "安装包的时候提示要重启 请点**\"否\"**\n",
74
+ "\n",
75
+ "\n",
76
+ "\"\"\"\n",
77
+ "display(Markdown(message))\n"
78
+ ]
79
  },
80
  {
81
  "cell_type": "code",
 
84
  "%cd ChatTTS_colab\n",
85
  "!git clone -q https://github.com/2noise/ChatTTS\n",
86
  "%cd ChatTTS\n",
87
+ "!git checkout -q f4c8329\n",
88
  "%cd ..\n",
89
  "!mv ChatTTS abc\n",
90
+ "!mv abc/ChatTTS ./ChatTTS\n",
91
+ "!pip install -q omegaconf vocos vector_quantize_pytorch gradio cn2an pypinyin openai\n",
92
  "# 启动 Gradio 有公网地址\n",
93
  "!python webui_mix.py --share\n"
94
  ],
95
  "metadata": {
96
+ "colab": {
97
+ "base_uri": "https://localhost:8080/"
98
+ },
99
+ "id": "hNDl-5muR77-",
100
+ "outputId": "9ca99a78-1354-4c4d-dfa9-30a82b1a7813"
101
  },
102
  "execution_count": null,
103
+ "outputs": [
104
+ {
105
+ "output_type": "stream",
106
+ "name": "stdout",
107
+ "text": [
108
+ "/content/ChatTTS_colab/ChatTTS_colab\n",
109
+ "/content/ChatTTS_colab/ChatTTS_colab/ChatTTS\n",
110
+ "/content/ChatTTS_colab/ChatTTS_colab\n",
111
+ "Loading ChatTTS model...\n",
112
+ "INFO:ChatTTS.core:Load from cache: /root/.cache/huggingface/hub/models--2Noise--ChatTTS/snapshots/ce5913842aebd78e4a01a02d47244b8d62ac4ee3\n",
113
+ "INFO:ChatTTS.core:use cuda:0\n",
114
+ "INFO:ChatTTS.core:vocos loaded.\n",
115
+ "INFO:ChatTTS.core:dvae loaded.\n",
116
+ "INFO:ChatTTS.core:gpt loaded.\n",
117
+ "INFO:ChatTTS.core:decoder loaded.\n",
118
+ "INFO:ChatTTS.core:tokenizer loaded.\n",
119
+ "INFO:ChatTTS.core:All initialized.\n",
120
+ "INFO:httpx:HTTP Request: GET https://api.gradio.app/pkg-version \"HTTP/1.1 200 OK\"\n",
121
+ "INFO:httpx:HTTP Request: GET https://checkip.amazonaws.com/ \"HTTP/1.1 200 \"\n",
122
+ "Running on local URL: http://127.0.0.1:7860\n",
123
+ "INFO:httpx:HTTP Request: GET http://127.0.0.1:7860/startup-events \"HTTP/1.1 200 OK\"\n",
124
+ "INFO:httpx:HTTP Request: HEAD http://127.0.0.1:7860/ \"HTTP/1.1 200 OK\"\n",
125
+ "INFO:httpx:HTTP Request: GET https://api.gradio.app/v2/tunnel-request \"HTTP/1.1 200 OK\"\n",
126
+ "Running on public URL: https://054d1298c1303e0370.gradio.live\n",
127
+ "\n",
128
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n",
129
+ "[' 海底二万里 。第一章\\u3000飞走的暗礁。人们一定还记得一八六六年海上发生的一件离奇的、神秘的、无法解释的怪事。且不说当时哄动沿海居民和世界舆论的各种传闻 这里只说一般航海人员特别激动的心情。欧美的进出口商人、船长和船主、各国的海军官佐以及这两大洲的各国政府都非常注意这件事。', '这事大体是这样 不久以前 好些大船在海上碰见了一一个 庞然大物 一个很长的物体 形状很像纺锤 有时发出磷光 它的体积比鲸鱼大得多 行动起来也比鲸鱼快得多。关于这个东西的出现 许多航海日志所记下的事实 如这个东西或这个生物的形状 在它运动时的难以估计的速度 它转移的惊人力量 它那种像是天生的特殊本领等等 大致是相同的。', '如果这东西是鲸鱼类动物 那么它的体积 是大大超过了生物学家曾经加以分类的鲸鱼。居维埃·拉色别德①、杜梅里②、卡特法日③ 这些生物学家一一除非看见过 也就是说 除非这些科学家本人的眼睛看见过——是不承认有这样一种怪物存在的。把多次观察的结果折中一下来看———方面丢开那些过低的估计 即这个东西只有二百英尺长 同时也不接受过于夸张的言论 即它有一英里。', '宽三英里长 ——我们可以肯定他说 晋书·阮籍传 。 其后纲维不摄 而虚无放诞之论盈于 这个奇怪的生物 如果真是存在的话 它的体积是大大超过鱼类学家所承认的体积的。这东西既然存在 而事实本身又是不可否认的 那么 由于人类好奇的心理 我们就不难理解这个怪物的出现会在全世界引起怎样的骚动。', '至于说这是荒唐无稽之谈 那是决不会有人同意的。因为 一八六六年七月二十日 加尔各答一布纳希汽船公司的喜金孙总督号 在澳大利亚海岸东边五英里 碰见了这个游动的巨大物体。巴克船长起初还以为这是没有人知道的、暗礁 他正要测定它的位置的时候 突然这个不可解释的物体喷出两道水柱 哗的一声射到空中一百五十英尺高。', '这么说 除非这座暗礁上边有间歇喷泉 不然的话 喜金孙总督号面前的东西 就是还没有人知道的一种海中哺乳类动物 它还从鼻孔中喷出有气泡的水柱呢。同年七月二十三日 西印度 太平洋汽船公司的克利斯托巴尔哥郎号 在太平洋上也碰到这样的事。喜金孙总督号看见这怪物以后三天 克利斯托巴尔哥郎号在相距七百里的地方也看见了它 由此可知义 实用性则是鉴别它们正确与否的根据。', ' 这个奇特的鲸鱼类动物能以掠人的速度从这一处转移到另一处。十五天以后 在离上面说的地点有两千里远的地方 国营轮船公司的海尔维地亚号和皇家邮船公司的山农号 在美国和欧洲之间的大西洋海面上相遇的时候 在北纬四十二度十五分、西经六十度三十五分的地方 同时看到了这个大怪物。', '根据两船同时观察得到的结果 估计这只哺乳动物的长度至少有三百五十多英尺 约一百零六米 因为山农号和海尔维地亚号两船连起来 都还比它短 两船从头至尾只有一百米长。可是 最长的鲸鱼 像常常出役于阿留申群岛的久阑马克岛和翁居里克岛①附近海面的那些鲸鱼 也只不过是五十六米 而比这再长的 从来就没有过。', '接连不断地传来的消息 横渡大西洋的贝雷尔号所做的种种观察 茵曼轮船公司的越提那号跟这个怪物的一次相碰 法国二级军舰诺曼第号军官们所写的记录 海军高级参谋弗兹一詹姆斯在克利德爵士号上所做的很精密的测算 这一切在当时的确曾经哄动一时。在民族性比较浮躁的国家里 大家都拿这件事作为谈笑资料 但在严肃和踏实的国家里 像英国、美国和德国就不同 它们对这事就非常关心。', '在各大城市里 这怪物变成了家喻户晓的事件。咖啡馆里歌唱它 报刊上嘲笑它 舞台上扮演它。谣言正好有了机会 从这怪物身上捏造出各种各样的奇闻。在一些发行量不多的报刊上派 。 出现了关于各种离奇的巨大动物的报道 从白鲸、北极海中可怕的 莫比·狄克 ①一直到庞大的 克拉肯 ②——这种怪鱼的触须可以缠住一只载重五百吨的船而把它拖到海底下去——都应有尽有。', '有些人甚至不惜引经据典 或者搬出古代的传说如亚里士多德③和蒲林尼④的见解 他们承认这类怪物的存在 或者搬出彭土皮丹主教⑤的挪威童话 保罗·埃纪德的记述 以及哈林顿的报告 这报告是不容怀疑的 他说 一八五七年 他在嘉斯第兰号上看见过一种大蛇 那种蛇以前只在那立宪号到过的海面上⑤才能看见。', '于是 在学术团体里和科学报刊中产生了相信者和怀疑者 这两派人无休止地争论着。 怪物问题 激动着人们。自以为懂科学的新闻记者和一向自以为多才的文人开起火来 他们在这次值得纪念的笔战中花费了不少的墨水 。甚至有几个人还流了两三滴血 因为有人把针对大海蛇的笔锋移向一些态度傲慢的家伙身上了。', '在六个月当中 争论继续着。彼此有理 各执一词。当时流行的小报都兴致勃勃地刊登争论的文章 它们不是攻击巴西地理学院、柏林皇家科学院、不列颠学术联合会或华盛顿斯密孙学院发表的权威论文 就是驳斥印度群岛报、摩亚诺神父的宇宙杂志、皮德曼的消息报里面的讨论和法国及其他各国大报刊的科学新闻。', '这些多才的作家故意曲解反对派也常引证的林奈①的一句话 大自然不制造蠢东西 恳求大家不要相信北海的大怪鱼、大海蛇、 莫比·狄克 和疯狂的海员们臆造出来的其它怪物的存在 不要因此而否定了大自然。最后 某一著名尖刻的讽刺报有一位最受欢迎的编辑先生草草了事地发表一篇文章物主义的一些基本范畴和基本原理。', '强调马克思主义哲学必 处理了这个怪物 他像夷包列提②那样 在大家的笑声中 给这佳物最后一次打击、把它结果了。于是机智战胜了科学。在一八六七年头几个月里 这个问题好像是人了土 不会再复洁了。但就在这个时候 人们又听说发生了一些新的事件。', '现在的问题并不是一个急待解决的科学问题 而是必须认真设法避免的一个危险。问题带了完全不同的面貌。这个怪物变成了小岛、岩石、暗礁 但它是会奔驰的、不可捉摸的、行动莫测的暗礁。一八六七年八月五日 蒙特利奥航海公司的摩拉维安号夜间驶到北纬二十七度三十分、西经七十二度十五分的地方 船右舷撞上了一座岩石 可是 任何地图也没有记载过这一带海面上有这座岩石。', '由于风力的助航和四百匹马力的推动 船的速度达到每小时十三海里。毫无疑问 如果不是船身质地优良 特别坚固 摩拉维安号被撞以后 一定要把它从加拿大载来的二百三十六名乘客一齐带到海底去。事故发生在早晨五点左右天刚破晓的时候。船上值班的海员们立即跑到船的后部 他们十分细心地观察海面。', '除了有个六百多米宽的大漩涡——好像水面受过猛烈的冲击——以外 他们什么也没有看见 只把事故发生的地点确切地记了下来。摩拉维安号继续航行 似乎并没有受到什么损伤。·它是撞上了暗礁呢 还是撞上了一只沉没的破船?。当时没有法子知道。后来到船坞检查了船底朋友?。', '这个问题是革命的首要问题。 运用马克思主义的立尝 才发现一部分龙骨折断了。这事实本身是十分严重的 可是 如果不是过了三个星期后 在相同的情况下又发生了相同的事件 它很可能跟许多其他的事件一样很快被人忘掉了。接着又发生的那一次撞船的事件 单单由于受害船的国籍和它所属公司的声望 就足以引起十分广泛的反响。', '英国著名的船主苟纳尔的名字是没有一个人不知道伪。这位精明的企业家早在一八四零年就创办了一家邮船公司 开辟了从利物浦到哈利法克斯①的航线 当时只有三艘四百匹马力、载重一千一百六十二吨的明轮木船。八年以后 公司扩大了 共有四艘六百五十匹马力、载重一千八百二十吨的船。', '再过两年 又添了两艘马力和载重量更大的船 一八五三年 苟纳尔公司继续取得装运政府邮件的特权 一连添造了阿拉伯号、波斯号、中国号、斯备脱亚号、爪哇号、俄罗斯号 这些都是头等的快船 而且是最宽大的 除了大东方号外 在海上航行的船没有能跟它们相比的。', '到一八六七年 这家公司一共有十二艘船~八艘明轮的 四艘暗轮的。我所以要把上面的情形简单地介绍一下 是要大家知道这家海运公司的重要性。它由于经营得法 是全世界都闻名的。任何航海企业 没有比这公司搞得更精明 经营得更成功的了。二十六年来学流派均是庸俗进化论的宣传者。', '实证主义者斯宾塞对其曾 苟纳尔公司的船在大西洋上航行了两千次 没有一次航行不达目的地 没有一次发生迟误 从没有遗失过一封信 损失过一个人或一只船。 因此 尽管法国竭力要抢它的生意 但是乘客们都一致愿意搭苟纳尔公司的船 这点从近年来官方的统计文献中就可以看出来。', '了解这情形以后 便没有人奇怪这家公司的一只汽船遭遇到意外事件会引起那么巨大的反响。一八六七年四月十三日 海很平静 风又是顺风 斯备脱亚号在西经十五度十二分、北纬四十五度三十七分的海面上行驶着。它在一千匹马力的发动机推动下 速度为每小时十三海里半。', '它的机轮在海中转动 完全正常。它当时的吃水深度是六米七十厘米 排水量是六 六百八十五方米。下午四点十六分 乘客们正在大厅中吃点心的时候 在斯各脱亚号船尾、左舷机轮后面一点 似乎发生了轻微的撞击。斯各脱亚号不是撞上了什么 而是被什么撞上了。', '憧它的不是敲击的器械而是钻凿的器械。这次冲撞是十分轻微的 要不是管船舱的人员跑到甲板上来喊 船要沉了 船要沉了 。 也许船上的人谁也不会在意。旅客们起初十分惊慌 但船长安德生很快就使他们安稳下来。危险并不会立刻就发生。斯各脱亚号由防水板分为七大间 一点也不在乎个把漏洞。', '安德生船长立即跑到舱底下去。他查出第五间被海水浸人了 海水浸入十分快 证明漏洞相当大。好在这间里没有蒸汽炉 不然的话 炉火就要熄灭了。安德生船长吩咐马上停船 并且命令一个潜水员下水检查船身的损坏情形。一会儿 他知道船底有一个长两米的大洞。', '这样一个裂口是没法堵住的 斯各脱亚号尽管机轮有一半浸在水里 但也必须继续行驶。当时船离克利亚峡还有三百海里 等船驶进公司的码头 已经误了三天期 在这三天里 利物浦的人都为它惶惶不安。斯各脱亚号被架了起来 工程师们开始检查。他们眼睛所看见的情形连自己也不能相信。', '在船身吃水线下两米半的地方 露出一个很规则的等边三角形的缺口。铁皮上的伤痕十分整齐 、就是钻孔机也不能凿得这么准确 弄成这个裂口的锐利器械一定不是用普通的钢铁制的 因为 这家伙在以惊人的力量向前猛撞 凿穿了四厘米厚的铁皮以后、还能用一种很难做到的后退动作 使自己脱身逃走。', '最近这次事件的经过大致就是这样。结果这又一次使舆论哄动起来。从这时候起 所有从前原因不明的航海遇难事件 现在都算在这个怪物的账上了。这只离奇古怪的动物于是负起了所有船只沉没的责任。不幸的是船沉的数目相当大 按照统计年鉴的记载 包括帆船和汽船在内 每年的损失约有三千艘左右 至于因下落不明而断定失踪 的 每年的数目也不下两百艘 。不管有没有冤枉这怪物 人们都把船只失踪的原因算在它身上。由于它的存在 五大洲间的海上交通越来越危险了 大家都坚决要求不惜任何代价清除海上这条可怕盼鲸鱼怪。']\n",
130
+ "INFO:ChatTTS.core:All initialized.\n",
131
+ " 46% 175/384 [00:05<00:07, 29.18steps/s]\n",
132
+ " 73% 1501/2048 [01:18<00:28, 19.09steps/s]\n",
133
+ "INFO:ChatTTS.core:All initialized.\n",
134
+ " 62% 238/384 [00:08<00:05, 28.48steps/s]\n",
135
+ " 36% 736/2048 [00:28<01:07, 19.51steps/s]"
136
+ ]
137
+ }
138
+ ]
139
  }
140
  ]
141
  }
config.py CHANGED
@@ -1,15 +1,13 @@
1
  # Description: Configuration file for the project
2
- llama_seed = 2581
3
- DEFAULT_DIR = "output"
4
  DEFAULT_SPEED = 5
5
  DEFAULT_ORAL = 2
6
  DEFAULT_LAUGH = 0
7
  DEFAULT_BK = 4
8
  # 段落切割
9
- DEFAULT_SEG_LENGTH = 80
10
- DEFAULT_BATCH_SIZE = 3
11
  # 温度
12
- DEFAULT_TEMPERATURE = 0.1
13
  # top_P
14
  DEFAULT_TOP_P = 0.7
15
  # top_K
@@ -43,4 +41,4 @@ LLM_PROMPT = """
43
  注意: character 字段的值需要使用类似 "旁白"、"年轻男性"、"年轻女性" 等角色身份。如果有多个角色,可以使用 "年轻男性1"、"年轻男性2" 等。
44
 
45
  --故事文本--
46
- """
 
1
  # Description: Configuration file for the project
 
 
2
  DEFAULT_SPEED = 5
3
  DEFAULT_ORAL = 2
4
  DEFAULT_LAUGH = 0
5
  DEFAULT_BK = 4
6
  # 段落切割
7
+ DEFAULT_SEG_LENGTH = 120
8
+ DEFAULT_BATCH_SIZE = 5
9
  # 温度
10
+ DEFAULT_TEMPERATURE = 0.3
11
  # top_P
12
  DEFAULT_TOP_P = 0.7
13
  # top_K
 
41
  注意: character 字段的值需要使用类似 "旁白"、"年轻男性"、"年轻女性" 等角色身份。如果有多个角色,可以使用 "年轻男性1"、"年轻男性2" 等。
42
 
43
  --故事文本--
44
+ """
tts_model.py CHANGED
@@ -1,15 +1,12 @@
1
- import datetime
2
- import json
 
3
  import os
4
- import re
5
  import time
6
-
7
- import numpy as np
8
- import torch
9
  from tqdm import tqdm
10
-
11
- import ChatTTS
12
  from config import DEFAULT_TEMPERATURE, DEFAULT_TOP_P, DEFAULT_TOP_K
 
13
 
14
 
15
  def load_chat_tts_model(source='huggingface', force_redownload=False, local_path=None):
@@ -22,7 +19,7 @@ def load_chat_tts_model(source='huggingface', force_redownload=False, local_path
22
  """
23
  print("Loading ChatTTS model...")
24
  chat = ChatTTS.Chat()
25
- chat.load_models(source=source, force_redownload=force_redownload, custom_path=local_path, compile=False)
26
  return chat
27
 
28
 
@@ -47,38 +44,19 @@ def deterministic(seed=0):
47
  torch.backends.cudnn.deterministic = True
48
  torch.backends.cudnn.benchmark = False
49
 
50
-
51
- def generate_audio_for_seed(chat, seed, texts, batch_size, speed, refine_text_prompt, roleid=None,
52
- temperature=DEFAULT_TEMPERATURE,
53
- top_P=DEFAULT_TOP_P, top_K=DEFAULT_TOP_K, cur_tqdm=None, skip_save=False,
54
- skip_refine_text=False, speaker_type="seed", pt_file=None):
55
  from utils import combine_audio, save_audio, batch_split
56
- print(f"speaker_type: {speaker_type}")
57
- if speaker_type == "seed":
58
- if seed in [None, -1, 0, "", "random"]:
59
- seed = np.random.randint(0, 9999)
60
- deterministic(seed)
61
- rnd_spk_emb = chat.sample_random_speaker()
62
- elif speaker_type == "role":
63
- # 从 JSON 文件中读取数据
64
- with open('./slct_voice_240605.json', 'r', encoding='utf-8') as json_file:
65
- slct_idx_loaded = json.load(json_file)
66
- # 将包含 Tensor 数据的部分转换回 Tensor 对象
67
- for key in slct_idx_loaded:
68
- tensor_list = slct_idx_loaded[key]["tensor"]
69
- slct_idx_loaded[key]["tensor"] = torch.tensor(tensor_list)
70
- # 将音色 tensor 打包进params_infer_code,固定使用此音色发音,调低temperature
71
- rnd_spk_emb = slct_idx_loaded[roleid]["tensor"]
72
- # temperature = 0.001
73
- elif speaker_type == "pt":
74
- print(pt_file)
75
- rnd_spk_emb = torch.load(pt_file)
76
- print(rnd_spk_emb.shape)
77
- if rnd_spk_emb.shape != (768,):
78
- raise ValueError("维度应为 768。")
79
- else:
80
- raise ValueError(f"Invalid speaker_type: {speaker_type}. ")
81
 
 
 
82
  params_infer_code = {
83
  'spk_emb': rnd_spk_emb,
84
  'prompt': f'[speed_{speed}]',
@@ -99,16 +77,13 @@ def generate_audio_for_seed(chat, seed, texts, batch_size, speed, refine_text_pr
99
  if not cur_tqdm:
100
  cur_tqdm = tqdm
101
 
102
- if re.search(r'\[uv_break\]|\[laugh\]', ''.join(texts)) is not None:
103
- if not skip_refine_text:
104
- print("Detected [uv_break] or [laugh] in text, skipping refine_text")
105
- skip_refine_text = True
106
-
107
  for batch in cur_tqdm(batch_split(texts, batch_size), desc=f"Inferring audio for seed={seed}"):
108
  flag += len(batch)
109
- _params_infer_code = {**params_infer_code}
110
- wavs = chat.infer(batch, params_infer_code=_params_infer_code, params_refine_text=params_refine_text,
111
- use_decoder=True, skip_refine_text=skip_refine_text)
 
 
112
  all_wavs.extend(wavs)
113
  clear_cuda_cache()
114
  if skip_save:
@@ -118,28 +93,9 @@ def generate_audio_for_seed(chat, seed, texts, batch_size, speed, refine_text_pr
118
  elapsed_time = end_time - start_time
119
  print(f"Saving audio for seed {seed}, took {elapsed_time:.2f}s")
120
  timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H%M%S')
121
- wav_filename = f"chattts-[seed_{seed}][speed_{speed}]{refine_text_prompt}[{timestamp}].wav"
122
- return save_audio(wav_filename, combined_audio)
123
-
124
-
125
- def generate_refine_text(chat, seed, text, refine_text_prompt, temperature=DEFAULT_TEMPERATURE,
126
- top_P=DEFAULT_TOP_P, top_K=DEFAULT_TOP_K):
127
- if seed in [None, -1, 0, "", "random"]:
128
- seed = np.random.randint(0, 9999)
129
-
130
- deterministic(seed)
131
-
132
- params_refine_text = {
133
- 'prompt': refine_text_prompt,
134
- 'top_P': top_P,
135
- 'top_K': top_K,
136
- 'temperature': temperature
137
- }
138
- print('params_refine_text:', text)
139
- print('refine_text_prompt:', refine_text_prompt)
140
- refine_text = chat.infer(text, params_refine_text=params_refine_text, refine_text_only=True, skip_refine_text=False)
141
- print('refine_text:', refine_text)
142
- return refine_text
143
 
144
 
145
  def tts(chat, text_file, seed, speed, oral, laugh, bk, seg, batch, progres=None):
 
1
+ import ChatTTS
2
+ import torch
3
+ import numpy as np
4
  import os
 
5
  import time
 
 
 
6
  from tqdm import tqdm
7
+ import datetime
 
8
  from config import DEFAULT_TEMPERATURE, DEFAULT_TOP_P, DEFAULT_TOP_K
9
+ import spaces
10
 
11
 
12
  def load_chat_tts_model(source='huggingface', force_redownload=False, local_path=None):
 
19
  """
20
  print("Loading ChatTTS model...")
21
  chat = ChatTTS.Chat()
22
+ chat.load_models(source=source, force_redownload=force_redownload, local_path=local_path)
23
  return chat
24
 
25
 
 
44
  torch.backends.cudnn.deterministic = True
45
  torch.backends.cudnn.benchmark = False
46
 
47
+ @spaces.GPU
48
+ def generate_audio_for_seed(chat, seed, texts, batch_size, speed, refine_text_prompt, temperature=DEFAULT_TEMPERATURE,
49
+ top_P=DEFAULT_TOP_P, top_K=DEFAULT_TOP_K, cur_tqdm=None, skip_save=False):
 
 
50
  from utils import combine_audio, save_audio, batch_split
51
+ # torch.manual_seed(seed)
52
+ # top_P = 0.7,
53
+ # top_K = 20,
54
+ # temperature = 0.3,
55
+ if seed in [None, -1, 0, "", "random"]:
56
+ seed = np.random.randint(0, 9999)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ deterministic(seed)
59
+ rnd_spk_emb = chat.sample_random_speaker()
60
  params_infer_code = {
61
  'spk_emb': rnd_spk_emb,
62
  'prompt': f'[speed_{speed}]',
 
77
  if not cur_tqdm:
78
  cur_tqdm = tqdm
79
 
 
 
 
 
 
80
  for batch in cur_tqdm(batch_split(texts, batch_size), desc=f"Inferring audio for seed={seed}"):
81
  flag += len(batch)
82
+ # refine_text = chat.infer(batch, params_infer_code=params_infer_code, params_refine_text=params_refine_text, refine_text_only=True)
83
+ # print(refine_text)
84
+ # exit()
85
+ wavs = chat.infer(batch, params_infer_code=params_infer_code, params_refine_text=params_refine_text,
86
+ use_decoder=True, skip_refine_text=False)
87
  all_wavs.extend(wavs)
88
  clear_cuda_cache()
89
  if skip_save:
 
93
  elapsed_time = end_time - start_time
94
  print(f"Saving audio for seed {seed}, took {elapsed_time:.2f}s")
95
  timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H%M%S')
96
+ wav_filename = f"long-[seed_{seed}][speed_{speed}]{refine_text_prompt}[{timestamp}].wav"
97
+ save_audio(wav_filename, combined_audio)
98
+ return wav_filename
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
 
101
  def tts(chat, text_file, seed, speed, oral, laugh, bk, seg, batch, progres=None):
utils.py CHANGED
@@ -4,16 +4,9 @@ except ImportError:
4
  print("The 'cn2an' module is not installed. Please install it using 'pip install cn2an'.")
5
  exit(1)
6
 
7
- try:
8
- import jieba
9
- except ImportError:
10
- print("The 'jieba' module is not installed. Please install it using 'pip install jieba'.")
11
- exit(1)
12
-
13
  import re
14
  import numpy as np
15
  import wave
16
- import jieba.posseg as pseg
17
 
18
 
19
  def save_audio(file_name, audio, rate=24000):
@@ -24,20 +17,13 @@ def save_audio(file_name, audio, rate=24000):
24
  :param rate:
25
  :return:
26
  """
27
- import os
28
- from config import DEFAULT_DIR
29
  audio = (audio * 32767).astype(np.int16)
30
 
31
- # 检查默认目录
32
- if not os.path.exists(DEFAULT_DIR):
33
- os.makedirs(DEFAULT_DIR)
34
- full_path = os.path.join(DEFAULT_DIR, file_name)
35
- with wave.open(full_path, "w") as wf:
36
  wf.setnchannels(1)
37
  wf.setsampwidth(2)
38
  wf.setframerate(rate)
39
  wf.writeframes(audio.tobytes())
40
- return full_path
41
 
42
 
43
  def combine_audio(wavs):
@@ -101,32 +87,16 @@ def remove_chinese_punctuation(text):
101
  :param text:
102
  :return:
103
  """
104
- chinese_punctuation_pattern = r"[:;!(),【】『』「」《》-‘“’”:,;!\(\)\[\]><\-·]"
105
- text = re.sub(chinese_punctuation_pattern, '', text)
106
- # 使用正则表达式将多个连续的句号替换为一个句号
107
- text = re.sub(r'[。,]{2,}', '。', text)
108
- # 删除开头和结尾的 , 号
109
- text = re.sub(r'^,|,$', '', text)
110
- return text
111
-
112
- def remove_english_punctuation(text):
113
- """
114
- 移除文本中的中文标点符号 [:;!(),【】『』「」《》-‘“’”:,;!\(\)\[\]><\-] 替换为 ,
115
- :param text:
116
- :return:
117
- """
118
- chinese_punctuation_pattern = r"[:;!(),【】『』「」《》-‘“’”:,;!\(\)\[\]><\-·]"
119
- text = re.sub(chinese_punctuation_pattern, ',', text)
120
  # 使用正则表达式将多个连续的句号替换为一个句号
121
- text = re.sub(r'[,\.]{2,}', '.', text)
122
- # 删除开头和结尾的 , 号
123
- text = re.sub(r'^,|,$', '', text)
124
  return text
125
 
126
 
127
  def text_normalize(text):
128
  """
129
- 对文本进行归一化处理 (PaddlePaddle版本)
130
  :param text:
131
  :return:
132
  """
@@ -134,7 +104,14 @@ def text_normalize(text):
134
  # ref: https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization
135
  tx = TextNormalizer()
136
  sentences = tx.normalize(text)
 
 
137
  _txt = ''.join(sentences)
 
 
 
 
 
138
  return _txt
139
 
140
 
@@ -147,20 +124,6 @@ def convert_numbers_to_chinese(text):
147
  return cn2an.transform(text, "an2cn")
148
 
149
 
150
- def detect_language(sentence):
151
- # ref: https://github.com/2noise/ChatTTS/blob/main/ChatTTS/utils/infer_utils.py#L55
152
- chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]')
153
- english_word_pattern = re.compile(r'\b[A-Za-z]+\b')
154
-
155
- chinese_chars = chinese_char_pattern.findall(sentence)
156
- english_words = english_word_pattern.findall(sentence)
157
-
158
- if len(chinese_chars) > len(english_words):
159
- return "zh"
160
- else:
161
- return "en"
162
-
163
-
164
  def split_text(text, min_length=60):
165
  """
166
  将文本分割为长度不小于min_length的句子
@@ -168,63 +131,33 @@ def split_text(text, min_length=60):
168
  :param min_length:
169
  :return:
170
  """
171
- # 短句分割符号
172
- sentence_delimiters = re.compile(r'([。?!\.]+)')
173
- # 匹配多个连续的回车符 作为段落点 强制分段
174
- paragraph_delimiters = re.compile(r'(\s*\n\s*)+')
175
-
176
- paragraphs = re.split(paragraph_delimiters, text)
177
-
178
  result = []
179
-
180
- for paragraph in paragraphs:
181
- if not paragraph.strip():
182
- continue # 跳过空段落
183
- # 小于阈值的段落直接分开
184
- if len(paragraph.strip()) < min_length:
185
- result.append(paragraph.strip())
186
- continue
187
- # 大于的再计算拆分
188
- sentences = re.split(sentence_delimiters, paragraph)
189
- current_sentence = ''
190
- for sentence in sentences:
191
- if re.match(sentence_delimiters, sentence):
192
- current_sentence += sentence.strip() + ''
193
- if len(current_sentence) >= min_length:
194
- result.append(current_sentence.strip())
195
- current_sentence = ''
196
- else:
197
- current_sentence += sentence.strip()
198
-
199
- if current_sentence:
200
- if len(current_sentence) < min_length and len(result) > 0:
201
- result[-1] += current_sentence
202
- else:
203
- result.append(current_sentence)
204
- if detect_language(text[:1024]) == "zh":
205
- result = [normalize_zh(_.strip()) for _ in result if _.strip()]
206
- else:
207
- result = [normalize_en(_.strip()) for _ in result if _.strip()]
208
  return result
209
 
210
 
211
- def normalize_en(text):
212
- # 不再在 ChatTTS 外正则化文本
213
- # from tn.english.normalizer import Normalizer
214
- # normalizer = Normalizer()
215
- # text = normalizer.normalize(text)
216
- # text = remove_english_punctuation(text)
217
- return text
218
-
219
-
220
  def normalize_zh(text):
221
- # 不再在 ChatTTS 外正则化文本
222
- # from tn.chinese.normalizer import Normalizer
223
- # normalizer = Normalizer()
224
- # text = normalizer.normalize(text)
225
- # text = remove_chinese_punctuation(text)
226
- text = process_ddd(text)
227
- return text
228
 
229
 
230
  def batch_split(items, batch_size=5):
@@ -256,76 +189,11 @@ def read_long_text(file_path):
256
  raise ValueError("无法识别文件编码")
257
 
258
 
259
- def replace_tokens(text):
260
- remove_tokens = ['UNK']
261
- for token in remove_tokens:
262
- text = re.sub(r'\[' + re.escape(token) + r'\]', '', text)
263
-
264
- tokens = ['uv_break', 'laugh','lbreak']
265
- for token in tokens:
266
- text = re.sub(r'\[' + re.escape(token) + r'\]', f'uu{token}uu', text)
267
- text = text.replace('_', '')
268
- return text
269
-
270
-
271
- def restore_tokens(text):
272
- tokens = ['uvbreak', 'laugh', 'UNK', 'lbreak']
273
- for token in tokens:
274
- text = re.sub(r'uu' + re.escape(token) + r'uu', f'[{token}]', text)
275
- text = text.replace('[uvbreak]', '[uv_break]')
276
- return text
277
-
278
-
279
- def process_ddd(text):
280
- """
281
- 处理“地”、“得” 字的使用,都替换为“的”
282
- 依据:地、得的使用,主要是在动词和形容词前后,本方法没有严格按照语法替换,因为时常遇到用错的情况。
283
- 另外受 jieba 分词准确率的影响,部分情况下可能会出漏掉。例如:小红帽疑惑地问
284
- :param text: 输入的文本
285
- :return: 处理后的文本
286
- """
287
- word_list = [(word, flag) for word, flag in pseg.cut(text, use_paddle=False)]
288
- # print(word_list)
289
- processed_words = []
290
- for i, (word, flag) in enumerate(word_list):
291
- if word in ["地", "得"]:
292
- # Check previous and next word's flag
293
- # prev_flag = word_list[i - 1][1] if i > 0 else None
294
- # next_flag = word_list[i + 1][1] if i + 1 < len(word_list) else None
295
-
296
- # if prev_flag in ['v', 'a'] or next_flag in ['v', 'a']:
297
- if flag in ['uv', 'ud']:
298
- processed_words.append("的")
299
- else:
300
- processed_words.append(word)
301
- else:
302
- processed_words.append(word)
303
-
304
- return ''.join(processed_words)
305
-
306
-
307
- def replace_space_between_chinese(text):
308
- return re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff])', '', text)
309
-
310
-
311
  if __name__ == '__main__':
312
- # txts = [
313
- # "快速地跑过红色的大门",
314
- # "笑得很开心,学得很好",
315
- # "小红帽疑惑地问?",
316
- # "大灰狼慌张地回答",
317
- # "哦,这是为了更好地听你说话。",
318
- # "大灰狼不耐烦地说:“为了更好地抱你。”",
319
- # "他跑得很快,工作做得非常认真,这是他努力地结果。得到",
320
- # ]
321
- # for txt in txts:
322
- # print(txt, '-->', process_ddd(txt))
323
-
324
  txts = [
325
  "电影中梁朝伟扮演的陈永仁的编号27149",
326
  "这块黄金重达324.75克 我们班的最高总分为583分",
327
  "12\~23 -1.5\~2",
328
- "居维埃·拉色别德①、杜梅里②、卡特法日③,"
329
 
330
  ]
331
  for txt in txts:
 
4
  print("The 'cn2an' module is not installed. Please install it using 'pip install cn2an'.")
5
  exit(1)
6
 
 
 
 
 
 
 
7
  import re
8
  import numpy as np
9
  import wave
 
10
 
11
 
12
  def save_audio(file_name, audio, rate=24000):
 
17
  :param rate:
18
  :return:
19
  """
 
 
20
  audio = (audio * 32767).astype(np.int16)
21
 
22
+ with wave.open(file_name, "w") as wf:
 
 
 
 
23
  wf.setnchannels(1)
24
  wf.setsampwidth(2)
25
  wf.setframerate(rate)
26
  wf.writeframes(audio.tobytes())
 
27
 
28
 
29
  def combine_audio(wavs):
 
87
  :param text:
88
  :return:
89
  """
90
+ chinese_punctuation_pattern = r"[:;!(),【】『』「」《》-‘“’”:,;!\(\)\[\]><\-]"
91
+ text = re.sub(chinese_punctuation_pattern, ' ', text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  # 使用正则表达式将多个连续的句号替换为一个句号
93
+ text = re.sub(r'{2,}', '', text)
 
 
94
  return text
95
 
96
 
97
  def text_normalize(text):
98
  """
99
+ 对文本进行归一化处理
100
  :param text:
101
  :return:
102
  """
 
104
  # ref: https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization
105
  tx = TextNormalizer()
106
  sentences = tx.normalize(text)
107
+ # print(sentences)
108
+
109
  _txt = ''.join(sentences)
110
+ # 替换掉除中文之外的所有字符
111
+ _txt = re.sub(
112
+ r"[^\u4e00-\u9fa5,。!?、]+", "", _txt
113
+ )
114
+
115
  return _txt
116
 
117
 
 
124
  return cn2an.transform(text, "an2cn")
125
 
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  def split_text(text, min_length=60):
128
  """
129
  将文本分割为长度不小于min_length的句子
 
131
  :param min_length:
132
  :return:
133
  """
134
+ sentence_delimiters = re.compile(r'([。?!\.\n]+)')
135
+ sentences = re.split(sentence_delimiters, text)
136
+ # print(sentences)
137
+ # exit()
 
 
 
138
  result = []
139
+ current_sentence = ''
140
+ for sentence in sentences:
141
+ if re.match(sentence_delimiters, sentence):
142
+ current_sentence += sentence.strip() + '。'
143
+ if len(current_sentence) >= min_length:
144
+ result.append(current_sentence.strip())
145
+ current_sentence = ''
146
+ else:
147
+ current_sentence += sentence.strip()
148
+ if current_sentence:
149
+ if len(current_sentence) < min_length and len(result) > 0:
150
+ result[-1] += current_sentence
151
+ else:
152
+ result.append(current_sentence)
153
+ # result = [convert_numbers_to_chinese(remove_chinese_punctuation(_.strip())) for _ in result if _.strip()]
154
+ result = [normalize_zh(_.strip()) for _ in result if _.strip()]
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  return result
156
 
157
 
 
 
 
 
 
 
 
 
 
158
  def normalize_zh(text):
159
+ # return text_normalize(remove_chinese_punctuation(text))
160
+ return convert_numbers_to_chinese(remove_chinese_punctuation(text))
 
 
 
 
 
161
 
162
 
163
  def batch_split(items, batch_size=5):
 
189
  raise ValueError("无法识别文件编码")
190
 
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  if __name__ == '__main__':
 
 
 
 
 
 
 
 
 
 
 
 
193
  txts = [
194
  "电影中梁朝伟扮演的陈永仁的编号27149",
195
  "这块黄金重达324.75克 我们班的最高总分为583分",
196
  "12\~23 -1.5\~2",
 
197
 
198
  ]
199
  for txt in txts:
webui_mix.py CHANGED
@@ -1,7 +1,3 @@
1
- import os
2
- import sys
3
-
4
- sys.path.insert(0, os.getcwd())
5
  import argparse
6
  import re
7
  import time
@@ -10,13 +6,12 @@ import pandas
10
  import numpy as np
11
  from tqdm import tqdm
12
  import random
 
13
  import gradio as gr
14
  import json
15
- from utils import normalize_zh, batch_split, normalize_audio, combine_audio
16
- from tts_model import load_chat_tts_model, clear_cuda_cache, generate_audio_for_seed
17
- from config import DEFAULT_BATCH_SIZE, DEFAULT_SPEED, DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P, DEFAULT_ORAL, \
18
- DEFAULT_LAUGH, DEFAULT_BK, DEFAULT_SEG_LENGTH
19
- import torch
20
 
21
  parser = argparse.ArgumentParser(description="Gradio ChatTTS MIX")
22
  parser.add_argument("--source", type=str, default="huggingface", help="Model source: 'huggingface' or 'local'.")
@@ -45,31 +40,30 @@ if not os.path.exists(SAVED_SEEDS_FILE):
45
 
46
  chat = load_chat_tts_model(source=args.source, local_path=args.local_path)
47
  # chat = None
48
- # chat = load_chat_tts_model(source="local", local_path=r"models")
49
 
50
  # 抽卡的最大数量
51
  max_audio_components = 10
52
 
 
 
 
 
 
 
 
53
  # 加载
54
  def load_seeds():
55
  with open(SAVED_SEEDS_FILE, "r") as f:
56
  global saved_seeds
57
-
58
- seeds = json.load(f)
59
-
60
- # 兼容旧的 JSON 格式,添加 path 字段
61
- for seed in seeds:
62
- if 'path' not in seed:
63
- seed['path'] = None
64
-
65
- saved_seeds = seeds
66
  return saved_seeds
67
 
68
 
69
  def display_seeds():
70
  seeds = load_seeds()
71
  # 转换为 List[List] 的形式
72
- return [[i, s['seed'], s['name'], s['path']] for i, s in enumerate(seeds)]
73
 
74
 
75
  saved_seeds = load_seeds()
@@ -84,14 +78,13 @@ def save_seeds():
84
 
85
 
86
  # 添加 seed
87
- def add_seed(seed, name, audio_path, save=True):
88
  for s in saved_seeds:
89
  if s['seed'] == seed:
90
  return False
91
  saved_seeds.append({
92
  'seed': seed,
93
- 'name': name,
94
- 'path': audio_path
95
  })
96
  if save:
97
  save_seeds()
@@ -117,7 +110,7 @@ def delete_seed(seed, save=True):
117
  return True
118
  return False
119
 
120
-
121
  def generate_seeds(num_seeds, texts, tq):
122
  """
123
  生成随机音频种子并保存
@@ -136,7 +129,7 @@ def generate_seeds(num_seeds, texts, tq):
136
  for _ in tq(range(num_seeds), desc=f"随机音色生成中..."):
137
  seed = np.random.randint(0, 9999)
138
 
139
- filename = generate_audio_for_seed(chat, seed, texts, 1, 5, "[oral_2][laugh_0][break_4]", None, 0.3, 0.7, 20)
140
  seeds.append((filename, seed))
141
  clear_cuda_cache()
142
 
@@ -144,12 +137,11 @@ def generate_seeds(num_seeds, texts, tq):
144
 
145
 
146
  # 保存选定的音频种子
147
- def do_save_seed(seed, audio_path):
148
- print(f"Saving seed {seed} to {audio_path}")
149
  seed = seed.replace('保存种子 ', '').strip()
150
  if not seed:
151
  return
152
- add_seed(int(seed), seed, audio_path)
153
  gr.Info(f"Seed {seed} has been saved.")
154
 
155
 
@@ -181,24 +173,11 @@ def do_delete_seed(val):
181
  return display_seeds()
182
 
183
 
184
- # 定义播放音频的函数
185
- def do_play_seed(val):
186
- # 从 val 匹配 [(\d+)] 获取index
187
- index = re.search(r'\[(\d+)\]', val)
188
- if index:
189
- index = int(index.group(1))
190
- seed = saved_seeds[index]['seed']
191
- audio_path = saved_seeds[index]['path']
192
- if audio_path:
193
- return gr.update(visible=True, value=audio_path)
194
- return gr.update(visible=False, value=None)
195
-
196
-
197
  def seed_change_btn():
198
  global SELECTED_SEED_INDEX
199
  if SELECTED_SEED_INDEX == -1:
200
- return ['删除', '试听']
201
- return [f'删除 idx=[{SELECTED_SEED_INDEX[0]}]', f'试听 idx=[{SELECTED_SEED_INDEX[0]}]']
202
 
203
 
204
  def audio_interface(num_seeds, texts, progress=gr.Progress()):
@@ -215,26 +194,11 @@ def audio_interface(num_seeds, texts, progress=gr.Progress()):
215
  # 不足的部分
216
  all_wavs = wavs + [None] * (max_audio_components - len(wavs))
217
  all_seeds = seeds + [''] * (max_audio_components - len(seeds))
218
- return [item for pair in zip(all_wavs, all_seeds, all_wavs) for item in pair]
219
-
220
-
221
- # 保存刚刚生成的种子文件路径
222
- audio_paths = [gr.State(value=None) for _ in range(max_audio_components)]
223
-
224
-
225
- def audio_interface_with_paths(num_seeds, texts, progress=gr.Progress()):
226
- """
227
- 比 audio_interface 多携带音频的 path
228
- """
229
- results = audio_interface(num_seeds, texts, progress)
230
- wavs = results[::2] # 提取音频文件路径
231
- for i, wav in enumerate(wavs):
232
- audio_paths[i].value = wav # 直接为 State 组件赋值
233
- return results
234
 
235
 
236
  def audio_interface_empty(num_seeds, texts, progress=gr.Progress(track_tqdm=True)):
237
- return [None, "", None] * max_audio_components
238
 
239
 
240
  def update_audio_components(slider_value):
@@ -242,9 +206,8 @@ def update_audio_components(slider_value):
242
  k = int(slider_value)
243
  audios = [gr.Audio(visible=True)] * k + [gr.Audio(visible=False)] * (max_audio_components - k)
244
  tbs = [gr.Textbox(visible=True)] * k + [gr.Textbox(visible=False)] * (max_audio_components - k)
245
- stats = [gr.State(value=None)] * max_audio_components
246
  print(f'k={k}, audios={len(audios)}')
247
- return [item for pair in zip(audios, tbs, stats) for item in pair]
248
 
249
 
250
  def seed_change(evt: gr.SelectData):
@@ -253,11 +216,11 @@ def seed_change(evt: gr.SelectData):
253
  SELECTED_SEED_INDEX = evt.index
254
  return evt.index
255
 
256
-
257
  def generate_tts_audio(text_file, num_seeds, seed, speed, oral, laugh, bk, min_length, batch_size, temperature, top_P,
258
- top_K, roleid=None, refine_text=True, speaker_type="seed", pt_file=None, progress=gr.Progress()):
259
  from tts_model import generate_audio_for_seed
260
- from utils import split_text, replace_tokens, restore_tokens
261
  if seed in [0, -1, None]:
262
  seed = random.randint(1, 9999)
263
  content = ''
@@ -265,151 +228,19 @@ def generate_tts_audio(text_file, num_seeds, seed, speed, oral, laugh, bk, min_l
265
  content = ""
266
  elif isinstance(text_file, str):
267
  content = text_file
268
- # 将 [uv_break] [laugh] 替换为 _uv_break_ _laugh_ 处理后再还原
269
- content = replace_tokens(content)
270
  texts = split_text(content, min_length=min_length)
271
- for i, text in enumerate(texts):
272
- texts[i] = restore_tokens(text)
273
 
274
  if oral < 0 or oral > 9 or laugh < 0 or laugh > 2 or bk < 0 or bk > 7:
275
  raise ValueError("oral_(0-9), laugh_(0-2), break_(0-7) out of range")
276
 
277
  refine_text_prompt = f"[oral_{oral}][laugh_{laugh}][break_{bk}]"
278
  try:
279
- output_files = generate_audio_for_seed(
280
- chat=chat,
281
- seed=seed,
282
- texts=texts,
283
- batch_size=batch_size,
284
- speed=speed,
285
- refine_text_prompt=refine_text_prompt,
286
- roleid=roleid,
287
- temperature=temperature,
288
- top_P=top_P,
289
- top_K=top_K,
290
- cur_tqdm=progress.tqdm,
291
- skip_save=False,
292
- skip_refine_text=not refine_text,
293
- speaker_type=speaker_type,
294
- pt_file=pt_file,
295
- )
296
  return output_files
297
  except Exception as e:
298
- raise e
299
-
300
-
301
- def generate_tts_audio_stream(text_file, num_seeds, seed, speed, oral, laugh, bk, min_length, batch_size, temperature,
302
- top_P,
303
- top_K, roleid=None, refine_text=True, speaker_type="seed", pt_file=None,
304
- stream_mode="fake"):
305
- from utils import split_text, replace_tokens, restore_tokens
306
- from tts_model import deterministic
307
- if seed in [0, -1, None]:
308
- seed = random.randint(1, 9999)
309
- content = ''
310
- if os.path.isfile(text_file):
311
- content = ""
312
- elif isinstance(text_file, str):
313
- content = text_file
314
- # 将 [uv_break] [laugh] 替换为 _uv_break_ _laugh_ 处理后再还原
315
- content = replace_tokens(content)
316
- # texts = [normalize_zh(_) for _ in content.split('\n') if _.strip()]
317
- texts = split_text(content, min_length=min_length)
318
-
319
- for i, text in enumerate(texts):
320
- texts[i] = restore_tokens(text)
321
-
322
- if oral < 0 or oral > 9 or laugh < 0 or laugh > 2 or bk < 0 or bk > 7:
323
- raise ValueError("oral_(0-9), laugh_(0-2), break_(0-7) out of range")
324
-
325
- refine_text_prompt = f"[oral_{oral}][laugh_{laugh}][break_{bk}]"
326
-
327
- print(f"speaker_type: {speaker_type}")
328
- if speaker_type == "seed":
329
- if seed in [None, -1, 0, "", "random"]:
330
- seed = np.random.randint(0, 9999)
331
- deterministic(seed)
332
- rnd_spk_emb = chat.sample_random_speaker()
333
- elif speaker_type == "role":
334
- # 从 JSON 文件中读取数据
335
- with open('./slct_voice_240605.json', 'r', encoding='utf-8') as json_file:
336
- slct_idx_loaded = json.load(json_file)
337
- # 将包含 Tensor 数据的部分转换回 Tensor 对象
338
- for key in slct_idx_loaded:
339
- tensor_list = slct_idx_loaded[key]["tensor"]
340
- slct_idx_loaded[key]["tensor"] = torch.tensor(tensor_list)
341
- # 将音色 tensor 打包进params_infer_code,固定使用此音色发音,调低temperature
342
- rnd_spk_emb = slct_idx_loaded[roleid]["tensor"]
343
- # temperature = 0.001
344
- elif speaker_type == "pt":
345
- print(pt_file)
346
- rnd_spk_emb = torch.load(pt_file)
347
- print(rnd_spk_emb.shape)
348
- if rnd_spk_emb.shape != (768,):
349
- raise ValueError("维度应为 768。")
350
- else:
351
- raise ValueError(f"Invalid speaker_type: {speaker_type}. ")
352
-
353
- params_infer_code = {
354
- 'spk_emb': rnd_spk_emb,
355
- 'prompt': f'[speed_{speed}]',
356
- 'top_P': top_P,
357
- 'top_K': top_K,
358
- 'temperature': temperature
359
- }
360
- params_refine_text = {
361
- 'prompt': refine_text_prompt,
362
- 'top_P': top_P,
363
- 'top_K': top_K,
364
- 'temperature': temperature
365
- }
366
-
367
- if stream_mode == "real":
368
- for text in texts:
369
- _params_infer_code = {**params_infer_code}
370
- wavs_gen = chat.infer(text, params_infer_code=_params_infer_code, params_refine_text=params_refine_text,
371
- use_decoder=True, skip_refine_text=True, stream=True)
372
- for gen in wavs_gen:
373
- wavs = [np.array([[]])]
374
- wavs[0] = np.hstack([wavs[0], np.array(gen[0])])
375
- audio = wavs[0][0]
376
- yield 24000, normalize_audio(audio)
377
-
378
- clear_cuda_cache()
379
- else:
380
- for text in batch_split(texts, batch_size):
381
- _params_infer_code = {**params_infer_code}
382
- wavs = chat.infer(text, params_infer_code=_params_infer_code, params_refine_text=params_refine_text,
383
- use_decoder=True, skip_refine_text=False, stream=False)
384
- combined_audio = combine_audio(wavs)
385
- yield 24000, combined_audio[0]
386
-
387
-
388
- def generate_refine(text_file, oral, laugh, bk, temperature, top_P, top_K, progress=gr.Progress()):
389
- from tts_model import generate_refine_text
390
- from utils import split_text, replace_tokens, restore_tokens, replace_space_between_chinese
391
- seed = random.randint(1, 9999)
392
- refine_text_prompt = f"[oral_{oral}][laugh_{laugh}][break_{bk}]"
393
- content = ''
394
- if os.path.isfile(text_file):
395
- content = ""
396
- elif isinstance(text_file, str):
397
- content = text_file
398
- if re.search(r'\[uv_break\]|\[laugh\]', content) is not None:
399
- gr.Info("检测到 [uv_break] [laugh],不能重复 refine ")
400
- # print("检测到 [uv_break] [laugh],不能重复 refine ")
401
- return content
402
- batch_size = 5
403
-
404
- content = replace_tokens(content)
405
- texts = split_text(content, min_length=120)
406
- print(texts)
407
- for i, text in enumerate(texts):
408
- texts[i] = restore_tokens(text)
409
- txts = []
410
- for batch in progress.tqdm(batch_split(texts, batch_size), desc=f"Refine Text Please Wait ..."):
411
- txts.extend(generate_refine_text(chat, seed, batch, refine_text_prompt, temperature, top_P, top_K))
412
- return replace_space_between_chinese('\n\n'.join(txts))
413
 
414
 
415
  def generate_seed():
@@ -422,28 +253,10 @@ def generate_seed():
422
 
423
  def update_label(text):
424
  word_count = len(text)
425
- return gr.update(label=f"朗读文本({word_count} 字)")
426
-
427
-
428
- def inser_token(text, btn):
429
- if btn == "+笑声":
430
- return gr.update(
431
- value=text + "[laugh]"
432
- )
433
- elif btn == "+停顿":
434
- return gr.update(
435
- value=text + "[uv_break]"
436
- )
437
 
438
 
439
  with gr.Blocks() as demo:
440
- # 项目链接
441
- gr.Markdown("""
442
- <div style='text-align: center; font-size: 16px;'>
443
- 🌟 <a href='https://github.com/6drf21e/ChatTTS_colab'>项目地址 欢迎 start</a> 🌟
444
- </div>
445
- """)
446
-
447
  with gr.Tab("音色抽卡"):
448
  with gr.Row():
449
  with gr.Column(scale=1):
@@ -454,10 +267,6 @@ with gr.Blocks() as demo:
454
  ]
455
  # gr.Markdown("### 随机音色抽卡")
456
  gr.Markdown("""
457
- 免抽卡,直接找稳定音色👇
458
-
459
- [ModelScope ChatTTS Speaker(国内)](https://modelscope.cn/studios/ttwwwaa/ChatTTS_Speaker) | [HuggingFace ChatTTS Speaker(国外)](https://huggingface.co/spaces/taa/ChatTTS_Speaker)
460
-
461
  在相同的 seed 和 温度等参数下,音色具有一定的一致性。点击下面的“随机音色生成”按钮将生成多个 seed。找到满意的音色后,点击音频下方“保存”按钮。
462
  **注意:不同机器使用相同种子生成的音频音色可能不同,同一机器使用相同种子多次生成的音频音色也可能变化。**
463
  """)
@@ -474,29 +283,21 @@ with gr.Blocks() as demo:
474
  gr.Markdown("### 种子管理界面")
475
  seed_list = gr.DataFrame(
476
  label="种子列表",
477
- headers=["Index", "Seed", "Name", "Path"],
478
- datatype=["number", "number", "str", "str"],
479
  interactive=True,
480
- col_count=(4, "fixed"),
481
- value=display_seeds
482
  )
483
-
484
  with gr.Row():
485
  refresh_button = gr.Button("刷新")
486
  save_button = gr.Button("保存")
487
  del_button = gr.Button("删除")
488
- play_button = gr.Button("试听")
489
-
490
- with gr.Row():
491
- # 添加已保存的种子音频播放组件
492
- audio_player = gr.Audio(label="播放已保存种子音频", visible=False)
493
-
494
  # 绑定按钮和函数
495
  refresh_button.click(display_seeds, outputs=seed_list)
496
- seed_list.select(seed_change).success(seed_change_btn, outputs=[del_button, play_button])
497
  save_button.click(do_save_seeds, inputs=[seed_list], outputs=None)
498
  del_button.click(do_delete_seed, inputs=del_button, outputs=seed_list)
499
- play_button.click(do_play_seed, inputs=play_button, outputs=audio_player)
500
 
501
  with gr.Column(scale=1):
502
  audio_components = []
@@ -504,13 +305,12 @@ with gr.Blocks() as demo:
504
  visible = i < num_seeds_default
505
  a = gr.Audio(f"Audio {i}", visible=visible)
506
  t = gr.Button(f"Seed", visible=visible)
507
- s = gr.State(value=None)
508
- t.click(do_save_seed, inputs=[t, s], outputs=None).success(display_seeds, outputs=seed_list)
509
  audio_components.append(a)
510
  audio_components.append(t)
511
- audio_components.append(s)
512
 
513
  num_seeds.change(update_audio_components, inputs=num_seeds, outputs=audio_components)
 
514
  # output = gr.Column()
515
  # audio = gr.Audio(label="Output Audio")
516
 
@@ -530,136 +330,46 @@ with gr.Blocks() as demo:
530
  placeholder="Please Input Text...", value=default_text)
531
  # 当文本框内容发生变化时调用 update_label 函数
532
  text_file_input.change(update_label, inputs=text_file_input, outputs=text_file_input)
533
- # 加入停顿按钮
534
- with gr.Row():
535
- break_button = gr.Button("+停顿", variant="secondary")
536
- laugh_button = gr.Button("+笑声", variant="secondary")
537
- refine_button = gr.Button("Refine Text(预处理 加入停顿词、笑声等)", variant="secondary")
538
 
539
  with gr.Column():
540
  gr.Markdown("### 配置参数")
 
541
  with gr.Row():
542
- with gr.Column():
543
- gr.Markdown("音色选择")
544
- num_seeds_input = gr.Number(label="生成音频的数量", value=1, precision=0, visible=False)
545
- speaker_stat = gr.State(value="seed")
546
- tab_seed = gr.Tab(label="种子")
547
- with tab_seed:
548
- with gr.Row():
549
- seed_input = gr.Number(label="指定种子", info="种子决定音色 0则随机", value=None,
550
- precision=0)
551
- generate_audio_seed = gr.Button("\U0001F3B2")
552
- tab_roleid = gr.Tab(label="内置音色")
553
- with tab_roleid:
554
- roleid_input = gr.Dropdown(label="内置音色",
555
- choices=[("发姐", "1"),
556
- ("纯情男大学生", "2"),
557
- ("阳光开朗大男孩", "3"),
558
- ("知心小姐姐", "4"),
559
- ("电视台女主持", "5"),
560
- ("魅力大叔", "6"),
561
- ("优雅甜美", "7"),
562
- ("贴心男宝2", "21"),
563
- ("正式打工人", "8"),
564
- ("贴心男宝1", "9")],
565
- value="1",
566
- info="选择音色后会覆盖种子。感谢 @QuantumDriver 提供音色")
567
- tab_pt = gr.Tab(label="上传.PT文件")
568
- with tab_pt:
569
- pt_input = gr.File(label="上传音色文件", file_types=[".pt"], height=100)
570
 
571
  with gr.Row():
572
- style_select = gr.Radio(label="预设参数", info="语速部分可自行更改",
573
- choices=["小说朗读", "对话", "中英混合", "默认"], value="默认",
574
- interactive=True, )
575
- with gr.Row():
576
- # refine
577
- refine_text_input = gr.Checkbox(label="Refine",
578
- info="打开后会自动根据下方参数添加笑声/停顿等。关闭后可自行添加 [uv_break] [laugh] 或者点击下方 Refin按钮先行转换",
579
- value=True)
580
- speed_input = gr.Slider(label="语速", minimum=1, maximum=10, value=DEFAULT_SPEED, step=1)
581
- with gr.Row():
582
- oral_input = gr.Slider(label="口语化", minimum=0, maximum=9, value=DEFAULT_ORAL, step=1)
583
- laugh_input = gr.Slider(label="笑声", minimum=0, maximum=2, value=DEFAULT_LAUGH, step=1)
584
- bk_input = gr.Slider(label="停顿", minimum=0, maximum=7, value=DEFAULT_BK, step=1)
585
  # gr.Markdown("### 文本参数")
586
  with gr.Row():
587
- min_length_input = gr.Number(label="文本分段长度", info="大于这个数值进行分段",
588
- value=DEFAULT_SEG_LENGTH, precision=0)
589
- batch_size_input = gr.Number(label="批大小", info="越高越快 太高爆显存 4G推荐3 其他酌情",
590
- value=DEFAULT_BATCH_SIZE, precision=0)
591
  with gr.Accordion("其他参数", open=False):
592
  with gr.Row():
593
  # 温度 top_P top_K
594
- temperature_input = gr.Slider(label="温度", minimum=0.01, maximum=1.0, step=0.01,
595
- value=DEFAULT_TEMPERATURE)
596
- top_P_input = gr.Slider(label="top_P", minimum=0.1, maximum=0.9, step=0.05, value=DEFAULT_TOP_P)
597
- top_K_input = gr.Slider(label="top_K", minimum=1, maximum=20, step=1, value=DEFAULT_TOP_K)
598
  # reset 按钮
599
  reset_button = gr.Button("重置")
600
 
601
  with gr.Row():
602
- with gr.Column():
603
- generate_button = gr.Button("生成音频", variant="primary")
604
- with gr.Column():
605
- generate_button_stream = gr.Button("流式生成音频(一边播放一边推理)", variant="primary")
606
- stream_select = gr.Radio(label="流输出方式",
607
- info="真流式为实验功能,播放效果:卡播卡播卡播(⏳🎵⏳🎵⏳🎵);伪流式为分段推理后输出,播放效果:卡卡卡播播播播(⏳⏳🎵🎵🎵🎵)。伪流式批次建议4以上减少卡顿",
608
- choices=[("真", "real"), ("伪", "fake")], value="fake", interactive=True, )
609
 
610
  with gr.Row():
611
  output_audio = gr.Audio(label="生成的音频文件")
612
- output_audio_stream = gr.Audio(label="流式音频", value=None,
613
- streaming=True,
614
- autoplay=True,
615
- # disable auto play for Windows, due to https://developer.chrome.com/blog/autoplay#webaudio
616
- interactive=False,
617
- show_label=True)
618
 
619
  generate_audio_seed.click(generate_seed,
620
  inputs=[],
621
  outputs=seed_input)
622
 
623
-
624
- def do_tab_change(evt: gr.SelectData):
625
- print(evt.selected, evt.index, evt.value, evt.target)
626
- kv = {
627
- "种子": "seed",
628
- "内置音色": "role",
629
- "上传.PT文件": "pt"
630
- }
631
- return kv.get(evt.value, "seed")
632
-
633
-
634
- tab_seed.select(do_tab_change, outputs=speaker_stat)
635
- tab_roleid.select(do_tab_change, outputs=speaker_stat)
636
- tab_pt.select(do_tab_change, outputs=speaker_stat)
637
-
638
-
639
- def do_style_select(x):
640
- if x == "小说朗读":
641
- return [4, 0, 0, 2]
642
- elif x == "对话":
643
- return [5, 5, 1, 4]
644
- elif x == "中英混合":
645
- return [4, 1, 0, 3]
646
- else:
647
- return [DEFAULT_SPEED, DEFAULT_ORAL, DEFAULT_LAUGH, DEFAULT_BK]
648
-
649
-
650
- # style_select 选择
651
- style_select.change(
652
- do_style_select,
653
- inputs=style_select,
654
- outputs=[speed_input, oral_input, laugh_input, bk_input]
655
- )
656
-
657
- # refine 按钮
658
- refine_button.click(
659
- generate_refine,
660
- inputs=[text_file_input, oral_input, laugh_input, bk_input, temperature_input, top_P_input, top_K_input],
661
- outputs=text_file_input
662
- )
663
  # 重置按钮 重置温度等参数
664
  reset_button.click(
665
  lambda: [0.3, 0.7, 20],
@@ -682,50 +392,9 @@ with gr.Blocks() as demo:
682
  temperature_input,
683
  top_P_input,
684
  top_K_input,
685
- roleid_input,
686
- refine_text_input,
687
- speaker_stat,
688
- pt_input
689
  ],
690
  outputs=[output_audio]
691
  )
692
-
693
- generate_button_stream.click(
694
- fn=generate_tts_audio_stream,
695
- inputs=[
696
- text_file_input,
697
- num_seeds_input,
698
- seed_input,
699
- speed_input,
700
- oral_input,
701
- laugh_input,
702
- bk_input,
703
- min_length_input,
704
- batch_size_input,
705
- temperature_input,
706
- top_P_input,
707
- top_K_input,
708
- roleid_input,
709
- refine_text_input,
710
- speaker_stat,
711
- pt_input,
712
- stream_select
713
- ],
714
- outputs=[output_audio_stream]
715
- )
716
-
717
- break_button.click(
718
- inser_token,
719
- inputs=[text_file_input, break_button],
720
- outputs=text_file_input
721
- )
722
-
723
- laugh_button.click(
724
- inser_token,
725
- inputs=[text_file_input, laugh_button],
726
- outputs=text_file_input
727
- )
728
-
729
  with gr.Tab("角色扮演"):
730
  def txt_2_script(text):
731
  lines = text.split("\n")
@@ -757,7 +426,7 @@ with gr.Blocks() as demo:
757
  characters = list([_["character"] for _ in lines])
758
  unique_characters = list(dict.fromkeys(characters))
759
  print([[character, 0] for character in unique_characters])
760
- return [[character, 0, 5, 2, 0, 4] for character in unique_characters]
761
 
762
 
763
  def get_txt_characters(text):
@@ -784,7 +453,7 @@ with gr.Blocks() as demo:
784
  scripts = llm_operation(api_base, api_key, model, LLM_PROMPT, text, required_keys=["txt", "character"])
785
  return script_2_txt(scripts)
786
 
787
-
788
  def generate_script_audio(text, models_seeds, progress=gr.Progress()):
789
  scripts = txt_2_script(text) # 将文本转换为剧本
790
  characters = get_characters(scripts) # 从剧本中提取角色
@@ -795,6 +464,7 @@ with gr.Blocks() as demo:
795
  import itertools
796
  from tts_model import generate_audio_for_seed
797
  from utils import combine_audio, save_audio, normalize_zh
 
798
 
799
  assert isinstance(models_seeds, pd.DataFrame)
800
 
@@ -807,40 +477,18 @@ with gr.Blocks() as demo:
807
  break
808
  yield batch
809
 
810
- column_mapping = {
811
- '角色': 'character',
812
- '种子': 'seed',
813
- '语速': 'speed',
814
- '口语': 'oral',
815
- '笑声': 'laugh',
816
- '停顿': 'break'
817
- }
818
- # 使用 rename 方法重命名 DataFrame 的列
819
- models_seeds = models_seeds.rename(columns=column_mapping).to_dict(orient='records')
820
- # models_seeds = models_seeds.to_dict(orient='records')
821
 
822
  # 检查每个角色是否都有对应的种子
823
- print(models_seeds)
824
- seed_lookup = {seed['character']: seed for seed in models_seeds}
825
-
826
- character_seeds = {}
827
- missing_seeds = []
828
- # 遍历所有角色
829
- for character in characters:
830
- character_name = character[0]
831
- seed_info = seed_lookup.get(character_name)
832
- if seed_info:
833
- character_seeds[character_name] = seed_info
834
- else:
835
- missing_seeds.append(character_name)
836
-
837
- if missing_seeds:
838
- missing_characters_str = ', '.join(missing_seeds)
839
- gr.Info(f"以下角色没有种子,请先设置种子:{missing_characters_str}")
840
- return None
841
-
842
- print(character_seeds)
843
- # return
844
  refine_text_prompt = "[oral_2][laugh_0][break_4]"
845
  all_wavs = []
846
 
@@ -854,21 +502,13 @@ with gr.Blocks() as demo:
854
  batch_size = 5 # 设置批次大小
855
  # 按角色处理
856
  for character, lines in progress.tqdm(grouped_lines.items(), desc="生成剧本音频"):
857
- info = character_seeds[character]
858
- seed = info["seed"]
859
- speed = info["speed"]
860
- orla = info["oral"]
861
- laugh = info["laugh"]
862
- bk = info["break"]
863
-
864
- refine_text_prompt = f"[oral_{orla}][laugh_{laugh}][break_{bk}]"
865
-
866
  # 按批次处理
867
  for batch_lines in batch(lines, batch_size):
868
  texts = [normalize_zh(line["txt"]) for line in batch_lines]
869
- print(f"seed={seed} t={texts} c={character} s={speed} r={refine_text_prompt}")
870
- wavs = generate_audio_for_seed(chat, int(seed), texts, DEFAULT_BATCH_SIZE, speed,
871
- refine_text_prompt, None, DEFAULT_TEMPERATURE, DEFAULT_TOP_P,
872
  DEFAULT_TOP_K, skip_save=True) # 批量处理文本
873
  batch_results[character].extend(wavs)
874
 
@@ -880,7 +520,8 @@ with gr.Blocks() as demo:
880
  # 合成所有音频
881
  audio = combine_audio(all_wavs)
882
  fname = f"script_{int(time.time())}.wav"
883
- return save_audio(fname, audio)
 
884
 
885
 
886
  script_example = {
@@ -915,7 +556,7 @@ with gr.Blocks() as demo:
915
  "txt": "当小红帽到达奶奶家时,她发现大灰狼伪装成了奶奶。",
916
  "character": "旁白"
917
  }, {
918
- "txt": "小红帽疑惑的问",
919
  "character": "旁白"
920
  }, {
921
  "txt": "奶奶,你的耳朵怎么这么尖?",
@@ -964,7 +605,7 @@ with gr.Blocks() as demo:
964
  placeholder="请输入API Base URL",
965
  value=r"https://api.openai.com/v1")
966
  openai_api_key_input = gr.Textbox(label="OpenAI API Key", placeholder="请输入API Key",
967
- value="sk-xxxxxxx", type="password")
968
  # AI提示词
969
  ai_text_input = gr.Textbox(label="剧情简介或者一段故事", placeholder="请输入文本...", lines=2,
970
  value=ai_text_default)
@@ -975,7 +616,7 @@ with gr.Blocks() as demo:
975
  with gr.Column(scale=3):
976
  gr.Markdown("### 脚本")
977
  gr.Markdown(
978
- "脚本可以手工编写也可以��左侧的AI脚本生成按钮生成。脚本格式 **角色::文本** 一行为一句” 注意是::")
979
  script_text = "\n".join(
980
  [f"{_.get('character', '')}::{_.get('txt', '')}" for _ in script_example['lines']])
981
 
@@ -987,20 +628,20 @@ with gr.Blocks() as demo:
987
  with gr.Column(scale=1):
988
  gr.Markdown("### 角色种子")
989
  # DataFrame 来存放转换后的脚本
990
- # 默认数据 [speed_5][oral_2][laugh_0][break_4]
991
  default_data = [
992
- ["旁白", 2222, 3, 0, 0, 2],
993
- ["年轻女性", 2, 5, 2, 0, 2],
994
- ["中年男性", 2424, 5, 2, 0, 2]
995
  ]
996
 
997
  script_data = gr.DataFrame(
998
  value=default_data,
999
  label="角色对应的音色种子,从抽卡那获取",
1000
- headers=["角色", "种子", "语速", "口语", "笑声", "停顿"],
1001
- datatype=["str", "number", "number", "number", "number", "number"],
1002
  interactive=True,
1003
- col_count=(6, "fixed"),
1004
  )
1005
  # 生视频按钮
1006
  script_generate_audio = gr.Button("步骤②:生成音频")
@@ -1033,4 +674,4 @@ with gr.Blocks() as demo:
1033
  outputs=[script_audio]
1034
  )
1035
 
1036
- demo.launch(share=args.share, inbrowser=True)
 
 
 
 
 
1
  import argparse
2
  import re
3
  import time
 
6
  import numpy as np
7
  from tqdm import tqdm
8
  import random
9
+ import os
10
  import gradio as gr
11
  import json
12
+ from utils import combine_audio, save_audio, batch_split, normalize_zh
13
+ from tts_model import load_chat_tts_model, clear_cuda_cache, deterministic, generate_audio_for_seed
14
+ import spaces
 
 
15
 
16
  parser = argparse.ArgumentParser(description="Gradio ChatTTS MIX")
17
  parser.add_argument("--source", type=str, default="huggingface", help="Model source: 'huggingface' or 'local'.")
 
40
 
41
  chat = load_chat_tts_model(source=args.source, local_path=args.local_path)
42
  # chat = None
43
+ # chat = load_chat_tts_model(source="local", local_path="models")
44
 
45
  # 抽卡的最大数量
46
  max_audio_components = 10
47
 
48
+
49
+ # print("loading ChatTTS model...")
50
+ # chat = ChatTTS.Chat()
51
+ # chat.load_models(source="local", local_path="models")
52
+ # torch.cuda.empty_cache()
53
+
54
+
55
  # 加载
56
  def load_seeds():
57
  with open(SAVED_SEEDS_FILE, "r") as f:
58
  global saved_seeds
59
+ saved_seeds = json.load(f)
 
 
 
 
 
 
 
 
60
  return saved_seeds
61
 
62
 
63
  def display_seeds():
64
  seeds = load_seeds()
65
  # 转换为 List[List] 的形式
66
+ return [[i, s['seed'], s['name']] for i, s in enumerate(seeds)]
67
 
68
 
69
  saved_seeds = load_seeds()
 
78
 
79
 
80
  # 添加 seed
81
+ def add_seed(seed, name, save=True):
82
  for s in saved_seeds:
83
  if s['seed'] == seed:
84
  return False
85
  saved_seeds.append({
86
  'seed': seed,
87
+ 'name': name
 
88
  })
89
  if save:
90
  save_seeds()
 
110
  return True
111
  return False
112
 
113
+ @spaces.GPU
114
  def generate_seeds(num_seeds, texts, tq):
115
  """
116
  生成随机音频种子并保存
 
129
  for _ in tq(range(num_seeds), desc=f"随机音色生成中..."):
130
  seed = np.random.randint(0, 9999)
131
 
132
+ filename = generate_audio_for_seed(chat, seed, texts, 1, 5, "[oral_2][laugh_0][break_4]", 0.3, 0.7, 20)
133
  seeds.append((filename, seed))
134
  clear_cuda_cache()
135
 
 
137
 
138
 
139
  # 保存选定的音频种子
140
+ def do_save_seed(seed):
 
141
  seed = seed.replace('保存种子 ', '').strip()
142
  if not seed:
143
  return
144
+ add_seed(int(seed), seed)
145
  gr.Info(f"Seed {seed} has been saved.")
146
 
147
 
 
173
  return display_seeds()
174
 
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  def seed_change_btn():
177
  global SELECTED_SEED_INDEX
178
  if SELECTED_SEED_INDEX == -1:
179
+ return '删除'
180
+ return f'删除 idx=[{SELECTED_SEED_INDEX[0]}]'
181
 
182
 
183
  def audio_interface(num_seeds, texts, progress=gr.Progress()):
 
194
  # 不足的部分
195
  all_wavs = wavs + [None] * (max_audio_components - len(wavs))
196
  all_seeds = seeds + [''] * (max_audio_components - len(seeds))
197
+ return [item for pair in zip(all_wavs, all_seeds) for item in pair]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
 
200
  def audio_interface_empty(num_seeds, texts, progress=gr.Progress(track_tqdm=True)):
201
+ return [None, ""] * max_audio_components
202
 
203
 
204
  def update_audio_components(slider_value):
 
206
  k = int(slider_value)
207
  audios = [gr.Audio(visible=True)] * k + [gr.Audio(visible=False)] * (max_audio_components - k)
208
  tbs = [gr.Textbox(visible=True)] * k + [gr.Textbox(visible=False)] * (max_audio_components - k)
 
209
  print(f'k={k}, audios={len(audios)}')
210
+ return [item for pair in zip(audios, tbs) for item in pair]
211
 
212
 
213
  def seed_change(evt: gr.SelectData):
 
216
  SELECTED_SEED_INDEX = evt.index
217
  return evt.index
218
 
219
+ @spaces.GPU
220
  def generate_tts_audio(text_file, num_seeds, seed, speed, oral, laugh, bk, min_length, batch_size, temperature, top_P,
221
+ top_K, progress=gr.Progress()):
222
  from tts_model import generate_audio_for_seed
223
+ from utils import split_text
224
  if seed in [0, -1, None]:
225
  seed = random.randint(1, 9999)
226
  content = ''
 
228
  content = ""
229
  elif isinstance(text_file, str):
230
  content = text_file
 
 
231
  texts = split_text(content, min_length=min_length)
232
+ print(texts)
 
233
 
234
  if oral < 0 or oral > 9 or laugh < 0 or laugh > 2 or bk < 0 or bk > 7:
235
  raise ValueError("oral_(0-9), laugh_(0-2), break_(0-7) out of range")
236
 
237
  refine_text_prompt = f"[oral_{oral}][laugh_{laugh}][break_{bk}]"
238
  try:
239
+ output_files = generate_audio_for_seed(chat, seed, texts, batch_size, speed, refine_text_prompt, temperature,
240
+ top_P, top_K, progress.tqdm)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  return output_files
242
  except Exception as e:
243
+ return str(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
 
246
  def generate_seed():
 
253
 
254
  def update_label(text):
255
  word_count = len(text)
256
+ return gr.update(label=f"朗读文本(字数: {word_count}")
 
 
 
 
 
 
 
 
 
 
 
257
 
258
 
259
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
260
  with gr.Tab("音色抽卡"):
261
  with gr.Row():
262
  with gr.Column(scale=1):
 
267
  ]
268
  # gr.Markdown("### 随机音色抽卡")
269
  gr.Markdown("""
 
 
 
 
270
  在相同的 seed 和 温度等参数下,音色具有一定的一致性。点击下面的“随机音色生成”按钮将生成多个 seed。找到满意的音色后,点击音频下方“保存”按钮。
271
  **注意:不同机器使用相同种子生成的音频音色可能不同,同一机器使用相同种子多次生成的音频音色也可能变化。**
272
  """)
 
283
  gr.Markdown("### 种子管理界面")
284
  seed_list = gr.DataFrame(
285
  label="种子列表",
286
+ headers=["Index", "Seed", "Name"],
287
+ datatype=["number", "number", "str"],
288
  interactive=True,
289
+ col_count=(3, "fixed"),
290
+ value=display_seeds()
291
  )
 
292
  with gr.Row():
293
  refresh_button = gr.Button("刷新")
294
  save_button = gr.Button("保存")
295
  del_button = gr.Button("删除")
 
 
 
 
 
 
296
  # 绑定按钮和函数
297
  refresh_button.click(display_seeds, outputs=seed_list)
298
+ seed_list.select(seed_change).success(seed_change_btn, outputs=[del_button])
299
  save_button.click(do_save_seeds, inputs=[seed_list], outputs=None)
300
  del_button.click(do_delete_seed, inputs=del_button, outputs=seed_list)
 
301
 
302
  with gr.Column(scale=1):
303
  audio_components = []
 
305
  visible = i < num_seeds_default
306
  a = gr.Audio(f"Audio {i}", visible=visible)
307
  t = gr.Button(f"Seed", visible=visible)
308
+ t.click(do_save_seed, inputs=[t], outputs=None).success(display_seeds, outputs=seed_list)
 
309
  audio_components.append(a)
310
  audio_components.append(t)
 
311
 
312
  num_seeds.change(update_audio_components, inputs=num_seeds, outputs=audio_components)
313
+
314
  # output = gr.Column()
315
  # audio = gr.Audio(label="Output Audio")
316
 
 
330
  placeholder="Please Input Text...", value=default_text)
331
  # 当文本框内容发生变化时调用 update_label 函数
332
  text_file_input.change(update_label, inputs=text_file_input, outputs=text_file_input)
 
 
 
 
 
333
 
334
  with gr.Column():
335
  gr.Markdown("### 配置参数")
336
+ gr.Markdown("根据需要配置以下参数来生成音频。")
337
  with gr.Row():
338
+ num_seeds_input = gr.Number(label="生成音频的数量", value=1, precision=0, visible=False)
339
+ seed_input = gr.Number(label="指定种子(留空则随机)", value=None, precision=0)
340
+ generate_audio_seed = gr.Button("\U0001F3B2")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
 
342
  with gr.Row():
343
+ speed_input = gr.Slider(label="语速", minimum=1, maximum=10, value=5, step=1)
344
+ oral_input = gr.Slider(label="口语化", minimum=0, maximum=9, value=2, step=1)
345
+
346
+ laugh_input = gr.Slider(label="笑声", minimum=0, maximum=2, value=0, step=1)
347
+ bk_input = gr.Slider(label="停顿", minimum=0, maximum=7, value=4, step=1)
 
 
 
 
 
 
 
 
348
  # gr.Markdown("### 文本参数")
349
  with gr.Row():
350
+ min_length_input = gr.Number(label="文本分段长度", info="大于这个数值进行分段", value=120,
351
+ precision=0)
352
+ batch_size_input = gr.Number(label="批大小", info="同时处理的批次 越高越快 太高爆显存", value=5,
353
+ precision=0)
354
  with gr.Accordion("其他参数", open=False):
355
  with gr.Row():
356
  # 温度 top_P top_K
357
+ temperature_input = gr.Slider(label="温度", minimum=0.01, maximum=1.0, step=0.01, value=0.3)
358
+ top_P_input = gr.Slider(label="top_P", minimum=0.1, maximum=0.9, step=0.05, value=0.7)
359
+ top_K_input = gr.Slider(label="top_K", minimum=1, maximum=20, step=1, value=20)
 
360
  # reset 按钮
361
  reset_button = gr.Button("重置")
362
 
363
  with gr.Row():
364
+ generate_button = gr.Button("生成音频", variant="primary")
 
 
 
 
 
 
365
 
366
  with gr.Row():
367
  output_audio = gr.Audio(label="生成的音频文件")
 
 
 
 
 
 
368
 
369
  generate_audio_seed.click(generate_seed,
370
  inputs=[],
371
  outputs=seed_input)
372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  # 重置按钮 重置温度等参数
374
  reset_button.click(
375
  lambda: [0.3, 0.7, 20],
 
392
  temperature_input,
393
  top_P_input,
394
  top_K_input,
 
 
 
 
395
  ],
396
  outputs=[output_audio]
397
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  with gr.Tab("角色扮演"):
399
  def txt_2_script(text):
400
  lines = text.split("\n")
 
426
  characters = list([_["character"] for _ in lines])
427
  unique_characters = list(dict.fromkeys(characters))
428
  print([[character, 0] for character in unique_characters])
429
+ return [[character, 0] for character in unique_characters]
430
 
431
 
432
  def get_txt_characters(text):
 
453
  scripts = llm_operation(api_base, api_key, model, LLM_PROMPT, text, required_keys=["txt", "character"])
454
  return script_2_txt(scripts)
455
 
456
+ @spaces.GPU
457
  def generate_script_audio(text, models_seeds, progress=gr.Progress()):
458
  scripts = txt_2_script(text) # 将文本转换为剧本
459
  characters = get_characters(scripts) # 从剧本中提取角色
 
464
  import itertools
465
  from tts_model import generate_audio_for_seed
466
  from utils import combine_audio, save_audio, normalize_zh
467
+ from config import DEFAULT_BATCH_SIZE, DEFAULT_SPEED, DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P
468
 
469
  assert isinstance(models_seeds, pd.DataFrame)
470
 
 
477
  break
478
  yield batch
479
 
480
+ models_seeds = models_seeds.to_dict(orient='records')
 
 
 
 
 
 
 
 
 
 
481
 
482
  # 检查每个角色是否都有对应的种子
483
+ for character, _ in characters:
484
+ if not any(seed['Character'] == character for seed in models_seeds):
485
+ gr.Info(f"角色 {character} 没有种子,请先设置种子。")
486
+ return None
487
+
488
+ # 将角色和对应的种子存为字典
489
+ character_seeds = {character: [seed['Seed'] for seed in models_seeds if seed['Character'] == character][0]
490
+ for character, _ in characters}
491
+ # todo 可以自定义 最好是按角色
 
 
 
 
 
 
 
 
 
 
 
 
492
  refine_text_prompt = "[oral_2][laugh_0][break_4]"
493
  all_wavs = []
494
 
 
502
  batch_size = 5 # 设置批次大小
503
  # 按角色处理
504
  for character, lines in progress.tqdm(grouped_lines.items(), desc="生成剧本音频"):
505
+ seed = character_seeds.get(character, 0)
 
 
 
 
 
 
 
 
506
  # 按批次处理
507
  for batch_lines in batch(lines, batch_size):
508
  texts = [normalize_zh(line["txt"]) for line in batch_lines]
509
+ print(f"seed={seed} t={texts} c={character}")
510
+ wavs = generate_audio_for_seed(chat, int(seed), texts, DEFAULT_BATCH_SIZE, DEFAULT_SPEED,
511
+ refine_text_prompt, DEFAULT_TEMPERATURE, DEFAULT_TOP_P,
512
  DEFAULT_TOP_K, skip_save=True) # 批量处理文本
513
  batch_results[character].extend(wavs)
514
 
 
520
  # 合成所有音频
521
  audio = combine_audio(all_wavs)
522
  fname = f"script_{int(time.time())}.wav"
523
+ save_audio(fname, audio)
524
+ return fname
525
 
526
 
527
  script_example = {
 
556
  "txt": "当小红帽到达奶奶家时,她发现大灰狼伪装成了奶奶。",
557
  "character": "旁白"
558
  }, {
559
+ "txt": "小红帽疑惑地问",
560
  "character": "旁白"
561
  }, {
562
  "txt": "奶奶,你的耳朵怎么这么尖?",
 
605
  placeholder="请输入API Base URL",
606
  value=r"https://api.openai.com/v1")
607
  openai_api_key_input = gr.Textbox(label="OpenAI API Key", placeholder="请输入API Key",
608
+ value="sk-xxxxxxx")
609
  # AI提示词
610
  ai_text_input = gr.Textbox(label="剧情简介或者一段故事", placeholder="请输入文本...", lines=2,
611
  value=ai_text_default)
 
616
  with gr.Column(scale=3):
617
  gr.Markdown("### 脚本")
618
  gr.Markdown(
619
+ "脚本可以手工编写也可以从右侧的AI脚本生成按钮生成。脚本格式 **角色::文本** 一行为一句” 注意是::")
620
  script_text = "\n".join(
621
  [f"{_.get('character', '')}::{_.get('txt', '')}" for _ in script_example['lines']])
622
 
 
628
  with gr.Column(scale=1):
629
  gr.Markdown("### 角色种子")
630
  # DataFrame 来存放转换后的脚本
631
+ # 默认数据
632
  default_data = [
633
+ ["旁白", 2222],
634
+ ["年轻女性", 2],
635
+ ["中年男性", 2424]
636
  ]
637
 
638
  script_data = gr.DataFrame(
639
  value=default_data,
640
  label="角色对应的音色种子,从抽卡那获取",
641
+ headers=["Character", "Seed"],
642
+ datatype=["str", "number"],
643
  interactive=True,
644
+ col_count=(2, "fixed"),
645
  )
646
  # 生视频按钮
647
  script_generate_audio = gr.Button("步骤②:生成音频")
 
674
  outputs=[script_audio]
675
  )
676
 
677
+ demo.launch(share=args.share)