speak-smart / phoneme.py
uzagi's picture
Update phoneme.py
e90c6d1 verified
import os
import time
import torch
from transformers import AutoProcessor, AutoModelForCTC, Wav2Vec2PhonemeCTCTokenizer
import librosa
from itertools import groupby
from datasets import load_dataset
# Load the model and processor
# checkpoint = "bookbot/wav2vec2-ljspeech-gruut"
checkpoint = "facebook/wav2vec2-lv-60-espeak-cv-ft"
model = AutoModelForCTC.from_pretrained(checkpoint)
processor = AutoProcessor.from_pretrained(checkpoint)
tokenizer = Wav2Vec2PhonemeCTCTokenizer.from_pretrained(checkpoint)
sr = processor.feature_extractor.sampling_rate
def decode_phonemes(
ids: torch.Tensor, processor: AutoProcessor, ignore_stress: bool = False
) -> str:
"""CTC-like decoding. First removes consecutive duplicates, then removes special tokens."""
# Remove consecutive duplicates
ids = [id_ for id_, _ in groupby(ids)]
special_token_ids = processor.tokenizer.all_special_ids + [
processor.tokenizer.word_delimiter_token_id
]
# Convert id to token, skipping special tokens
phonemes = [processor.decode(id_) for id_ in ids if id_ not in special_token_ids]
# Join phonemes
prediction = " ".join(phonemes)
# Ignore IPA stress marks if specified
if ignore_stress:
prediction = prediction.replace("ˈ", "").replace("ˌ", "")
return prediction
def text_to_phonemes(text: str) -> str:
s_time = time.time()
"""Convert text to phonemes using phonemizer."""
# phonemes = phonemize(text, language="en-us", backend="espeak", strip=True)
phonemes = tokenizer.phonemize(text, phonemizer_lang="en-us")
e_time = time.time()
print(f"Execution time of text_to_phonemes: {e_time - s_time:.6f} seconds")
return phonemes
def separate_characters(input_string):
no_spaces = input_string.replace(" ", "")
spaced_string = " ".join(no_spaces)
return spaced_string
def predict_phonemes(audio_array):
# Load audio file and preprocess
# audio_array, _ = librosa.load(audio_path, sr=sr)
inputs = processor(audio_array, return_tensors="pt", padding=True)
# Perform inference
with torch.no_grad():
logits = model(inputs["input_values"]).logits
# Decode the predicted phonemes
predicted_ids = torch.argmax(logits, dim=-1)
predicted_phonemes = decode_phonemes(
predicted_ids[0], processor, ignore_stress=True
)
return predicted_phonemes # Return the predicted phonemes
def adjust_phonemes(predicted: str) -> str:
# Replace specific phonemes or patterns as needed
# adjusted = predicted.replace(" ə ", " ") # Remove schwa if it appears alone
adjusted = predicted.replace(" ", " ") # Remove double spaces
adjusted = adjusted.strip() # Trim leading/trailing spaces
return adjusted
def calculate_score(expected: str, predicted: str) -> float:
expected_list = expected.split()
predicted_list = predicted.split()
# Calculate the number of correct matches
correct_matches = sum(1 for e, p in zip(expected_list, predicted_list) if e == p)
# Calculate the score as the ratio of correct matches to expected phonemes
score = correct_matches / len(expected_list) if expected_list else 0
return score
def test_sound():
start_time = time.time()
ds = load_dataset(
"patrickvonplaten/librispeech_asr_dummy",
"clean",
split="validation",
trust_remote_code=True,
)
audio_array = ds[0]["audio"]["array"]
text = ds[0]["text"]
# audio_path = "hello.wav"
# text = "Hello"
expected_transcript = text # Expected transcript
expected_phonemes = text_to_phonemes(text) # Expected phonemes for "Hello"
expected_phonemes = separate_characters(expected_phonemes)
# Call the phoneme prediction function
predicted_phonemes = predict_phonemes(audio_array)
adjusted_phonemes = adjust_phonemes(predicted_phonemes)
print(f"Expected Phonemes: {expected_phonemes}")
print(f"Predicted Phonemes: {predicted_phonemes}")
print(f"Adjusted Phonemes: {adjusted_phonemes}")
# Calculate score based on expected and predicted phonemes
score = calculate_score(expected_phonemes, adjusted_phonemes)
# Prepare the output
text = f"Transcript: {expected_transcript}\nExpected Phonemes: {expected_phonemes}\nPredicted Phonemes: {predicted_phonemes}\nAdjusted Phonemes: {adjusted_phonemes}\nScore: {score:.2f}"
end_time = time.time()
execution_time = end_time - start_time
print(f"Execution time: {execution_time:.6f} seconds")
return {"text": text}