Threaded MultipackDistributedDataloader with prefetched samples (#759)
Browse files* Multithreading implementation [WIP]
* Added benchmarking
* 35% increased throughput
* Memory pinning
* Start threads in init
* Correct print of samples
* Sleep if queue is full
* Remove pin_memory (worse)
* Simplify logic to one thread
* Remove benchmark
* Use deque for constant speed
* Formatting
* Formatting
* Formatting
* Formatting
* Rollback to use queue
* Fix multi-epoch training
* Add num epochs arg
* Start thread in __iter__
* Formatting
* Use is_alive correctly
* Simplify loading thread
src/axolotl/core/trainer_builder.py
CHANGED
@@ -111,7 +111,8 @@ class AxolotlTrainer(Trainer):
|
|
111 |
|
112 |
args = None # type: AxolotlTrainingArguments
|
113 |
|
114 |
-
def __init__(self, *args, bench_data_collator=None, **kwargs):
|
|
|
115 |
self.bench_data_collator = bench_data_collator
|
116 |
super().__init__(*args, **kwargs)
|
117 |
|
@@ -182,6 +183,7 @@ class AxolotlTrainer(Trainer):
|
|
182 |
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
183 |
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
184 |
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
|
|
185 |
)
|
186 |
)
|
187 |
return super().get_train_dataloader()
|
@@ -205,6 +207,7 @@ class AxolotlTrainer(Trainer):
|
|
205 |
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
206 |
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
207 |
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
|
|
208 |
)
|
209 |
)
|
210 |
return super().get_eval_dataloader(eval_dataset)
|
@@ -680,6 +683,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
680 |
**data_collator_kwargs,
|
681 |
),
|
682 |
callbacks=self.get_callbacks(),
|
|
|
683 |
**trainer_kwargs,
|
684 |
)
|
685 |
trainer = self.hook_post_create_trainer(trainer)
|
|
|
111 |
|
112 |
args = None # type: AxolotlTrainingArguments
|
113 |
|
114 |
+
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
|
115 |
+
self.num_epochs = num_epochs
|
116 |
self.bench_data_collator = bench_data_collator
|
117 |
super().__init__(*args, **kwargs)
|
118 |
|
|
|
183 |
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
184 |
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
185 |
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
186 |
+
num_epochs=self.num_epochs,
|
187 |
)
|
188 |
)
|
189 |
return super().get_train_dataloader()
|
|
|
207 |
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
208 |
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
209 |
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
210 |
+
num_epochs=self.num_epochs,
|
211 |
)
|
212 |
)
|
213 |
return super().get_eval_dataloader(eval_dataset)
|
|
|
683 |
**data_collator_kwargs,
|
684 |
),
|
685 |
callbacks=self.get_callbacks(),
|
686 |
+
num_epochs=self.cfg.num_epochs,
|
687 |
**trainer_kwargs,
|
688 |
)
|
689 |
trainer = self.hook_post_create_trainer(trainer)
|
src/axolotl/utils/dataloader.py
CHANGED
@@ -3,6 +3,9 @@ import hashlib
|
|
3 |
import itertools
|
4 |
import logging
|
5 |
import math
|
|
|
|
|
|
|
6 |
from typing import Any, Callable, List, Union
|
7 |
|
8 |
import numba
|
@@ -149,6 +152,8 @@ class MultipackDistributedDataloader:
|
|
149 |
packing_efficiency_estimate: float = 1.0,
|
150 |
sample_packing_seq_len_multiplier: int = 1,
|
151 |
device_count: int = 1,
|
|
|
|
|
152 |
):
|
153 |
# Dataset
|
154 |
self.dataset = dataset
|
@@ -167,6 +172,7 @@ class MultipackDistributedDataloader:
|
|
167 |
self.seq_max_length = seq_max_length
|
168 |
self.batch_max_length = batch_size * seq_max_length
|
169 |
self.collate_fn = collate_fn
|
|
|
170 |
|
171 |
self.num_replicas = 1
|
172 |
self.rank = 0
|
@@ -177,6 +183,44 @@ class MultipackDistributedDataloader:
|
|
177 |
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
178 |
self.device_count = device_count
|
179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
def generate_batches(self, set_stats=False):
|
181 |
LOG.info("generating packed batches")
|
182 |
if self.sampler:
|
@@ -206,11 +250,7 @@ class MultipackDistributedDataloader:
|
|
206 |
|
207 |
return batches, totseqs
|
208 |
|
209 |
-
def
|
210 |
-
if hasattr(self.sampler, "set_epoch"):
|
211 |
-
new_epoch = self.sampler.epoch + 1
|
212 |
-
self.sampler.set_epoch(new_epoch)
|
213 |
-
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
214 |
all_batches, _ = self.generate_batches(set_stats=True)
|
215 |
features = self.dataset.features.keys()
|
216 |
len_remaining = self._len_est()
|
|
|
3 |
import itertools
|
4 |
import logging
|
5 |
import math
|
6 |
+
import time
|
7 |
+
from queue import Queue
|
8 |
+
from threading import Thread
|
9 |
from typing import Any, Callable, List, Union
|
10 |
|
11 |
import numba
|
|
|
152 |
packing_efficiency_estimate: float = 1.0,
|
153 |
sample_packing_seq_len_multiplier: int = 1,
|
154 |
device_count: int = 1,
|
155 |
+
prefetch_max: int = 1000,
|
156 |
+
num_epochs: int = 1,
|
157 |
):
|
158 |
# Dataset
|
159 |
self.dataset = dataset
|
|
|
172 |
self.seq_max_length = seq_max_length
|
173 |
self.batch_max_length = batch_size * seq_max_length
|
174 |
self.collate_fn = collate_fn
|
175 |
+
self.num_epochs = num_epochs
|
176 |
|
177 |
self.num_replicas = 1
|
178 |
self.rank = 0
|
|
|
183 |
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
184 |
self.device_count = device_count
|
185 |
|
186 |
+
# maxsize is maximum number of samples in queue
|
187 |
+
self.prefetch_max = prefetch_max
|
188 |
+
self.queue: Queue = Queue(maxsize=prefetch_max)
|
189 |
+
self.thread = None
|
190 |
+
|
191 |
+
def _worker(self):
|
192 |
+
LOG.info(
|
193 |
+
f"[WORKER] Epochs: {self.num_epochs}, Samples: {self.len_w_stats()*self.batch_size}"
|
194 |
+
)
|
195 |
+
for epoch in range(self.num_epochs):
|
196 |
+
for sample in self._internal_batch_generator():
|
197 |
+
while True:
|
198 |
+
if self.queue.full():
|
199 |
+
time.sleep(1)
|
200 |
+
else:
|
201 |
+
break
|
202 |
+
self.queue.put(sample)
|
203 |
+
|
204 |
+
# stop the queue when epoch is done
|
205 |
+
self.queue.put(None)
|
206 |
+
|
207 |
+
def __iter__(self):
|
208 |
+
if hasattr(self.sampler, "set_epoch"):
|
209 |
+
new_epoch = self.sampler.epoch + 1
|
210 |
+
self.sampler.set_epoch(new_epoch)
|
211 |
+
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
212 |
+
|
213 |
+
if self.thread is None:
|
214 |
+
self.thread = Thread(target=self._worker, daemon=True)
|
215 |
+
self.thread.start()
|
216 |
+
|
217 |
+
while True:
|
218 |
+
item = self.queue.get()
|
219 |
+
|
220 |
+
if item is None:
|
221 |
+
break
|
222 |
+
yield item
|
223 |
+
|
224 |
def generate_batches(self, set_stats=False):
|
225 |
LOG.info("generating packed batches")
|
226 |
if self.sampler:
|
|
|
250 |
|
251 |
return batches, totseqs
|
252 |
|
253 |
+
def _internal_batch_generator(self):
|
|
|
|
|
|
|
|
|
254 |
all_batches, _ = self.generate_batches(set_stats=True)
|
255 |
features = self.dataset.features.keys()
|
256 |
len_remaining = self._len_est()
|
src/axolotl/utils/trainer.py
CHANGED
@@ -216,6 +216,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
|
216 |
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
217 |
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
218 |
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
|
|
219 |
)
|
220 |
data_loader_len = data_loader.len_w_stats()
|
221 |
actual_eff = data_loader.efficiency()
|
|
|
216 |
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
217 |
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
218 |
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
219 |
+
num_epochs=cfg.num_epochs,
|
220 |
)
|
221 |
data_loader_len = data_loader.len_w_stats()
|
222 |
actual_eff = data_loader.efficiency()
|