run_ctc_common_voice.py / run_ctc_model.py
patrickvonplaten's picture
up
a246192
#!/usr/bin/env python3
import sys
import torch
from transformers import AutoModelForCTC, AutoProcessor
from datasets import load_dataset, load_metric
import torchaudio.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = sys.argv[1]
lang = sys.argv[2]
lang_phoneme = sys.argv[3]
num_samples = int(sys.argv[4])
model = AutoModelForCTC.from_pretrained(model_id).to(device)
processor = AutoProcessor.from_pretrained(model_id)
ds = load_dataset("common_voice", lang, split="test", streaming=True)
sample_iter = iter(ds)
wer = load_metric("wer")
cer = load_metric("cer")
targets_ids = []
predictions_ids = []
for i in range(num_samples):
sample = next(sample_iter)
resampled_audio = F.resample(torch.tensor(sample["audio"]["array"]), 48_000, 16_000).numpy()
input_values = processor(resampled_audio, return_tensors="pt").input_values
with torch.no_grad():
logits = model(input_values.to(device)).logits
prediction_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(prediction_ids)
print(f"Correct: {sample['sentence']}")
print(f"Predict: {transcription}")
print(20 * '-')
predictions_ids.append(prediction_ids[0].tolist())
kwargs = {}
if len(lang_phoneme) > 0:
kwargs["phonemizer_lang"] = lang_phoneme
targets_ids.append(processor.tokenizer(sample["sentence"], **kwargs).input_ids)
print("Compute metrics.....")
import ipdb; ipdb.set_trace()
transcriptions = processor.batch_decode(predictions_ids)
targets_str = processor.batch_decode(targets_ids, group_tokens=False)
wer = wer.compute(predictions=transcriptions, references=targets_str)
cer = cer.compute(predictions=transcriptions, references=targets_str)
print("wer", wer)
print("cer", cer)