Commit
·
4015e59
1
Parent(s):
c1c0b27
Upload 6 files
Browse files- config/eval_gpt2.py +8 -0
- config/eval_gpt2_large.py +8 -0
- config/eval_gpt2_medium.py +8 -0
- config/eval_gpt2_xl.py +8 -0
- config/finetune_shakespeare.py +25 -0
- config/train_gpt2.py +25 -0
config/eval_gpt2.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# evaluate the base gpt2
|
| 2 |
+
# n_layer=12, n_head=12, n_embd=768
|
| 3 |
+
# 124M parameters
|
| 4 |
+
batch_size = 8
|
| 5 |
+
eval_iters = 500 # use more iterations to get good estimate
|
| 6 |
+
eval_only = True
|
| 7 |
+
wandb_log = False
|
| 8 |
+
init_from = 'gpt2'
|
config/eval_gpt2_large.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# evaluate the base gpt2
|
| 2 |
+
# n_layer=36, n_head=20, n_embd=1280
|
| 3 |
+
# 774M parameters
|
| 4 |
+
batch_size = 8
|
| 5 |
+
eval_iters = 500 # use more iterations to get good estimate
|
| 6 |
+
eval_only = True
|
| 7 |
+
wandb_log = False
|
| 8 |
+
init_from = 'gpt2-large'
|
config/eval_gpt2_medium.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# evaluate the base gpt2
|
| 2 |
+
# n_layer=24, n_head=16, n_embd=1024
|
| 3 |
+
# 350M parameters
|
| 4 |
+
batch_size = 8
|
| 5 |
+
eval_iters = 500 # use more iterations to get good estimate
|
| 6 |
+
eval_only = True
|
| 7 |
+
wandb_log = False
|
| 8 |
+
init_from = 'gpt2-medium'
|
config/eval_gpt2_xl.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# evaluate the base gpt2
|
| 2 |
+
# n_layer=48, n_head=25, n_embd=1600
|
| 3 |
+
# 1558M parameters
|
| 4 |
+
batch_size = 8
|
| 5 |
+
eval_iters = 500 # use more iterations to get good estimate
|
| 6 |
+
eval_only = True
|
| 7 |
+
wandb_log = False
|
| 8 |
+
init_from = 'gpt2-xl'
|
config/finetune_shakespeare.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
|
| 3 |
+
out_dir = 'out-shakespeare'
|
| 4 |
+
eval_interval = 5
|
| 5 |
+
eval_iters = 40
|
| 6 |
+
wandb_log = False # feel free to turn on
|
| 7 |
+
wandb_project = 'shakespeare'
|
| 8 |
+
wandb_run_name = 'ft-' + str(time.time())
|
| 9 |
+
|
| 10 |
+
dataset = 'shakespeare'
|
| 11 |
+
init_from = 'gpt2-xl' # this is the largest GPT-2 model
|
| 12 |
+
|
| 13 |
+
# only save checkpoints if the validation loss improves
|
| 14 |
+
always_save_checkpoint = False
|
| 15 |
+
|
| 16 |
+
# the number of examples per iter:
|
| 17 |
+
# 1 batch_size * 32 grad_accum * 1024 tokens = 32,768 tokens/iter
|
| 18 |
+
# shakespeare has 301,966 tokens, so 1 epoch ~= 9.2 iters
|
| 19 |
+
batch_size = 1
|
| 20 |
+
gradient_accumulation_steps = 32
|
| 21 |
+
max_iters = 20
|
| 22 |
+
|
| 23 |
+
# finetune at constant LR
|
| 24 |
+
learning_rate = 3e-5
|
| 25 |
+
decay_lr = False
|
config/train_gpt2.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# config for training GPT-2 (124M) down to very nice loss of ~2.85 on 1 node of 8X A100 40GB
|
| 2 |
+
# launch as the following (e.g. in a screen session) and wait ~5 days:
|
| 3 |
+
# $ torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py
|
| 4 |
+
|
| 5 |
+
wandb_log = True
|
| 6 |
+
wandb_project = 'owt'
|
| 7 |
+
wandb_run_name='gpt2-124M'
|
| 8 |
+
|
| 9 |
+
# these make the total batch size be ~0.5M
|
| 10 |
+
# 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520
|
| 11 |
+
batch_size = 12
|
| 12 |
+
block_size = 1024
|
| 13 |
+
gradient_accumulation_steps = 5 * 8
|
| 14 |
+
|
| 15 |
+
# this makes total number of tokens be 300B
|
| 16 |
+
max_iters = 600000
|
| 17 |
+
lr_decay_iters = 600000
|
| 18 |
+
|
| 19 |
+
# eval stuff
|
| 20 |
+
eval_interval = 1000
|
| 21 |
+
eval_iters = 200
|
| 22 |
+
log_interval = 10
|
| 23 |
+
|
| 24 |
+
# weight decay
|
| 25 |
+
weight_decay = 1e-1
|