Spaces:
Runtime error
Runtime error
File size: 2,529 Bytes
57926d1 a05c869 0448aa2 e5ea233 a05c869 57926d1 a05c869 57926d1 42a9fb0 0d56eb9 57926d1 3b8d409 57926d1 3b8d409 57926d1 0448aa2 57926d1 379fa33 0448aa2 57926d1 bbbf923 3b8d409 0448aa2 d32240b 3b8d409 0448aa2 3b8d409 0448aa2 379fa33 bbbf923 379fa33 decaa84 bbbf923 379fa33 8a068ad 607a780 379fa33 0448aa2 3b8d409 0448aa2 a084c8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import nltk
import librosa
import torch
import gradio as gr
from pyctcdecode import build_ctcdecoder
from transformers import AutoProcessor, AutoModelForCTC
nltk.download("punkt")
model_name = "facebook/wav2vec2-base-960h"
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForCTC.from_pretrained(model_name)
def load_and_fix_data(input_file):
#read the file
speech, sample_rate = librosa.load(input_file)
#make it 1D
if len(speech.shape) > 1:
speech = speech[:,0] + speech[:,1]
#resampling to 16KHz
if sample_rate !=16000:
speech = librosa.resample(speech, sample_rate,16000)
return speech
def fix_transcription_casing(input_sentence):
sentences = nltk.sent_tokenize(input_sentence)
return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
def predict_and_ctc_decode(input_file):
speech = load_and_fix_data(input_file)
input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
logits = model(input_values).logits.cpu().detach().numpy()[0]
vocab_list = list(processor.tokenizer.get_vocab().keys())
decoder = build_ctcdecoder(vocab_list)
pred = decoder.decode(logits)
transcribed_text = fix_transcription_casing(pred.lower())
return transcribed_text
def predict_and_greedy_decode(input_file):
speech = load_and_fix_data(input_file)
input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
pred = processor.batch_decode(predicted_ids)
transcribed_text = fix_transcription_casing(pred[0].lower())
return transcribed_text
def return_all_predictions(input_file, model_name):
print(model_name)
return predict_and_ctc_decode(input_file), predict_and_greedy_decode(input_file)
gr.Interface(return_all_predictions,
inputs = [gr.inputs.Audio(source="microphone", type="filepath", label="Record/ Drop audio"), gr.inputs.Dropdown(["facebook/wav2vec2-base-960h", "facebook/hubert-large-ls960-ft"], label="Model Name")],
outputs = [gr.outputs.Textbox(label="Beam CTC Decoding"), gr.outputs.Textbox(label="Greedy Decoding")],
title="ASR using Wav2Vec 2.0 & pyctcdecode",
description = "Extending HF ASR models with pyctcdecode decoder",
layout = "horizontal",
examples = [["test1.wav", "test2.wav"], ["facebook/wav2vec2-base-960h", "facebook/hubert-large-ls960-ft"]], theme="huggingface").launch() |