Spaces:
Running
Running
feat(train): handle distributed_shampoo in pjit
Browse files- tools/train/train.py +40 -36
tools/train/train.py
CHANGED
|
@@ -25,7 +25,7 @@ import sys
|
|
| 25 |
import time
|
| 26 |
from dataclasses import asdict, dataclass, field
|
| 27 |
from pathlib import Path
|
| 28 |
-
from typing import Callable, Optional
|
| 29 |
|
| 30 |
import datasets
|
| 31 |
import jax
|
|
@@ -36,7 +36,7 @@ import transformers
|
|
| 36 |
import wandb
|
| 37 |
from datasets import Dataset
|
| 38 |
from distributed_shampoo import GraftingType, distributed_shampoo
|
| 39 |
-
from flax.core.frozen_dict import
|
| 40 |
from flax.serialization import from_bytes, to_bytes
|
| 41 |
from flax.training import train_state
|
| 42 |
from flax.training.common_utils import onehot, stack_forest
|
|
@@ -523,6 +523,12 @@ def main():
|
|
| 523 |
use_fast=True,
|
| 524 |
)
|
| 525 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
# Preprocessing the datasets.
|
| 527 |
# We need to normalize and tokenize inputs and targets.
|
| 528 |
|
|
@@ -620,6 +626,13 @@ def main():
|
|
| 620 |
precision=jax.lax.Precision.HIGHEST,
|
| 621 |
best_effort_memory_usage_reduction=training_args.optim_quantized,
|
| 622 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 623 |
|
| 624 |
elif training_args.optim == "adam":
|
| 625 |
optimizer = optax.adamw(
|
|
@@ -636,43 +649,40 @@ def main():
|
|
| 636 |
clipping_threshold=training_args.max_grad_norm,
|
| 637 |
)
|
| 638 |
|
| 639 |
-
# get PartitionSpec for model params
|
| 640 |
-
param_spec = set_partitions(model.params)
|
| 641 |
-
|
| 642 |
# get PartitionSpec for optimizer state
|
| 643 |
def get_opt_state_spec_and_shape(param_spec):
|
| 644 |
-
if training_args.optim
|
| 645 |
# get opt_state shape without actual init
|
| 646 |
opt_state_shape = jax.eval_shape(optimizer.init, model.params)
|
| 647 |
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
|
|
|
|
|
|
| 662 |
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
|
| 667 |
elif training_args.optim == "distributed_shampoo":
|
| 668 |
-
|
| 669 |
-
_opt_state = optimizer.init(model.params)
|
| 670 |
-
opt_state_spec = _opt_state.pspec_fn(
|
| 671 |
params=model.params,
|
| 672 |
-
params_partition_spec=
|
| 673 |
partition_spec_for_statistics=PartitionSpec(None, "batch", None),
|
| 674 |
)
|
| 675 |
-
opt_state_shape =
|
| 676 |
else:
|
| 677 |
raise NotImplementedError
|
| 678 |
return opt_state_spec, opt_state_shape
|
|
@@ -714,18 +724,12 @@ def main():
|
|
| 714 |
in_axis_resources=(param_spec,),
|
| 715 |
out_axis_resources=state_spec,
|
| 716 |
donate_argnums=(0,),
|
| 717 |
-
)(
|
| 718 |
|
| 719 |
else:
|
| 720 |
# restore opt_state
|
| 721 |
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
| 722 |
opt_state = from_bytes(opt_state_shape, f.read())
|
| 723 |
-
# need to freeze dict for pjit
|
| 724 |
-
opt_state = jax.tree_map(
|
| 725 |
-
lambda x: freeze(x) if isinstance(x, dict) else x,
|
| 726 |
-
opt_state,
|
| 727 |
-
is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)),
|
| 728 |
-
)
|
| 729 |
|
| 730 |
# restore other attributes
|
| 731 |
with (Path(artifact_dir) / "training_state.json").open("r") as f:
|
|
@@ -746,7 +750,7 @@ def main():
|
|
| 746 |
in_axis_resources=(param_spec, opt_state_spec),
|
| 747 |
out_axis_resources=state_spec,
|
| 748 |
donate_argnums=(0, 1),
|
| 749 |
-
)(
|
| 750 |
|
| 751 |
# remove opt_state from CPU
|
| 752 |
del opt_state
|
|
|
|
| 25 |
import time
|
| 26 |
from dataclasses import asdict, dataclass, field
|
| 27 |
from pathlib import Path
|
| 28 |
+
from typing import Any, Callable, NamedTuple, Optional
|
| 29 |
|
| 30 |
import datasets
|
| 31 |
import jax
|
|
|
|
| 36 |
import wandb
|
| 37 |
from datasets import Dataset
|
| 38 |
from distributed_shampoo import GraftingType, distributed_shampoo
|
| 39 |
+
from flax.core.frozen_dict import FrozenDict, freeze
|
| 40 |
from flax.serialization import from_bytes, to_bytes
|
| 41 |
from flax.training import train_state
|
| 42 |
from flax.training.common_utils import onehot, stack_forest
|
|
|
|
| 523 |
use_fast=True,
|
| 524 |
)
|
| 525 |
|
| 526 |
+
# get PartitionSpec for model params (required to be a dict)
|
| 527 |
+
param_spec = set_partitions(model.params)
|
| 528 |
+
|
| 529 |
+
# convert params to frozen dict
|
| 530 |
+
model._params = freeze(model.params)
|
| 531 |
+
|
| 532 |
# Preprocessing the datasets.
|
| 533 |
# We need to normalize and tokenize inputs and targets.
|
| 534 |
|
|
|
|
| 626 |
precision=jax.lax.Precision.HIGHEST,
|
| 627 |
best_effort_memory_usage_reduction=training_args.optim_quantized,
|
| 628 |
)
|
| 629 |
+
# get the real optimizer and helper functions
|
| 630 |
+
update_fn = optimizer.update
|
| 631 |
+
optimizer = optimizer.init(model.params)
|
| 632 |
+
opt_fn = NamedTuple("opt_fn", pspec_fn=Any, shape_and_dtype_fn=Any)(
|
| 633 |
+
optimizer.pspec_fn, optimizer.shape_and_dtype_fn
|
| 634 |
+
)
|
| 635 |
+
optimizer = optax.GradientTransformation(optimizer.init_fn, update_fn)
|
| 636 |
|
| 637 |
elif training_args.optim == "adam":
|
| 638 |
optimizer = optax.adamw(
|
|
|
|
| 649 |
clipping_threshold=training_args.max_grad_norm,
|
| 650 |
)
|
| 651 |
|
|
|
|
|
|
|
|
|
|
| 652 |
# get PartitionSpec for optimizer state
|
| 653 |
def get_opt_state_spec_and_shape(param_spec):
|
| 654 |
+
if training_args.optim in ["adam", "adafactor"]:
|
| 655 |
# get opt_state shape without actual init
|
| 656 |
opt_state_shape = jax.eval_shape(optimizer.init, model.params)
|
| 657 |
|
| 658 |
+
if training_args.optim == "adam":
|
| 659 |
+
|
| 660 |
+
def _opt_state_spec_per_leaf(x):
|
| 661 |
+
if isinstance(x, FrozenDict):
|
| 662 |
+
# variables with same structure as params
|
| 663 |
+
return param_spec
|
| 664 |
+
else:
|
| 665 |
+
# other variables such as count
|
| 666 |
+
return None
|
| 667 |
+
|
| 668 |
+
opt_state_spec = jax.tree_map(
|
| 669 |
+
_opt_state_spec_per_leaf,
|
| 670 |
+
opt_state_shape,
|
| 671 |
+
# return None spec for empty elements
|
| 672 |
+
is_leaf=lambda x: isinstance(x, (FrozenDict, optax.EmptyState)),
|
| 673 |
+
)
|
| 674 |
|
| 675 |
+
elif training_args.optim == "adafactor":
|
| 676 |
+
# factorized state must be replicated (rank different than params)
|
| 677 |
+
opt_state_spec = None
|
| 678 |
|
| 679 |
elif training_args.optim == "distributed_shampoo":
|
| 680 |
+
opt_state_spec = opt_fn.pspec_fn(
|
|
|
|
|
|
|
| 681 |
params=model.params,
|
| 682 |
+
params_partition_spec=param_spec,
|
| 683 |
partition_spec_for_statistics=PartitionSpec(None, "batch", None),
|
| 684 |
)
|
| 685 |
+
opt_state_shape = opt_fn.shape_and_dtype_fn(model.params)
|
| 686 |
else:
|
| 687 |
raise NotImplementedError
|
| 688 |
return opt_state_spec, opt_state_shape
|
|
|
|
| 724 |
in_axis_resources=(param_spec,),
|
| 725 |
out_axis_resources=state_spec,
|
| 726 |
donate_argnums=(0,),
|
| 727 |
+
)(model.params)
|
| 728 |
|
| 729 |
else:
|
| 730 |
# restore opt_state
|
| 731 |
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
| 732 |
opt_state = from_bytes(opt_state_shape, f.read())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 733 |
|
| 734 |
# restore other attributes
|
| 735 |
with (Path(artifact_dir) / "training_state.json").open("r") as f:
|
|
|
|
| 750 |
in_axis_resources=(param_spec, opt_state_spec),
|
| 751 |
out_axis_resources=state_spec,
|
| 752 |
donate_argnums=(0, 1),
|
| 753 |
+
)(model.params, opt_state)
|
| 754 |
|
| 755 |
# remove opt_state from CPU
|
| 756 |
del opt_state
|