Gregniuki commited on
Commit
a591b69
·
verified ·
1 Parent(s): ef99879

Update infer/utils_infer.py

Browse files
Files changed (1) hide show
  1. infer/utils_infer.py +6 -6
infer/utils_infer.py CHANGED
@@ -116,7 +116,7 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
116
  vocoder.load_state_dict(state_dict)
117
 
118
  # Convert vocoder to bfloat16 if using a compatible device
119
- vocoder = vocoder.eval().to(device).to(torch.bfloat16)
120
 
121
  elif vocoder_name == "bigvgan":
122
  try:
@@ -132,7 +132,7 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
132
  vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
133
 
134
  vocoder.remove_weight_norm()
135
- vocoder = vocoder.eval().to(device).to(torch.bfloat16) # Convert to bfloat16
136
 
137
  return vocoder
138
 
@@ -147,7 +147,7 @@ def initialize_asr_pipeline(device: str = device, dtype=None):
147
  if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6:
148
  dtype = torch.float16
149
  elif "cpu" in device:
150
- dtype = torch.bfloat16
151
  else:
152
  dtype = torch.float32
153
 
@@ -185,7 +185,7 @@ def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
185
  if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6:
186
  dtype = torch.float16
187
  elif "cpu" in device:
188
- dtype = torch.bfloat16
189
  else:
190
  dtype = torch.float32
191
 
@@ -265,7 +265,7 @@ def load_model(
265
  vocab_char_map=vocab_char_map,
266
  ).to(device)
267
 
268
- dtype = torch.bfloat16 if mel_spec_type == "bigvgan" else None
269
  model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
270
 
271
  return model
@@ -471,7 +471,7 @@ def infer_batch_process(
471
  sway_sampling_coef=sway_sampling_coef,
472
  )
473
 
474
- generated = generated.to(torch.bfloat16)
475
  generated = generated[:, ref_audio_len:, :]
476
  generated_mel_spec = generated.permute(0, 2, 1)
477
  if mel_spec_type == "vocos":
 
116
  vocoder.load_state_dict(state_dict)
117
 
118
  # Convert vocoder to bfloat16 if using a compatible device
119
+ vocoder = vocoder.eval().to(device).to(torch.float32)
120
 
121
  elif vocoder_name == "bigvgan":
122
  try:
 
132
  vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
133
 
134
  vocoder.remove_weight_norm()
135
+ vocoder = vocoder.eval().to(device).to(torch.float32) # Convert to bfloat16
136
 
137
  return vocoder
138
 
 
147
  if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6:
148
  dtype = torch.float16
149
  elif "cpu" in device:
150
+ dtype = torch.float32
151
  else:
152
  dtype = torch.float32
153
 
 
185
  if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6:
186
  dtype = torch.float16
187
  elif "cpu" in device:
188
+ dtype = torch.float32
189
  else:
190
  dtype = torch.float32
191
 
 
265
  vocab_char_map=vocab_char_map,
266
  ).to(device)
267
 
268
+ dtype = torch.float32 if mel_spec_type == "bigvgan" else None
269
  model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
270
 
271
  return model
 
471
  sway_sampling_coef=sway_sampling_coef,
472
  )
473
 
474
+ generated = generated.to(torch.float32)
475
  generated = generated[:, ref_audio_len:, :]
476
  generated_mel_spec = generated.permute(0, 2, 1)
477
  if mel_spec_type == "vocos":