File size: 4,276 Bytes
26fe98b c7828cd 2efe483 c7828cd b67130e c7828cd 26fe98b cd186c3 fb76b60 26fe98b 54bcda0 26fe98b 54bcda0 26fe98b 54bcda0 835adae 54bcda0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import torchaudio
from speechbrain.inference.ASR import EncoderASR
from speechbrain.dataio.encoder import CTCTextEncoder
from pyctcdecode import build_ctcdecoder
import torch
import speechbrain as sb
import logging
from huggingface_hub import hf_hub_download # Add this import
# Set up logging
logging.basicConfig(level=logging.INFO)
# Download the checkpoint from Hugging Face Hub
checkpoint_path = hf_hub_download(
repo_id="brdhaker3/TunASR",
filename="model/1234/save/CKPT+2024-05-27+00-52-30+00/wav2vec2.ckpt", # Path to your checkpoint
local_dir="./", # Save it to a local directory
)
logging.info(f"Checkpoint downloaded to: {checkpoint_path}")
# Load the ASR model
asr_model = EncoderASR.from_hparams(
source="brdhaker3/TunASR",
savedir = "./model"
)
# Loading Custom Tokenizer
encoder = CTCTextEncoder()
encoder.load_or_create(
path=asr_model.hparams.encoder_file,
from_didatasets=[[]],
output_key="char_list",
special_labels={"blank_label": 0, "unk_label": 1},
sequence_input=True,
)
asr_model.tokenizer = encoder
# Prepare labels for the CTC decoder
vocab = asr_model.tokenizer.ind2lab
labels = [vocab[i] for i in range(len(vocab))] # Extract labels from the tokenizer
labels = [""] + labels[1:-1] + ["1"] # Adjust labels to match CTC format
# Initialize the CTC decoder with a language model
decoder = build_ctcdecoder(
labels,
kenlm_model_path=asr_model.hparams.ngram_lm_path, # Path to your LM
alpha=0.5, # LM weight
beta=1.0, # Word insertion penalty
)
class ASR(sb.core.Brain):
def treat_wav(self, sig):
"""Process a waveform and return the transcribed text."""
feats = self.modules.wav2vec2(sig.to("cpu"), torch.tensor([1]).to("cpu"))
feats = self.modules.enc(feats)
logits = self.modules.ctc_lin(feats)
p_ctc = self.hparams.log_softmax(logits)
predicted_words = []
for logs in p_ctc:
text = decoder.decode(logs.detach().cpu().numpy())
predicted_words.append(text.split(" "))
return " ".join(predicted_words[0])
# Initialize the ASR model
asr_brain = ASR(
modules=asr_model.hparams.modules,
hparams=vars(asr_model.hparams),
run_opts={"device": "cpu"},
checkpointer=asr_model.hparams.checkpointer,
)
asr_brain.tokenizer = encoder
asr_brain.checkpointer.recover_if_possible()
asr_brain.modules.eval()
def treat_wav_file(file_mic, file_upload, asr=asr_brain, device="cpu"):
if file_mic is not None:
wav = file_mic
elif file_upload is not None:
wav = file_upload
else:
return "ERROR: You have to either use the microphone or upload an audio file"
# Read and preprocess the audio file
info = torchaudio.info(wav)
sr = info.sample_rate
sig = sb.dataio.dataio.read_audio(wav)
if len(sig.shape) > 1:
sig = torch.mean(sig, dim=1)
sig = torch.unsqueeze(sig, 0)
tensor_wav = sig.to(device)
resampled = torchaudio.functional.resample(tensor_wav, sr, 16000)
# Transcribe the audio
sentence = asr.treat_wav(resampled)
return sentence
# Test the function
# print(treat_wav_file("./audio.wav", "./audio.wav"))
#Gradio interface
import gradio as gr
title = "Tunisian Speech Recognition"
description = '''
This is a Tunisian ASR based on the **WavLM Model**, fine-tuned on a dataset of **2.5 hours**, resulting in a **W.E.R of 24%** and a **C.E.R of 9%**.
\n
Interested? Try it out!
'''
disclaimer = '''
> ⚠️ **Disclaimer:**
> This is a **demo model**, The transcription accuracy isn't accurate due to Hugging Face model storage constraints.
> For better performance,you can run the full model locally.
> Please check out the repository and follow the instructions: [Full Model Repo Link](https://huggingface.co/brdhaker3/TunASR)
'''
with gr.Blocks() as demo:
gr.Markdown(f"# {title}")
gr.Markdown(description)
gr.Markdown(disclaimer)
interface = gr.Interface(
fn=treat_wav_file,
inputs=[
gr.Audio(sources="microphone", type='filepath', label="Record"),
gr.Audio(sources="upload", type='filepath', label="Upload File")
],
outputs="text",
title="",
description=""
)
demo.launch()
|