Spaces:
Runtime error
Runtime error
File size: 3,923 Bytes
5672777 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
# Copyright 2023 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Some layered modules/functions to help users writing custom training loop."""
import inspect
import tensorflow as tf, tf_keras
def create_global_step() -> tf.Variable:
"""Creates a `tf.Variable` suitable for use as a global step counter.
Creating and managing a global step variable may be necessary for
`AbstractTrainer` subclasses that perform multiple parameter updates per
`Controller` "step", or use different optimizers on different steps.
In these cases, an `optimizer.iterations` property generally can't be used
directly, since it would correspond to parameter updates instead of iterations
in the `Controller`'s training loop. Such use cases should simply call
`step.assign_add(1)` at the end of each step.
Returns:
A non-trainable scalar `tf.Variable` of dtype `tf.int64`, with only the
first replica's value retained when synchronizing across replicas in
a distributed setting.
"""
return tf.Variable(
0,
dtype=tf.int64,
name="global_step",
trainable=False,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
"""A utility function to help create a `tf.distribute.DistributedDataset`.
Args:
strategy: An instance of `tf.distribute.Strategy`.
dataset_or_fn: A instance of `tf.data.Dataset`, or a "dataset function"
returning a `tf.data.Dataset`. If it is a function, it may optionally have
an argument named `input_context` which will be passed a
`tf.distribute.InputContext` instance.
*args: Any positional arguments to pass through to `dataset_or_fn`.
**kwargs: Any keyword arguments to pass through to `dataset_or_fn`, except
that the `input_options` keyword is used to specify a
`tf.distribute.InputOptions` for making the distributed dataset.
Returns:
A distributed Dataset.
"""
if strategy is None:
strategy = tf.distribute.get_strategy()
input_options = kwargs.pop("input_options", None)
if isinstance(dataset_or_fn, tf.data.Dataset):
return strategy.experimental_distribute_dataset(dataset_or_fn,
input_options)
if not callable(dataset_or_fn):
raise ValueError("`dataset_or_fn` should be either callable or an instance "
"of `tf.data.Dataset`.")
def dataset_fn(input_context):
"""Wraps `dataset_or_fn` for strategy.distribute_datasets_from_function."""
# If `dataset_or_fn` is a function and has an argument named
# `input_context`, pass through the given `input_context`. Otherwise
# `input_context` will be ignored.
argspec = inspect.getfullargspec(dataset_or_fn)
arg_names = argspec.args
if "input_context" in arg_names:
kwargs["input_context"] = input_context
return dataset_or_fn(*args, **kwargs)
return strategy.distribute_datasets_from_function(dataset_fn, input_options)
def get_value(x):
"""Returns input values, converting any TensorFlow values to NumPy values.
Args:
x: The input. May be a `tf.Tensor` or `tf.Variable`.
Returns:
If the input is a TensorFlow `Tensor`, returns the `Tensor`'s equivalent
NumPy value. Otherwise, just returns the input.
"""
if not tf.is_tensor(x):
return x
return x.numpy()
|