Pretrain multipack v2 (#1470)
Browse files
requirements.txt
CHANGED
|
@@ -40,3 +40,4 @@ gcsfs
|
|
| 40 |
# adlfs
|
| 41 |
|
| 42 |
trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
|
|
|
|
|
|
| 40 |
# adlfs
|
| 41 |
|
| 42 |
trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
|
| 43 |
+
zstandard==0.22.0
|
src/axolotl/utils/collators.py
CHANGED
|
@@ -217,13 +217,24 @@ class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|
| 217 |
Collator for multipack specific to the using the BatchSampler
|
| 218 |
"""
|
| 219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
def __call__(self, features, return_tensors=None):
|
| 221 |
chunked_data = {}
|
| 222 |
for feature in features.keys():
|
| 223 |
if feature == "length":
|
| 224 |
continue
|
| 225 |
if feature == "attention_mask":
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
chunked_data[feature] = np.concatenate(arrays)
|
| 228 |
else:
|
| 229 |
arrays = [np.array(item) for item in features[feature]]
|
|
|
|
| 217 |
Collator for multipack specific to the using the BatchSampler
|
| 218 |
"""
|
| 219 |
|
| 220 |
+
def __init__(self, *args, multipack_attn=True, **kwargs):
|
| 221 |
+
super().__init__(*args, **kwargs)
|
| 222 |
+
self.multipack_attn = multipack_attn
|
| 223 |
+
|
| 224 |
def __call__(self, features, return_tensors=None):
|
| 225 |
chunked_data = {}
|
| 226 |
for feature in features.keys():
|
| 227 |
if feature == "length":
|
| 228 |
continue
|
| 229 |
if feature == "attention_mask":
|
| 230 |
+
if self.multipack_attn:
|
| 231 |
+
arrays = [
|
| 232 |
+
(i + 1) * np.array(item[feature])
|
| 233 |
+
for i, item in enumerate(features[feature])
|
| 234 |
+
if feature in item
|
| 235 |
+
]
|
| 236 |
+
else:
|
| 237 |
+
arrays = [(1) * np.array(item) for item in features[feature]]
|
| 238 |
chunked_data[feature] = np.concatenate(arrays)
|
| 239 |
else:
|
| 240 |
arrays = [np.array(item) for item in features[feature]]
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
|
@@ -511,6 +511,14 @@ class AxolotlInputConfig(
|
|
| 511 |
eval_sample_packing: Optional[bool] = None
|
| 512 |
pad_to_sequence_len: Optional[bool] = None
|
| 513 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
xformers_attention: Optional[bool] = None
|
| 515 |
sdp_attention: Optional[bool] = None
|
| 516 |
s2_attention: Optional[bool] = None
|
|
|
|
| 511 |
eval_sample_packing: Optional[bool] = None
|
| 512 |
pad_to_sequence_len: Optional[bool] = None
|
| 513 |
|
| 514 |
+
pretrain_multipack_buffer_size: Optional[int] = 10_000
|
| 515 |
+
pretrain_multipack_attn: Optional[bool] = Field(
|
| 516 |
+
default=True,
|
| 517 |
+
metadata={
|
| 518 |
+
"help": "whether to prevent cross attention for packed sequences during pretraining",
|
| 519 |
+
},
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
xformers_attention: Optional[bool] = None
|
| 523 |
sdp_attention: Optional[bool] = None
|
| 524 |
s2_attention: Optional[bool] = None
|
src/axolotl/utils/data.py
CHANGED
|
@@ -108,6 +108,7 @@ def prepare_dataset(cfg, tokenizer):
|
|
| 108 |
max_tokens=cfg.sequence_len,
|
| 109 |
batch_size=cfg.micro_batch_size,
|
| 110 |
seed=cfg.seed or 42,
|
|
|
|
| 111 |
)
|
| 112 |
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
| 113 |
train_dataset = train_dataset.with_format("torch")
|
|
@@ -816,6 +817,7 @@ def wrap_pretraining_dataset(
|
|
| 816 |
return_tensors="pt",
|
| 817 |
padding=True,
|
| 818 |
pad_to_multiple_of=max_tokens * batch_size,
|
|
|
|
| 819 |
)
|
| 820 |
encode = functools.partial(
|
| 821 |
encode_packed_pretraining,
|
|
@@ -823,6 +825,7 @@ def wrap_pretraining_dataset(
|
|
| 823 |
ds_wrapper_fn,
|
| 824 |
max_seq_length=max_tokens,
|
| 825 |
batch_size=batch_size,
|
|
|
|
| 826 |
)
|
| 827 |
# set this to 1 so downstream data_loader doesn't try to increase the batch again
|
| 828 |
cfg.micro_batch_size = 1
|
|
@@ -861,6 +864,7 @@ def encode_packed_pretraining(
|
|
| 861 |
examples: Dict[str, List],
|
| 862 |
max_seq_length: int = 2048,
|
| 863 |
batch_size: int = 4,
|
|
|
|
| 864 |
) -> Dict[str, List]:
|
| 865 |
# pylint: disable=duplicate-code
|
| 866 |
# tokenize all the examples
|
|
@@ -868,7 +872,9 @@ def encode_packed_pretraining(
|
|
| 868 |
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
|
| 869 |
|
| 870 |
train_dataset = process_pretraining_datasets_for_packing(
|
| 871 |
-
train_dataset,
|
|
|
|
|
|
|
| 872 |
)
|
| 873 |
|
| 874 |
sampler = MultipackBatchSampler(
|
|
|
|
| 108 |
max_tokens=cfg.sequence_len,
|
| 109 |
batch_size=cfg.micro_batch_size,
|
| 110 |
seed=cfg.seed or 42,
|
| 111 |
+
buffer_size=cfg.pretrain_multipack_buffer_size or 10_000,
|
| 112 |
)
|
| 113 |
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
| 114 |
train_dataset = train_dataset.with_format("torch")
|
|
|
|
| 817 |
return_tensors="pt",
|
| 818 |
padding=True,
|
| 819 |
pad_to_multiple_of=max_tokens * batch_size,
|
| 820 |
+
multipack_attn=cfg.pretrain_multipack_attn,
|
| 821 |
)
|
| 822 |
encode = functools.partial(
|
| 823 |
encode_packed_pretraining,
|
|
|
|
| 825 |
ds_wrapper_fn,
|
| 826 |
max_seq_length=max_tokens,
|
| 827 |
batch_size=batch_size,
|
| 828 |
+
multipack_attn=cfg.pretrain_multipack_attn,
|
| 829 |
)
|
| 830 |
# set this to 1 so downstream data_loader doesn't try to increase the batch again
|
| 831 |
cfg.micro_batch_size = 1
|
|
|
|
| 864 |
examples: Dict[str, List],
|
| 865 |
max_seq_length: int = 2048,
|
| 866 |
batch_size: int = 4,
|
| 867 |
+
multipack_attn: Optional[bool] = False,
|
| 868 |
) -> Dict[str, List]:
|
| 869 |
# pylint: disable=duplicate-code
|
| 870 |
# tokenize all the examples
|
|
|
|
| 872 |
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
|
| 873 |
|
| 874 |
train_dataset = process_pretraining_datasets_for_packing(
|
| 875 |
+
train_dataset,
|
| 876 |
+
max_seq_length,
|
| 877 |
+
skip_position_ids=not multipack_attn,
|
| 878 |
)
|
| 879 |
|
| 880 |
sampler = MultipackBatchSampler(
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -172,17 +172,21 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|
| 172 |
return train_dataset, eval_dataset
|
| 173 |
|
| 174 |
|
| 175 |
-
def process_pretraining_datasets_for_packing(
|
|
|
|
|
|
|
| 176 |
drop_long = partial(drop_long_seq, sequence_len=sequence_len)
|
| 177 |
|
| 178 |
train_dataset = train_dataset.filter(
|
| 179 |
drop_long,
|
| 180 |
desc="Dropping Long Sequences",
|
| 181 |
)
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
|
|
|
|
|
|
| 186 |
return train_dataset
|
| 187 |
|
| 188 |
|
|
|
|
| 172 |
return train_dataset, eval_dataset
|
| 173 |
|
| 174 |
|
| 175 |
+
def process_pretraining_datasets_for_packing(
|
| 176 |
+
train_dataset, sequence_len, skip_position_ids=True
|
| 177 |
+
):
|
| 178 |
drop_long = partial(drop_long_seq, sequence_len=sequence_len)
|
| 179 |
|
| 180 |
train_dataset = train_dataset.filter(
|
| 181 |
drop_long,
|
| 182 |
desc="Dropping Long Sequences",
|
| 183 |
)
|
| 184 |
+
if skip_position_ids:
|
| 185 |
+
train_dataset = train_dataset.map(
|
| 186 |
+
add_position_ids,
|
| 187 |
+
desc="Add position_id column (Pretraining Sample Packing)",
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
return train_dataset
|
| 191 |
|
| 192 |
|