Spaces:
Running
Running
import os | |
import time | |
import torch | |
from transformers import AutoProcessor, AutoModelForCTC, Wav2Vec2PhonemeCTCTokenizer | |
import librosa | |
from itertools import groupby | |
from datasets import load_dataset | |
# Load the model and processor | |
# checkpoint = "bookbot/wav2vec2-ljspeech-gruut" | |
checkpoint = "facebook/wav2vec2-lv-60-espeak-cv-ft" | |
model = AutoModelForCTC.from_pretrained(checkpoint) | |
processor = AutoProcessor.from_pretrained(checkpoint) | |
tokenizer = Wav2Vec2PhonemeCTCTokenizer.from_pretrained(checkpoint) | |
sr = processor.feature_extractor.sampling_rate | |
def decode_phonemes( | |
ids: torch.Tensor, processor: AutoProcessor, ignore_stress: bool = False | |
) -> str: | |
"""CTC-like decoding. First removes consecutive duplicates, then removes special tokens.""" | |
# Remove consecutive duplicates | |
ids = [id_ for id_, _ in groupby(ids)] | |
special_token_ids = processor.tokenizer.all_special_ids + [ | |
processor.tokenizer.word_delimiter_token_id | |
] | |
# Convert id to token, skipping special tokens | |
phonemes = [processor.decode(id_) for id_ in ids if id_ not in special_token_ids] | |
# Join phonemes | |
prediction = " ".join(phonemes) | |
# Ignore IPA stress marks if specified | |
if ignore_stress: | |
prediction = prediction.replace("ˈ", "").replace("ˌ", "") | |
return prediction | |
def text_to_phonemes(text: str) -> str: | |
s_time = time.time() | |
"""Convert text to phonemes using phonemizer.""" | |
# phonemes = phonemize(text, language="en-us", backend="espeak", strip=True) | |
phonemes = tokenizer.phonemize(text, phonemizer_lang="en-us") | |
e_time = time.time() | |
print(f"Execution time of text_to_phonemes: {e_time - s_time:.6f} seconds") | |
return phonemes | |
def separate_characters(input_string): | |
no_spaces = input_string.replace(" ", "") | |
spaced_string = " ".join(no_spaces) | |
return spaced_string | |
def predict_phonemes(audio_array): | |
# Load audio file and preprocess | |
# audio_array, _ = librosa.load(audio_path, sr=sr) | |
inputs = processor(audio_array, return_tensors="pt", padding=True) | |
# Perform inference | |
with torch.no_grad(): | |
logits = model(inputs["input_values"]).logits | |
# Decode the predicted phonemes | |
predicted_ids = torch.argmax(logits, dim=-1) | |
predicted_phonemes = decode_phonemes( | |
predicted_ids[0], processor, ignore_stress=True | |
) | |
return predicted_phonemes # Return the predicted phonemes | |
def adjust_phonemes(predicted: str) -> str: | |
# Replace specific phonemes or patterns as needed | |
# adjusted = predicted.replace(" ə ", " ") # Remove schwa if it appears alone | |
adjusted = predicted.replace(" ", " ") # Remove double spaces | |
adjusted = adjusted.strip() # Trim leading/trailing spaces | |
return adjusted | |
def calculate_score(expected: str, predicted: str) -> float: | |
expected_list = expected.split() | |
predicted_list = predicted.split() | |
# Calculate the number of correct matches | |
correct_matches = sum(1 for e, p in zip(expected_list, predicted_list) if e == p) | |
# Calculate the score as the ratio of correct matches to expected phonemes | |
score = correct_matches / len(expected_list) if expected_list else 0 | |
return score | |
def test_sound(): | |
start_time = time.time() | |
ds = load_dataset( | |
"patrickvonplaten/librispeech_asr_dummy", | |
"clean", | |
split="validation", | |
trust_remote_code=True, | |
) | |
audio_array = ds[0]["audio"]["array"] | |
text = ds[0]["text"] | |
# audio_path = "hello.wav" | |
# text = "Hello" | |
expected_transcript = text # Expected transcript | |
expected_phonemes = text_to_phonemes(text) # Expected phonemes for "Hello" | |
expected_phonemes = separate_characters(expected_phonemes) | |
# Call the phoneme prediction function | |
predicted_phonemes = predict_phonemes(audio_array) | |
adjusted_phonemes = adjust_phonemes(predicted_phonemes) | |
print(f"Expected Phonemes: {expected_phonemes}") | |
print(f"Predicted Phonemes: {predicted_phonemes}") | |
print(f"Adjusted Phonemes: {adjusted_phonemes}") | |
# Calculate score based on expected and predicted phonemes | |
score = calculate_score(expected_phonemes, adjusted_phonemes) | |
# Prepare the output | |
text = f"Transcript: {expected_transcript}\nExpected Phonemes: {expected_phonemes}\nPredicted Phonemes: {predicted_phonemes}\nAdjusted Phonemes: {adjusted_phonemes}\nScore: {score:.2f}" | |
end_time = time.time() | |
execution_time = end_time - start_time | |
print(f"Execution time: {execution_time:.6f} seconds") | |
return {"text": text} | |