Spaces:
Running
Running
Merge pull request #24 from borisdayma/feat--log-model-frequently
Browse files- seq2seq/run_seq2seq_flax.py +44 -29
seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -84,7 +84,7 @@ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
|
| 84 |
OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
|
| 85 |
OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
|
| 86 |
BOS_TOKEN_ID = 16384
|
| 87 |
-
BASE_MODEL = 'facebook/bart-large'
|
| 88 |
|
| 89 |
|
| 90 |
@dataclass
|
|
@@ -231,6 +231,12 @@ class DataTrainingArguments:
|
|
| 231 |
log_model: bool = field(
|
| 232 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
| 233 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
|
| 235 |
def __post_init__(self):
|
| 236 |
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
|
@@ -340,7 +346,7 @@ def wandb_log(metrics, step=None, prefix=None):
|
|
| 340 |
if jax.process_index() == 0:
|
| 341 |
log_metrics = {f'{prefix}/{k}' if prefix is not None else k: jax.device_get(v) for k,v in metrics.items()}
|
| 342 |
if step is not None:
|
| 343 |
-
log_metrics
|
| 344 |
wandb.log(log_metrics)
|
| 345 |
|
| 346 |
|
|
@@ -773,6 +779,38 @@ def main():
|
|
| 773 |
|
| 774 |
return eval_metrics
|
| 775 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 776 |
for epoch in epochs:
|
| 777 |
# ======================== Training ================================
|
| 778 |
train_start = time.time()
|
|
@@ -795,6 +833,9 @@ def main():
|
|
| 795 |
|
| 796 |
if global_step % training_args.eval_steps == 0:
|
| 797 |
run_evaluation()
|
|
|
|
|
|
|
|
|
|
| 798 |
|
| 799 |
# log final train metrics
|
| 800 |
wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
|
|
@@ -809,34 +850,8 @@ def main():
|
|
| 809 |
eval_metrics = run_evaluation()
|
| 810 |
|
| 811 |
# save checkpoint after each epoch and push checkpoint to the hub
|
| 812 |
-
|
| 813 |
-
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
| 814 |
-
|
| 815 |
-
# save model locally
|
| 816 |
-
model.save_pretrained(
|
| 817 |
-
training_args.output_dir,
|
| 818 |
-
params=params,
|
| 819 |
-
)
|
| 820 |
-
|
| 821 |
-
# save to W&B
|
| 822 |
-
if data_args.log_model:
|
| 823 |
-
metadata = {'epoch': epoch+1, 'eval/loss': eval_metrics['loss']}
|
| 824 |
-
artifact = wandb.Artifact(
|
| 825 |
-
name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
|
| 826 |
-
)
|
| 827 |
-
artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
|
| 828 |
-
artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
|
| 829 |
-
wandb.run.log_artifact(artifact)
|
| 830 |
|
| 831 |
-
# save to the hub
|
| 832 |
-
if training_args.push_to_hub:
|
| 833 |
-
model.save_pretrained(
|
| 834 |
-
training_args.output_dir,
|
| 835 |
-
params=params,
|
| 836 |
-
push_to_hub=training_args.push_to_hub,
|
| 837 |
-
commit_message=f"Saving weights and logs of epoch {epoch+1}",
|
| 838 |
-
temp_dir=True # avoid issues with being in a repository
|
| 839 |
-
)
|
| 840 |
|
| 841 |
# ======================== Prediction loop ==============================
|
| 842 |
if training_args.do_predict:
|
|
|
|
| 84 |
OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
|
| 85 |
OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
|
| 86 |
BOS_TOKEN_ID = 16384
|
| 87 |
+
BASE_MODEL = 'facebook/bart-large-cnn' # we currently have issues with bart-large
|
| 88 |
|
| 89 |
|
| 90 |
@dataclass
|
|
|
|
| 231 |
log_model: bool = field(
|
| 232 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
| 233 |
)
|
| 234 |
+
save_model_steps: Optional[int] = field(
|
| 235 |
+
default=3000, # about once every hour in our experiments
|
| 236 |
+
metadata={
|
| 237 |
+
"help": "For logging the model more frequently. Used only when `log_model` is set."
|
| 238 |
+
},
|
| 239 |
+
)
|
| 240 |
|
| 241 |
def __post_init__(self):
|
| 242 |
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
|
|
|
| 346 |
if jax.process_index() == 0:
|
| 347 |
log_metrics = {f'{prefix}/{k}' if prefix is not None else k: jax.device_get(v) for k,v in metrics.items()}
|
| 348 |
if step is not None:
|
| 349 |
+
log_metrics['train/step'] = step
|
| 350 |
wandb.log(log_metrics)
|
| 351 |
|
| 352 |
|
|
|
|
| 779 |
|
| 780 |
return eval_metrics
|
| 781 |
|
| 782 |
+
def run_save_model(step, epoch, eval_metrics=None):
|
| 783 |
+
if jax.process_index() == 0:
|
| 784 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
| 785 |
+
|
| 786 |
+
# save model locally
|
| 787 |
+
model.save_pretrained(
|
| 788 |
+
training_args.output_dir,
|
| 789 |
+
params=params,
|
| 790 |
+
)
|
| 791 |
+
|
| 792 |
+
# save to W&B
|
| 793 |
+
if data_args.log_model:
|
| 794 |
+
metadata = {'step': step, 'epoch': epoch}
|
| 795 |
+
if eval_metrics is not None:
|
| 796 |
+
metadata['eval/loss'] = eval_metrics['loss']
|
| 797 |
+
artifact = wandb.Artifact(
|
| 798 |
+
name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
|
| 799 |
+
)
|
| 800 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
|
| 801 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
|
| 802 |
+
wandb.run.log_artifact(artifact)
|
| 803 |
+
|
| 804 |
+
# save to the hub
|
| 805 |
+
if training_args.push_to_hub:
|
| 806 |
+
model.save_pretrained(
|
| 807 |
+
training_args.output_dir,
|
| 808 |
+
params=params,
|
| 809 |
+
push_to_hub=training_args.push_to_hub,
|
| 810 |
+
commit_message=f"Saving weights and logs of epoch {epoch+1}",
|
| 811 |
+
temp_dir=True # avoid issues with being in a repository
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
for epoch in epochs:
|
| 815 |
# ======================== Training ================================
|
| 816 |
train_start = time.time()
|
|
|
|
| 833 |
|
| 834 |
if global_step % training_args.eval_steps == 0:
|
| 835 |
run_evaluation()
|
| 836 |
+
|
| 837 |
+
if global_step % data_args.save_model_steps == 0:
|
| 838 |
+
run_save_model(global_step, epoch)
|
| 839 |
|
| 840 |
# log final train metrics
|
| 841 |
wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
|
|
|
|
| 850 |
eval_metrics = run_evaluation()
|
| 851 |
|
| 852 |
# save checkpoint after each epoch and push checkpoint to the hub
|
| 853 |
+
run_save_model(global_step, epoch, eval_metrics)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 854 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 855 |
|
| 856 |
# ======================== Prediction loop ==============================
|
| 857 |
if training_args.do_predict:
|