File size: 2,365 Bytes
57926d1
 
 
a05c869
0448aa2
 
a05c869
57926d1
a05c869
57926d1
 
0448aa2
57926d1
 
3b8d409
57926d1
 
3b8d409
57926d1
 
0448aa2
57926d1
 
 
379fa33
0448aa2
57926d1
 
 
bbbf923
3b8d409
 
0448aa2
d32240b
3b8d409
 
0448aa2
 
 
 
3b8d409
0448aa2
379fa33
 
 
 
 
 
 
 
 
 
bbbf923
379fa33
 
 
 
bbbf923
379fa33
 
8a068ad
3b8d409
379fa33
0448aa2
3b8d409
0448aa2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import nltk
import librosa
import torch
import gradio as gr
from pyctcdecode import build_ctcdecoder
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC

nltk.download("punkt")

#Loading the model and the tokenizer
model_name = "facebook/wav2vec2-base-960h"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)

def load_and_fix_data(input_file):  
  #read the file
  speech, sample_rate = librosa.load(input_file)
  #make it 1D
  if len(speech.shape) > 1: 
      speech = speech[:,0] + speech[:,1]
  #resampling to 16KHz
  if sample_rate !=16000:
    speech = librosa.resample(speech, sample_rate,16000)
  return speech

def fix_transcription_casing(input_sentence):
  sentences = nltk.sent_tokenize(input_sentence)
  return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
  
def predict_and_ctc_decode(input_file):
  speech = load_and_fix_data(input_file)

  input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
  logits = model(input_values).logits.cpu().detach().numpy()[0]
  
  vocab_list = list(processor.tokenizer.get_vocab().keys())  
  decoder = build_ctcdecoder(vocab_list)
  pred = decoder.decode(logits)

  transcribed_text = fix_transcription_casing(pred.lower())

  return transcribed_text

def predict_and_greedy_decode(input_file):
  speech = load_and_fix_data(input_file)

  input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
  logits = model(input_values).logits

  predicted_ids = torch.argmax(logits, dim=-1)
  pred = processor.batch_decode(predicted_ids)

  transcribed_text = fix_transcription_casing(pred[0].lower())

  return transcribed_text

def return_all_predictions(input_file):
  return predict_and_ctc_decode(input_file), predict_and_greedy_decode(input_file)


gr.Interface(return_all_predictions,
             inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Record/ Drop audio"),
             outputs = [gr.outputs.Textbox(label="Beam CTC Decoding"), gr.outputs.Textbox(label="Greedy Decoding")],
             title="ASR using Wav2Vec 2.0 & pyctcdecode",
             description = "Extending HF ASR models with pyctcdecode decoder",
             layout = "horizontal",
             examples = [["test.wav"]], theme="huggingface").launch()