NikiPshg's picture
Upload 21 files
b3ef6db verified
from G2P_lexicon.config_models import config_sp
from G2P_lexicon.transformer import TransformerBlock
from G2P_lexicon.sp_tokenizer import Tokenizer_sp
import torch
import os
dirname = os.path.dirname(__file__)
class Stress_Pred:
def __init__(self,
model,
tokenizer):
self.SP = model
self.tokenizer = tokenizer
self.SP.eval()
def __call__(self, srs):
with torch.no_grad():
enc_input_tokens = self.tokenizer.encode(srs)
pad_id = torch.tensor(self.tokenizer.pad_idx)
enc_num_padding_tokens = 32 - len(enc_input_tokens)
encoder_input = torch.cat(
[
torch.tensor(enc_input_tokens),
pad_id.repeat(enc_num_padding_tokens)
],
dim=0)
encoder_mask = (encoder_input != pad_id).unsqueeze(0).unsqueeze(0).int()
label = self.greedy_decode_stress(
src=encoder_input,
src_mask=encoder_mask,
start_token=self.tokenizer.sos_idx,
)
return label
def greedy_decode_stress(self,
src,
src_mask,
start_token):
len_src = (src != 3).int().sum().item()
index_vowels = torch.tensor([(idx) for (idx, i) in enumerate(src) if not (i in list_tokens_without_stress)])[
:len_src]
src = src.unsqueeze(0)
src_mask = src_mask.unsqueeze(0)
input_decoder = self.SP.encode(src, src_mask)
label = torch.tensor([]).type_as(src.data)
for idx in range(len_src):
if idx in index_vowels:
label = torch.cat([label, torch.ones(1, 1).type_as(src.data).fill_(src[0][idx])], dim=1)
else:
tgt_mask = (torch.tril(torch.ones((label.size(1), label.size(1)))).type_as(src.data)).unsqueeze(0)
out = self.SP.decode(input_decoder, src_mask, label, tgt_mask)
prob = self.SP.fc_out(out[:, -1])
_, next_word = torch.max(prob, dim=1)
next_word = next_word.data[0]
label = torch.cat([label, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
pred = self.tokenizer.decode(label[0].tolist())[1:-1]
return pred
dict_path = os.path.join(dirname, "my_tokenizer\sp_dict.json")
model_path = os.path.join(dirname, "models\model_sp.pt")
tokenizer_sp = Tokenizer_sp(dict_path=dict_path)
set_tokens_without_stress = set()
for token, phoneme in tokenizer_sp.idx2token.items():
if phoneme[-1].isdigit():
set_tokens_without_stress.add(tokenizer_sp.token2idx[phoneme[:-1]])
list_tokens_without_stress = list(set_tokens_without_stress)
sp_model = TransformerBlock(config=config_sp,
tokenizer=tokenizer_sp)
sp_model.load_state_dict(
torch.load(model_path, map_location=torch.device('cpu')))
SP = Stress_Pred(model=sp_model,
tokenizer=tokenizer_sp)
if __name__ == '__main__':
print(SP(['N', 'IH', 'K', 'IY', 'T', 'AH'])) #['N', 'IH2', 'K', 'IY1', 'T', 'AH0']