Spaces:
Runtime error
Runtime error
Vaibhav Srivastav
commited on
Commit
·
379fa33
1
Parent(s):
3b8d409
adding greedy decoding
Browse files
app.py
CHANGED
@@ -22,8 +22,7 @@ def load_and_fix_data(input_file):
|
|
22 |
if sample_rate !=16000:
|
23 |
speech = librosa.resample(speech, sample_rate,16000)
|
24 |
return speech
|
25 |
-
|
26 |
-
|
27 |
def fix_transcription_casing(input_sentence):
|
28 |
sentences = nltk.sent_tokenize(input_sentence)
|
29 |
return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
|
@@ -41,10 +40,27 @@ def predict_and_decode(input_file):
|
|
41 |
transcribed_text = fix_transcription_casing(pred.lower())
|
42 |
|
43 |
return transcribed_text
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
gr.Interface(predict_and_decode,
|
46 |
inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Record/ Drop audio"),
|
47 |
-
outputs = gr.outputs.Textbox(label="
|
48 |
title="ASR using Wav2Vec 2.0 & pyctcdecode",
|
49 |
description = "Extending HF ASR models with pyctcdecode decoder",
|
50 |
layout = "horizontal",
|
|
|
22 |
if sample_rate !=16000:
|
23 |
speech = librosa.resample(speech, sample_rate,16000)
|
24 |
return speech
|
25 |
+
|
|
|
26 |
def fix_transcription_casing(input_sentence):
|
27 |
sentences = nltk.sent_tokenize(input_sentence)
|
28 |
return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
|
|
|
40 |
transcribed_text = fix_transcription_casing(pred.lower())
|
41 |
|
42 |
return transcribed_text
|
43 |
+
|
44 |
+
def predict_and_greedy_decode(input_file):
|
45 |
+
speech = load_and_fix_data(input_file)
|
46 |
+
|
47 |
+
input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
|
48 |
+
logits = model(input_values).logits
|
49 |
+
|
50 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
51 |
+
pred = processor.batch_decode(predicted_ids)
|
52 |
+
|
53 |
+
transcribed_text = fix_transcription_casing(pred.lower())
|
54 |
+
|
55 |
+
return transcribed_text
|
56 |
+
|
57 |
+
def return_all_predictions(input_file):
|
58 |
+
return predict_and_decode(input_file), predict_and_greedy_decode(input_file)
|
59 |
+
|
60 |
+
|
61 |
gr.Interface(predict_and_decode,
|
62 |
inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Record/ Drop audio"),
|
63 |
+
outputs = [gr.outputs.Textbox(label="Beam CTC Decoding"), gr.outputs.Textbox(label="Greedy Decoding")],
|
64 |
title="ASR using Wav2Vec 2.0 & pyctcdecode",
|
65 |
description = "Extending HF ASR models with pyctcdecode decoder",
|
66 |
layout = "horizontal",
|