Spaces:
Running
Running
Merge pull request #90 from borisdayma/feat-new
Browse files- dev/seq2seq/run_seq2seq_flax.py +18 -31
dev/seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -100,12 +100,6 @@ class ModelArguments:
|
|
| 100 |
"help": "Pretrained config name or path if not the same as model_name"
|
| 101 |
},
|
| 102 |
)
|
| 103 |
-
tokenizer_name: Optional[str] = field(
|
| 104 |
-
default=None,
|
| 105 |
-
metadata={
|
| 106 |
-
"help": "Pretrained tokenizer name or path if not the same as model_name"
|
| 107 |
-
},
|
| 108 |
-
)
|
| 109 |
cache_dir: Optional[str] = field(
|
| 110 |
default=None,
|
| 111 |
metadata={
|
|
@@ -422,7 +416,7 @@ def wandb_log(metrics, step=None, prefix=None):
|
|
| 422 |
f"{prefix}/{k}" if prefix is not None else k: v for k, v in metrics.items()
|
| 423 |
}
|
| 424 |
if step is not None:
|
| 425 |
-
log_metrics["train/step"] =
|
| 426 |
wandb.log(log_metrics)
|
| 427 |
|
| 428 |
|
|
@@ -534,11 +528,6 @@ def main():
|
|
| 534 |
)
|
| 535 |
|
| 536 |
else:
|
| 537 |
-
base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
| 538 |
-
model_args.model_name_or_path,
|
| 539 |
-
seed=training_args.seed,
|
| 540 |
-
dtype=getattr(jnp, model_args.dtype),
|
| 541 |
-
)
|
| 542 |
# Set up our new model config
|
| 543 |
config = BartConfig.from_pretrained(model_args.model_name_or_path)
|
| 544 |
config.tie_word_embeddings = False
|
|
@@ -563,11 +552,6 @@ def main():
|
|
| 563 |
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 564 |
)
|
| 565 |
|
| 566 |
-
# Use pre-trained weights for encoder
|
| 567 |
-
model.params["model"]["encoder"] = base_model.params["model"]["encoder"]
|
| 568 |
-
model.params["model"]["shared"] = base_model.params["model"]["shared"]
|
| 569 |
-
del base_model
|
| 570 |
-
|
| 571 |
# Load tokenizer if it has not been set
|
| 572 |
if tokenizer is None:
|
| 573 |
tokenizer = AutoTokenizer.from_pretrained(
|
|
@@ -862,7 +846,7 @@ def main():
|
|
| 862 |
f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
|
| 863 |
)
|
| 864 |
logger.info(
|
| 865 |
-
f" Total train batch size (w. parallel &
|
| 866 |
)
|
| 867 |
logger.info(f" Total global steps = {total_steps}")
|
| 868 |
logger.info(f" Total optimization steps = {total_optimization_steps}")
|
|
@@ -870,7 +854,7 @@ def main():
|
|
| 870 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
| 871 |
|
| 872 |
# set default x-axis as 'train/step'
|
| 873 |
-
wandb_log({}, step=state.step)
|
| 874 |
wandb.define_metric("*", step_metric="train/step")
|
| 875 |
|
| 876 |
# add interesting config parameters
|
|
@@ -909,7 +893,7 @@ def main():
|
|
| 909 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
| 910 |
|
| 911 |
# log metrics
|
| 912 |
-
wandb_log(eval_metrics, step=state.step, prefix="eval")
|
| 913 |
|
| 914 |
# Print metrics and update progress bar
|
| 915 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
|
@@ -943,6 +927,10 @@ def main():
|
|
| 943 |
|
| 944 |
# save to W&B
|
| 945 |
if data_args.log_model:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 946 |
metadata = {"step": step, "epoch": epoch}
|
| 947 |
if eval_metrics is not None:
|
| 948 |
metadata["eval/loss"] = eval_metrics["loss"]
|
|
@@ -970,11 +958,8 @@ def main():
|
|
| 970 |
artifact.add_file(
|
| 971 |
str(Path(training_args.output_dir) / "training_state.json")
|
| 972 |
)
|
| 973 |
-
wandb.run.log_artifact(artifact)
|
| 974 |
|
| 975 |
-
|
| 976 |
-
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
| 977 |
-
c.cleanup(wandb.util.from_human_size("5GB"))
|
| 978 |
|
| 979 |
# save to the hub
|
| 980 |
if training_args.push_to_hub:
|
|
@@ -988,7 +973,8 @@ def main():
|
|
| 988 |
|
| 989 |
for epoch in epochs:
|
| 990 |
# ======================== Training ================================
|
| 991 |
-
|
|
|
|
| 992 |
|
| 993 |
# Create sampling rng
|
| 994 |
rng, input_rng = jax.random.split(rng)
|
|
@@ -1010,19 +996,20 @@ def main():
|
|
| 1010 |
total=steps_per_epoch,
|
| 1011 |
):
|
| 1012 |
state, train_metric = p_train_step(state, batch)
|
|
|
|
| 1013 |
|
| 1014 |
-
if
|
| 1015 |
# log metrics
|
| 1016 |
-
wandb_log(unreplicate(train_metric), step=
|
| 1017 |
|
| 1018 |
-
if training_args.eval_steps and
|
| 1019 |
run_evaluation()
|
| 1020 |
|
| 1021 |
-
if
|
| 1022 |
-
run_save_model(state,
|
| 1023 |
|
| 1024 |
# log final train metrics
|
| 1025 |
-
wandb_log(unreplicate(train_metric), step=
|
| 1026 |
|
| 1027 |
train_metric = unreplicate(train_metric)
|
| 1028 |
epochs.write(
|
|
|
|
| 100 |
"help": "Pretrained config name or path if not the same as model_name"
|
| 101 |
},
|
| 102 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
cache_dir: Optional[str] = field(
|
| 104 |
default=None,
|
| 105 |
metadata={
|
|
|
|
| 416 |
f"{prefix}/{k}" if prefix is not None else k: v for k, v in metrics.items()
|
| 417 |
}
|
| 418 |
if step is not None:
|
| 419 |
+
log_metrics["train/step"] = step
|
| 420 |
wandb.log(log_metrics)
|
| 421 |
|
| 422 |
|
|
|
|
| 528 |
)
|
| 529 |
|
| 530 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 531 |
# Set up our new model config
|
| 532 |
config = BartConfig.from_pretrained(model_args.model_name_or_path)
|
| 533 |
config.tie_word_embeddings = False
|
|
|
|
| 552 |
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 553 |
)
|
| 554 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 555 |
# Load tokenizer if it has not been set
|
| 556 |
if tokenizer is None:
|
| 557 |
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
| 846 |
f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
|
| 847 |
)
|
| 848 |
logger.info(
|
| 849 |
+
f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
|
| 850 |
)
|
| 851 |
logger.info(f" Total global steps = {total_steps}")
|
| 852 |
logger.info(f" Total optimization steps = {total_optimization_steps}")
|
|
|
|
| 854 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
| 855 |
|
| 856 |
# set default x-axis as 'train/step'
|
| 857 |
+
wandb_log({}, step=unreplicate(state.step))
|
| 858 |
wandb.define_metric("*", step_metric="train/step")
|
| 859 |
|
| 860 |
# add interesting config parameters
|
|
|
|
| 893 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
| 894 |
|
| 895 |
# log metrics
|
| 896 |
+
wandb_log(eval_metrics, step=unreplicate(state.step), prefix="eval")
|
| 897 |
|
| 898 |
# Print metrics and update progress bar
|
| 899 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
|
|
|
| 927 |
|
| 928 |
# save to W&B
|
| 929 |
if data_args.log_model:
|
| 930 |
+
# save some space
|
| 931 |
+
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
| 932 |
+
c.cleanup(wandb.util.from_human_size("5GB"))
|
| 933 |
+
|
| 934 |
metadata = {"step": step, "epoch": epoch}
|
| 935 |
if eval_metrics is not None:
|
| 936 |
metadata["eval/loss"] = eval_metrics["loss"]
|
|
|
|
| 958 |
artifact.add_file(
|
| 959 |
str(Path(training_args.output_dir) / "training_state.json")
|
| 960 |
)
|
|
|
|
| 961 |
|
| 962 |
+
wandb.run.log_artifact(artifact)
|
|
|
|
|
|
|
| 963 |
|
| 964 |
# save to the hub
|
| 965 |
if training_args.push_to_hub:
|
|
|
|
| 973 |
|
| 974 |
for epoch in epochs:
|
| 975 |
# ======================== Training ================================
|
| 976 |
+
step = unreplicate(state.step)
|
| 977 |
+
wandb_log({"train/epoch": epoch}, step=step)
|
| 978 |
|
| 979 |
# Create sampling rng
|
| 980 |
rng, input_rng = jax.random.split(rng)
|
|
|
|
| 996 |
total=steps_per_epoch,
|
| 997 |
):
|
| 998 |
state, train_metric = p_train_step(state, batch)
|
| 999 |
+
step = unreplicate(state.step)
|
| 1000 |
|
| 1001 |
+
if step % data_args.log_interval == 0 and jax.process_index() == 0:
|
| 1002 |
# log metrics
|
| 1003 |
+
wandb_log(unreplicate(train_metric), step=step, prefix="train")
|
| 1004 |
|
| 1005 |
+
if training_args.eval_steps and step % training_args.eval_steps == 0:
|
| 1006 |
run_evaluation()
|
| 1007 |
|
| 1008 |
+
if step % data_args.save_model_steps == 0:
|
| 1009 |
+
run_save_model(state, step, epoch)
|
| 1010 |
|
| 1011 |
# log final train metrics
|
| 1012 |
+
wandb_log(unreplicate(train_metric), step=step, prefix="train")
|
| 1013 |
|
| 1014 |
train_metric = unreplicate(train_metric)
|
| 1015 |
epochs.write(
|