Update audiocraft/models/loaders.py
Browse files
audiocraft/models/loaders.py
CHANGED
@@ -121,8 +121,10 @@ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', depth='f
|
|
121 |
model = builders.get_lm_model(cfg)
|
122 |
if depth=='bfloat16':
|
123 |
model = model.to(torch.bfloat16)
|
|
|
124 |
if depth=='float16':
|
125 |
model = model.to(torch.float16)
|
|
|
126 |
model.load_state_dict(pkg['best_state'])
|
127 |
model.eval()
|
128 |
model.cfg = cfg
|
|
|
121 |
model = builders.get_lm_model(cfg)
|
122 |
if depth=='bfloat16':
|
123 |
model = model.to(torch.bfloat16)
|
124 |
+
cfg.dtype = 'bfloat16'
|
125 |
if depth=='float16':
|
126 |
model = model.to(torch.float16)
|
127 |
+
cfg.dtype = 'float16'
|
128 |
model.load_state_dict(pkg['best_state'])
|
129 |
model.eval()
|
130 |
model.cfg = cfg
|