# -*- coding: utf-8 -*- # @Time : 2022/2/15 7:57 下午 # @Author : JianingWang # @File : trie import logging from typing import List from collections import OrderedDict logger = logging.getLogger(__name__) class Trie: def __init__(self): self.data = {} def add(self, word: str): """ Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation. The special key `""` is used to represent termination. This function is idempotent, adding twice the same word will leave the trie unchanged Example: ```python >>> trie = Trie() >>> trie.add("Hello 友達") >>> trie.data {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}} >>> trie.add("Hello") >>> trie.data {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}} ``` """ if not word: # Prevent empty string return ref = self.data for char in word: ref[char] = char in ref and ref[char] or {} ref = ref[char] ref[""] = 1 def find(self, text: str): states = OrderedDict() offsets = [] skip = 0 for current, current_char in enumerate(text): if skip and current < skip: continue to_remove = set() reset = False for start, trie_pointer in states.items(): if "" in trie_pointer: for lookstart, looktrie_pointer in states.items(): if lookstart > start: break elif lookstart < start: lookahead_index = current + 1 end = current + 1 else: lookahead_index = current end = current next_char = text[lookahead_index] if lookahead_index < len(text) else None if "" in looktrie_pointer: start = lookstart end = lookahead_index skip = lookahead_index while next_char in looktrie_pointer: looktrie_pointer = looktrie_pointer[next_char] lookahead_index += 1 if "" in looktrie_pointer: start = lookstart end = lookahead_index skip = lookahead_index if lookahead_index == len(text): break next_char = text[lookahead_index] offsets.append([start, end]) reset = True break elif current_char in trie_pointer: trie_pointer = trie_pointer[current_char] states[start] = trie_pointer else: to_remove.add(start) if reset: states = {} else: for start in to_remove: del states[start] if current >= skip and current_char in self.data: states[current] = self.data[current_char] for start, trie_pointer in states.items(): if "" in trie_pointer: end = len(text) offsets.append([start, end]) break return offsets def split(self, text: str) -> List[str]: """ Example: ```python >>> trie = Trie() >>> trie.split("[CLS] This is a extra_id_100") ["[CLS] This is a extra_id_100"] >>> trie.add("[CLS]") >>> trie.add("extra_id_1") >>> trie.add("extra_id_100") >>> trie.split("[CLS] This is a extra_id_100") ["[CLS]", " This is a ", "extra_id_100"] ``` """ word_sets = self.find(text) offsets = [0] for w in word_sets: offsets.extend(w) return self.cut_text(text, offsets) def cut_text(self, text, offsets): offsets.append(len(text)) tokens = [] start = 0 for end in offsets: if start > end: logger.error( "There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it anyway." ) continue elif start == end: continue tokens.append(text[start:end]) start = end return tokens def __reduce__(self): return None if __name__ == "__main__": trie = Trie() for word in ["A", "AB", "BD", "BWA"]: trie.add(word) print(trie.__reduce__())