Spaces:
Runtime error
Runtime error
File size: 2,365 Bytes
57926d1 a05c869 0448aa2 a05c869 57926d1 a05c869 57926d1 0448aa2 57926d1 3b8d409 57926d1 3b8d409 57926d1 0448aa2 57926d1 379fa33 0448aa2 57926d1 bbbf923 3b8d409 0448aa2 d32240b 3b8d409 0448aa2 3b8d409 0448aa2 379fa33 bbbf923 379fa33 bbbf923 379fa33 8a068ad 3b8d409 379fa33 0448aa2 3b8d409 0448aa2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import nltk
import librosa
import torch
import gradio as gr
from pyctcdecode import build_ctcdecoder
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
nltk.download("punkt")
#Loading the model and the tokenizer
model_name = "facebook/wav2vec2-base-960h"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)
def load_and_fix_data(input_file):
#read the file
speech, sample_rate = librosa.load(input_file)
#make it 1D
if len(speech.shape) > 1:
speech = speech[:,0] + speech[:,1]
#resampling to 16KHz
if sample_rate !=16000:
speech = librosa.resample(speech, sample_rate,16000)
return speech
def fix_transcription_casing(input_sentence):
sentences = nltk.sent_tokenize(input_sentence)
return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
def predict_and_ctc_decode(input_file):
speech = load_and_fix_data(input_file)
input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
logits = model(input_values).logits.cpu().detach().numpy()[0]
vocab_list = list(processor.tokenizer.get_vocab().keys())
decoder = build_ctcdecoder(vocab_list)
pred = decoder.decode(logits)
transcribed_text = fix_transcription_casing(pred.lower())
return transcribed_text
def predict_and_greedy_decode(input_file):
speech = load_and_fix_data(input_file)
input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
pred = processor.batch_decode(predicted_ids)
transcribed_text = fix_transcription_casing(pred[0].lower())
return transcribed_text
def return_all_predictions(input_file):
return predict_and_ctc_decode(input_file), predict_and_greedy_decode(input_file)
gr.Interface(return_all_predictions,
inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Record/ Drop audio"),
outputs = [gr.outputs.Textbox(label="Beam CTC Decoding"), gr.outputs.Textbox(label="Greedy Decoding")],
title="ASR using Wav2Vec 2.0 & pyctcdecode",
description = "Extending HF ASR models with pyctcdecode decoder",
layout = "horizontal",
examples = [["test.wav"]], theme="huggingface").launch() |