Spaces:
Running
Running
Merge pull request #16 from borisdayma/feat-log_model
Browse files- requirements.txt +1 -1
- seq2seq/run_seq2seq_flax.py +34 -11
requirements.txt
CHANGED
|
@@ -9,4 +9,4 @@ flax
|
|
| 9 |
jupyter
|
| 10 |
# for logging
|
| 11 |
tensorboard
|
| 12 |
-
|
|
|
|
| 9 |
jupyter
|
| 10 |
# for logging
|
| 11 |
tensorboard
|
| 12 |
+
tensorflow
|
seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -199,7 +199,7 @@ class DataTrainingArguments:
|
|
| 199 |
},
|
| 200 |
)
|
| 201 |
preprocessing_num_workers: Optional[int] = field(
|
| 202 |
-
default=
|
| 203 |
metadata={"help": "The number of processes to use for the preprocessing."},
|
| 204 |
)
|
| 205 |
source_prefix: Optional[str] = field(
|
|
@@ -225,6 +225,9 @@ class DataTrainingArguments:
|
|
| 225 |
"value if set."
|
| 226 |
},
|
| 227 |
)
|
|
|
|
|
|
|
|
|
|
| 228 |
|
| 229 |
def __post_init__(self):
|
| 230 |
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
|
@@ -812,6 +815,36 @@ def main():
|
|
| 812 |
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
| 813 |
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
|
| 814 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 815 |
# ======================== Prediction loop ==============================
|
| 816 |
if training_args.do_predict:
|
| 817 |
logger.info("*** Predict ***")
|
|
@@ -851,16 +884,6 @@ def main():
|
|
| 851 |
desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
|
| 852 |
logger.info(desc)
|
| 853 |
|
| 854 |
-
# save checkpoint after each epoch and push checkpoint to the hub
|
| 855 |
-
if jax.process_index() == 0:
|
| 856 |
-
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
| 857 |
-
model.save_pretrained(
|
| 858 |
-
training_args.output_dir,
|
| 859 |
-
params=params,
|
| 860 |
-
push_to_hub=training_args.push_to_hub,
|
| 861 |
-
commit_message=f"Saving weights and logs of epoch {epoch+1}",
|
| 862 |
-
)
|
| 863 |
-
|
| 864 |
|
| 865 |
if __name__ == "__main__":
|
| 866 |
main()
|
|
|
|
| 199 |
},
|
| 200 |
)
|
| 201 |
preprocessing_num_workers: Optional[int] = field(
|
| 202 |
+
default=80, # ensure we have the same datasets cached data and avoid using too much space
|
| 203 |
metadata={"help": "The number of processes to use for the preprocessing."},
|
| 204 |
)
|
| 205 |
source_prefix: Optional[str] = field(
|
|
|
|
| 225 |
"value if set."
|
| 226 |
},
|
| 227 |
)
|
| 228 |
+
log_model: bool = field(
|
| 229 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
| 230 |
+
)
|
| 231 |
|
| 232 |
def __post_init__(self):
|
| 233 |
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
|
|
|
| 815 |
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
| 816 |
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
|
| 817 |
|
| 818 |
+
# save checkpoint after each epoch and push checkpoint to the hub
|
| 819 |
+
if jax.process_index() == 0:
|
| 820 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
| 821 |
+
|
| 822 |
+
# save model locally
|
| 823 |
+
model.save_pretrained(
|
| 824 |
+
training_args.output_dir,
|
| 825 |
+
params=params,
|
| 826 |
+
)
|
| 827 |
+
|
| 828 |
+
# save to W&B
|
| 829 |
+
if data_args.log_model:
|
| 830 |
+
metadata = {'epoch': epoch+1, 'eval/loss': eval_metrics['loss']}
|
| 831 |
+
artifact = wandb.Artifact(
|
| 832 |
+
name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
|
| 833 |
+
)
|
| 834 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
|
| 835 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
|
| 836 |
+
wandb.run.log_artifact(artifact)
|
| 837 |
+
|
| 838 |
+
# save to the hub
|
| 839 |
+
if training_args.push_to_hub:
|
| 840 |
+
model.save_pretrained(
|
| 841 |
+
training_args.output_dir,
|
| 842 |
+
params=params,
|
| 843 |
+
push_to_hub=training_args.push_to_hub,
|
| 844 |
+
commit_message=f"Saving weights and logs of epoch {epoch+1}",
|
| 845 |
+
temp_dir=True # avoid issues with being in a repository
|
| 846 |
+
)
|
| 847 |
+
|
| 848 |
# ======================== Prediction loop ==============================
|
| 849 |
if training_args.do_predict:
|
| 850 |
logger.info("*** Predict ***")
|
|
|
|
| 884 |
desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
|
| 885 |
logger.info(desc)
|
| 886 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 887 |
|
| 888 |
if __name__ == "__main__":
|
| 889 |
main()
|