Spaces:
Runtime error
Runtime error
bWX1204813
commited on
Commit
·
091b1e0
0
Parent(s):
initial
Browse files- .gitignore +3 -0
- EDA.ipynb +0 -0
- README.md +3 -0
- __pycache__/denoise.cpython-38.pyc +0 -0
- __pycache__/metrics.cpython-38.pyc +0 -0
- __pycache__/utils.cpython-38.pyc +0 -0
- app.py +43 -0
- denoisers/.ipynb_checkpoints/SpectralGating-checkpoint.py +26 -0
- denoisers/.ipynb_checkpoints/demucs-checkpoint.py +67 -0
- denoisers/SpectralGating.py +26 -0
- denoisers/__pycache__/SpectralGating.cpython-38.pyc +0 -0
- denoisers/demucs.py +67 -0
- evaluation.py +62 -0
- main.py +15 -0
- metrics.py +21 -0
- tutorial.ipynb +0 -0
- utils.py +53 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
.idea/*
|
2 |
+
.ipynb_checkpoints/*
|
3 |
+
nohup.out
|
EDA.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
README.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
| Attempt | #1 | #2 |
|
2 |
+
| :---: | :---: | :---: |
|
3 |
+
| Seconds | 301 | 283 |
|
__pycache__/denoise.cpython-38.pyc
ADDED
Binary file (3.81 kB). View file
|
|
__pycache__/metrics.cpython-38.pyc
ADDED
Binary file (867 Bytes). View file
|
|
__pycache__/utils.cpython-38.pyc
ADDED
Binary file (1.86 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
|
4 |
+
from re import M
|
5 |
+
import uuid
|
6 |
+
import shutil
|
7 |
+
import ffmpeg
|
8 |
+
import logging
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
|
12 |
+
from denoisers.SpectralGating import SpectralGating
|
13 |
+
|
14 |
+
model = SpectralGating()
|
15 |
+
|
16 |
+
|
17 |
+
def denoising_transform(audio):
|
18 |
+
src_path = "cache_wav/source/{}.wav".format(str(uuid.uuid4()))
|
19 |
+
tgt_path = "cache_wav/target/{}.wav".format(str(uuid.uuid4()))
|
20 |
+
# os.rename(audio.name, src_path)
|
21 |
+
(ffmpeg.input(audio)
|
22 |
+
.output(src_path, acodec='pcm_s16le', ac=1, ar=22050)
|
23 |
+
.run()
|
24 |
+
)
|
25 |
+
|
26 |
+
model.predict(src_path, tgt_path)
|
27 |
+
return tgt_path
|
28 |
+
|
29 |
+
|
30 |
+
inputs = gr.inputs.Audio(label="Source Audio", source="microphone", type='filepath')
|
31 |
+
outputs = gr.outputs.Audio(label="Target Audio", type='filepath')
|
32 |
+
|
33 |
+
title = "Chinese-to-English Direct Speech-to-Speech Translation (BETA)"
|
34 |
+
#"""
|
35 |
+
gr.Interface(
|
36 |
+
denoising_transform, inputs, outputs, title=title,
|
37 |
+
allow_flagging='never',
|
38 |
+
).launch(
|
39 |
+
server_name='localhost',
|
40 |
+
server_port=7871,
|
41 |
+
#ssl_keyfile='example.key',
|
42 |
+
#ssl_certfile="example.crt",
|
43 |
+
)
|
denoisers/.ipynb_checkpoints/SpectralGating-checkpoint.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import noisereduce as nr
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
|
5 |
+
|
6 |
+
class SpectralGating(torch.nn.Module):
|
7 |
+
"""example: wav_noisy = '/media/public/datasets/denoising/DS_10283_2791/noisy_trainset_56spk_wav/p312_002.wav' """
|
8 |
+
def __init__(self, rate=16000):
|
9 |
+
super(SpectralGating, self).__init__()
|
10 |
+
self.rate = rate
|
11 |
+
|
12 |
+
def forward(self, wav):
|
13 |
+
reduced_noise = torch.Tensor(nr.reduce_noise(y=wav, sr=self.rate))
|
14 |
+
return reduced_noise
|
15 |
+
|
16 |
+
def predict(self, wav_path, out_path):
|
17 |
+
data, rate = torchaudio.load(wav_path)
|
18 |
+
reduced_noise = torch.Tensor(nr.reduce_noise(y=data, sr=rate))
|
19 |
+
torchaudio.save(out_path, reduced_noise, rate)
|
20 |
+
return reduced_noise
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
|
denoisers/.ipynb_checkpoints/demucs-checkpoint.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class Encoder(torch.nn.Module):
|
5 |
+
def __init__(self, in_channels, out_channels,
|
6 |
+
kernel_size_1=8, stride_1=4,
|
7 |
+
kernel_size_2=1, stride_2=1):
|
8 |
+
super(Encoder, self).__init__()
|
9 |
+
|
10 |
+
self.conv1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
|
11 |
+
kernel_size=kernel_size_1, stride=stride_1)
|
12 |
+
self.relu1 = torch.nn.ReLU()
|
13 |
+
self.conv2 = torch.nn.Conv1d(in_channels=out_channels, out_channels=2 * out_channels,
|
14 |
+
kernel_size=kernel_size_2, stride=stride_2)
|
15 |
+
self.glu = torch.nn.GLU()
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
x = self.relu1(self.conv1(x))
|
19 |
+
x = self.glu(self.conv2(x))
|
20 |
+
return x
|
21 |
+
|
22 |
+
|
23 |
+
class Decoder(torch.nn.Module):
|
24 |
+
def __init__(self, in_channels, out_channels,
|
25 |
+
kernel_size_1=3, stride_1=1,
|
26 |
+
kernel_size_2=8, stride_2=4):
|
27 |
+
super(Decoder, self).__init__()
|
28 |
+
|
29 |
+
self.conv1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=2 * in_channels,
|
30 |
+
kernel_size=kernel_size_1, stride=stride_1)
|
31 |
+
self.glu = torch.nn.GLU()
|
32 |
+
self.conv2 = torch.nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels,
|
33 |
+
kernel_size=kernel_size_2, stride=stride_2)
|
34 |
+
self.relu = torch.nn.ReLU()
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
x = self.glu(self.conv1(x))
|
38 |
+
x = self.relu(self.conv2(x))
|
39 |
+
return x
|
40 |
+
|
41 |
+
|
42 |
+
class Demucs(torch.nn.Module):
|
43 |
+
def __init__(self):
|
44 |
+
super(Demucs, self).__init__()
|
45 |
+
|
46 |
+
self.encoder1 = Encoder(in_channels=1, out_channels=64)
|
47 |
+
self.encoder2 = Encoder(in_channels=64, out_channels=128)
|
48 |
+
self.encoder3 = Encoder(in_channels=128, out_channels=256)
|
49 |
+
|
50 |
+
self.lstm = torch.nn.LSTM(input_size=256, hidden_size=256, num_layers=2)
|
51 |
+
|
52 |
+
self.decoder1 = Decoder(in_channels=256, out_channels=128)
|
53 |
+
self.decoder2 = Decoder(in_channels=128, out_channels=64)
|
54 |
+
self.decoder3 = Decoder(in_channels=64, out_channels=1)
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
out1 = self.encoder1(x)
|
58 |
+
out2 = self.encoder2(out1)
|
59 |
+
out3 = self.encoder3(out2)
|
60 |
+
|
61 |
+
x = self.lstm(out3)
|
62 |
+
|
63 |
+
x = self.decoder1(x + out3)
|
64 |
+
x = self.decoder2(x + out2)
|
65 |
+
x = self.decoder3(x + out1)
|
66 |
+
|
67 |
+
return x
|
denoisers/SpectralGating.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import noisereduce as nr
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
|
5 |
+
|
6 |
+
class SpectralGating(torch.nn.Module):
|
7 |
+
"""example: wav_noisy = '/media/public/datasets/denoising/DS_10283_2791/noisy_trainset_56spk_wav/p312_002.wav' """
|
8 |
+
def __init__(self, rate=16000):
|
9 |
+
super(SpectralGating, self).__init__()
|
10 |
+
self.rate = rate
|
11 |
+
|
12 |
+
def forward(self, wav):
|
13 |
+
reduced_noise = torch.Tensor(nr.reduce_noise(y=wav, sr=self.rate))
|
14 |
+
return reduced_noise
|
15 |
+
|
16 |
+
def predict(self, wav_path, out_path):
|
17 |
+
data, rate = torchaudio.load(wav_path)
|
18 |
+
reduced_noise = torch.Tensor(nr.reduce_noise(y=data, sr=rate))
|
19 |
+
torchaudio.save(out_path, reduced_noise, rate)
|
20 |
+
return reduced_noise
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
|
denoisers/__pycache__/SpectralGating.cpython-38.pyc
ADDED
Binary file (1.2 kB). View file
|
|
denoisers/demucs.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class Encoder(torch.nn.Module):
|
5 |
+
def __init__(self, in_channels, out_channels,
|
6 |
+
kernel_size_1=8, stride_1=4,
|
7 |
+
kernel_size_2=1, stride_2=1):
|
8 |
+
super(Encoder, self).__init__()
|
9 |
+
|
10 |
+
self.conv1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
|
11 |
+
kernel_size=kernel_size_1, stride=stride_1)
|
12 |
+
self.relu1 = torch.nn.ReLU()
|
13 |
+
self.conv2 = torch.nn.Conv1d(in_channels=out_channels, out_channels=2 * out_channels,
|
14 |
+
kernel_size=kernel_size_2, stride=stride_2)
|
15 |
+
self.glu = torch.nn.GLU()
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
x = self.relu1(self.conv1(x))
|
19 |
+
x = self.glu(self.conv2(x))
|
20 |
+
return x
|
21 |
+
|
22 |
+
|
23 |
+
class Decoder(torch.nn.Module):
|
24 |
+
def __init__(self, in_channels, out_channels,
|
25 |
+
kernel_size_1=3, stride_1=1,
|
26 |
+
kernel_size_2=8, stride_2=4):
|
27 |
+
super(Decoder, self).__init__()
|
28 |
+
|
29 |
+
self.conv1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=2 * in_channels,
|
30 |
+
kernel_size=kernel_size_1, stride=stride_1)
|
31 |
+
self.glu = torch.nn.GLU()
|
32 |
+
self.conv2 = torch.nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels,
|
33 |
+
kernel_size=kernel_size_2, stride=stride_2)
|
34 |
+
self.relu = torch.nn.ReLU()
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
x = self.glu(self.conv1(x))
|
38 |
+
x = self.relu(self.conv2(x))
|
39 |
+
return x
|
40 |
+
|
41 |
+
|
42 |
+
class Demucs(torch.nn.Module):
|
43 |
+
def __init__(self):
|
44 |
+
super(Demucs, self).__init__()
|
45 |
+
|
46 |
+
self.encoder1 = Encoder(in_channels=1, out_channels=64)
|
47 |
+
self.encoder2 = Encoder(in_channels=64, out_channels=128)
|
48 |
+
self.encoder3 = Encoder(in_channels=128, out_channels=256)
|
49 |
+
|
50 |
+
self.lstm = torch.nn.LSTM(input_size=256, hidden_size=256, num_layers=2)
|
51 |
+
|
52 |
+
self.decoder1 = Decoder(in_channels=256, out_channels=128)
|
53 |
+
self.decoder2 = Decoder(in_channels=128, out_channels=64)
|
54 |
+
self.decoder3 = Decoder(in_channels=64, out_channels=1)
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
out1 = self.encoder1(x)
|
58 |
+
out2 = self.encoder2(out1)
|
59 |
+
out3 = self.encoder3(out2)
|
60 |
+
|
61 |
+
x = self.lstm(out3)
|
62 |
+
|
63 |
+
x = self.decoder1(x + out3)
|
64 |
+
x = self.decoder2(x + out2)
|
65 |
+
x = self.decoder3(x + out1)
|
66 |
+
|
67 |
+
return x
|
evaluation.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from tqdm import tqdm
|
3 |
+
|
4 |
+
from utils import load_wav, collect_valentini_paths
|
5 |
+
from metrics import Metrics
|
6 |
+
from denoisers.SpectralGating import SpectralGating
|
7 |
+
|
8 |
+
|
9 |
+
PARSERS = {
|
10 |
+
'valentini': collect_valentini_paths
|
11 |
+
}
|
12 |
+
MODELS = {
|
13 |
+
'baseline': SpectralGating
|
14 |
+
}
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
def evaluate_on_dataset(model_name, dataset_path, dataset_type, ideal):
|
19 |
+
model = MODELS[model_name]()
|
20 |
+
parser = PARSERS[dataset_type]
|
21 |
+
clean_wavs, noisy_wavs = parser(dataset_path)
|
22 |
+
|
23 |
+
metrics = Metrics()
|
24 |
+
mean_scores = {'PESQ': 0, 'STOI': 0}
|
25 |
+
for clean_path, noisy_path in tqdm(zip(clean_wavs, noisy_wavs), total=len(clean_wavs)):
|
26 |
+
clean_wav = load_wav(clean_path)
|
27 |
+
noisy_wav = load_wav(noisy_path)
|
28 |
+
denoised_wav = model(noisy_wav)
|
29 |
+
if ideal:
|
30 |
+
scores = metrics.calculate(noisy_wav, clean_wav)
|
31 |
+
else:
|
32 |
+
scores = metrics.calculate(noisy_wav, denoised_wav)
|
33 |
+
|
34 |
+
mean_scores['PESQ'] += scores['PESQ']
|
35 |
+
mean_scores['STOI'] += scores['STOI']
|
36 |
+
|
37 |
+
mean_scores['PESQ'] = mean_scores['PESQ'].numpy() / len(clean_wavs)
|
38 |
+
mean_scores['STOI'] = mean_scores['STOI'].numpy() / len(clean_wavs)
|
39 |
+
|
40 |
+
return mean_scores
|
41 |
+
|
42 |
+
|
43 |
+
if __name__ == '__main__':
|
44 |
+
parser = argparse.ArgumentParser(prog='Program to evaluate denoising')
|
45 |
+
parser.add_argument('--dataset_path', type=str,
|
46 |
+
default='/media/public/datasets/denoising/DS_10283_2791/',
|
47 |
+
help='Path to dataset folder')
|
48 |
+
parser.add_argument('--dataset_type', type=str, required=True,
|
49 |
+
choices=['valentini'])
|
50 |
+
parser.add_argument('--model_name', type=str,
|
51 |
+
choices=['baseline'])
|
52 |
+
parser.add_argument('--ideal', type=bool, default=False,
|
53 |
+
help="Evaluate metrics on testing data with ideal denoising")
|
54 |
+
|
55 |
+
args = parser.parse_args()
|
56 |
+
|
57 |
+
mean_scores = evaluate_on_dataset(model_name=args.model_name,
|
58 |
+
dataset_path=args.dataset_path,
|
59 |
+
dataset_type=args.dataset_type,
|
60 |
+
ideal=args.ideal)
|
61 |
+
print(f"Metrics on {args.dataset_type} dataset with "
|
62 |
+
f"{args.model_name if args.model_name is not None else 'ideal denoising'} = {mean_scores}")
|
main.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
print(torch.__version__)
|
3 |
+
from torchaudio.utils import download_asset
|
4 |
+
|
5 |
+
|
6 |
+
def print_hi(name):
|
7 |
+
# Use a breakpoint in the code line below to debug your script.
|
8 |
+
print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint.
|
9 |
+
|
10 |
+
|
11 |
+
# Press the green button in the gutter to run the script.
|
12 |
+
if __name__ == '__main__':
|
13 |
+
print_hi('PyCharm')
|
14 |
+
|
15 |
+
# See PyCharm help at https://www.jetbrains.com/help/pycharm/
|
metrics.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
|
2 |
+
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
|
3 |
+
import torch
|
4 |
+
import torchaudio
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
class Metrics:
|
10 |
+
def __init__(self, rate=16000):
|
11 |
+
self.nb_pesq = PerceptualEvaluationSpeechQuality(rate, 'wb')
|
12 |
+
self.stoi = ShortTimeObjectiveIntelligibility(rate, False)
|
13 |
+
|
14 |
+
def calculate(self, preds, target):
|
15 |
+
return {'PESQ': self.nb_pesq(preds, target),
|
16 |
+
'STOI': self.stoi(preds, target)}
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
|
tutorial.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
utils.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchaudio
|
2 |
+
import torch
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
|
7 |
+
def collect_valentini_paths(dataset_path):
|
8 |
+
clean_path = Path(dataset_path) / 'clean_testset_wav'
|
9 |
+
noisy_path = Path(dataset_path) / 'noisy_testset_wav'
|
10 |
+
|
11 |
+
clean_wavs = list(clean_path.glob("*"))
|
12 |
+
noisy_wavs = list(noisy_path.glob("*"))
|
13 |
+
|
14 |
+
return clean_wavs, noisy_wavs
|
15 |
+
|
16 |
+
|
17 |
+
def load_wav(path):
|
18 |
+
wav, org_sr = torchaudio.load(path)
|
19 |
+
wav = torchaudio.functional.resample(wav, orig_freq=org_sr, new_freq=16000)
|
20 |
+
return wav
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
def plot_spectrogram(stft, title="Spectrogram", xlim=None):
|
25 |
+
magnitude = stft.abs()
|
26 |
+
spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
|
27 |
+
figure, axis = plt.subplots(1, 1)
|
28 |
+
img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto")
|
29 |
+
figure.suptitle(title)
|
30 |
+
plt.colorbar(img, ax=axis)
|
31 |
+
plt.show()
|
32 |
+
|
33 |
+
|
34 |
+
def plot_mask(mask, title="Mask", xlim=None):
|
35 |
+
mask = mask.numpy()
|
36 |
+
figure, axis = plt.subplots(1, 1)
|
37 |
+
img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto")
|
38 |
+
figure.suptitle(title)
|
39 |
+
plt.colorbar(img, ax=axis)
|
40 |
+
plt.show()
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
def generate_mixture(waveform_clean, waveform_noise, target_snr):
|
46 |
+
|
47 |
+
power_clean_signal = waveform_clean.pow(2).mean()
|
48 |
+
power_noise_signal = waveform_noise.pow(2).mean()
|
49 |
+
current_snr = 10 * torch.log10(power_clean_signal / power_noise_signal)
|
50 |
+
|
51 |
+
waveform_noise *= 10 ** (-(target_snr - current_snr) / 20)
|
52 |
+
return waveform_clean + waveform_noise
|
53 |
+
|