|
|
|
import sys |
|
import torch |
|
|
|
from transformers import AutoModelForCTC, AutoProcessor |
|
from datasets import load_dataset, load_metric |
|
import torchaudio.functional as F |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
model_id = sys.argv[1] |
|
lang = sys.argv[2] |
|
lang_phoneme = sys.argv[3] |
|
num_samples = int(sys.argv[4]) |
|
|
|
model = AutoModelForCTC.from_pretrained(model_id).to(device) |
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
|
ds = load_dataset("common_voice", lang, split="test", streaming=True) |
|
sample_iter = iter(ds) |
|
|
|
wer = load_metric("wer") |
|
cer = load_metric("cer") |
|
|
|
targets_ids = [] |
|
predictions_ids = [] |
|
for i in range(num_samples): |
|
sample = next(sample_iter) |
|
resampled_audio = F.resample(torch.tensor(sample["audio"]["array"]), 48_000, 16_000).numpy() |
|
input_values = processor(resampled_audio, return_tensors="pt").input_values |
|
|
|
with torch.no_grad(): |
|
logits = model(input_values.to(device)).logits |
|
|
|
prediction_ids = torch.argmax(logits, dim=-1) |
|
transcription = processor.batch_decode(prediction_ids) |
|
|
|
print(f"Correct: {sample['sentence']}") |
|
print(f"Predict: {transcription}") |
|
print(20 * '-') |
|
|
|
predictions_ids.append(prediction_ids[0].tolist()) |
|
|
|
kwargs = {} |
|
if len(lang_phoneme) > 0: |
|
kwargs["phonemizer_lang"] = lang_phoneme |
|
|
|
targets_ids.append(processor.tokenizer(sample["sentence"], **kwargs).input_ids) |
|
|
|
print("Compute metrics.....") |
|
|
|
import ipdb; ipdb.set_trace() |
|
transcriptions = processor.batch_decode(predictions_ids) |
|
targets_str = processor.batch_decode(targets_ids, group_tokens=False) |
|
|
|
wer = wer.compute(predictions=transcriptions, references=targets_str) |
|
cer = cer.compute(predictions=transcriptions, references=targets_str) |
|
|
|
print("wer", wer) |
|
print("cer", cer) |
|
|