speak-smart / phoneme.py
uzagi
add phoneme
cb3e494
raw
history blame
5.53 kB
import os
import time
import torch
from transformers import AutoProcessor, AutoModelForCTC, Wav2Vec2PhonemeCTCTokenizer
import librosa
from itertools import groupby
from datasets import load_dataset
from phonemizer import phonemize
from phonemizer.backend.espeak.wrapper import EspeakWrapper
# PHONEMIZER_ESPEAK_LIBRARY="c:\Program Files\eSpeak NG\libespeak-ng.dll"
# PHONEMIZER_ESPEAK_PATH="c:\Program Files\eSpeak NG"
# ESPEAK_PATH = os.getenv("PHONEMIZER_ESPEAK_LIBRARY")
# if ESPEAK_PATH is not None:
# EspeakWrapper.set_library(ESPEAK_PATH)
# print(f"Loaded environment variables PHONEMIZER_ESPEAK_LIBRARY: {ESPEAK_PATH}")
# print(f"Using espeak library: {EspeakWrapper.library_path}")
# Load the model and processor
# checkpoint = "bookbot/wav2vec2-ljspeech-gruut"
checkpoint = "facebook/wav2vec2-lv-60-espeak-cv-ft"
model = AutoModelForCTC.from_pretrained(checkpoint)
processor = AutoProcessor.from_pretrained(checkpoint)
tokenizer = Wav2Vec2PhonemeCTCTokenizer.from_pretrained(checkpoint)
sr = processor.feature_extractor.sampling_rate
def decode_phonemes(
ids: torch.Tensor, processor: AutoProcessor, ignore_stress: bool = False
) -> str:
"""CTC-like decoding. First removes consecutive duplicates, then removes special tokens."""
# Remove consecutive duplicates
ids = [id_ for id_, _ in groupby(ids)]
special_token_ids = processor.tokenizer.all_special_ids + [
processor.tokenizer.word_delimiter_token_id
]
# Convert id to token, skipping special tokens
phonemes = [processor.decode(id_) for id_ in ids if id_ not in special_token_ids]
# Join phonemes
prediction = " ".join(phonemes)
# Ignore IPA stress marks if specified
if ignore_stress:
prediction = prediction.replace("ˈ", "").replace("ˌ", "")
return prediction
def text_to_phonemes(text: str) -> str:
s_time = time.time()
"""Convert text to phonemes using phonemizer."""
# phonemes = phonemize(text, language="en-us", backend="espeak", strip=True)
phonemes = tokenizer.phonemize(text, phonemizer_lang="en-us")
e_time = time.time()
print(f"Execution time of text_to_phonemes: {e_time - s_time:.6f} seconds")
return phonemes
def text_to_phonemes_2(text: str) -> str:
s_time = time.time()
"""Convert text to phonemes using phonemizer."""
phonemes = phonemize(text, language="en-us", backend="espeak", strip=True)
# phonemes = tokenizer.phonemize(text)
e_time = time.time()
print(f"Execution time of text_to_phonemes_2: {e_time - s_time:.6f} seconds")
return phonemes
def separate_characters(input_string):
no_spaces = input_string.replace(" ", "")
spaced_string = " ".join(no_spaces)
return spaced_string
def predict_phonemes(audio_array):
# Load audio file and preprocess
# audio_array, _ = librosa.load(audio_path, sr=sr)
inputs = processor(audio_array, return_tensors="pt", padding=True)
# Perform inference
with torch.no_grad():
logits = model(inputs["input_values"]).logits
# Decode the predicted phonemes
predicted_ids = torch.argmax(logits, dim=-1)
predicted_phonemes = decode_phonemes(
predicted_ids[0], processor, ignore_stress=True
)
return predicted_phonemes # Return the predicted phonemes
def adjust_phonemes(predicted: str) -> str:
# Replace specific phonemes or patterns as needed
# adjusted = predicted.replace(" ə ", " ") # Remove schwa if it appears alone
adjusted = predicted.replace(" ", " ") # Remove double spaces
adjusted = adjusted.strip() # Trim leading/trailing spaces
return adjusted
def calculate_score(expected: str, predicted: str) -> float:
expected_list = expected.split()
predicted_list = predicted.split()
# Calculate the number of correct matches
correct_matches = sum(1 for e, p in zip(expected_list, predicted_list) if e == p)
# Calculate the score as the ratio of correct matches to expected phonemes
score = correct_matches / len(expected_list) if expected_list else 0
return score
def test_sound():
start_time = time.time()
ds = load_dataset(
"patrickvonplaten/librispeech_asr_dummy",
"clean",
split="validation",
trust_remote_code=True,
)
audio_array = ds[0]["audio"]["array"]
text = ds[0]["text"]
# audio_path = "hello.wav"
# text = "Hello"
expected_transcript = text # Expected transcript
expected_phonemes = text_to_phonemes(text) # Expected phonemes for "Hello"
expected_phonemes = separate_characters(expected_phonemes)
# Call the phoneme prediction function
predicted_phonemes = predict_phonemes(audio_array)
adjusted_phonemes = adjust_phonemes(predicted_phonemes)
# expected_phonemes_2 = text_to_phonemes_2(expected_transcript)
print(f"Expected Phonemes: {expected_phonemes}")
# print(f"Expected Phonemes 2: {expected_phonemes_2}")
print(f"Predicted Phonemes: {predicted_phonemes}")
print(f"Adjusted Phonemes: {adjusted_phonemes}")
# Calculate score based on expected and predicted phonemes
score = calculate_score(expected_phonemes, adjusted_phonemes)
# Prepare the output
text = f"Transcript: {expected_transcript}\nExpected Phonemes: {expected_phonemes}\nPredicted Phonemes: {predicted_phonemes}\nAdjusted Phonemes: {adjusted_phonemes}\nScore: {score:.2f}"
end_time = time.time()
execution_time = end_time - start_time
print(f"Execution time: {execution_time:.6f} seconds")
return {"text": text}