Spaces:
Running
Running
fix: OOM
Browse files
dev/seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -801,8 +801,8 @@ def main():
|
|
| 801 |
p_eval_step = jax.pmap(eval_step, "batch")
|
| 802 |
|
| 803 |
# Replicate the train state on each device
|
| 804 |
-
state = state.replicate()
|
| 805 |
del model._params
|
|
|
|
| 806 |
|
| 807 |
logger.info("***** Running training *****")
|
| 808 |
logger.info(f" Num examples = {len_train_dataset}")
|
|
|
|
| 801 |
p_eval_step = jax.pmap(eval_step, "batch")
|
| 802 |
|
| 803 |
# Replicate the train state on each device
|
|
|
|
| 804 |
del model._params
|
| 805 |
+
state = state.replicate()
|
| 806 |
|
| 807 |
logger.info("***** Running training *****")
|
| 808 |
logger.info(f" Num examples = {len_train_dataset}")
|