Spaces:
Runtime error
Runtime error
import pesq | |
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality | |
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility | |
from torchaudio.transforms import Resample | |
import torch | |
import torchaudio | |
from torchmetrics import SignalNoiseRatio | |
class Metrics(torch.nn.Module): | |
def __init__(self, source_rate, target_rate=16000, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.source_rate = source_rate | |
self.target_rate = target_rate | |
self.resampler = Resample(orig_freq=source_rate, new_freq=target_rate) | |
self.nb_pesq = PerceptualEvaluationSpeechQuality(target_rate, 'wb') | |
self.stoi = ShortTimeObjectiveIntelligibility(target_rate, False) | |
self.snr = SignalNoiseRatio() | |
def forward(self, denoised, clean): | |
pesq_scores, stoi_scores = 0, 0 | |
for denoised_wav, clean_wav in zip(denoised, clean): | |
if self.source_rate != self.target_rate: | |
denoised_wav = self.resampler(denoised_wav) | |
clean_wav = self.resampler(clean_wav) | |
try: | |
pesq_scores += self.nb_pesq(denoised_wav, clean_wav).item() | |
stoi_scores += self.stoi(denoised_wav, clean_wav).item() | |
except pesq.NoUtterancesError as e: | |
print(e) | |
except ValueError as e: | |
print(e) | |
return {'PESQ': pesq_scores, | |
'STOI': stoi_scores} | |