Spaces:
Running
Running
fix: log train_metric only if defined
Browse files
dev/seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -336,10 +336,7 @@ def main():
|
|
| 336 |
# Set the verbosity to info of the Transformers logger (on main process only):
|
| 337 |
logger.info(f"Training/evaluation parameters {training_args}")
|
| 338 |
|
| 339 |
-
#
|
| 340 |
-
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
| 341 |
-
# (the dataset will be downloaded automatically from the datasets Hub).
|
| 342 |
-
#
|
| 343 |
if data_args.train_file is not None or data_args.validation_file is not None:
|
| 344 |
data_files = {
|
| 345 |
"train": data_args.train_file,
|
|
@@ -826,7 +823,10 @@ def main():
|
|
| 826 |
temp_dir=True, # avoid issues with being in a repository
|
| 827 |
)
|
| 828 |
|
|
|
|
| 829 |
last_time = time.perf_counter()
|
|
|
|
|
|
|
| 830 |
for epoch in epochs:
|
| 831 |
state.replace(epoch=jax_utils.replicate(epoch))
|
| 832 |
# ======================== Training ================================
|
|
@@ -871,12 +871,13 @@ def main():
|
|
| 871 |
run_save_model(state, eval_metrics)
|
| 872 |
|
| 873 |
# log final train metrics
|
| 874 |
-
train_metric
|
| 875 |
-
|
|
|
|
| 876 |
|
| 877 |
-
|
| 878 |
-
|
| 879 |
-
|
| 880 |
|
| 881 |
# Final evaluation
|
| 882 |
eval_metrics = run_evaluation()
|
|
|
|
| 336 |
# Set the verbosity to info of the Transformers logger (on main process only):
|
| 337 |
logger.info(f"Training/evaluation parameters {training_args}")
|
| 338 |
|
| 339 |
+
# Load dataset
|
|
|
|
|
|
|
|
|
|
| 340 |
if data_args.train_file is not None or data_args.validation_file is not None:
|
| 341 |
data_files = {
|
| 342 |
"train": data_args.train_file,
|
|
|
|
| 823 |
temp_dir=True, # avoid issues with being in a repository
|
| 824 |
)
|
| 825 |
|
| 826 |
+
# init variables
|
| 827 |
last_time = time.perf_counter()
|
| 828 |
+
train_metric = None
|
| 829 |
+
|
| 830 |
for epoch in epochs:
|
| 831 |
state.replace(epoch=jax_utils.replicate(epoch))
|
| 832 |
# ======================== Training ================================
|
|
|
|
| 871 |
run_save_model(state, eval_metrics)
|
| 872 |
|
| 873 |
# log final train metrics
|
| 874 |
+
if train_metric is not None:
|
| 875 |
+
train_metric = get_metrics(train_metric)
|
| 876 |
+
wandb_log(train_metric, step=step, prefix="train")
|
| 877 |
|
| 878 |
+
epochs.write(
|
| 879 |
+
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
| 880 |
+
)
|
| 881 |
|
| 882 |
# Final evaluation
|
| 883 |
eval_metrics = run_evaluation()
|