g2p_with_stress / G2P_lexicon /sp_tokenizer.py
NikiPshg's picture
Upload 27 files
9ba7d3b verified
raw
history blame
2.96 kB
import json
class Tokenizer_sp:
def __init__(self, config: dict = None, srs: bool = True, dict_path=None, text=None):
if config is None:
config = {}
self.sos = config.get('BOS_TOKEN', '<sos>')
self.eos = config.get('EOS_TOKEN', '<eos>')
self.unk = config.get('UNK_TOKEN', '<unk>')
self.pad = config.get('PAD_TOKEN', '<pad>')
self.tokens = []
self.srs = srs
if dict_path:
self.load_dict_from_file(dict_path)
elif text:
self.create_tokenizer(text)
else:
raise ValueError("Текстов нет")
def create_tokenizer(self, texts):
tokens = []
for phonemes_list in texts:
for phoneme in phonemes_list:
tokens.append(phoneme)
self.tokens = [self.sos, self.eos, self.unk, self.pad] + list(set(tokens))
self.token2idx = {token: int(i) for i, token in enumerate(self.tokens)}
self.idx2token = {int(i): token for i, token in enumerate(self.tokens)}
self.unk_idx = self.token2idx[self.unk]
self.sos_idx = self.token2idx[self.sos]
self.eos_idx = self.token2idx[self.eos]
self.pad_idx = self.token2idx[self.pad]
def load_dict_from_file(self, file_path):
with open(file_path, 'r') as file:
data = json.load(file)
self.idx2token = {int(token): idx for token, idx in data.items()}
self.token2idx = {idx: int(token) for token, idx in self.idx2token.items()}
self.unk_idx = self.token2idx.get(self.unk)
self.sos_idx = self.token2idx.get(self.sos)
self.eos_idx = self.token2idx.get(self.eos)
self.pad_idx = self.token2idx.get(self.pad)
def tokenize(self, text):
if not self.srs:
tokens = []
for tok in text:
if tok in self.token2idx:
tokens.append(tok)
else:
tokens.append(self.unk_idx)
return [self.sos] + tokens + [self.eos]
else:
return [self.sos] + list(text) + [self.eos]
def convert_tokens_to_idx(self, tokens):
idx_list = [self.token2idx.get(tok, self.unk_idx) for tok in tokens]
return idx_list
def encode(self, text, seq_len=None):
tokens = self.tokenize(text)[:seq_len]
return self.convert_tokens_to_idx(tokens)
def decode(self, idx_list):
ans = []
for idx in idx_list:
try:
ans.append(self.idx2token[int(idx)])
except KeyError:
ans.append(self.idx2token[self.unk_idx])
return ans
def get_vocab_size(self):
return len(self.token2idx)
if __name__ == "__main__":
tokenizer_sp = Tokenizer_sp(dict_path='./my_tokenizer/my_dict_256.json')
print(tokenizer_sp.idx2token)