Gregniuki commited on
Commit
05a4c54
·
verified ·
1 Parent(s): df9c347

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -221,22 +221,22 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence,
221
  )
222
 
223
  # Process generated tensor
224
- generated = generated[:, ref_audio_len:, :]
225
- generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
226
 
227
  # Convert to appropriate dtype and device
228
- generated_mel_spec = generated_mel_spec.to(dtype=torch.float16, device=vocos.device) # Ensure device matches vocos
229
- generated_wave = vocos.decode(generated_mel_spec)
230
 
231
  # Adjust wave RMS if needed
232
- if rms < target_rms:
233
- generated_wave = generated_wave * rms / target_rms
234
 
235
  # Convert to numpy
236
- generated_wave = generated_wave.squeeze().cpu().numpy()
237
 
238
  # Append to list
239
- generated_waves.append(generated_wave)spectrograms.append(generated_mel_spec[0].cpu().numpy())
240
  # Ensure generated_mel_spec is in a compatible dtype (e.g., float32) before passing it to numpy
241
  # generated_mel_spec = generated_mel_spec.to(dtype=torch.float32) # Convert to float32 if it's in bfloat16
242
 
 
221
  )
222
 
223
  # Process generated tensor
224
+ generated = generated[:, ref_audio_len:, :]
225
+ generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
226
 
227
  # Convert to appropriate dtype and device
228
+ generated_mel_spec = generated_mel_spec.to(dtype=torch.float16, device=vocos.device) # Ensure device matches vocos
229
+ generated_wave = vocos.decode(generated_mel_spec)
230
 
231
  # Adjust wave RMS if needed
232
+ if rms < target_rms:
233
+ generated_wave = generated_wave * rms / target_rms
234
 
235
  # Convert to numpy
236
+ generated_wave = generated_wave.squeeze().cpu().numpy()
237
 
238
  # Append to list
239
+ generated_waves.append(generated_wave)spectrograms.append(generated_mel_spec[0].cpu().numpy())
240
  # Ensure generated_mel_spec is in a compatible dtype (e.g., float32) before passing it to numpy
241
  # generated_mel_spec = generated_mel_spec.to(dtype=torch.float32) # Convert to float32 if it's in bfloat16
242