Spaces:
Running
Running
refactor(train): cleanup
Browse files- tools/train/train.py +51 -31
tools/train/train.py
CHANGED
|
@@ -310,12 +310,40 @@ class TrainingArguments:
|
|
| 310 |
metadata={"help": "Reference to a wandb artifact for resuming training."},
|
| 311 |
)
|
| 312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
def __post_init__(self):
|
| 314 |
assert self.optim in [
|
| 315 |
"distributed_shampoo",
|
| 316 |
"adam",
|
| 317 |
"adafactor",
|
| 318 |
], f"Selected optimizer not supported: {self.optim}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
|
| 321 |
class TrainState(train_state.TrainState):
|
|
@@ -396,17 +424,6 @@ def main():
|
|
| 396 |
else:
|
| 397 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 398 |
|
| 399 |
-
if (
|
| 400 |
-
os.path.exists(training_args.output_dir)
|
| 401 |
-
and os.listdir(training_args.output_dir)
|
| 402 |
-
and training_args.do_train
|
| 403 |
-
and not training_args.overwrite_output_dir
|
| 404 |
-
):
|
| 405 |
-
raise ValueError(
|
| 406 |
-
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
| 407 |
-
"Use --overwrite_output_dir to overcome."
|
| 408 |
-
)
|
| 409 |
-
|
| 410 |
# Make one log on every process with the configuration for debugging.
|
| 411 |
logging.basicConfig(
|
| 412 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
@@ -433,14 +450,18 @@ def main():
|
|
| 433 |
)
|
| 434 |
|
| 435 |
logger.info(f"Local TPUs: {jax.local_device_count()}")
|
| 436 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
|
| 438 |
# Set up wandb run
|
| 439 |
if jax.process_index() == 0:
|
| 440 |
wandb.init(
|
| 441 |
-
entity=
|
| 442 |
-
project=
|
| 443 |
-
job_type=
|
| 444 |
config=parser.parse_args(),
|
| 445 |
)
|
| 446 |
|
|
@@ -520,17 +541,14 @@ def main():
|
|
| 520 |
train_batch_size = (
|
| 521 |
training_args.per_device_train_batch_size * jax.local_device_count()
|
| 522 |
)
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
* training_args.gradient_accumulation_steps
|
| 526 |
-
* jax.process_count()
|
| 527 |
-
)
|
| 528 |
eval_batch_size = (
|
| 529 |
training_args.per_device_eval_batch_size * jax.local_device_count()
|
| 530 |
)
|
| 531 |
len_train_dataset, len_eval_dataset = dataset.length
|
| 532 |
steps_per_epoch = (
|
| 533 |
-
len_train_dataset //
|
| 534 |
if len_train_dataset is not None
|
| 535 |
else None
|
| 536 |
)
|
|
@@ -708,14 +726,12 @@ def main():
|
|
| 708 |
grads=grads,
|
| 709 |
dropout_rng=new_dropout_rng,
|
| 710 |
train_time=state.train_time + delta_time,
|
| 711 |
-
train_samples=state.train_samples +
|
| 712 |
)
|
| 713 |
|
| 714 |
metrics = {
|
| 715 |
"loss": loss,
|
| 716 |
-
"learning_rate": learning_rate_fn(
|
| 717 |
-
state.step // training_args.gradient_accumulation_steps
|
| 718 |
-
),
|
| 719 |
}
|
| 720 |
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
| 721 |
|
|
@@ -733,19 +749,20 @@ def main():
|
|
| 733 |
return metrics
|
| 734 |
|
| 735 |
# Create parallel version of the train and eval step
|
| 736 |
-
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
| 737 |
-
p_eval_step = jax.pmap(eval_step, "batch")
|
| 738 |
|
| 739 |
logger.info("***** Running training *****")
|
| 740 |
logger.info(f" Num examples = {len_train_dataset}")
|
| 741 |
logger.info(f" Num Epochs = {num_epochs}")
|
| 742 |
logger.info(
|
| 743 |
-
f"
|
| 744 |
)
|
| 745 |
logger.info(f" Number of devices = {jax.device_count()}")
|
| 746 |
logger.info(
|
| 747 |
-
f"
|
| 748 |
)
|
|
|
|
| 749 |
logger.info(f" Model parameters = {num_params:,}")
|
| 750 |
epochs = tqdm(
|
| 751 |
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
|
@@ -762,8 +779,9 @@ def main():
|
|
| 762 |
{
|
| 763 |
"len_train_dataset": len_train_dataset,
|
| 764 |
"len_eval_dataset": len_eval_dataset,
|
| 765 |
-
"
|
| 766 |
"num_params": num_params,
|
|
|
|
| 767 |
}
|
| 768 |
)
|
| 769 |
|
|
@@ -774,7 +792,9 @@ def main():
|
|
| 774 |
# ======================== Evaluating ==============================
|
| 775 |
eval_metrics = []
|
| 776 |
if training_args.do_eval:
|
| 777 |
-
eval_loader = dataset.dataloader(
|
|
|
|
|
|
|
| 778 |
eval_steps = (
|
| 779 |
len_eval_dataset // eval_batch_size
|
| 780 |
if len_eval_dataset is not None
|
|
|
|
| 310 |
metadata={"help": "Reference to a wandb artifact for resuming training."},
|
| 311 |
)
|
| 312 |
|
| 313 |
+
wandb_entity: Optional[str] = field(
|
| 314 |
+
default=None,
|
| 315 |
+
metadata={"help": "The wandb entity to use (for teams)."},
|
| 316 |
+
)
|
| 317 |
+
wandb_project: str = field(
|
| 318 |
+
default="dalle-mini",
|
| 319 |
+
metadata={"help": "The name of the wandb project."},
|
| 320 |
+
)
|
| 321 |
+
wandb_job_type: str = field(
|
| 322 |
+
default="Seq2Seq",
|
| 323 |
+
metadata={"help": "The name of the wandb job type."},
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
assert_TPU_available: bool = field(
|
| 327 |
+
default=False,
|
| 328 |
+
metadata={"help": "Verify that TPU is not in use."},
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
def __post_init__(self):
|
| 332 |
assert self.optim in [
|
| 333 |
"distributed_shampoo",
|
| 334 |
"adam",
|
| 335 |
"adafactor",
|
| 336 |
], f"Selected optimizer not supported: {self.optim}"
|
| 337 |
+
if (
|
| 338 |
+
os.path.exists(self.output_dir)
|
| 339 |
+
and os.listdir(self.output_dir)
|
| 340 |
+
and self.do_train
|
| 341 |
+
and not self.overwrite_output_dir
|
| 342 |
+
):
|
| 343 |
+
raise ValueError(
|
| 344 |
+
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
| 345 |
+
"Use --overwrite_output_dir to overcome."
|
| 346 |
+
)
|
| 347 |
|
| 348 |
|
| 349 |
class TrainState(train_state.TrainState):
|
|
|
|
| 424 |
else:
|
| 425 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 426 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
# Make one log on every process with the configuration for debugging.
|
| 428 |
logging.basicConfig(
|
| 429 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
|
|
| 450 |
)
|
| 451 |
|
| 452 |
logger.info(f"Local TPUs: {jax.local_device_count()}")
|
| 453 |
+
logger.info(f"Global TPUs: {jax.device_count()}")
|
| 454 |
+
if training_args.assert_TPU_available:
|
| 455 |
+
assert (
|
| 456 |
+
jax.local_device_count() == 8
|
| 457 |
+
), "TPUs in use, please check running processes"
|
| 458 |
|
| 459 |
# Set up wandb run
|
| 460 |
if jax.process_index() == 0:
|
| 461 |
wandb.init(
|
| 462 |
+
entity=training_args.wandb_entity,
|
| 463 |
+
project=training_args.wandb_project,
|
| 464 |
+
job_type=training_args.wandb_job_type,
|
| 465 |
config=parser.parse_args(),
|
| 466 |
)
|
| 467 |
|
|
|
|
| 541 |
train_batch_size = (
|
| 542 |
training_args.per_device_train_batch_size * jax.local_device_count()
|
| 543 |
)
|
| 544 |
+
batch_size_per_node = train_batch_size * training_args.gradient_accumulation_steps
|
| 545 |
+
batch_size_per_step = batch_size_per_node * jax.process_count()
|
|
|
|
|
|
|
|
|
|
| 546 |
eval_batch_size = (
|
| 547 |
training_args.per_device_eval_batch_size * jax.local_device_count()
|
| 548 |
)
|
| 549 |
len_train_dataset, len_eval_dataset = dataset.length
|
| 550 |
steps_per_epoch = (
|
| 551 |
+
len_train_dataset // batch_size_per_node
|
| 552 |
if len_train_dataset is not None
|
| 553 |
else None
|
| 554 |
)
|
|
|
|
| 726 |
grads=grads,
|
| 727 |
dropout_rng=new_dropout_rng,
|
| 728 |
train_time=state.train_time + delta_time,
|
| 729 |
+
train_samples=state.train_samples + batch_size_per_step,
|
| 730 |
)
|
| 731 |
|
| 732 |
metrics = {
|
| 733 |
"loss": loss,
|
| 734 |
+
"learning_rate": learning_rate_fn(state.step),
|
|
|
|
|
|
|
| 735 |
}
|
| 736 |
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
| 737 |
|
|
|
|
| 749 |
return metrics
|
| 750 |
|
| 751 |
# Create parallel version of the train and eval step
|
| 752 |
+
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, 1))
|
| 753 |
+
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(1,))
|
| 754 |
|
| 755 |
logger.info("***** Running training *****")
|
| 756 |
logger.info(f" Num examples = {len_train_dataset}")
|
| 757 |
logger.info(f" Num Epochs = {num_epochs}")
|
| 758 |
logger.info(
|
| 759 |
+
f" Batch size per device = {training_args.per_device_train_batch_size}"
|
| 760 |
)
|
| 761 |
logger.info(f" Number of devices = {jax.device_count()}")
|
| 762 |
logger.info(
|
| 763 |
+
f" Gradient accumulation steps = {training_args.gradient_accumulation_steps}"
|
| 764 |
)
|
| 765 |
+
logger.info(f" Batch size per update = {batch_size_per_step}")
|
| 766 |
logger.info(f" Model parameters = {num_params:,}")
|
| 767 |
epochs = tqdm(
|
| 768 |
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
|
|
|
| 779 |
{
|
| 780 |
"len_train_dataset": len_train_dataset,
|
| 781 |
"len_eval_dataset": len_eval_dataset,
|
| 782 |
+
"batch_size_per_step": batch_size_per_step,
|
| 783 |
"num_params": num_params,
|
| 784 |
+
"num_devices": jax.device_count(),
|
| 785 |
}
|
| 786 |
)
|
| 787 |
|
|
|
|
| 792 |
# ======================== Evaluating ==============================
|
| 793 |
eval_metrics = []
|
| 794 |
if training_args.do_eval:
|
| 795 |
+
eval_loader = dataset.dataloader(
|
| 796 |
+
"eval", training_args.per_device_eval_batch_size
|
| 797 |
+
)
|
| 798 |
eval_steps = (
|
| 799 |
len_eval_dataset // eval_batch_size
|
| 800 |
if len_eval_dataset is not None
|