BorisovMaksim commited on
Commit
9ff4511
·
1 Parent(s): 3f8f152

refactored code to work with hydra and wandb

Browse files
checkpoing_saver.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import logging
4
+ import torch
5
+ import wandb
6
+
7
+ class CheckpointSaver:
8
+ def __init__(self, dirpath, decreasing=True, top_n=5):
9
+ """
10
+ dirpath: Directory path where to store all model weights
11
+ decreasing: If decreasing is `True`, then lower metric is better
12
+ top_n: Total number of models to track based on validation metric value
13
+ """
14
+ if not os.path.exists(dirpath): os.makedirs(dirpath)
15
+ self.dirpath = dirpath
16
+ self.top_n = top_n
17
+ self.decreasing = decreasing
18
+ self.top_model_paths = []
19
+ self.best_metric_val = np.Inf if decreasing else -np.Inf
20
+
21
+ def __call__(self, model, epoch, metric_val):
22
+ model_path = os.path.join(self.dirpath, model.__class__.__name__ + f'_epoch{epoch}.pt')
23
+ save = metric_val < self.best_metric_val if self.decreasing else metric_val > self.best_metric_val
24
+ if save:
25
+ logging.info(
26
+ f"Current metric value better than {metric_val} better than best {self.best_metric_val}, saving model at {model_path}, & logging model weights to W&B.")
27
+ self.best_metric_val = metric_val
28
+ torch.save(model.state_dict(), model_path)
29
+ self.log_artifact(f'model-ckpt-epoch-{epoch}.pt', model_path, metric_val)
30
+ self.top_model_paths.append({'path': model_path, 'score': metric_val})
31
+ self.top_model_paths = sorted(self.top_model_paths, key=lambda o: o['score'], reverse=not self.decreasing)
32
+ if len(self.top_model_paths) > self.top_n:
33
+ self.cleanup()
34
+
35
+ def log_artifact(self, filename, model_path, metric_val):
36
+ artifact = wandb.Artifact(filename, type='model', metadata={'Validation score': metric_val})
37
+ artifact.add_file(model_path)
38
+ wandb.run.log_artifact(artifact)
39
+
40
+ def cleanup(self):
41
+ to_remove = self.top_model_paths[self.top_n:]
42
+ logging.info(f"Removing extra models.. {to_remove}")
43
+ for o in to_remove:
44
+ os.remove(o['path'])
45
+ self.top_model_paths = self.top_model_paths[:self.top_n]
conf/config.yaml CHANGED
@@ -4,12 +4,29 @@ defaults:
4
  - loss: mse
5
  - optimizer: sgd
6
 
 
 
 
7
 
8
  dataloader:
9
  max_seconds: 2
10
  sample_rate: 16000
11
- batch_size: 12
 
12
 
 
 
 
 
 
 
13
 
14
- augmentations:
15
- - random_crop
 
 
 
 
 
 
 
 
4
  - loss: mse
5
  - optimizer: sgd
6
 
7
+ training:
8
+ num_epochs: 5
9
+ model_save_path: /media/public/checkpoints
10
 
11
  dataloader:
12
  max_seconds: 2
13
  sample_rate: 16000
14
+ train_batch_size: 12
15
+ valid_batch_size: 12
16
 
17
+ validation:
18
+ path: /media/public/datasets/denoising/DS_10283_2791/noisy_testset_wav
19
+ wavs:
20
+ easy: p232_284.wav
21
+ medium: p232_071.wav
22
+ hard : p257_171.wav
23
 
24
+
25
+ wandb:
26
+ project: denoising
27
+ log_interval: 100
28
+ api_key: local-e23d01ece807cb31e69b2cf4137e4998e4b9856f
29
+ host: http://localhost:8080/
30
+ notes: "Experiment note"
31
+ tags:
32
+ - baseline
conf/dataset/valentini.yaml CHANGED
@@ -1,3 +1,4 @@
1
-
2
- name: valentini
3
- path: /media/public/dataset/denoising/DS_10283_2791/
 
 
1
+ valentini:
2
+ dataset_path: /media/public/datasets/denoising/DS_10283_2791/
3
+ val_fraction: 0.2
4
+ sample_rate: 48000
datasets/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ from torch.utils.data import Dataset
3
+
4
+ from datasets.valentini import Valentini
5
+ from transforms import Transform
6
+
7
+ DATASETS_POOL = {
8
+ 'valentini': Valentini
9
+ }
10
+
11
+
12
+
13
+ def get_datasets(cfg) -> Tuple[Dataset, Dataset]:
14
+ name, dataset_params = list(cfg['dataset'].items())[0]
15
+ transform = Transform(input_sr=dataset_params['sample_rate'], **cfg['dataloader'])
16
+ train_dataset = DATASETS_POOL[name](valid=False, transform=transform, **dataset_params)
17
+ valid_dataset = DATASETS_POOL[name](valid=True, transform=transform, **dataset_params)
18
+ return train_dataset, valid_dataset
datasets.py → datasets/valentini.py RENAMED
@@ -1,17 +1,19 @@
1
  import torch
2
  from torch.utils.data import Dataset
3
  from pathlib import Path
4
- from utils import load_wav
 
 
5
 
 
6
 
7
  class Valentini(Dataset):
8
- def __init__(self, dataset_path='/media/public/dataset/denoising/DS_10283_2791/', transform=None,
9
- valid=False):
10
  clean_path = Path(dataset_path) / 'clean_trainset_56spk_wav'
11
  noisy_path = Path(dataset_path) / 'noisy_trainset_56spk_wav'
12
  clean_wavs = list(clean_path.glob("*"))
13
  noisy_wavs = list(noisy_path.glob("*"))
14
- valid_threshold = int(len(clean_wavs) * 0.2)
15
  if valid:
16
  self.clean_wavs = clean_wavs[:valid_threshold]
17
  self.noisy_wavs = noisy_wavs[:valid_threshold]
@@ -22,16 +24,17 @@ class Valentini(Dataset):
22
  assert len(self.clean_wavs) == len(self.noisy_wavs)
23
 
24
  self.transform = transform
 
25
 
26
  def __len__(self):
27
  return len(self.clean_wavs)
28
 
29
  def __getitem__(self, idx):
30
- noisy_wav = load_wav(self.noisy_wavs[idx])
31
- clean_wav = load_wav(self.clean_wavs[idx])
32
 
33
  if self.transform:
34
- random_seed = torch.randint(100, (1,))[0]
35
  torch.manual_seed(random_seed)
36
  noisy_wav = self.transform(noisy_wav)
37
  torch.manual_seed(random_seed)
@@ -39,8 +42,5 @@ class Valentini(Dataset):
39
  return noisy_wav, clean_wav
40
 
41
 
42
- DATASETS_POOL = {
43
- 'valentini': Valentini
44
- }
45
 
46
 
 
1
  import torch
2
  from torch.utils.data import Dataset
3
  from pathlib import Path
4
+ import torchaudio
5
+ import numpy as np
6
+ from torchaudio.transforms import Resample
7
 
8
+ HIGH_RANDOM_SEED = 1000
9
 
10
  class Valentini(Dataset):
11
+ def __init__(self, dataset_path, val_fraction, transform=None, valid=False, *args, **kwargs):
 
12
  clean_path = Path(dataset_path) / 'clean_trainset_56spk_wav'
13
  noisy_path = Path(dataset_path) / 'noisy_trainset_56spk_wav'
14
  clean_wavs = list(clean_path.glob("*"))
15
  noisy_wavs = list(noisy_path.glob("*"))
16
+ valid_threshold = int(len(clean_wavs) * val_fraction)
17
  if valid:
18
  self.clean_wavs = clean_wavs[:valid_threshold]
19
  self.noisy_wavs = noisy_wavs[:valid_threshold]
 
24
  assert len(self.clean_wavs) == len(self.noisy_wavs)
25
 
26
  self.transform = transform
27
+ self.valid = valid
28
 
29
  def __len__(self):
30
  return len(self.clean_wavs)
31
 
32
  def __getitem__(self, idx):
33
+ noisy_wav, noisy_sr = torchaudio.load(self.noisy_wavs[idx])
34
+ clean_wav, clean_sr = torchaudio.load(self.clean_wavs[idx])
35
 
36
  if self.transform:
37
+ random_seed = 0 if self.valid else torch.randint(HIGH_RANDOM_SEED, (1,))[0]
38
  torch.manual_seed(random_seed)
39
  noisy_wav = self.transform(noisy_wav)
40
  torch.manual_seed(random_seed)
 
42
  return noisy_wav, clean_wav
43
 
44
 
 
 
 
45
 
46
 
denoisers/demucs.py CHANGED
@@ -34,7 +34,7 @@ class Decoder(torch.nn.Module):
34
  self.glu = torch.nn.GLU(dim=-2)
35
  self.conv2 = torch.nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels,
36
  kernel_size=cfg['conv2']['kernel_size'],
37
- stride=cfg['conv2']['kernel_size'])
38
  self.relu = torch.nn.ReLU()
39
 
40
  def forward(self, x):
 
34
  self.glu = torch.nn.GLU(dim=-2)
35
  self.conv2 = torch.nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels,
36
  kernel_size=cfg['conv2']['kernel_size'],
37
+ stride=cfg['conv2']['stride'])
38
  self.relu = torch.nn.ReLU()
39
 
40
  def forward(self, x):
testing/metrics.py CHANGED
@@ -12,7 +12,7 @@ class Metrics:
12
  self.snr = SignalNoiseRatio()
13
 
14
  def calculate(self, denoised, clean):
15
- return {'PESQ': self.nb_pesq(denoised, clean),
16
- 'STOI': self.stoi(denoised, clean)}
17
 
18
 
 
12
  self.snr = SignalNoiseRatio()
13
 
14
  def calculate(self, denoised, clean):
15
+ return {'PESQ': self.nb_pesq(denoised, clean).item(),
16
+ 'STOI': self.stoi(denoised, clean).item()}
17
 
18
 
train.py CHANGED
@@ -1,94 +1,61 @@
1
  import os
2
- from torch.utils.tensorboard import SummaryWriter
3
  import torch
4
- from torch.nn import Sequential
5
  from torch.utils.data import DataLoader
6
- from datetime import datetime
7
- from torchvision.transforms import RandomCrop
8
- from utils import load_wav
9
- from denoisers.demucs import Demucs
10
  from pathlib import Path
11
  from omegaconf import DictConfig
 
 
12
 
13
- from optimizers import OPTIMIZERS_POOL
14
- from losses import LOSSES
15
- from datasets import DATASETS_POOL
16
  from denoisers import get_model
17
  from optimizers import get_optimizer
18
  from losses import get_loss
19
-
 
 
20
 
21
  os.environ['CUDA_VISIBLE_DEVICES'] = "1"
22
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
 
24
- #
25
- #
26
- # DATASET_PATH = Path('/media/public/dataset/denoising/DS_10283_2791/')
27
- # VALID_WAVS = {'hard': 'p257_171.wav',
28
- # 'medium': 'p232_071.wav',
29
- # 'easy': 'p232_284.wav'}
30
- # MAX_SECONDS = 2
31
- # SAMPLE_RATE = 16000
32
- #
33
- # transform = Sequential(RandomCrop((1, int(MAX_SECONDS * SAMPLE_RATE)), pad_if_needed=True))
34
- #
35
- # training_loader = DataLoader(Valentini(valid=False, transform=transform), batch_size=12, shuffle=True)
36
- # validation_loader = DataLoader(Valentini(valid=True, transform=transform), batch_size=12, shuffle=True)
37
 
 
 
 
 
 
 
 
38
 
 
 
39
 
40
- def train(cfg: DictConfig):
41
- timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
42
- model = get_model(cfg['model'])
43
  optimizer = get_optimizer(model.parameters(), cfg['optimizer'])
44
  loss_fn = get_loss(cfg['loss'])
 
45
 
46
- writer = SummaryWriter('runs/denoising_trainer_{}'.format(timestamp))
47
- epoch_number = 0
48
-
49
- EPOCHS = 5
50
 
51
- best_vloss = 1_000_000.
52
-
53
- for tag, wav_path in VALID_WAVS.items():
54
- wav = load_wav(DATASET_PATH / 'noisy_testset_wav' / wav_path)
55
- writer.add_audio(tag=tag, snd_tensor=wav, sample_rate=SAMPLE_RATE)
56
- writer.flush()
57
-
58
- for epoch in range(EPOCHS):
59
- print('EPOCH {}:'.format(epoch_number + 1))
60
 
 
61
  model.train(True)
62
-
63
- running_loss = 0.
64
- last_loss = 0.
65
-
66
  for i, data in enumerate(training_loader):
67
  inputs, labels = data
68
  inputs, labels = inputs.to(device), labels.to(device)
69
-
70
  optimizer.zero_grad()
71
-
72
  outputs = model(inputs)
73
-
74
  loss = loss_fn(outputs, labels)
75
  loss.backward()
76
-
77
  optimizer.step()
78
 
79
- running_loss += loss.item()
80
- if i % 1000 == 999:
81
- last_loss = running_loss / 1000 # loss per batch
82
- print(' batch {} loss: {}'.format(i + 1, last_loss))
83
- tb_x = epoch_number * len(training_loader) + i + 1
84
- writer.add_scalar('Loss/train', last_loss, tb_x)
85
- running_loss = 0.
86
-
87
- avg_loss = last_loss
88
 
89
  model.train(False)
90
 
91
- running_vloss = 0.0
92
  with torch.no_grad():
93
  for i, vdata in enumerate(validation_loader):
94
  vinputs, vlabels = vdata
@@ -96,28 +63,29 @@ def train(cfg: DictConfig):
96
  voutputs = model(vinputs)
97
  vloss = loss_fn(voutputs, vlabels)
98
  running_vloss += vloss
 
 
 
 
99
 
100
- avg_vloss = running_vloss / (i + 1)
101
- print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
 
102
 
103
- writer.add_scalars('Training vs. Validation Loss',
104
- {'Training': avg_loss, 'Validation': avg_vloss},
105
- epoch_number + 1)
106
- for tag, wav_path in VALID_WAVS.items():
107
- wav = load_wav(DATASET_PATH / 'noisy_testset_wav' / wav_path)
 
108
  wav = torch.reshape(wav, (1, 1, -1)).to(device)
109
  prediction = model(wav)
110
- writer.add_audio(tag=f"Model predicted {tag} on epoch {epoch}",
111
- snd_tensor=prediction,
112
- sample_rate=SAMPLE_RATE)
113
- writer.flush()
114
-
115
- if avg_vloss < best_vloss:
116
- best_vloss = avg_vloss
117
- model_path = 'checkpoints/model_{}_{}'.format(timestamp, epoch_number)
118
- torch.save(model.state_dict(), model_path)
119
 
120
- epoch_number += 1
121
 
122
 
123
  if __name__ == '__main__':
 
1
  import os
 
2
  import torch
 
3
  from torch.utils.data import DataLoader
 
 
 
 
4
  from pathlib import Path
5
  from omegaconf import DictConfig
6
+ import wandb
7
+ import torchaudio
8
 
9
+ from checkpoing_saver import CheckpointSaver
 
 
10
  from denoisers import get_model
11
  from optimizers import get_optimizer
12
  from losses import get_loss
13
+ from datasets import get_datasets
14
+ from testing.metrics import Metrics
15
+ import omegaconf
16
 
17
  os.environ['CUDA_VISIBLE_DEVICES'] = "1"
18
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ def train(cfg: DictConfig):
22
+ wandb.login(key=cfg['wandb']['api_key'], host=cfg['wandb']['host'])
23
+ wandb.init(project=cfg['wandb']['project'],
24
+ notes=cfg['wandb']['notes'],
25
+ tags=cfg['wandb']['tags'],
26
+ config=omegaconf.OmegaConf.to_container(
27
+ cfg, resolve=True, throw_on_missing=True))
28
 
29
+ checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'])
30
+ metrics = Metrics(rate=cfg['dataloader']['sample_rate'])
31
 
32
+ model = get_model(cfg['model']).to(device)
 
 
33
  optimizer = get_optimizer(model.parameters(), cfg['optimizer'])
34
  loss_fn = get_loss(cfg['loss'])
35
+ train_dataset, valid_dataset = get_datasets(cfg)
36
 
37
+ training_loader = DataLoader(train_dataset, batch_size=cfg['dataloader']['train_batch_size'], shuffle=True)
38
+ validation_loader = DataLoader(valid_dataset, batch_size=cfg['dataloader']['valid_batch_size'], shuffle=True)
 
 
39
 
40
+ wandb.watch(model, log_freq=100)
 
 
 
 
 
 
 
 
41
 
42
+ for epoch in range(cfg['training']['num_epochs']):
43
  model.train(True)
 
 
 
 
44
  for i, data in enumerate(training_loader):
45
  inputs, labels = data
46
  inputs, labels = inputs.to(device), labels.to(device)
 
47
  optimizer.zero_grad()
 
48
  outputs = model(inputs)
 
49
  loss = loss_fn(outputs, labels)
50
  loss.backward()
 
51
  optimizer.step()
52
 
53
+ if i % cfg['wandb']['log_interval'] == 0:
54
+ wandb.log({"loss": loss})
 
 
 
 
 
 
 
55
 
56
  model.train(False)
57
 
58
+ running_vloss, running_pesq, running_stoi = 0.0, 0.0, 0.0
59
  with torch.no_grad():
60
  for i, vdata in enumerate(validation_loader):
61
  vinputs, vlabels = vdata
 
63
  voutputs = model(vinputs)
64
  vloss = loss_fn(voutputs, vlabels)
65
  running_vloss += vloss
66
+ running_metrics = metrics.calculate(denoised=voutputs, clean=vlabels)
67
+ running_pesq += running_metrics['PESQ']
68
+ running_stoi += running_metrics['STOI']
69
+
70
 
71
+ avg_vloss = running_vloss / len(validation_loader)
72
+ avg_pesq = running_pesq / len(validation_loader)
73
+ avg_stoi = running_stoi / len(validation_loader)
74
 
75
+ wandb.log({"valid_loss": avg_vloss,
76
+ "valid_pesq": avg_pesq,
77
+ "valid_stoi": avg_stoi})
78
+
79
+ for tag, wav_path in cfg['validation']['wavs'].items():
80
+ wav, rate = torchaudio.load(Path(cfg['validation']['path']) / wav_path)
81
  wav = torch.reshape(wav, (1, 1, -1)).to(device)
82
  prediction = model(wav)
83
+ wandb.log({
84
+ f"{tag}_epoch_{epoch}": wandb.Audio(
85
+ prediction.cpu()[0][0],
86
+ sample_rate=rate)})
 
 
 
 
 
87
 
88
+ checkpoint_saver(model, epoch, metric_val=avg_pesq)
89
 
90
 
91
  if __name__ == '__main__':
transforms.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torchaudio.transforms import Resample
4
+ from torchvision.transforms import RandomCrop
5
+
6
+ class Transform(torch.nn.Module):
7
+ def __init__(
8
+ self,
9
+ input_sr,
10
+ sample_rate,
11
+ max_seconds,
12
+ *args,
13
+ **kwargs
14
+ ):
15
+ super().__init__()
16
+ self.resample = Resample(orig_freq=input_sr, new_freq=sample_rate)
17
+ self.random_crop = RandomCrop((1, int(max_seconds * sample_rate)), pad_if_needed=True)
18
+
19
+ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
20
+ resampled = self.resample(waveform)
21
+ croped = self.random_crop(resampled)
22
+ return croped
utils.py CHANGED
@@ -14,11 +14,6 @@ def collect_valentini_paths(dataset_path):
14
  return clean_wavs, noisy_wavs
15
 
16
 
17
- def load_wav(path):
18
- wav, org_sr = torchaudio.load(path)
19
- wav = torchaudio.functional.resample(wav, orig_freq=org_sr, new_freq=16000)
20
- return wav
21
-
22
 
23
 
24
  def plot_spectrogram(stft, title="Spectrogram", xlim=None):
 
14
  return clean_wavs, noisy_wavs
15
 
16
 
 
 
 
 
 
17
 
18
 
19
  def plot_spectrogram(stft, title="Spectrogram", xlim=None):