don't use mask expansion for inference (#392)
Browse files- examples/llama-2/lora.yml +1 -0
- examples/llama-2/qlora.yml +1 -0
- src/axolotl/utils/models.py +4 -2
examples/llama-2/lora.yml
CHANGED
|
@@ -2,6 +2,7 @@ base_model: meta-llama/Llama-2-7b-hf
|
|
| 2 |
base_model_config: meta-llama/Llama-2-7b-hf
|
| 3 |
model_type: LlamaForCausalLM
|
| 4 |
tokenizer_type: LlamaTokenizer
|
|
|
|
| 5 |
|
| 6 |
load_in_8bit: true
|
| 7 |
load_in_4bit: false
|
|
|
|
| 2 |
base_model_config: meta-llama/Llama-2-7b-hf
|
| 3 |
model_type: LlamaForCausalLM
|
| 4 |
tokenizer_type: LlamaTokenizer
|
| 5 |
+
is_llama_derived_model: true
|
| 6 |
|
| 7 |
load_in_8bit: true
|
| 8 |
load_in_4bit: false
|
examples/llama-2/qlora.yml
CHANGED
|
@@ -2,6 +2,7 @@ base_model: meta-llama/Llama-2-7b-hf
|
|
| 2 |
base_model_config: meta-llama/Llama-2-7b-hf
|
| 3 |
model_type: LlamaForCausalLM
|
| 4 |
tokenizer_type: LlamaTokenizer
|
|
|
|
| 5 |
|
| 6 |
load_in_8bit: false
|
| 7 |
load_in_4bit: true
|
|
|
|
| 2 |
base_model_config: meta-llama/Llama-2-7b-hf
|
| 3 |
model_type: LlamaForCausalLM
|
| 4 |
tokenizer_type: LlamaTokenizer
|
| 5 |
+
is_llama_derived_model: true
|
| 6 |
|
| 7 |
load_in_8bit: false
|
| 8 |
load_in_4bit: true
|
src/axolotl/utils/models.py
CHANGED
|
@@ -138,8 +138,10 @@ def load_model(
|
|
| 138 |
LOG.info("patching with xpos rope")
|
| 139 |
replace_llama_rope_with_xpos_rope()
|
| 140 |
|
| 141 |
-
if
|
| 142 |
-
cfg.
|
|
|
|
|
|
|
| 143 |
):
|
| 144 |
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
| 145 |
|
|
|
|
| 138 |
LOG.info("patching with xpos rope")
|
| 139 |
replace_llama_rope_with_xpos_rope()
|
| 140 |
|
| 141 |
+
if (
|
| 142 |
+
cfg.is_llama_derived_model
|
| 143 |
+
and (cfg.max_packed_sequence_len or cfg.sample_packing)
|
| 144 |
+
and not cfg.inference
|
| 145 |
):
|
| 146 |
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
| 147 |
|