Spaces:
Running
Running
feat(train): cleanup args
Browse files- tools/train/train.py +21 -17
tools/train/train.py
CHANGED
|
@@ -199,8 +199,11 @@ class TrainingArguments:
|
|
| 199 |
per_device_train_batch_size: int = field(
|
| 200 |
default=8, metadata={"help": "Batch size per GPU/TPU/CPU for training."}
|
| 201 |
)
|
| 202 |
-
per_device_eval_batch_size: int = field(
|
| 203 |
-
default=
|
|
|
|
|
|
|
|
|
|
| 204 |
)
|
| 205 |
|
| 206 |
gradient_accumulation_steps: int = field(
|
|
@@ -252,6 +255,13 @@ class TrainingArguments:
|
|
| 252 |
},
|
| 253 |
)
|
| 254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
lr_decay: str = field(
|
| 256 |
default=None,
|
| 257 |
metadata={
|
|
@@ -277,13 +287,6 @@ class TrainingArguments:
|
|
| 277 |
},
|
| 278 |
)
|
| 279 |
|
| 280 |
-
num_train_epochs: int = field(
|
| 281 |
-
default=3, metadata={"help": "Total number of training epochs to perform."}
|
| 282 |
-
)
|
| 283 |
-
warmup_steps: int = field(
|
| 284 |
-
default=0, metadata={"help": "Linear warmup over warmup_steps."}
|
| 285 |
-
)
|
| 286 |
-
|
| 287 |
logging_steps: int = field(
|
| 288 |
default=40, metadata={"help": "Log every X updates steps."}
|
| 289 |
)
|
|
@@ -334,6 +337,11 @@ class TrainingArguments:
|
|
| 334 |
"adam",
|
| 335 |
"adafactor",
|
| 336 |
], f"Selected optimizer not supported: {self.optim}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
if (
|
| 338 |
os.path.exists(self.output_dir)
|
| 339 |
and os.listdir(self.output_dir)
|
|
@@ -623,9 +631,7 @@ def main():
|
|
| 623 |
beta2=training_args.beta2,
|
| 624 |
diagonal_epsilon=1e-10,
|
| 625 |
matrix_epsilon=1e-8,
|
| 626 |
-
weight_decay=training_args.weight_decay
|
| 627 |
-
if training_args.weight_decay is not None
|
| 628 |
-
else 0.0,
|
| 629 |
start_preconditioning_step=training_args.warmup_steps,
|
| 630 |
preconditioning_compute_steps=training_args.preconditioning_compute_steps,
|
| 631 |
statistics_compute_steps=1,
|
|
@@ -648,9 +654,7 @@ def main():
|
|
| 648 |
b1=training_args.beta1,
|
| 649 |
b2=training_args.beta2,
|
| 650 |
eps=training_args.adam_epsilon,
|
| 651 |
-
weight_decay=training_args.weight_decay
|
| 652 |
-
if training_args.weight_decay is not None
|
| 653 |
-
else 0.0,
|
| 654 |
mask=decay_mask_fn,
|
| 655 |
)
|
| 656 |
elif training_args.optim == "adafactor":
|
|
@@ -749,8 +753,8 @@ def main():
|
|
| 749 |
return metrics
|
| 750 |
|
| 751 |
# Create parallel version of the train and eval step
|
| 752 |
-
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,
|
| 753 |
-
p_eval_step = jax.pmap(eval_step, "batch"
|
| 754 |
|
| 755 |
logger.info("***** Running training *****")
|
| 756 |
logger.info(f" Num examples = {len_train_dataset}")
|
|
|
|
| 199 |
per_device_train_batch_size: int = field(
|
| 200 |
default=8, metadata={"help": "Batch size per GPU/TPU/CPU for training."}
|
| 201 |
)
|
| 202 |
+
per_device_eval_batch_size: Optional[int] = field(
|
| 203 |
+
default=None,
|
| 204 |
+
metadata={
|
| 205 |
+
"help": "Batch size per GPU/TPU/CPU for evaluation. Same as training batch size if not set."
|
| 206 |
+
},
|
| 207 |
)
|
| 208 |
|
| 209 |
gradient_accumulation_steps: int = field(
|
|
|
|
| 255 |
},
|
| 256 |
)
|
| 257 |
|
| 258 |
+
num_train_epochs: int = field(
|
| 259 |
+
default=3, metadata={"help": "Total number of training epochs to perform."}
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
warmup_steps: int = field(
|
| 263 |
+
default=0, metadata={"help": "Linear warmup over warmup_steps."}
|
| 264 |
+
)
|
| 265 |
lr_decay: str = field(
|
| 266 |
default=None,
|
| 267 |
metadata={
|
|
|
|
| 287 |
},
|
| 288 |
)
|
| 289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
logging_steps: int = field(
|
| 291 |
default=40, metadata={"help": "Log every X updates steps."}
|
| 292 |
)
|
|
|
|
| 337 |
"adam",
|
| 338 |
"adafactor",
|
| 339 |
], f"Selected optimizer not supported: {self.optim}"
|
| 340 |
+
if self.per_device_eval_batch_size is None:
|
| 341 |
+
self.per_device_eval_batch_size = self.per_device_train_batch_size
|
| 342 |
+
if self.weight_decay is None:
|
| 343 |
+
if self.optim in ["distributed_shampoo", "adam"]:
|
| 344 |
+
self.weight_decay = 0.0
|
| 345 |
if (
|
| 346 |
os.path.exists(self.output_dir)
|
| 347 |
and os.listdir(self.output_dir)
|
|
|
|
| 631 |
beta2=training_args.beta2,
|
| 632 |
diagonal_epsilon=1e-10,
|
| 633 |
matrix_epsilon=1e-8,
|
| 634 |
+
weight_decay=training_args.weight_decay,
|
|
|
|
|
|
|
| 635 |
start_preconditioning_step=training_args.warmup_steps,
|
| 636 |
preconditioning_compute_steps=training_args.preconditioning_compute_steps,
|
| 637 |
statistics_compute_steps=1,
|
|
|
|
| 654 |
b1=training_args.beta1,
|
| 655 |
b2=training_args.beta2,
|
| 656 |
eps=training_args.adam_epsilon,
|
| 657 |
+
weight_decay=training_args.weight_decay,
|
|
|
|
|
|
|
| 658 |
mask=decay_mask_fn,
|
| 659 |
)
|
| 660 |
elif training_args.optim == "adafactor":
|
|
|
|
| 753 |
return metrics
|
| 754 |
|
| 755 |
# Create parallel version of the train and eval step
|
| 756 |
+
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
| 757 |
+
p_eval_step = jax.pmap(eval_step, "batch")
|
| 758 |
|
| 759 |
logger.info("***** Running training *****")
|
| 760 |
logger.info(f" Num examples = {len_train_dataset}")
|