Gregniuki commited on
Commit
8600d7b
·
verified ·
1 Parent(s): 266d24a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -21
app.py CHANGED
@@ -208,27 +208,35 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence,
208
  print(f"Duration: {duration} seconds")
209
  # inference
210
  with torch.inference_mode():
211
- generated, _ = ema_model.sample(
212
- cond=audio,
213
- text=final_text_list,
214
- duration=duration,
215
- steps=nfe_step,
216
- cfg_strength=cfg_strength,
217
- sway_sampling_coef=sway_sampling_coef,
218
- )
219
-
220
- generated = generated[:, ref_audio_len:, :]
221
- generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
222
- generated_mel_spec = generated_mel_spec.to(dtype=torch.float16) # Convert to bfloat16
223
- generated_wave = vocos.decode(generated_mel_spec.cpu())
224
- if rms < target_rms:
225
- generated_wave = generated_wave * rms / target_rms
226
-
227
- # wav -> numpy
228
- generated_wave = generated_wave.squeeze().cpu().numpy()
229
-
230
- generated_waves.append(generated_wave)
231
- # spectrograms.append(generated_mel_spec[0].cpu().numpy())
 
 
 
 
 
 
 
 
232
  # Ensure generated_mel_spec is in a compatible dtype (e.g., float32) before passing it to numpy
233
  # generated_mel_spec = generated_mel_spec.to(dtype=torch.float32) # Convert to float32 if it's in bfloat16
234
 
 
208
  print(f"Duration: {duration} seconds")
209
  # inference
210
  with torch.inference_mode():
211
+ # Ensure all inputs are on the same device as ema_model
212
+ audio = audio.to(ema_model.device) # Match ema_model's device
213
+ final_text_list = [t.to(ema_model.device) if isinstance(t, torch.Tensor) else t for t in final_text_list]
214
+ generated, _ = ema_model.sample(
215
+ cond=audio,
216
+ text=final_text_list,
217
+ duration=duration,
218
+ steps=nfe_step,
219
+ cfg_strength=cfg_strength,
220
+ sway_sampling_coef=sway_sampling_coef,
221
+ )
222
+
223
+ # Process generated tensor
224
+ generated = generated[:, ref_audio_len:, :]
225
+ generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
226
+
227
+ # Convert to appropriate dtype and device
228
+ generated_mel_spec = generated_mel_spec.to(dtype=torch.float16, device=vocos.device) # Ensure device matches vocos
229
+ generated_wave = vocos.decode(generated_mel_spec)
230
+
231
+ # Adjust wave RMS if needed
232
+ if rms < target_rms:
233
+ generated_wave = generated_wave * rms / target_rms
234
+
235
+ # Convert to numpy
236
+ generated_wave = generated_wave.squeeze().cpu().numpy()
237
+
238
+ # Append to list
239
+ generated_waves.append(generated_wave)spectrograms.append(generated_mel_spec[0].cpu().numpy())
240
  # Ensure generated_mel_spec is in a compatible dtype (e.g., float32) before passing it to numpy
241
  # generated_mel_spec = generated_mel_spec.to(dtype=torch.float32) # Convert to float32 if it's in bfloat16
242