fixes to save on fractional save_steps (#1643)
Browse files
src/axolotl/core/trainer_builder.py
CHANGED
@@ -43,7 +43,7 @@ from axolotl.utils.callbacks import (
|
|
43 |
LossWatchDogCallback,
|
44 |
SaveAxolotlConfigtoWandBCallback,
|
45 |
SaveBetterTransformerModelCallback,
|
46 |
-
|
47 |
bench_eval_callback_factory,
|
48 |
causal_lm_bench_eval_callback_factory,
|
49 |
log_prediction_callback_factory,
|
@@ -945,7 +945,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
945 |
if self.cfg.loss_watchdog_threshold is not None:
|
946 |
callbacks.append(LossWatchDogCallback(self.cfg))
|
947 |
|
948 |
-
callbacks.append(
|
949 |
|
950 |
return callbacks
|
951 |
|
@@ -1431,7 +1431,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|
1431 |
|
1432 |
def get_callbacks(self):
|
1433 |
callbacks = super().get_callbacks()
|
1434 |
-
callbacks.append(
|
1435 |
|
1436 |
return callbacks
|
1437 |
|
|
|
43 |
LossWatchDogCallback,
|
44 |
SaveAxolotlConfigtoWandBCallback,
|
45 |
SaveBetterTransformerModelCallback,
|
46 |
+
SaveModelCallback,
|
47 |
bench_eval_callback_factory,
|
48 |
causal_lm_bench_eval_callback_factory,
|
49 |
log_prediction_callback_factory,
|
|
|
945 |
if self.cfg.loss_watchdog_threshold is not None:
|
946 |
callbacks.append(LossWatchDogCallback(self.cfg))
|
947 |
|
948 |
+
callbacks.append(SaveModelCallback())
|
949 |
|
950 |
return callbacks
|
951 |
|
|
|
1431 |
|
1432 |
def get_callbacks(self):
|
1433 |
callbacks = super().get_callbacks()
|
1434 |
+
callbacks.append(SaveModelCallback())
|
1435 |
|
1436 |
return callbacks
|
1437 |
|
src/axolotl/utils/callbacks/__init__.py
CHANGED
@@ -3,6 +3,7 @@
|
|
3 |
from __future__ import annotations
|
4 |
|
5 |
import logging
|
|
|
6 |
import os
|
7 |
from shutil import copyfile
|
8 |
from tempfile import NamedTemporaryFile
|
@@ -775,7 +776,7 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|
775 |
return control
|
776 |
|
777 |
|
778 |
-
class
|
779 |
"""Callback to save model on train end"""
|
780 |
|
781 |
def on_step_end( # pylint: disable=unused-argument
|
@@ -788,6 +789,13 @@ class SaveModelOnTrainEndCallback(TrainerCallback):
|
|
788 |
# Save
|
789 |
if state.global_step >= state.max_steps:
|
790 |
control.should_save = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
791 |
|
792 |
def on_train_end( # pylint: disable=unused-argument
|
793 |
self, args, state, control, **kwargs
|
|
|
3 |
from __future__ import annotations
|
4 |
|
5 |
import logging
|
6 |
+
import math
|
7 |
import os
|
8 |
from shutil import copyfile
|
9 |
from tempfile import NamedTemporaryFile
|
|
|
776 |
return control
|
777 |
|
778 |
|
779 |
+
class SaveModelCallback(TrainerCallback):
|
780 |
"""Callback to save model on train end"""
|
781 |
|
782 |
def on_step_end( # pylint: disable=unused-argument
|
|
|
789 |
# Save
|
790 |
if state.global_step >= state.max_steps:
|
791 |
control.should_save = True
|
792 |
+
elif (
|
793 |
+
args.save_strategy == IntervalStrategy.STEPS
|
794 |
+
and state.save_steps < 1.0
|
795 |
+
and state.global_step % math.ceil(state.save_steps * state.max_steps) == 0
|
796 |
+
):
|
797 |
+
# workaround to save model on fractional save_steps
|
798 |
+
control.should_save = True
|
799 |
|
800 |
def on_train_end( # pylint: disable=unused-argument
|
801 |
self, args, state, control, **kwargs
|