Update WER score + Publish Benchmark Results
Browse files
README.md
CHANGED
@@ -23,7 +23,7 @@ model-index:
|
|
23 |
metrics:
|
24 |
- name: Test WER
|
25 |
type: wer
|
26 |
-
value: 23.
|
27 |
---
|
28 |
|
29 |
# Sinai Voice Arabic Speech Recognition Model
|
@@ -31,12 +31,136 @@ model-index:
|
|
31 |
Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53)
|
32 |
on Arabic using the [Common Voice](https://huggingface.co/datasets/common_voice)
|
33 |
|
34 |
-
|
35 |
-
## Usage
|
36 |
|
37 |
Please install:
|
38 |
- [PyTorch](https://pytorch.org/)
|
39 |
-
- `$ pip3 install jiwer lang_trans torchaudio datasets transformers`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
The model can be used directly (without a language model) as follows:
|
42 |
```python
|
@@ -51,10 +175,15 @@ resamplers = { # all three sampling rates exist in test split
|
|
51 |
44100: torchaudio.transforms.Resample(44100, 16000),
|
52 |
32000: torchaudio.transforms.Resample(32000, 16000),
|
53 |
}
|
|
|
54 |
def prepare_example(example):
|
55 |
speech, sampling_rate = torchaudio.load(example["path"])
|
56 |
-
|
|
|
|
|
|
|
57 |
return example
|
|
|
58 |
dataset = dataset.map(prepare_example)
|
59 |
processor = Wav2Vec2Processor.from_pretrained("bakrianoo/sinai-voice-ar-stt")
|
60 |
model = Wav2Vec2ForCTC.from_pretrained("bakrianoo/sinai-voice-ar-stt").eval()
|
@@ -103,9 +232,8 @@ predicted: أين المشكل
|
|
103 |
reference: وَلِلَّهِ يَسْجُدُ مَا فِي السَّمَاوَاتِ وَمَا فِي الْأَرْضِ مِنْ دَابَّةٍ وَالْمَلَائِكَةُ وَهُمْ لَا يَسْتَكْبِرُونَ
|
104 |
predicted: ولله يسجد ما في السماوات وما في الأرض من دابة والملائكة وهم لا يستكبرون
|
105 |
```
|
106 |
-
## Evaluation
|
107 |
|
108 |
-
|
109 |
|
110 |
The model can be evaluated as follows on the Arabic test data of Common Voice:
|
111 |
```python
|
@@ -122,10 +250,15 @@ resamplers = { # all three sampling rates exist in test split
|
|
122 |
44100: torchaudio.transforms.Resample(44100, 16000),
|
123 |
32000: torchaudio.transforms.Resample(32000, 16000),
|
124 |
}
|
|
|
125 |
def prepare_example(example):
|
126 |
speech, sampling_rate = torchaudio.load(example["path"])
|
127 |
-
|
|
|
|
|
|
|
128 |
return example
|
|
|
129 |
test_split = test_split.map(prepare_example)
|
130 |
processor = Wav2Vec2Processor.from_pretrained("bakrianoo/sinai-voice-ar-stt")
|
131 |
model = Wav2Vec2ForCTC.from_pretrained("bakrianoo/sinai-voice-ar-stt").to("cuda").eval()
|
@@ -141,8 +274,8 @@ test_split = test_split.map(predict, batched=True, batch_size=16, remove_columns
|
|
141 |
transformation = jiwer.Compose([
|
142 |
# normalize some diacritics, remove punctuation, and replace Persian letters with Arabic ones
|
143 |
jiwer.SubstituteRegexes({
|
144 |
-
r'[auiFNKo
|
145 |
-
r"[
|
146 |
# default transformation below
|
147 |
jiwer.RemoveMultipleSpaces(),
|
148 |
jiwer.Strip(),
|
@@ -158,7 +291,7 @@ metrics = jiwer.compute_measures(
|
|
158 |
)
|
159 |
print(f"WER: {metrics['wer']:.2%}")
|
160 |
```
|
161 |
-
**Test Result**: 23.
|
162 |
|
163 |
|
164 |
## Other Arabic Voice recognition Models
|
|
|
23 |
metrics:
|
24 |
- name: Test WER
|
25 |
type: wer
|
26 |
+
value: 23.80
|
27 |
---
|
28 |
|
29 |
# Sinai Voice Arabic Speech Recognition Model
|
|
|
31 |
Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53)
|
32 |
on Arabic using the [Common Voice](https://huggingface.co/datasets/common_voice)
|
33 |
|
34 |
+
Most of evaluation codes in this documentation are INSPIRED by [elgeish/wav2vec2-large-xlsr-53-arabic](https://huggingface.co/elgeish/wav2vec2-large-xlsr-53-arabic)
|
|
|
35 |
|
36 |
Please install:
|
37 |
- [PyTorch](https://pytorch.org/)
|
38 |
+
- `$ pip3 install jiwer lang_trans torchaudio datasets transformers pandas tqdm`
|
39 |
+
|
40 |
+
## Benchmark
|
41 |
+
|
42 |
+
We evaluated the model against different Arabic-STT Wav2Vec models.
|
43 |
+
|
44 |
+
| | model | using_transliation | WER |
|
45 |
+
|---:|:--------------------------------------|:---------------------|---------:|
|
46 |
+
| 0 | bakrianoo/sinai-voice-ar-stt | True | 0.238001 |
|
47 |
+
| 1 | elgeish/wav2vec2-large-xlsr-53-arabic | True | 0.266527 |
|
48 |
+
| 2 | othrif/wav2vec2-large-xlsr-arabic | True | 0.298122 |
|
49 |
+
| 3 | bakrianoo/sinai-voice-ar-stt | False | 0.448987 |
|
50 |
+
| 4 | othrif/wav2vec2-large-xlsr-arabic | False | 0.464004 |
|
51 |
+
| 5 | anas/wav2vec2-large-xlsr-arabic | True | 0.506191 |
|
52 |
+
| 6 | anas/wav2vec2-large-xlsr-arabic | False | 0.622288 |
|
53 |
+
|
54 |
+
|
55 |
+
<details>
|
56 |
+
<summary>We used the following <b>CODE</b> to generate the above results</summary>
|
57 |
+
|
58 |
+
```python
|
59 |
+
import jiwer
|
60 |
+
import torch
|
61 |
+
from tqdm.auto import tqdm
|
62 |
+
import torchaudio
|
63 |
+
from datasets import load_dataset
|
64 |
+
from lang_trans.arabic import buckwalter
|
65 |
+
from transformers import set_seed, Wav2Vec2ForCTC, Wav2Vec2Processor
|
66 |
+
import pandas as pd
|
67 |
+
|
68 |
+
# load test dataset
|
69 |
+
set_seed(42)
|
70 |
+
test_split = load_dataset("common_voice", "ar", split="test")
|
71 |
+
|
72 |
+
# init sample rate resamplers
|
73 |
+
resamplers = { # all three sampling rates exist in test split
|
74 |
+
48000: torchaudio.transforms.Resample(48000, 16000),
|
75 |
+
44100: torchaudio.transforms.Resample(44100, 16000),
|
76 |
+
32000: torchaudio.transforms.Resample(32000, 16000),
|
77 |
+
}
|
78 |
+
|
79 |
+
# WER composer
|
80 |
+
transformation = jiwer.Compose([
|
81 |
+
# normalize some diacritics, remove punctuation, and replace Persian letters with Arabic ones
|
82 |
+
jiwer.SubstituteRegexes({
|
83 |
+
r'[auiFNKo\~_،؟»\?;:\-,\.؛«!"]': "", "\u06D6": "",
|
84 |
+
r"[\|\{]": "A", "p": "h", "ک": "k", "ی": "y"}),
|
85 |
+
# default transformation below
|
86 |
+
jiwer.RemoveMultipleSpaces(),
|
87 |
+
jiwer.Strip(),
|
88 |
+
jiwer.SentencesToListOfWords(),
|
89 |
+
jiwer.RemoveEmptyStrings(),
|
90 |
+
])
|
91 |
+
|
92 |
+
def prepare_example(example):
|
93 |
+
speech, sampling_rate = torchaudio.load(example["path"])
|
94 |
+
if sampling_rate in resamplers:
|
95 |
+
example["speech"] = resamplers[sampling_rate](speech).squeeze().numpy()
|
96 |
+
else:
|
97 |
+
example["speech"] = resamplers[4800](speech).squeeze().numpy()
|
98 |
+
return example
|
99 |
+
|
100 |
+
def predict(batch):
|
101 |
+
inputs = processor(batch["speech"], sampling_rate=16000, return_tensors="pt", padding=True)
|
102 |
+
with torch.no_grad():
|
103 |
+
predicted = torch.argmax(model(inputs.input_values.to("cuda")).logits, dim=-1)
|
104 |
+
predicted[predicted == -100] = processor.tokenizer.pad_token_id # see fine-tuning script
|
105 |
+
batch["predicted"] = processor.batch_decode(predicted)
|
106 |
+
return batch
|
107 |
+
|
108 |
+
# prepare the test dataset
|
109 |
+
test_split = test_split.map(prepare_example)
|
110 |
+
|
111 |
+
stt_models = {
|
112 |
+
"elgeish/wav2vec2-large-xlsr-53-arabic",
|
113 |
+
"othrif/wav2vec2-large-xlsr-arabic",
|
114 |
+
"anas/wav2vec2-large-xlsr-arabic",
|
115 |
+
"bakrianoo/sinai-voice-ar-stt"
|
116 |
+
}
|
117 |
+
|
118 |
+
stt_results = []
|
119 |
+
|
120 |
+
for model_path in tqdm(stt_models):
|
121 |
+
processor = Wav2Vec2Processor.from_pretrained(model_path)
|
122 |
+
model = Wav2Vec2ForCTC.from_pretrained(model_path).to("cuda").eval()
|
123 |
+
|
124 |
+
test_split_preds = test_split.map(predict, batched=True, batch_size=56, remove_columns=["speech"])
|
125 |
+
|
126 |
+
orig_metrics = jiwer.compute_measures(
|
127 |
+
truth=[s for s in test_split_preds["sentence"]],
|
128 |
+
hypothesis=[s for s in test_split_preds["predicted"]],
|
129 |
+
truth_transform=transformation,
|
130 |
+
hypothesis_transform=transformation,
|
131 |
+
)
|
132 |
+
|
133 |
+
trans_metrics = jiwer.compute_measures(
|
134 |
+
truth=[buckwalter.trans(s) for s in test_split_preds["sentence"]], # Buckwalter transliteration
|
135 |
+
hypothesis=[buckwalter.trans(s) for s in test_split_preds["predicted"]], # Buckwalter transliteration
|
136 |
+
truth_transform=transformation,
|
137 |
+
hypothesis_transform=transformation,
|
138 |
+
)
|
139 |
+
|
140 |
+
stt_results.append({
|
141 |
+
"model": model_path,
|
142 |
+
"using_transliation": True,
|
143 |
+
"WER": trans_metrics["wer"]
|
144 |
+
})
|
145 |
+
|
146 |
+
stt_results.append({
|
147 |
+
"model": model_path,
|
148 |
+
"using_transliation": False,
|
149 |
+
"WER": orig_metrics["wer"]
|
150 |
+
})
|
151 |
+
|
152 |
+
del model
|
153 |
+
del processor
|
154 |
+
|
155 |
+
stt_results_df = pd.DataFrame(stt_results)
|
156 |
+
stt_results_df = stt_results_df.sort_values('WER', axis=0, ascending=True)
|
157 |
+
stt_results_df.head(n=50)
|
158 |
+
|
159 |
+
```
|
160 |
+
</details>
|
161 |
+
|
162 |
+
|
163 |
+
## Usage
|
164 |
|
165 |
The model can be used directly (without a language model) as follows:
|
166 |
```python
|
|
|
175 |
44100: torchaudio.transforms.Resample(44100, 16000),
|
176 |
32000: torchaudio.transforms.Resample(32000, 16000),
|
177 |
}
|
178 |
+
|
179 |
def prepare_example(example):
|
180 |
speech, sampling_rate = torchaudio.load(example["path"])
|
181 |
+
if sampling_rate in resamplers:
|
182 |
+
example["speech"] = resamplers[sampling_rate](speech).squeeze().numpy()
|
183 |
+
else:
|
184 |
+
example["speech"] = resamplers[4800](speech).squeeze().numpy()
|
185 |
return example
|
186 |
+
|
187 |
dataset = dataset.map(prepare_example)
|
188 |
processor = Wav2Vec2Processor.from_pretrained("bakrianoo/sinai-voice-ar-stt")
|
189 |
model = Wav2Vec2ForCTC.from_pretrained("bakrianoo/sinai-voice-ar-stt").eval()
|
|
|
232 |
reference: وَلِلَّهِ يَسْجُدُ مَا فِي السَّمَاوَاتِ وَمَا فِي الْأَرْضِ مِنْ دَابَّةٍ وَالْمَلَائِكَةُ وَهُمْ لَا يَسْتَكْبِرُونَ
|
233 |
predicted: ولله يسجد ما في السماوات وما في الأرض من دابة والملائكة وهم لا يستكبرون
|
234 |
```
|
|
|
235 |
|
236 |
+
## Evaluation
|
237 |
|
238 |
The model can be evaluated as follows on the Arabic test data of Common Voice:
|
239 |
```python
|
|
|
250 |
44100: torchaudio.transforms.Resample(44100, 16000),
|
251 |
32000: torchaudio.transforms.Resample(32000, 16000),
|
252 |
}
|
253 |
+
|
254 |
def prepare_example(example):
|
255 |
speech, sampling_rate = torchaudio.load(example["path"])
|
256 |
+
if sampling_rate in resamplers:
|
257 |
+
example["speech"] = resamplers[sampling_rate](speech).squeeze().numpy()
|
258 |
+
else:
|
259 |
+
example["speech"] = resamplers[4800](speech).squeeze().numpy()
|
260 |
return example
|
261 |
+
|
262 |
test_split = test_split.map(prepare_example)
|
263 |
processor = Wav2Vec2Processor.from_pretrained("bakrianoo/sinai-voice-ar-stt")
|
264 |
model = Wav2Vec2ForCTC.from_pretrained("bakrianoo/sinai-voice-ar-stt").to("cuda").eval()
|
|
|
274 |
transformation = jiwer.Compose([
|
275 |
# normalize some diacritics, remove punctuation, and replace Persian letters with Arabic ones
|
276 |
jiwer.SubstituteRegexes({
|
277 |
+
r'[auiFNKo\\~_،؟»\\?;:\\-,\\.؛«!"]': "", "\\u06D6": "",
|
278 |
+
r"[\\|\\{]": "A", "p": "h", "ک": "k", "ی": "y"}),
|
279 |
# default transformation below
|
280 |
jiwer.RemoveMultipleSpaces(),
|
281 |
jiwer.Strip(),
|
|
|
291 |
)
|
292 |
print(f"WER: {metrics['wer']:.2%}")
|
293 |
```
|
294 |
+
**Test Result**: 23.80%
|
295 |
|
296 |
|
297 |
## Other Arabic Voice recognition Models
|