Spaces:
Build error
Build error
File size: 4,378 Bytes
b100e1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
# Defaults for finetuning with train.py.
#
#
# You must also include a binding for MODEL.
#
# Required to be set:
#
# - MIXTURE_OR_TASK_NAME
# - TASK_FEATURE_LENGTHS
# - TRAIN_STEPS # includes pretrain steps
# - MODEL_DIR # automatically set when using xm_launch
# - INITIAL_CHECKPOINT_PATH
#
# When running locally, it needs to be passed in the `gin.MODEL_DIR` flag.
#
# `TRAIN_STEPS` should include pre-training steps, e.g., if pre-trained ckpt
# has 1M steps, TRAIN_STEPS = 1.1M will perform 0.1M fine-tuning steps.
#
# Commonly overridden options:
# - DROPOUT_RATE
# - BATCH_SIZE
# - PjitPartitioner.num_partitions
# - Trainer.num_microbatches
# - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess
# on the fly. Most common tasks are cached, hence this is set to True by
# default.
from __gin__ import dynamic_registration
import __main__ as train_script
import seqio
from t5x import gin_utils
from t5x import partitioning
from t5x import utils
from t5x import trainer
# Must be overridden
MODEL_DIR = %gin.REQUIRED
MIXTURE_OR_TASK_NAME = %gin.REQUIRED
TASK_FEATURE_LENGTHS = %gin.REQUIRED
MIXTURE_OR_TASK_MODULE = %gin.REQUIRED
TRAIN_STEPS = %gin.REQUIRED
INITIAL_CHECKPOINT_PATH = %gin.REQUIRED
# Commonly overridden
DROPOUT_RATE = 0.1
USE_CACHED_TASKS = True
BATCH_SIZE = 128
# Sometimes overridden
EVAL_STEPS = 20
# Convenience overrides.
EVALUATOR_USE_MEMORY_CACHE = True
EVALUATOR_NUM_EXAMPLES = None # Use all examples in the infer_eval dataset.
JSON_WRITE_N_RESULTS = None # Write all inferences.
# HW RNG is faster than SW, but has limited determinism.
# Most notably it is not deterministic across different
# submeshes.
USE_HARDWARE_RNG = False
# None always uses faster, hardware RNG
RANDOM_SEED = None
# DEPRECATED: Import the this module in your gin file.
MIXTURE_OR_TASK_MODULE = None
train_script.train:
model = %MODEL # imported from separate gin file
model_dir = %MODEL_DIR
train_dataset_cfg = @train/utils.DatasetConfig()
train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig()
checkpoint_cfg = @utils.CheckpointConfig()
partitioner = @partitioning.PjitPartitioner()
trainer_cls = @trainer.Trainer
total_steps = %TRAIN_STEPS
eval_steps = %EVAL_STEPS
eval_period = 1000
random_seed = %RANDOM_SEED
use_hardware_rng = %USE_HARDWARE_RNG
summarize_config_fn = @gin_utils.summarize_gin_config
inference_evaluator_cls = @seqio.Evaluator
partitioning.PjitPartitioner:
num_partitions = 1
model_parallel_submesh = None
logical_axis_rules = @partitioning.standard_logical_axis_rules()
seqio.Evaluator:
logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger]
num_examples = %EVALUATOR_NUM_EXAMPLES
use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE
seqio.JSONLogger:
write_n_results = %JSON_WRITE_N_RESULTS
train/utils.DatasetConfig:
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
task_feature_lengths = %TASK_FEATURE_LENGTHS
split = 'train'
batch_size = %BATCH_SIZE
shuffle = True
seed = None # use a new seed each run/restart
use_cached = %USE_CACHED_TASKS
pack = True
module = %MIXTURE_OR_TASK_MODULE
train_eval/utils.DatasetConfig:
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
task_feature_lengths = %TASK_FEATURE_LENGTHS
split = 'validation'
batch_size = %BATCH_SIZE
shuffle = False
seed = 42
use_cached = %USE_CACHED_TASKS
pack = True
module = %MIXTURE_OR_TASK_MODULE
infer_eval/utils.DatasetConfig:
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
task_feature_lengths = None # compute max
split = 'validation'
batch_size = %BATCH_SIZE
shuffle = False
seed = 42
use_cached = %USE_CACHED_TASKS
pack = False
module = %MIXTURE_OR_TASK_MODULE
utils.CheckpointConfig:
restore = @utils.RestoreCheckpointConfig()
save = @utils.SaveCheckpointConfig()
utils.RestoreCheckpointConfig:
path = %INITIAL_CHECKPOINT_PATH
mode = 'specific'
dtype = 'float32'
utils.SaveCheckpointConfig:
period = 5000
dtype = 'float32'
keep = None # keep all checkpoints
save_dataset = False # don't checkpoint dataset state
trainer.Trainer:
num_microbatches = None
learning_rate_fn = @utils.create_learning_rate_scheduler()
utils.create_learning_rate_scheduler:
factors = 'constant'
base_learning_rate = 0.001
warmup_steps = 1000
|