Spaces:
Running
Running
File size: 3,986 Bytes
1337d7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import torch
import numpy as np
import onnxruntime as ort
TOKEN_LIMIT = 510
SAMPLE_RATE = 24_000
class Kokoro:
def __init__(self, model_path: str, style_vector_path: str, tokenizer, lang: str = 'en-us') -> None:
"""
Initializes the ONNXInference class.
Args:
model_path (str): Path to the ONNX model file.
style_vector_path (str): Path to the style vector file.
lang (str): Language code for the tokenizer.
"""
self.sess = ort.InferenceSession(model_path)
self.style_vector_path = style_vector_path
self.tokenizer = tokenizer
self.lang = lang
def preprocess(self, text):
"""
Converts input text to tokenized numerical IDs and loads the style vector.
Args:
text (str): Input text to preprocess.
Returns:
tuple: Tokenized input and corresponding style vector.
"""
# Convert text to phonemes and tokenize
phonemes = self.tokenizer.phonemize(text, lang=self.lang)
tokenized_phonemes = self.tokenizer.tokenize(phonemes)
if not tokenized_phonemes:
raise ValueError("No tokens found after tokenization")
style_vector = torch.load(self.style_vector_path, weights_only=True)
if len(tokenized_phonemes) > TOKEN_LIMIT:
token_chunks = self.split_into_chunks(tokenized_phonemes)
tokens_list = []
styles_list = []
for chunk in token_chunks:
token_chunk = [[0, *chunk, 0]]
style_chunk = style_vector[len(chunk)].numpy()
tokens_list.append(token_chunk)
styles_list.append(style_chunk)
return tokens_list, styles_list
style_vector = style_vector[len(tokenized_phonemes)].numpy()
tokenized_phonemes = [[0, *tokenized_phonemes, 0]]
return tokenized_phonemes, style_vector
@staticmethod
def split_into_chunks(tokens):
"""
Splits a list of tokens into chunks of size TOKEN_LIMIT.
Args:
tokens (list): List of tokens to split.
Returns:
list: List of token chunks.
"""
tokens_chunks = []
for i in range(0, len(tokens), TOKEN_LIMIT):
tokens_chunks.append(tokens[i:i+TOKEN_LIMIT])
return tokens_chunks
def infer(self, tokens, style_vector, speed=1.0):
"""
Runs inference using the ONNX model.
Args:
tokens (list): Tokenized input for the model.
style_vector (numpy.ndarray): Style vector for the model.
speed (float): Speed parameter for inference.
Returns:
numpy.ndarray: Generated audio data.
"""
# Perform inference
audio = self.sess.run(
None,
{
'tokens': tokens,
'style': style_vector,
'speed': np.array([speed], dtype=np.float32),
}
)[0]
return audio
def generate_audio(self, text, speed=1.0):
"""
Full pipeline: preprocess, infer, and save the generated audio.
Args:
text (str): Input text to generate audio from.
speed (float): Speed parameter for inference.
"""
# Preprocess text
tokenized_data, styles_data = self.preprocess(text)
audio_segments = []
if len(tokenized_data) > 1: # list of token chunks
for token_chunk, style_chunk in zip(tokenized_data, styles_data):
audio = self.infer(token_chunk, style_chunk, speed=speed)
audio_segments.append(audio)
else: # single token less than input limit
# Run inference
audio = self.infer(tokenized_data, styles_data, speed=speed)
audio_segments.append(audio)
full_audio = np.concatenate(audio_segments)
return full_audio, SAMPLE_RATE
|