Spaces:
Running
Running
feat(train): merge logged dict
Browse files- tools/train/train.py +8 -8
tools/train/train.py
CHANGED
|
@@ -797,7 +797,7 @@ def main():
|
|
| 797 |
|
| 798 |
# init variables
|
| 799 |
last_time = time.perf_counter()
|
| 800 |
-
|
| 801 |
|
| 802 |
for epoch in epochs:
|
| 803 |
state.replace(epoch=jax_utils.replicate(epoch))
|
|
@@ -821,20 +821,20 @@ def main():
|
|
| 821 |
last_time = new_time
|
| 822 |
|
| 823 |
# train step
|
| 824 |
-
state,
|
| 825 |
state, batch, jax_utils.replicate(delta_time)
|
| 826 |
)
|
| 827 |
step = unreplicate(state.step)
|
| 828 |
|
| 829 |
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
| 830 |
# log metrics
|
| 831 |
-
|
| 832 |
# log state parameters
|
| 833 |
state_dict = {
|
| 834 |
k.split("_")[-1]: unreplicate(getattr(state, k))
|
| 835 |
for k in ["epoch", "train_time", "train_samples"]
|
| 836 |
}
|
| 837 |
-
wandb_log(state_dict, step=step, prefix="train")
|
| 838 |
|
| 839 |
eval_metrics = None
|
| 840 |
if training_args.eval_steps and step % training_args.eval_steps == 0:
|
|
@@ -844,12 +844,12 @@ def main():
|
|
| 844 |
run_save_model(state, eval_metrics)
|
| 845 |
|
| 846 |
# log final train metrics
|
| 847 |
-
if
|
| 848 |
-
|
| 849 |
-
wandb_log(
|
| 850 |
|
| 851 |
epochs.write(
|
| 852 |
-
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {
|
| 853 |
)
|
| 854 |
|
| 855 |
# Final evaluation
|
|
|
|
| 797 |
|
| 798 |
# init variables
|
| 799 |
last_time = time.perf_counter()
|
| 800 |
+
train_metrics = None
|
| 801 |
|
| 802 |
for epoch in epochs:
|
| 803 |
state.replace(epoch=jax_utils.replicate(epoch))
|
|
|
|
| 821 |
last_time = new_time
|
| 822 |
|
| 823 |
# train step
|
| 824 |
+
state, train_metrics = p_train_step(
|
| 825 |
state, batch, jax_utils.replicate(delta_time)
|
| 826 |
)
|
| 827 |
step = unreplicate(state.step)
|
| 828 |
|
| 829 |
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
| 830 |
# log metrics
|
| 831 |
+
metrics = unreplicate(train_metrics)
|
| 832 |
# log state parameters
|
| 833 |
state_dict = {
|
| 834 |
k.split("_")[-1]: unreplicate(getattr(state, k))
|
| 835 |
for k in ["epoch", "train_time", "train_samples"]
|
| 836 |
}
|
| 837 |
+
wandb_log({**metrics, **state_dict}, step=step, prefix="train")
|
| 838 |
|
| 839 |
eval_metrics = None
|
| 840 |
if training_args.eval_steps and step % training_args.eval_steps == 0:
|
|
|
|
| 844 |
run_save_model(state, eval_metrics)
|
| 845 |
|
| 846 |
# log final train metrics
|
| 847 |
+
if train_metrics is not None:
|
| 848 |
+
train_metrics = unreplicate(train_metrics)
|
| 849 |
+
wandb_log(train_metrics, step=step, prefix="train")
|
| 850 |
|
| 851 |
epochs.write(
|
| 852 |
+
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
|
| 853 |
)
|
| 854 |
|
| 855 |
# Final evaluation
|