File size: 1,459 Bytes
95d8ea8
091b1e0
 
e2b0b28
091b1e0
 
bd0a813
091b1e0
e2b0b28
 
 
 
 
 
 
 
bd0a813
091b1e0
e2b0b28
95d8ea8
 
e2b0b28
 
 
95d8ea8
 
 
 
 
 
 
 
 
 
091b1e0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import pesq
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
from torchaudio.transforms import Resample
import torch
import torchaudio
from torchmetrics import SignalNoiseRatio

class Metrics(torch.nn.Module):
    def __init__(self, source_rate, target_rate=16000, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.source_rate = source_rate
        self.target_rate = target_rate
        self.resampler = Resample(orig_freq=source_rate, new_freq=target_rate)
        self.nb_pesq = PerceptualEvaluationSpeechQuality(target_rate, 'wb')
        self.stoi = ShortTimeObjectiveIntelligibility(target_rate, False)
        self.snr = SignalNoiseRatio()
        
    def forward(self, denoised, clean):
        pesq_scores, stoi_scores = 0, 0
        for denoised_wav, clean_wav in zip(denoised, clean):
            if self.source_rate != self.target_rate:
                denoised_wav = self.resampler(denoised_wav)
                clean_wav = self.resampler(clean_wav)
            try:
                pesq_scores += self.nb_pesq(denoised_wav, clean_wav).item()
                stoi_scores += self.stoi(denoised_wav, clean_wav).item()
            except pesq.NoUtterancesError as e:
                print(e)
            except ValueError as e:
                print(e)

        return {'PESQ': pesq_scores,
                'STOI': stoi_scores}