Update asr_lm_eng.py
Browse files- asr_lm_eng.py +48 -63
asr_lm_eng.py
CHANGED
|
@@ -21,54 +21,56 @@ processor = AutoProcessor.from_pretrained(MODEL_ID)
|
|
| 21 |
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
|
| 22 |
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
#
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
|
| 69 |
|
| 70 |
def transcribe(audio_data=None, lang="eng (English)"):
|
| 71 |
|
|
|
|
|
|
|
| 72 |
if not audio_data:
|
| 73 |
return "<<ERROR: Empty Audio Input>>"
|
| 74 |
|
|
@@ -113,24 +115,7 @@ def transcribe(audio_data=None, lang="eng (English)"):
|
|
| 113 |
with torch.no_grad():
|
| 114 |
outputs = model(**inputs).logits
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
transcription = processor.decode(ids)
|
| 119 |
-
else:
|
| 120 |
-
assert False
|
| 121 |
-
# beam_search_result = beam_search_decoder(outputs.to("cpu"))
|
| 122 |
-
# transcription = " ".join(beam_search_result[0][0].words).strip()
|
| 123 |
|
| 124 |
return transcription
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
ASR_EXAMPLES = [
|
| 128 |
-
["upload/english.mp3", "eng (English)"],
|
| 129 |
-
# ["upload/tamil.mp3", "tam (Tamil)"],
|
| 130 |
-
# ["upload/burmese.mp3", "mya (Burmese)"],
|
| 131 |
-
]
|
| 132 |
-
|
| 133 |
-
ASR_NOTE = """
|
| 134 |
-
The above demo doesn't use beam-search decoding using a language model.
|
| 135 |
-
Checkout the instructions [here](https://huggingface.co/facebook/mms-1b-all) on how to run LM decoding for better accuracy.
|
| 136 |
-
"""
|
|
|
|
| 21 |
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
|
| 22 |
|
| 23 |
|
| 24 |
+
lm_decoding_config = {}
|
| 25 |
+
lm_decoding_configfile = hf_hub_download(
|
| 26 |
+
repo_id="facebook/mms-cclms",
|
| 27 |
+
filename="decoding_config.json",
|
| 28 |
+
subfolder="mms-1b-all",
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
with open(lm_decoding_configfile) as f:
|
| 32 |
+
lm_decoding_config = json.loads(f.read())
|
| 33 |
+
|
| 34 |
+
# allow language model decoding for "eng"
|
| 35 |
+
|
| 36 |
+
decoding_config = lm_decoding_config["eng"]
|
| 37 |
+
|
| 38 |
+
lm_file = hf_hub_download(
|
| 39 |
+
repo_id="facebook/mms-cclms",
|
| 40 |
+
filename=decoding_config["lmfile"].rsplit("/", 1)[1],
|
| 41 |
+
subfolder=decoding_config["lmfile"].rsplit("/", 1)[0],
|
| 42 |
+
)
|
| 43 |
+
token_file = hf_hub_download(
|
| 44 |
+
repo_id="facebook/mms-cclms",
|
| 45 |
+
filename=decoding_config["tokensfile"].rsplit("/", 1)[1],
|
| 46 |
+
subfolder=decoding_config["tokensfile"].rsplit("/", 1)[0],
|
| 47 |
+
)
|
| 48 |
+
lexicon_file = None
|
| 49 |
+
if decoding_config["lexiconfile"] is not None:
|
| 50 |
+
lexicon_file = hf_hub_download(
|
| 51 |
+
repo_id="facebook/mms-cclms",
|
| 52 |
+
filename=decoding_config["lexiconfile"].rsplit("/", 1)[1],
|
| 53 |
+
subfolder=decoding_config["lexiconfile"].rsplit("/", 1)[0],
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
beam_search_decoder = ctc_decoder(
|
| 57 |
+
lexicon=lexicon_file,
|
| 58 |
+
tokens=token_file,
|
| 59 |
+
lm=lm_file,
|
| 60 |
+
nbest=1,
|
| 61 |
+
beam_size=500,
|
| 62 |
+
beam_size_token=50,
|
| 63 |
+
lm_weight=float(decoding_config["lmweight"]),
|
| 64 |
+
word_score=float(decoding_config["wordscore"]),
|
| 65 |
+
sil_score=float(decoding_config["silweight"]),
|
| 66 |
+
blank_token="<s>",
|
| 67 |
+
)
|
| 68 |
|
| 69 |
|
| 70 |
def transcribe(audio_data=None, lang="eng (English)"):
|
| 71 |
|
| 72 |
+
assert lang.startswith("eng")
|
| 73 |
+
|
| 74 |
if not audio_data:
|
| 75 |
return "<<ERROR: Empty Audio Input>>"
|
| 76 |
|
|
|
|
| 115 |
with torch.no_grad():
|
| 116 |
outputs = model(**inputs).logits
|
| 117 |
|
| 118 |
+
beam_search_result = beam_search_decoder(outputs.to("cpu"))
|
| 119 |
+
transcription = " ".join(beam_search_result[0][0].words).strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
return transcription
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|