Spaces:
Running
Running
feat: log epoch + check params
Browse files- dev/seq2seq/do_big_run.sh +4 -4
- dev/seq2seq/do_small_run.sh +3 -3
- dev/seq2seq/run_seq2seq_flax.py +21 -30
dev/seq2seq/do_big_run.sh
CHANGED
|
@@ -1,16 +1,16 @@
|
|
| 1 |
python run_seq2seq_flax.py \
|
| 2 |
-
--max_source_length 128 \
|
| 3 |
--dataset_repo_or_path dalle-mini/encoded \
|
| 4 |
--train_file **/train/*/*.jsonl \
|
| 5 |
--validation_file **/valid/*/*.jsonl \
|
|
|
|
|
|
|
| 6 |
--streaming \
|
| 7 |
-
--
|
| 8 |
-
--len_eval 100 \
|
| 9 |
--output_dir output \
|
| 10 |
--per_device_train_batch_size 56 \
|
| 11 |
--per_device_eval_batch_size 56 \
|
| 12 |
--preprocessing_num_workers 80 \
|
| 13 |
-
--warmup_steps
|
| 14 |
--gradient_accumulation_steps 8 \
|
| 15 |
--do_train \
|
| 16 |
--do_eval \
|
|
|
|
| 1 |
python run_seq2seq_flax.py \
|
|
|
|
| 2 |
--dataset_repo_or_path dalle-mini/encoded \
|
| 3 |
--train_file **/train/*/*.jsonl \
|
| 4 |
--validation_file **/valid/*/*.jsonl \
|
| 5 |
+
--len_train 42684248 \
|
| 6 |
+
--len_eval 34328 \
|
| 7 |
--streaming \
|
| 8 |
+
--normalize_text \
|
|
|
|
| 9 |
--output_dir output \
|
| 10 |
--per_device_train_batch_size 56 \
|
| 11 |
--per_device_eval_batch_size 56 \
|
| 12 |
--preprocessing_num_workers 80 \
|
| 13 |
+
--warmup_steps 500 \
|
| 14 |
--gradient_accumulation_steps 8 \
|
| 15 |
--do_train \
|
| 16 |
--do_eval \
|
dev/seq2seq/do_small_run.sh
CHANGED
|
@@ -2,9 +2,9 @@ python run_seq2seq_flax.py \
|
|
| 2 |
--dataset_repo_or_path dalle-mini/encoded \
|
| 3 |
--train_file **/train/*/*.jsonl \
|
| 4 |
--validation_file **/valid/*/*.jsonl \
|
|
|
|
|
|
|
| 5 |
--streaming \
|
| 6 |
-
--len_train 1000000 \
|
| 7 |
-
--len_eval 1000 \
|
| 8 |
--output_dir output \
|
| 9 |
--per_device_train_batch_size 56 \
|
| 10 |
--per_device_eval_batch_size 56 \
|
|
@@ -15,5 +15,5 @@ python run_seq2seq_flax.py \
|
|
| 15 |
--do_eval \
|
| 16 |
--adafactor \
|
| 17 |
--num_train_epochs 1 \
|
| 18 |
-
--max_train_samples
|
| 19 |
--learning_rate 0.005
|
|
|
|
| 2 |
--dataset_repo_or_path dalle-mini/encoded \
|
| 3 |
--train_file **/train/*/*.jsonl \
|
| 4 |
--validation_file **/valid/*/*.jsonl \
|
| 5 |
+
--len_train 42684248 \
|
| 6 |
+
--len_eval 34328 \
|
| 7 |
--streaming \
|
|
|
|
|
|
|
| 8 |
--output_dir output \
|
| 9 |
--per_device_train_batch_size 56 \
|
| 10 |
--per_device_eval_batch_size 56 \
|
|
|
|
| 15 |
--do_eval \
|
| 16 |
--adafactor \
|
| 17 |
--num_train_epochs 1 \
|
| 18 |
+
--max_train_samples 10000 \
|
| 19 |
--learning_rate 0.005
|
dev/seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -138,16 +138,6 @@ class DataTrainingArguments:
|
|
| 138 |
Arguments pertaining to what data we are going to input our model for training and eval.
|
| 139 |
"""
|
| 140 |
|
| 141 |
-
dataset_name: Optional[str] = field(
|
| 142 |
-
default=None,
|
| 143 |
-
metadata={"help": "The name of the dataset to use (via the datasets library)."},
|
| 144 |
-
)
|
| 145 |
-
dataset_config_name: Optional[str] = field(
|
| 146 |
-
default=None,
|
| 147 |
-
metadata={
|
| 148 |
-
"help": "The configuration name of the dataset to use (via the datasets library)."
|
| 149 |
-
},
|
| 150 |
-
)
|
| 151 |
text_column: Optional[str] = field(
|
| 152 |
default="caption",
|
| 153 |
metadata={
|
|
@@ -260,14 +250,10 @@ class DataTrainingArguments:
|
|
| 260 |
)
|
| 261 |
|
| 262 |
def __post_init__(self):
|
| 263 |
-
if
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
):
|
| 268 |
-
raise ValueError(
|
| 269 |
-
"Need either a dataset name or a training/validation file."
|
| 270 |
-
)
|
| 271 |
else:
|
| 272 |
if self.train_file is not None:
|
| 273 |
extension = self.train_file.split(".")[-1]
|
|
@@ -287,6 +273,10 @@ class DataTrainingArguments:
|
|
| 287 |
], "`validation_file` should be a tsv, csv or json file."
|
| 288 |
if self.val_max_target_length is None:
|
| 289 |
self.val_max_target_length = self.max_target_length
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
|
| 291 |
|
| 292 |
class TrainState(train_state.TrainState):
|
|
@@ -467,18 +457,6 @@ def main():
|
|
| 467 |
"Use --overwrite_output_dir to overcome."
|
| 468 |
)
|
| 469 |
|
| 470 |
-
# Set up wandb run
|
| 471 |
-
wandb.init(
|
| 472 |
-
entity="dalle-mini",
|
| 473 |
-
project="dalle-mini",
|
| 474 |
-
job_type="Seq2Seq",
|
| 475 |
-
config=parser.parse_args(),
|
| 476 |
-
)
|
| 477 |
-
|
| 478 |
-
# set default x-axis as 'train/step'
|
| 479 |
-
wandb.define_metric("train/step")
|
| 480 |
-
wandb.define_metric("*", step_metric="train/step")
|
| 481 |
-
|
| 482 |
# Make one log on every process with the configuration for debugging.
|
| 483 |
pylogging.basicConfig(
|
| 484 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
@@ -528,6 +506,18 @@ def main():
|
|
| 528 |
|
| 529 |
return step, optimizer_step, opt_state
|
| 530 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 531 |
if model_args.from_checkpoint is not None:
|
| 532 |
artifact = wandb.run.use_artifact(model_args.from_checkpoint)
|
| 533 |
artifact_dir = artifact.download()
|
|
@@ -1006,6 +996,7 @@ def main():
|
|
| 1006 |
|
| 1007 |
for epoch in epochs:
|
| 1008 |
# ======================== Training ================================
|
|
|
|
| 1009 |
|
| 1010 |
# Create sampling rng
|
| 1011 |
rng, input_rng = jax.random.split(rng)
|
|
|
|
| 138 |
Arguments pertaining to what data we are going to input our model for training and eval.
|
| 139 |
"""
|
| 140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
text_column: Optional[str] = field(
|
| 142 |
default="caption",
|
| 143 |
metadata={
|
|
|
|
| 250 |
)
|
| 251 |
|
| 252 |
def __post_init__(self):
|
| 253 |
+
if self.dataset_repo_or_path is None:
|
| 254 |
+
raise ValueError("Need a dataset repository or path.")
|
| 255 |
+
if self.train_file is None or self.validation_file is None:
|
| 256 |
+
raise ValueError("Need training/validation file.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
else:
|
| 258 |
if self.train_file is not None:
|
| 259 |
extension = self.train_file.split(".")[-1]
|
|
|
|
| 273 |
], "`validation_file` should be a tsv, csv or json file."
|
| 274 |
if self.val_max_target_length is None:
|
| 275 |
self.val_max_target_length = self.max_target_length
|
| 276 |
+
if self.streaming and (self.len_train is None or self.len_eval is None):
|
| 277 |
+
raise ValueError(
|
| 278 |
+
"Streaming requires providing length of training and validation datasets"
|
| 279 |
+
)
|
| 280 |
|
| 281 |
|
| 282 |
class TrainState(train_state.TrainState):
|
|
|
|
| 457 |
"Use --overwrite_output_dir to overcome."
|
| 458 |
)
|
| 459 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
# Make one log on every process with the configuration for debugging.
|
| 461 |
pylogging.basicConfig(
|
| 462 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
|
|
| 506 |
|
| 507 |
return step, optimizer_step, opt_state
|
| 508 |
|
| 509 |
+
# Set up wandb run
|
| 510 |
+
wandb.init(
|
| 511 |
+
entity="dalle-mini",
|
| 512 |
+
project="dalle-mini",
|
| 513 |
+
job_type="Seq2Seq",
|
| 514 |
+
config=parser.parse_args(),
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
# set default x-axis as 'train/step'
|
| 518 |
+
wandb.define_metric("train/step")
|
| 519 |
+
wandb.define_metric("*", step_metric="train/step")
|
| 520 |
+
|
| 521 |
if model_args.from_checkpoint is not None:
|
| 522 |
artifact = wandb.run.use_artifact(model_args.from_checkpoint)
|
| 523 |
artifact_dir = artifact.download()
|
|
|
|
| 996 |
|
| 997 |
for epoch in epochs:
|
| 998 |
# ======================== Training ================================
|
| 999 |
+
wandb_log({"train/epoch": epoch}, step=global_step)
|
| 1000 |
|
| 1001 |
# Create sampling rng
|
| 1002 |
rng, input_rng = jax.random.split(rng)
|