Spaces:
Running
Running
feat(train): improve pjit speed
Browse files- src/dalle_mini/data.py +7 -28
- tools/train/train.py +78 -42
src/dalle_mini/data.py
CHANGED
|
@@ -152,14 +152,7 @@ class Dataset:
|
|
| 152 |
),
|
| 153 |
)
|
| 154 |
|
| 155 |
-
def dataloader(
|
| 156 |
-
self, split, per_device_batch_size, gradient_accumulation_steps=None, epoch=None
|
| 157 |
-
):
|
| 158 |
-
num_devices = jax.local_device_count()
|
| 159 |
-
total_batch_size = per_device_batch_size * num_devices
|
| 160 |
-
if gradient_accumulation_steps is not None:
|
| 161 |
-
total_batch_size *= gradient_accumulation_steps
|
| 162 |
-
|
| 163 |
def _dataloader_datasets_non_streaming(
|
| 164 |
dataset: Dataset,
|
| 165 |
rng: jax.random.PRNGKey = None,
|
|
@@ -168,7 +161,7 @@ class Dataset:
|
|
| 168 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
| 169 |
Shuffle batches if rng is set.
|
| 170 |
"""
|
| 171 |
-
steps_per_epoch = len(dataset) //
|
| 172 |
|
| 173 |
if rng is not None:
|
| 174 |
batch_idx = jax.random.permutation(rng, len(dataset))
|
|
@@ -176,20 +169,13 @@ class Dataset:
|
|
| 176 |
batch_idx = jnp.arange(len(dataset))
|
| 177 |
|
| 178 |
batch_idx = batch_idx[
|
| 179 |
-
: steps_per_epoch *
|
| 180 |
] # Skip incomplete batch.
|
| 181 |
-
batch_idx = batch_idx.reshape((steps_per_epoch,
|
| 182 |
|
| 183 |
for idx in batch_idx:
|
| 184 |
batch = dataset[idx]
|
| 185 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
| 186 |
-
if gradient_accumulation_steps is not None:
|
| 187 |
-
batch = jax.tree_map(
|
| 188 |
-
lambda x: x.reshape(
|
| 189 |
-
(gradient_accumulation_steps, -1) + x.shape[1:]
|
| 190 |
-
),
|
| 191 |
-
batch,
|
| 192 |
-
)
|
| 193 |
yield batch
|
| 194 |
|
| 195 |
def _dataloader_datasets_streaming(
|
|
@@ -205,22 +191,15 @@ class Dataset:
|
|
| 205 |
# For validation data we put the entire set on each host as we could lose
|
| 206 |
# too many samples on pods
|
| 207 |
if epoch is not None:
|
| 208 |
-
|
|
|
|
| 209 |
dataset.set_epoch(epoch)
|
| 210 |
epoch += 1
|
| 211 |
for item in dataset:
|
| 212 |
for k, v in item.items():
|
| 213 |
batch[k].append(v)
|
| 214 |
-
if len(batch[keys[0]]) ==
|
| 215 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
| 216 |
-
if gradient_accumulation_steps is not None:
|
| 217 |
-
# training mode
|
| 218 |
-
batch = jax.tree_map(
|
| 219 |
-
lambda x: x.reshape(
|
| 220 |
-
(gradient_accumulation_steps, -1) + x.shape[1:]
|
| 221 |
-
),
|
| 222 |
-
batch,
|
| 223 |
-
)
|
| 224 |
yield batch
|
| 225 |
batch = {k: [] for k in keys}
|
| 226 |
first_loop = False
|
|
|
|
| 152 |
),
|
| 153 |
)
|
| 154 |
|
| 155 |
+
def dataloader(self, split, batch_size, epoch=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
def _dataloader_datasets_non_streaming(
|
| 157 |
dataset: Dataset,
|
| 158 |
rng: jax.random.PRNGKey = None,
|
|
|
|
| 161 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
| 162 |
Shuffle batches if rng is set.
|
| 163 |
"""
|
| 164 |
+
steps_per_epoch = len(dataset) // batch_size
|
| 165 |
|
| 166 |
if rng is not None:
|
| 167 |
batch_idx = jax.random.permutation(rng, len(dataset))
|
|
|
|
| 169 |
batch_idx = jnp.arange(len(dataset))
|
| 170 |
|
| 171 |
batch_idx = batch_idx[
|
| 172 |
+
: steps_per_epoch * batch_size
|
| 173 |
] # Skip incomplete batch.
|
| 174 |
+
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
| 175 |
|
| 176 |
for idx in batch_idx:
|
| 177 |
batch = dataset[idx]
|
| 178 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
yield batch
|
| 180 |
|
| 181 |
def _dataloader_datasets_streaming(
|
|
|
|
| 191 |
# For validation data we put the entire set on each host as we could lose
|
| 192 |
# too many samples on pods
|
| 193 |
if epoch is not None:
|
| 194 |
+
assert split == "train"
|
| 195 |
+
# reshuffle training data at each epoch
|
| 196 |
dataset.set_epoch(epoch)
|
| 197 |
epoch += 1
|
| 198 |
for item in dataset:
|
| 199 |
for k, v in item.items():
|
| 200 |
batch[k].append(v)
|
| 201 |
+
if len(batch[keys[0]]) == batch_size:
|
| 202 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
yield batch
|
| 204 |
batch = {k: [] for k in keys}
|
| 205 |
first_loop = False
|
tools/train/train.py
CHANGED
|
@@ -36,12 +36,12 @@ import transformers
|
|
| 36 |
import wandb
|
| 37 |
from datasets import Dataset
|
| 38 |
from distributed_shampoo import GraftingType, distributed_shampoo
|
| 39 |
-
from flax.core.frozen_dict import FrozenDict, freeze
|
| 40 |
from flax.serialization import from_bytes, to_bytes
|
| 41 |
from flax.training import train_state
|
| 42 |
from flax.training.common_utils import onehot, stack_forest
|
| 43 |
from jax.experimental import PartitionSpec, maps
|
| 44 |
-
from jax.experimental.pjit import pjit
|
| 45 |
from tqdm import tqdm
|
| 46 |
from transformers import HfArgumentParser
|
| 47 |
|
|
@@ -551,12 +551,12 @@ def main():
|
|
| 551 |
num_epochs = training_args.num_train_epochs
|
| 552 |
# batch size
|
| 553 |
minibatch_size = (
|
| 554 |
-
training_args.per_device_train_batch_size *
|
| 555 |
)
|
| 556 |
batch_size_per_node = minibatch_size * training_args.gradient_accumulation_steps
|
| 557 |
batch_size_per_step = batch_size_per_node * jax.process_count()
|
| 558 |
eval_batch_size = (
|
| 559 |
-
training_args.per_device_eval_batch_size *
|
| 560 |
)
|
| 561 |
len_train_dataset, len_eval_dataset = dataset.length
|
| 562 |
steps_per_epoch = (
|
|
@@ -762,6 +762,10 @@ def main():
|
|
| 762 |
# free memory
|
| 763 |
del model._params
|
| 764 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 765 |
# label smoothed cross entropy
|
| 766 |
def loss_fn(logits, labels):
|
| 767 |
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
|
|
@@ -771,16 +775,18 @@ def main():
|
|
| 771 |
# Define gradient update step fn
|
| 772 |
def train_step(state, batch, delta_time):
|
| 773 |
# check correct batch shape during compilation
|
| 774 |
-
assert batch["labels"].shape[0:
|
|
|
|
| 775 |
training_args.gradient_accumulation_steps,
|
| 776 |
-
|
| 777 |
-
), f"Expected label batch of shape
|
| 778 |
# create a new rng
|
| 779 |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
| 780 |
# use a different rng per node
|
| 781 |
dropout_rng = jax.random.fold_in(dropout_rng, jax.process_index())
|
| 782 |
|
| 783 |
def compute_loss(params, minibatch):
|
|
|
|
| 784 |
labels = minibatch.pop("labels")
|
| 785 |
logits = state.apply_fn(
|
| 786 |
**minibatch, params=params, dropout_rng=dropout_rng, train=True
|
|
@@ -789,32 +795,52 @@ def main():
|
|
| 789 |
|
| 790 |
grad_fn = jax.value_and_grad(compute_loss)
|
| 791 |
|
| 792 |
-
|
| 793 |
-
|
| 794 |
-
loss, grads = grad_fn(state.params, minibatch)
|
| 795 |
-
else:
|
| 796 |
|
| 797 |
-
|
| 798 |
-
minibatch = jax.tree_map(
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
cumul_loss_grads,
|
| 802 |
-
grad_fn(state.params, minibatch),
|
| 803 |
)
|
|
|
|
|
|
|
| 804 |
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
|
| 808 |
-
|
| 809 |
-
|
| 810 |
-
|
| 811 |
-
|
| 812 |
-
|
| 813 |
-
|
| 814 |
-
|
| 815 |
-
|
| 816 |
-
|
| 817 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 818 |
|
| 819 |
state = state.apply_gradients(
|
| 820 |
grads=grads,
|
|
@@ -832,6 +858,7 @@ def main():
|
|
| 832 |
|
| 833 |
# Define eval fn
|
| 834 |
def eval_step(params, batch):
|
|
|
|
| 835 |
labels = batch.pop("labels")
|
| 836 |
logits = model(**batch, params=params, train=False)[0]
|
| 837 |
loss = loss_fn(logits, labels)
|
|
@@ -843,13 +870,13 @@ def main():
|
|
| 843 |
# Create parallel version of the train and eval step
|
| 844 |
p_train_step = pjit(
|
| 845 |
train_step,
|
| 846 |
-
in_axis_resources=(state_spec,
|
| 847 |
out_axis_resources=(state_spec, None),
|
| 848 |
donate_argnums=(0,),
|
| 849 |
)
|
| 850 |
p_eval_step = pjit(
|
| 851 |
eval_step,
|
| 852 |
-
in_axis_resources=(param_spec,
|
| 853 |
out_axis_resources=None,
|
| 854 |
)
|
| 855 |
|
|
@@ -890,9 +917,7 @@ def main():
|
|
| 890 |
# ======================== Evaluating ==============================
|
| 891 |
eval_metrics = []
|
| 892 |
if training_args.do_eval:
|
| 893 |
-
eval_loader = dataset.dataloader(
|
| 894 |
-
"eval", training_args.per_device_eval_batch_size
|
| 895 |
-
)
|
| 896 |
eval_steps = (
|
| 897 |
len_eval_dataset // eval_batch_size
|
| 898 |
if len_eval_dataset is not None
|
|
@@ -905,8 +930,8 @@ def main():
|
|
| 905 |
leave=False,
|
| 906 |
total=eval_steps,
|
| 907 |
):
|
| 908 |
-
#
|
| 909 |
-
metrics = p_eval_step(state.params, batch)
|
| 910 |
eval_metrics.append(metrics)
|
| 911 |
|
| 912 |
# normalize eval metrics
|
|
@@ -1010,8 +1035,7 @@ def main():
|
|
| 1010 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
| 1011 |
train_loader = dataset.dataloader(
|
| 1012 |
"train",
|
| 1013 |
-
|
| 1014 |
-
training_args.gradient_accumulation_steps,
|
| 1015 |
epoch,
|
| 1016 |
)
|
| 1017 |
# train
|
|
@@ -1022,15 +1046,27 @@ def main():
|
|
| 1022 |
leave=False,
|
| 1023 |
total=steps_per_epoch,
|
| 1024 |
):
|
| 1025 |
-
|
| 1026 |
# calculate delta time (we have a lag of one step but it's ok)
|
| 1027 |
new_time = time.perf_counter()
|
| 1028 |
delta_time = new_time - last_time
|
| 1029 |
last_time = new_time
|
| 1030 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1031 |
# train step
|
| 1032 |
-
state, train_metrics = p_train_step(state, batch, delta_time)
|
| 1033 |
-
step = state.step
|
| 1034 |
|
| 1035 |
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
| 1036 |
all_metrics = metrics_logger.get_all_train_metrics(
|
|
|
|
| 36 |
import wandb
|
| 37 |
from datasets import Dataset
|
| 38 |
from distributed_shampoo import GraftingType, distributed_shampoo
|
| 39 |
+
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
| 40 |
from flax.serialization import from_bytes, to_bytes
|
| 41 |
from flax.training import train_state
|
| 42 |
from flax.training.common_utils import onehot, stack_forest
|
| 43 |
from jax.experimental import PartitionSpec, maps
|
| 44 |
+
from jax.experimental.pjit import pjit, with_sharding_constraint
|
| 45 |
from tqdm import tqdm
|
| 46 |
from transformers import HfArgumentParser
|
| 47 |
|
|
|
|
| 551 |
num_epochs = training_args.num_train_epochs
|
| 552 |
# batch size
|
| 553 |
minibatch_size = (
|
| 554 |
+
training_args.per_device_train_batch_size * training_args.dp_devices
|
| 555 |
)
|
| 556 |
batch_size_per_node = minibatch_size * training_args.gradient_accumulation_steps
|
| 557 |
batch_size_per_step = batch_size_per_node * jax.process_count()
|
| 558 |
eval_batch_size = (
|
| 559 |
+
training_args.per_device_eval_batch_size * training_args.dp_devices
|
| 560 |
)
|
| 561 |
len_train_dataset, len_eval_dataset = dataset.length
|
| 562 |
steps_per_epoch = (
|
|
|
|
| 762 |
# free memory
|
| 763 |
del model._params
|
| 764 |
|
| 765 |
+
# define batch specs
|
| 766 |
+
keys = ["attention_mask", "decoder_input_ids", "input_ids", "labels"]
|
| 767 |
+
batch_spec = freeze({k: PartitionSpec("batch") for k in keys})
|
| 768 |
+
|
| 769 |
# label smoothed cross entropy
|
| 770 |
def loss_fn(logits, labels):
|
| 771 |
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
|
|
|
|
| 775 |
# Define gradient update step fn
|
| 776 |
def train_step(state, batch, delta_time):
|
| 777 |
# check correct batch shape during compilation
|
| 778 |
+
assert batch["labels"].shape[0:3] == (
|
| 779 |
+
training_args.dp_devices,
|
| 780 |
training_args.gradient_accumulation_steps,
|
| 781 |
+
training_args.per_device_train_batch_size,
|
| 782 |
+
), f"Expected label batch of shape dp_devices x gradient_acculumation x batch_per_device and got {batch['labels'].shape}"
|
| 783 |
# create a new rng
|
| 784 |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
| 785 |
# use a different rng per node
|
| 786 |
dropout_rng = jax.random.fold_in(dropout_rng, jax.process_index())
|
| 787 |
|
| 788 |
def compute_loss(params, minibatch):
|
| 789 |
+
minibatch = unfreeze(minibatch)
|
| 790 |
labels = minibatch.pop("labels")
|
| 791 |
logits = state.apply_fn(
|
| 792 |
**minibatch, params=params, dropout_rng=dropout_rng, train=True
|
|
|
|
| 795 |
|
| 796 |
grad_fn = jax.value_and_grad(compute_loss)
|
| 797 |
|
| 798 |
+
def loss_grad_per_device(device_batch):
|
| 799 |
+
# device_batch has format (gradient_accumulation_steps, batch_size, ...)
|
|
|
|
|
|
|
| 800 |
|
| 801 |
+
if training_args.gradient_accumulation_steps == 1:
|
| 802 |
+
minibatch = jax.tree_map(
|
| 803 |
+
lambda x: x[0],
|
| 804 |
+
device_batch,
|
|
|
|
|
|
|
| 805 |
)
|
| 806 |
+
loss, grads = grad_fn(state.params, minibatch)
|
| 807 |
+
else:
|
| 808 |
|
| 809 |
+
def _cumul_loss_grads(i, cumul_loss_grads):
|
| 810 |
+
minibatch = jax.tree_map(
|
| 811 |
+
lambda x: x[i],
|
| 812 |
+
device_batch,
|
| 813 |
+
)
|
| 814 |
+
return jax.tree_map(
|
| 815 |
+
lambda x, y: x + y,
|
| 816 |
+
cumul_loss_grads,
|
| 817 |
+
grad_fn(state.params, minibatch),
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
init_loss_grads = (
|
| 821 |
+
0.0,
|
| 822 |
+
jax.tree_map(jnp.zeros_like, state.params),
|
| 823 |
+
)
|
| 824 |
+
loss, grads = jax.tree_map(
|
| 825 |
+
lambda x: x / training_args.gradient_accumulation_steps,
|
| 826 |
+
jax.lax.fori_loop(
|
| 827 |
+
0,
|
| 828 |
+
training_args.gradient_accumulation_steps,
|
| 829 |
+
_cumul_loss_grads,
|
| 830 |
+
init_loss_grads,
|
| 831 |
+
),
|
| 832 |
+
)
|
| 833 |
+
return loss, grads
|
| 834 |
+
|
| 835 |
+
# calculate loss, grads per dp device
|
| 836 |
+
# batch has shape (dp_devices, gradient_accumulation_steps, batch_per_dp_device, ...)
|
| 837 |
+
loss, grads = jax.vmap(loss_grad_per_device, in_axes=0, out_axes=(0, 0))(batch)
|
| 838 |
+
# enforce sharding constraints to avoid OOM
|
| 839 |
+
loss = with_sharding_constraint(loss, PartitionSpec("batch"))
|
| 840 |
+
grads = with_sharding_constraint(grads, PartitionSpec("batch"))
|
| 841 |
+
# calculate the mean over all devices
|
| 842 |
+
loss = jnp.mean(loss)
|
| 843 |
+
grads = jax.tree_map(lambda x: jnp.mean(x, axis=0), grads)
|
| 844 |
|
| 845 |
state = state.apply_gradients(
|
| 846 |
grads=grads,
|
|
|
|
| 858 |
|
| 859 |
# Define eval fn
|
| 860 |
def eval_step(params, batch):
|
| 861 |
+
batch = unfreeze(batch)
|
| 862 |
labels = batch.pop("labels")
|
| 863 |
logits = model(**batch, params=params, train=False)[0]
|
| 864 |
loss = loss_fn(logits, labels)
|
|
|
|
| 870 |
# Create parallel version of the train and eval step
|
| 871 |
p_train_step = pjit(
|
| 872 |
train_step,
|
| 873 |
+
in_axis_resources=(state_spec, batch_spec, None),
|
| 874 |
out_axis_resources=(state_spec, None),
|
| 875 |
donate_argnums=(0,),
|
| 876 |
)
|
| 877 |
p_eval_step = pjit(
|
| 878 |
eval_step,
|
| 879 |
+
in_axis_resources=(param_spec, batch_spec),
|
| 880 |
out_axis_resources=None,
|
| 881 |
)
|
| 882 |
|
|
|
|
| 917 |
# ======================== Evaluating ==============================
|
| 918 |
eval_metrics = []
|
| 919 |
if training_args.do_eval:
|
| 920 |
+
eval_loader = dataset.dataloader("eval", eval_batch_size)
|
|
|
|
|
|
|
| 921 |
eval_steps = (
|
| 922 |
len_eval_dataset // eval_batch_size
|
| 923 |
if len_eval_dataset is not None
|
|
|
|
| 930 |
leave=False,
|
| 931 |
total=eval_steps,
|
| 932 |
):
|
| 933 |
+
# TODO: make this more efficient once training loop is fast
|
| 934 |
+
metrics = p_eval_step(state.params, freeze(batch))
|
| 935 |
eval_metrics.append(metrics)
|
| 936 |
|
| 937 |
# normalize eval metrics
|
|
|
|
| 1035 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
| 1036 |
train_loader = dataset.dataloader(
|
| 1037 |
"train",
|
| 1038 |
+
batch_size_per_node,
|
|
|
|
| 1039 |
epoch,
|
| 1040 |
)
|
| 1041 |
# train
|
|
|
|
| 1046 |
leave=False,
|
| 1047 |
total=steps_per_epoch,
|
| 1048 |
):
|
|
|
|
| 1049 |
# calculate delta time (we have a lag of one step but it's ok)
|
| 1050 |
new_time = time.perf_counter()
|
| 1051 |
delta_time = new_time - last_time
|
| 1052 |
last_time = new_time
|
| 1053 |
|
| 1054 |
+
# reshape data into (dp_devices, gradient_accumulation_steps, batch_per_dp_device, ...)
|
| 1055 |
+
batch = jax.tree_map(
|
| 1056 |
+
lambda x: x.reshape(
|
| 1057 |
+
(
|
| 1058 |
+
training_args.dp_devices,
|
| 1059 |
+
training_args.gradient_accumulation_steps,
|
| 1060 |
+
-1,
|
| 1061 |
+
)
|
| 1062 |
+
+ x.shape[1:]
|
| 1063 |
+
),
|
| 1064 |
+
batch,
|
| 1065 |
+
)
|
| 1066 |
+
|
| 1067 |
# train step
|
| 1068 |
+
state, train_metrics = p_train_step(state, freeze(batch), delta_time)
|
| 1069 |
+
step = int(state.step)
|
| 1070 |
|
| 1071 |
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
| 1072 |
all_metrics = metrics_logger.get_all_train_metrics(
|