denoising / evaluation.py
BorisovMaksim's picture
initinal
77c6c52
raw
history blame
2.09 kB
import argparse
from tqdm import tqdm
from utils import load_wav, collect_valentini_paths
from metrics import Metrics
from denoisers.SpectralGating import SpectralGating
PARSERS = {
'valentini': collect_valentini_paths
}
MODELS = {
'baseline': SpectralGating
}
def evaluate_on_dataset(model_name, dataset_path, dataset_type):
if model_name is not None:
model = MODELS[model_name]()
parser = PARSERS[dataset_type]
clean_wavs, noisy_wavs = parser(dataset_path)
metrics = Metrics()
mean_scores = {'PESQ': 0, 'STOI': 0}
for clean_path, noisy_path in tqdm(zip(clean_wavs, noisy_wavs), total=len(clean_wavs)):
clean_wav = load_wav(clean_path)
noisy_wav = load_wav(noisy_path)
if model_name is None:
scores = metrics.calculate(noisy_wav, clean_wav)
else:
denoised_wav = model(noisy_wav)
scores = metrics.calculate(noisy_wav, denoised_wav)
mean_scores['PESQ'] += scores['PESQ']
mean_scores['STOI'] += scores['STOI']
mean_scores['PESQ'] = mean_scores['PESQ'].numpy() / len(clean_wavs)
mean_scores['STOI'] = mean_scores['STOI'].numpy() / len(clean_wavs)
return mean_scores
if __name__ == '__main__':
parser = argparse.ArgumentParser(prog='Program to evaluate denoising')
parser.add_argument('--dataset_path', type=str,
default='/media/public/datasets/denoising/DS_10283_2791/',
help='Path to dataset folder')
parser.add_argument('--dataset_type', type=str, required=True,
choices=['valentini'])
parser.add_argument('--model_name', type=str,
choices=['baseline'])
args = parser.parse_args()
mean_scores = evaluate_on_dataset(model_name=args.model_name,
dataset_path=args.dataset_path,
dataset_type=args.dataset_type)
print(f"Metrics on {args.dataset_type} dataset with "
f"{args.model_name if args.model_name is not None else 'ideal denoising'} = {mean_scores}")