BorisovMaksim commited on
Commit
95d8ea8
·
1 Parent(s): 1160793

add try except for calculating pesq scores

Browse files
Files changed (7) hide show
  1. checkpoing_saver.py +10 -3
  2. conf/config.yaml +4 -1
  3. datasets/minimal.py +1 -2
  4. losses.py +9 -6
  5. main.py +2 -1
  6. testing/metrics.py +14 -3
  7. train.py +19 -13
checkpoing_saver.py CHANGED
@@ -19,14 +19,21 @@ class CheckpointSaver:
19
  self.best_metric_val = np.Inf if decreasing else -np.Inf
20
  self.run_name = run_name
21
 
22
- def __call__(self, model, epoch, metric_val):
23
- model_path = os.path.join(self.dirpath, self.run_name, model.__class__.__name__ + f'_epoch{epoch}.pt')
 
24
  save = metric_val < self.best_metric_val if self.decreasing else metric_val > self.best_metric_val
25
  if save:
26
  logging.info(
27
  f"Current metric value better than {metric_val} better than best {self.best_metric_val}, saving model at {model_path}, & logging model weights to W&B.")
28
  self.best_metric_val = metric_val
29
- torch.save(model.state_dict(), model_path)
 
 
 
 
 
 
30
  self.log_artifact(f'model-ckpt-epoch-{epoch}.pt', model_path, metric_val)
31
  self.top_model_paths.append({'path': model_path, 'score': metric_val})
32
  self.top_model_paths = sorted(self.top_model_paths, key=lambda o: o['score'], reverse=not self.decreasing)
 
19
  self.best_metric_val = np.Inf if decreasing else -np.Inf
20
  self.run_name = run_name
21
 
22
+
23
+ def __call__(self, model, epoch, metric_val, optimizer, loss):
24
+ model_path = os.path.join(self.dirpath, model.__class__.__name__ + f'_{self.run_name}_epoch{epoch}.pt')
25
  save = metric_val < self.best_metric_val if self.decreasing else metric_val > self.best_metric_val
26
  if save:
27
  logging.info(
28
  f"Current metric value better than {metric_val} better than best {self.best_metric_val}, saving model at {model_path}, & logging model weights to W&B.")
29
  self.best_metric_val = metric_val
30
+ torch.save(
31
+ { # Save our checkpoint loc
32
+ 'epoch': epoch,
33
+ 'model_state_dict': model.state_dict(),
34
+ 'optimizer_state_dict': optimizer.state_dict(),
35
+ 'loss': loss,
36
+ }, model_path)
37
  self.log_artifact(f'model-ckpt-epoch-{epoch}.pt', model_path, metric_val)
38
  self.top_model_paths.append({'path': model_path, 'score': metric_val})
39
  self.top_model_paths = sorted(self.top_model_paths, key=lambda o: o['score'], reverse=not self.decreasing)
conf/config.yaml CHANGED
@@ -20,10 +20,13 @@ validation:
20
 
21
 
22
  wandb:
 
23
  project: denoising
24
  log_interval: 100
25
  api_key: local-e23d01ece807cb31e69b2cf4137e4998e4b9856f
26
  host: http://localhost:8080/
27
  notes: "Experiment note"
28
  tags:
29
- - baseline
 
 
 
20
 
21
 
22
  wandb:
23
+ run_name: default
24
  project: denoising
25
  log_interval: 100
26
  api_key: local-e23d01ece807cb31e69b2cf4137e4998e4b9856f
27
  host: http://localhost:8080/
28
  notes: "Experiment note"
29
  tags:
30
+ - baseline
31
+
32
+ gpu: 0
datasets/minimal.py CHANGED
@@ -18,7 +18,6 @@ class Minimal(Dataset):
18
  return len(self.wavs)
19
 
20
  def __getitem__(self, idx):
21
- wav, rate = torchaudio.load(self.wavs[idx])
22
  wav = self.resampler(wav)
23
- wav = torch.reshape(wav, (1, 1, -1))
24
  return wav, self.target_rate
 
18
  return len(self.wavs)
19
 
20
  def __getitem__(self, idx):
21
+ wav, rate = torchaudio.load(Path(self.dataset_path) / self.wavs[idx])
22
  wav = self.resampler(wav)
 
23
  return wav, self.target_rate
losses.py CHANGED
@@ -12,6 +12,8 @@
12
  import torch
13
  import torch.nn.functional as F
14
 
 
 
15
  """STFT-based Loss modules."""
16
 
17
 
@@ -26,7 +28,8 @@ def stft(x, fft_size, hop_size, win_length, window):
26
  Returns:
27
  Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
28
  """
29
- x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
 
30
  real = x_stft[..., 0]
31
  imag = x_stft[..., 1]
32
 
@@ -154,7 +157,7 @@ class L1_Multi_STFT(torch.nn.Module):
154
  """Initialize STFT loss module."""
155
  super(L1_Multi_STFT, self).__init__()
156
  self.multi_STFT_loss = MultiResolutionSTFTLoss()
157
- self.l1_loss = torch.nn.L1Loss()
158
 
159
  def forward(self, x, y):
160
  """Calculate forward propagation.
@@ -173,10 +176,10 @@ class L1_Multi_STFT(torch.nn.Module):
173
  LOSSES = {
174
  'mse': torch.nn.MSELoss(),
175
  'L1': torch.nn.L1Loss(),
176
- 'Multi_STFT': MultiResolutionSTFTLoss,
177
- 'L1_Multi_STFT': L1_Multi_STFT
178
  }
179
 
180
 
181
- def get_loss(loss_config):
182
- return LOSSES[loss_config['name']]
 
12
  import torch
13
  import torch.nn.functional as F
14
 
15
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+
17
  """STFT-based Loss modules."""
18
 
19
 
 
28
  Returns:
29
  Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
30
  """
31
+ x_stft = torch.stft(x[:, 0, :], fft_size, hop_size, win_length, window, return_complex=True)
32
+ x_stft = torch.view_as_real(x_stft)
33
  real = x_stft[..., 0]
34
  imag = x_stft[..., 1]
35
 
 
157
  """Initialize STFT loss module."""
158
  super(L1_Multi_STFT, self).__init__()
159
  self.multi_STFT_loss = MultiResolutionSTFTLoss()
160
+ self.l1_loss = torch.nn.L1Loss()
161
 
162
  def forward(self, x, y):
163
  """Calculate forward propagation.
 
176
  LOSSES = {
177
  'mse': torch.nn.MSELoss(),
178
  'L1': torch.nn.L1Loss(),
179
+ 'Multi_STFT': MultiResolutionSTFTLoss(),
180
+ 'L1_Multi_STFT': L1_Multi_STFT()
181
  }
182
 
183
 
184
+ def get_loss(loss_config, device):
185
+ return LOSSES[loss_config['name']].to(device)
main.py CHANGED
@@ -1,8 +1,9 @@
1
  import hydra
2
- from omegaconf import DictConfig, OmegaConf
3
  from train import train
4
 
5
 
 
6
  @hydra.main(version_base=None, config_path="conf", config_name="config")
7
  def main(cfg: DictConfig):
8
  train(cfg)
 
1
  import hydra
2
+ from omegaconf import DictConfig
3
  from train import train
4
 
5
 
6
+
7
  @hydra.main(version_base=None, config_path="conf", config_name="config")
8
  def main(cfg: DictConfig):
9
  train(cfg)
testing/metrics.py CHANGED
@@ -1,10 +1,10 @@
 
1
  from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
2
  from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
3
  import torch
4
  import torchaudio
5
  from torchmetrics import SignalNoiseRatio
6
 
7
-
8
  class Metrics:
9
  def __init__(self, rate=16000):
10
  self.nb_pesq = PerceptualEvaluationSpeechQuality(rate, 'wb')
@@ -12,7 +12,18 @@ class Metrics:
12
  self.snr = SignalNoiseRatio()
13
 
14
  def calculate(self, denoised, clean):
15
- return {'PESQ': self.nb_pesq(denoised, clean).item(),
16
- 'STOI': self.stoi(denoised, clean).item()}
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
 
1
+ import pesq
2
  from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
3
  from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
4
  import torch
5
  import torchaudio
6
  from torchmetrics import SignalNoiseRatio
7
 
 
8
  class Metrics:
9
  def __init__(self, rate=16000):
10
  self.nb_pesq = PerceptualEvaluationSpeechQuality(rate, 'wb')
 
12
  self.snr = SignalNoiseRatio()
13
 
14
  def calculate(self, denoised, clean):
15
+ pesq_scores, stoi_scores = 0, 0
16
+ for denoised_wav, clean_wav in zip(denoised, clean):
17
+ try:
18
+ pesq_scores += self.nb_pesq(denoised_wav, clean_wav).item()
19
+ stoi_scores += self.stoi(denoised_wav, clean_wav).item()
20
+ except pesq.NoUtterancesError as e:
21
+ print(e)
22
+ except ValueError as e:
23
+ print(e)
24
+
25
+
26
+ return {'PESQ': pesq_scores,
27
+ 'STOI': stoi_scores}
28
 
29
 
train.py CHANGED
@@ -13,24 +13,24 @@ from datasets import get_datasets
13
  from testing.metrics import Metrics
14
  from datasets.minimal import Minimal
15
 
16
- os.environ['CUDA_VISIBLE_DEVICES'] = "1"
17
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
-
19
 
20
  def train(cfg: DictConfig):
 
 
21
  wandb.login(key=cfg['wandb']['api_key'], host=cfg['wandb']['host'])
22
  wandb.init(project=cfg['wandb']['project'],
23
  notes=cfg['wandb']['notes'],
24
  tags=cfg['wandb']['tags'],
25
  config=omegaconf.OmegaConf.to_container(
26
  cfg, resolve=True, throw_on_missing=True))
 
27
 
28
  checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'], run_name=wandb.run.name)
29
  metrics = Metrics(rate=cfg['dataloader']['sample_rate'])
30
 
31
  model = get_model(cfg['model']).to(device)
32
  optimizer = get_optimizer(model.parameters(), cfg['optimizer'])
33
- loss_fn = get_loss(cfg['loss'])
34
  train_dataset, valid_dataset = get_datasets(cfg)
35
  minimal_dataset = Minimal(cfg)
36
 
@@ -70,12 +70,13 @@ def train(cfg: DictConfig):
70
  running_stoi += running_metrics['STOI']
71
 
72
  if phase == 'train' and i % cfg['wandb']['log_interval'] == 0:
73
- wandb.log({"train_loss": running_loss / (i + 1),
74
- "train_pesq": running_pesq / (i + 1),
75
- "train_stoi": running_stoi / (i + 1)})
76
- epoch_loss = running_loss / len(dataloaders[phase])
77
- eposh_pesq = running_pesq / len(dataloaders[phase])
78
- eposh_stoi = running_stoi / len(dataloaders[phase])
 
79
 
80
  wandb.log({f"{phase}_loss": epoch_loss,
81
  f"{phase}_pesq": eposh_pesq,
@@ -83,10 +84,15 @@ def train(cfg: DictConfig):
83
 
84
  if phase == 'val':
85
  for i, (wav, rate) in enumerate(dataloaders['minimal']):
86
- prediction = model(wav)
87
  wandb.log({
88
  f"{i}_example": wandb.Audio(
89
- prediction.cpu()[0][0],
90
  sample_rate=rate)})
91
 
92
- checkpoint_saver(model, epoch, metric_val=eposh_pesq)
 
 
 
 
 
 
13
  from testing.metrics import Metrics
14
  from datasets.minimal import Minimal
15
 
 
 
 
16
 
17
  def train(cfg: DictConfig):
18
+ device = torch.device(f'cuda:{cfg.gpu}' if torch.cuda.is_available() else 'cpu')
19
+
20
  wandb.login(key=cfg['wandb']['api_key'], host=cfg['wandb']['host'])
21
  wandb.init(project=cfg['wandb']['project'],
22
  notes=cfg['wandb']['notes'],
23
  tags=cfg['wandb']['tags'],
24
  config=omegaconf.OmegaConf.to_container(
25
  cfg, resolve=True, throw_on_missing=True))
26
+ wandb.run.name = cfg['wandb']['run_name']
27
 
28
  checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'], run_name=wandb.run.name)
29
  metrics = Metrics(rate=cfg['dataloader']['sample_rate'])
30
 
31
  model = get_model(cfg['model']).to(device)
32
  optimizer = get_optimizer(model.parameters(), cfg['optimizer'])
33
+ loss_fn = get_loss(cfg['loss'], device)
34
  train_dataset, valid_dataset = get_datasets(cfg)
35
  minimal_dataset = Minimal(cfg)
36
 
 
70
  running_stoi += running_metrics['STOI']
71
 
72
  if phase == 'train' and i % cfg['wandb']['log_interval'] == 0:
73
+ wandb.log({"train_loss": running_loss / (i + 1) / inputs.size(0),
74
+ "train_pesq": running_pesq / (i + 1) / inputs.size(0),
75
+ "train_stoi": running_stoi / (i + 1) / inputs.size(0)})
76
+
77
+ epoch_loss = running_loss / len(dataloaders[phase].dataset)
78
+ eposh_pesq = running_pesq / len(dataloaders[phase].dataset)
79
+ eposh_stoi = running_stoi / len(dataloaders[phase].dataset)
80
 
81
  wandb.log({f"{phase}_loss": epoch_loss,
82
  f"{phase}_pesq": eposh_pesq,
 
84
 
85
  if phase == 'val':
86
  for i, (wav, rate) in enumerate(dataloaders['minimal']):
87
+ prediction = model(wav.to(device))
88
  wandb.log({
89
  f"{i}_example": wandb.Audio(
90
+ prediction.detach().cpu().numpy()[0][0],
91
  sample_rate=rate)})
92
 
93
+ checkpoint_saver(model, epoch, metric_val=eposh_pesq,
94
+ optimizer=optimizer, loss=epoch_loss)
95
+
96
+
97
+ if __name__ == "__main__":
98
+ pass