File size: 2,529 Bytes
57926d1
 
 
a05c869
0448aa2
e5ea233
a05c869
57926d1
a05c869
57926d1
42a9fb0
0d56eb9
57926d1
3b8d409
57926d1
 
3b8d409
57926d1
 
0448aa2
57926d1
 
 
379fa33
0448aa2
57926d1
 
 
bbbf923
3b8d409
 
0448aa2
d32240b
3b8d409
 
0448aa2
 
 
 
3b8d409
0448aa2
379fa33
 
 
 
 
 
 
 
 
 
bbbf923
379fa33
 
 
decaa84
 
bbbf923
379fa33
 
8a068ad
607a780
379fa33
0448aa2
3b8d409
0448aa2
a084c8f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import nltk
import librosa
import torch
import gradio as gr
from pyctcdecode import build_ctcdecoder
from transformers import AutoProcessor, AutoModelForCTC

nltk.download("punkt")

model_name = "facebook/wav2vec2-base-960h"
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForCTC.from_pretrained(model_name)

def load_and_fix_data(input_file):  
  #read the file
  speech, sample_rate = librosa.load(input_file)
  #make it 1D
  if len(speech.shape) > 1: 
      speech = speech[:,0] + speech[:,1]
  #resampling to 16KHz
  if sample_rate !=16000:
    speech = librosa.resample(speech, sample_rate,16000)
  return speech

def fix_transcription_casing(input_sentence):
  sentences = nltk.sent_tokenize(input_sentence)
  return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
  
def predict_and_ctc_decode(input_file):
  speech = load_and_fix_data(input_file)

  input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
  logits = model(input_values).logits.cpu().detach().numpy()[0]
  
  vocab_list = list(processor.tokenizer.get_vocab().keys())  
  decoder = build_ctcdecoder(vocab_list)
  pred = decoder.decode(logits)

  transcribed_text = fix_transcription_casing(pred.lower())

  return transcribed_text

def predict_and_greedy_decode(input_file):
  speech = load_and_fix_data(input_file)

  input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
  logits = model(input_values).logits

  predicted_ids = torch.argmax(logits, dim=-1)
  pred = processor.batch_decode(predicted_ids)

  transcribed_text = fix_transcription_casing(pred[0].lower())

  return transcribed_text

def return_all_predictions(input_file, model_name):
  print(model_name)
  return predict_and_ctc_decode(input_file), predict_and_greedy_decode(input_file)


gr.Interface(return_all_predictions,
             inputs = [gr.inputs.Audio(source="microphone", type="filepath", label="Record/ Drop audio"), gr.inputs.Dropdown(["facebook/wav2vec2-base-960h", "facebook/hubert-large-ls960-ft"], label="Model Name")],
             outputs = [gr.outputs.Textbox(label="Beam CTC Decoding"), gr.outputs.Textbox(label="Greedy Decoding")],
             title="ASR using Wav2Vec 2.0 & pyctcdecode",
             description = "Extending HF ASR models with pyctcdecode decoder",
             layout = "horizontal",
             examples = [["test1.wav", "test2.wav"], ["facebook/wav2vec2-base-960h", "facebook/hubert-large-ls960-ft"]], theme="huggingface").launch()