Lint and format
Browse files- .gitignore +1 -1
- docker/Dockerfile-base +0 -1
- examples/falcon/config-7b-lora.yml +0 -1
- examples/falcon/config-7b.yml +0 -1
- scripts/alpaca_json_to_jsonl.py +21 -5
- scripts/finetune.py +19 -15
- src/axolotl/datasets.py +4 -7
- src/axolotl/utils/data.py +31 -23
- tests/test_prompters.py +6 -4
.gitignore
CHANGED
|
@@ -160,4 +160,4 @@ cython_debug/
|
|
| 160 |
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 161 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 162 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 163 |
-
.idea/
|
|
|
|
| 160 |
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 161 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 162 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 163 |
+
.idea/
|
docker/Dockerfile-base
CHANGED
|
@@ -99,4 +99,3 @@ RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main" \
|
|
| 99 |
pip3 install awscli && \
|
| 100 |
# The base image ships with `pydantic==1.8.2` which is not working
|
| 101 |
pip3 install -U --no-cache-dir pydantic
|
| 102 |
-
|
|
|
|
| 99 |
pip3 install awscli && \
|
| 100 |
# The base image ships with `pydantic==1.8.2` which is not working
|
| 101 |
pip3 install -U --no-cache-dir pydantic
|
|
|
examples/falcon/config-7b-lora.yml
CHANGED
|
@@ -61,4 +61,3 @@ special_tokens:
|
|
| 61 |
pad_token: "<|endoftext|>"
|
| 62 |
bos_token: ">>ABSTRACT<<"
|
| 63 |
eos_token: "<|endoftext|>"
|
| 64 |
-
|
|
|
|
| 61 |
pad_token: "<|endoftext|>"
|
| 62 |
bos_token: ">>ABSTRACT<<"
|
| 63 |
eos_token: "<|endoftext|>"
|
|
|
examples/falcon/config-7b.yml
CHANGED
|
@@ -61,4 +61,3 @@ special_tokens:
|
|
| 61 |
pad_token: "<|endoftext|>"
|
| 62 |
bos_token: ">>ABSTRACT<<"
|
| 63 |
eos_token: "<|endoftext|>"
|
| 64 |
-
|
|
|
|
| 61 |
pad_token: "<|endoftext|>"
|
| 62 |
bos_token: ">>ABSTRACT<<"
|
| 63 |
eos_token: "<|endoftext|>"
|
|
|
scripts/alpaca_json_to_jsonl.py
CHANGED
|
@@ -1,23 +1,39 @@
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
|
|
|
|
|
|
| 3 |
from pathlib import Path
|
| 4 |
|
| 5 |
import fire
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
# add src to the pythonpath so we don't need to pip install this
|
| 9 |
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 10 |
src_dir = os.path.join(project_root, "src")
|
| 11 |
sys.path.insert(0, src_dir)
|
| 12 |
|
| 13 |
-
from axolotl.convert import *
|
| 14 |
-
|
| 15 |
|
| 16 |
def main(
|
| 17 |
-
|
| 18 |
output: Optional[Path] = None,
|
| 19 |
to_stdout: Optional[bool] = False,
|
| 20 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
file_reader = FileReader()
|
| 22 |
if to_stdout or output is None:
|
| 23 |
writer = StdoutWriter()
|
|
@@ -28,7 +44,7 @@ def main(
|
|
| 28 |
|
| 29 |
converter = JsonToJsonlConverter(file_reader, writer, json_parser, jsonl_serializer)
|
| 30 |
|
| 31 |
-
converter.convert(
|
| 32 |
|
| 33 |
|
| 34 |
if __name__ == "__main__":
|
|
|
|
| 1 |
+
"""Module to convert json file to jsonl"""
|
| 2 |
+
|
| 3 |
import os
|
| 4 |
import sys
|
| 5 |
+
|
| 6 |
+
from typing import Optional
|
| 7 |
from pathlib import Path
|
| 8 |
|
| 9 |
import fire
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
from axolotl.convert import (
|
| 13 |
+
FileReader,
|
| 14 |
+
StdoutWriter,
|
| 15 |
+
FileWriter,
|
| 16 |
+
JsonlSerializer,
|
| 17 |
+
JsonParser,
|
| 18 |
+
JsonToJsonlConverter,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
|
| 22 |
# add src to the pythonpath so we don't need to pip install this
|
| 23 |
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 24 |
src_dir = os.path.join(project_root, "src")
|
| 25 |
sys.path.insert(0, src_dir)
|
| 26 |
|
|
|
|
|
|
|
| 27 |
|
| 28 |
def main(
|
| 29 |
+
file: Path,
|
| 30 |
output: Optional[Path] = None,
|
| 31 |
to_stdout: Optional[bool] = False,
|
| 32 |
):
|
| 33 |
+
"""
|
| 34 |
+
Convert a json file to jsonl
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
file_reader = FileReader()
|
| 38 |
if to_stdout or output is None:
|
| 39 |
writer = StdoutWriter()
|
|
|
|
| 44 |
|
| 45 |
converter = JsonToJsonlConverter(file_reader, writer, json_parser, jsonl_serializer)
|
| 46 |
|
| 47 |
+
converter.convert(file, output)
|
| 48 |
|
| 49 |
|
| 50 |
if __name__ == "__main__":
|
scripts/finetune.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import importlib
|
| 2 |
import logging
|
| 3 |
import os
|
|
@@ -16,15 +18,16 @@ from axolotl.utils.tokenization import check_dataset_labels
|
|
| 16 |
from axolotl.utils.validation import validate_config
|
| 17 |
from axolotl.utils.dict import DictDefault
|
| 18 |
|
| 19 |
-
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 20 |
-
src_dir = os.path.join(project_root, "src")
|
| 21 |
-
sys.path.insert(0, src_dir)
|
| 22 |
-
|
| 23 |
from axolotl.utils.data import load_prepare_datasets
|
| 24 |
from axolotl.utils.models import load_model, load_tokenizer
|
| 25 |
from axolotl.utils.trainer import setup_trainer
|
| 26 |
from axolotl.utils.wandb import setup_wandb_env_vars
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
|
| 29 |
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
| 30 |
|
|
@@ -37,7 +40,7 @@ def choose_device(cfg):
|
|
| 37 |
try:
|
| 38 |
if torch.backends.mps.is_available():
|
| 39 |
return "mps"
|
| 40 |
-
except:
|
| 41 |
return "cpu"
|
| 42 |
|
| 43 |
cfg.device = get_device()
|
|
@@ -73,7 +76,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
|
| 73 |
|
| 74 |
model.eval()
|
| 75 |
with torch.no_grad():
|
| 76 |
-
# gc = GenerationConfig() # TODO swap out and use this
|
| 77 |
generated = model.generate(
|
| 78 |
inputs=batch["input_ids"].to(cfg.device),
|
| 79 |
do_sample=True,
|
|
@@ -130,12 +133,12 @@ def train(
|
|
| 130 |
config = choose_config(config)
|
| 131 |
|
| 132 |
# load the config from the yaml file
|
| 133 |
-
with open(config, "
|
| 134 |
-
cfg: DictDefault = DictDefault(yaml.load(
|
| 135 |
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
| 136 |
# then overwrite the value
|
| 137 |
cfg_keys = cfg.keys()
|
| 138 |
-
for k in kwargs:
|
| 139 |
# if not strict, allow writing to cfg even if it's not in the yml already
|
| 140 |
if k in cfg_keys or cfg.strict is False:
|
| 141 |
# handle booleans
|
|
@@ -167,13 +170,11 @@ def train(
|
|
| 167 |
|
| 168 |
# load the tokenizer first
|
| 169 |
logging.info("loading tokenizer...")
|
| 170 |
-
tokenizer = load_tokenizer(
|
| 171 |
-
cfg.base_model_config,
|
| 172 |
-
cfg.tokenizer_type,
|
| 173 |
-
cfg
|
| 174 |
-
)
|
| 175 |
|
| 176 |
-
if check_not_in(
|
|
|
|
|
|
|
| 177 |
train_dataset, eval_dataset = load_prepare_datasets(
|
| 178 |
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
| 179 |
)
|
|
@@ -262,10 +263,13 @@ def train(
|
|
| 262 |
|
| 263 |
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
| 264 |
|
|
|
|
| 265 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
| 266 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
| 267 |
if cfg.local_rank == 0:
|
| 268 |
model.save_pretrained(cfg.output_dir)
|
|
|
|
|
|
|
| 269 |
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
| 270 |
|
| 271 |
|
|
|
|
| 1 |
+
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
| 2 |
+
|
| 3 |
import importlib
|
| 4 |
import logging
|
| 5 |
import os
|
|
|
|
| 18 |
from axolotl.utils.validation import validate_config
|
| 19 |
from axolotl.utils.dict import DictDefault
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
from axolotl.utils.data import load_prepare_datasets
|
| 22 |
from axolotl.utils.models import load_model, load_tokenizer
|
| 23 |
from axolotl.utils.trainer import setup_trainer
|
| 24 |
from axolotl.utils.wandb import setup_wandb_env_vars
|
| 25 |
|
| 26 |
+
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 27 |
+
src_dir = os.path.join(project_root, "src")
|
| 28 |
+
sys.path.insert(0, src_dir)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
|
| 32 |
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
| 33 |
|
|
|
|
| 40 |
try:
|
| 41 |
if torch.backends.mps.is_available():
|
| 42 |
return "mps"
|
| 43 |
+
except Exception: # pylint: disable=broad-exception-caught
|
| 44 |
return "cpu"
|
| 45 |
|
| 46 |
cfg.device = get_device()
|
|
|
|
| 76 |
|
| 77 |
model.eval()
|
| 78 |
with torch.no_grad():
|
| 79 |
+
# gc = GenerationConfig() # TODO swap out and use this # pylint: disable=fixme
|
| 80 |
generated = model.generate(
|
| 81 |
inputs=batch["input_ids"].to(cfg.device),
|
| 82 |
do_sample=True,
|
|
|
|
| 133 |
config = choose_config(config)
|
| 134 |
|
| 135 |
# load the config from the yaml file
|
| 136 |
+
with open(config, encoding="utf-8") as file:
|
| 137 |
+
cfg: DictDefault = DictDefault(yaml.load(file, Loader=yaml.Loader))
|
| 138 |
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
| 139 |
# then overwrite the value
|
| 140 |
cfg_keys = cfg.keys()
|
| 141 |
+
for k, _ in kwargs.items():
|
| 142 |
# if not strict, allow writing to cfg even if it's not in the yml already
|
| 143 |
if k in cfg_keys or cfg.strict is False:
|
| 144 |
# handle booleans
|
|
|
|
| 170 |
|
| 171 |
# load the tokenizer first
|
| 172 |
logging.info("loading tokenizer...")
|
| 173 |
+
tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
+
if check_not_in(
|
| 176 |
+
["inference", "shard", "merge_lora"], kwargs
|
| 177 |
+
): # don't need to load dataset for these
|
| 178 |
train_dataset, eval_dataset = load_prepare_datasets(
|
| 179 |
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
| 180 |
)
|
|
|
|
| 263 |
|
| 264 |
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
| 265 |
|
| 266 |
+
# pylint: disable=fixme
|
| 267 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
| 268 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
| 269 |
if cfg.local_rank == 0:
|
| 270 |
model.save_pretrained(cfg.output_dir)
|
| 271 |
+
|
| 272 |
+
# pylint: disable=fixme
|
| 273 |
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
| 274 |
|
| 275 |
|
src/axolotl/datasets.py
CHANGED
|
@@ -82,10 +82,8 @@ class ConstantLengthDataset(IterableDataset):
|
|
| 82 |
else:
|
| 83 |
example_len = 0
|
| 84 |
|
| 85 |
-
if (
|
| 86 |
-
|
| 87 |
-
or buffer_len + int(add_concat_token) + example_len
|
| 88 |
-
> self.seq_length
|
| 89 |
):
|
| 90 |
if buffer["input_ids"]:
|
| 91 |
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
|
|
@@ -95,9 +93,8 @@ class ConstantLengthDataset(IterableDataset):
|
|
| 95 |
: self.seq_length
|
| 96 |
]
|
| 97 |
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
| 98 |
-
if (
|
| 99 |
-
|
| 100 |
-
and attention_mask.size() == input_ids.size()
|
| 101 |
):
|
| 102 |
yield {
|
| 103 |
"input_ids": input_ids,
|
|
|
|
| 82 |
else:
|
| 83 |
example_len = 0
|
| 84 |
|
| 85 |
+
if not example_len or (
|
| 86 |
+
buffer_len + int(add_concat_token) + example_len > self.seq_length
|
|
|
|
|
|
|
| 87 |
):
|
| 88 |
if buffer["input_ids"]:
|
| 89 |
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
|
|
|
|
| 93 |
: self.seq_length
|
| 94 |
]
|
| 95 |
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
| 96 |
+
if labels.size() == input_ids.size() and (
|
| 97 |
+
attention_mask.size() == input_ids.size()
|
|
|
|
| 98 |
):
|
| 99 |
yield {
|
| 100 |
"input_ids": input_ids,
|
src/axolotl/utils/data.py
CHANGED
|
@@ -1,14 +1,12 @@
|
|
| 1 |
import logging
|
| 2 |
from hashlib import md5
|
| 3 |
from pathlib import Path
|
| 4 |
-
from typing import Union
|
| 5 |
|
| 6 |
from datasets import (
|
| 7 |
load_from_disk,
|
| 8 |
load_dataset,
|
| 9 |
-
IterableDataset,
|
| 10 |
Dataset,
|
| 11 |
-
concatenate_datasets,
|
| 12 |
DatasetDict,
|
| 13 |
)
|
| 14 |
from huggingface_hub import hf_hub_download
|
|
@@ -48,10 +46,12 @@ def load_tokenized_prepared_datasets(
|
|
| 48 |
md5(
|
| 49 |
(
|
| 50 |
str(cfg.sequence_len)
|
| 51 |
-
+ "@"
|
| 52 |
-
+ "|".join(
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
| 55 |
).encode("utf-8")
|
| 56 |
).hexdigest()
|
| 57 |
)
|
|
@@ -68,7 +68,7 @@ def load_tokenized_prepared_datasets(
|
|
| 68 |
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
|
| 69 |
)
|
| 70 |
dataset = dataset["train"]
|
| 71 |
-
except:
|
| 72 |
pass
|
| 73 |
|
| 74 |
if dataset:
|
|
@@ -109,15 +109,21 @@ def load_tokenized_prepared_datasets(
|
|
| 109 |
fp = hf_hub_download(
|
| 110 |
repo_id=d.path, repo_type="dataset", filename=d.data_files
|
| 111 |
)
|
| 112 |
-
ds: Dataset = load_dataset(
|
|
|
|
|
|
|
| 113 |
if not ds:
|
| 114 |
-
raise
|
| 115 |
# support for using a subset of the data
|
| 116 |
if d.shards:
|
| 117 |
if "train" in ds:
|
| 118 |
-
ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(
|
|
|
|
|
|
|
| 119 |
else:
|
| 120 |
-
ds: Dataset = ds.shuffle(seed=42).shard(
|
|
|
|
|
|
|
| 121 |
d_type = d.type
|
| 122 |
d_type_split = d_type.split(":")
|
| 123 |
d_base_type = d_type_split[0]
|
|
@@ -243,7 +249,7 @@ def load_tokenized_prepared_datasets(
|
|
| 243 |
|
| 244 |
def load_prepare_datasets(
|
| 245 |
tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path
|
| 246 |
-
) ->
|
| 247 |
max_packed_sequence_len = (
|
| 248 |
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
| 249 |
)
|
|
@@ -259,12 +265,14 @@ def load_prepare_datasets(
|
|
| 259 |
md5(
|
| 260 |
(
|
| 261 |
str(cfg.sequence_len)
|
| 262 |
-
+ "@"
|
| 263 |
-
+ str(max_packed_sequence_len)
|
| 264 |
-
+ seed
|
| 265 |
-
+ "|".join(
|
| 266 |
-
|
| 267 |
-
|
|
|
|
|
|
|
| 268 |
).encode("utf-8")
|
| 269 |
).hexdigest()
|
| 270 |
)
|
|
@@ -285,7 +293,7 @@ def load_prepare_datasets(
|
|
| 285 |
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
|
| 286 |
)
|
| 287 |
dataset = dataset["train"]
|
| 288 |
-
except:
|
| 289 |
pass
|
| 290 |
|
| 291 |
if dataset:
|
|
@@ -327,9 +335,9 @@ def load_prepare_datasets(
|
|
| 327 |
d
|
| 328 |
for d in dataset
|
| 329 |
if len(d["input_ids"]) < cfg.sequence_len
|
| 330 |
-
and len(d["input_ids"]) > 0
|
| 331 |
-
and len(d["input_ids"]) == len(d["attention_mask"])
|
| 332 |
-
and len(d["input_ids"]) == len(d["labels"])
|
| 333 |
]
|
| 334 |
)
|
| 335 |
|
|
|
|
| 1 |
import logging
|
| 2 |
from hashlib import md5
|
| 3 |
from pathlib import Path
|
| 4 |
+
from typing import Tuple, Union
|
| 5 |
|
| 6 |
from datasets import (
|
| 7 |
load_from_disk,
|
| 8 |
load_dataset,
|
|
|
|
| 9 |
Dataset,
|
|
|
|
| 10 |
DatasetDict,
|
| 11 |
)
|
| 12 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 46 |
md5(
|
| 47 |
(
|
| 48 |
str(cfg.sequence_len)
|
| 49 |
+
+ "@" # noqa: W503
|
| 50 |
+
+ "|".join( # noqa: W503
|
| 51 |
+
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
| 52 |
+
)
|
| 53 |
+
+ "|" # noqa: W503
|
| 54 |
+
+ tokenizer_name # noqa: W503
|
| 55 |
).encode("utf-8")
|
| 56 |
).hexdigest()
|
| 57 |
)
|
|
|
|
| 68 |
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
|
| 69 |
)
|
| 70 |
dataset = dataset["train"]
|
| 71 |
+
except Exception: # pylint: disable=broad-except
|
| 72 |
pass
|
| 73 |
|
| 74 |
if dataset:
|
|
|
|
| 109 |
fp = hf_hub_download(
|
| 110 |
repo_id=d.path, repo_type="dataset", filename=d.data_files
|
| 111 |
)
|
| 112 |
+
ds: Dataset = load_dataset(
|
| 113 |
+
"json", data_files=fp, streaming=False, split=None
|
| 114 |
+
)
|
| 115 |
if not ds:
|
| 116 |
+
raise ValueError("unhandled dataset load")
|
| 117 |
# support for using a subset of the data
|
| 118 |
if d.shards:
|
| 119 |
if "train" in ds:
|
| 120 |
+
ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(
|
| 121 |
+
num_shards=d.shards, index=0
|
| 122 |
+
)
|
| 123 |
else:
|
| 124 |
+
ds: Dataset = ds.shuffle(seed=42).shard(
|
| 125 |
+
num_shards=d.shards, index=0
|
| 126 |
+
)
|
| 127 |
d_type = d.type
|
| 128 |
d_type_split = d_type.split(":")
|
| 129 |
d_base_type = d_type_split[0]
|
|
|
|
| 249 |
|
| 250 |
def load_prepare_datasets(
|
| 251 |
tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path
|
| 252 |
+
) -> Tuple[Dataset, Dataset]:
|
| 253 |
max_packed_sequence_len = (
|
| 254 |
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
| 255 |
)
|
|
|
|
| 265 |
md5(
|
| 266 |
(
|
| 267 |
str(cfg.sequence_len)
|
| 268 |
+
+ "@" # noqa: W503
|
| 269 |
+
+ str(max_packed_sequence_len) # noqa: W503
|
| 270 |
+
+ seed # noqa: W503
|
| 271 |
+
+ "|".join( # noqa: W503
|
| 272 |
+
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
| 273 |
+
)
|
| 274 |
+
+ "|" # noqa: W503
|
| 275 |
+
+ tokenizer_name # noqa: W503
|
| 276 |
).encode("utf-8")
|
| 277 |
).hexdigest()
|
| 278 |
)
|
|
|
|
| 293 |
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
|
| 294 |
)
|
| 295 |
dataset = dataset["train"]
|
| 296 |
+
except Exception: # pylint: disable=broad-except
|
| 297 |
pass
|
| 298 |
|
| 299 |
if dataset:
|
|
|
|
| 335 |
d
|
| 336 |
for d in dataset
|
| 337 |
if len(d["input_ids"]) < cfg.sequence_len
|
| 338 |
+
and len(d["input_ids"]) > 0 # noqa: W503
|
| 339 |
+
and len(d["input_ids"]) == len(d["attention_mask"]) # noqa: W503
|
| 340 |
+
and len(d["input_ids"]) == len(d["labels"]) # noqa: W503
|
| 341 |
]
|
| 342 |
)
|
| 343 |
|
tests/test_prompters.py
CHANGED
|
@@ -12,7 +12,9 @@ class AlpacaPrompterTest(unittest.TestCase):
|
|
| 12 |
|
| 13 |
def test_prompt_style_w_instruct(self):
|
| 14 |
prompter = AlpacaPrompter(prompt_style=PromptStyle.instruct.value)
|
| 15 |
-
res = next(
|
|
|
|
|
|
|
| 16 |
assert "Below is an instruction" in res
|
| 17 |
assert "### Instruction:" in res
|
| 18 |
assert "### Input:" in res
|
|
@@ -30,7 +32,9 @@ class AlpacaPrompterTest(unittest.TestCase):
|
|
| 30 |
|
| 31 |
def test_prompt_style_w_chat(self):
|
| 32 |
prompter = AlpacaPrompter(prompt_style=PromptStyle.chat.value)
|
| 33 |
-
res = next(
|
|
|
|
|
|
|
| 34 |
assert "Below is an instruction" in res
|
| 35 |
assert "### Instruction:" not in res
|
| 36 |
assert "### Input:" not in res
|
|
@@ -45,5 +49,3 @@ class AlpacaPrompterTest(unittest.TestCase):
|
|
| 45 |
assert "### Response:" not in res
|
| 46 |
assert "USER:" in res
|
| 47 |
assert "ASSISTANT:" in res
|
| 48 |
-
|
| 49 |
-
|
|
|
|
| 12 |
|
| 13 |
def test_prompt_style_w_instruct(self):
|
| 14 |
prompter = AlpacaPrompter(prompt_style=PromptStyle.instruct.value)
|
| 15 |
+
res = next(
|
| 16 |
+
prompter.build_prompt("tell me a joke about the following", "alpacas")
|
| 17 |
+
)
|
| 18 |
assert "Below is an instruction" in res
|
| 19 |
assert "### Instruction:" in res
|
| 20 |
assert "### Input:" in res
|
|
|
|
| 32 |
|
| 33 |
def test_prompt_style_w_chat(self):
|
| 34 |
prompter = AlpacaPrompter(prompt_style=PromptStyle.chat.value)
|
| 35 |
+
res = next(
|
| 36 |
+
prompter.build_prompt("tell me a joke about the following", "alpacas")
|
| 37 |
+
)
|
| 38 |
assert "Below is an instruction" in res
|
| 39 |
assert "### Instruction:" not in res
|
| 40 |
assert "### Input:" not in res
|
|
|
|
| 49 |
assert "### Response:" not in res
|
| 50 |
assert "USER:" in res
|
| 51 |
assert "ASSISTANT:" in res
|
|
|
|
|
|