Spaces:
Runtime error
Runtime error
Commit
·
9ff4511
1
Parent(s):
3f8f152
refactored code to work with hydra and wandb
Browse files- checkpoing_saver.py +45 -0
- conf/config.yaml +20 -3
- conf/dataset/valentini.yaml +4 -3
- datasets/__init__.py +18 -0
- datasets.py → datasets/valentini.py +10 -10
- denoisers/demucs.py +1 -1
- testing/metrics.py +2 -2
- train.py +42 -74
- transforms.py +22 -0
- utils.py +0 -5
checkpoing_saver.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import logging
|
4 |
+
import torch
|
5 |
+
import wandb
|
6 |
+
|
7 |
+
class CheckpointSaver:
|
8 |
+
def __init__(self, dirpath, decreasing=True, top_n=5):
|
9 |
+
"""
|
10 |
+
dirpath: Directory path where to store all model weights
|
11 |
+
decreasing: If decreasing is `True`, then lower metric is better
|
12 |
+
top_n: Total number of models to track based on validation metric value
|
13 |
+
"""
|
14 |
+
if not os.path.exists(dirpath): os.makedirs(dirpath)
|
15 |
+
self.dirpath = dirpath
|
16 |
+
self.top_n = top_n
|
17 |
+
self.decreasing = decreasing
|
18 |
+
self.top_model_paths = []
|
19 |
+
self.best_metric_val = np.Inf if decreasing else -np.Inf
|
20 |
+
|
21 |
+
def __call__(self, model, epoch, metric_val):
|
22 |
+
model_path = os.path.join(self.dirpath, model.__class__.__name__ + f'_epoch{epoch}.pt')
|
23 |
+
save = metric_val < self.best_metric_val if self.decreasing else metric_val > self.best_metric_val
|
24 |
+
if save:
|
25 |
+
logging.info(
|
26 |
+
f"Current metric value better than {metric_val} better than best {self.best_metric_val}, saving model at {model_path}, & logging model weights to W&B.")
|
27 |
+
self.best_metric_val = metric_val
|
28 |
+
torch.save(model.state_dict(), model_path)
|
29 |
+
self.log_artifact(f'model-ckpt-epoch-{epoch}.pt', model_path, metric_val)
|
30 |
+
self.top_model_paths.append({'path': model_path, 'score': metric_val})
|
31 |
+
self.top_model_paths = sorted(self.top_model_paths, key=lambda o: o['score'], reverse=not self.decreasing)
|
32 |
+
if len(self.top_model_paths) > self.top_n:
|
33 |
+
self.cleanup()
|
34 |
+
|
35 |
+
def log_artifact(self, filename, model_path, metric_val):
|
36 |
+
artifact = wandb.Artifact(filename, type='model', metadata={'Validation score': metric_val})
|
37 |
+
artifact.add_file(model_path)
|
38 |
+
wandb.run.log_artifact(artifact)
|
39 |
+
|
40 |
+
def cleanup(self):
|
41 |
+
to_remove = self.top_model_paths[self.top_n:]
|
42 |
+
logging.info(f"Removing extra models.. {to_remove}")
|
43 |
+
for o in to_remove:
|
44 |
+
os.remove(o['path'])
|
45 |
+
self.top_model_paths = self.top_model_paths[:self.top_n]
|
conf/config.yaml
CHANGED
@@ -4,12 +4,29 @@ defaults:
|
|
4 |
- loss: mse
|
5 |
- optimizer: sgd
|
6 |
|
|
|
|
|
|
|
7 |
|
8 |
dataloader:
|
9 |
max_seconds: 2
|
10 |
sample_rate: 16000
|
11 |
-
|
|
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
- loss: mse
|
5 |
- optimizer: sgd
|
6 |
|
7 |
+
training:
|
8 |
+
num_epochs: 5
|
9 |
+
model_save_path: /media/public/checkpoints
|
10 |
|
11 |
dataloader:
|
12 |
max_seconds: 2
|
13 |
sample_rate: 16000
|
14 |
+
train_batch_size: 12
|
15 |
+
valid_batch_size: 12
|
16 |
|
17 |
+
validation:
|
18 |
+
path: /media/public/datasets/denoising/DS_10283_2791/noisy_testset_wav
|
19 |
+
wavs:
|
20 |
+
easy: p232_284.wav
|
21 |
+
medium: p232_071.wav
|
22 |
+
hard : p257_171.wav
|
23 |
|
24 |
+
|
25 |
+
wandb:
|
26 |
+
project: denoising
|
27 |
+
log_interval: 100
|
28 |
+
api_key: local-e23d01ece807cb31e69b2cf4137e4998e4b9856f
|
29 |
+
host: http://localhost:8080/
|
30 |
+
notes: "Experiment note"
|
31 |
+
tags:
|
32 |
+
- baseline
|
conf/dataset/valentini.yaml
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
1 |
+
valentini:
|
2 |
+
dataset_path: /media/public/datasets/denoising/DS_10283_2791/
|
3 |
+
val_fraction: 0.2
|
4 |
+
sample_rate: 48000
|
datasets/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
|
4 |
+
from datasets.valentini import Valentini
|
5 |
+
from transforms import Transform
|
6 |
+
|
7 |
+
DATASETS_POOL = {
|
8 |
+
'valentini': Valentini
|
9 |
+
}
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
def get_datasets(cfg) -> Tuple[Dataset, Dataset]:
|
14 |
+
name, dataset_params = list(cfg['dataset'].items())[0]
|
15 |
+
transform = Transform(input_sr=dataset_params['sample_rate'], **cfg['dataloader'])
|
16 |
+
train_dataset = DATASETS_POOL[name](valid=False, transform=transform, **dataset_params)
|
17 |
+
valid_dataset = DATASETS_POOL[name](valid=True, transform=transform, **dataset_params)
|
18 |
+
return train_dataset, valid_dataset
|
datasets.py → datasets/valentini.py
RENAMED
@@ -1,17 +1,19 @@
|
|
1 |
import torch
|
2 |
from torch.utils.data import Dataset
|
3 |
from pathlib import Path
|
4 |
-
|
|
|
|
|
5 |
|
|
|
6 |
|
7 |
class Valentini(Dataset):
|
8 |
-
def __init__(self, dataset_path
|
9 |
-
valid=False):
|
10 |
clean_path = Path(dataset_path) / 'clean_trainset_56spk_wav'
|
11 |
noisy_path = Path(dataset_path) / 'noisy_trainset_56spk_wav'
|
12 |
clean_wavs = list(clean_path.glob("*"))
|
13 |
noisy_wavs = list(noisy_path.glob("*"))
|
14 |
-
valid_threshold = int(len(clean_wavs) *
|
15 |
if valid:
|
16 |
self.clean_wavs = clean_wavs[:valid_threshold]
|
17 |
self.noisy_wavs = noisy_wavs[:valid_threshold]
|
@@ -22,16 +24,17 @@ class Valentini(Dataset):
|
|
22 |
assert len(self.clean_wavs) == len(self.noisy_wavs)
|
23 |
|
24 |
self.transform = transform
|
|
|
25 |
|
26 |
def __len__(self):
|
27 |
return len(self.clean_wavs)
|
28 |
|
29 |
def __getitem__(self, idx):
|
30 |
-
noisy_wav =
|
31 |
-
clean_wav =
|
32 |
|
33 |
if self.transform:
|
34 |
-
random_seed = torch.randint(
|
35 |
torch.manual_seed(random_seed)
|
36 |
noisy_wav = self.transform(noisy_wav)
|
37 |
torch.manual_seed(random_seed)
|
@@ -39,8 +42,5 @@ class Valentini(Dataset):
|
|
39 |
return noisy_wav, clean_wav
|
40 |
|
41 |
|
42 |
-
DATASETS_POOL = {
|
43 |
-
'valentini': Valentini
|
44 |
-
}
|
45 |
|
46 |
|
|
|
1 |
import torch
|
2 |
from torch.utils.data import Dataset
|
3 |
from pathlib import Path
|
4 |
+
import torchaudio
|
5 |
+
import numpy as np
|
6 |
+
from torchaudio.transforms import Resample
|
7 |
|
8 |
+
HIGH_RANDOM_SEED = 1000
|
9 |
|
10 |
class Valentini(Dataset):
|
11 |
+
def __init__(self, dataset_path, val_fraction, transform=None, valid=False, *args, **kwargs):
|
|
|
12 |
clean_path = Path(dataset_path) / 'clean_trainset_56spk_wav'
|
13 |
noisy_path = Path(dataset_path) / 'noisy_trainset_56spk_wav'
|
14 |
clean_wavs = list(clean_path.glob("*"))
|
15 |
noisy_wavs = list(noisy_path.glob("*"))
|
16 |
+
valid_threshold = int(len(clean_wavs) * val_fraction)
|
17 |
if valid:
|
18 |
self.clean_wavs = clean_wavs[:valid_threshold]
|
19 |
self.noisy_wavs = noisy_wavs[:valid_threshold]
|
|
|
24 |
assert len(self.clean_wavs) == len(self.noisy_wavs)
|
25 |
|
26 |
self.transform = transform
|
27 |
+
self.valid = valid
|
28 |
|
29 |
def __len__(self):
|
30 |
return len(self.clean_wavs)
|
31 |
|
32 |
def __getitem__(self, idx):
|
33 |
+
noisy_wav, noisy_sr = torchaudio.load(self.noisy_wavs[idx])
|
34 |
+
clean_wav, clean_sr = torchaudio.load(self.clean_wavs[idx])
|
35 |
|
36 |
if self.transform:
|
37 |
+
random_seed = 0 if self.valid else torch.randint(HIGH_RANDOM_SEED, (1,))[0]
|
38 |
torch.manual_seed(random_seed)
|
39 |
noisy_wav = self.transform(noisy_wav)
|
40 |
torch.manual_seed(random_seed)
|
|
|
42 |
return noisy_wav, clean_wav
|
43 |
|
44 |
|
|
|
|
|
|
|
45 |
|
46 |
|
denoisers/demucs.py
CHANGED
@@ -34,7 +34,7 @@ class Decoder(torch.nn.Module):
|
|
34 |
self.glu = torch.nn.GLU(dim=-2)
|
35 |
self.conv2 = torch.nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels,
|
36 |
kernel_size=cfg['conv2']['kernel_size'],
|
37 |
-
stride=cfg['conv2']['
|
38 |
self.relu = torch.nn.ReLU()
|
39 |
|
40 |
def forward(self, x):
|
|
|
34 |
self.glu = torch.nn.GLU(dim=-2)
|
35 |
self.conv2 = torch.nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels,
|
36 |
kernel_size=cfg['conv2']['kernel_size'],
|
37 |
+
stride=cfg['conv2']['stride'])
|
38 |
self.relu = torch.nn.ReLU()
|
39 |
|
40 |
def forward(self, x):
|
testing/metrics.py
CHANGED
@@ -12,7 +12,7 @@ class Metrics:
|
|
12 |
self.snr = SignalNoiseRatio()
|
13 |
|
14 |
def calculate(self, denoised, clean):
|
15 |
-
return {'PESQ': self.nb_pesq(denoised, clean),
|
16 |
-
'STOI': self.stoi(denoised, clean)}
|
17 |
|
18 |
|
|
|
12 |
self.snr = SignalNoiseRatio()
|
13 |
|
14 |
def calculate(self, denoised, clean):
|
15 |
+
return {'PESQ': self.nb_pesq(denoised, clean).item(),
|
16 |
+
'STOI': self.stoi(denoised, clean).item()}
|
17 |
|
18 |
|
train.py
CHANGED
@@ -1,94 +1,61 @@
|
|
1 |
import os
|
2 |
-
from torch.utils.tensorboard import SummaryWriter
|
3 |
import torch
|
4 |
-
from torch.nn import Sequential
|
5 |
from torch.utils.data import DataLoader
|
6 |
-
from datetime import datetime
|
7 |
-
from torchvision.transforms import RandomCrop
|
8 |
-
from utils import load_wav
|
9 |
-
from denoisers.demucs import Demucs
|
10 |
from pathlib import Path
|
11 |
from omegaconf import DictConfig
|
|
|
|
|
12 |
|
13 |
-
from
|
14 |
-
from losses import LOSSES
|
15 |
-
from datasets import DATASETS_POOL
|
16 |
from denoisers import get_model
|
17 |
from optimizers import get_optimizer
|
18 |
from losses import get_loss
|
19 |
-
|
|
|
|
|
20 |
|
21 |
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
|
22 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
23 |
|
24 |
-
#
|
25 |
-
#
|
26 |
-
# DATASET_PATH = Path('/media/public/dataset/denoising/DS_10283_2791/')
|
27 |
-
# VALID_WAVS = {'hard': 'p257_171.wav',
|
28 |
-
# 'medium': 'p232_071.wav',
|
29 |
-
# 'easy': 'p232_284.wav'}
|
30 |
-
# MAX_SECONDS = 2
|
31 |
-
# SAMPLE_RATE = 16000
|
32 |
-
#
|
33 |
-
# transform = Sequential(RandomCrop((1, int(MAX_SECONDS * SAMPLE_RATE)), pad_if_needed=True))
|
34 |
-
#
|
35 |
-
# training_loader = DataLoader(Valentini(valid=False, transform=transform), batch_size=12, shuffle=True)
|
36 |
-
# validation_loader = DataLoader(Valentini(valid=True, transform=transform), batch_size=12, shuffle=True)
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
|
|
|
|
39 |
|
40 |
-
|
41 |
-
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
42 |
-
model = get_model(cfg['model'])
|
43 |
optimizer = get_optimizer(model.parameters(), cfg['optimizer'])
|
44 |
loss_fn = get_loss(cfg['loss'])
|
|
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
EPOCHS = 5
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
for tag, wav_path in VALID_WAVS.items():
|
54 |
-
wav = load_wav(DATASET_PATH / 'noisy_testset_wav' / wav_path)
|
55 |
-
writer.add_audio(tag=tag, snd_tensor=wav, sample_rate=SAMPLE_RATE)
|
56 |
-
writer.flush()
|
57 |
-
|
58 |
-
for epoch in range(EPOCHS):
|
59 |
-
print('EPOCH {}:'.format(epoch_number + 1))
|
60 |
|
|
|
61 |
model.train(True)
|
62 |
-
|
63 |
-
running_loss = 0.
|
64 |
-
last_loss = 0.
|
65 |
-
|
66 |
for i, data in enumerate(training_loader):
|
67 |
inputs, labels = data
|
68 |
inputs, labels = inputs.to(device), labels.to(device)
|
69 |
-
|
70 |
optimizer.zero_grad()
|
71 |
-
|
72 |
outputs = model(inputs)
|
73 |
-
|
74 |
loss = loss_fn(outputs, labels)
|
75 |
loss.backward()
|
76 |
-
|
77 |
optimizer.step()
|
78 |
|
79 |
-
|
80 |
-
|
81 |
-
last_loss = running_loss / 1000 # loss per batch
|
82 |
-
print(' batch {} loss: {}'.format(i + 1, last_loss))
|
83 |
-
tb_x = epoch_number * len(training_loader) + i + 1
|
84 |
-
writer.add_scalar('Loss/train', last_loss, tb_x)
|
85 |
-
running_loss = 0.
|
86 |
-
|
87 |
-
avg_loss = last_loss
|
88 |
|
89 |
model.train(False)
|
90 |
|
91 |
-
running_vloss = 0.0
|
92 |
with torch.no_grad():
|
93 |
for i, vdata in enumerate(validation_loader):
|
94 |
vinputs, vlabels = vdata
|
@@ -96,28 +63,29 @@ def train(cfg: DictConfig):
|
|
96 |
voutputs = model(vinputs)
|
97 |
vloss = loss_fn(voutputs, vlabels)
|
98 |
running_vloss += vloss
|
|
|
|
|
|
|
|
|
99 |
|
100 |
-
avg_vloss = running_vloss / (
|
101 |
-
|
|
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
108 |
wav = torch.reshape(wav, (1, 1, -1)).to(device)
|
109 |
prediction = model(wav)
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
if avg_vloss < best_vloss:
|
116 |
-
best_vloss = avg_vloss
|
117 |
-
model_path = 'checkpoints/model_{}_{}'.format(timestamp, epoch_number)
|
118 |
-
torch.save(model.state_dict(), model_path)
|
119 |
|
120 |
-
|
121 |
|
122 |
|
123 |
if __name__ == '__main__':
|
|
|
1 |
import os
|
|
|
2 |
import torch
|
|
|
3 |
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
|
4 |
from pathlib import Path
|
5 |
from omegaconf import DictConfig
|
6 |
+
import wandb
|
7 |
+
import torchaudio
|
8 |
|
9 |
+
from checkpoing_saver import CheckpointSaver
|
|
|
|
|
10 |
from denoisers import get_model
|
11 |
from optimizers import get_optimizer
|
12 |
from losses import get_loss
|
13 |
+
from datasets import get_datasets
|
14 |
+
from testing.metrics import Metrics
|
15 |
+
import omegaconf
|
16 |
|
17 |
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
|
18 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
def train(cfg: DictConfig):
|
22 |
+
wandb.login(key=cfg['wandb']['api_key'], host=cfg['wandb']['host'])
|
23 |
+
wandb.init(project=cfg['wandb']['project'],
|
24 |
+
notes=cfg['wandb']['notes'],
|
25 |
+
tags=cfg['wandb']['tags'],
|
26 |
+
config=omegaconf.OmegaConf.to_container(
|
27 |
+
cfg, resolve=True, throw_on_missing=True))
|
28 |
|
29 |
+
checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'])
|
30 |
+
metrics = Metrics(rate=cfg['dataloader']['sample_rate'])
|
31 |
|
32 |
+
model = get_model(cfg['model']).to(device)
|
|
|
|
|
33 |
optimizer = get_optimizer(model.parameters(), cfg['optimizer'])
|
34 |
loss_fn = get_loss(cfg['loss'])
|
35 |
+
train_dataset, valid_dataset = get_datasets(cfg)
|
36 |
|
37 |
+
training_loader = DataLoader(train_dataset, batch_size=cfg['dataloader']['train_batch_size'], shuffle=True)
|
38 |
+
validation_loader = DataLoader(valid_dataset, batch_size=cfg['dataloader']['valid_batch_size'], shuffle=True)
|
|
|
|
|
39 |
|
40 |
+
wandb.watch(model, log_freq=100)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
+
for epoch in range(cfg['training']['num_epochs']):
|
43 |
model.train(True)
|
|
|
|
|
|
|
|
|
44 |
for i, data in enumerate(training_loader):
|
45 |
inputs, labels = data
|
46 |
inputs, labels = inputs.to(device), labels.to(device)
|
|
|
47 |
optimizer.zero_grad()
|
|
|
48 |
outputs = model(inputs)
|
|
|
49 |
loss = loss_fn(outputs, labels)
|
50 |
loss.backward()
|
|
|
51 |
optimizer.step()
|
52 |
|
53 |
+
if i % cfg['wandb']['log_interval'] == 0:
|
54 |
+
wandb.log({"loss": loss})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
model.train(False)
|
57 |
|
58 |
+
running_vloss, running_pesq, running_stoi = 0.0, 0.0, 0.0
|
59 |
with torch.no_grad():
|
60 |
for i, vdata in enumerate(validation_loader):
|
61 |
vinputs, vlabels = vdata
|
|
|
63 |
voutputs = model(vinputs)
|
64 |
vloss = loss_fn(voutputs, vlabels)
|
65 |
running_vloss += vloss
|
66 |
+
running_metrics = metrics.calculate(denoised=voutputs, clean=vlabels)
|
67 |
+
running_pesq += running_metrics['PESQ']
|
68 |
+
running_stoi += running_metrics['STOI']
|
69 |
+
|
70 |
|
71 |
+
avg_vloss = running_vloss / len(validation_loader)
|
72 |
+
avg_pesq = running_pesq / len(validation_loader)
|
73 |
+
avg_stoi = running_stoi / len(validation_loader)
|
74 |
|
75 |
+
wandb.log({"valid_loss": avg_vloss,
|
76 |
+
"valid_pesq": avg_pesq,
|
77 |
+
"valid_stoi": avg_stoi})
|
78 |
+
|
79 |
+
for tag, wav_path in cfg['validation']['wavs'].items():
|
80 |
+
wav, rate = torchaudio.load(Path(cfg['validation']['path']) / wav_path)
|
81 |
wav = torch.reshape(wav, (1, 1, -1)).to(device)
|
82 |
prediction = model(wav)
|
83 |
+
wandb.log({
|
84 |
+
f"{tag}_epoch_{epoch}": wandb.Audio(
|
85 |
+
prediction.cpu()[0][0],
|
86 |
+
sample_rate=rate)})
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
+
checkpoint_saver(model, epoch, metric_val=avg_pesq)
|
89 |
|
90 |
|
91 |
if __name__ == '__main__':
|
transforms.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from torchaudio.transforms import Resample
|
4 |
+
from torchvision.transforms import RandomCrop
|
5 |
+
|
6 |
+
class Transform(torch.nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
input_sr,
|
10 |
+
sample_rate,
|
11 |
+
max_seconds,
|
12 |
+
*args,
|
13 |
+
**kwargs
|
14 |
+
):
|
15 |
+
super().__init__()
|
16 |
+
self.resample = Resample(orig_freq=input_sr, new_freq=sample_rate)
|
17 |
+
self.random_crop = RandomCrop((1, int(max_seconds * sample_rate)), pad_if_needed=True)
|
18 |
+
|
19 |
+
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
20 |
+
resampled = self.resample(waveform)
|
21 |
+
croped = self.random_crop(resampled)
|
22 |
+
return croped
|
utils.py
CHANGED
@@ -14,11 +14,6 @@ def collect_valentini_paths(dataset_path):
|
|
14 |
return clean_wavs, noisy_wavs
|
15 |
|
16 |
|
17 |
-
def load_wav(path):
|
18 |
-
wav, org_sr = torchaudio.load(path)
|
19 |
-
wav = torchaudio.functional.resample(wav, orig_freq=org_sr, new_freq=16000)
|
20 |
-
return wav
|
21 |
-
|
22 |
|
23 |
|
24 |
def plot_spectrogram(stft, title="Spectrogram", xlim=None):
|
|
|
14 |
return clean_wavs, noisy_wavs
|
15 |
|
16 |
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
|
19 |
def plot_spectrogram(stft, title="Spectrogram", xlim=None):
|