fix relative path for fixtures
Browse files
src/axolotl/utils/models.py
CHANGED
|
@@ -129,6 +129,7 @@ def load_model(
|
|
| 129 |
llm_int8_threshold=6.0,
|
| 130 |
llm_int8_has_fp16_weight=False,
|
| 131 |
bnb_4bit_compute_dtype=torch_dtype,
|
|
|
|
| 132 |
bnb_4bit_use_double_quant=True,
|
| 133 |
bnb_4bit_quant_type="nf4",
|
| 134 |
)
|
|
@@ -280,8 +281,8 @@ def load_model(
|
|
| 280 |
# llama is PROBABLY model parallelizable, but the default isn't that it is
|
| 281 |
# so let's only set it for the 4bit, see
|
| 282 |
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
|
| 283 |
-
setattr(model,
|
| 284 |
-
setattr(model,
|
| 285 |
|
| 286 |
requires_grad = []
|
| 287 |
for name, param in model.named_parameters(recurse=True):
|
|
|
|
| 129 |
llm_int8_threshold=6.0,
|
| 130 |
llm_int8_has_fp16_weight=False,
|
| 131 |
bnb_4bit_compute_dtype=torch_dtype,
|
| 132 |
+
bnb_4bit_compute_dtype=torch_dtype,
|
| 133 |
bnb_4bit_use_double_quant=True,
|
| 134 |
bnb_4bit_quant_type="nf4",
|
| 135 |
)
|
|
|
|
| 281 |
# llama is PROBABLY model parallelizable, but the default isn't that it is
|
| 282 |
# so let's only set it for the 4bit, see
|
| 283 |
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
|
| 284 |
+
setattr(model, "is_parallelizable", True)
|
| 285 |
+
setattr(model, "model_parallel", True)
|
| 286 |
|
| 287 |
requires_grad = []
|
| 288 |
for name, param in model.named_parameters(recurse=True):
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -125,7 +125,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 125 |
output_dir=cfg.output_dir,
|
| 126 |
save_total_limit=3,
|
| 127 |
load_best_model_at_end=(
|
| 128 |
-
cfg.
|
|
|
|
| 129 |
and save_steps
|
| 130 |
and save_steps % eval_steps == 0
|
| 131 |
and cfg.load_in_8bit is not True
|
|
|
|
| 125 |
output_dir=cfg.output_dir,
|
| 126 |
save_total_limit=3,
|
| 127 |
load_best_model_at_end=(
|
| 128 |
+
cfg.load_best_model_at_end is not False
|
| 129 |
+
and cfg.val_set_size > 0
|
| 130 |
and save_steps
|
| 131 |
and save_steps % eval_steps == 0
|
| 132 |
and cfg.load_in_8bit is not True
|
tests/test_prompt_tokenizers.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
|
|
| 1 |
import json
|
| 2 |
import logging
|
| 3 |
import unittest
|
|
|
|
| 4 |
from pathlib import Path
|
| 5 |
|
| 6 |
from transformers import AutoTokenizer
|
|
@@ -12,6 +14,10 @@ logging.basicConfig(level="INFO")
|
|
| 12 |
|
| 13 |
|
| 14 |
class TestPromptTokenizationStrategies(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
def setUp(self) -> None:
|
| 16 |
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
| 17 |
self.tokenizer.add_special_tokens(
|
|
@@ -24,10 +30,15 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|
| 24 |
|
| 25 |
def test_sharegpt_integration(self):
|
| 26 |
print(Path(__file__).parent)
|
| 27 |
-
with open(
|
|
|
|
|
|
|
| 28 |
data = fin.read()
|
| 29 |
conversation = json.loads(data)
|
| 30 |
-
with open(
|
|
|
|
|
|
|
|
|
|
| 31 |
data = fin.read()
|
| 32 |
tokenized_conversation = json.loads(data)
|
| 33 |
prompter = ShareGPTPrompter("chat")
|
|
|
|
| 1 |
+
"""Module for testing prompt tokenizers."""
|
| 2 |
import json
|
| 3 |
import logging
|
| 4 |
import unittest
|
| 5 |
+
|
| 6 |
from pathlib import Path
|
| 7 |
|
| 8 |
from transformers import AutoTokenizer
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
class TestPromptTokenizationStrategies(unittest.TestCase):
|
| 17 |
+
"""
|
| 18 |
+
Test class for prompt tokenization strategies.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
def setUp(self) -> None:
|
| 22 |
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
| 23 |
self.tokenizer.add_special_tokens(
|
|
|
|
| 30 |
|
| 31 |
def test_sharegpt_integration(self):
|
| 32 |
print(Path(__file__).parent)
|
| 33 |
+
with open(
|
| 34 |
+
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
|
| 35 |
+
) as fin:
|
| 36 |
data = fin.read()
|
| 37 |
conversation = json.loads(data)
|
| 38 |
+
with open(
|
| 39 |
+
Path(__file__).parent / "fixtures/conversation.tokenized.json",
|
| 40 |
+
encoding="utf-8",
|
| 41 |
+
) as fin:
|
| 42 |
data = fin.read()
|
| 43 |
tokenized_conversation = json.loads(data)
|
| 44 |
prompter = ShareGPTPrompter("chat")
|