Spaces:
Build error
Build error
# Copyright 2022 The T5X Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Extensions to Jax/Flax core functions for Mixture of Experts training. | |
""" | |
import dataclasses | |
import re | |
from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union | |
import flax | |
import jax | |
import numpy as np | |
from t5x import train_state | |
# Type Stubs | |
ParamTree = Any | |
PyTreeDef = Any | |
Gradients = Union[flax.core.FrozenDict, train_state.TrainState] | |
def match_fn(prefix: Optional[str]) -> Callable[[str], bool]: | |
"""Creates a function returning true iff a string matches the prefix. | |
Args: | |
prefix: Regex prefix to match. If none, then return match function will not | |
match any strings. | |
Returns: | |
Prefix match function. | |
""" | |
if not prefix: | |
return lambda name: False | |
params_regex = re.compile(f'^{prefix}') | |
return lambda name: params_regex.match(name) is not None | |
def scale_sharded_grads(grads: Gradients, | |
sharded_match_fn: Optional[Callable[[str], bool]], | |
scale_factor: float) -> Gradients: | |
"""Scales sharded grads, identified by sharded_match_fn, by scale_factor. | |
Args: | |
grads: Parameter gradients. | |
sharded_match_fn: Filter function for distinguishing sharded parameters from | |
replicated parameters. | |
scale_factor: Amount by which to scale sharded parameter gradients. | |
Returns: | |
Gradients matching input, expect with sharded parameter gradients rescaled. | |
""" | |
if sharded_match_fn: | |
names_and_grads, tree_def = _tree_flatten_with_names(grads) | |
scaled_grads = [ | |
grad * scale_factor if sharded_match_fn(name) else grad | |
for name, grad in names_and_grads | |
] | |
return tree_def.unflatten(scaled_grads) | |
else: | |
return grads | |
def tree_map_with_names(f, param_tree, match_name_fn=lambda name: True): | |
"""Like jax.tree_map but with a filter on the leaf path name. | |
Args: | |
f: The function to be applied to each parameter in `param_tree`. | |
param_tree: The tree of parameters `f` should be applied to. | |
match_name_fn: This function is called with each tree leave's path name, | |
which has a path-like format ('a/b/c'), and decides whether `f` should be | |
applied to that leaf or the leaf should be kept as-is. | |
Returns: | |
A tree identical in structure to `param_tree` but with the leaves the | |
result of calling `f` on them in the cases where `match_name_fn` returns | |
True for that leaf's path name. | |
""" | |
names_and_vals, tree_def = _tree_flatten_with_names(param_tree) | |
vals = [f(v) if match_name_fn(name) else v for name, v in names_and_vals] | |
return tree_def.unflatten(vals) | |
def _tree_flatten_with_names( | |
tree: ParamTree) -> Tuple[Sequence[Tuple[str, Any]], PyTreeDef]: | |
"""Like jax.tree_flatten but also fetches leaf names. | |
Specialized to parameter trees of the form {'key0': {'subkey0': Any}, ...}. | |
Args: | |
tree: Tree of parameters to flatten. | |
Returns: | |
- A list of leaf name and value pairs: [(name, value), ...]. | |
- A tree definition object representing the structure of the flattened tree. | |
""" | |
# PyTrees don't treat None values as leaves, so we explicitly declare them as | |
# such. | |
vals, tree_def = jax.tree_flatten(tree, is_leaf=lambda x: x is None) | |
# 'Fake' token tree that is use to track jax internal tree traversal and | |
# adjust our custom tree traversal to be compatible with it. | |
tokens = range(len(vals)) | |
token_tree = tree_def.unflatten(tokens) | |
val_names, perm = zip(*_traverse_with_names(token_tree)) | |
inv_perm = np.argsort(perm) | |
# Custom traversal should visit the same number of leaves. | |
if len(val_names) != len(vals): | |
raise ValueError(f'Pytree traversal detected {len(val_names)} names, ' | |
f'but {len(vals)} leafs.\nTreeDef is:\n{tree_def}') | |
return [(val_names[i], v) for i, v in zip(inv_perm, vals)], tree_def | |
def _traverse_with_names( | |
param_tree: ParamTree) -> Iterable[Tuple[str, ParamTree]]: | |
"""Traverses nested dicts/dataclasses and emits (leaf_name, leaf_val).""" | |
if dataclasses.is_dataclass(param_tree): | |
param_tree = flax.serialization.to_state_dict(param_tree) | |
if isinstance(param_tree, (dict, flax.core.FrozenDict)): | |
keys = sorted(param_tree.keys()) | |
for key in keys: | |
for path, v in _traverse_with_names(param_tree[key]): | |
yield (key + '/' + path).rstrip('/'), v | |
else: | |
yield '', param_tree | |