File size: 4,537 Bytes
cb3e494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import time
import torch
from transformers import AutoProcessor, AutoModelForCTC, Wav2Vec2PhonemeCTCTokenizer
import librosa
from itertools import groupby
from datasets import load_dataset

# Load the model and processor
# checkpoint = "bookbot/wav2vec2-ljspeech-gruut"
checkpoint = "facebook/wav2vec2-lv-60-espeak-cv-ft"
model = AutoModelForCTC.from_pretrained(checkpoint)
processor = AutoProcessor.from_pretrained(checkpoint)
tokenizer = Wav2Vec2PhonemeCTCTokenizer.from_pretrained(checkpoint)
sr = processor.feature_extractor.sampling_rate


def decode_phonemes(
    ids: torch.Tensor, processor: AutoProcessor, ignore_stress: bool = False
) -> str:
    """CTC-like decoding. First removes consecutive duplicates, then removes special tokens."""
    # Remove consecutive duplicates
    ids = [id_ for id_, _ in groupby(ids)]

    special_token_ids = processor.tokenizer.all_special_ids + [
        processor.tokenizer.word_delimiter_token_id
    ]
    # Convert id to token, skipping special tokens
    phonemes = [processor.decode(id_) for id_ in ids if id_ not in special_token_ids]

    # Join phonemes
    prediction = " ".join(phonemes)

    # Ignore IPA stress marks if specified
    if ignore_stress:
        prediction = prediction.replace("ˈ", "").replace("ˌ", "")

    return prediction


def text_to_phonemes(text: str) -> str:
    s_time = time.time()
    """Convert text to phonemes using phonemizer."""
    # phonemes = phonemize(text, language="en-us", backend="espeak", strip=True)
    phonemes = tokenizer.phonemize(text, phonemizer_lang="en-us")
    e_time = time.time()
    print(f"Execution time of text_to_phonemes: {e_time - s_time:.6f} seconds")
    return phonemes


def separate_characters(input_string):
    no_spaces = input_string.replace(" ", "")
    spaced_string = " ".join(no_spaces)
    return spaced_string


def predict_phonemes(audio_array):
    # Load audio file and preprocess
    # audio_array, _ = librosa.load(audio_path, sr=sr)

    inputs = processor(audio_array, return_tensors="pt", padding=True)

    # Perform inference
    with torch.no_grad():
        logits = model(inputs["input_values"]).logits

    # Decode the predicted phonemes
    predicted_ids = torch.argmax(logits, dim=-1)
    predicted_phonemes = decode_phonemes(
        predicted_ids[0], processor, ignore_stress=True
    )

    return predicted_phonemes  # Return the predicted phonemes


def adjust_phonemes(predicted: str) -> str:
    # Replace specific phonemes or patterns as needed
    # adjusted = predicted.replace(" ə ", " ")  # Remove schwa if it appears alone
    adjusted = predicted.replace("  ", " ")  # Remove double spaces
    adjusted = adjusted.strip()  # Trim leading/trailing spaces
    return adjusted


def calculate_score(expected: str, predicted: str) -> float:
    expected_list = expected.split()
    predicted_list = predicted.split()

    # Calculate the number of correct matches
    correct_matches = sum(1 for e, p in zip(expected_list, predicted_list) if e == p)

    # Calculate the score as the ratio of correct matches to expected phonemes
    score = correct_matches / len(expected_list) if expected_list else 0
    return score


def test_sound():
    start_time = time.time()

    ds = load_dataset(
        "patrickvonplaten/librispeech_asr_dummy",
        "clean",
        split="validation",
        trust_remote_code=True,
    )
    audio_array = ds[0]["audio"]["array"]

    text = ds[0]["text"]
    # audio_path = "hello.wav"
    # text = "Hello"
    expected_transcript = text  # Expected transcript
    expected_phonemes = text_to_phonemes(text)  # Expected phonemes for "Hello"
    expected_phonemes = separate_characters(expected_phonemes)
    # Call the phoneme prediction function
    predicted_phonemes = predict_phonemes(audio_array)
    adjusted_phonemes = adjust_phonemes(predicted_phonemes)
    print(f"Expected Phonemes: {expected_phonemes}")
    print(f"Predicted Phonemes: {predicted_phonemes}")
    print(f"Adjusted Phonemes: {adjusted_phonemes}")

    # Calculate score based on expected and predicted phonemes
    score = calculate_score(expected_phonemes, adjusted_phonemes)

    # Prepare the output
    text = f"Transcript: {expected_transcript}\nExpected Phonemes: {expected_phonemes}\nPredicted Phonemes: {predicted_phonemes}\nAdjusted Phonemes: {adjusted_phonemes}\nScore: {score:.2f}"
    end_time = time.time()
    execution_time = end_time - start_time
    print(f"Execution time: {execution_time:.6f} seconds")
    return {"text": text}