winglian commited on
Commit
ba45531
·
unverified ·
1 Parent(s): 8a1572a

fixes to save on fractional save_steps (#1643)

Browse files
src/axolotl/core/trainer_builder.py CHANGED
@@ -43,7 +43,7 @@ from axolotl.utils.callbacks import (
43
  LossWatchDogCallback,
44
  SaveAxolotlConfigtoWandBCallback,
45
  SaveBetterTransformerModelCallback,
46
- SaveModelOnTrainEndCallback,
47
  bench_eval_callback_factory,
48
  causal_lm_bench_eval_callback_factory,
49
  log_prediction_callback_factory,
@@ -945,7 +945,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
945
  if self.cfg.loss_watchdog_threshold is not None:
946
  callbacks.append(LossWatchDogCallback(self.cfg))
947
 
948
- callbacks.append(SaveModelOnTrainEndCallback())
949
 
950
  return callbacks
951
 
@@ -1431,7 +1431,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
1431
 
1432
  def get_callbacks(self):
1433
  callbacks = super().get_callbacks()
1434
- callbacks.append(SaveModelOnTrainEndCallback())
1435
 
1436
  return callbacks
1437
 
 
43
  LossWatchDogCallback,
44
  SaveAxolotlConfigtoWandBCallback,
45
  SaveBetterTransformerModelCallback,
46
+ SaveModelCallback,
47
  bench_eval_callback_factory,
48
  causal_lm_bench_eval_callback_factory,
49
  log_prediction_callback_factory,
 
945
  if self.cfg.loss_watchdog_threshold is not None:
946
  callbacks.append(LossWatchDogCallback(self.cfg))
947
 
948
+ callbacks.append(SaveModelCallback())
949
 
950
  return callbacks
951
 
 
1431
 
1432
  def get_callbacks(self):
1433
  callbacks = super().get_callbacks()
1434
+ callbacks.append(SaveModelCallback())
1435
 
1436
  return callbacks
1437
 
src/axolotl/utils/callbacks/__init__.py CHANGED
@@ -3,6 +3,7 @@
3
  from __future__ import annotations
4
 
5
  import logging
 
6
  import os
7
  from shutil import copyfile
8
  from tempfile import NamedTemporaryFile
@@ -775,7 +776,7 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
775
  return control
776
 
777
 
778
- class SaveModelOnTrainEndCallback(TrainerCallback):
779
  """Callback to save model on train end"""
780
 
781
  def on_step_end( # pylint: disable=unused-argument
@@ -788,6 +789,13 @@ class SaveModelOnTrainEndCallback(TrainerCallback):
788
  # Save
789
  if state.global_step >= state.max_steps:
790
  control.should_save = True
 
 
 
 
 
 
 
791
 
792
  def on_train_end( # pylint: disable=unused-argument
793
  self, args, state, control, **kwargs
 
3
  from __future__ import annotations
4
 
5
  import logging
6
+ import math
7
  import os
8
  from shutil import copyfile
9
  from tempfile import NamedTemporaryFile
 
776
  return control
777
 
778
 
779
+ class SaveModelCallback(TrainerCallback):
780
  """Callback to save model on train end"""
781
 
782
  def on_step_end( # pylint: disable=unused-argument
 
789
  # Save
790
  if state.global_step >= state.max_steps:
791
  control.should_save = True
792
+ elif (
793
+ args.save_strategy == IntervalStrategy.STEPS
794
+ and state.save_steps < 1.0
795
+ and state.global_step % math.ceil(state.save_steps * state.max_steps) == 0
796
+ ):
797
+ # workaround to save model on fractional save_steps
798
+ control.should_save = True
799
 
800
  def on_train_end( # pylint: disable=unused-argument
801
  self, args, state, control, **kwargs