Spaces:
Running
Running
Merge pull request #15 from borisdayma/feat-fix-lr
Browse files- requirements.txt +3 -0
- seq2seq/run_seq2seq_flax.py +6 -3
- seq2seq/sweep.yaml +3 -3
requirements.txt
CHANGED
|
@@ -7,3 +7,6 @@ jax[tpu]>=0.2.16
|
|
| 7 |
-e git+https://github.com/huggingface/datasets.git@master#egg=datasets
|
| 8 |
flax
|
| 9 |
jupyter
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
-e git+https://github.com/huggingface/datasets.git@master#egg=datasets
|
| 8 |
flax
|
| 9 |
jupyter
|
| 10 |
+
# for logging
|
| 11 |
+
tensorboard
|
| 12 |
+
tetnsorflow
|
seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -19,8 +19,11 @@ Script adapted from run_summarization_flax.py
|
|
| 19 |
"""
|
| 20 |
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
| 21 |
|
| 22 |
-
import logging as pylogging # To avoid collision with transformers.utils.logging
|
| 23 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
import sys
|
| 25 |
import time
|
| 26 |
from dataclasses import dataclass, field
|
|
@@ -673,12 +676,12 @@ def main():
|
|
| 673 |
grads = jax.tree_map(lambda x: x / training_args.gradient_accumulation_steps, grad_accum)
|
| 674 |
grads = jax.lax.pmean(grads, "batch")
|
| 675 |
new_state = state.apply_gradients(
|
| 676 |
-
grads=grads, grad_accum=jax.tree_map(jnp.zeros_like, grads), optimizer_step=state.optimizer_step
|
| 677 |
)
|
| 678 |
return new_state
|
| 679 |
|
| 680 |
new_state = jax.lax.cond(
|
| 681 |
-
state.step % training_args.gradient_accumulation_steps == 0,
|
| 682 |
lambda _: update_fn(),
|
| 683 |
lambda _: state.replace(grad_accum=grad_accum, step=state.step + 1),
|
| 684 |
None,
|
|
|
|
| 19 |
"""
|
| 20 |
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
| 21 |
|
|
|
|
| 22 |
import os
|
| 23 |
+
# set a common huggingface cache folder (used with datasets and transformers)
|
| 24 |
+
os.environ['HF_HOME'] = '/data/huggingface/' # required before importing transformers & datasets
|
| 25 |
+
|
| 26 |
+
import logging as pylogging # To avoid collision with transformers.utils.logging
|
| 27 |
import sys
|
| 28 |
import time
|
| 29 |
from dataclasses import dataclass, field
|
|
|
|
| 676 |
grads = jax.tree_map(lambda x: x / training_args.gradient_accumulation_steps, grad_accum)
|
| 677 |
grads = jax.lax.pmean(grads, "batch")
|
| 678 |
new_state = state.apply_gradients(
|
| 679 |
+
grads=grads, grad_accum=jax.tree_map(jnp.zeros_like, grads), optimizer_step=state.optimizer_step + 1
|
| 680 |
)
|
| 681 |
return new_state
|
| 682 |
|
| 683 |
new_state = jax.lax.cond(
|
| 684 |
+
(state.step + 1) % training_args.gradient_accumulation_steps == 0,
|
| 685 |
lambda _: update_fn(),
|
| 686 |
lambda _: state.replace(grad_accum=grad_accum, step=state.step + 1),
|
| 687 |
None,
|
seq2seq/sweep.yaml
CHANGED
|
@@ -8,9 +8,9 @@ metric:
|
|
| 8 |
parameters:
|
| 9 |
learning_rate:
|
| 10 |
distribution: log_uniform
|
| 11 |
-
# from exp(min) to exp(max), ie 1e-
|
| 12 |
-
min: -
|
| 13 |
-
max: -
|
| 14 |
gradient_accumulation_steps:
|
| 15 |
value: 8
|
| 16 |
warmup_steps:
|
|
|
|
| 8 |
parameters:
|
| 9 |
learning_rate:
|
| 10 |
distribution: log_uniform
|
| 11 |
+
# from exp(min) to exp(max), ie 1e-4 to 5e-3 on log scale
|
| 12 |
+
min: -9.2
|
| 13 |
+
max: -5.3
|
| 14 |
gradient_accumulation_steps:
|
| 15 |
value: 8
|
| 16 |
warmup_steps:
|