black formatting
Browse files- scripts/finetune.py +13 -8
- setup.py +9 -9
- src/axolotl/datasets.py +15 -4
- src/axolotl/prompt_tokenizers.py +13 -9
- src/axolotl/prompters.py +11 -8
- src/axolotl/utils/callbacks.py +11 -2
- src/axolotl/utils/data.py +20 -8
- src/axolotl/utils/models.py +22 -7
- src/axolotl/utils/schedulers.py +4 -1
- src/axolotl/utils/tokenization.py +3 -4
- src/axolotl/utils/trainer.py +11 -4
scripts/finetune.py
CHANGED
|
@@ -191,7 +191,9 @@ def train(
|
|
| 191 |
if cfg.debug:
|
| 192 |
logging.info("check_dataset_labels...")
|
| 193 |
check_dataset_labels(
|
| 194 |
-
train_dataset.select(
|
|
|
|
|
|
|
| 195 |
tokenizer,
|
| 196 |
)
|
| 197 |
|
|
@@ -218,17 +220,20 @@ def train(
|
|
| 218 |
logging.info("Starting trainer...")
|
| 219 |
resume_from_checkpoint = cfg.resume_from_checkpoint
|
| 220 |
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
| 221 |
-
possible_checkpoints = [
|
|
|
|
|
|
|
| 222 |
if len(possible_checkpoints) > 0:
|
| 223 |
-
sorted_paths = sorted(
|
|
|
|
|
|
|
| 224 |
resume_from_checkpoint = sorted_paths[-1]
|
| 225 |
-
logging.info(
|
|
|
|
|
|
|
| 226 |
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
| 227 |
|
| 228 |
-
logging.info(
|
| 229 |
-
f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}"
|
| 230 |
-
)
|
| 231 |
-
|
| 232 |
|
| 233 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
| 234 |
trainer.save_pretrained(cfg.output_dir)
|
|
|
|
| 191 |
if cfg.debug:
|
| 192 |
logging.info("check_dataset_labels...")
|
| 193 |
check_dataset_labels(
|
| 194 |
+
train_dataset.select(
|
| 195 |
+
[random.randrange(0, len(train_dataset) - 1) for i in range(5)]
|
| 196 |
+
),
|
| 197 |
tokenizer,
|
| 198 |
)
|
| 199 |
|
|
|
|
| 220 |
logging.info("Starting trainer...")
|
| 221 |
resume_from_checkpoint = cfg.resume_from_checkpoint
|
| 222 |
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
| 223 |
+
possible_checkpoints = [
|
| 224 |
+
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
| 225 |
+
]
|
| 226 |
if len(possible_checkpoints) > 0:
|
| 227 |
+
sorted_paths = sorted(
|
| 228 |
+
possible_checkpoints, key=lambda path: int(path.split("-")[-1])
|
| 229 |
+
)
|
| 230 |
resume_from_checkpoint = sorted_paths[-1]
|
| 231 |
+
logging.info(
|
| 232 |
+
f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
|
| 233 |
+
)
|
| 234 |
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
| 235 |
|
| 236 |
+
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
| 239 |
trainer.save_pretrained(cfg.output_dir)
|
setup.py
CHANGED
|
@@ -10,22 +10,22 @@ with open("./requirements.txt", "r") as requirements_file:
|
|
| 10 |
install_requires.append(r)
|
| 11 |
|
| 12 |
setup(
|
| 13 |
-
name=
|
| 14 |
-
version=
|
| 15 |
description="You know you're going to axolotl questions",
|
| 16 |
-
package_dir={
|
| 17 |
packages=find_packages(),
|
| 18 |
install_requires=install_requires,
|
| 19 |
extras_require={
|
| 20 |
-
|
| 21 |
"alpaca_lora_4bit @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
|
| 22 |
],
|
| 23 |
-
|
| 24 |
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
|
| 25 |
],
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
]
|
| 30 |
},
|
| 31 |
)
|
|
|
|
| 10 |
install_requires.append(r)
|
| 11 |
|
| 12 |
setup(
|
| 13 |
+
name="axolotl",
|
| 14 |
+
version="0.1",
|
| 15 |
description="You know you're going to axolotl questions",
|
| 16 |
+
package_dir={"": "src"},
|
| 17 |
packages=find_packages(),
|
| 18 |
install_requires=install_requires,
|
| 19 |
extras_require={
|
| 20 |
+
"int4": [
|
| 21 |
"alpaca_lora_4bit @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
|
| 22 |
],
|
| 23 |
+
"int4_triton": [
|
| 24 |
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
|
| 25 |
],
|
| 26 |
+
"extras": [
|
| 27 |
+
"flash-attn",
|
| 28 |
+
"deepspeed",
|
| 29 |
+
],
|
| 30 |
},
|
| 31 |
)
|
src/axolotl/datasets.py
CHANGED
|
@@ -31,6 +31,7 @@ class TokenizedPromptDataset(IterableDataset):
|
|
| 31 |
except InvalidDataException:
|
| 32 |
pass
|
| 33 |
|
|
|
|
| 34 |
# TODO this isn't the best since it can't interleave datasets
|
| 35 |
class ConstantLengthDataset(IterableDataset):
|
| 36 |
"""
|
|
@@ -40,6 +41,7 @@ class ConstantLengthDataset(IterableDataset):
|
|
| 40 |
dataset (dataset.Dataset): Dataset with text files.
|
| 41 |
seq_length (int): Length of token sequences to return.
|
| 42 |
"""
|
|
|
|
| 43 |
def __init__(
|
| 44 |
self,
|
| 45 |
tokenizer,
|
|
@@ -93,14 +95,19 @@ class ConstantLengthDataset(IterableDataset):
|
|
| 93 |
: self.seq_length
|
| 94 |
]
|
| 95 |
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
| 96 |
-
if
|
|
|
|
|
|
|
|
|
|
| 97 |
yield {
|
| 98 |
"input_ids": input_ids,
|
| 99 |
"labels": labels,
|
| 100 |
"attention_mask": attention_mask,
|
| 101 |
}
|
| 102 |
else:
|
| 103 |
-
logging.warning(
|
|
|
|
|
|
|
| 104 |
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
|
| 105 |
buffer_len = 0
|
| 106 |
|
|
@@ -116,11 +123,15 @@ class ConstantLengthDataset(IterableDataset):
|
|
| 116 |
attention_mask.append(1)
|
| 117 |
labels.append(self.concat_token_id)
|
| 118 |
|
| 119 |
-
input_ids_with_concat = torch.tensor(
|
|
|
|
|
|
|
| 120 |
attention_mask_with_concat = torch.tensor(
|
| 121 |
attention_mask, dtype=self.tokens_dtype
|
| 122 |
)
|
| 123 |
-
labels_with_concat = torch.tensor(
|
|
|
|
|
|
|
| 124 |
|
| 125 |
buffer["input_ids"].append(input_ids_with_concat)
|
| 126 |
buffer["attention_mask"].append(attention_mask_with_concat)
|
|
|
|
| 31 |
except InvalidDataException:
|
| 32 |
pass
|
| 33 |
|
| 34 |
+
|
| 35 |
# TODO this isn't the best since it can't interleave datasets
|
| 36 |
class ConstantLengthDataset(IterableDataset):
|
| 37 |
"""
|
|
|
|
| 41 |
dataset (dataset.Dataset): Dataset with text files.
|
| 42 |
seq_length (int): Length of token sequences to return.
|
| 43 |
"""
|
| 44 |
+
|
| 45 |
def __init__(
|
| 46 |
self,
|
| 47 |
tokenizer,
|
|
|
|
| 95 |
: self.seq_length
|
| 96 |
]
|
| 97 |
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
| 98 |
+
if (
|
| 99 |
+
labels.size() == input_ids.size()
|
| 100 |
+
and attention_mask.size() == input_ids.size()
|
| 101 |
+
):
|
| 102 |
yield {
|
| 103 |
"input_ids": input_ids,
|
| 104 |
"labels": labels,
|
| 105 |
"attention_mask": attention_mask,
|
| 106 |
}
|
| 107 |
else:
|
| 108 |
+
logging.warning(
|
| 109 |
+
"dropping batch due to tensor size mismatch"
|
| 110 |
+
)
|
| 111 |
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
|
| 112 |
buffer_len = 0
|
| 113 |
|
|
|
|
| 123 |
attention_mask.append(1)
|
| 124 |
labels.append(self.concat_token_id)
|
| 125 |
|
| 126 |
+
input_ids_with_concat = torch.tensor(
|
| 127 |
+
input_ids, dtype=self.tokens_dtype
|
| 128 |
+
)
|
| 129 |
attention_mask_with_concat = torch.tensor(
|
| 130 |
attention_mask, dtype=self.tokens_dtype
|
| 131 |
)
|
| 132 |
+
labels_with_concat = torch.tensor(
|
| 133 |
+
labels, dtype=self.tokens_dtype
|
| 134 |
+
)
|
| 135 |
|
| 136 |
buffer["input_ids"].append(input_ids_with_concat)
|
| 137 |
buffer["attention_mask"].append(attention_mask_with_concat)
|
src/axolotl/prompt_tokenizers.py
CHANGED
|
@@ -126,10 +126,8 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
|
| 126 |
|
| 127 |
|
| 128 |
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
| 129 |
-
def parse_instruction_fields(self, prompt) ->
|
| 130 |
-
return
|
| 131 |
-
prompt["text"]
|
| 132 |
-
)
|
| 133 |
|
| 134 |
def tokenize_prompt(self, prompt):
|
| 135 |
instruction = self.parse_instruction_fields(prompt)
|
|
@@ -139,9 +137,7 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
|
| 139 |
return tokenized_full_prompt
|
| 140 |
|
| 141 |
def _build_full_prompt(self, instruction):
|
| 142 |
-
return self.prompter.build_prompt(
|
| 143 |
-
instruction
|
| 144 |
-
)
|
| 145 |
|
| 146 |
|
| 147 |
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
@@ -149,8 +145,16 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
| 149 |
raise NotImplementedError
|
| 150 |
|
| 151 |
def tokenize_prompt(self, prompt):
|
| 152 |
-
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
tokenized_full_prompt = self._tokenize(full_prompt)
|
| 155 |
if not self.train_on_inputs:
|
| 156 |
user_prompt = self.prompter.build_prompt(
|
|
|
|
| 126 |
|
| 127 |
|
| 128 |
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
| 129 |
+
def parse_instruction_fields(self, prompt) -> str:
|
| 130 |
+
return prompt["text"]
|
|
|
|
|
|
|
| 131 |
|
| 132 |
def tokenize_prompt(self, prompt):
|
| 133 |
instruction = self.parse_instruction_fields(prompt)
|
|
|
|
| 137 |
return tokenized_full_prompt
|
| 138 |
|
| 139 |
def _build_full_prompt(self, instruction):
|
| 140 |
+
return self.prompter.build_prompt(instruction)
|
|
|
|
|
|
|
| 141 |
|
| 142 |
|
| 143 |
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
|
|
| 145 |
raise NotImplementedError
|
| 146 |
|
| 147 |
def tokenize_prompt(self, prompt):
|
| 148 |
+
(
|
| 149 |
+
instruction,
|
| 150 |
+
input,
|
| 151 |
+
output,
|
| 152 |
+
reflection,
|
| 153 |
+
corrected,
|
| 154 |
+
) = self.parse_instruction_fields(prompt)
|
| 155 |
+
full_prompt = self._build_full_prompt(
|
| 156 |
+
instruction, input, output, reflection, corrected
|
| 157 |
+
)
|
| 158 |
tokenized_full_prompt = self._tokenize(full_prompt)
|
| 159 |
if not self.train_on_inputs:
|
| 160 |
user_prompt = self.prompter.build_prompt(
|
src/axolotl/prompters.py
CHANGED
|
@@ -36,10 +36,7 @@ class JeopardyPrompter(AlpacaPrompter):
|
|
| 36 |
|
| 37 |
|
| 38 |
class CompletionPrompter(AlpacaPrompter):
|
| 39 |
-
def build_prompt(
|
| 40 |
-
self,
|
| 41 |
-
instruction: str
|
| 42 |
-
) -> str:
|
| 43 |
return instruction
|
| 44 |
|
| 45 |
def get_response(self, output: str) -> str:
|
|
@@ -75,7 +72,9 @@ class ReflectAlpacaPrompter:
|
|
| 75 |
else:
|
| 76 |
res = self.prompt_no_input.format(instruction=instruction)
|
| 77 |
if output and reflection and corrected:
|
| 78 |
-
label = self.agent_label.format(
|
|
|
|
|
|
|
| 79 |
res = f"{res}{label}"
|
| 80 |
return res
|
| 81 |
|
|
@@ -200,9 +199,13 @@ class ShareGPTPrompter:
|
|
| 200 |
if len(parts) != 2:
|
| 201 |
break
|
| 202 |
parts[0] += sep
|
| 203 |
-
round_len =
|
|
|
|
|
|
|
| 204 |
# we have to strip the initial part, any dangling whitespace creates an additional ghost token
|
| 205 |
-
instruction_len =
|
|
|
|
|
|
|
| 206 |
target[cur_len : cur_len + instruction_len] = [
|
| 207 |
IGNORE_TOKEN_ID
|
| 208 |
] * instruction_len
|
|
@@ -212,7 +215,7 @@ class ShareGPTPrompter:
|
|
| 212 |
break
|
| 213 |
|
| 214 |
# Fix: Truncate the target to have the same length as input_ids
|
| 215 |
-
target = target[:len(tokenized_result["input_ids"])]
|
| 216 |
# target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
|
| 217 |
|
| 218 |
attention_mask = [
|
|
|
|
| 36 |
|
| 37 |
|
| 38 |
class CompletionPrompter(AlpacaPrompter):
|
| 39 |
+
def build_prompt(self, instruction: str) -> str:
|
|
|
|
|
|
|
|
|
|
| 40 |
return instruction
|
| 41 |
|
| 42 |
def get_response(self, output: str) -> str:
|
|
|
|
| 72 |
else:
|
| 73 |
res = self.prompt_no_input.format(instruction=instruction)
|
| 74 |
if output and reflection and corrected:
|
| 75 |
+
label = self.agent_label.format(
|
| 76 |
+
output=output, reflection=reflection, corrected=corrected
|
| 77 |
+
)
|
| 78 |
res = f"{res}{label}"
|
| 79 |
return res
|
| 80 |
|
|
|
|
| 199 |
if len(parts) != 2:
|
| 200 |
break
|
| 201 |
parts[0] += sep
|
| 202 |
+
round_len = (
|
| 203 |
+
len(tokenizer(rou)["input_ids"]) - 1
|
| 204 |
+
) # -1 ignores the bos_token generated for this
|
| 205 |
# we have to strip the initial part, any dangling whitespace creates an additional ghost token
|
| 206 |
+
instruction_len = (
|
| 207 |
+
len(tokenizer(parts[0].strip())["input_ids"]) - 1
|
| 208 |
+
) # -1 ignores the bos_token generated for this
|
| 209 |
target[cur_len : cur_len + instruction_len] = [
|
| 210 |
IGNORE_TOKEN_ID
|
| 211 |
] * instruction_len
|
|
|
|
| 215 |
break
|
| 216 |
|
| 217 |
# Fix: Truncate the target to have the same length as input_ids
|
| 218 |
+
target = target[: len(tokenized_result["input_ids"])]
|
| 219 |
# target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
|
| 220 |
|
| 221 |
attention_mask = [
|
src/axolotl/utils/callbacks.py
CHANGED
|
@@ -1,8 +1,15 @@
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
| 5 |
|
|
|
|
| 6 |
class SavePeftModelCallback(TrainerCallback):
|
| 7 |
def on_save(
|
| 8 |
self,
|
|
@@ -11,7 +18,9 @@ class SavePeftModelCallback(TrainerCallback):
|
|
| 11 |
control: TrainerControl,
|
| 12 |
**kwargs,
|
| 13 |
):
|
| 14 |
-
checkpoint_folder = os.path.join(
|
|
|
|
|
|
|
| 15 |
|
| 16 |
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
| 17 |
kwargs["model"].save_pretrained(peft_model_path)
|
|
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
+
from transformers import (
|
| 4 |
+
Seq2SeqTrainer,
|
| 5 |
+
TrainerCallback,
|
| 6 |
+
TrainingArguments,
|
| 7 |
+
TrainerState,
|
| 8 |
+
TrainerControl,
|
| 9 |
+
)
|
| 10 |
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
| 11 |
|
| 12 |
+
|
| 13 |
class SavePeftModelCallback(TrainerCallback):
|
| 14 |
def on_save(
|
| 15 |
self,
|
|
|
|
| 18 |
control: TrainerControl,
|
| 19 |
**kwargs,
|
| 20 |
):
|
| 21 |
+
checkpoint_folder = os.path.join(
|
| 22 |
+
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
|
| 23 |
+
)
|
| 24 |
|
| 25 |
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
| 26 |
kwargs["model"].save_pretrained(peft_model_path)
|
src/axolotl/utils/data.py
CHANGED
|
@@ -2,7 +2,13 @@ import logging
|
|
| 2 |
from hashlib import md5
|
| 3 |
from pathlib import Path
|
| 4 |
|
| 5 |
-
from datasets import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from huggingface_hub import hf_hub_download
|
| 7 |
|
| 8 |
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
|
|
@@ -75,7 +81,9 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
| 75 |
else:
|
| 76 |
ds = load_dataset(d.path, streaming=True)
|
| 77 |
else:
|
| 78 |
-
fp = hf_hub_download(
|
|
|
|
|
|
|
| 79 |
ds = load_dataset("json", data_files=fp, streaming=True, split=None)
|
| 80 |
if not ds:
|
| 81 |
raise Exception("unhandled dataset load")
|
|
@@ -140,7 +148,9 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
| 140 |
samples = samples + [i for i in d]
|
| 141 |
dataset = Dataset.from_list(samples).shuffle(seed=42)
|
| 142 |
if cfg.local_rank == 0:
|
| 143 |
-
logging.info(
|
|
|
|
|
|
|
| 144 |
dataset.save_to_disk(prepared_ds_path)
|
| 145 |
|
| 146 |
if cfg.max_packed_sequence_len is not None:
|
|
@@ -153,12 +163,14 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
| 153 |
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
|
| 154 |
|
| 155 |
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
| 156 |
-
logging.info(
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
-
dataset = dataset.train_test_split(
|
| 160 |
-
test_size=cfg.val_set_size, shuffle=False
|
| 161 |
-
)
|
| 162 |
train_dataset = dataset["train"]
|
| 163 |
eval_dataset = dataset["test"]
|
| 164 |
|
|
|
|
| 2 |
from hashlib import md5
|
| 3 |
from pathlib import Path
|
| 4 |
|
| 5 |
+
from datasets import (
|
| 6 |
+
load_from_disk,
|
| 7 |
+
load_dataset,
|
| 8 |
+
IterableDataset,
|
| 9 |
+
Dataset,
|
| 10 |
+
concatenate_datasets,
|
| 11 |
+
)
|
| 12 |
from huggingface_hub import hf_hub_download
|
| 13 |
|
| 14 |
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
|
|
|
|
| 81 |
else:
|
| 82 |
ds = load_dataset(d.path, streaming=True)
|
| 83 |
else:
|
| 84 |
+
fp = hf_hub_download(
|
| 85 |
+
repo_id=d.path, repo_type="dataset", filename=d.data_files
|
| 86 |
+
)
|
| 87 |
ds = load_dataset("json", data_files=fp, streaming=True, split=None)
|
| 88 |
if not ds:
|
| 89 |
raise Exception("unhandled dataset load")
|
|
|
|
| 148 |
samples = samples + [i for i in d]
|
| 149 |
dataset = Dataset.from_list(samples).shuffle(seed=42)
|
| 150 |
if cfg.local_rank == 0:
|
| 151 |
+
logging.info(
|
| 152 |
+
f"Saving merged prepared dataset to disk... {prepared_ds_path}"
|
| 153 |
+
)
|
| 154 |
dataset.save_to_disk(prepared_ds_path)
|
| 155 |
|
| 156 |
if cfg.max_packed_sequence_len is not None:
|
|
|
|
| 163 |
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
|
| 164 |
|
| 165 |
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
| 166 |
+
logging.info(
|
| 167 |
+
f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
|
| 168 |
+
)
|
| 169 |
+
dataset = dataset.shard(
|
| 170 |
+
num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx
|
| 171 |
+
)
|
| 172 |
|
| 173 |
+
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
|
|
|
|
|
|
|
| 174 |
train_dataset = dataset["train"]
|
| 175 |
eval_dataset = dataset["test"]
|
| 176 |
|
src/axolotl/utils/models.py
CHANGED
|
@@ -8,15 +8,19 @@ import transformers
|
|
| 8 |
from transformers import (
|
| 9 |
AutoModelForCausalLM,
|
| 10 |
AutoTokenizer,
|
| 11 |
-
PreTrainedModel,
|
|
|
|
| 12 |
)
|
|
|
|
| 13 |
try:
|
| 14 |
from transformers import (
|
| 15 |
LlamaForCausalLM,
|
| 16 |
LlamaTokenizer,
|
| 17 |
)
|
| 18 |
except:
|
| 19 |
-
logging.warning(
|
|
|
|
|
|
|
| 20 |
|
| 21 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
| 22 |
|
|
@@ -40,7 +44,9 @@ def load_model(
|
|
| 40 |
# TODO refactor as a kwarg
|
| 41 |
load_in_8bit = cfg.load_in_8bit
|
| 42 |
tokenizer = None
|
| 43 |
-
is_llama_derived_model = "llama" in base_model or (
|
|
|
|
|
|
|
| 44 |
|
| 45 |
if is_llama_derived_model and cfg.flash_attention:
|
| 46 |
if cfg.device not in ["mps", "cpu"] and inference is False:
|
|
@@ -49,11 +55,16 @@ def load_model(
|
|
| 49 |
logging.info("patching with flash attention")
|
| 50 |
replace_llama_attn_with_flash_attn()
|
| 51 |
elif is_llama_derived_model and cfg.xformers_attention:
|
| 52 |
-
from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import
|
|
|
|
|
|
|
|
|
|
| 53 |
logging.info("patching with xformers attention")
|
| 54 |
hijack_llama_attention()
|
| 55 |
|
| 56 |
-
torch_dtype =
|
|
|
|
|
|
|
| 57 |
try:
|
| 58 |
if cfg.load_4bit:
|
| 59 |
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
|
|
@@ -74,8 +85,12 @@ def load_model(
|
|
| 74 |
try:
|
| 75 |
snapshot_download_kwargs = {}
|
| 76 |
if cfg.base_model_ignore_patterns:
|
| 77 |
-
snapshot_download_kwargs[
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
files = (
|
| 80 |
list(cache_model_path.glob("*.pt"))
|
| 81 |
+ list(cache_model_path.glob("*.safetensors"))
|
|
|
|
| 8 |
from transformers import (
|
| 9 |
AutoModelForCausalLM,
|
| 10 |
AutoTokenizer,
|
| 11 |
+
PreTrainedModel,
|
| 12 |
+
AutoConfig,
|
| 13 |
)
|
| 14 |
+
|
| 15 |
try:
|
| 16 |
from transformers import (
|
| 17 |
LlamaForCausalLM,
|
| 18 |
LlamaTokenizer,
|
| 19 |
)
|
| 20 |
except:
|
| 21 |
+
logging.warning(
|
| 22 |
+
"This version of transformers does not support Llama. Consider upgrading."
|
| 23 |
+
)
|
| 24 |
|
| 25 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
| 26 |
|
|
|
|
| 44 |
# TODO refactor as a kwarg
|
| 45 |
load_in_8bit = cfg.load_in_8bit
|
| 46 |
tokenizer = None
|
| 47 |
+
is_llama_derived_model = "llama" in base_model or (
|
| 48 |
+
cfg.model_type and "llama" in cfg.model_type.lower()
|
| 49 |
+
)
|
| 50 |
|
| 51 |
if is_llama_derived_model and cfg.flash_attention:
|
| 52 |
if cfg.device not in ["mps", "cpu"] and inference is False:
|
|
|
|
| 55 |
logging.info("patching with flash attention")
|
| 56 |
replace_llama_attn_with_flash_attn()
|
| 57 |
elif is_llama_derived_model and cfg.xformers_attention:
|
| 58 |
+
from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import (
|
| 59 |
+
hijack_llama_attention,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
logging.info("patching with xformers attention")
|
| 63 |
hijack_llama_attention()
|
| 64 |
|
| 65 |
+
torch_dtype = (
|
| 66 |
+
torch.float16 if cfg.load_in_8bit or cfg.fp16 or cfg.bf16 else torch.float32
|
| 67 |
+
)
|
| 68 |
try:
|
| 69 |
if cfg.load_4bit:
|
| 70 |
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
|
|
|
|
| 85 |
try:
|
| 86 |
snapshot_download_kwargs = {}
|
| 87 |
if cfg.base_model_ignore_patterns:
|
| 88 |
+
snapshot_download_kwargs[
|
| 89 |
+
"ignore_patterns"
|
| 90 |
+
] = cfg.base_model_ignore_patterns
|
| 91 |
+
cache_model_path = Path(
|
| 92 |
+
snapshot_download(base_model, **snapshot_download_kwargs)
|
| 93 |
+
)
|
| 94 |
files = (
|
| 95 |
list(cache_model_path.glob("*.pt"))
|
| 96 |
+ list(cache_model_path.glob("*.safetensors"))
|
src/axolotl/utils/schedulers.py
CHANGED
|
@@ -26,7 +26,10 @@ class InterpolatingLogScheduler(LRScheduler):
|
|
| 26 |
if self.last_epoch <= 0:
|
| 27 |
lrs = [self.min_lr for base_lr in self.base_lrs]
|
| 28 |
elif self.last_epoch < self.num_steps:
|
| 29 |
-
lrs = [
|
|
|
|
|
|
|
|
|
|
| 30 |
else:
|
| 31 |
lrs = [self.max_lr for base_lr in self.base_lrs]
|
| 32 |
|
|
|
|
| 26 |
if self.last_epoch <= 0:
|
| 27 |
lrs = [self.min_lr for base_lr in self.base_lrs]
|
| 28 |
elif self.last_epoch < self.num_steps:
|
| 29 |
+
lrs = [
|
| 30 |
+
self.min_lr * (self.q ** (self.last_epoch - 1))
|
| 31 |
+
for base_lr in self.base_lrs
|
| 32 |
+
]
|
| 33 |
else:
|
| 34 |
lrs = [self.max_lr for base_lr in self.base_lrs]
|
| 35 |
|
src/axolotl/utils/tokenization.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from termcolor import colored
|
| 2 |
import logging
|
| 3 |
|
|
|
|
| 4 |
def check_dataset_labels(dataset, tokenizer):
|
| 5 |
# the dataset is already shuffled, so let's just check the first 5 elements
|
| 6 |
for idx in range(5):
|
|
@@ -11,7 +12,7 @@ def check_example_labels(example, tokenizer):
|
|
| 11 |
# Get the input_ids, labels, and attention_mask from the dataset
|
| 12 |
input_ids = example["input_ids"]
|
| 13 |
labels = example["labels"]
|
| 14 |
-
attention_mask =example["attention_mask"]
|
| 15 |
|
| 16 |
# You can compare the input_ids and labels element-wise
|
| 17 |
# Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
|
|
@@ -21,9 +22,7 @@ def check_example_labels(example, tokenizer):
|
|
| 21 |
):
|
| 22 |
decoded_input_token = tokenizer.decode(input_id)
|
| 23 |
# Choose the color based on whether the label has the ignore value or not
|
| 24 |
-
color = (
|
| 25 |
-
"red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
|
| 26 |
-
)
|
| 27 |
colored_token = colored(decoded_input_token, color) + colored(
|
| 28 |
f"({label_id}, {mask}, {input_id})", "white"
|
| 29 |
)
|
|
|
|
| 1 |
from termcolor import colored
|
| 2 |
import logging
|
| 3 |
|
| 4 |
+
|
| 5 |
def check_dataset_labels(dataset, tokenizer):
|
| 6 |
# the dataset is already shuffled, so let's just check the first 5 elements
|
| 7 |
for idx in range(5):
|
|
|
|
| 12 |
# Get the input_ids, labels, and attention_mask from the dataset
|
| 13 |
input_ids = example["input_ids"]
|
| 14 |
labels = example["labels"]
|
| 15 |
+
attention_mask = example["attention_mask"]
|
| 16 |
|
| 17 |
# You can compare the input_ids and labels element-wise
|
| 18 |
# Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
|
|
|
|
| 22 |
):
|
| 23 |
decoded_input_token = tokenizer.decode(input_id)
|
| 24 |
# Choose the color based on whether the label has the ignore value or not
|
| 25 |
+
color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
|
|
|
|
|
|
|
| 26 |
colored_token = colored(decoded_input_token, color) + colored(
|
| 27 |
f"({label_id}, {mask}, {input_id})", "white"
|
| 28 |
)
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -78,7 +78,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 78 |
|
| 79 |
training_args = transformers.TrainingArguments(
|
| 80 |
per_device_train_batch_size=cfg.micro_batch_size,
|
| 81 |
-
per_device_eval_batch_size=cfg.eval_batch_size
|
|
|
|
|
|
|
| 82 |
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
| 83 |
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
| 84 |
num_train_epochs=cfg.num_epochs,
|
|
@@ -90,14 +92,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 90 |
output_dir=cfg.output_dir,
|
| 91 |
save_total_limit=3,
|
| 92 |
load_best_model_at_end=True
|
| 93 |
-
if cfg.val_set_size > 0
|
|
|
|
|
|
|
|
|
|
| 94 |
else False,
|
| 95 |
ddp_find_unused_parameters=False if cfg.ddp else None,
|
| 96 |
group_by_length=cfg.group_by_length,
|
| 97 |
report_to="wandb" if cfg.use_wandb else None,
|
| 98 |
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
| 99 |
optim=cfg.optimizer if cfg.optimizer else "adamw_hf",
|
| 100 |
-
lr_scheduler_type=cfg.lr_scheduler
|
|
|
|
|
|
|
| 101 |
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
| 102 |
**training_arguments_kwargs,
|
| 103 |
)
|
|
@@ -184,7 +191,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 184 |
data_collator_kwargs["pad_to_multiple_of"] = 8
|
| 185 |
|
| 186 |
callbacks = []
|
| 187 |
-
if cfg.adapter ==
|
| 188 |
callbacks.append(SavePeftModelCallback)
|
| 189 |
|
| 190 |
trainer = transformers.Trainer(
|
|
|
|
| 78 |
|
| 79 |
training_args = transformers.TrainingArguments(
|
| 80 |
per_device_train_batch_size=cfg.micro_batch_size,
|
| 81 |
+
per_device_eval_batch_size=cfg.eval_batch_size
|
| 82 |
+
if cfg.eval_batch_size is not None
|
| 83 |
+
else cfg.micro_batch_size,
|
| 84 |
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
| 85 |
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
| 86 |
num_train_epochs=cfg.num_epochs,
|
|
|
|
| 92 |
output_dir=cfg.output_dir,
|
| 93 |
save_total_limit=3,
|
| 94 |
load_best_model_at_end=True
|
| 95 |
+
if cfg.val_set_size > 0
|
| 96 |
+
and save_steps is not None
|
| 97 |
+
and save_steps % eval_steps == 0
|
| 98 |
+
and cfg.load_in_8bit is not True
|
| 99 |
else False,
|
| 100 |
ddp_find_unused_parameters=False if cfg.ddp else None,
|
| 101 |
group_by_length=cfg.group_by_length,
|
| 102 |
report_to="wandb" if cfg.use_wandb else None,
|
| 103 |
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
| 104 |
optim=cfg.optimizer if cfg.optimizer else "adamw_hf",
|
| 105 |
+
lr_scheduler_type=cfg.lr_scheduler
|
| 106 |
+
if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep")
|
| 107 |
+
else "cosine",
|
| 108 |
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
| 109 |
**training_arguments_kwargs,
|
| 110 |
)
|
|
|
|
| 191 |
data_collator_kwargs["pad_to_multiple_of"] = 8
|
| 192 |
|
| 193 |
callbacks = []
|
| 194 |
+
if cfg.adapter == "lora":
|
| 195 |
callbacks.append(SavePeftModelCallback)
|
| 196 |
|
| 197 |
trainer = transformers.Trainer(
|