import re
from urllib.parse import urlparse
from shutil import rmtree
import logging
import os
from pathlib import Path
import sys
import tensorflow.compat.v1 as tf
import tensorflow.compat.v2 as tf2
import mesh_tensorflow as mtf
from data.encoders import fetch_encoder
def setup_logging(args):
tf.get_logger().propagate = False # Remove double log on console
name = os.path.splitext(os.path.basename(args.model))[0]
handlers = [
logger = logging.getLogger("tensorflow")
logger.handlers = handlers
return logger
def get_batch_size(params):
return params[f"{params['mode']}_batch_size"]
def add_mode_to_params(params, mode):
if mode == tf.estimator.ModeKeys.PREDICT:
params["mode"] = "predict"
elif mode == tf.estimator.ModeKeys.EVAL:
params["mode"] = "eval"
elif mode == tf.estimator.ModeKeys.TRAIN:
params["mode"] = "train"
raise ValueError(f"Invalid mode {mode}")
return params
def simd_mesh_setup(params, mesh_shape, layout_rules):
"""Constructs SimdMesh function - instructions on how to evenly split tensors across all TPU cores"""
num_hosts = params["context"].num_hosts
host_placement_fn = params["context"].tpu_host_placement_function
device_list = [host_placement_fn(host_id=i) for i in range(num_hosts)]"device_list = {device_list}")
# TODO: Better estimation of replica cache size?
replica_cache_size = 300 * 1000000 # 300M per replica
# Worker 0 caches all the TPU binaries
worker0_mem = replica_cache_size * params["context"].num_replicas
devices_memory_usage = [worker0_mem] + [0] * (num_hosts - 1)
var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memory_usage)
mesh_devices = [""] * mesh_shape.size
mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
mesh_shape, layout_rules, mesh_devices, params["context"].device_assignment)
return var_placer, mesh_impl
def remove_batch_from_layout(layout):
The tf-mesh layout splits across batch size, remove it.
Useful for prediction steps, when you no longer want large batches.
:param layout: string describing tf-mesh layout
:return: layout minus batch dimension
layout = layout.split(',')
ret_layout = ""
for i in layout:
if "batch" in i:
ret_layout += f"{i},"
return ret_layout[:-1]
def yes_or_no(question):
while True:
reply = str(input(question+' (y/n): ')).lower().strip()
if reply[:1] == 'y':
return True
if reply[:1] == 'n':
return False
def remove_gs_or_filepath(path):
parsed_url = urlparse(path)
if parsed_url.scheme == "gs":
os.system(f"gsutil rm -rf {path}")
def save_config(params_dict, logdir):
print(f"Saving config to {logdir}")
text = "{\n\n"
total_params = len(params_dict)
for count, key in enumerate(params_dict):
config_value = str(params_dict[key])
if'[a-zA-Z]', config_value):
if config_value.lower() != 'true':
if config_value.lower() != 'false':
if config_value[0] != '[':
# TODO: Making a manual exception for parsing epsilon right now since it's the only number in
# scientific notation. Should fix this.
if key != "epsilon":
config_value = f'"{config_value}"'
if count == total_params - 1:
text += f'"{str(key)}"' + ' : ' + config_value + '\n\n'
text += f'"{str(key)}"' + ' : ' + config_value + ',\n\n'
text += '\n\n}'
sess = tf.InteractiveSession()
summary_op = tf.summary.text("run_config", tf.convert_to_tensor(text))
summary_writer = tf.summary.FileWriter(f"{logdir}/config", sess.graph)
text =
summary_writer.add_summary(text, 0)
def expand_attention_types_params(params_list):
newlist = []
for item in params_list:
for _ in range(item[1]):
return newlist
def get_n_trainable_vars(graph):
Gets number of trainable vars in a MTF model.
:param graph: Mesh-Tensorflow graph
:return: None
total_parameters = 0
for variable in graph.trainable_variables:
shape = variable.shape.dims
variable_parameters = 1
for dim in shape:
variable_parameters *= dim.size
total_parameters += variable_parameters
print(f"\n\nN TRAINABLE VARS:\n{total_parameters:,}\n\n")
def print_dim_names(graph):
Print names of all Dimensions
:param graph: Mesh-Tensorflow graph
:return: None
all_dim_names = []
for variable in graph.all_variables:
names = variable.shape.dimension_names
# Print all dim names in graph & write to file
all_dim_names = [item for sublist in all_dim_names for item in sublist] # Flatten all dims
unique_dims = list(set(all_dim_names))
print("ALL DIM NAMES:")
for dim_name in unique_dims:
def get_graph_info(graph):
Wrapper fn that calculates number of trainable vars in an MTF graph & prints all dim_names to file
TODO: how to get un-trainable dim-names too, batch etc.
:param graph: Mesh-Tensorflow graph
:return: None
def loss_denominator(targets, num_microbatches):
"""Denominator applied to losses.
This is usually the size of the targets tensor (omitting ensemble
dimensions). Alternatively, it is an override value passed to the
class constructor.
targets: a mtf.Tensor
num_microbatches: an integer - greater than one if the step has been
serialized into multiple microbatches to save memory.
a float
ret = float(targets.shape.size) * num_microbatches
return float(ret)
def check_dataset(input_fn, params, global_step=None):
if global_step is not None:
dataset = input_fn(params, global_step=global_step)
dataset = input_fn(params)
dataset_iter = dataset.make_one_shot_iterator()
tensor, _ = next(dataset_iter)
enc = fetch_encoder(params)
for p in tensor[:1]:
txt = enc.decode(p)
print('-' * 50)
print(txt[:500], '\n\n...\n\n', txt[-500:])
print('-' * 50)
def auto_layout(graph, mesh_shape, logits, loss):
layout_rules = mtf.auto_mtf.layout(graph, mesh_shape, [logits, loss])
print(f"Auto-selected layout:\n{layout_rules}\nRe-initialize graph with selected layout")
def auto_layout_and_mesh_shape(graph, num_cores, logits, loss):
layout_rules, mesh_shape = mtf.auto_mtf.layout_and_mesh_shape(graph, num_cores,
[logits, loss], max_mesh_shape_dimensions=4)
print(f"Num cores:\n{num_cores}\nAuto-selected layout:\n{layout_rules}\nAuto-selected mesh shape:\n{mesh_shape}" \
f"\nRe-initialize graph with selected layout & mesh shape")
def create_host_call(model_dir):
"""Construct a host_call writing scalar summaries.
Borrowed from t2t.
model_dir: String containing path to train
(fn, args) Pair to be called by TPUEstimator as the host_call.
graph = tf.get_default_graph()
# A list of (name, lowered tensor) tuples
summaries = graph.get_collection(mtf.utils.SCALAR_SUMMARIES_COLLECTION_KEY)
def maybe_cast(tensor):
assert tensor.shape.is_compatible_with([]),
if tensor.dtype == tf.int64:
return tf.to_int32(tensor)
if tensor.dtype == tf.bfloat16:
return tf.cast(tensor, tf.float32)
return tensor
reshaped_tensors = [tf.reshape(maybe_cast(t), [1]) for _, t in summaries]
# When no supported summaries are found, don't create host_call. Otherwise,
# TPU outfeed queue would enqueue global_step while host_call doesn't dequeue
# it, eventually causing hang.
if not reshaped_tensors:
return None
def host_call_fn(global_step, *args):
"""Training host call. Creates scalar summaries for training metrics."""
# This function is executed on the CPU and should not directly reference
# any Tensors in the rest of the `model_fn`. To pass Tensors from the
# model to the `model_fn`, provide as part of the `host_call`.
global_step = tf.cast(global_step[0], tf.int64)
with tf2.summary.create_file_writer(model_dir).as_default():
# We cannot directly use any tensor from summaries, because each
# tensor here must be a concat of multiple tensors from all shards.
# Therefore, we rely on the assumption that args wil have the same
# length as summaries, and all tensors in args will have the same
# order of self._tup_summaries.
assert len(args) == len(summaries)
for i, tensor in enumerate(args):
name = summaries[i][0]
tf2.summary.scalar(name, tf.reduce_mean(tensor), step=global_step)
return tf.summary.all_v2_summary_ops()
global_step_t = tf.reshape(tf.to_int32(tf.train.get_global_step()), [1])
return host_call_fn, [global_step_t] + reshaped_tensors
def natural_sort(l):
convert = lambda text: int(text) if text.isdigit() else text.lower()
alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
return sorted(l, key = alphanum_key)