Maxime tmm1 commited on
Commit
7fd662d
·
unverified ·
1 Parent(s): 9e69968

Update src/axolotl/utils/models.py

Browse files

Co-authored-by: Aman Gupta Karmani <[email protected]>

Files changed (1) hide show
  1. src/axolotl/utils/models.py +1 -1
src/axolotl/utils/models.py CHANGED
@@ -368,7 +368,7 @@ def load_model(
368
 
369
  # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
370
  # convert them back to fp16/bf16 for flash-attn compatibility.
371
- if (fix_dtype or not cfg.adapter) and (
372
  cfg.flash_attention and cfg.is_llama_derived_model
373
  ):
374
  for name, module in model.named_modules():
 
368
 
369
  # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
370
  # convert them back to fp16/bf16 for flash-attn compatibility.
371
+ if fix_dtype and (
372
  cfg.flash_attention and cfg.is_llama_derived_model
373
  ):
374
  for name, module in model.named_modules():