|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Extract pre-computed feature vectors from BERT.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import codecs |
|
import collections |
|
import json |
|
import re |
|
|
|
import modeling |
|
import tokenization |
|
import tensorflow as tf |
|
|
|
flags = tf.flags |
|
|
|
FLAGS = flags.FLAGS |
|
|
|
flags.DEFINE_string("input_file", None, "") |
|
|
|
flags.DEFINE_string("output_file", None, "") |
|
|
|
flags.DEFINE_string("layers", "-1,-2,-3,-4", "") |
|
|
|
flags.DEFINE_string( |
|
"bert_config_file", None, |
|
"The config json file corresponding to the pre-trained BERT model. " |
|
"This specifies the model architecture.") |
|
|
|
flags.DEFINE_integer( |
|
"max_seq_length", 128, |
|
"The maximum total input sequence length after WordPiece tokenization. " |
|
"Sequences longer than this will be truncated, and sequences shorter " |
|
"than this will be padded.") |
|
|
|
flags.DEFINE_string( |
|
"init_checkpoint", None, |
|
"Initial checkpoint (usually from a pre-trained BERT model).") |
|
|
|
flags.DEFINE_string("vocab_file", None, |
|
"The vocabulary file that the BERT model was trained on.") |
|
|
|
flags.DEFINE_bool( |
|
"do_lower_case", True, |
|
"Whether to lower case the input text. Should be True for uncased " |
|
"models and False for cased models.") |
|
|
|
flags.DEFINE_integer("batch_size", 32, "Batch size for predictions.") |
|
|
|
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") |
|
|
|
flags.DEFINE_string("master", None, |
|
"If using a TPU, the address of the master.") |
|
|
|
flags.DEFINE_integer( |
|
"num_tpu_cores", 8, |
|
"Only used if `use_tpu` is True. Total number of TPU cores to use.") |
|
|
|
flags.DEFINE_bool( |
|
"use_one_hot_embeddings", False, |
|
"If True, tf.one_hot will be used for embedding lookups, otherwise " |
|
"tf.nn.embedding_lookup will be used. On TPUs, this should be True " |
|
"since it is much faster.") |
|
|
|
|
|
class InputExample(object): |
|
|
|
def __init__(self, unique_id, text_a, text_b): |
|
self.unique_id = unique_id |
|
self.text_a = text_a |
|
self.text_b = text_b |
|
|
|
|
|
class InputFeatures(object): |
|
"""A single set of features of data.""" |
|
|
|
def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids): |
|
self.unique_id = unique_id |
|
self.tokens = tokens |
|
self.input_ids = input_ids |
|
self.input_mask = input_mask |
|
self.input_type_ids = input_type_ids |
|
|
|
|
|
def input_fn_builder(features, seq_length): |
|
"""Creates an `input_fn` closure to be passed to TPUEstimator.""" |
|
|
|
all_unique_ids = [] |
|
all_input_ids = [] |
|
all_input_mask = [] |
|
all_input_type_ids = [] |
|
|
|
for feature in features: |
|
all_unique_ids.append(feature.unique_id) |
|
all_input_ids.append(feature.input_ids) |
|
all_input_mask.append(feature.input_mask) |
|
all_input_type_ids.append(feature.input_type_ids) |
|
|
|
def input_fn(params): |
|
"""The actual input function.""" |
|
batch_size = params["batch_size"] |
|
|
|
num_examples = len(features) |
|
|
|
|
|
|
|
|
|
d = tf.data.Dataset.from_tensor_slices({ |
|
"unique_ids": |
|
tf.constant(all_unique_ids, shape=[num_examples], dtype=tf.int32), |
|
"input_ids": |
|
tf.constant( |
|
all_input_ids, shape=[num_examples, seq_length], |
|
dtype=tf.int32), |
|
"input_mask": |
|
tf.constant( |
|
all_input_mask, |
|
shape=[num_examples, seq_length], |
|
dtype=tf.int32), |
|
"input_type_ids": |
|
tf.constant( |
|
all_input_type_ids, |
|
shape=[num_examples, seq_length], |
|
dtype=tf.int32), |
|
}) |
|
|
|
d = d.batch(batch_size=batch_size, drop_remainder=False) |
|
return d |
|
|
|
return input_fn |
|
|
|
|
|
def model_fn_builder(bert_config, init_checkpoint, layer_indexes, use_tpu, |
|
use_one_hot_embeddings): |
|
"""Returns `model_fn` closure for TPUEstimator.""" |
|
|
|
def model_fn(features, labels, mode, params): |
|
"""The `model_fn` for TPUEstimator.""" |
|
|
|
unique_ids = features["unique_ids"] |
|
input_ids = features["input_ids"] |
|
input_mask = features["input_mask"] |
|
input_type_ids = features["input_type_ids"] |
|
|
|
model = modeling.BertModel( |
|
config=bert_config, |
|
is_training=False, |
|
input_ids=input_ids, |
|
input_mask=input_mask, |
|
token_type_ids=input_type_ids, |
|
use_one_hot_embeddings=use_one_hot_embeddings) |
|
|
|
if mode != tf.estimator.ModeKeys.PREDICT: |
|
raise ValueError("Only PREDICT modes are supported: %s" % (mode)) |
|
|
|
tvars = tf.trainable_variables() |
|
scaffold_fn = None |
|
(assignment_map, |
|
initialized_variable_names) = modeling.get_assignment_map_from_checkpoint( |
|
tvars, init_checkpoint) |
|
if use_tpu: |
|
|
|
def tpu_scaffold(): |
|
tf.train.init_from_checkpoint(init_checkpoint, assignment_map) |
|
return tf.train.Scaffold() |
|
|
|
scaffold_fn = tpu_scaffold |
|
else: |
|
tf.train.init_from_checkpoint(init_checkpoint, assignment_map) |
|
|
|
tf.logging.info("**** Trainable Variables ****") |
|
for var in tvars: |
|
init_string = "" |
|
if var.name in initialized_variable_names: |
|
init_string = ", *INIT_FROM_CKPT*" |
|
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, |
|
init_string) |
|
|
|
all_layers = model.get_all_encoder_layers() |
|
|
|
predictions = { |
|
"unique_id": unique_ids, |
|
} |
|
|
|
for (i, layer_index) in enumerate(layer_indexes): |
|
predictions["layer_output_%d" % i] = all_layers[layer_index] |
|
|
|
output_spec = tf.contrib.tpu.TPUEstimatorSpec( |
|
mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) |
|
return output_spec |
|
|
|
return model_fn |
|
|
|
|
|
def convert_examples_to_features(examples, seq_length, tokenizer): |
|
"""Loads a data file into a list of `InputBatch`s.""" |
|
|
|
features = [] |
|
for (ex_index, example) in enumerate(examples): |
|
tokens_a = tokenizer.tokenize(example.text_a) |
|
|
|
tokens_b = None |
|
if example.text_b: |
|
tokens_b = tokenizer.tokenize(example.text_b) |
|
|
|
if tokens_b: |
|
|
|
|
|
|
|
_truncate_seq_pair(tokens_a, tokens_b, seq_length - 3) |
|
else: |
|
|
|
if len(tokens_a) > seq_length - 2: |
|
tokens_a = tokens_a[0:(seq_length - 2)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokens = [] |
|
input_type_ids = [] |
|
tokens.append("[CLS]") |
|
input_type_ids.append(0) |
|
for token in tokens_a: |
|
tokens.append(token) |
|
input_type_ids.append(0) |
|
tokens.append("[SEP]") |
|
input_type_ids.append(0) |
|
|
|
if tokens_b: |
|
for token in tokens_b: |
|
tokens.append(token) |
|
input_type_ids.append(1) |
|
tokens.append("[SEP]") |
|
input_type_ids.append(1) |
|
|
|
input_ids = tokenizer.convert_tokens_to_ids(tokens) |
|
|
|
|
|
|
|
input_mask = [1] * len(input_ids) |
|
|
|
|
|
while len(input_ids) < seq_length: |
|
input_ids.append(0) |
|
input_mask.append(0) |
|
input_type_ids.append(0) |
|
|
|
assert len(input_ids) == seq_length |
|
assert len(input_mask) == seq_length |
|
assert len(input_type_ids) == seq_length |
|
|
|
if ex_index < 5: |
|
tf.logging.info("*** Example ***") |
|
tf.logging.info("unique_id: %s" % (example.unique_id)) |
|
tf.logging.info("tokens: %s" % " ".join( |
|
[tokenization.printable_text(x) for x in tokens])) |
|
tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) |
|
tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) |
|
tf.logging.info( |
|
"input_type_ids: %s" % " ".join([str(x) for x in input_type_ids])) |
|
|
|
features.append( |
|
InputFeatures( |
|
unique_id=example.unique_id, |
|
tokens=tokens, |
|
input_ids=input_ids, |
|
input_mask=input_mask, |
|
input_type_ids=input_type_ids)) |
|
return features |
|
|
|
|
|
def _truncate_seq_pair(tokens_a, tokens_b, max_length): |
|
"""Truncates a sequence pair in place to the maximum length.""" |
|
|
|
|
|
|
|
|
|
|
|
while True: |
|
total_length = len(tokens_a) + len(tokens_b) |
|
if total_length <= max_length: |
|
break |
|
if len(tokens_a) > len(tokens_b): |
|
tokens_a.pop() |
|
else: |
|
tokens_b.pop() |
|
|
|
|
|
def read_examples(input_file): |
|
"""Read a list of `InputExample`s from an input file.""" |
|
examples = [] |
|
unique_id = 0 |
|
with tf.gfile.GFile(input_file, "r") as reader: |
|
while True: |
|
line = tokenization.convert_to_unicode(reader.readline()) |
|
if not line: |
|
break |
|
line = line.strip() |
|
text_a = None |
|
text_b = None |
|
m = re.match(r"^(.*) \|\|\| (.*)$", line) |
|
if m is None: |
|
text_a = line |
|
else: |
|
text_a = m.group(1) |
|
text_b = m.group(2) |
|
examples.append( |
|
InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b)) |
|
unique_id += 1 |
|
return examples |
|
|
|
|
|
def main(_): |
|
tf.logging.set_verbosity(tf.logging.INFO) |
|
|
|
layer_indexes = [int(x) for x in FLAGS.layers.split(",")] |
|
|
|
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) |
|
|
|
tokenizer = tokenization.FullTokenizer( |
|
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) |
|
|
|
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 |
|
run_config = tf.contrib.tpu.RunConfig( |
|
master=FLAGS.master, |
|
tpu_config=tf.contrib.tpu.TPUConfig( |
|
num_shards=FLAGS.num_tpu_cores, |
|
per_host_input_for_training=is_per_host)) |
|
|
|
examples = read_examples(FLAGS.input_file) |
|
|
|
features = convert_examples_to_features( |
|
examples=examples, seq_length=FLAGS.max_seq_length, tokenizer=tokenizer) |
|
|
|
unique_id_to_feature = {} |
|
for feature in features: |
|
unique_id_to_feature[feature.unique_id] = feature |
|
|
|
model_fn = model_fn_builder( |
|
bert_config=bert_config, |
|
init_checkpoint=FLAGS.init_checkpoint, |
|
layer_indexes=layer_indexes, |
|
use_tpu=FLAGS.use_tpu, |
|
use_one_hot_embeddings=FLAGS.use_one_hot_embeddings) |
|
|
|
|
|
|
|
estimator = tf.contrib.tpu.TPUEstimator( |
|
use_tpu=FLAGS.use_tpu, |
|
model_fn=model_fn, |
|
config=run_config, |
|
predict_batch_size=FLAGS.batch_size) |
|
|
|
input_fn = input_fn_builder( |
|
features=features, seq_length=FLAGS.max_seq_length) |
|
|
|
with codecs.getwriter("utf-8")(tf.gfile.Open(FLAGS.output_file, |
|
"w")) as writer: |
|
for result in estimator.predict(input_fn, yield_single_examples=True): |
|
unique_id = int(result["unique_id"]) |
|
feature = unique_id_to_feature[unique_id] |
|
output_json = collections.OrderedDict() |
|
output_json["linex_index"] = unique_id |
|
all_features = [] |
|
for (i, token) in enumerate(feature.tokens): |
|
all_layers = [] |
|
for (j, layer_index) in enumerate(layer_indexes): |
|
layer_output = result["layer_output_%d" % j] |
|
layers = collections.OrderedDict() |
|
layers["index"] = layer_index |
|
layers["values"] = [ |
|
round(float(x), 6) for x in layer_output[i:(i + 1)].flat |
|
] |
|
all_layers.append(layers) |
|
features = collections.OrderedDict() |
|
features["token"] = token |
|
features["layers"] = all_layers |
|
all_features.append(features) |
|
output_json["features"] = all_features |
|
writer.write(json.dumps(output_json) + "\n") |
|
|
|
|
|
if __name__ == "__main__": |
|
flags.mark_flag_as_required("input_file") |
|
flags.mark_flag_as_required("vocab_file") |
|
flags.mark_flag_as_required("bert_config_file") |
|
flags.mark_flag_as_required("init_checkpoint") |
|
flags.mark_flag_as_required("output_file") |
|
tf.app.run() |
|
|