casperhansen commited on
Commit
05bd6f1
·
unverified ·
1 Parent(s): 20aa4b5

Threaded MultipackDistributedDataloader with prefetched samples (#759)

Browse files

* Multithreading implementation [WIP]

* Added benchmarking

* 35% increased throughput

* Memory pinning

* Start threads in init

* Correct print of samples

* Sleep if queue is full

* Remove pin_memory (worse)

* Simplify logic to one thread

* Remove benchmark

* Use deque for constant speed

* Formatting

* Formatting

* Formatting

* Formatting

* Rollback to use queue

* Fix multi-epoch training

* Add num epochs arg

* Start thread in __iter__

* Formatting

* Use is_alive correctly

* Simplify loading thread

src/axolotl/core/trainer_builder.py CHANGED
@@ -111,7 +111,8 @@ class AxolotlTrainer(Trainer):
111
 
112
  args = None # type: AxolotlTrainingArguments
113
 
114
- def __init__(self, *args, bench_data_collator=None, **kwargs):
 
115
  self.bench_data_collator = bench_data_collator
116
  super().__init__(*args, **kwargs)
117
 
@@ -182,6 +183,7 @@ class AxolotlTrainer(Trainer):
182
  packing_efficiency_estimate=self.args.sample_packing_efficiency,
183
  sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
184
  device_count=int(os.environ.get("WORLD_SIZE", 1)),
 
185
  )
186
  )
187
  return super().get_train_dataloader()
@@ -205,6 +207,7 @@ class AxolotlTrainer(Trainer):
205
  packing_efficiency_estimate=self.args.sample_packing_efficiency,
206
  sample_packing_seq_len_multiplier=self.args.eval_batch_size,
207
  device_count=int(os.environ.get("WORLD_SIZE", 1)),
 
208
  )
209
  )
210
  return super().get_eval_dataloader(eval_dataset)
@@ -680,6 +683,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
680
  **data_collator_kwargs,
681
  ),
682
  callbacks=self.get_callbacks(),
 
683
  **trainer_kwargs,
684
  )
685
  trainer = self.hook_post_create_trainer(trainer)
 
111
 
112
  args = None # type: AxolotlTrainingArguments
113
 
114
+ def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
115
+ self.num_epochs = num_epochs
116
  self.bench_data_collator = bench_data_collator
117
  super().__init__(*args, **kwargs)
118
 
 
183
  packing_efficiency_estimate=self.args.sample_packing_efficiency,
184
  sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
185
  device_count=int(os.environ.get("WORLD_SIZE", 1)),
186
+ num_epochs=self.num_epochs,
187
  )
188
  )
189
  return super().get_train_dataloader()
 
207
  packing_efficiency_estimate=self.args.sample_packing_efficiency,
208
  sample_packing_seq_len_multiplier=self.args.eval_batch_size,
209
  device_count=int(os.environ.get("WORLD_SIZE", 1)),
210
+ num_epochs=self.num_epochs,
211
  )
212
  )
213
  return super().get_eval_dataloader(eval_dataset)
 
683
  **data_collator_kwargs,
684
  ),
685
  callbacks=self.get_callbacks(),
686
+ num_epochs=self.cfg.num_epochs,
687
  **trainer_kwargs,
688
  )
689
  trainer = self.hook_post_create_trainer(trainer)
src/axolotl/utils/dataloader.py CHANGED
@@ -3,6 +3,9 @@ import hashlib
3
  import itertools
4
  import logging
5
  import math
 
 
 
6
  from typing import Any, Callable, List, Union
7
 
8
  import numba
@@ -149,6 +152,8 @@ class MultipackDistributedDataloader:
149
  packing_efficiency_estimate: float = 1.0,
150
  sample_packing_seq_len_multiplier: int = 1,
151
  device_count: int = 1,
 
 
152
  ):
153
  # Dataset
154
  self.dataset = dataset
@@ -167,6 +172,7 @@ class MultipackDistributedDataloader:
167
  self.seq_max_length = seq_max_length
168
  self.batch_max_length = batch_size * seq_max_length
169
  self.collate_fn = collate_fn
 
170
 
171
  self.num_replicas = 1
172
  self.rank = 0
@@ -177,6 +183,44 @@ class MultipackDistributedDataloader:
177
  self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
178
  self.device_count = device_count
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  def generate_batches(self, set_stats=False):
181
  LOG.info("generating packed batches")
182
  if self.sampler:
@@ -206,11 +250,7 @@ class MultipackDistributedDataloader:
206
 
207
  return batches, totseqs
208
 
209
- def __iter__(self):
210
- if hasattr(self.sampler, "set_epoch"):
211
- new_epoch = self.sampler.epoch + 1
212
- self.sampler.set_epoch(new_epoch)
213
- LOG.info(f"calling sampler.set_epoch({new_epoch})")
214
  all_batches, _ = self.generate_batches(set_stats=True)
215
  features = self.dataset.features.keys()
216
  len_remaining = self._len_est()
 
3
  import itertools
4
  import logging
5
  import math
6
+ import time
7
+ from queue import Queue
8
+ from threading import Thread
9
  from typing import Any, Callable, List, Union
10
 
11
  import numba
 
152
  packing_efficiency_estimate: float = 1.0,
153
  sample_packing_seq_len_multiplier: int = 1,
154
  device_count: int = 1,
155
+ prefetch_max: int = 1000,
156
+ num_epochs: int = 1,
157
  ):
158
  # Dataset
159
  self.dataset = dataset
 
172
  self.seq_max_length = seq_max_length
173
  self.batch_max_length = batch_size * seq_max_length
174
  self.collate_fn = collate_fn
175
+ self.num_epochs = num_epochs
176
 
177
  self.num_replicas = 1
178
  self.rank = 0
 
183
  self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
184
  self.device_count = device_count
185
 
186
+ # maxsize is maximum number of samples in queue
187
+ self.prefetch_max = prefetch_max
188
+ self.queue: Queue = Queue(maxsize=prefetch_max)
189
+ self.thread = None
190
+
191
+ def _worker(self):
192
+ LOG.info(
193
+ f"[WORKER] Epochs: {self.num_epochs}, Samples: {self.len_w_stats()*self.batch_size}"
194
+ )
195
+ for epoch in range(self.num_epochs):
196
+ for sample in self._internal_batch_generator():
197
+ while True:
198
+ if self.queue.full():
199
+ time.sleep(1)
200
+ else:
201
+ break
202
+ self.queue.put(sample)
203
+
204
+ # stop the queue when epoch is done
205
+ self.queue.put(None)
206
+
207
+ def __iter__(self):
208
+ if hasattr(self.sampler, "set_epoch"):
209
+ new_epoch = self.sampler.epoch + 1
210
+ self.sampler.set_epoch(new_epoch)
211
+ LOG.info(f"calling sampler.set_epoch({new_epoch})")
212
+
213
+ if self.thread is None:
214
+ self.thread = Thread(target=self._worker, daemon=True)
215
+ self.thread.start()
216
+
217
+ while True:
218
+ item = self.queue.get()
219
+
220
+ if item is None:
221
+ break
222
+ yield item
223
+
224
  def generate_batches(self, set_stats=False):
225
  LOG.info("generating packed batches")
226
  if self.sampler:
 
250
 
251
  return batches, totseqs
252
 
253
+ def _internal_batch_generator(self):
 
 
 
 
254
  all_batches, _ = self.generate_batches(set_stats=True)
255
  features = self.dataset.features.keys()
256
  len_remaining = self._len_est()
src/axolotl/utils/trainer.py CHANGED
@@ -216,6 +216,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
216
  packing_efficiency_estimate=cfg.sample_packing_eff_est,
217
  sample_packing_seq_len_multiplier=cfg.micro_batch_size,
218
  device_count=int(os.environ.get("WORLD_SIZE", 1)),
 
219
  )
220
  data_loader_len = data_loader.len_w_stats()
221
  actual_eff = data_loader.efficiency()
 
216
  packing_efficiency_estimate=cfg.sample_packing_eff_est,
217
  sample_packing_seq_len_multiplier=cfg.micro_batch_size,
218
  device_count=int(os.environ.get("WORLD_SIZE", 1)),
219
+ num_epochs=cfg.num_epochs,
220
  )
221
  data_loader_len = data_loader.len_w_stats()
222
  actual_eff = data_loader.efficiency()