Model with more data
Browse files- eval.py +20 -5
- language_model/attrs.json +1 -1
- train.ipynb +0 -0
- vocab.json +1 -1
eval.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
from datasets import load_dataset, load_metric, Audio, Dataset
|
| 3 |
-
from transformers import pipeline, AutoFeatureExtractor
|
|
|
|
| 4 |
import re
|
| 5 |
import argparse
|
| 6 |
import unicodedata
|
|
@@ -106,18 +107,29 @@ def main(args):
|
|
| 106 |
dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True)
|
| 107 |
|
| 108 |
# for testing: only process the first two examples as a test
|
| 109 |
-
|
|
|
|
| 110 |
|
| 111 |
-
# load processor
|
| 112 |
feature_extractor = AutoFeatureExtractor.from_pretrained(args.model_id)
|
|
|
|
| 113 |
sampling_rate = feature_extractor.sampling_rate
|
| 114 |
|
| 115 |
# resample audio
|
| 116 |
dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
-
|
| 119 |
-
|
| 120 |
|
|
|
|
|
|
|
|
|
|
| 121 |
# map function to decode audio
|
| 122 |
def map_to_pred(batch):
|
| 123 |
prediction = asr(batch["audio"]["array"], chunk_length_s=args.chunk_length_s, stride_length_s=args.stride_length_s)
|
|
@@ -158,6 +170,9 @@ if __name__ == "__main__":
|
|
| 158 |
parser.add_argument(
|
| 159 |
"--log_outputs", action='store_true', help="If defined, write outputs to log file for analysis."
|
| 160 |
)
|
|
|
|
|
|
|
|
|
|
| 161 |
args = parser.parse_args()
|
| 162 |
|
| 163 |
main(args)
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
from datasets import load_dataset, load_metric, Audio, Dataset
|
| 3 |
+
from transformers import pipeline, AutoFeatureExtractor, AutoTokenizer, Wav2Vec2ForCTC
|
| 4 |
+
import os
|
| 5 |
import re
|
| 6 |
import argparse
|
| 7 |
import unicodedata
|
|
|
|
| 107 |
dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True)
|
| 108 |
|
| 109 |
# for testing: only process the first two examples as a test
|
| 110 |
+
if args.limit:
|
| 111 |
+
dataset = dataset.select(range(limit))
|
| 112 |
|
|
|
|
| 113 |
feature_extractor = AutoFeatureExtractor.from_pretrained(args.model_id)
|
| 114 |
+
# load processor
|
| 115 |
sampling_rate = feature_extractor.sampling_rate
|
| 116 |
|
| 117 |
# resample audio
|
| 118 |
dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
|
| 119 |
+
|
| 120 |
+
asr = None
|
| 121 |
+
|
| 122 |
+
if os.path.exists(args.model_id):
|
| 123 |
+
model = Wav2Vec2ForCTC.from_pretrained(args.model_id)
|
| 124 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
|
| 125 |
+
|
| 126 |
|
| 127 |
+
# load eval pipeline
|
| 128 |
+
asr = pipeline("automatic-speech-recognition", model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
|
| 129 |
|
| 130 |
+
else:
|
| 131 |
+
asr = pipeline("automatic-speech-recognition", model=args.model_id)
|
| 132 |
+
|
| 133 |
# map function to decode audio
|
| 134 |
def map_to_pred(batch):
|
| 135 |
prediction = asr(batch["audio"]["array"], chunk_length_s=args.chunk_length_s, stride_length_s=args.stride_length_s)
|
|
|
|
| 170 |
parser.add_argument(
|
| 171 |
"--log_outputs", action='store_true', help="If defined, write outputs to log file for analysis."
|
| 172 |
)
|
| 173 |
+
parser.add_argument(
|
| 174 |
+
"--limit", type=int, help="Not required. If greater than zero, select a subset of this size from the dataset.", default=0
|
| 175 |
+
)
|
| 176 |
args = parser.parse_args()
|
| 177 |
|
| 178 |
main(args)
|
language_model/attrs.json
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
{"alpha": 0.
|
|
|
|
| 1 |
+
{"alpha": 0.9, "beta": 2.5, "unk_score_offset": -10.0, "score_boundary": true}
|
train.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
vocab.json
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
{"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6, "g": 7, "h": 8, "i": 9, "j": 10, "k": 11, "l": 12, "m": 13, "n": 14, "o": 15, "p": 16, "q": 17, "r": 18, "s": 19, "t": 20, "u": 21, "v": 22, "w": 23, "x": 24, "y": 25, "z": 26, "
|
|
|
|
| 1 |
+
{"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6, "g": 7, "h": 8, "i": 9, "j": 10, "k": 11, "l": 12, "m": 13, "n": 14, "o": 15, "p": 16, "q": 17, "r": 18, "s": 19, "t": 20, "u": 21, "v": 22, "w": 23, "x": 24, "y": 25, "z": 26, "\u00e1": 27, "\u00e9": 28, "\u00ed": 29, "\u00f3": 30, "\u00fa": 31, "\u00fd": 32, "\u010d": 33, "\u010f": 34, "\u011b": 35, "\u0148": 36, "\u0159": 37, "\u0161": 38, "\u0165": 39, "\u016f": 40, "\u017e": 41, "|": 0, "[UNK]": 42, "[PAD]": 43}
|