# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test Subtokenizer and string helper methods.""" import collections import tempfile import tensorflow as tf, tf_keras from official.legacy.transformer.utils import tokenizer class SubtokenizerTest(tf.test.TestCase): def _init_subtokenizer(self, vocab_list): temp_file = tempfile.NamedTemporaryFile(delete=False) with tf.io.gfile.GFile(temp_file.name, "w") as w: for subtoken in vocab_list: w.write("'%s'" % subtoken) w.write("\n") return tokenizer.Subtokenizer(temp_file.name, reserved_tokens=[]) def test_encode(self): vocab_list = ["123_", "test", "ing_"] subtokenizer = self._init_subtokenizer(vocab_list) s = "testing 123" encoded_list = subtokenizer.encode(s) self.assertEqual([1, 2, 0], encoded_list) def test_decode(self): vocab_list = ["123_", "test", "ing_"] subtokenizer = self._init_subtokenizer(vocab_list) encoded_list = [1, 2, 0] # testing 123 decoded_str = subtokenizer.decode(encoded_list) self.assertEqual("testing 123", decoded_str) def test_subtoken_ids_to_tokens(self): vocab_list = ["123_", "test", "ing_"] subtokenizer = self._init_subtokenizer(vocab_list) encoded_list = [1, 2, 0] # testing 123 token_list = subtokenizer._subtoken_ids_to_tokens(encoded_list) self.assertEqual([u"testing", u"123"], token_list) class StringHelperTest(tf.test.TestCase): def test_split_string_to_tokens(self): text = "test? testing 123." tokens = tokenizer._split_string_to_tokens(text, tokenizer._ALPHANUMERIC_CHAR_SET) self.assertEqual(["test", "? ", "testing", "123", "."], tokens) def test_join_tokens_to_string(self): tokens = ["test", "? ", "testing", "123", "."] s = tokenizer._join_tokens_to_string(tokens, tokenizer._ALPHANUMERIC_CHAR_SET) self.assertEqual("test? testing 123.", s) def test_escape_token(self): token = u"abc_\\4" alphabet = set("abc_\\u;") escaped_token = tokenizer._escape_token(token, alphabet) self.assertEqual("abc\\u\\\\\\52;_", escaped_token) def test_unescape_token(self): escaped_token = u"Underline: \\u, Backslash: \\\\, Unicode: \\52;" unescaped_token = tokenizer._unescape_token(escaped_token) self.assertEqual("Underline: _, Backslash: \\, Unicode: 4", unescaped_token) def test_list_to_index_dict(self): lst = ["test", "strings"] d = tokenizer._list_to_index_dict(lst) self.assertDictEqual({"test": 0, "strings": 1}, d) def test_split_token_to_subtokens(self): token = "abc" subtoken_dict = {"a": 0, "b": 1, "c": 2, "ab": 3} max_subtoken_length = 2 subtokens = tokenizer._split_token_to_subtokens(token, subtoken_dict, max_subtoken_length) self.assertEqual(["ab", "c"], subtokens) def test_generate_alphabet_dict(self): s = ["testing", "123"] reserved_tokens = ["???"] alphabet = tokenizer._generate_alphabet_dict(s, reserved_tokens) self.assertIn("?", alphabet) self.assertIn("t", alphabet) self.assertIn("e", alphabet) self.assertIn("s", alphabet) self.assertIn("i", alphabet) self.assertIn("n", alphabet) self.assertIn("g", alphabet) self.assertIn("1", alphabet) self.assertIn("2", alphabet) self.assertIn("3", alphabet) def test_count_and_gen_subtokens(self): token_counts = {"abc": 5} alphabet = set("abc_") subtoken_dict = {"a": 0, "b": 1, "c": 2, "_": 3} max_subtoken_length = 2 subtoken_counts = tokenizer._count_and_gen_subtokens( token_counts, alphabet, subtoken_dict, max_subtoken_length) self.assertIsInstance(subtoken_counts, collections.defaultdict) self.assertDictEqual( { "a": 5, "b": 5, "c": 5, "_": 5, "ab": 5, "bc": 5, "c_": 5, "abc": 5, "bc_": 5, "abc_": 5 }, subtoken_counts) def test_filter_and_bucket_subtokens(self): subtoken_counts = collections.defaultdict(int, { "a": 2, "b": 4, "c": 1, "ab": 6, "ac": 3, "abbc": 5 }) min_count = 3 subtoken_buckets = tokenizer._filter_and_bucket_subtokens( subtoken_counts, min_count) self.assertEqual(len(subtoken_buckets[0]), 0) self.assertEqual(set("b"), subtoken_buckets[1]) self.assertEqual(set(["ab", "ac"]), subtoken_buckets[2]) self.assertEqual(len(subtoken_buckets[3]), 0) self.assertEqual(set(["abbc"]), subtoken_buckets[4]) def test_gen_new_subtoken_list(self): subtoken_counts = collections.defaultdict(int, { "translate": 10, "t": 40, "tr": 16, "tra": 12 }) min_count = 5 alphabet = set("translate") reserved_tokens = ["reserved", "tokens"] subtoken_list, max_token_length = tokenizer._gen_new_subtoken_list( subtoken_counts, min_count, alphabet, reserved_tokens) # Check that "tra" isn"t in the list (its count should be decremented to 2, # so it should not be added to the canddiate list). self.assertNotIn("tra", subtoken_list) self.assertIn("tr", subtoken_list) self.assertIn("t", subtoken_list) self.assertEqual(len("translate"), max_token_length) def test_generate_subtokens(self): token_counts = {"ab": 1, "bc": 3, "abc": 5} alphabet = set("abc_") min_count = 100 num_iterations = 1 reserved_tokens = ["reserved", "tokens"] vocab_list = tokenizer._generate_subtokens(token_counts, alphabet, min_count, num_iterations, reserved_tokens) # Check that reserved tokens are at the front of the list self.assertEqual(vocab_list[:2], reserved_tokens) # Check that each character in alphabet is in the vocab list for c in alphabet: self.assertIn(c, vocab_list) if __name__ == "__main__": tf.test.main()