File size: 4,276 Bytes
26fe98b
 
 
 
 
 
 
c7828cd
 
 
 
2efe483
c7828cd
 
 
 
 
 
 
b67130e
c7828cd
 
 
 
 
 
 
 
 
26fe98b
 
cd186c3
fb76b60
26fe98b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54bcda0
26fe98b
54bcda0
 
 
26fe98b
54bcda0
 
 
 
 
835adae
 
 
54bcda0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torchaudio
from speechbrain.inference.ASR import EncoderASR
from speechbrain.dataio.encoder import CTCTextEncoder
from pyctcdecode import build_ctcdecoder
import torch
import speechbrain as sb




import logging
from huggingface_hub import hf_hub_download  # Add this import

# Set up logging
logging.basicConfig(level=logging.INFO)

# Download the checkpoint from Hugging Face Hub
checkpoint_path = hf_hub_download(
    repo_id="brdhaker3/TunASR",
    filename="model/1234/save/CKPT+2024-05-27+00-52-30+00/wav2vec2.ckpt",  # Path to your checkpoint

    local_dir="./",  # Save it to a local directory
)
logging.info(f"Checkpoint downloaded to: {checkpoint_path}")





# Load the ASR model
asr_model = EncoderASR.from_hparams(
    source="brdhaker3/TunASR",
    savedir = "./model"
)

# Loading Custom Tokenizer
encoder = CTCTextEncoder()
encoder.load_or_create(
    path=asr_model.hparams.encoder_file,
    from_didatasets=[[]],
    output_key="char_list",
    special_labels={"blank_label": 0, "unk_label": 1},
    sequence_input=True,
)
asr_model.tokenizer = encoder

# Prepare labels for the CTC decoder
vocab = asr_model.tokenizer.ind2lab
labels = [vocab[i] for i in range(len(vocab))]  # Extract labels from the tokenizer
labels = [""] + labels[1:-1] + ["1"]  # Adjust labels to match CTC format

# Initialize the CTC decoder with a language model
decoder = build_ctcdecoder(
    labels,
    kenlm_model_path=asr_model.hparams.ngram_lm_path,  # Path to your LM
    alpha=0.5,  # LM weight
    beta=1.0,   # Word insertion penalty
)
class ASR(sb.core.Brain):
    def treat_wav(self, sig):
        """Process a waveform and return the transcribed text."""
        feats = self.modules.wav2vec2(sig.to("cpu"), torch.tensor([1]).to("cpu"))
        feats = self.modules.enc(feats)
        logits = self.modules.ctc_lin(feats)
        p_ctc = self.hparams.log_softmax(logits)
        predicted_words = []
        for logs in p_ctc:
            text = decoder.decode(logs.detach().cpu().numpy())
            predicted_words.append(text.split(" "))
        return " ".join(predicted_words[0])
# Initialize the ASR model
asr_brain = ASR(
    modules=asr_model.hparams.modules,
    hparams=vars(asr_model.hparams),
    run_opts={"device": "cpu"},
    checkpointer=asr_model.hparams.checkpointer,
)
asr_brain.tokenizer = encoder
asr_brain.checkpointer.recover_if_possible()
asr_brain.modules.eval()
def treat_wav_file(file_mic, file_upload, asr=asr_brain, device="cpu"):
    if file_mic is not None:
        wav = file_mic
    elif file_upload is not None:
        wav = file_upload
    else:
        return "ERROR: You have to either use the microphone or upload an audio file"

    # Read and preprocess the audio file
    info = torchaudio.info(wav)
    sr = info.sample_rate
    sig = sb.dataio.dataio.read_audio(wav)
    if len(sig.shape) > 1:
        sig = torch.mean(sig, dim=1)
    sig = torch.unsqueeze(sig, 0)
    tensor_wav = sig.to(device)
    resampled = torchaudio.functional.resample(tensor_wav, sr, 16000)

    # Transcribe the audio
    sentence = asr.treat_wav(resampled)
    return sentence

# Test the function
# print(treat_wav_file("./audio.wav", "./audio.wav"))


#Gradio interface
import gradio as gr

title = "Tunisian Speech Recognition"

description = ''' 
This is a Tunisian ASR based on the **WavLM Model**, fine-tuned on a dataset of **2.5 hours**, resulting in a **W.E.R of 24%** and a **C.E.R of 9%**.  
\n
Interested? Try it out!
'''

disclaimer = '''
> ⚠️ **Disclaimer:**  
> This is a **demo model**, The transcription accuracy isn't accurate due to Hugging Face model storage constraints.  
> For better performance,you can run the full model locally.  
> Please check out the repository and follow the instructions: [Full Model Repo Link](https://huggingface.co/brdhaker3/TunASR)  
'''

with gr.Blocks() as demo:
    gr.Markdown(f"# {title}")
    gr.Markdown(description)
    gr.Markdown(disclaimer)

    interface = gr.Interface(
        fn=treat_wav_file,
        inputs=[
            gr.Audio(sources="microphone", type='filepath', label="Record"),
            gr.Audio(sources="upload", type='filepath', label="Upload File")
        ],
        outputs="text",
        title="",
        description=""
    )
    

demo.launch()