run eval on the first step to get a baseline (#617)
Browse files* run eval on the first step to get a baseline
* wandb kleeps getting moved around by pre-commit ...
src/axolotl/utils/callbacks.py
CHANGED
|
@@ -66,6 +66,29 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
|
|
| 66 |
return control
|
| 67 |
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
class SaveBetterTransformerModelCallback(
|
| 70 |
TrainerCallback
|
| 71 |
): # pylint: disable=too-few-public-methods
|
|
|
|
| 66 |
return control
|
| 67 |
|
| 68 |
|
| 69 |
+
class EvalFirstStepCallback(
|
| 70 |
+
TrainerCallback
|
| 71 |
+
): # pylint: disable=too-few-public-methods disable=unused-argument
|
| 72 |
+
"""
|
| 73 |
+
Callback to trigger evals on the first step
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
def on_step_end(
|
| 77 |
+
self,
|
| 78 |
+
args: TrainingArguments,
|
| 79 |
+
state: TrainerState,
|
| 80 |
+
control: TrainerControl,
|
| 81 |
+
**kwargs,
|
| 82 |
+
):
|
| 83 |
+
if (
|
| 84 |
+
args.evaluation_strategy == IntervalStrategy.STEPS
|
| 85 |
+
and args.eval_steps < 1.0
|
| 86 |
+
and state.global_step == 1
|
| 87 |
+
):
|
| 88 |
+
control.should_evaluate = True
|
| 89 |
+
return control
|
| 90 |
+
|
| 91 |
+
|
| 92 |
class SaveBetterTransformerModelCallback(
|
| 93 |
TrainerCallback
|
| 94 |
): # pylint: disable=too-few-public-methods
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -28,6 +28,7 @@ from transformers.trainer_pt_utils import SequentialDistributedSampler
|
|
| 28 |
|
| 29 |
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
| 30 |
from axolotl.utils.callbacks import (
|
|
|
|
| 31 |
GPUStatsCallback,
|
| 32 |
SaveBetterTransformerModelCallback,
|
| 33 |
SavePeftModelCallback,
|
|
@@ -704,6 +705,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
| 704 |
|
| 705 |
callbacks = []
|
| 706 |
callbacks.append(GPUStatsCallback(cfg))
|
|
|
|
| 707 |
|
| 708 |
if cfg.relora_steps:
|
| 709 |
callbacks.append(ReLoRACallback(cfg))
|
|
|
|
| 28 |
|
| 29 |
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
| 30 |
from axolotl.utils.callbacks import (
|
| 31 |
+
EvalFirstStepCallback,
|
| 32 |
GPUStatsCallback,
|
| 33 |
SaveBetterTransformerModelCallback,
|
| 34 |
SavePeftModelCallback,
|
|
|
|
| 705 |
|
| 706 |
callbacks = []
|
| 707 |
callbacks.append(GPUStatsCallback(cfg))
|
| 708 |
+
callbacks.append(EvalFirstStepCallback)
|
| 709 |
|
| 710 |
if cfg.relora_steps:
|
| 711 |
callbacks.append(ReLoRACallback(cfg))
|