Merge pull request #62 from OpenAccess-AI-Collective/qlora-fixes
Browse files- README.md +5 -5
- scripts/finetune.py +1 -0
- src/axolotl/utils/data.py +8 -11
- src/axolotl/utils/models.py +6 -4
- src/axolotl/utils/trainer.py +1 -1
- src/axolotl/utils/validation.py +15 -4
README.md
CHANGED
|
@@ -24,7 +24,7 @@
|
|
| 24 |
|
| 25 |
## Quickstart ⚡
|
| 26 |
|
| 27 |
-
**Requirements**: Python 3.9.
|
| 28 |
|
| 29 |
```bash
|
| 30 |
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
|
@@ -45,7 +45,7 @@ accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
|
|
| 45 |
|
| 46 |
### Environment
|
| 47 |
|
| 48 |
-
- Docker
|
| 49 |
```bash
|
| 50 |
docker run --gpus '"all"' --rm -it winglian/axolotl:main
|
| 51 |
```
|
|
@@ -334,7 +334,7 @@ strict:
|
|
| 334 |
|
| 335 |
### Accelerate
|
| 336 |
|
| 337 |
-
Configure accelerate
|
| 338 |
|
| 339 |
```bash
|
| 340 |
accelerate config
|
|
@@ -368,7 +368,7 @@ Pass the appropriate flag to the train command:
|
|
| 368 |
Add below flag to train command above
|
| 369 |
|
| 370 |
```bash
|
| 371 |
-
--merge_lora --lora_model_dir="./completed-model"
|
| 372 |
```
|
| 373 |
|
| 374 |
## Common Errors 🧰
|
|
@@ -389,7 +389,7 @@ Try set `fp16: true`
|
|
| 389 |
Try to turn off xformers.
|
| 390 |
|
| 391 |
## Need help? 🙋♂️
|
| 392 |
-
|
| 393 |
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
|
| 394 |
|
| 395 |
## Contributing 🤝
|
|
|
|
| 24 |
|
| 25 |
## Quickstart ⚡
|
| 26 |
|
| 27 |
+
**Requirements**: Python 3.9.
|
| 28 |
|
| 29 |
```bash
|
| 30 |
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
|
|
|
| 45 |
|
| 46 |
### Environment
|
| 47 |
|
| 48 |
+
- Docker
|
| 49 |
```bash
|
| 50 |
docker run --gpus '"all"' --rm -it winglian/axolotl:main
|
| 51 |
```
|
|
|
|
| 334 |
|
| 335 |
### Accelerate
|
| 336 |
|
| 337 |
+
Configure accelerate
|
| 338 |
|
| 339 |
```bash
|
| 340 |
accelerate config
|
|
|
|
| 368 |
Add below flag to train command above
|
| 369 |
|
| 370 |
```bash
|
| 371 |
+
--merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
| 372 |
```
|
| 373 |
|
| 374 |
## Common Errors 🧰
|
|
|
|
| 389 |
Try to turn off xformers.
|
| 390 |
|
| 391 |
## Need help? 🙋♂️
|
| 392 |
+
|
| 393 |
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
|
| 394 |
|
| 395 |
## Contributing 🤝
|
scripts/finetune.py
CHANGED
|
@@ -176,6 +176,7 @@ def train(
|
|
| 176 |
if "merge_lora" in kwargs and cfg.adapter is not None:
|
| 177 |
logging.info("running merge of LoRA with base model")
|
| 178 |
model = model.merge_and_unload()
|
|
|
|
| 179 |
|
| 180 |
if cfg.local_rank == 0:
|
| 181 |
logging.info("saving merged model")
|
|
|
|
| 176 |
if "merge_lora" in kwargs and cfg.adapter is not None:
|
| 177 |
logging.info("running merge of LoRA with base model")
|
| 178 |
model = model.merge_and_unload()
|
| 179 |
+
model.to(dtype=torch.float16)
|
| 180 |
|
| 181 |
if cfg.local_rank == 0:
|
| 182 |
logging.info("saving merged model")
|
src/axolotl/utils/data.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import logging
|
| 2 |
from hashlib import md5
|
| 3 |
from pathlib import Path
|
|
|
|
| 4 |
|
| 5 |
from datasets import (
|
| 6 |
load_from_disk,
|
|
@@ -80,7 +81,7 @@ def load_tokenized_prepared_datasets(
|
|
| 80 |
logging.info("Loading raw datasets...")
|
| 81 |
datasets = []
|
| 82 |
for d in cfg.datasets:
|
| 83 |
-
ds = None
|
| 84 |
ds_from_hub = False
|
| 85 |
try:
|
| 86 |
load_dataset(d.path, streaming=True, use_auth_token=True)
|
|
@@ -90,36 +91,32 @@ def load_tokenized_prepared_datasets(
|
|
| 90 |
|
| 91 |
# prefer local dataset, even if hub exists
|
| 92 |
if Path(d.path).exists():
|
| 93 |
-
ds:
|
| 94 |
"json", data_files=d.path, streaming=False, split=None
|
| 95 |
)
|
| 96 |
elif ds_from_hub:
|
| 97 |
if d.data_files:
|
| 98 |
-
ds = load_dataset(
|
| 99 |
d.path,
|
| 100 |
streaming=False,
|
| 101 |
data_files=d.data_files,
|
| 102 |
use_auth_token=True,
|
| 103 |
)
|
| 104 |
else:
|
| 105 |
-
ds = load_dataset(d.path, streaming=False, use_auth_token=True)
|
| 106 |
else:
|
| 107 |
fp = hf_hub_download(
|
| 108 |
repo_id=d.path, repo_type="dataset", filename=d.data_files
|
| 109 |
)
|
| 110 |
-
ds = load_dataset("json", data_files=fp, streaming=False, split=None)
|
| 111 |
if not ds:
|
| 112 |
raise Exception("unhandled dataset load")
|
| 113 |
# support for using a subset of the data
|
| 114 |
if d.shards:
|
| 115 |
-
<<<<<<< Updated upstream
|
| 116 |
-
ds = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0)
|
| 117 |
-
=======
|
| 118 |
if "train" in ds:
|
| 119 |
-
ds = ds.shuffle(seed=42)["train"].shard(num_shards=
|
| 120 |
else:
|
| 121 |
-
ds = ds.shuffle(seed=42).shard(num_shards=
|
| 122 |
-
>>>>>>> Stashed changes
|
| 123 |
d_type = d.type
|
| 124 |
d_type_split = d_type.split(":")
|
| 125 |
d_base_type = d_type_split[0]
|
|
|
|
| 1 |
import logging
|
| 2 |
from hashlib import md5
|
| 3 |
from pathlib import Path
|
| 4 |
+
from typing import Union
|
| 5 |
|
| 6 |
from datasets import (
|
| 7 |
load_from_disk,
|
|
|
|
| 81 |
logging.info("Loading raw datasets...")
|
| 82 |
datasets = []
|
| 83 |
for d in cfg.datasets:
|
| 84 |
+
ds: Union[Dataset, DatasetDict] = None
|
| 85 |
ds_from_hub = False
|
| 86 |
try:
|
| 87 |
load_dataset(d.path, streaming=True, use_auth_token=True)
|
|
|
|
| 91 |
|
| 92 |
# prefer local dataset, even if hub exists
|
| 93 |
if Path(d.path).exists():
|
| 94 |
+
ds: Dataset = load_dataset(
|
| 95 |
"json", data_files=d.path, streaming=False, split=None
|
| 96 |
)
|
| 97 |
elif ds_from_hub:
|
| 98 |
if d.data_files:
|
| 99 |
+
ds: Dataset = load_dataset(
|
| 100 |
d.path,
|
| 101 |
streaming=False,
|
| 102 |
data_files=d.data_files,
|
| 103 |
use_auth_token=True,
|
| 104 |
)
|
| 105 |
else:
|
| 106 |
+
ds: Dataset = load_dataset(d.path, streaming=False, use_auth_token=True)
|
| 107 |
else:
|
| 108 |
fp = hf_hub_download(
|
| 109 |
repo_id=d.path, repo_type="dataset", filename=d.data_files
|
| 110 |
)
|
| 111 |
+
ds: Dataset = load_dataset("json", data_files=fp, streaming=False, split=None)
|
| 112 |
if not ds:
|
| 113 |
raise Exception("unhandled dataset load")
|
| 114 |
# support for using a subset of the data
|
| 115 |
if d.shards:
|
|
|
|
|
|
|
|
|
|
| 116 |
if "train" in ds:
|
| 117 |
+
ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0)
|
| 118 |
else:
|
| 119 |
+
ds: Dataset = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0)
|
|
|
|
| 120 |
d_type = d.type
|
| 121 |
d_type_split = d_type.split(":")
|
| 122 |
d_base_type = d_type_split[0]
|
src/axolotl/utils/models.py
CHANGED
|
@@ -85,7 +85,7 @@ def load_model(
|
|
| 85 |
raise e
|
| 86 |
|
| 87 |
model_kwargs = {}
|
| 88 |
-
if cfg.adapter == "qlora":
|
| 89 |
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 90 |
load_in_4bit=True,
|
| 91 |
llm_int8_threshold=6.0,
|
|
@@ -247,8 +247,10 @@ def load_model(
|
|
| 247 |
model.resize_token_embeddings(embeddings_len)
|
| 248 |
|
| 249 |
if (
|
| 250 |
-
(cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora"
|
| 251 |
-
|
|
|
|
|
|
|
| 252 |
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
| 253 |
model = prepare_model_for_int8_training(model)
|
| 254 |
|
|
@@ -297,7 +299,7 @@ def load_adapter(model, cfg, adapter):
|
|
| 297 |
|
| 298 |
if adapter is None:
|
| 299 |
return model, None
|
| 300 |
-
if adapter
|
| 301 |
return load_lora(model, cfg)
|
| 302 |
if adapter == "llama-adapter":
|
| 303 |
return load_llama_adapter(model, cfg)
|
|
|
|
| 85 |
raise e
|
| 86 |
|
| 87 |
model_kwargs = {}
|
| 88 |
+
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
| 89 |
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 90 |
load_in_4bit=True,
|
| 91 |
llm_int8_threshold=6.0,
|
|
|
|
| 247 |
model.resize_token_embeddings(embeddings_len)
|
| 248 |
|
| 249 |
if (
|
| 250 |
+
((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora")
|
| 251 |
+
and not cfg.load_4bit
|
| 252 |
+
and (load_in_8bit or cfg.load_in_4bit)
|
| 253 |
+
):
|
| 254 |
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
| 255 |
model = prepare_model_for_int8_training(model)
|
| 256 |
|
|
|
|
| 299 |
|
| 300 |
if adapter is None:
|
| 301 |
return model, None
|
| 302 |
+
if adapter in ["lora", "qlora"]:
|
| 303 |
return load_lora(model, cfg)
|
| 304 |
if adapter == "llama-adapter":
|
| 305 |
return load_llama_adapter(model, cfg)
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -205,7 +205,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 205 |
)
|
| 206 |
callbacks.append(early_stop_cb)
|
| 207 |
|
| 208 |
-
if cfg.local_rank == 0 and cfg.adapter
|
| 209 |
callbacks.append(SavePeftModelCallback)
|
| 210 |
|
| 211 |
data_collator_kwargs = {
|
|
|
|
| 205 |
)
|
| 206 |
callbacks.append(early_stop_cb)
|
| 207 |
|
| 208 |
+
if cfg.local_rank == 0 and cfg.adapter in ["lora", "qlora"]: # only save in rank 0
|
| 209 |
callbacks.append(SavePeftModelCallback)
|
| 210 |
|
| 211 |
data_collator_kwargs = {
|
src/axolotl/utils/validation.py
CHANGED
|
@@ -1,9 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
def validate_config(cfg):
|
| 2 |
if cfg.adapter == "qlora":
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
# TODO
|
| 8 |
# MPT 7b
|
| 9 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
|
| 4 |
def validate_config(cfg):
|
| 5 |
if cfg.adapter == "qlora":
|
| 6 |
+
if cfg.merge_lora:
|
| 7 |
+
# can't merge qlora if loaded in 8bit or 4bit
|
| 8 |
+
assert cfg.load_in_8bit is False
|
| 9 |
+
assert cfg.load_4bit is False
|
| 10 |
+
assert cfg.load_in_4bit is False
|
| 11 |
+
else:
|
| 12 |
+
assert cfg.load_in_8bit is False
|
| 13 |
+
assert cfg.load_4bit is False
|
| 14 |
+
assert cfg.load_in_4bit is True
|
| 15 |
+
if cfg.load_in_8bit and cfg.adapter == "lora":
|
| 16 |
+
logging.warning("we recommend setting `load_in_8bit: true`")
|
| 17 |
+
|
| 18 |
# TODO
|
| 19 |
# MPT 7b
|
| 20 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|