Update audiocraft/models/loaders.py
Browse files- audiocraft/models/loaders.py +11 -9
audiocraft/models/loaders.py
CHANGED
@@ -104,15 +104,17 @@ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', depth='f
|
|
104 |
pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
|
105 |
cfg = OmegaConf.create(pkg['xp.cfg'])
|
106 |
cfg.device = str(device)
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
116 |
_delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
|
117 |
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
|
118 |
_delete_param(cfg, 'conditioners.args.drop_desc_p')
|
|
|
104 |
pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
|
105 |
cfg = OmegaConf.create(pkg['xp.cfg'])
|
106 |
cfg.device = str(device)
|
107 |
+
|
108 |
+
cfg.dtype = 'float32'
|
109 |
+
#if cfg.device == 'cpu':
|
110 |
+
# cfg.dtype = 'float32'
|
111 |
+
#else:
|
112 |
+
# if depth=='float32':
|
113 |
+
# cfg.dtype = 'float32'
|
114 |
+
# if depth=='bfloat16':
|
115 |
+
# cfg.dtype = 'bfloat16'
|
116 |
+
# if depth=='float16':
|
117 |
+
# cfg.dtype = 'float16'
|
118 |
_delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
|
119 |
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
|
120 |
_delete_param(cfg, 'conditioners.args.drop_desc_p')
|