speecht5_tts / handler.py
Dupaja's picture
Added function to convert digits to written numbers, for handling with SpeechT5
history blame
5.71 kB
import librosa
import numpy as np
import torch
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from datasets import load_dataset
import time
import re
import inflect
from typing import Dict, List, Any
def convert_numbers_to_text(input_string):
p = inflect.engine()
words = input_string.split()
new_words = []
for word in words:
if word.isdigit() and len(word) == 4: # Check for years (4-digit numbers)
year = int(word)
if year < 2000:
# Split the year into two parts
first_part = year // 100
second_part = year % 100
# Convert each part to words and combine
word = p.number_to_words(first_part) + " " + p.number_to_words(second_part)
elif year < 9999:
# Convert directly for year 2000 and beyond
word = p.number_to_words(year)
elif word.replace(',','').isdigit(): # Check for any other digits
word = word.replace(',','')
number = int(word)
word = p.number_to_words(number).replace(',', '')
return ' '.join(new_words)
def split_and_recombine_text(text, desired_length=200, max_length=300):
"""Split text it into chunks of a desired length trying to keep sentences intact."""
# normalize text, remove redundant whitespace and convert non-ascii quotes to ascii
text = re.sub(r'\n\n+', '\n', text)
text = re.sub(r'\s+', ' ', text)
text = re.sub(r'[“”]', '"', text)
rv = []
in_quote = False
current = ""
split_pos = []
pos = -1
end_pos = len(text) - 1
def seek(delta):
nonlocal pos, in_quote, current
is_neg = delta < 0
for _ in range(abs(delta)):
if is_neg:
pos -= 1
current = current[:-1]
pos += 1
current += text[pos]
if text[pos] == '"':
in_quote = not in_quote
return text[pos]
def peek(delta):
p = pos + delta
return text[p] if p < end_pos and p >= 0 else ""
def commit():
nonlocal rv, current, split_pos
current = ""
split_pos = []
while pos < end_pos:
c = seek(1)
# do we need to force a split?
if len(current) >= max_length:
if len(split_pos) > 0 and len(current) > (desired_length / 2):
# we have at least one sentence and we are over half the desired length, seek back to the last split
d = pos - split_pos[-1]
# no full sentences, seek back until we are not in the middle of a word and split there
while c not in '!?.\n ' and pos > 0 and len(current) > desired_length:
c = seek(-1)
# check for sentence boundaries
elif not in_quote and (c in '!?\n' or (c == '.' and peek(1) in '\n ')):
# seek forward if we have consecutive boundary markers but still within the max length
while pos < len(text) - 1 and len(current) < max_length and peek(1) in '!?.':
c = seek(1)
if len(current) >= desired_length:
# treat end of quote as a boundary if its followed by a space or newline
elif in_quote and peek(1) == '"' and peek(2) in '\n ':
# clean up, remove lines with only whitespace or punctuation
rv = [s.strip() for s in rv]
rv = [s for s in rv if len(s) > 0 and not re.match(r'^[\s\.,;:!?]*$', s)]
return rv
class EndpointHandler:
def __init__(self, path=""):
#checkpoint = "microsoft/speecht5_tts"
#vocoder_id = "microsoft/speecht5_hifigan"
#dataset_id = "Matthijs/cmu-arctic-xvectors"
checkpoint = "Dupaja/speecht5_tts"
vocoder_id = "Dupaja/speecht5_hifigan"
dataset_id = "Dupaja/cmu-arctic-xvectors"
self.model= SpeechT5ForTextToSpeech.from_pretrained(checkpoint, low_cpu_mem_usage=True)
self.processor = SpeechT5Processor.from_pretrained(checkpoint)
self.vocoder = SpeechT5HifiGan.from_pretrained(vocoder_id)
embeddings_dataset = load_dataset(dataset_id, split="validation", trust_remote_code=True)
self.embeddings_dataset = embeddings_dataset
self.speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
given_text = data.get("inputs", "")
start_time = time.time()
given_text = convert_numbers_to_text(given_text)
texts = split_and_recombine_text(given_text)
audios = []
for t in texts:
inputs = self.processor(text=t, return_tensors="pt")
speech = self.model.generate_speech(inputs["input_ids"], self.speaker_embeddings, vocoder=self.vocoder)
final_speech = np.concatenate(audios)
run_time_total = time.time() - start_time
# Return the expected response format
return {
"statusCode": 200,
"body": {
"audio": final_speech, # Consider encoding this to a suitable format
"sampling_rate": 16000,
"run_time_total": str(run_time_total),
handler = EndpointHandler()