File size: 1,691 Bytes
091b1e0
 
 
 
20c7778
 
 
 
 
 
 
 
 
 
 
 
 
091b1e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torchaudio
import torch
import matplotlib.pyplot as plt
from pathlib import Path
from torch.nn.functional import pad



def pad_cut_batch_audio(wavs, new_shape):
    wav_length = wavs.shape[-1]
    new_length = new_shape[-1]

    if wav_length > new_length:
        wavs = wavs[:, :, :new_length]
    elif wav_length < new_length:
        wavs = pad(wavs, (0, new_length - wav_length))
    return wavs


def collect_valentini_paths(dataset_path):
    clean_path = Path(dataset_path) / 'clean_testset_wav'
    noisy_path = Path(dataset_path) / 'noisy_testset_wav'

    clean_wavs = list(clean_path.glob("*"))
    noisy_wavs = list(noisy_path.glob("*"))

    return clean_wavs, noisy_wavs




def plot_spectrogram(stft, title="Spectrogram", xlim=None):
    magnitude = stft.abs()
    spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
    figure, axis = plt.subplots(1, 1)
    img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto")
    figure.suptitle(title)
    plt.colorbar(img, ax=axis)
    plt.show()


def plot_mask(mask, title="Mask", xlim=None):
    mask = mask.numpy()
    figure, axis = plt.subplots(1, 1)
    img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto")
    figure.suptitle(title)
    plt.colorbar(img, ax=axis)
    plt.show()


    

def generate_mixture(waveform_clean, waveform_noise, target_snr):
    
    power_clean_signal = waveform_clean.pow(2).mean()
    power_noise_signal = waveform_noise.pow(2).mean()
    current_snr = 10 * torch.log10(power_clean_signal / power_noise_signal)
    
    waveform_noise *= 10 ** (-(target_snr - current_snr) / 20)
    return waveform_clean + waveform_noise