Spaces:
Runtime error
Runtime error
Commit
·
95d8ea8
1
Parent(s):
1160793
add try except for calculating pesq scores
Browse files- checkpoing_saver.py +10 -3
- conf/config.yaml +4 -1
- datasets/minimal.py +1 -2
- losses.py +9 -6
- main.py +2 -1
- testing/metrics.py +14 -3
- train.py +19 -13
checkpoing_saver.py
CHANGED
@@ -19,14 +19,21 @@ class CheckpointSaver:
|
|
19 |
self.best_metric_val = np.Inf if decreasing else -np.Inf
|
20 |
self.run_name = run_name
|
21 |
|
22 |
-
|
23 |
-
|
|
|
24 |
save = metric_val < self.best_metric_val if self.decreasing else metric_val > self.best_metric_val
|
25 |
if save:
|
26 |
logging.info(
|
27 |
f"Current metric value better than {metric_val} better than best {self.best_metric_val}, saving model at {model_path}, & logging model weights to W&B.")
|
28 |
self.best_metric_val = metric_val
|
29 |
-
torch.save(
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
self.log_artifact(f'model-ckpt-epoch-{epoch}.pt', model_path, metric_val)
|
31 |
self.top_model_paths.append({'path': model_path, 'score': metric_val})
|
32 |
self.top_model_paths = sorted(self.top_model_paths, key=lambda o: o['score'], reverse=not self.decreasing)
|
|
|
19 |
self.best_metric_val = np.Inf if decreasing else -np.Inf
|
20 |
self.run_name = run_name
|
21 |
|
22 |
+
|
23 |
+
def __call__(self, model, epoch, metric_val, optimizer, loss):
|
24 |
+
model_path = os.path.join(self.dirpath, model.__class__.__name__ + f'_{self.run_name}_epoch{epoch}.pt')
|
25 |
save = metric_val < self.best_metric_val if self.decreasing else metric_val > self.best_metric_val
|
26 |
if save:
|
27 |
logging.info(
|
28 |
f"Current metric value better than {metric_val} better than best {self.best_metric_val}, saving model at {model_path}, & logging model weights to W&B.")
|
29 |
self.best_metric_val = metric_val
|
30 |
+
torch.save(
|
31 |
+
{ # Save our checkpoint loc
|
32 |
+
'epoch': epoch,
|
33 |
+
'model_state_dict': model.state_dict(),
|
34 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
35 |
+
'loss': loss,
|
36 |
+
}, model_path)
|
37 |
self.log_artifact(f'model-ckpt-epoch-{epoch}.pt', model_path, metric_val)
|
38 |
self.top_model_paths.append({'path': model_path, 'score': metric_val})
|
39 |
self.top_model_paths = sorted(self.top_model_paths, key=lambda o: o['score'], reverse=not self.decreasing)
|
conf/config.yaml
CHANGED
@@ -20,10 +20,13 @@ validation:
|
|
20 |
|
21 |
|
22 |
wandb:
|
|
|
23 |
project: denoising
|
24 |
log_interval: 100
|
25 |
api_key: local-e23d01ece807cb31e69b2cf4137e4998e4b9856f
|
26 |
host: http://localhost:8080/
|
27 |
notes: "Experiment note"
|
28 |
tags:
|
29 |
-
- baseline
|
|
|
|
|
|
20 |
|
21 |
|
22 |
wandb:
|
23 |
+
run_name: default
|
24 |
project: denoising
|
25 |
log_interval: 100
|
26 |
api_key: local-e23d01ece807cb31e69b2cf4137e4998e4b9856f
|
27 |
host: http://localhost:8080/
|
28 |
notes: "Experiment note"
|
29 |
tags:
|
30 |
+
- baseline
|
31 |
+
|
32 |
+
gpu: 0
|
datasets/minimal.py
CHANGED
@@ -18,7 +18,6 @@ class Minimal(Dataset):
|
|
18 |
return len(self.wavs)
|
19 |
|
20 |
def __getitem__(self, idx):
|
21 |
-
wav, rate = torchaudio.load(self.wavs[idx])
|
22 |
wav = self.resampler(wav)
|
23 |
-
wav = torch.reshape(wav, (1, 1, -1))
|
24 |
return wav, self.target_rate
|
|
|
18 |
return len(self.wavs)
|
19 |
|
20 |
def __getitem__(self, idx):
|
21 |
+
wav, rate = torchaudio.load(Path(self.dataset_path) / self.wavs[idx])
|
22 |
wav = self.resampler(wav)
|
|
|
23 |
return wav, self.target_rate
|
losses.py
CHANGED
@@ -12,6 +12,8 @@
|
|
12 |
import torch
|
13 |
import torch.nn.functional as F
|
14 |
|
|
|
|
|
15 |
"""STFT-based Loss modules."""
|
16 |
|
17 |
|
@@ -26,7 +28,8 @@ def stft(x, fft_size, hop_size, win_length, window):
|
|
26 |
Returns:
|
27 |
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
28 |
"""
|
29 |
-
x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
|
|
|
30 |
real = x_stft[..., 0]
|
31 |
imag = x_stft[..., 1]
|
32 |
|
@@ -154,7 +157,7 @@ class L1_Multi_STFT(torch.nn.Module):
|
|
154 |
"""Initialize STFT loss module."""
|
155 |
super(L1_Multi_STFT, self).__init__()
|
156 |
self.multi_STFT_loss = MultiResolutionSTFTLoss()
|
157 |
-
self.l1_loss =
|
158 |
|
159 |
def forward(self, x, y):
|
160 |
"""Calculate forward propagation.
|
@@ -173,10 +176,10 @@ class L1_Multi_STFT(torch.nn.Module):
|
|
173 |
LOSSES = {
|
174 |
'mse': torch.nn.MSELoss(),
|
175 |
'L1': torch.nn.L1Loss(),
|
176 |
-
'Multi_STFT': MultiResolutionSTFTLoss,
|
177 |
-
'L1_Multi_STFT': L1_Multi_STFT
|
178 |
}
|
179 |
|
180 |
|
181 |
-
def get_loss(loss_config):
|
182 |
-
return LOSSES[loss_config['name']]
|
|
|
12 |
import torch
|
13 |
import torch.nn.functional as F
|
14 |
|
15 |
+
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
16 |
+
|
17 |
"""STFT-based Loss modules."""
|
18 |
|
19 |
|
|
|
28 |
Returns:
|
29 |
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
30 |
"""
|
31 |
+
x_stft = torch.stft(x[:, 0, :], fft_size, hop_size, win_length, window, return_complex=True)
|
32 |
+
x_stft = torch.view_as_real(x_stft)
|
33 |
real = x_stft[..., 0]
|
34 |
imag = x_stft[..., 1]
|
35 |
|
|
|
157 |
"""Initialize STFT loss module."""
|
158 |
super(L1_Multi_STFT, self).__init__()
|
159 |
self.multi_STFT_loss = MultiResolutionSTFTLoss()
|
160 |
+
self.l1_loss = torch.nn.L1Loss()
|
161 |
|
162 |
def forward(self, x, y):
|
163 |
"""Calculate forward propagation.
|
|
|
176 |
LOSSES = {
|
177 |
'mse': torch.nn.MSELoss(),
|
178 |
'L1': torch.nn.L1Loss(),
|
179 |
+
'Multi_STFT': MultiResolutionSTFTLoss(),
|
180 |
+
'L1_Multi_STFT': L1_Multi_STFT()
|
181 |
}
|
182 |
|
183 |
|
184 |
+
def get_loss(loss_config, device):
|
185 |
+
return LOSSES[loss_config['name']].to(device)
|
main.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
import hydra
|
2 |
-
from omegaconf import DictConfig
|
3 |
from train import train
|
4 |
|
5 |
|
|
|
6 |
@hydra.main(version_base=None, config_path="conf", config_name="config")
|
7 |
def main(cfg: DictConfig):
|
8 |
train(cfg)
|
|
|
1 |
import hydra
|
2 |
+
from omegaconf import DictConfig
|
3 |
from train import train
|
4 |
|
5 |
|
6 |
+
|
7 |
@hydra.main(version_base=None, config_path="conf", config_name="config")
|
8 |
def main(cfg: DictConfig):
|
9 |
train(cfg)
|
testing/metrics.py
CHANGED
@@ -1,10 +1,10 @@
|
|
|
|
1 |
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
|
2 |
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
|
3 |
import torch
|
4 |
import torchaudio
|
5 |
from torchmetrics import SignalNoiseRatio
|
6 |
|
7 |
-
|
8 |
class Metrics:
|
9 |
def __init__(self, rate=16000):
|
10 |
self.nb_pesq = PerceptualEvaluationSpeechQuality(rate, 'wb')
|
@@ -12,7 +12,18 @@ class Metrics:
|
|
12 |
self.snr = SignalNoiseRatio()
|
13 |
|
14 |
def calculate(self, denoised, clean):
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
|
|
|
1 |
+
import pesq
|
2 |
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
|
3 |
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
|
4 |
import torch
|
5 |
import torchaudio
|
6 |
from torchmetrics import SignalNoiseRatio
|
7 |
|
|
|
8 |
class Metrics:
|
9 |
def __init__(self, rate=16000):
|
10 |
self.nb_pesq = PerceptualEvaluationSpeechQuality(rate, 'wb')
|
|
|
12 |
self.snr = SignalNoiseRatio()
|
13 |
|
14 |
def calculate(self, denoised, clean):
|
15 |
+
pesq_scores, stoi_scores = 0, 0
|
16 |
+
for denoised_wav, clean_wav in zip(denoised, clean):
|
17 |
+
try:
|
18 |
+
pesq_scores += self.nb_pesq(denoised_wav, clean_wav).item()
|
19 |
+
stoi_scores += self.stoi(denoised_wav, clean_wav).item()
|
20 |
+
except pesq.NoUtterancesError as e:
|
21 |
+
print(e)
|
22 |
+
except ValueError as e:
|
23 |
+
print(e)
|
24 |
+
|
25 |
+
|
26 |
+
return {'PESQ': pesq_scores,
|
27 |
+
'STOI': stoi_scores}
|
28 |
|
29 |
|
train.py
CHANGED
@@ -13,24 +13,24 @@ from datasets import get_datasets
|
|
13 |
from testing.metrics import Metrics
|
14 |
from datasets.minimal import Minimal
|
15 |
|
16 |
-
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
|
17 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
18 |
-
|
19 |
|
20 |
def train(cfg: DictConfig):
|
|
|
|
|
21 |
wandb.login(key=cfg['wandb']['api_key'], host=cfg['wandb']['host'])
|
22 |
wandb.init(project=cfg['wandb']['project'],
|
23 |
notes=cfg['wandb']['notes'],
|
24 |
tags=cfg['wandb']['tags'],
|
25 |
config=omegaconf.OmegaConf.to_container(
|
26 |
cfg, resolve=True, throw_on_missing=True))
|
|
|
27 |
|
28 |
checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'], run_name=wandb.run.name)
|
29 |
metrics = Metrics(rate=cfg['dataloader']['sample_rate'])
|
30 |
|
31 |
model = get_model(cfg['model']).to(device)
|
32 |
optimizer = get_optimizer(model.parameters(), cfg['optimizer'])
|
33 |
-
loss_fn = get_loss(cfg['loss'])
|
34 |
train_dataset, valid_dataset = get_datasets(cfg)
|
35 |
minimal_dataset = Minimal(cfg)
|
36 |
|
@@ -70,12 +70,13 @@ def train(cfg: DictConfig):
|
|
70 |
running_stoi += running_metrics['STOI']
|
71 |
|
72 |
if phase == 'train' and i % cfg['wandb']['log_interval'] == 0:
|
73 |
-
wandb.log({"train_loss": running_loss / (i + 1),
|
74 |
-
"train_pesq": running_pesq / (i + 1),
|
75 |
-
"train_stoi": running_stoi / (i + 1)})
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
79 |
|
80 |
wandb.log({f"{phase}_loss": epoch_loss,
|
81 |
f"{phase}_pesq": eposh_pesq,
|
@@ -83,10 +84,15 @@ def train(cfg: DictConfig):
|
|
83 |
|
84 |
if phase == 'val':
|
85 |
for i, (wav, rate) in enumerate(dataloaders['minimal']):
|
86 |
-
prediction = model(wav)
|
87 |
wandb.log({
|
88 |
f"{i}_example": wandb.Audio(
|
89 |
-
prediction.cpu()[0][0],
|
90 |
sample_rate=rate)})
|
91 |
|
92 |
-
checkpoint_saver(model, epoch, metric_val=eposh_pesq
|
|
|
|
|
|
|
|
|
|
|
|
13 |
from testing.metrics import Metrics
|
14 |
from datasets.minimal import Minimal
|
15 |
|
|
|
|
|
|
|
16 |
|
17 |
def train(cfg: DictConfig):
|
18 |
+
device = torch.device(f'cuda:{cfg.gpu}' if torch.cuda.is_available() else 'cpu')
|
19 |
+
|
20 |
wandb.login(key=cfg['wandb']['api_key'], host=cfg['wandb']['host'])
|
21 |
wandb.init(project=cfg['wandb']['project'],
|
22 |
notes=cfg['wandb']['notes'],
|
23 |
tags=cfg['wandb']['tags'],
|
24 |
config=omegaconf.OmegaConf.to_container(
|
25 |
cfg, resolve=True, throw_on_missing=True))
|
26 |
+
wandb.run.name = cfg['wandb']['run_name']
|
27 |
|
28 |
checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'], run_name=wandb.run.name)
|
29 |
metrics = Metrics(rate=cfg['dataloader']['sample_rate'])
|
30 |
|
31 |
model = get_model(cfg['model']).to(device)
|
32 |
optimizer = get_optimizer(model.parameters(), cfg['optimizer'])
|
33 |
+
loss_fn = get_loss(cfg['loss'], device)
|
34 |
train_dataset, valid_dataset = get_datasets(cfg)
|
35 |
minimal_dataset = Minimal(cfg)
|
36 |
|
|
|
70 |
running_stoi += running_metrics['STOI']
|
71 |
|
72 |
if phase == 'train' and i % cfg['wandb']['log_interval'] == 0:
|
73 |
+
wandb.log({"train_loss": running_loss / (i + 1) / inputs.size(0),
|
74 |
+
"train_pesq": running_pesq / (i + 1) / inputs.size(0),
|
75 |
+
"train_stoi": running_stoi / (i + 1) / inputs.size(0)})
|
76 |
+
|
77 |
+
epoch_loss = running_loss / len(dataloaders[phase].dataset)
|
78 |
+
eposh_pesq = running_pesq / len(dataloaders[phase].dataset)
|
79 |
+
eposh_stoi = running_stoi / len(dataloaders[phase].dataset)
|
80 |
|
81 |
wandb.log({f"{phase}_loss": epoch_loss,
|
82 |
f"{phase}_pesq": eposh_pesq,
|
|
|
84 |
|
85 |
if phase == 'val':
|
86 |
for i, (wav, rate) in enumerate(dataloaders['minimal']):
|
87 |
+
prediction = model(wav.to(device))
|
88 |
wandb.log({
|
89 |
f"{i}_example": wandb.Audio(
|
90 |
+
prediction.detach().cpu().numpy()[0][0],
|
91 |
sample_rate=rate)})
|
92 |
|
93 |
+
checkpoint_saver(model, epoch, metric_val=eposh_pesq,
|
94 |
+
optimizer=optimizer, loss=epoch_loss)
|
95 |
+
|
96 |
+
|
97 |
+
if __name__ == "__main__":
|
98 |
+
pass
|