Spaces:
Running
on
Zero
Running
on
Zero
Update infer/utils_infer.py
Browse files- infer/utils_infer.py +6 -6
infer/utils_infer.py
CHANGED
@@ -116,7 +116,7 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
|
|
116 |
vocoder.load_state_dict(state_dict)
|
117 |
|
118 |
# Convert vocoder to bfloat16 if using a compatible device
|
119 |
-
vocoder = vocoder.eval().to(device).to(torch.
|
120 |
|
121 |
elif vocoder_name == "bigvgan":
|
122 |
try:
|
@@ -132,7 +132,7 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
|
|
132 |
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
|
133 |
|
134 |
vocoder.remove_weight_norm()
|
135 |
-
vocoder = vocoder.eval().to(device).to(torch.
|
136 |
|
137 |
return vocoder
|
138 |
|
@@ -147,7 +147,7 @@ def initialize_asr_pipeline(device: str = device, dtype=None):
|
|
147 |
if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6:
|
148 |
dtype = torch.float16
|
149 |
elif "cpu" in device:
|
150 |
-
dtype = torch.
|
151 |
else:
|
152 |
dtype = torch.float32
|
153 |
|
@@ -185,7 +185,7 @@ def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
|
|
185 |
if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6:
|
186 |
dtype = torch.float16
|
187 |
elif "cpu" in device:
|
188 |
-
dtype = torch.
|
189 |
else:
|
190 |
dtype = torch.float32
|
191 |
|
@@ -265,7 +265,7 @@ def load_model(
|
|
265 |
vocab_char_map=vocab_char_map,
|
266 |
).to(device)
|
267 |
|
268 |
-
dtype = torch.
|
269 |
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
270 |
|
271 |
return model
|
@@ -471,7 +471,7 @@ def infer_batch_process(
|
|
471 |
sway_sampling_coef=sway_sampling_coef,
|
472 |
)
|
473 |
|
474 |
-
generated = generated.to(torch.
|
475 |
generated = generated[:, ref_audio_len:, :]
|
476 |
generated_mel_spec = generated.permute(0, 2, 1)
|
477 |
if mel_spec_type == "vocos":
|
|
|
116 |
vocoder.load_state_dict(state_dict)
|
117 |
|
118 |
# Convert vocoder to bfloat16 if using a compatible device
|
119 |
+
vocoder = vocoder.eval().to(device).to(torch.float32)
|
120 |
|
121 |
elif vocoder_name == "bigvgan":
|
122 |
try:
|
|
|
132 |
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
|
133 |
|
134 |
vocoder.remove_weight_norm()
|
135 |
+
vocoder = vocoder.eval().to(device).to(torch.float32) # Convert to bfloat16
|
136 |
|
137 |
return vocoder
|
138 |
|
|
|
147 |
if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6:
|
148 |
dtype = torch.float16
|
149 |
elif "cpu" in device:
|
150 |
+
dtype = torch.float32
|
151 |
else:
|
152 |
dtype = torch.float32
|
153 |
|
|
|
185 |
if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6:
|
186 |
dtype = torch.float16
|
187 |
elif "cpu" in device:
|
188 |
+
dtype = torch.float32
|
189 |
else:
|
190 |
dtype = torch.float32
|
191 |
|
|
|
265 |
vocab_char_map=vocab_char_map,
|
266 |
).to(device)
|
267 |
|
268 |
+
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
269 |
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
270 |
|
271 |
return model
|
|
|
471 |
sway_sampling_coef=sway_sampling_coef,
|
472 |
)
|
473 |
|
474 |
+
generated = generated.to(torch.float32)
|
475 |
generated = generated[:, ref_audio_len:, :]
|
476 |
generated_mel_spec = generated.permute(0, 2, 1)
|
477 |
if mel_spec_type == "vocos":
|