Spaces:
Running
Running
feat: log metrics more frequently
Browse files- seq2seq/run_seq2seq_flax.py +16 -3
seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -215,6 +215,13 @@ class DataTrainingArguments:
|
|
| 215 |
overwrite_cache: bool = field(
|
| 216 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
| 217 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
def __post_init__(self):
|
| 220 |
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
|
@@ -307,12 +314,12 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
|
| 307 |
|
| 308 |
train_metrics = get_metrics(train_metrics)
|
| 309 |
for key, vals in train_metrics.items():
|
| 310 |
-
tag = f"
|
| 311 |
for i, val in enumerate(vals):
|
| 312 |
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
| 313 |
|
| 314 |
for metric_name, value in eval_metrics.items():
|
| 315 |
-
summary_writer.scalar(f"
|
| 316 |
|
| 317 |
|
| 318 |
def create_learning_rate_fn(
|
|
@@ -718,6 +725,7 @@ def main():
|
|
| 718 |
|
| 719 |
train_time = 0
|
| 720 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
|
|
|
| 721 |
for epoch in epochs:
|
| 722 |
# ======================== Training ================================
|
| 723 |
train_start = time.time()
|
|
@@ -730,11 +738,16 @@ def main():
|
|
| 730 |
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
| 731 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
| 732 |
# train
|
| 733 |
-
for
|
|
|
|
| 734 |
batch = next(train_loader)
|
| 735 |
state, train_metric = p_train_step(state, batch)
|
| 736 |
train_metrics.append(train_metric)
|
| 737 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 738 |
train_time += time.time() - train_start
|
| 739 |
|
| 740 |
train_metric = unreplicate(train_metric)
|
|
|
|
| 215 |
overwrite_cache: bool = field(
|
| 216 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
| 217 |
)
|
| 218 |
+
log_interval: Optional[int] = field(
|
| 219 |
+
default=500,
|
| 220 |
+
metadata={
|
| 221 |
+
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
| 222 |
+
"value if set."
|
| 223 |
+
},
|
| 224 |
+
)
|
| 225 |
|
| 226 |
def __post_init__(self):
|
| 227 |
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
|
|
|
| 314 |
|
| 315 |
train_metrics = get_metrics(train_metrics)
|
| 316 |
for key, vals in train_metrics.items():
|
| 317 |
+
tag = f"train_epoch/{key}"
|
| 318 |
for i, val in enumerate(vals):
|
| 319 |
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
| 320 |
|
| 321 |
for metric_name, value in eval_metrics.items():
|
| 322 |
+
summary_writer.scalar(f"eval/{metric_name}", value, step)
|
| 323 |
|
| 324 |
|
| 325 |
def create_learning_rate_fn(
|
|
|
|
| 725 |
|
| 726 |
train_time = 0
|
| 727 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
| 728 |
+
global_step = 0
|
| 729 |
for epoch in epochs:
|
| 730 |
# ======================== Training ================================
|
| 731 |
train_start = time.time()
|
|
|
|
| 738 |
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
| 739 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
| 740 |
# train
|
| 741 |
+
for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
| 742 |
+
global_step +=1
|
| 743 |
batch = next(train_loader)
|
| 744 |
state, train_metric = p_train_step(state, batch)
|
| 745 |
train_metrics.append(train_metric)
|
| 746 |
|
| 747 |
+
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
| 748 |
+
for k, v in unreplicate(train_metric).items():
|
| 749 |
+
wandb.log(f{'train/{k}': jax.device_get(v), step=global_step)
|
| 750 |
+
|
| 751 |
train_time += time.time() - train_start
|
| 752 |
|
| 753 |
train_metric = unreplicate(train_metric)
|