|
import torchaudio |
|
from speechbrain.inference.ASR import EncoderASR |
|
from speechbrain.dataio.encoder import CTCTextEncoder |
|
from pyctcdecode import build_ctcdecoder |
|
import torch |
|
import speechbrain as sb |
|
|
|
|
|
|
|
|
|
import logging |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
checkpoint_path = hf_hub_download( |
|
repo_id="brdhaker3/TunASR", |
|
filename="model/1234/save/CKPT+2024-05-27+00-52-30+00/wav2vec2.ckpt", |
|
|
|
local_dir="./", |
|
) |
|
logging.info(f"Checkpoint downloaded to: {checkpoint_path}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
asr_model = EncoderASR.from_hparams( |
|
source="brdhaker3/TunASR", |
|
savedir = "./model" |
|
) |
|
|
|
|
|
encoder = CTCTextEncoder() |
|
encoder.load_or_create( |
|
path=asr_model.hparams.encoder_file, |
|
from_didatasets=[[]], |
|
output_key="char_list", |
|
special_labels={"blank_label": 0, "unk_label": 1}, |
|
sequence_input=True, |
|
) |
|
asr_model.tokenizer = encoder |
|
|
|
|
|
vocab = asr_model.tokenizer.ind2lab |
|
labels = [vocab[i] for i in range(len(vocab))] |
|
labels = [""] + labels[1:-1] + ["1"] |
|
|
|
|
|
decoder = build_ctcdecoder( |
|
labels, |
|
kenlm_model_path=asr_model.hparams.ngram_lm_path, |
|
alpha=0.5, |
|
beta=1.0, |
|
) |
|
class ASR(sb.core.Brain): |
|
def treat_wav(self, sig): |
|
"""Process a waveform and return the transcribed text.""" |
|
feats = self.modules.wav2vec2(sig.to("cpu"), torch.tensor([1]).to("cpu")) |
|
feats = self.modules.enc(feats) |
|
logits = self.modules.ctc_lin(feats) |
|
p_ctc = self.hparams.log_softmax(logits) |
|
predicted_words = [] |
|
for logs in p_ctc: |
|
text = decoder.decode(logs.detach().cpu().numpy()) |
|
predicted_words.append(text.split(" ")) |
|
return " ".join(predicted_words[0]) |
|
|
|
asr_brain = ASR( |
|
modules=asr_model.hparams.modules, |
|
hparams=vars(asr_model.hparams), |
|
run_opts={"device": "cpu"}, |
|
checkpointer=asr_model.hparams.checkpointer, |
|
) |
|
asr_brain.tokenizer = encoder |
|
asr_brain.checkpointer.recover_if_possible() |
|
asr_brain.modules.eval() |
|
def treat_wav_file(file_mic, file_upload, asr=asr_brain, device="cpu"): |
|
if file_mic is not None: |
|
wav = file_mic |
|
elif file_upload is not None: |
|
wav = file_upload |
|
else: |
|
return "ERROR: You have to either use the microphone or upload an audio file" |
|
|
|
|
|
info = torchaudio.info(wav) |
|
sr = info.sample_rate |
|
sig = sb.dataio.dataio.read_audio(wav) |
|
if len(sig.shape) > 1: |
|
sig = torch.mean(sig, dim=1) |
|
sig = torch.unsqueeze(sig, 0) |
|
tensor_wav = sig.to(device) |
|
resampled = torchaudio.functional.resample(tensor_wav, sr, 16000) |
|
|
|
|
|
sentence = asr.treat_wav(resampled) |
|
return sentence |
|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
|
title = "Tunisian Speech Recognition" |
|
|
|
description = ''' |
|
This is a Tunisian ASR based on the **WavLM Model**, fine-tuned on a dataset of **2.5 hours**, resulting in a **W.E.R of 24%** and a **C.E.R of 9%**. |
|
\n |
|
Interested? Try it out! |
|
''' |
|
|
|
disclaimer = ''' |
|
> ⚠️ **Disclaimer:** |
|
> This is a **demo model**, The transcription accuracy isn't accurate due to Hugging Face model storage constraints. |
|
> For better performance,you can run the full model locally. |
|
> Please check out the repository and follow the instructions: [Full Model Repo Link](https://huggingface.co/brdhaker3/TunASR) |
|
''' |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(f"# {title}") |
|
gr.Markdown(description) |
|
gr.Markdown(disclaimer) |
|
|
|
interface = gr.Interface( |
|
fn=treat_wav_file, |
|
inputs=[ |
|
gr.Audio(sources="microphone", type='filepath', label="Record"), |
|
gr.Audio(sources="upload", type='filepath', label="Upload File") |
|
], |
|
outputs="text", |
|
title="", |
|
description="" |
|
) |
|
|
|
|
|
demo.launch() |
|
|