Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -208,27 +208,35 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence,
|
|
208 |
print(f"Duration: {duration} seconds")
|
209 |
# inference
|
210 |
with torch.inference_mode():
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
# Ensure generated_mel_spec is in a compatible dtype (e.g., float32) before passing it to numpy
|
233 |
# generated_mel_spec = generated_mel_spec.to(dtype=torch.float32) # Convert to float32 if it's in bfloat16
|
234 |
|
|
|
208 |
print(f"Duration: {duration} seconds")
|
209 |
# inference
|
210 |
with torch.inference_mode():
|
211 |
+
# Ensure all inputs are on the same device as ema_model
|
212 |
+
audio = audio.to(ema_model.device) # Match ema_model's device
|
213 |
+
final_text_list = [t.to(ema_model.device) if isinstance(t, torch.Tensor) else t for t in final_text_list]
|
214 |
+
generated, _ = ema_model.sample(
|
215 |
+
cond=audio,
|
216 |
+
text=final_text_list,
|
217 |
+
duration=duration,
|
218 |
+
steps=nfe_step,
|
219 |
+
cfg_strength=cfg_strength,
|
220 |
+
sway_sampling_coef=sway_sampling_coef,
|
221 |
+
)
|
222 |
+
|
223 |
+
# Process generated tensor
|
224 |
+
generated = generated[:, ref_audio_len:, :]
|
225 |
+
generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
|
226 |
+
|
227 |
+
# Convert to appropriate dtype and device
|
228 |
+
generated_mel_spec = generated_mel_spec.to(dtype=torch.float16, device=vocos.device) # Ensure device matches vocos
|
229 |
+
generated_wave = vocos.decode(generated_mel_spec)
|
230 |
+
|
231 |
+
# Adjust wave RMS if needed
|
232 |
+
if rms < target_rms:
|
233 |
+
generated_wave = generated_wave * rms / target_rms
|
234 |
+
|
235 |
+
# Convert to numpy
|
236 |
+
generated_wave = generated_wave.squeeze().cpu().numpy()
|
237 |
+
|
238 |
+
# Append to list
|
239 |
+
generated_waves.append(generated_wave)spectrograms.append(generated_mel_spec[0].cpu().numpy())
|
240 |
# Ensure generated_mel_spec is in a compatible dtype (e.g., float32) before passing it to numpy
|
241 |
# generated_mel_spec = generated_mel_spec.to(dtype=torch.float32) # Convert to float32 if it's in bfloat16
|
242 |
|