denoising / testing /metrics.py
BorisovMaksim's picture
rewrote demucs model
e2b0b28
raw
history blame
1.46 kB
import pesq
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
from torchaudio.transforms import Resample
import torch
import torchaudio
from torchmetrics import SignalNoiseRatio
class Metrics(torch.nn.Module):
def __init__(self, source_rate, target_rate=16000, *args, **kwargs):
super().__init__(*args, **kwargs)
self.source_rate = source_rate
self.target_rate = target_rate
self.resampler = Resample(orig_freq=source_rate, new_freq=target_rate)
self.nb_pesq = PerceptualEvaluationSpeechQuality(target_rate, 'wb')
self.stoi = ShortTimeObjectiveIntelligibility(target_rate, False)
self.snr = SignalNoiseRatio()
def forward(self, denoised, clean):
pesq_scores, stoi_scores = 0, 0
for denoised_wav, clean_wav in zip(denoised, clean):
if self.source_rate != self.target_rate:
denoised_wav = self.resampler(denoised_wav)
clean_wav = self.resampler(clean_wav)
try:
pesq_scores += self.nb_pesq(denoised_wav, clean_wav).item()
stoi_scores += self.stoi(denoised_wav, clean_wav).item()
except pesq.NoUtterancesError as e:
print(e)
except ValueError as e:
print(e)
return {'PESQ': pesq_scores,
'STOI': stoi_scores}