Spaces:
Runtime error
Runtime error
File size: 1,459 Bytes
95d8ea8 091b1e0 e2b0b28 091b1e0 bd0a813 091b1e0 e2b0b28 bd0a813 091b1e0 e2b0b28 95d8ea8 e2b0b28 95d8ea8 091b1e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import pesq
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
from torchaudio.transforms import Resample
import torch
import torchaudio
from torchmetrics import SignalNoiseRatio
class Metrics(torch.nn.Module):
def __init__(self, source_rate, target_rate=16000, *args, **kwargs):
super().__init__(*args, **kwargs)
self.source_rate = source_rate
self.target_rate = target_rate
self.resampler = Resample(orig_freq=source_rate, new_freq=target_rate)
self.nb_pesq = PerceptualEvaluationSpeechQuality(target_rate, 'wb')
self.stoi = ShortTimeObjectiveIntelligibility(target_rate, False)
self.snr = SignalNoiseRatio()
def forward(self, denoised, clean):
pesq_scores, stoi_scores = 0, 0
for denoised_wav, clean_wav in zip(denoised, clean):
if self.source_rate != self.target_rate:
denoised_wav = self.resampler(denoised_wav)
clean_wav = self.resampler(clean_wav)
try:
pesq_scores += self.nb_pesq(denoised_wav, clean_wav).item()
stoi_scores += self.stoi(denoised_wav, clean_wav).item()
except pesq.NoUtterancesError as e:
print(e)
except ValueError as e:
print(e)
return {'PESQ': pesq_scores,
'STOI': stoi_scores}
|