Fix `total_num_steps` (#1566)
Browse files* Fix `total_num_steps`
* Fix total_num_steps
* lint
- src/axolotl/utils/trainer.py +5 -15
src/axolotl/utils/trainer.py
CHANGED
|
@@ -330,7 +330,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|
| 330 |
/ cfg.sample_packing_eff_est
|
| 331 |
/ cfg.sequence_len
|
| 332 |
// cfg.batch_size
|
| 333 |
-
// int(os.environ.get("WORLD_SIZE", 1))
|
| 334 |
)
|
| 335 |
- 1
|
| 336 |
)
|
|
@@ -359,18 +358,14 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|
| 359 |
train_dataset.remove_columns(["length"]),
|
| 360 |
batch_sampler=sampler,
|
| 361 |
)
|
| 362 |
-
data_loader_len = len(data_loader) //
|
|
|
|
|
|
|
| 363 |
actual_eff = sampler.efficiency()
|
| 364 |
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
| 365 |
# FIXME: is there a bug here somewhere? the total num steps depends
|
| 366 |
# on the agreed on value for sample_packing_eff_est
|
| 367 |
-
total_num_steps = int(
|
| 368 |
-
math.floor(
|
| 369 |
-
data_loader_len
|
| 370 |
-
* cfg.num_epochs
|
| 371 |
-
/ int(os.environ.get("WORLD_SIZE", 1))
|
| 372 |
-
)
|
| 373 |
-
)
|
| 374 |
|
| 375 |
def calc_sample_packing_eff_est(estimates: List[float]):
|
| 376 |
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
|
@@ -391,12 +386,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|
| 391 |
)
|
| 392 |
else:
|
| 393 |
total_num_steps = int(
|
| 394 |
-
math.ceil(
|
| 395 |
-
len(train_dataset)
|
| 396 |
-
* cfg.num_epochs
|
| 397 |
-
/ int(os.environ.get("WORLD_SIZE", 1))
|
| 398 |
-
/ cfg.batch_size
|
| 399 |
-
)
|
| 400 |
)
|
| 401 |
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
|
| 402 |
return total_num_steps
|
|
|
|
| 330 |
/ cfg.sample_packing_eff_est
|
| 331 |
/ cfg.sequence_len
|
| 332 |
// cfg.batch_size
|
|
|
|
| 333 |
)
|
| 334 |
- 1
|
| 335 |
)
|
|
|
|
| 358 |
train_dataset.remove_columns(["length"]),
|
| 359 |
batch_sampler=sampler,
|
| 360 |
)
|
| 361 |
+
data_loader_len = len(data_loader) // (
|
| 362 |
+
cfg.world_size * cfg.gradient_accumulation_steps
|
| 363 |
+
)
|
| 364 |
actual_eff = sampler.efficiency()
|
| 365 |
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
| 366 |
# FIXME: is there a bug here somewhere? the total num steps depends
|
| 367 |
# on the agreed on value for sample_packing_eff_est
|
| 368 |
+
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
|
| 370 |
def calc_sample_packing_eff_est(estimates: List[float]):
|
| 371 |
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
|
|
|
| 386 |
)
|
| 387 |
else:
|
| 388 |
total_num_steps = int(
|
| 389 |
+
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
)
|
| 391 |
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
|
| 392 |
return total_num_steps
|