import os import time import torch from transformers import AutoProcessor, AutoModelForCTC, Wav2Vec2PhonemeCTCTokenizer import librosa from itertools import groupby from datasets import load_dataset # Load the model and processor # checkpoint = "bookbot/wav2vec2-ljspeech-gruut" checkpoint = "facebook/wav2vec2-lv-60-espeak-cv-ft" model = AutoModelForCTC.from_pretrained(checkpoint) processor = AutoProcessor.from_pretrained(checkpoint) tokenizer = Wav2Vec2PhonemeCTCTokenizer.from_pretrained(checkpoint) sr = processor.feature_extractor.sampling_rate def decode_phonemes( ids: torch.Tensor, processor: AutoProcessor, ignore_stress: bool = False ) -> str: """CTC-like decoding. First removes consecutive duplicates, then removes special tokens.""" # Remove consecutive duplicates ids = [id_ for id_, _ in groupby(ids)] special_token_ids = processor.tokenizer.all_special_ids + [ processor.tokenizer.word_delimiter_token_id ] # Convert id to token, skipping special tokens phonemes = [processor.decode(id_) for id_ in ids if id_ not in special_token_ids] # Join phonemes prediction = " ".join(phonemes) # Ignore IPA stress marks if specified if ignore_stress: prediction = prediction.replace("ˈ", "").replace("ˌ", "") return prediction def text_to_phonemes(text: str) -> str: s_time = time.time() """Convert text to phonemes using phonemizer.""" # phonemes = phonemize(text, language="en-us", backend="espeak", strip=True) phonemes = tokenizer.phonemize(text, phonemizer_lang="en-us") e_time = time.time() print(f"Execution time of text_to_phonemes: {e_time - s_time:.6f} seconds") return phonemes def separate_characters(input_string): no_spaces = input_string.replace(" ", "") spaced_string = " ".join(no_spaces) return spaced_string def predict_phonemes(audio_array): # Load audio file and preprocess # audio_array, _ = librosa.load(audio_path, sr=sr) inputs = processor(audio_array, return_tensors="pt", padding=True) # Perform inference with torch.no_grad(): logits = model(inputs["input_values"]).logits # Decode the predicted phonemes predicted_ids = torch.argmax(logits, dim=-1) predicted_phonemes = decode_phonemes( predicted_ids[0], processor, ignore_stress=True ) return predicted_phonemes # Return the predicted phonemes def adjust_phonemes(predicted: str) -> str: # Replace specific phonemes or patterns as needed # adjusted = predicted.replace(" ə ", " ") # Remove schwa if it appears alone adjusted = predicted.replace(" ", " ") # Remove double spaces adjusted = adjusted.strip() # Trim leading/trailing spaces return adjusted def calculate_score(expected: str, predicted: str) -> float: expected_list = expected.split() predicted_list = predicted.split() # Calculate the number of correct matches correct_matches = sum(1 for e, p in zip(expected_list, predicted_list) if e == p) # Calculate the score as the ratio of correct matches to expected phonemes score = correct_matches / len(expected_list) if expected_list else 0 return score def test_sound(): start_time = time.time() ds = load_dataset( "patrickvonplaten/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True, ) audio_array = ds[0]["audio"]["array"] text = ds[0]["text"] # audio_path = "hello.wav" # text = "Hello" expected_transcript = text # Expected transcript expected_phonemes = text_to_phonemes(text) # Expected phonemes for "Hello" expected_phonemes = separate_characters(expected_phonemes) # Call the phoneme prediction function predicted_phonemes = predict_phonemes(audio_array) adjusted_phonemes = adjust_phonemes(predicted_phonemes) print(f"Expected Phonemes: {expected_phonemes}") print(f"Predicted Phonemes: {predicted_phonemes}") print(f"Adjusted Phonemes: {adjusted_phonemes}") # Calculate score based on expected and predicted phonemes score = calculate_score(expected_phonemes, adjusted_phonemes) # Prepare the output text = f"Transcript: {expected_transcript}\nExpected Phonemes: {expected_phonemes}\nPredicted Phonemes: {predicted_phonemes}\nAdjusted Phonemes: {adjusted_phonemes}\nScore: {score:.2f}" end_time = time.time() execution_time = end_time - start_time print(f"Execution time: {execution_time:.6f} seconds") return {"text": text}