raoyonghui commited on
Commit
0faafc9
·
1 Parent(s): 6ec52a1

support long text synthesis

Browse files
Files changed (1) hide show
  1. app.py +115 -36
app.py CHANGED
@@ -45,6 +45,77 @@ def detect_speech_language(speech_file):
45
  _, probs = whisper_model.detect_language(mel)
46
  return max(probs, key=probs.get)
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  @torch.no_grad()
50
  def get_prompt_text(speech_16k, language):
@@ -320,43 +391,51 @@ def maskgct_inference(
320
  rescale_cfg_s2a=0.75,
321
  device=torch.device("cuda:0"),
322
  ):
323
- speech_16k = librosa.load(prompt_speech_path, sr=16000)[0]
324
- speech = librosa.load(prompt_speech_path, sr=24000)[0]
325
-
326
- prompt_language = detect_speech_language(prompt_speech_path)
327
- full_prompt_text, short_prompt_text, shot_prompt_end_ts = get_prompt_text(prompt_speech_path,
328
- prompt_language)
329
- # use the first 4+ seconds wav as the prompt in case the prompt wav is too long
330
- speech = speech[0: int(shot_prompt_end_ts * 24000)]
331
- speech_16k = speech_16k[0: int(shot_prompt_end_ts*16000)]
332
-
333
- target_language = detect_text_language(target_text)
334
- combine_semantic_code, _ = text2semantic(
335
- device,
336
- speech_16k,
337
- short_prompt_text,
338
- prompt_language,
339
- target_text,
340
- target_language,
341
- target_len,
342
- n_timesteps,
343
- cfg,
344
- rescale_cfg,
345
- )
346
- acoustic_code = extract_acoustic_code(torch.tensor(speech).unsqueeze(0).to(device))
347
- _, recovered_audio = semantic2acoustic(
348
- device,
349
- combine_semantic_code,
350
- acoustic_code,
351
- n_timesteps=n_timesteps_s2a,
352
- cfg=cfg_s2a,
353
- rescale_cfg=rescale_cfg_s2a,
354
- )
355
-
356
- return recovered_audio
 
 
 
 
 
 
 
 
357
 
358
 
359
- @spaces.GPU
360
  def inference(
361
  prompt_wav,
362
  target_text,
@@ -398,7 +477,7 @@ iface = gr.Interface(
398
  fn=inference,
399
  inputs=[
400
  gr.Audio(label="Upload Prompt Wav", type="filepath"),
401
- gr.Textbox(label="Target Text"),
402
  gr.Number(
403
  label="Target Duration (in seconds), if the target duration is less than 0, the system will estimate a duration.", value=-1
404
  ), # Removed 'optional=True'
 
45
  _, probs = whisper_model.detect_language(mel)
46
  return max(probs, key=probs.get)
47
 
48
+ def is_chinese(string):
49
+ """
50
+ check if the string contains any Chinese character
51
+ :return: bool
52
+ """
53
+ for ch in string:
54
+ if u'\u4e00' <= ch <= u'\u9fff':
55
+ return True
56
+ return False
57
+
58
+ def is_english(string):
59
+ """
60
+ check if the string contains any English leter
61
+ :return: bool
62
+ """
63
+ for ch in string:
64
+ if ch.isalpha():
65
+ return True
66
+ return False
67
+
68
+ def preprocess(sentence):
69
+ if is_chinese(sentence[-1]) or is_english(sentence[-1]):
70
+ sentence = sentence + "。"
71
+ if sentence[-1] == "!":
72
+ sentence = sentence[0:-1] + "!"
73
+ elif sentence[-1] == "?":
74
+ sentence = sentence[0:-1] + "?"
75
+ elif sentence[-1] not in ["?", "!"] :
76
+ sentence = sentence[0:-1] +"。"
77
+ return sentence
78
+
79
+
80
+ def split_paragraph(text):
81
+ sentences = []
82
+ first_punt_list = ";!?。!?;…"
83
+ second_punc_list = first_punt_list + ", ,"
84
+ third_punt_list = second_punc_list + "」)》”’』])>\"']】 "
85
+
86
+ fisrt_punc_check_start = 5
87
+ second_punc_check_start = 40
88
+ third_punc_check_start = 60
89
+ force_seg_len = 80
90
+ cur_length = 0.0
91
+ temp_sent = ""
92
+ for char in text:
93
+ temp_sent = temp_sent + char
94
+ if is_english(char):
95
+ cur_length = cur_length + 0.3
96
+ elif is_chinese(char):
97
+ cur_length = cur_length + 1
98
+ else:
99
+ cur_length = cur_length + 0.6
100
+ if cur_length < fisrt_punc_check_start:
101
+ continue
102
+ do_split = False
103
+ if char in first_punt_list:
104
+ do_split = True
105
+ elif cur_length > second_punc_check_start and char in second_punc_list:
106
+ do_split = True
107
+ elif cur_length > third_punc_check_start and char in third_punt_list:
108
+ do_split = True
109
+ elif cur_length > force_seg_len:
110
+ do_split = True
111
+ if do_split:
112
+ sentences.append(temp_sent)
113
+ cur_length = 0
114
+ temp_sent = ""
115
+ if len(temp_sent):
116
+ sentences.append(temp_sent)
117
+ return sentences
118
+
119
 
120
  @torch.no_grad()
121
  def get_prompt_text(speech_16k, language):
 
391
  rescale_cfg_s2a=0.75,
392
  device=torch.device("cuda:0"),
393
  ):
394
+ sentences = split_paragraph(target_text)
395
+ total_recovered_audio = None
396
+ print("split_paragraph: before:", target_text, "\nafter:", sentences)
397
+ for sentence in sentences:
398
+ target_text = preprocess(sentence)
399
+ speech_16k = librosa.load(prompt_speech_path, sr=16000)[0]
400
+ speech = librosa.load(prompt_speech_path, sr=24000)[0]
401
+ prompt_language = detect_speech_language(prompt_speech_path)
402
+ full_prompt_text, short_prompt_text, shot_prompt_end_ts = get_prompt_text(prompt_speech_path,
403
+ prompt_language)
404
+ # use the first 4+ seconds wav as the prompt in case the prompt wav is too long
405
+ speech = speech[0: int(shot_prompt_end_ts * 24000)]
406
+ speech_16k = speech_16k[0: int(shot_prompt_end_ts*16000)]
407
+
408
+ target_language = detect_text_language(target_text)
409
+ combine_semantic_code, _ = text2semantic(
410
+ device,
411
+ speech_16k,
412
+ short_prompt_text,
413
+ prompt_language,
414
+ target_text,
415
+ target_language,
416
+ target_len,
417
+ n_timesteps,
418
+ cfg,
419
+ rescale_cfg,
420
+ )
421
+ acoustic_code = extract_acoustic_code(torch.tensor(speech).unsqueeze(0).to(device))
422
+ _, recovered_audio = semantic2acoustic(
423
+ device,
424
+ combine_semantic_code,
425
+ acoustic_code,
426
+ n_timesteps=n_timesteps_s2a,
427
+ cfg=cfg_s2a,
428
+ rescale_cfg=rescale_cfg_s2a,
429
+ )
430
+ print("finish text:", target_text)
431
+ if total_recovered_audio is None:
432
+ total_recovered_audio = recovered_audio
433
+ else:
434
+ total_recovered_audio = np.concatenate([total_recovered_audio, recovered_audio])
435
+ return total_recovered_audio
436
 
437
 
438
+ @spaces.GPU(duration=300)
439
  def inference(
440
  prompt_wav,
441
  target_text,
 
477
  fn=inference,
478
  inputs=[
479
  gr.Audio(label="Upload Prompt Wav", type="filepath"),
480
+ gr.Textbox(label="Target Text", max_length=1024),
481
  gr.Number(
482
  label="Target Duration (in seconds), if the target duration is less than 0, the system will estimate a duration.", value=-1
483
  ), # Removed 'optional=True'