ford442 commited on
Commit
5d5e503
·
verified ·
1 Parent(s): e81ccc6

Update audiocraft/models/loaders.py

Browse files
Files changed (1) hide show
  1. audiocraft/models/loaders.py +11 -9
audiocraft/models/loaders.py CHANGED
@@ -104,15 +104,17 @@ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', depth='f
104
  pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
105
  cfg = OmegaConf.create(pkg['xp.cfg'])
106
  cfg.device = str(device)
107
- if cfg.device == 'cpu':
108
- cfg.dtype = 'float32'
109
- else:
110
- if depth=='float32':
111
- cfg.dtype = 'float32'
112
- if depth=='bfloat16':
113
- cfg.dtype = 'bfloat16'
114
- if depth=='float16':
115
- cfg.dtype = 'float16'
 
 
116
  _delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
117
  _delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
118
  _delete_param(cfg, 'conditioners.args.drop_desc_p')
 
104
  pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
105
  cfg = OmegaConf.create(pkg['xp.cfg'])
106
  cfg.device = str(device)
107
+
108
+ cfg.dtype = 'float32'
109
+ #if cfg.device == 'cpu':
110
+ # cfg.dtype = 'float32'
111
+ #else:
112
+ # if depth=='float32':
113
+ # cfg.dtype = 'float32'
114
+ # if depth=='bfloat16':
115
+ # cfg.dtype = 'bfloat16'
116
+ # if depth=='float16':
117
+ # cfg.dtype = 'float16'
118
  _delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
119
  _delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
120
  _delete_param(cfg, 'conditioners.args.drop_desc_p')