Gregniuki commited on
Commit
c7c1bcf
·
verified ·
1 Parent(s): 48dfb69

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -31
app.py CHANGED
@@ -191,39 +191,30 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence,
191
  spectrograms = []
192
 
193
 
194
- for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
195
- if len(ref_text[-1].encode('utf-8')) == 1:
196
- ref_text = gen_text
197
-
198
- # Prepare the text
199
- #if len(ref_text[-1].encode('utf-8')) == 1:
200
- # ref_text = ref_text + " "
201
- # gen_text = gen_text
202
- text_list = [ref_text + gen_text]
203
- final_text_list = convert_char_to_pinyin(text_list)
204
-
205
- # Calculate duration
206
- # ref_audio_len = audio.shape[-1] // hop_length
207
- # zh_pause_punc = r"。,、;:?!,"
208
- # gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
209
- # duration = min(10, max(1, int(round(gen_text_len / (speed * 10)))))*100
210
-
211
- # Calculate duration
212
- print(f"ref len: {len(ref_text.encode('utf-8'))} chars")
213
- print(f"gen len: {len(gen_text.encode('utf-8'))} chars")
214
- ref_audio_len = audio.shape[-1] // hop_length
215
- zh_pause_punc = r"。,、;:?!"
216
- ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
217
- gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
218
- if len(ref_text[-1].encode('utf-8')) == 1:
219
- duration = min(2000, max(270, (ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed))))
220
- print(f"Duration: {duration} seconds")
221
- else:
222
- ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
223
- duration = min(2000, max(270, (ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed))))
224
 
225
- print(f"Duration: {duration} seconds")
 
 
 
 
 
 
 
226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  # inference
228
  with torch.inference_mode():
229
  generated, _ = ema_model.sample(
 
191
  spectrograms = []
192
 
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
+ for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
196
+ # If the last character of ref_text is a single byte (likely punctuation or space), use gen_text as reference
197
+ if len(ref_text[-1].encode('utf-8')) == 1:
198
+ ref_text = ref_text + ' '
199
+
200
+ # Prepare the text for pinyin conversion
201
+ text_list = [ref_text + gen_text]
202
+ final_text_list = convert_char_to_pinyin(text_list)
203
 
204
+ # Calculate text lengths including punctuation-based adjustments
205
+ print(f"ref len: {len(ref_text.encode('utf-8'))} chars")
206
+ print(f"gen len: {len(gen_text.encode('utf-8'))} chars")
207
+
208
+ ref_audio_len = audio.shape[-1] // hop_length
209
+ zh_pause_punc = r"。,、;:?!"
210
+ ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
211
+ gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
212
+
213
+ # Calculate duration based on the lengths of ref_text and gen_text
214
+ duration = min(2000, max(270, (ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed))))
215
+
216
+ # Print the calculated duration
217
+ print(f"Duration: {duration} seconds")
218
  # inference
219
  with torch.inference_mode():
220
  generated, _ = ema_model.sample(