BorisovMaksim commited on
Commit
20c7778
·
1 Parent(s): 3c183ae

rewrote demucs model

Browse files

changed configs default values
refactoring

Files changed (5) hide show
  1. datasets/valentini.py +2 -5
  2. testing/unit_tests.py +4 -0
  3. train.py +41 -17
  4. transforms.py +12 -5
  5. utils.py +13 -0
datasets/valentini.py CHANGED
@@ -2,10 +2,8 @@ import torch
2
  from torch.utils.data import Dataset
3
  from pathlib import Path
4
  import torchaudio
5
- import numpy as np
6
- from torchaudio.transforms import Resample
7
 
8
- HIGH_RANDOM_SEED = 1000
9
 
10
  class Valentini(Dataset):
11
  def __init__(self, dataset_path, val_fraction, transform=None, valid=False, *args, **kwargs):
@@ -34,9 +32,8 @@ class Valentini(Dataset):
34
  clean_wav, clean_sr = torchaudio.load(self.clean_wavs[idx])
35
 
36
  if self.transform:
37
- random_seed = 0 if self.valid else torch.randint(HIGH_RANDOM_SEED, (1,))[0]
38
  torch.manual_seed(random_seed)
39
-
40
  noisy_wav = self.transform(noisy_wav)
41
  torch.manual_seed(random_seed)
42
  clean_wav = self.transform(clean_wav)
 
2
  from torch.utils.data import Dataset
3
  from pathlib import Path
4
  import torchaudio
 
 
5
 
6
+ MAX_RANDOM_SEED = 1000
7
 
8
  class Valentini(Dataset):
9
  def __init__(self, dataset_path, val_fraction, transform=None, valid=False, *args, **kwargs):
 
32
  clean_wav, clean_sr = torchaudio.load(self.clean_wavs[idx])
33
 
34
  if self.transform:
35
+ random_seed = 0 if self.valid else torch.randint(MAX_RANDOM_SEED, (1,))[0]
36
  torch.manual_seed(random_seed)
 
37
  noisy_wav = self.transform(noisy_wav)
38
  torch.manual_seed(random_seed)
39
  clean_wav = self.transform(clean_wav)
testing/unit_tests.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+
2
+
3
+ def test_model_inference():
4
+ assert 1 == 1
train.py CHANGED
@@ -12,21 +12,30 @@ from losses import get_loss
12
  from datasets import get_datasets
13
  from testing.metrics import Metrics
14
  from datasets.minimal import Minimal
 
15
 
16
-
17
- def train(cfg: DictConfig):
18
- device = torch.device(f'cuda:{cfg.gpu}' if torch.cuda.is_available() else 'cpu')
19
-
20
  wandb.login(key=cfg['wandb']['api_key'], host=cfg['wandb']['host'])
21
  wandb.init(project=cfg['wandb']['project'],
22
  notes=cfg['wandb']['notes'],
23
- tags=cfg['wandb']['tags'],
24
- config=omegaconf.OmegaConf.to_container(
25
- cfg, resolve=True, throw_on_missing=True))
26
- wandb.run.name = cfg['wandb']['run_name']
 
 
 
 
 
 
 
 
27
 
 
 
 
28
  checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'], run_name=wandb.run.name)
29
- metrics = Metrics(rate=cfg['dataloader']['sample_rate'])
30
 
31
  model = get_model(cfg['model']).to(device)
32
  optimizer = get_optimizer(model.parameters(), cfg['optimizer'])
@@ -35,14 +44,16 @@ def train(cfg: DictConfig):
35
  minimal_dataset = Minimal(cfg)
36
 
37
  dataloaders = {
38
- 'train': DataLoader(train_dataset, batch_size=cfg['dataloader']['train_batch_size'], shuffle=True),
39
- 'val': DataLoader(valid_dataset, batch_size=cfg['dataloader']['valid_batch_size'], shuffle=True),
 
 
40
  'minimal': DataLoader(minimal_dataset)
41
  }
42
 
43
- wandb.watch(model, log_freq=100)
44
-
45
- for epoch in range(cfg['training']['num_epochs']):
46
  for phase in ['train', 'val']:
47
  if phase == 'train':
48
  model.train()
@@ -50,7 +61,8 @@ def train(cfg: DictConfig):
50
  model.eval()
51
 
52
  running_loss, running_pesq, running_stoi = 0.0, 0.0, 0.0
53
- for i, (inputs, labels) in enumerate(dataloaders[phase]):
 
54
  inputs = inputs.to(device)
55
  labels = labels.to(device)
56
 
@@ -64,11 +76,16 @@ def train(cfg: DictConfig):
64
  loss.backward()
65
  optimizer.step()
66
 
67
- running_metrics = metrics.calculate(denoised=outputs, clean=labels)
68
  running_loss += loss.item() * inputs.size(0)
69
  running_pesq += running_metrics['PESQ']
70
  running_stoi += running_metrics['STOI']
71
 
 
 
 
 
 
72
  if phase == 'train' and i % cfg['wandb']['log_interval'] == 0:
73
  wandb.log({"train_loss": running_loss / (i + 1) / inputs.size(0),
74
  "train_pesq": running_pesq / (i + 1) / inputs.size(0),
@@ -84,7 +101,13 @@ def train(cfg: DictConfig):
84
 
85
  if phase == 'val':
86
  for i, (wav, rate) in enumerate(dataloaders['minimal']):
87
- prediction = model(wav.to(device))
 
 
 
 
 
 
88
  wandb.log({
89
  f"{i}_example": wandb.Audio(
90
  prediction.detach().cpu().numpy()[0][0],
@@ -92,6 +115,7 @@ def train(cfg: DictConfig):
92
 
93
  checkpoint_saver(model, epoch, metric_val=eposh_pesq,
94
  optimizer=optimizer, loss=epoch_loss)
 
95
 
96
 
97
  if __name__ == "__main__":
 
12
  from datasets import get_datasets
13
  from testing.metrics import Metrics
14
  from datasets.minimal import Minimal
15
+ from tqdm import tqdm
16
 
17
+ def init_wandb(cfg):
 
 
 
18
  wandb.login(key=cfg['wandb']['api_key'], host=cfg['wandb']['host'])
19
  wandb.init(project=cfg['wandb']['project'],
20
  notes=cfg['wandb']['notes'],
21
+ config=omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
22
+ resume=cfg['wandb']['resume'],
23
+ name=cfg['wandb']['run_name'])
24
+ if wandb.run.resumed:
25
+ api = wandb.Api()
26
+ runs = api.runs(f"{cfg['wandb']['entity']}/{cfg['wandb']['project']}",
27
+ order='train_pesq')
28
+ run = [run for run in runs if run.name == cfg['wandb']['run_name'] and run.state != 'running'][0]
29
+ artifacts = run.logged_artifacts()
30
+ best_model = [artifact for artifact in artifacts if artifact.type == 'model'][0]
31
+
32
+ best_model.download()
33
 
34
+ def train(cfg: DictConfig):
35
+ device = torch.device(f'cuda:{cfg.gpu}' if torch.cuda.is_available() else 'cpu')
36
+ init_wandb(cfg)
37
  checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'], run_name=wandb.run.name)
38
+ metrics = Metrics(source_rate=cfg['dataloader']['sample_rate']).to(device)
39
 
40
  model = get_model(cfg['model']).to(device)
41
  optimizer = get_optimizer(model.parameters(), cfg['optimizer'])
 
44
  minimal_dataset = Minimal(cfg)
45
 
46
  dataloaders = {
47
+ 'train': DataLoader(train_dataset, batch_size=cfg['dataloader']['train_batch_size'], shuffle=True,
48
+ num_workers=cfg['dataloader']['num_workers']),
49
+ 'val': DataLoader(valid_dataset, batch_size=cfg['dataloader']['valid_batch_size'], shuffle=False,
50
+ num_workers=cfg['dataloader']['num_workers']),
51
  'minimal': DataLoader(minimal_dataset)
52
  }
53
 
54
+ wandb.watch(model, log_freq=cfg['wandb']['log_interval'])
55
+ epoch = 0
56
+ while epoch < cfg['training']['num_epochs']:
57
  for phase in ['train', 'val']:
58
  if phase == 'train':
59
  model.train()
 
61
  model.eval()
62
 
63
  running_loss, running_pesq, running_stoi = 0.0, 0.0, 0.0
64
+ loop = tqdm(dataloaders[phase])
65
+ for i, (inputs, labels) in enumerate(loop):
66
  inputs = inputs.to(device)
67
  labels = labels.to(device)
68
 
 
76
  loss.backward()
77
  optimizer.step()
78
 
79
+ running_metrics = metrics(denoised=outputs, clean=labels)
80
  running_loss += loss.item() * inputs.size(0)
81
  running_pesq += running_metrics['PESQ']
82
  running_stoi += running_metrics['STOI']
83
 
84
+ loop.set_description(f"Epoch [{epoch}/{cfg['training']['num_epochs']}][{phase}]")
85
+ loop.set_postfix(loss=running_loss / (i + 1) / inputs.size(0),
86
+ pesq=running_pesq / (i + 1) / inputs.size(0),
87
+ stoi=running_stoi / (i + 1) / inputs.size(0))
88
+
89
  if phase == 'train' and i % cfg['wandb']['log_interval'] == 0:
90
  wandb.log({"train_loss": running_loss / (i + 1) / inputs.size(0),
91
  "train_pesq": running_pesq / (i + 1) / inputs.size(0),
 
101
 
102
  if phase == 'val':
103
  for i, (wav, rate) in enumerate(dataloaders['minimal']):
104
+ if cfg['dataloader']['normalize']:
105
+ std = torch.std(wav)
106
+ wav = wav / std
107
+ prediction = model(wav.to(device))
108
+ prediction = prediction * std
109
+ else:
110
+ prediction = model(wav.to(device))
111
  wandb.log({
112
  f"{i}_example": wandb.Audio(
113
  prediction.detach().cpu().numpy()[0][0],
 
115
 
116
  checkpoint_saver(model, epoch, metric_val=eposh_pesq,
117
  optimizer=optimizer, loss=epoch_loss)
118
+ epoch += 1
119
 
120
 
121
  if __name__ == "__main__":
transforms.py CHANGED
@@ -8,17 +8,24 @@ from torchvision.transforms import RandomCrop
8
  class Transform(torch.nn.Module):
9
  def __init__(
10
  self,
11
- input_sr,
12
  sample_rate,
13
  max_seconds,
 
14
  *args,
15
  **kwargs
16
  ):
17
  super().__init__()
18
- self.resample = Resample(orig_freq=input_sr, new_freq=sample_rate)
 
 
19
  self.random_crop = RandomCrop((1, int(max_seconds * sample_rate)), pad_if_needed=True)
 
20
 
21
  def forward(self, waveform: torch.Tensor) -> torch.Tensor:
22
- resampled = self.resample(waveform)
23
- croped = self.random_crop(resampled)
24
- return croped
 
 
 
 
8
  class Transform(torch.nn.Module):
9
  def __init__(
10
  self,
11
+ input_sample_rate,
12
  sample_rate,
13
  max_seconds,
14
+ normalize,
15
  *args,
16
  **kwargs
17
  ):
18
  super().__init__()
19
+ self.input_sample_rate = input_sample_rate
20
+ self.sample_rate = sample_rate
21
+ self.resample = Resample(orig_freq=input_sample_rate, new_freq=sample_rate)
22
  self.random_crop = RandomCrop((1, int(max_seconds * sample_rate)), pad_if_needed=True)
23
+ self.normalize = normalize
24
 
25
  def forward(self, waveform: torch.Tensor) -> torch.Tensor:
26
+ if self.input_sample_rate != self.sample_rate:
27
+ waveform = self.resample(waveform)
28
+ if self.normalize:
29
+ waveform = waveform / torch.std(waveform)
30
+ cropped = self.random_crop(waveform)
31
+ return cropped
utils.py CHANGED
@@ -2,6 +2,19 @@ import torchaudio
2
  import torch
3
  import matplotlib.pyplot as plt
4
  from pathlib import Path
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  def collect_valentini_paths(dataset_path):
 
2
  import torch
3
  import matplotlib.pyplot as plt
4
  from pathlib import Path
5
+ from torch.nn.functional import pad
6
+
7
+
8
+
9
+ def pad_cut_batch_audio(wavs, new_shape):
10
+ wav_length = wavs.shape[-1]
11
+ new_length = new_shape[-1]
12
+
13
+ if wav_length > new_length:
14
+ wavs = wavs[:, :, :new_length]
15
+ elif wav_length < new_length:
16
+ wavs = pad(wavs, (0, new_length - wav_length))
17
+ return wavs
18
 
19
 
20
  def collect_valentini_paths(dataset_path):