ford442 commited on
Commit
792bc61
·
verified ·
1 Parent(s): 5d5e503

Update audiocraft/models/loaders.py

Browse files
Files changed (1) hide show
  1. audiocraft/models/loaders.py +2 -0
audiocraft/models/loaders.py CHANGED
@@ -121,8 +121,10 @@ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', depth='f
121
  model = builders.get_lm_model(cfg)
122
  if depth=='bfloat16':
123
  model = model.to(torch.bfloat16)
 
124
  if depth=='float16':
125
  model = model.to(torch.float16)
 
126
  model.load_state_dict(pkg['best_state'])
127
  model.eval()
128
  model.cfg = cfg
 
121
  model = builders.get_lm_model(cfg)
122
  if depth=='bfloat16':
123
  model = model.to(torch.bfloat16)
124
+ cfg.dtype = 'bfloat16'
125
  if depth=='float16':
126
  model = model.to(torch.float16)
127
+ cfg.dtype = 'float16'
128
  model.load_state_dict(pkg['best_state'])
129
  model.eval()
130
  model.cfg = cfg