Pretrain transforms (#1261)
Browse files* wip for pretraining/iterable data with arbitrary prompt strategies
* more fixes, wip
* more fixes for custom pretraining
* iterable ds wrapper not needed
* remove extra features
* chore: lint
* update pretraning example yml
* fix order for partials
* fixup for tests
- examples/tiny-llama/pretrain.yml +1 -0
- src/axolotl/datasets.py +1 -1
- src/axolotl/prompt_strategies/pretrain.py +58 -0
- src/axolotl/utils/data.py +49 -36
- tests/test_packed_pretraining.py +36 -25
examples/tiny-llama/pretrain.yml
CHANGED
|
@@ -12,6 +12,7 @@ max_steps: 200
|
|
| 12 |
pretraining_dataset:
|
| 13 |
path: c4
|
| 14 |
name: en
|
|
|
|
| 15 |
dataset_prepared_path:
|
| 16 |
val_set_size: 0.0
|
| 17 |
output_dir: ./model-out
|
|
|
|
| 12 |
pretraining_dataset:
|
| 13 |
path: c4
|
| 14 |
name: en
|
| 15 |
+
type: pretrain
|
| 16 |
dataset_prepared_path:
|
| 17 |
val_set_size: 0.0
|
| 18 |
output_dir: ./model-out
|
src/axolotl/datasets.py
CHANGED
|
@@ -31,7 +31,7 @@ class TokenizedPromptDataset(Dataset):
|
|
| 31 |
def __init__( # pylint: disable=super-init-not-called
|
| 32 |
self,
|
| 33 |
prompt_tokenizer: PromptTokenizingStrategy,
|
| 34 |
-
dataset:
|
| 35 |
process_count: Optional[int] = None,
|
| 36 |
keep_in_memory: Optional[bool] = False,
|
| 37 |
**kwargs,
|
|
|
|
| 31 |
def __init__( # pylint: disable=super-init-not-called
|
| 32 |
self,
|
| 33 |
prompt_tokenizer: PromptTokenizingStrategy,
|
| 34 |
+
dataset: Dataset,
|
| 35 |
process_count: Optional[int] = None,
|
| 36 |
keep_in_memory: Optional[bool] = False,
|
| 37 |
**kwargs,
|
src/axolotl/prompt_strategies/pretrain.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""pretraining prompt strategies"""
|
| 2 |
+
from typing import Generator
|
| 3 |
+
|
| 4 |
+
from transformers import BatchEncoding
|
| 5 |
+
|
| 6 |
+
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class PretrainTokenizer:
|
| 10 |
+
"""basic tokenization class for pretraining"""
|
| 11 |
+
|
| 12 |
+
def build_prompt(self, prompt) -> Generator[str, None, None]:
|
| 13 |
+
yield prompt
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class PretrainTokenizationStrategy(PromptTokenizingStrategy):
|
| 17 |
+
"""handles tokenization for pretraining with strides"""
|
| 18 |
+
|
| 19 |
+
@property
|
| 20 |
+
def supports_batched(self):
|
| 21 |
+
return True
|
| 22 |
+
|
| 23 |
+
def __init__(self, *args, max_length=None, **kwargs):
|
| 24 |
+
super().__init__(*args, **kwargs)
|
| 25 |
+
if max_length:
|
| 26 |
+
self.max_length = max_length
|
| 27 |
+
|
| 28 |
+
def _tokenize(
|
| 29 |
+
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
| 30 |
+
) -> BatchEncoding:
|
| 31 |
+
res = self.tokenizer(
|
| 32 |
+
prompt,
|
| 33 |
+
truncation=True,
|
| 34 |
+
max_length=self.max_length - 1,
|
| 35 |
+
add_special_tokens=True,
|
| 36 |
+
return_overflowing_tokens=True,
|
| 37 |
+
stride=256,
|
| 38 |
+
)
|
| 39 |
+
res["input_ids"] = [
|
| 40 |
+
seq + [self.tokenizer.eos_token_id] for seq in res["input_ids"]
|
| 41 |
+
]
|
| 42 |
+
res["attention_mask"] = [seq + [1] for seq in res["attention_mask"]]
|
| 43 |
+
|
| 44 |
+
return res
|
| 45 |
+
|
| 46 |
+
def tokenize_prompt(self, prompt):
|
| 47 |
+
return self._tokenize(prompt["text"])
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load(tokenizer, cfg):
|
| 51 |
+
strat = PretrainTokenizationStrategy(
|
| 52 |
+
PretrainTokenizer(),
|
| 53 |
+
tokenizer,
|
| 54 |
+
cfg.train_on_inputs,
|
| 55 |
+
cfg.sequence_len,
|
| 56 |
+
max_length=cfg.sequence_len * 64,
|
| 57 |
+
)
|
| 58 |
+
return strat
|
src/axolotl/utils/data.py
CHANGED
|
@@ -4,7 +4,7 @@ import hashlib
|
|
| 4 |
import logging
|
| 5 |
from collections import defaultdict
|
| 6 |
from pathlib import Path
|
| 7 |
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 8 |
|
| 9 |
import torch
|
| 10 |
import yaml
|
|
@@ -88,12 +88,21 @@ def prepare_dataset(cfg, tokenizer):
|
|
| 88 |
path = cfg.pretraining_dataset[0]["path"]
|
| 89 |
name = cfg.pretraining_dataset[0]["name"]
|
| 90 |
|
| 91 |
-
|
| 92 |
-
|
|
|
|
| 93 |
tokenizer,
|
| 94 |
cfg,
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
max_tokens=cfg.sequence_len,
|
|
|
|
| 97 |
seed=cfg.seed or 42,
|
| 98 |
)
|
| 99 |
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
|
@@ -383,9 +392,9 @@ def load_tokenized_prepared_datasets(
|
|
| 383 |
|
| 384 |
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
|
| 385 |
config_dataset=config_dataset,
|
| 386 |
-
dataset=ds,
|
| 387 |
tokenizer=tokenizer,
|
| 388 |
cfg=cfg,
|
|
|
|
| 389 |
d_base_type=d_base_type,
|
| 390 |
d_prompt_style=d_prompt_style,
|
| 391 |
)
|
|
@@ -496,7 +505,12 @@ def load_prepare_datasets(
|
|
| 496 |
|
| 497 |
|
| 498 |
def get_dataset_wrapper(
|
| 499 |
-
config_dataset,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
):
|
| 501 |
dataset_wrapper = None
|
| 502 |
dataset_prompter = None
|
|
@@ -507,7 +521,8 @@ def get_dataset_wrapper(
|
|
| 507 |
}
|
| 508 |
|
| 509 |
if (
|
| 510 |
-
|
|
|
|
| 511 |
and "attention_mask" in dataset.features
|
| 512 |
and "labels" in dataset.features
|
| 513 |
):
|
|
@@ -765,69 +780,60 @@ def encode_pretraining(
|
|
| 765 |
return ret
|
| 766 |
|
| 767 |
|
| 768 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 769 |
if cfg.sample_packing:
|
| 770 |
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
|
| 771 |
tokenizer,
|
| 772 |
return_tensors="pt",
|
| 773 |
padding=True,
|
| 774 |
-
pad_to_multiple_of=max_tokens *
|
| 775 |
)
|
| 776 |
encode = functools.partial(
|
| 777 |
encode_packed_pretraining,
|
| 778 |
-
tokenizer,
|
| 779 |
collate_fn,
|
|
|
|
| 780 |
max_seq_length=max_tokens,
|
| 781 |
-
batch_size=
|
| 782 |
)
|
| 783 |
# set this to 1 so downstream data_loader doesn't try to increase the batch again
|
| 784 |
cfg.micro_batch_size = 1
|
| 785 |
else:
|
| 786 |
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
| 787 |
|
| 788 |
-
dataset =
|
| 789 |
-
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
|
| 790 |
dataset = dataset.map(
|
| 791 |
encode,
|
| 792 |
batched=True,
|
| 793 |
-
batch_size=
|
| 794 |
-
input_columns="text",
|
| 795 |
# remove all the existing columns after mapping since they end up having
|
| 796 |
# a different length than the encoded/tokenized column
|
| 797 |
remove_columns=dataset.features.keys(),
|
| 798 |
-
desc="Encoding Pretraining",
|
| 799 |
)
|
| 800 |
return dataset
|
| 801 |
|
| 802 |
|
| 803 |
def encode_packed_pretraining(
|
| 804 |
-
tokenizer: PreTrainedTokenizerBase,
|
| 805 |
collate_fn,
|
| 806 |
-
|
|
|
|
| 807 |
max_seq_length: int = 2048,
|
| 808 |
batch_size: int = 4,
|
| 809 |
) -> Dict[str, List]:
|
| 810 |
# pylint: disable=duplicate-code
|
| 811 |
# tokenize all the examples
|
| 812 |
# rows get split with stride (overlap)
|
| 813 |
-
|
| 814 |
-
examples,
|
| 815 |
-
truncation=True,
|
| 816 |
-
max_length=max_seq_length - 1,
|
| 817 |
-
add_special_tokens=True,
|
| 818 |
-
return_overflowing_tokens=True,
|
| 819 |
-
stride=256,
|
| 820 |
-
)
|
| 821 |
-
|
| 822 |
-
input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]]
|
| 823 |
-
attention_mask = [seq + [1] for seq in res["attention_mask"]]
|
| 824 |
-
|
| 825 |
-
tokenized_examples = {
|
| 826 |
-
"input_ids": input_ids,
|
| 827 |
-
"attention_mask": attention_mask,
|
| 828 |
-
}
|
| 829 |
|
| 830 |
-
train_dataset = Dataset.from_dict(tokenized_examples)
|
| 831 |
train_dataset = process_pretraining_datasets_for_packing(
|
| 832 |
train_dataset, max_seq_length
|
| 833 |
)
|
|
@@ -845,7 +851,14 @@ def encode_packed_pretraining(
|
|
| 845 |
for batch in sampler:
|
| 846 |
for data in batch:
|
| 847 |
features = train_dataset[data]
|
| 848 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 849 |
collated_features = collate_fn(features)
|
| 850 |
|
| 851 |
for feature in features.keys():
|
|
|
|
| 4 |
import logging
|
| 5 |
from collections import defaultdict
|
| 6 |
from pathlib import Path
|
| 7 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 8 |
|
| 9 |
import torch
|
| 10 |
import yaml
|
|
|
|
| 88 |
path = cfg.pretraining_dataset[0]["path"]
|
| 89 |
name = cfg.pretraining_dataset[0]["name"]
|
| 90 |
|
| 91 |
+
ds_wrapper_partial = functools.partial(
|
| 92 |
+
get_dataset_wrapper,
|
| 93 |
+
cfg.pretraining_dataset[0],
|
| 94 |
tokenizer,
|
| 95 |
cfg,
|
| 96 |
+
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
train_dataset = wrap_pretraining_dataset(
|
| 100 |
+
load_dataset(path, streaming=True, split="train", name=name),
|
| 101 |
+
tokenizer,
|
| 102 |
+
cfg,
|
| 103 |
+
ds_wrapper_partial,
|
| 104 |
max_tokens=cfg.sequence_len,
|
| 105 |
+
batch_size=cfg.micro_batch_size,
|
| 106 |
seed=cfg.seed or 42,
|
| 107 |
)
|
| 108 |
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
|
|
|
| 392 |
|
| 393 |
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
|
| 394 |
config_dataset=config_dataset,
|
|
|
|
| 395 |
tokenizer=tokenizer,
|
| 396 |
cfg=cfg,
|
| 397 |
+
dataset=ds,
|
| 398 |
d_base_type=d_base_type,
|
| 399 |
d_prompt_style=d_prompt_style,
|
| 400 |
)
|
|
|
|
| 505 |
|
| 506 |
|
| 507 |
def get_dataset_wrapper(
|
| 508 |
+
config_dataset,
|
| 509 |
+
tokenizer,
|
| 510 |
+
cfg,
|
| 511 |
+
d_base_type,
|
| 512 |
+
dataset,
|
| 513 |
+
d_prompt_style=None,
|
| 514 |
):
|
| 515 |
dataset_wrapper = None
|
| 516 |
dataset_prompter = None
|
|
|
|
| 521 |
}
|
| 522 |
|
| 523 |
if (
|
| 524 |
+
isinstance(dataset, Dataset)
|
| 525 |
+
and "input_ids" in dataset.features
|
| 526 |
and "attention_mask" in dataset.features
|
| 527 |
and "labels" in dataset.features
|
| 528 |
):
|
|
|
|
| 780 |
return ret
|
| 781 |
|
| 782 |
|
| 783 |
+
def wrap_pretraining_dataset(
|
| 784 |
+
dataset,
|
| 785 |
+
tokenizer,
|
| 786 |
+
cfg,
|
| 787 |
+
ds_wrapper_fn,
|
| 788 |
+
max_tokens=2048,
|
| 789 |
+
batch_size=1,
|
| 790 |
+
seed=42,
|
| 791 |
+
buffer_size=10_000,
|
| 792 |
+
):
|
| 793 |
if cfg.sample_packing:
|
| 794 |
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
|
| 795 |
tokenizer,
|
| 796 |
return_tensors="pt",
|
| 797 |
padding=True,
|
| 798 |
+
pad_to_multiple_of=max_tokens * batch_size,
|
| 799 |
)
|
| 800 |
encode = functools.partial(
|
| 801 |
encode_packed_pretraining,
|
|
|
|
| 802 |
collate_fn,
|
| 803 |
+
ds_wrapper_fn,
|
| 804 |
max_seq_length=max_tokens,
|
| 805 |
+
batch_size=batch_size,
|
| 806 |
)
|
| 807 |
# set this to 1 so downstream data_loader doesn't try to increase the batch again
|
| 808 |
cfg.micro_batch_size = 1
|
| 809 |
else:
|
| 810 |
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
| 811 |
|
| 812 |
+
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
|
|
|
|
| 813 |
dataset = dataset.map(
|
| 814 |
encode,
|
| 815 |
batched=True,
|
| 816 |
+
batch_size=buffer_size,
|
| 817 |
+
# input_columns="text",
|
| 818 |
# remove all the existing columns after mapping since they end up having
|
| 819 |
# a different length than the encoded/tokenized column
|
| 820 |
remove_columns=dataset.features.keys(),
|
|
|
|
| 821 |
)
|
| 822 |
return dataset
|
| 823 |
|
| 824 |
|
| 825 |
def encode_packed_pretraining(
|
|
|
|
| 826 |
collate_fn,
|
| 827 |
+
ds_wrapper: Callable,
|
| 828 |
+
examples: Dict[str, List],
|
| 829 |
max_seq_length: int = 2048,
|
| 830 |
batch_size: int = 4,
|
| 831 |
) -> Dict[str, List]:
|
| 832 |
# pylint: disable=duplicate-code
|
| 833 |
# tokenize all the examples
|
| 834 |
# rows get split with stride (overlap)
|
| 835 |
+
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 836 |
|
|
|
|
| 837 |
train_dataset = process_pretraining_datasets_for_packing(
|
| 838 |
train_dataset, max_seq_length
|
| 839 |
)
|
|
|
|
| 851 |
for batch in sampler:
|
| 852 |
for data in batch:
|
| 853 |
features = train_dataset[data]
|
| 854 |
+
if "num_truncated_tokens" in features:
|
| 855 |
+
del features["num_truncated_tokens"]
|
| 856 |
+
if "num_truncated_tokens" in features:
|
| 857 |
+
del features["num_truncated_tokens"]
|
| 858 |
+
if "overflow_to_sample_mapping" in features:
|
| 859 |
+
del features["overflow_to_sample_mapping"]
|
| 860 |
+
if "labels" not in features:
|
| 861 |
+
features["labels"] = features["input_ids"].copy()
|
| 862 |
collated_features = collate_fn(features)
|
| 863 |
|
| 864 |
for feature in features.keys():
|
tests/test_packed_pretraining.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
| 1 |
"""Module for testing streaming dataset sequence packing"""
|
|
|
|
| 2 |
import unittest
|
| 3 |
-
from functools import partial
|
| 4 |
|
| 5 |
import torch
|
| 6 |
from datasets import load_dataset
|
| 7 |
from torch.utils.data import DataLoader
|
| 8 |
from transformers import AutoTokenizer
|
| 9 |
|
| 10 |
-
from axolotl.utils.
|
| 11 |
-
from axolotl.utils.
|
| 12 |
|
| 13 |
|
| 14 |
class TestPretrainingPacking(unittest.TestCase):
|
|
@@ -20,8 +20,6 @@ class TestPretrainingPacking(unittest.TestCase):
|
|
| 20 |
# pylint: disable=duplicate-code
|
| 21 |
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
| 22 |
self.tokenizer.pad_token = "</s>"
|
| 23 |
-
self.max_seq_length = 2048
|
| 24 |
-
self.batch_size = 2
|
| 25 |
|
| 26 |
def test_packing_stream_dataset(self):
|
| 27 |
# pylint: disable=duplicate-code
|
|
@@ -31,30 +29,43 @@ class TestPretrainingPacking(unittest.TestCase):
|
|
| 31 |
streaming=True,
|
| 32 |
)["train"]
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
)
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
|
|
|
| 43 |
self.tokenizer,
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
batch_size=self.batch_size,
|
| 47 |
)
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
)
|
| 55 |
|
| 56 |
trainer_loader = DataLoader(
|
| 57 |
-
|
| 58 |
batch_size=1,
|
| 59 |
collate_fn=None,
|
| 60 |
drop_last=True,
|
|
@@ -64,16 +75,16 @@ class TestPretrainingPacking(unittest.TestCase):
|
|
| 64 |
if idx > 10:
|
| 65 |
break
|
| 66 |
assert data["input_ids"].shape == torch.Size(
|
| 67 |
-
[1,
|
| 68 |
)
|
| 69 |
assert data["position_ids"].shape == torch.Size(
|
| 70 |
-
[1,
|
| 71 |
)
|
| 72 |
assert data["labels"].shape == torch.Size(
|
| 73 |
-
[1,
|
| 74 |
)
|
| 75 |
assert data["attention_mask"].shape == torch.Size(
|
| 76 |
-
[1,
|
| 77 |
)
|
| 78 |
idx += 1
|
| 79 |
|
|
|
|
| 1 |
"""Module for testing streaming dataset sequence packing"""
|
| 2 |
+
import functools
|
| 3 |
import unittest
|
|
|
|
| 4 |
|
| 5 |
import torch
|
| 6 |
from datasets import load_dataset
|
| 7 |
from torch.utils.data import DataLoader
|
| 8 |
from transformers import AutoTokenizer
|
| 9 |
|
| 10 |
+
from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
|
| 11 |
+
from axolotl.utils.dict import DictDefault
|
| 12 |
|
| 13 |
|
| 14 |
class TestPretrainingPacking(unittest.TestCase):
|
|
|
|
| 20 |
# pylint: disable=duplicate-code
|
| 21 |
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
| 22 |
self.tokenizer.pad_token = "</s>"
|
|
|
|
|
|
|
| 23 |
|
| 24 |
def test_packing_stream_dataset(self):
|
| 25 |
# pylint: disable=duplicate-code
|
|
|
|
| 29 |
streaming=True,
|
| 30 |
)["train"]
|
| 31 |
|
| 32 |
+
cfg = DictDefault(
|
| 33 |
+
{
|
| 34 |
+
"pretraining_dataset": [
|
| 35 |
+
{
|
| 36 |
+
"path": "c4",
|
| 37 |
+
"name": "en",
|
| 38 |
+
"type": "pretrain",
|
| 39 |
+
}
|
| 40 |
+
],
|
| 41 |
+
"sample_packing": True,
|
| 42 |
+
"pad_to_sequence_len": True,
|
| 43 |
+
"sequence_len": 2048,
|
| 44 |
+
"micro_batch_size": 2,
|
| 45 |
+
}
|
| 46 |
)
|
| 47 |
|
| 48 |
+
ds_wrapper_partial = functools.partial(
|
| 49 |
+
get_dataset_wrapper,
|
| 50 |
+
cfg.pretraining_dataset[0],
|
| 51 |
self.tokenizer,
|
| 52 |
+
cfg,
|
| 53 |
+
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
|
|
|
| 54 |
)
|
| 55 |
|
| 56 |
+
original_bsz = cfg.micro_batch_size
|
| 57 |
+
train_dataset = wrap_pretraining_dataset(
|
| 58 |
+
dataset,
|
| 59 |
+
self.tokenizer,
|
| 60 |
+
cfg,
|
| 61 |
+
ds_wrapper_partial,
|
| 62 |
+
max_tokens=cfg.sequence_len,
|
| 63 |
+
batch_size=cfg.micro_batch_size,
|
| 64 |
+
seed=cfg.seed or 42,
|
| 65 |
)
|
| 66 |
|
| 67 |
trainer_loader = DataLoader(
|
| 68 |
+
train_dataset,
|
| 69 |
batch_size=1,
|
| 70 |
collate_fn=None,
|
| 71 |
drop_last=True,
|
|
|
|
| 75 |
if idx > 10:
|
| 76 |
break
|
| 77 |
assert data["input_ids"].shape == torch.Size(
|
| 78 |
+
[1, original_bsz * cfg.sequence_len]
|
| 79 |
)
|
| 80 |
assert data["position_ids"].shape == torch.Size(
|
| 81 |
+
[1, original_bsz * cfg.sequence_len]
|
| 82 |
)
|
| 83 |
assert data["labels"].shape == torch.Size(
|
| 84 |
+
[1, original_bsz * cfg.sequence_len]
|
| 85 |
)
|
| 86 |
assert data["attention_mask"].shape == torch.Size(
|
| 87 |
+
[1, original_bsz * cfg.sequence_len]
|
| 88 |
)
|
| 89 |
idx += 1
|
| 90 |
|