fix axolotl training args dataclass annotation
Browse files
src/axolotl/utils/trainer.py
CHANGED
|
@@ -5,7 +5,7 @@ import logging
|
|
| 5 |
import math
|
| 6 |
import os
|
| 7 |
import sys
|
| 8 |
-
from dataclasses import field
|
| 9 |
from pathlib import Path
|
| 10 |
from typing import Optional
|
| 11 |
|
|
@@ -29,6 +29,7 @@ from axolotl.utils.schedulers import (
|
|
| 29 |
LOG = logging.getLogger("axolotl")
|
| 30 |
|
| 31 |
|
|
|
|
| 32 |
class AxolotlTrainingArguments(TrainingArguments):
|
| 33 |
"""
|
| 34 |
Extend the base TrainingArguments for axolotl helpers
|
|
@@ -188,7 +189,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 188 |
if cfg.save_safetensors:
|
| 189 |
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
| 190 |
|
| 191 |
-
training_args = AxolotlTrainingArguments(
|
| 192 |
per_device_train_batch_size=cfg.micro_batch_size,
|
| 193 |
per_device_eval_batch_size=cfg.eval_batch_size
|
| 194 |
if cfg.eval_batch_size is not None
|
|
|
|
| 5 |
import math
|
| 6 |
import os
|
| 7 |
import sys
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
from pathlib import Path
|
| 10 |
from typing import Optional
|
| 11 |
|
|
|
|
| 29 |
LOG = logging.getLogger("axolotl")
|
| 30 |
|
| 31 |
|
| 32 |
+
@dataclass
|
| 33 |
class AxolotlTrainingArguments(TrainingArguments):
|
| 34 |
"""
|
| 35 |
Extend the base TrainingArguments for axolotl helpers
|
|
|
|
| 189 |
if cfg.save_safetensors:
|
| 190 |
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
| 191 |
|
| 192 |
+
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
| 193 |
per_device_train_batch_size=cfg.micro_batch_size,
|
| 194 |
per_device_eval_batch_size=cfg.eval_batch_size
|
| 195 |
if cfg.eval_batch_size is not None
|