brdhaker3's picture
Update app.py
835adae verified
import torchaudio
from speechbrain.inference.ASR import EncoderASR
from speechbrain.dataio.encoder import CTCTextEncoder
from pyctcdecode import build_ctcdecoder
import torch
import speechbrain as sb
import logging
from huggingface_hub import hf_hub_download # Add this import
# Set up logging
logging.basicConfig(level=logging.INFO)
# Download the checkpoint from Hugging Face Hub
checkpoint_path = hf_hub_download(
repo_id="brdhaker3/TunASR",
filename="model/1234/save/CKPT+2024-05-27+00-52-30+00/wav2vec2.ckpt", # Path to your checkpoint
local_dir="./", # Save it to a local directory
)
logging.info(f"Checkpoint downloaded to: {checkpoint_path}")
# Load the ASR model
asr_model = EncoderASR.from_hparams(
source="brdhaker3/TunASR",
savedir = "./model"
)
# Loading Custom Tokenizer
encoder = CTCTextEncoder()
encoder.load_or_create(
path=asr_model.hparams.encoder_file,
from_didatasets=[[]],
output_key="char_list",
special_labels={"blank_label": 0, "unk_label": 1},
sequence_input=True,
)
asr_model.tokenizer = encoder
# Prepare labels for the CTC decoder
vocab = asr_model.tokenizer.ind2lab
labels = [vocab[i] for i in range(len(vocab))] # Extract labels from the tokenizer
labels = [""] + labels[1:-1] + ["1"] # Adjust labels to match CTC format
# Initialize the CTC decoder with a language model
decoder = build_ctcdecoder(
labels,
kenlm_model_path=asr_model.hparams.ngram_lm_path, # Path to your LM
alpha=0.5, # LM weight
beta=1.0, # Word insertion penalty
)
class ASR(sb.core.Brain):
def treat_wav(self, sig):
"""Process a waveform and return the transcribed text."""
feats = self.modules.wav2vec2(sig.to("cpu"), torch.tensor([1]).to("cpu"))
feats = self.modules.enc(feats)
logits = self.modules.ctc_lin(feats)
p_ctc = self.hparams.log_softmax(logits)
predicted_words = []
for logs in p_ctc:
text = decoder.decode(logs.detach().cpu().numpy())
predicted_words.append(text.split(" "))
return " ".join(predicted_words[0])
# Initialize the ASR model
asr_brain = ASR(
modules=asr_model.hparams.modules,
hparams=vars(asr_model.hparams),
run_opts={"device": "cpu"},
checkpointer=asr_model.hparams.checkpointer,
)
asr_brain.tokenizer = encoder
asr_brain.checkpointer.recover_if_possible()
asr_brain.modules.eval()
def treat_wav_file(file_mic, file_upload, asr=asr_brain, device="cpu"):
if file_mic is not None:
wav = file_mic
elif file_upload is not None:
wav = file_upload
else:
return "ERROR: You have to either use the microphone or upload an audio file"
# Read and preprocess the audio file
info = torchaudio.info(wav)
sr = info.sample_rate
sig = sb.dataio.dataio.read_audio(wav)
if len(sig.shape) > 1:
sig = torch.mean(sig, dim=1)
sig = torch.unsqueeze(sig, 0)
tensor_wav = sig.to(device)
resampled = torchaudio.functional.resample(tensor_wav, sr, 16000)
# Transcribe the audio
sentence = asr.treat_wav(resampled)
return sentence
# Test the function
# print(treat_wav_file("./audio.wav", "./audio.wav"))
#Gradio interface
import gradio as gr
title = "Tunisian Speech Recognition"
description = '''
This is a Tunisian ASR based on the **WavLM Model**, fine-tuned on a dataset of **2.5 hours**, resulting in a **W.E.R of 24%** and a **C.E.R of 9%**.
\n
Interested? Try it out!
'''
disclaimer = '''
> ⚠️ **Disclaimer:**
> This is a **demo model**, The transcription accuracy isn't accurate due to Hugging Face model storage constraints.
> For better performance,you can run the full model locally.
> Please check out the repository and follow the instructions: [Full Model Repo Link](https://huggingface.co/brdhaker3/TunASR)
'''
with gr.Blocks() as demo:
gr.Markdown(f"# {title}")
gr.Markdown(description)
gr.Markdown(disclaimer)
interface = gr.Interface(
fn=treat_wav_file,
inputs=[
gr.Audio(sources="microphone", type='filepath', label="Record"),
gr.Audio(sources="upload", type='filepath', label="Upload File")
],
outputs="text",
title="",
description=""
)
demo.launch()