|
|
|
|
|
class CharTokenizer: |
|
def __init__(self): |
|
self.chars = set() |
|
self.char2idx = {} |
|
self.idx2char = {} |
|
|
|
def fit(self, texts): |
|
for text in texts: |
|
self.chars.update(set(text)) |
|
self.chars = sorted(list(self.chars)) |
|
self.char2idx = {char: idx for idx, char in enumerate(self.chars)} |
|
self.idx2char = {idx: char for char, idx in self.char2idx.items()} |
|
|
|
def encode(self, text, max_length=None): |
|
encoded = [self.char2idx[char] for char in text if char in self.char2idx] |
|
if max_length: |
|
encoded = encoded[:max_length] + [0] * (max_length - len(encoded)) |
|
return encoded |
|
|
|
def decode(self, tokens): |
|
return ''.join([self.idx2char[token] for token in tokens if token in self.idx2char]) |
|
|