winglian commited on
Commit
bb991fd
·
1 Parent(s): d653859

fix bug when model_type not explicitly passed

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +1 -1
src/axolotl/utils/models.py CHANGED
@@ -35,7 +35,7 @@ def load_model(
35
  # TODO refactor as a kwarg
36
  load_in_8bit = cfg.load_in_8bit
37
  tokenizer = None
38
- is_llama_derived_model = "llama" in base_model or "llama" in cfg.model_type.lower()
39
 
40
  if is_llama_derived_model and cfg.flash_attention:
41
  if cfg.device not in ["mps", "cpu"] and inference is False:
 
35
  # TODO refactor as a kwarg
36
  load_in_8bit = cfg.load_in_8bit
37
  tokenizer = None
38
+ is_llama_derived_model = "llama" in base_model or (cfg.model_type and "llama" in cfg.model_type.lower())
39
 
40
  if is_llama_derived_model and cfg.flash_attention:
41
  if cfg.device not in ["mps", "cpu"] and inference is False: