Spaces:
Runtime error
Runtime error
Commit
·
20c7778
1
Parent(s):
3c183ae
rewrote demucs model
Browse fileschanged configs default values
refactoring
- datasets/valentini.py +2 -5
- testing/unit_tests.py +4 -0
- train.py +41 -17
- transforms.py +12 -5
- utils.py +13 -0
datasets/valentini.py
CHANGED
@@ -2,10 +2,8 @@ import torch
|
|
2 |
from torch.utils.data import Dataset
|
3 |
from pathlib import Path
|
4 |
import torchaudio
|
5 |
-
import numpy as np
|
6 |
-
from torchaudio.transforms import Resample
|
7 |
|
8 |
-
|
9 |
|
10 |
class Valentini(Dataset):
|
11 |
def __init__(self, dataset_path, val_fraction, transform=None, valid=False, *args, **kwargs):
|
@@ -34,9 +32,8 @@ class Valentini(Dataset):
|
|
34 |
clean_wav, clean_sr = torchaudio.load(self.clean_wavs[idx])
|
35 |
|
36 |
if self.transform:
|
37 |
-
random_seed = 0 if self.valid else torch.randint(
|
38 |
torch.manual_seed(random_seed)
|
39 |
-
|
40 |
noisy_wav = self.transform(noisy_wav)
|
41 |
torch.manual_seed(random_seed)
|
42 |
clean_wav = self.transform(clean_wav)
|
|
|
2 |
from torch.utils.data import Dataset
|
3 |
from pathlib import Path
|
4 |
import torchaudio
|
|
|
|
|
5 |
|
6 |
+
MAX_RANDOM_SEED = 1000
|
7 |
|
8 |
class Valentini(Dataset):
|
9 |
def __init__(self, dataset_path, val_fraction, transform=None, valid=False, *args, **kwargs):
|
|
|
32 |
clean_wav, clean_sr = torchaudio.load(self.clean_wavs[idx])
|
33 |
|
34 |
if self.transform:
|
35 |
+
random_seed = 0 if self.valid else torch.randint(MAX_RANDOM_SEED, (1,))[0]
|
36 |
torch.manual_seed(random_seed)
|
|
|
37 |
noisy_wav = self.transform(noisy_wav)
|
38 |
torch.manual_seed(random_seed)
|
39 |
clean_wav = self.transform(clean_wav)
|
testing/unit_tests.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
def test_model_inference():
|
4 |
+
assert 1 == 1
|
train.py
CHANGED
@@ -12,21 +12,30 @@ from losses import get_loss
|
|
12 |
from datasets import get_datasets
|
13 |
from testing.metrics import Metrics
|
14 |
from datasets.minimal import Minimal
|
|
|
15 |
|
16 |
-
|
17 |
-
def train(cfg: DictConfig):
|
18 |
-
device = torch.device(f'cuda:{cfg.gpu}' if torch.cuda.is_available() else 'cpu')
|
19 |
-
|
20 |
wandb.login(key=cfg['wandb']['api_key'], host=cfg['wandb']['host'])
|
21 |
wandb.init(project=cfg['wandb']['project'],
|
22 |
notes=cfg['wandb']['notes'],
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
wandb.run.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
|
|
|
|
|
|
28 |
checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'], run_name=wandb.run.name)
|
29 |
-
metrics = Metrics(
|
30 |
|
31 |
model = get_model(cfg['model']).to(device)
|
32 |
optimizer = get_optimizer(model.parameters(), cfg['optimizer'])
|
@@ -35,14 +44,16 @@ def train(cfg: DictConfig):
|
|
35 |
minimal_dataset = Minimal(cfg)
|
36 |
|
37 |
dataloaders = {
|
38 |
-
'train': DataLoader(train_dataset, batch_size=cfg['dataloader']['train_batch_size'], shuffle=True
|
39 |
-
|
|
|
|
|
40 |
'minimal': DataLoader(minimal_dataset)
|
41 |
}
|
42 |
|
43 |
-
wandb.watch(model, log_freq=
|
44 |
-
|
45 |
-
|
46 |
for phase in ['train', 'val']:
|
47 |
if phase == 'train':
|
48 |
model.train()
|
@@ -50,7 +61,8 @@ def train(cfg: DictConfig):
|
|
50 |
model.eval()
|
51 |
|
52 |
running_loss, running_pesq, running_stoi = 0.0, 0.0, 0.0
|
53 |
-
|
|
|
54 |
inputs = inputs.to(device)
|
55 |
labels = labels.to(device)
|
56 |
|
@@ -64,11 +76,16 @@ def train(cfg: DictConfig):
|
|
64 |
loss.backward()
|
65 |
optimizer.step()
|
66 |
|
67 |
-
running_metrics = metrics
|
68 |
running_loss += loss.item() * inputs.size(0)
|
69 |
running_pesq += running_metrics['PESQ']
|
70 |
running_stoi += running_metrics['STOI']
|
71 |
|
|
|
|
|
|
|
|
|
|
|
72 |
if phase == 'train' and i % cfg['wandb']['log_interval'] == 0:
|
73 |
wandb.log({"train_loss": running_loss / (i + 1) / inputs.size(0),
|
74 |
"train_pesq": running_pesq / (i + 1) / inputs.size(0),
|
@@ -84,7 +101,13 @@ def train(cfg: DictConfig):
|
|
84 |
|
85 |
if phase == 'val':
|
86 |
for i, (wav, rate) in enumerate(dataloaders['minimal']):
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
wandb.log({
|
89 |
f"{i}_example": wandb.Audio(
|
90 |
prediction.detach().cpu().numpy()[0][0],
|
@@ -92,6 +115,7 @@ def train(cfg: DictConfig):
|
|
92 |
|
93 |
checkpoint_saver(model, epoch, metric_val=eposh_pesq,
|
94 |
optimizer=optimizer, loss=epoch_loss)
|
|
|
95 |
|
96 |
|
97 |
if __name__ == "__main__":
|
|
|
12 |
from datasets import get_datasets
|
13 |
from testing.metrics import Metrics
|
14 |
from datasets.minimal import Minimal
|
15 |
+
from tqdm import tqdm
|
16 |
|
17 |
+
def init_wandb(cfg):
|
|
|
|
|
|
|
18 |
wandb.login(key=cfg['wandb']['api_key'], host=cfg['wandb']['host'])
|
19 |
wandb.init(project=cfg['wandb']['project'],
|
20 |
notes=cfg['wandb']['notes'],
|
21 |
+
config=omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
|
22 |
+
resume=cfg['wandb']['resume'],
|
23 |
+
name=cfg['wandb']['run_name'])
|
24 |
+
if wandb.run.resumed:
|
25 |
+
api = wandb.Api()
|
26 |
+
runs = api.runs(f"{cfg['wandb']['entity']}/{cfg['wandb']['project']}",
|
27 |
+
order='train_pesq')
|
28 |
+
run = [run for run in runs if run.name == cfg['wandb']['run_name'] and run.state != 'running'][0]
|
29 |
+
artifacts = run.logged_artifacts()
|
30 |
+
best_model = [artifact for artifact in artifacts if artifact.type == 'model'][0]
|
31 |
+
|
32 |
+
best_model.download()
|
33 |
|
34 |
+
def train(cfg: DictConfig):
|
35 |
+
device = torch.device(f'cuda:{cfg.gpu}' if torch.cuda.is_available() else 'cpu')
|
36 |
+
init_wandb(cfg)
|
37 |
checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'], run_name=wandb.run.name)
|
38 |
+
metrics = Metrics(source_rate=cfg['dataloader']['sample_rate']).to(device)
|
39 |
|
40 |
model = get_model(cfg['model']).to(device)
|
41 |
optimizer = get_optimizer(model.parameters(), cfg['optimizer'])
|
|
|
44 |
minimal_dataset = Minimal(cfg)
|
45 |
|
46 |
dataloaders = {
|
47 |
+
'train': DataLoader(train_dataset, batch_size=cfg['dataloader']['train_batch_size'], shuffle=True,
|
48 |
+
num_workers=cfg['dataloader']['num_workers']),
|
49 |
+
'val': DataLoader(valid_dataset, batch_size=cfg['dataloader']['valid_batch_size'], shuffle=False,
|
50 |
+
num_workers=cfg['dataloader']['num_workers']),
|
51 |
'minimal': DataLoader(minimal_dataset)
|
52 |
}
|
53 |
|
54 |
+
wandb.watch(model, log_freq=cfg['wandb']['log_interval'])
|
55 |
+
epoch = 0
|
56 |
+
while epoch < cfg['training']['num_epochs']:
|
57 |
for phase in ['train', 'val']:
|
58 |
if phase == 'train':
|
59 |
model.train()
|
|
|
61 |
model.eval()
|
62 |
|
63 |
running_loss, running_pesq, running_stoi = 0.0, 0.0, 0.0
|
64 |
+
loop = tqdm(dataloaders[phase])
|
65 |
+
for i, (inputs, labels) in enumerate(loop):
|
66 |
inputs = inputs.to(device)
|
67 |
labels = labels.to(device)
|
68 |
|
|
|
76 |
loss.backward()
|
77 |
optimizer.step()
|
78 |
|
79 |
+
running_metrics = metrics(denoised=outputs, clean=labels)
|
80 |
running_loss += loss.item() * inputs.size(0)
|
81 |
running_pesq += running_metrics['PESQ']
|
82 |
running_stoi += running_metrics['STOI']
|
83 |
|
84 |
+
loop.set_description(f"Epoch [{epoch}/{cfg['training']['num_epochs']}][{phase}]")
|
85 |
+
loop.set_postfix(loss=running_loss / (i + 1) / inputs.size(0),
|
86 |
+
pesq=running_pesq / (i + 1) / inputs.size(0),
|
87 |
+
stoi=running_stoi / (i + 1) / inputs.size(0))
|
88 |
+
|
89 |
if phase == 'train' and i % cfg['wandb']['log_interval'] == 0:
|
90 |
wandb.log({"train_loss": running_loss / (i + 1) / inputs.size(0),
|
91 |
"train_pesq": running_pesq / (i + 1) / inputs.size(0),
|
|
|
101 |
|
102 |
if phase == 'val':
|
103 |
for i, (wav, rate) in enumerate(dataloaders['minimal']):
|
104 |
+
if cfg['dataloader']['normalize']:
|
105 |
+
std = torch.std(wav)
|
106 |
+
wav = wav / std
|
107 |
+
prediction = model(wav.to(device))
|
108 |
+
prediction = prediction * std
|
109 |
+
else:
|
110 |
+
prediction = model(wav.to(device))
|
111 |
wandb.log({
|
112 |
f"{i}_example": wandb.Audio(
|
113 |
prediction.detach().cpu().numpy()[0][0],
|
|
|
115 |
|
116 |
checkpoint_saver(model, epoch, metric_val=eposh_pesq,
|
117 |
optimizer=optimizer, loss=epoch_loss)
|
118 |
+
epoch += 1
|
119 |
|
120 |
|
121 |
if __name__ == "__main__":
|
transforms.py
CHANGED
@@ -8,17 +8,24 @@ from torchvision.transforms import RandomCrop
|
|
8 |
class Transform(torch.nn.Module):
|
9 |
def __init__(
|
10 |
self,
|
11 |
-
|
12 |
sample_rate,
|
13 |
max_seconds,
|
|
|
14 |
*args,
|
15 |
**kwargs
|
16 |
):
|
17 |
super().__init__()
|
18 |
-
self.
|
|
|
|
|
19 |
self.random_crop = RandomCrop((1, int(max_seconds * sample_rate)), pad_if_needed=True)
|
|
|
20 |
|
21 |
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
8 |
class Transform(torch.nn.Module):
|
9 |
def __init__(
|
10 |
self,
|
11 |
+
input_sample_rate,
|
12 |
sample_rate,
|
13 |
max_seconds,
|
14 |
+
normalize,
|
15 |
*args,
|
16 |
**kwargs
|
17 |
):
|
18 |
super().__init__()
|
19 |
+
self.input_sample_rate = input_sample_rate
|
20 |
+
self.sample_rate = sample_rate
|
21 |
+
self.resample = Resample(orig_freq=input_sample_rate, new_freq=sample_rate)
|
22 |
self.random_crop = RandomCrop((1, int(max_seconds * sample_rate)), pad_if_needed=True)
|
23 |
+
self.normalize = normalize
|
24 |
|
25 |
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
26 |
+
if self.input_sample_rate != self.sample_rate:
|
27 |
+
waveform = self.resample(waveform)
|
28 |
+
if self.normalize:
|
29 |
+
waveform = waveform / torch.std(waveform)
|
30 |
+
cropped = self.random_crop(waveform)
|
31 |
+
return cropped
|
utils.py
CHANGED
@@ -2,6 +2,19 @@ import torchaudio
|
|
2 |
import torch
|
3 |
import matplotlib.pyplot as plt
|
4 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
def collect_valentini_paths(dataset_path):
|
|
|
2 |
import torch
|
3 |
import matplotlib.pyplot as plt
|
4 |
from pathlib import Path
|
5 |
+
from torch.nn.functional import pad
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
def pad_cut_batch_audio(wavs, new_shape):
|
10 |
+
wav_length = wavs.shape[-1]
|
11 |
+
new_length = new_shape[-1]
|
12 |
+
|
13 |
+
if wav_length > new_length:
|
14 |
+
wavs = wavs[:, :, :new_length]
|
15 |
+
elif wav_length < new_length:
|
16 |
+
wavs = pad(wavs, (0, new_length - wav_length))
|
17 |
+
return wavs
|
18 |
|
19 |
|
20 |
def collect_valentini_paths(dataset_path):
|