Spaces:
Running
Running
Pedro Cuenca
commited on
Commit
·
de74f11
1
Parent(s):
9c0e5c9
fix typos and update requirements
Browse files- seq2seq/requirements.txt +2 -0
- seq2seq/run_seq2seq_flax.py +4 -4
seq2seq/requirements.txt
CHANGED
|
@@ -4,3 +4,5 @@ jaxlib>=0.1.59
|
|
| 4 |
flax>=0.3.4
|
| 5 |
optax>=0.0.8
|
| 6 |
tensorboard
|
|
|
|
|
|
|
|
|
| 4 |
flax>=0.3.4
|
| 5 |
optax>=0.0.8
|
| 6 |
tensorboard
|
| 7 |
+
nltk
|
| 8 |
+
wandb
|
seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -19,7 +19,7 @@ Script adapted from run_summarization_flax.py
|
|
| 19 |
"""
|
| 20 |
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
| 21 |
|
| 22 |
-
import logging
|
| 23 |
import os
|
| 24 |
import sys
|
| 25 |
import time
|
|
@@ -60,7 +60,7 @@ from transformers.file_utils import is_offline_mode
|
|
| 60 |
|
| 61 |
import wandb
|
| 62 |
|
| 63 |
-
logger =
|
| 64 |
|
| 65 |
try:
|
| 66 |
nltk.data.find("tokenizers/punkt")
|
|
@@ -389,7 +389,7 @@ def main():
|
|
| 389 |
data_files["validation"] = data_args.validation_file
|
| 390 |
if data_args.test_file is not None:
|
| 391 |
data_files["test"] = data_args.test_file
|
| 392 |
-
dataset = load_dataset"csv", data_files=data_files, cache_dir=model_args.cache_dir, delimiter="\t")
|
| 393 |
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
| 394 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
| 395 |
|
|
@@ -411,7 +411,7 @@ def main():
|
|
| 411 |
|
| 412 |
|
| 413 |
# Create a custom model and initialize it randomly
|
| 414 |
-
model = CustomFlaxBartForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 415 |
|
| 416 |
# Use pre-trained weights for encoder
|
| 417 |
model.params['model']['encoder'] = base_model.params['model']['encoder']
|
|
|
|
| 19 |
"""
|
| 20 |
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
| 21 |
|
| 22 |
+
import logging as pylogging # To avoid collision with transformers.utils.logging
|
| 23 |
import os
|
| 24 |
import sys
|
| 25 |
import time
|
|
|
|
| 60 |
|
| 61 |
import wandb
|
| 62 |
|
| 63 |
+
logger = pylogging.getLogger(__name__)
|
| 64 |
|
| 65 |
try:
|
| 66 |
nltk.data.find("tokenizers/punkt")
|
|
|
|
| 389 |
data_files["validation"] = data_args.validation_file
|
| 390 |
if data_args.test_file is not None:
|
| 391 |
data_files["test"] = data_args.test_file
|
| 392 |
+
dataset = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir, delimiter="\t")
|
| 393 |
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
| 394 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
| 395 |
|
|
|
|
| 411 |
|
| 412 |
|
| 413 |
# Create a custom model and initialize it randomly
|
| 414 |
+
model = CustomFlaxBartForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
| 415 |
|
| 416 |
# Use pre-trained weights for encoder
|
| 417 |
model.params['model']['encoder'] = base_model.params['model']['encoder']
|