Spaces:
Running
Running
| import math | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from transformers.modeling_utils import PreTrainedModel | |
| from transformers.configuration_utils import PretrainedConfig | |
| import json | |
| import os | |
| import re | |
| from transformers.tokenization_utils import PreTrainedTokenizer | |
| import phonemizer | |
| import torch.nn.functional as F | |
| OSCILLATION = { | |
| 'deu': [1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1], | |
| 'rmc-script_latin': [2, 2, 1, 2, 2], | |
| 'hun': [1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1], | |
| 'fra': [1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1], | |
| 'eng': [1, 2, 2, 1, 2, 2], | |
| 'grc': [1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1], | |
| 'ron': [1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2], | |
| } | |
| def has_non_roman_characters(input_string): | |
| # Find any character outside the ASCII range | |
| non_roman_pattern = re.compile(r"[^\x00-\x7F]") | |
| # Search the input string for non-Roman characters | |
| match = non_roman_pattern.search(input_string) | |
| has_non_roman = match is not None | |
| return has_non_roman | |
| class VitsConfig(PretrainedConfig): | |
| model_type = "vits" | |
| def __init__( | |
| self, | |
| vocab_size=38, | |
| hidden_size=192, | |
| num_hidden_layers=6, | |
| num_attention_heads=2, | |
| use_bias=True, | |
| ffn_dim=768, | |
| ffn_kernel_size=3, | |
| flow_size=192, | |
| # hidden_act="relu", | |
| upsample_initial_channel=512, | |
| upsample_rates=[8, 8, 2, 2], | |
| upsample_kernel_sizes=[16, 16, 4, 4], | |
| resblock_kernel_sizes=[3, 7, 11], | |
| resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], | |
| prior_encoder_num_flows=4, | |
| prior_encoder_num_wavenet_layers=4, | |
| wavenet_kernel_size=5, | |
| **kwargs, | |
| ): | |
| self.vocab_size = vocab_size | |
| self.hidden_size = hidden_size | |
| self.num_hidden_layers = num_hidden_layers | |
| self.num_attention_heads = num_attention_heads | |
| self.use_bias = use_bias | |
| self.ffn_dim = ffn_dim | |
| self.ffn_kernel_size = ffn_kernel_size | |
| self.flow_size = flow_size | |
| self.upsample_initial_channel = upsample_initial_channel | |
| self.upsample_rates = upsample_rates | |
| self.upsample_kernel_sizes = upsample_kernel_sizes | |
| self.resblock_kernel_sizes = resblock_kernel_sizes | |
| self.resblock_dilation_sizes = resblock_dilation_sizes | |
| self.prior_encoder_num_flows = prior_encoder_num_flows | |
| self.prior_encoder_num_wavenet_layers = prior_encoder_num_wavenet_layers | |
| self.wavenet_kernel_size = wavenet_kernel_size | |
| super().__init__() | |
| class VitsWaveNet(torch.nn.Module): | |
| def __init__(self, config, num_layers): | |
| super().__init__() | |
| self.hidden_size = config.hidden_size | |
| self.num_layers = num_layers | |
| self.in_layers = torch.nn.ModuleList() | |
| self.res_skip_layers = torch.nn.ModuleList() | |
| # if hasattr(nn.utils.parametrizations, "weight_norm"): | |
| # # raise ValueError | |
| weight_norm = nn.utils.parametrizations.weight_norm | |
| # else: | |
| # raise ValueError | |
| # # weight_norm = nn.utils.weight_norm | |
| for i in range(num_layers): | |
| in_layer = torch.nn.Conv1d( | |
| in_channels=config.hidden_size, | |
| out_channels=2 * config.hidden_size, | |
| kernel_size=config.wavenet_kernel_size, | |
| dilation=1, | |
| padding=2, | |
| ) | |
| in_layer = weight_norm(in_layer, name="weight") | |
| self.in_layers.append(in_layer) | |
| # last one is not necessary | |
| if i < num_layers - 1: | |
| res_skip_channels = 2 * config.hidden_size | |
| else: | |
| res_skip_channels = config.hidden_size | |
| res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1) | |
| res_skip_layer = weight_norm(res_skip_layer, name="weight") | |
| self.res_skip_layers.append(res_skip_layer) | |
| def forward(self, | |
| inputs): | |
| outputs = torch.zeros_like(inputs) | |
| num_channels = torch.IntTensor([self.hidden_size])[0] | |
| for i in range(self.num_layers): | |
| in_act = self.in_layers[i](inputs) | |
| # global_states = torch.zeros_like(hidden_states) # style ? | |
| # acts = fused_add_tanh_sigmoid_multiply(hidden_states, global_states, num_channels_tensor[0]) | |
| # -- | |
| # def fused_add_tanh_sigmoid_multiply(input_a, input_b, num_channels): | |
| # in_act = input_a # + input_b | |
| t_act = torch.tanh(in_act[:, :num_channels, :]) | |
| s_act = torch.sigmoid(in_act[:, num_channels:, :]) | |
| acts = t_act * s_act | |
| res_skip_acts = self.res_skip_layers[i](acts) | |
| if i < self.num_layers - 1: | |
| res_acts = res_skip_acts[:, : self.hidden_size, :] | |
| inputs = inputs + res_acts | |
| outputs = outputs + res_skip_acts[:, self.hidden_size :, :] | |
| else: | |
| outputs = outputs + res_skip_acts | |
| return outputs | |
| # Copied from transformers.models.speecht5.modeling_speecht5.HifiGanResidualBlock | |
| class HifiGanResidualBlock(nn.Module): | |
| def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1): | |
| super().__init__() | |
| self.leaky_relu_slope = leaky_relu_slope | |
| self.convs1 = nn.ModuleList( | |
| [ | |
| nn.Conv1d( | |
| channels, | |
| channels, | |
| kernel_size, | |
| stride=1, | |
| dilation=dilation[i], | |
| padding=self.get_padding(kernel_size, dilation[i]), | |
| ) | |
| for i in range(len(dilation)) | |
| ] | |
| ) | |
| self.convs2 = nn.ModuleList( | |
| [ | |
| nn.Conv1d( | |
| channels, | |
| channels, | |
| kernel_size, | |
| stride=1, | |
| dilation=1, | |
| padding=self.get_padding(kernel_size, 1), | |
| ) | |
| for _ in range(len(dilation)) | |
| ] | |
| ) | |
| def get_padding(self, kernel_size, dilation=1): | |
| # 1, 3, 5, 15 | |
| return (kernel_size * dilation - dilation) // 2 | |
| def forward(self, hidden_states): | |
| for conv1, conv2 in zip(self.convs1, self.convs2): | |
| residual = hidden_states | |
| hidden_states = nn.functional.leaky_relu(hidden_states, negative_slope=self.leaky_relu_slope) | |
| hidden_states = conv1(hidden_states) | |
| hidden_states = nn.functional.leaky_relu(hidden_states, negative_slope=self.leaky_relu_slope) | |
| hidden_states = conv2(hidden_states) | |
| hidden_states = hidden_states + residual | |
| return hidden_states | |
| class VitsHifiGan(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.num_kernels = len(config.resblock_kernel_sizes) | |
| self.num_upsamples = len(config.upsample_rates) | |
| self.conv_pre = nn.Conv1d( | |
| config.flow_size, | |
| config.upsample_initial_channel, | |
| kernel_size=7, | |
| stride=1, | |
| padding=3, | |
| ) | |
| self.upsampler = nn.ModuleList() | |
| for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)): | |
| self.upsampler.append( | |
| nn.ConvTranspose1d( | |
| config.upsample_initial_channel // (2**i), | |
| config.upsample_initial_channel // (2 ** (i + 1)), | |
| kernel_size=kernel_size, | |
| stride=upsample_rate, | |
| padding=(kernel_size - upsample_rate) // 2, | |
| ) | |
| ) | |
| self.resblocks = nn.ModuleList() | |
| for i in range(len(self.upsampler)): | |
| channels = config.upsample_initial_channel // (2 ** (i + 1)) | |
| for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes): | |
| self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation)) | |
| self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3, bias=False) | |
| def forward(self, | |
| spectrogram): | |
| hidden_states = self.conv_pre(spectrogram) | |
| for i in range(self.num_upsamples): | |
| hidden_states = F.leaky_relu(hidden_states, negative_slope=.1, inplace=True) | |
| hidden_states = self.upsampler[i](hidden_states) | |
| res_state = self.resblocks[i * self.num_kernels](hidden_states) | |
| for j in range(1, self.num_kernels): | |
| res_state += self.resblocks[i * self.num_kernels + j](hidden_states) | |
| hidden_states = res_state / self.num_kernels | |
| hidden_states = F.leaky_relu(hidden_states, negative_slope=.01, inplace=True) | |
| hidden_states = self.conv_post(hidden_states) | |
| waveform = torch.tanh(hidden_states) | |
| return waveform | |
| class VitsResidualCouplingLayer(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.half_channels = config.flow_size // 2 | |
| self.conv_pre = nn.Conv1d(self.half_channels, config.hidden_size, 1) | |
| self.wavenet = VitsWaveNet(config, num_layers=config.prior_encoder_num_wavenet_layers) | |
| self.conv_post = nn.Conv1d(config.hidden_size, self.half_channels, 1) | |
| def forward(self, | |
| x, | |
| reverse=False): | |
| first_half, second_half = torch.split(x, [self.half_channels] * 2, dim=1) | |
| hidden_states = self.conv_pre(first_half) | |
| hidden_states = self.wavenet(hidden_states) | |
| mean = self.conv_post(hidden_states) | |
| second_half = (second_half - mean) | |
| outputs = torch.cat([first_half, second_half], dim=1) | |
| return outputs | |
| class VitsResidualCouplingBlock(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.flows = nn.ModuleList() | |
| for _ in range(config.prior_encoder_num_flows): | |
| self.flows.append(VitsResidualCouplingLayer(config)) | |
| def forward(self, x, reverse=False): | |
| # x L [1, 192, 481] | |
| for flow in reversed(self.flows): | |
| x = torch.flip(x, [1]) # flipud CHANNELs | |
| x = flow(x, reverse=True) | |
| return x | |
| class VitsAttention(nn.Module): | |
| """has no positional info""" | |
| def __init__(self, config): | |
| super().__init__() | |
| self.embed_dim = config.hidden_size | |
| self.num_heads = config.num_attention_heads | |
| self.head_dim = self.embed_dim // self.num_heads | |
| self.scaling = self.head_dim**-0.5 | |
| self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias) | |
| self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias) | |
| self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias) | |
| self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias) | |
| def _shape(self, tensor, seq_len, bsz): | |
| return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() | |
| def forward( | |
| self, | |
| hidden_states, | |
| layer_head_mask = None, | |
| output_attentions = False, | |
| ): | |
| bsz, tgt_len, _ = hidden_states.size() | |
| # Q | |
| query_states = self.q_proj(hidden_states) * self.scaling | |
| # K/V | |
| hidden_states = hidden_states[:, :40, :] # drop time-frames from k/v [bs*2, time, 96=ch] | |
| key_states = self._shape(self.k_proj(hidden_states), -1, bsz) | |
| value_states = self._shape(self.v_proj(hidden_states), -1, bsz) | |
| proj_shape = (bsz * self.num_heads, -1, self.head_dim) | |
| query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) | |
| key_states = key_states.view(*proj_shape) | |
| value_states = value_states.view(*proj_shape) | |
| attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) | |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1) | |
| attn_output = torch.bmm(attn_weights, | |
| value_states) | |
| attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) | |
| attn_output = attn_output.transpose(1, 2) | |
| # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be | |
| # partitioned aross GPUs when using tensor-parallelism. | |
| attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) | |
| attn_output = self.out_proj(attn_output) | |
| return attn_output | |
| class VitsFeedForward(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.conv_1 = nn.Conv1d(config.hidden_size, config.ffn_dim, config.ffn_kernel_size, padding=1) | |
| self.conv_2 = nn.Conv1d(config.ffn_dim, config.hidden_size, config.ffn_kernel_size, padding=1) | |
| def forward(self, hidden_states): | |
| hidden_states = hidden_states.permute(0, 2, 1) | |
| hidden_states = F.relu(self.conv_1(hidden_states)) # inplace changes sound ; | |
| hidden_states = self.conv_2(hidden_states) | |
| hidden_states = hidden_states.permute(0, 2, 1) | |
| return hidden_states | |
| class VitsEncoderLayer(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.attention = VitsAttention(config) | |
| self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-5) | |
| self.feed_forward = VitsFeedForward(config) | |
| self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-5) | |
| def forward( | |
| self, | |
| hidden_states, | |
| output_attentions = False, | |
| ): | |
| residual = hidden_states | |
| hidden_states = self.attention( | |
| hidden_states=hidden_states, | |
| # attention_mask=attention_mask, | |
| output_attentions=output_attentions, | |
| ) | |
| hidden_states = self.layer_norm(residual + hidden_states) | |
| residual = hidden_states | |
| hidden_states = self.feed_forward(hidden_states) | |
| hidden_states = self.final_layer_norm(residual + hidden_states) | |
| outputs = (hidden_states,) | |
| return outputs | |
| class VitsEncoder(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.layers = nn.ModuleList([VitsEncoderLayer(config) for _ in range(config.num_hidden_layers)]) | |
| def forward( | |
| self, | |
| hidden_states): | |
| for _layer in self.layers: | |
| layer_outputs = _layer(hidden_states) | |
| hidden_states = layer_outputs[0] | |
| return hidden_states | |
| class VitsTextEncoder(nn.Module): | |
| """ | |
| Has VitsEncoder | |
| """ | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) | |
| self.encoder = VitsEncoder(config) # 6 Layers of VitsAttention | |
| self.project = nn.Conv1d(config.hidden_size, config.flow_size * 2, kernel_size=1) | |
| def forward(self, | |
| input_ids | |
| ): | |
| hidden_states = self.embed_tokens(input_ids) * 4 #Actually4-or-4.856406460551018-@-845-len-ids-deu | |
| stats = self.project(self.encoder(hidden_states=hidden_states).transpose(1, 2)).transpose(1, 2) | |
| return stats[:, :, :self.config.flow_size] # prior_means | |
| class VitsPreTrainedModel(PreTrainedModel): | |
| config_class = VitsConfig | |
| base_model_prefix = "vits" | |
| main_input_name = "input_ids" | |
| supports_gradient_checkpointing = True | |
| class VitsModel(VitsPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.config = config | |
| self.text_encoder = VitsTextEncoder(config) # has VitsEncoder that includes 6L of VitsAttention | |
| self.flow = VitsResidualCouplingBlock(config) | |
| self.decoder = VitsHifiGan(config) | |
| def forward( | |
| self, | |
| input_ids = None, | |
| attention_mask = None, | |
| speaker_id = None, | |
| output_attentions = None, | |
| output_hidden_states = None, | |
| return_dict = None, | |
| labels = None, | |
| speed = None, | |
| lang_code = 'deu', # speed oscillation pattern per voice/lang | |
| ): | |
| mask_dtype = self.text_encoder.embed_tokens.weight.dtype | |
| if attention_mask is not None: | |
| input_padding_mask = attention_mask.unsqueeze(-1).to(mask_dtype) | |
| else: | |
| raise ValueError | |
| input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).to(mask_dtype) | |
| prior_means = self.text_encoder(input_ids=input_ids) | |
| input_padding_mask = input_padding_mask.transpose(1, 2) | |
| bs, in_len, _ = prior_means.shape | |
| # VITS Duration Oscillation | |
| pattern = OSCILLATION.get(lang_code, [1, 2, 1]) | |
| duration = torch.tensor(pattern, | |
| device=prior_means.device).repeat(int(in_len / len(pattern)) + 2)[None, None, :in_len] # perhaps define [1, 2, 1] per voice or language | |
| duration[:, :, 0] = 4 | |
| duration[:, :, -1] = 3 | |
| # ATTN | |
| predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long() | |
| indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device) | |
| output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1) | |
| output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype) | |
| attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1) | |
| batch_size, _, output_length, input_length = attn_mask.shape | |
| cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1) | |
| indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device) | |
| valid_indices = indices.unsqueeze(0) < cum_duration | |
| valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length) | |
| padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1] | |
| attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask | |
| attn = attn[:, 0, :, :] | |
| attn = attn + 1e-4 * torch.rand_like(attn) | |
| attn /= attn.sum(2, keepdims=True) | |
| #print(attn) | |
| prior_means = torch.matmul(attn, prior_means) # try attn to contain .5/.5 instead of 1/0 so it smoothly interpolates repeated prior_means | |
| #prior_means = F.interpolate(prior_means.transpose(1,2), int(1.74 * prior_means.shape[1]), mode='linear').transpose(1,2) # extend for slow speed | |
| # prior means have now been replicated x duration of each prior mean | |
| latents = self.flow(prior_means.transpose(1, 2), # + torch.randn_like(prior_means) * .94, | |
| reverse=True) | |
| waveform = self.decoder(latents) # [bs, 1, 16000] | |
| return waveform[:, 0, :] | |
| class VitsTokenizer(PreTrainedTokenizer): | |
| vocab_files_names = {"vocab_file": "vocab.json"} | |
| model_input_names = ["input_ids", "attention_mask"] | |
| def __init__( | |
| self, | |
| vocab_file, | |
| pad_token="<pad>", | |
| unk_token="<unk>", | |
| language=None, | |
| add_blank=True, | |
| normalize=True, | |
| phonemize=True, | |
| is_uroman=False, | |
| **kwargs, | |
| ): | |
| with open(vocab_file, encoding="utf-8") as vocab_handle: | |
| self.encoder = json.load(vocab_handle) | |
| self.decoder = {v: k for k, v in self.encoder.items()} | |
| self.language = language | |
| self.add_blank = add_blank | |
| self.normalize = normalize | |
| self.phonemize = phonemize | |
| self.is_uroman = is_uroman | |
| super().__init__( | |
| pad_token=pad_token, | |
| unk_token=unk_token, | |
| language=language, | |
| add_blank=add_blank, | |
| normalize=normalize, | |
| phonemize=phonemize, | |
| is_uroman=is_uroman, | |
| **kwargs, | |
| ) | |
| def vocab_size(self): | |
| return len(self.encoder) | |
| def get_vocab(self): | |
| vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} | |
| vocab.update(self.added_tokens_encoder) | |
| return vocab | |
| def normalize_text(self, input_string): | |
| """Lowercase the input string, respecting any special token ids that may be part or entirely upper-cased.""" | |
| all_vocabulary = list(self.encoder.keys()) + list(self.added_tokens_encoder.keys()) | |
| filtered_text = "" | |
| i = 0 | |
| while i < len(input_string): | |
| found_match = False | |
| for word in all_vocabulary: | |
| if input_string[i : i + len(word)] == word: | |
| filtered_text += word | |
| i += len(word) | |
| found_match = True | |
| break | |
| if not found_match: | |
| filtered_text += input_string[i].lower() | |
| i += 1 | |
| return filtered_text | |
| def _preprocess_char(self, text): | |
| """Special treatment of characters in certain languages""" | |
| if self.language == "ron": | |
| text = text.replace("ț", "ţ") | |
| return text | |
| def prepare_for_tokenization( | |
| self, text: str, is_split_into_words: bool = False, normalize = None, **kwargs): | |
| normalize = normalize if normalize is not None else self.normalize | |
| if normalize: | |
| # normalise for casing | |
| text = self.normalize_text(text) | |
| filtered_text = self._preprocess_char(text) | |
| if has_non_roman_characters(filtered_text) and self.is_uroman: | |
| # 7 langs - For now replace all to romans in app.py | |
| raise ValueError | |
| if self.phonemize: | |
| if not is_phonemizer_available(): | |
| raise ImportError("Please install the `phonemizer` Python package to use this tokenizer.") | |
| filtered_text = phonemizer.phonemize( | |
| filtered_text, | |
| language="en-us", | |
| backend="espeak", | |
| strip=True, | |
| preserve_punctuation=True, | |
| with_stress=True, | |
| ) | |
| filtered_text = re.sub(r"\s+", " ", filtered_text) | |
| elif normalize: | |
| # strip any chars outside of the vocab (punctuation) | |
| filtered_text = "".join(list(filter(lambda char: char in self.encoder, filtered_text))).strip() | |
| return filtered_text, kwargs | |
| def _tokenize(self, text): | |
| """Tokenize a string by inserting the `<pad>` token at the boundary between adjacent characters.""" | |
| tokens = list(text) | |
| if self.add_blank: | |
| # sounds dyslexi if no space between letters | |
| # sounds disconnected if >2 spaces between letters | |
| interspersed = [self._convert_id_to_token(0)] * (len(tokens) * 2) # + 1) # +1 rises slice index error if tokens odd | |
| interspersed[::2] = tokens | |
| tokens = interspersed + [self._convert_id_to_token(0)] # append one last space (it has indexing error ::2 mismatch if tokens is odd) | |
| return tokens | |
| def _convert_token_to_id(self, token): | |
| """Converts a token (str) in an id using the vocab.""" | |
| return self.encoder.get(token, self.encoder.get(self.unk_token)) | |
| def _convert_id_to_token(self, index): | |
| """Converts an index (integer) in a token (str) using the vocab.""" | |
| return self.decoder.get(index) | |