4bit quantized support (wip)
Browse files- README.md +2 -2
- configs/cerebras_1_3B_alpaca.yml +2 -1
- configs/llama_65B_alpaca.yml +1 -1
- configs/llama_7B_alpaca.yml +1 -1
- configs/pythia_1_2B_alpaca.yml +1 -1
- pyproject.toml +0 -3
- requirements.txt +1 -2
- scripts/finetune.py +69 -19
- setup.cfg +0 -33
- setup.py +30 -0
- src/axolotl/datasets.py +1 -0
README.md
CHANGED
|
@@ -29,8 +29,8 @@ shuf -n2000 data/vicuna_cleaned.jsonl > data/vicuna_cleaned.subset0.jsonl
|
|
| 29 |
```
|
| 30 |
|
| 31 |
- Create a new or update the existing YAML config (config/pythia_1_2B_alpaca.yml)[config/pythia_1_2B_alpaca.yml]
|
| 32 |
-
- Install python dependencies `pip3 install -
|
| 33 |
-
- Configure accelerate `accelerate
|
| 34 |
|
| 35 |
```yaml
|
| 36 |
compute_environment: LOCAL_MACHINE
|
|
|
|
| 29 |
```
|
| 30 |
|
| 31 |
- Create a new or update the existing YAML config (config/pythia_1_2B_alpaca.yml)[config/pythia_1_2B_alpaca.yml]
|
| 32 |
+
- Install python dependencies `pip3 install -e .[triton]` or `pip3 install -e .[cuda]`
|
| 33 |
+
- Configure accelerate `accelerate config` or update `~/.cache/huggingface/accelerate/default_config.yaml`
|
| 34 |
|
| 35 |
```yaml
|
| 36 |
compute_environment: LOCAL_MACHINE
|
configs/cerebras_1_3B_alpaca.yml
CHANGED
|
@@ -11,7 +11,7 @@ datasets:
|
|
| 11 |
type: gpteacher
|
| 12 |
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
| 13 |
type: gpteacher
|
| 14 |
-
dataset_prepared_path:
|
| 15 |
val_set_size: 0.05
|
| 16 |
adapter: lora
|
| 17 |
sequence_len: 2048
|
|
@@ -34,6 +34,7 @@ train_on_inputs: false
|
|
| 34 |
group_by_length: false
|
| 35 |
bf16: True
|
| 36 |
tf32: True
|
|
|
|
| 37 |
early_stopping_patience:
|
| 38 |
resume_from_checkpoint:
|
| 39 |
local_rank:
|
|
|
|
| 11 |
type: gpteacher
|
| 12 |
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
| 13 |
type: gpteacher
|
| 14 |
+
dataset_prepared_path: last_run_prepared
|
| 15 |
val_set_size: 0.05
|
| 16 |
adapter: lora
|
| 17 |
sequence_len: 2048
|
|
|
|
| 34 |
group_by_length: false
|
| 35 |
bf16: True
|
| 36 |
tf32: True
|
| 37 |
+
gradient_checkpointing:
|
| 38 |
early_stopping_patience:
|
| 39 |
resume_from_checkpoint:
|
| 40 |
local_rank:
|
configs/llama_65B_alpaca.yml
CHANGED
|
@@ -11,7 +11,7 @@ datasets:
|
|
| 11 |
type: gpteacher
|
| 12 |
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
| 13 |
type: gpteacher
|
| 14 |
-
dataset_prepared_path:
|
| 15 |
val_set_size: 0.04
|
| 16 |
adapter: lora
|
| 17 |
lora_model_dir:
|
|
|
|
| 11 |
type: gpteacher
|
| 12 |
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
| 13 |
type: gpteacher
|
| 14 |
+
dataset_prepared_path: last_run_prepared
|
| 15 |
val_set_size: 0.04
|
| 16 |
adapter: lora
|
| 17 |
lora_model_dir:
|
configs/llama_7B_alpaca.yml
CHANGED
|
@@ -11,7 +11,7 @@ datasets:
|
|
| 11 |
type: gpteacher
|
| 12 |
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
| 13 |
type: gpteacher
|
| 14 |
-
dataset_prepared_path:
|
| 15 |
val_set_size: 0.04
|
| 16 |
adapter: lora
|
| 17 |
lora_model_dir:
|
|
|
|
| 11 |
type: gpteacher
|
| 12 |
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
| 13 |
type: gpteacher
|
| 14 |
+
dataset_prepared_path: last_run_prepared
|
| 15 |
val_set_size: 0.04
|
| 16 |
adapter: lora
|
| 17 |
lora_model_dir:
|
configs/pythia_1_2B_alpaca.yml
CHANGED
|
@@ -11,7 +11,7 @@ datasets:
|
|
| 11 |
type: gpteacher
|
| 12 |
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
| 13 |
type: gpteacher
|
| 14 |
-
dataset_prepared_path:
|
| 15 |
val_set_size: 0.05
|
| 16 |
adapter: lora
|
| 17 |
lora_model_dir:
|
|
|
|
| 11 |
type: gpteacher
|
| 12 |
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
| 13 |
type: gpteacher
|
| 14 |
+
dataset_prepared_path: last_run_prepared
|
| 15 |
val_set_size: 0.05
|
| 16 |
adapter: lora
|
| 17 |
lora_model_dir:
|
pyproject.toml
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
[build-system]
|
| 2 |
-
requires = ["setuptools", "wheel"]
|
| 3 |
-
build-backend = "setuptools.build_meta"
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
-
git+https://github.com/huggingface/
|
| 2 |
-
git+https://github.com/huggingface/transformers.git
|
| 3 |
attrdict
|
| 4 |
fire
|
| 5 |
PyYAML==6.0
|
|
|
|
| 1 |
+
transformers @ git+https://github.com/huggingface/transformers.git
|
|
|
|
| 2 |
attrdict
|
| 3 |
fire
|
| 4 |
PyYAML==6.0
|
scripts/finetune.py
CHANGED
|
@@ -13,12 +13,6 @@ import transformers
|
|
| 13 |
import yaml
|
| 14 |
from attrdict import AttrDefault
|
| 15 |
from datasets import load_dataset, IterableDataset, Dataset, load_from_disk
|
| 16 |
-
from peft import (
|
| 17 |
-
LoraConfig,
|
| 18 |
-
get_peft_model,
|
| 19 |
-
prepare_model_for_int8_training,
|
| 20 |
-
PeftModel,
|
| 21 |
-
)
|
| 22 |
from torch import nn
|
| 23 |
from transformers import (
|
| 24 |
AutoModelForCausalLM,
|
|
@@ -45,7 +39,7 @@ from axolotl.prompt_tokenizers import (
|
|
| 45 |
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
|
| 46 |
|
| 47 |
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
|
| 48 |
-
DEFAULT_DATASET_PREPARED_PATH = "
|
| 49 |
|
| 50 |
|
| 51 |
def setup_wandb_env_vars(cfg):
|
|
@@ -60,7 +54,11 @@ def setup_wandb_env_vars(cfg):
|
|
| 60 |
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|
| 61 |
|
| 62 |
|
| 63 |
-
def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", inference: bool=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
if adapter != "lora":
|
| 65 |
raise NotImplementedError(f"{adapter} peft adapter not available")
|
| 66 |
if "llama" in base_model:
|
|
@@ -70,7 +68,43 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe
|
|
| 70 |
|
| 71 |
torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,
|
| 72 |
try:
|
| 73 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
model = LlamaForCausalLM.from_pretrained(
|
| 75 |
base_model,
|
| 76 |
load_in_8bit=cfg.load_in_8bit,
|
|
@@ -92,13 +126,14 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe
|
|
| 92 |
device_map=cfg.device_map,
|
| 93 |
)
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
|
|
|
| 102 |
|
| 103 |
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
|
| 104 |
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
|
@@ -107,7 +142,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe
|
|
| 107 |
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
| 108 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 109 |
|
| 110 |
-
if
|
| 111 |
model = prepare_model_for_int8_training(model)
|
| 112 |
|
| 113 |
lora_config = LoraConfig(
|
|
@@ -128,6 +163,16 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe
|
|
| 128 |
if cfg.ddp:
|
| 129 |
model.to(f"cuda:{cfg.local_rank}")
|
| 130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
# TODO resume_from_checkpoint handling
|
| 132 |
model.print_trainable_parameters()
|
| 133 |
return model, tokenizer, lora_config
|
|
@@ -243,6 +288,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 243 |
training_arguments_kwargs["tf32"] = cfg.tf32
|
| 244 |
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
| 245 |
training_arguments_kwargs["logging_steps"] = logging_steps
|
|
|
|
|
|
|
| 246 |
|
| 247 |
training_args = transformers.TrainingArguments(
|
| 248 |
per_device_train_batch_size=cfg.micro_batch_size,
|
|
@@ -260,7 +307,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 260 |
group_by_length=cfg.group_by_length,
|
| 261 |
report_to="wandb" if cfg.use_wandb else None,
|
| 262 |
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
| 263 |
-
gradient_checkpointing=cfg.gradient_checkpointing,
|
| 264 |
**training_arguments_kwargs,
|
| 265 |
)
|
| 266 |
|
|
@@ -356,11 +402,13 @@ def train(
|
|
| 356 |
cfg.bf16 = False
|
| 357 |
|
| 358 |
# Load the model and tokenizer
|
|
|
|
| 359 |
model, tokenizer, lora_config = load_model(
|
| 360 |
-
cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter, inference=("inference" in kwargs)
|
| 361 |
)
|
| 362 |
|
| 363 |
if "inference" in kwargs:
|
|
|
|
| 364 |
do_inference(cfg, model, tokenizer)
|
| 365 |
return
|
| 366 |
|
|
@@ -369,6 +417,7 @@ def train(
|
|
| 369 |
dataset = load_from_disk(cfg.dataset_prepared_path)
|
| 370 |
logging.info("Prepared dataset loaded from disk...")
|
| 371 |
else:
|
|
|
|
| 372 |
datasets = []
|
| 373 |
for d in cfg.datasets:
|
| 374 |
if Path(d.path).exists():
|
|
@@ -402,6 +451,7 @@ def train(
|
|
| 402 |
constant_len_dataset = ConstantLengthDataset(
|
| 403 |
tokenizer, datasets, seq_length=cfg.sequence_len
|
| 404 |
)
|
|
|
|
| 405 |
dataset = Dataset.from_list(
|
| 406 |
[_ for _ in constant_len_dataset]
|
| 407 |
).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)
|
|
|
|
| 13 |
import yaml
|
| 14 |
from attrdict import AttrDefault
|
| 15 |
from datasets import load_dataset, IterableDataset, Dataset, load_from_disk
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
from torch import nn
|
| 17 |
from transformers import (
|
| 18 |
AutoModelForCausalLM,
|
|
|
|
| 39 |
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
|
| 40 |
|
| 41 |
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
|
| 42 |
+
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
| 43 |
|
| 44 |
|
| 45 |
def setup_wandb_env_vars(cfg):
|
|
|
|
| 54 |
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|
| 55 |
|
| 56 |
|
| 57 |
+
def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, adapter="lora", inference: bool=False):
|
| 58 |
+
# TODO refactor as a kwarg
|
| 59 |
+
load_in_8bit = cfg.load_in_8bit
|
| 60 |
+
tokenizer = None
|
| 61 |
+
|
| 62 |
if adapter != "lora":
|
| 63 |
raise NotImplementedError(f"{adapter} peft adapter not available")
|
| 64 |
if "llama" in base_model:
|
|
|
|
| 68 |
|
| 69 |
torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,
|
| 70 |
try:
|
| 71 |
+
if cfg.load_4bit:
|
| 72 |
+
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_int4_lora_model
|
| 73 |
+
replace_peft_model_with_int4_lora_model()
|
| 74 |
+
|
| 75 |
+
from peft import (
|
| 76 |
+
LoraConfig,
|
| 77 |
+
get_peft_model,
|
| 78 |
+
prepare_model_for_int8_training,
|
| 79 |
+
PeftModel,
|
| 80 |
+
)
|
| 81 |
+
except Exception as e:
|
| 82 |
+
logging.exception(e)
|
| 83 |
+
raise e
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
if cfg.load_4bit and "llama" in base_model:
|
| 87 |
+
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
| 88 |
+
from huggingface_hub import snapshot_download
|
| 89 |
+
|
| 90 |
+
cache_model_path = Path(snapshot_download(base_model))
|
| 91 |
+
# TODO search .glob for a .pt, .safetensor, or .bin
|
| 92 |
+
cache_model_path.glob("*.pt")
|
| 93 |
+
files = list(cache_model_path.glob('*.pt')) + list(cache_model_path.glob('*.safetensor')) + list(cache_model_path.glob('*.bin'))
|
| 94 |
+
if len(files) > 0:
|
| 95 |
+
model_path = str(files[0])
|
| 96 |
+
else:
|
| 97 |
+
logging.warning("unable to find a cached model file, this will likely fail...")
|
| 98 |
+
model_path = str(cache_model_path)
|
| 99 |
+
model, tokenizer = load_llama_model_4bit_low_ram(
|
| 100 |
+
base_model_config if base_model_config else base_model,
|
| 101 |
+
model_path,
|
| 102 |
+
device_map=cfg.device_map,
|
| 103 |
+
groupsize=-1,
|
| 104 |
+
is_v1_model=True,
|
| 105 |
+
)
|
| 106 |
+
load_in_8bit = False
|
| 107 |
+
elif "llama" in base_model:
|
| 108 |
model = LlamaForCausalLM.from_pretrained(
|
| 109 |
base_model,
|
| 110 |
load_in_8bit=cfg.load_in_8bit,
|
|
|
|
| 126 |
device_map=cfg.device_map,
|
| 127 |
)
|
| 128 |
|
| 129 |
+
if not tokenizer:
|
| 130 |
+
try:
|
| 131 |
+
if "llama" in base_model:
|
| 132 |
+
tokenizer = LlamaTokenizer.from_pretrained(model)
|
| 133 |
+
else:
|
| 134 |
+
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
|
| 135 |
+
except:
|
| 136 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
| 137 |
|
| 138 |
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
|
| 139 |
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
|
|
|
| 142 |
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
| 143 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 144 |
|
| 145 |
+
if load_in_8bit:
|
| 146 |
model = prepare_model_for_int8_training(model)
|
| 147 |
|
| 148 |
lora_config = LoraConfig(
|
|
|
|
| 163 |
if cfg.ddp:
|
| 164 |
model.to(f"cuda:{cfg.local_rank}")
|
| 165 |
|
| 166 |
+
if cfg.load_4bit:
|
| 167 |
+
# Scales to half
|
| 168 |
+
print('Fitting 4bit scales and zeros to half')
|
| 169 |
+
for n, m in model.named_modules():
|
| 170 |
+
if 'Autograd4bitQuantLinear' in str(type(m)) or 'Linear4bitLt' in str(type(m)):
|
| 171 |
+
if hasattr(m, "is_v1_model") and m.is_v1_model:
|
| 172 |
+
m.zeros = m.zeros.half()
|
| 173 |
+
m.scales = m.scales.half()
|
| 174 |
+
m.bias = m.bias.half()
|
| 175 |
+
|
| 176 |
# TODO resume_from_checkpoint handling
|
| 177 |
model.print_trainable_parameters()
|
| 178 |
return model, tokenizer, lora_config
|
|
|
|
| 288 |
training_arguments_kwargs["tf32"] = cfg.tf32
|
| 289 |
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
| 290 |
training_arguments_kwargs["logging_steps"] = logging_steps
|
| 291 |
+
if cfg.gradient_checkpointing is not None:
|
| 292 |
+
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
|
| 293 |
|
| 294 |
training_args = transformers.TrainingArguments(
|
| 295 |
per_device_train_batch_size=cfg.micro_batch_size,
|
|
|
|
| 307 |
group_by_length=cfg.group_by_length,
|
| 308 |
report_to="wandb" if cfg.use_wandb else None,
|
| 309 |
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
|
|
|
| 310 |
**training_arguments_kwargs,
|
| 311 |
)
|
| 312 |
|
|
|
|
| 402 |
cfg.bf16 = False
|
| 403 |
|
| 404 |
# Load the model and tokenizer
|
| 405 |
+
logging.info("loading model, tokenizer, and lora_config...")
|
| 406 |
model, tokenizer, lora_config = load_model(
|
| 407 |
+
cfg.base_model, cfg.base_model_config, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter, inference=("inference" in kwargs)
|
| 408 |
)
|
| 409 |
|
| 410 |
if "inference" in kwargs:
|
| 411 |
+
logging.info("calling do_inference function")
|
| 412 |
do_inference(cfg, model, tokenizer)
|
| 413 |
return
|
| 414 |
|
|
|
|
| 417 |
dataset = load_from_disk(cfg.dataset_prepared_path)
|
| 418 |
logging.info("Prepared dataset loaded from disk...")
|
| 419 |
else:
|
| 420 |
+
logging.info("Loading raw datasets...")
|
| 421 |
datasets = []
|
| 422 |
for d in cfg.datasets:
|
| 423 |
if Path(d.path).exists():
|
|
|
|
| 451 |
constant_len_dataset = ConstantLengthDataset(
|
| 452 |
tokenizer, datasets, seq_length=cfg.sequence_len
|
| 453 |
)
|
| 454 |
+
logging.info("merging, packing, shuffling, and splitting master dataset")
|
| 455 |
dataset = Dataset.from_list(
|
| 456 |
[_ for _ in constant_len_dataset]
|
| 457 |
).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)
|
setup.cfg
DELETED
|
@@ -1,33 +0,0 @@
|
|
| 1 |
-
[metadata]
|
| 2 |
-
name = axolotl
|
| 3 |
-
version = 0.1.0
|
| 4 |
-
description = You know you're going to axolotl questions
|
| 5 |
-
author = Wing Lian
|
| 6 |
-
author_email = [email protected]
|
| 7 |
-
license = MIT
|
| 8 |
-
|
| 9 |
-
[options]
|
| 10 |
-
package_dir =
|
| 11 |
-
=src
|
| 12 |
-
packages = find:
|
| 13 |
-
install_requires =
|
| 14 |
-
transformers @ git+https://github.com/huggingface/transformers.git@main
|
| 15 |
-
peft @ git+https://github.com/huggingface/peft.git@main
|
| 16 |
-
attrdict
|
| 17 |
-
fire
|
| 18 |
-
PyYAML == 6.0
|
| 19 |
-
black
|
| 20 |
-
bitsandbytes
|
| 21 |
-
datasets
|
| 22 |
-
accelerate
|
| 23 |
-
sentencepiece
|
| 24 |
-
wandb
|
| 25 |
-
flash-attn
|
| 26 |
-
einops
|
| 27 |
-
|
| 28 |
-
[options.packages.find]
|
| 29 |
-
where = src
|
| 30 |
-
|
| 31 |
-
[options.extras_require]
|
| 32 |
-
gptq_cuda = alpaca_lora_4bit[cuda] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip#egg=alpaca_lora_4bit[cuda]
|
| 33 |
-
gptq_triton = alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip#egg=alpaca_lora_4bit[triton]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
setup.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from setuptools import setup, find_packages
|
| 3 |
+
|
| 4 |
+
install_requires = []
|
| 5 |
+
with open("./requirements.txt", "r") as requirements_file:
|
| 6 |
+
# don't include peft yet until we check the int4
|
| 7 |
+
reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r]
|
| 8 |
+
reqs = [r for r in reqs if r[0] != "#"]
|
| 9 |
+
for r in reqs:
|
| 10 |
+
install_requires.append(r)
|
| 11 |
+
|
| 12 |
+
setup(
|
| 13 |
+
name='axolotl',
|
| 14 |
+
version='0.1',
|
| 15 |
+
description="You know you're going to axolotl questions",
|
| 16 |
+
package_dir={'': 'src'},
|
| 17 |
+
packages=find_packages(),
|
| 18 |
+
install_requires=install_requires,
|
| 19 |
+
extras_require={
|
| 20 |
+
None: [
|
| 21 |
+
"peft @ git+https://github.com/huggingface/peft.git",
|
| 22 |
+
],
|
| 23 |
+
'int4_cuda': [
|
| 24 |
+
"alpaca_lora_4bit[cuda] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip#egg=alpaca_lora_4bit[cuda]",
|
| 25 |
+
],
|
| 26 |
+
'int4_triton': [
|
| 27 |
+
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip#egg=alpaca_lora_4bit[triton]",
|
| 28 |
+
],
|
| 29 |
+
},
|
| 30 |
+
)
|
src/axolotl/datasets.py
CHANGED
|
@@ -31,6 +31,7 @@ class TokenizedPromptDataset(IterableDataset):
|
|
| 31 |
pass
|
| 32 |
|
| 33 |
|
|
|
|
| 34 |
class ConstantLengthDataset(IterableDataset):
|
| 35 |
"""
|
| 36 |
Iterable dataset that returns constant length chunks of tokens from stream of text files.
|
|
|
|
| 31 |
pass
|
| 32 |
|
| 33 |
|
| 34 |
+
# TODO this isn't the best since it can't interleave datasets
|
| 35 |
class ConstantLengthDataset(IterableDataset):
|
| 36 |
"""
|
| 37 |
Iterable dataset that returns constant length chunks of tokens from stream of text files.
|