mattricesound commited on
Commit
d54e023
·
2 Parent(s): 9e82ce4 abb9ffa

Merge pull request #3 from mhrice/umx-train-init

Browse files
Files changed (8) hide show
  1. .gitignore +3 -1
  2. .gitmodules +3 -0
  3. datasets.py +4 -2
  4. download_egfx.sh +1 -1
  5. egfx.ipynb +0 -0
  6. models.py +106 -18
  7. train.py +9 -5
  8. umx +1 -0
.gitignore CHANGED
@@ -4,4 +4,6 @@ wandb/
4
  *.egg-info/
5
  data/
6
  .DS_Store
7
- __pycache__/
 
 
 
4
  *.egg-info/
5
  data/
6
  .DS_Store
7
+ __pycache__/
8
+ lightning_logs/
9
+ RemFX/
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "umx"]
2
+ path = umx
3
+ url = https://github.com/sigsep/open-unmix-pytorch
datasets.py CHANGED
@@ -31,8 +31,10 @@ class GuitarFXDataset(Dataset):
31
  ]
32
  for i, effect in enumerate(effect_type):
33
  for pickup in Path(self.root / effect).iterdir():
34
- self.wet_files += list(pickup.glob("*.wav"))
35
- self.dry_files += list(self.root.glob(f"Clean/{pickup.name}/**/*.wav"))
 
 
36
  self.labels += [i] * len(self.wet_files)
37
  print(
38
  f"Found {len(self.wet_files)} wet files and {len(self.dry_files)} dry files"
 
31
  ]
32
  for i, effect in enumerate(effect_type):
33
  for pickup in Path(self.root / effect).iterdir():
34
+ self.wet_files += sorted(list(pickup.glob("*.wav")))
35
+ self.dry_files += sorted(
36
+ list(self.root.glob(f"Clean/{pickup.name}/**/*.wav"))
37
+ )
38
  self.labels += [i] * len(self.wet_files)
39
  print(
40
  f"Found {len(self.wet_files)} wet files and {len(self.dry_files)} dry files"
download_egfx.sh CHANGED
@@ -16,6 +16,6 @@ wget https://zenodo.org/record/7044411/files/Spring-Reverb.zip?download=1 -O Spr
16
  wget https://zenodo.org/record/7044411/files/Sweep-Echo.zip?download=1 -O Sweep-Echo.zip
17
  wget https://zenodo.org/record/7044411/files/TapeEcho.zip?download=1 -O TapeEcho.zip
18
  wget https://zenodo.org/record/7044411/files/TubeScreamer.zip?download=1 -O TubeScreamer.zip
19
- unzip \*.zip
20
 
21
 
 
16
  wget https://zenodo.org/record/7044411/files/Sweep-Echo.zip?download=1 -O Sweep-Echo.zip
17
  wget https://zenodo.org/record/7044411/files/TapeEcho.zip?download=1 -O TapeEcho.zip
18
  wget https://zenodo.org/record/7044411/files/TubeScreamer.zip?download=1 -O TubeScreamer.zip
19
+ unzip -n \*.zip
20
 
21
 
egfx.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
models.py CHANGED
@@ -1,44 +1,106 @@
1
- from audio_diffusion_pytorch import AudioDiffusionModel
2
  import torch
3
  from torch import Tensor
4
  import pytorch_lightning as pl
5
  from einops import rearrange
6
  import wandb
 
 
 
 
 
 
 
7
 
8
  SAMPLE_RATE = 22050 # From audio-diffusion-pytorch
9
 
10
 
11
- class TCNWrapper(pl.LightningModule):
12
- def __init__(self):
 
 
 
 
 
13
  super().__init__()
14
- self.model = AudioDiffusionModel(in_channels=1)
 
 
 
 
 
 
 
 
15
 
16
  def forward(self, x: torch.Tensor):
17
  return self.model(x)
18
 
19
  def training_step(self, batch, batch_idx):
20
- loss = self.common_step(batch, batch_idx, mode="train")
21
  return loss
22
 
23
  def validation_step(self, batch, batch_idx):
24
- loss = self.common_step(batch, batch_idx, mode="val")
 
25
 
26
  def common_step(self, batch, batch_idx, mode: str = "train"):
27
  x, target, label = batch
28
- loss = self(x)
 
 
 
 
 
29
  self.log(f"{mode}_loss", loss, on_step=True, on_epoch=True)
30
- return loss
31
 
32
  def configure_optimizers(self):
33
  return torch.optim.Adam(
34
  self.parameters(), lr=1e-4, betas=(0.95, 0.999), eps=1e-6, weight_decay=1e-3
35
  )
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- class AudioDiffusionWrapper(pl.LightningModule):
39
- def __init__(self):
40
  super().__init__()
41
- self.model = AudioDiffusionModel(in_channels=1)
42
 
43
  def forward(self, x: torch.Tensor):
44
  return self.model(x)
@@ -77,10 +139,8 @@ class AudioDiffusionWrapper(pl.LightningModule):
77
  def log_sample(self, batch, num_steps=10):
78
  # Get start diffusion noise
79
  noise = torch.randn(batch.shape, device=self.device)
80
- sampled = self.model.sample(
81
- noise=noise, num_steps=num_steps # Suggested range: 2-50
82
- )
83
- self.log_wandb_audio_batch(
84
  id="sample",
85
  samples=sampled,
86
  sampling_rate=SAMPLE_RATE,
@@ -89,17 +149,45 @@ class AudioDiffusionWrapper(pl.LightningModule):
89
 
90
 
91
  def log_wandb_audio_batch(
92
- id: str, samples: Tensor, sampling_rate: int, caption: str = ""
 
 
 
 
93
  ):
94
  num_items = samples.shape[0]
95
  samples = rearrange(samples, "b c t -> b t c")
96
  for idx in range(num_items):
97
- wandb.log(
98
  {
99
- f"sample_{idx}_{id}": wandb.Audio(
100
  samples[idx].cpu().numpy(),
101
  caption=caption,
102
  sample_rate=sampling_rate,
103
  )
104
  }
105
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from torch import Tensor
3
  import pytorch_lightning as pl
4
  from einops import rearrange
5
  import wandb
6
+ from audio_diffusion_pytorch import AudioDiffusionModel
7
+
8
+ import sys
9
+
10
+ sys.path.append("./umx")
11
+ from umx.openunmix.model import OpenUnmix, Separator
12
+
13
 
14
  SAMPLE_RATE = 22050 # From audio-diffusion-pytorch
15
 
16
 
17
+ class OpenUnmixModel(pl.LightningModule):
18
+ def __init__(
19
+ self,
20
+ n_fft: int = 2048,
21
+ hop_length: int = 512,
22
+ alpha: float = 0.3,
23
+ ):
24
  super().__init__()
25
+ self.model = OpenUnmix(
26
+ nb_channels=1,
27
+ nb_bins=n_fft // 2 + 1,
28
+ )
29
+ self.n_fft = n_fft
30
+ self.hop_length = hop_length
31
+ self.alpha = alpha
32
+ window = torch.hann_window(n_fft)
33
+ self.register_buffer("window", window)
34
 
35
  def forward(self, x: torch.Tensor):
36
  return self.model(x)
37
 
38
  def training_step(self, batch, batch_idx):
39
+ loss, _ = self.common_step(batch, batch_idx, mode="train")
40
  return loss
41
 
42
  def validation_step(self, batch, batch_idx):
43
+ loss, Y = self.common_step(batch, batch_idx, mode="val")
44
+ return loss, Y
45
 
46
  def common_step(self, batch, batch_idx, mode: str = "train"):
47
  x, target, label = batch
48
+ X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
49
+ Y = self(X)
50
+ Y_hat = spectrogram(
51
+ target, self.window, self.n_fft, self.hop_length, self.alpha
52
+ )
53
+ loss = torch.nn.functional.mse_loss(Y, Y_hat)
54
  self.log(f"{mode}_loss", loss, on_step=True, on_epoch=True)
55
+ return loss, Y
56
 
57
  def configure_optimizers(self):
58
  return torch.optim.Adam(
59
  self.parameters(), lr=1e-4, betas=(0.95, 0.999), eps=1e-6, weight_decay=1e-3
60
  )
61
 
62
+ def on_validation_epoch_start(self):
63
+ self.log_next = True
64
+
65
+ def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
66
+ if self.log_next:
67
+ x, target, label = batch
68
+ s = Separator(
69
+ target_models={"other": self.model},
70
+ nb_channels=1,
71
+ sample_rate=SAMPLE_RATE,
72
+ n_fft=self.n_fft,
73
+ n_hop=self.hop_length,
74
+ ).to(self.device)
75
+ outputs = s(x).squeeze(1)
76
+ log_wandb_audio_batch(
77
+ logger=self.logger,
78
+ id="sample",
79
+ samples=x.cpu(),
80
+ sampling_rate=SAMPLE_RATE,
81
+ caption=f"Epoch {self.current_epoch}",
82
+ )
83
+ log_wandb_audio_batch(
84
+ logger=self.logger,
85
+ id="prediction",
86
+ samples=outputs.cpu(),
87
+ sampling_rate=SAMPLE_RATE,
88
+ caption=f"Epoch {self.current_epoch}",
89
+ )
90
+ log_wandb_audio_batch(
91
+ logger=self.loggger,
92
+ id="target",
93
+ samples=target.cpu(),
94
+ sampling_rate=SAMPLE_RATE,
95
+ caption=f"Epoch {self.current_epoch}",
96
+ )
97
+ self.log_next = False
98
+
99
 
100
+ class DiffusionGenerationModel(pl.LightningModule):
101
+ def __init__(self, model: torch.nn.Module):
102
  super().__init__()
103
+ self.model = model
104
 
105
  def forward(self, x: torch.Tensor):
106
  return self.model(x)
 
139
  def log_sample(self, batch, num_steps=10):
140
  # Get start diffusion noise
141
  noise = torch.randn(batch.shape, device=self.device)
142
+ sampled = self.sample(noise=noise, num_steps=num_steps) # Suggested range: 2-50
143
+ log_wandb_audio_batch(
 
 
144
  id="sample",
145
  samples=sampled,
146
  sampling_rate=SAMPLE_RATE,
 
149
 
150
 
151
  def log_wandb_audio_batch(
152
+ logger: pl.loggers.WandbLogger,
153
+ id: str,
154
+ samples: Tensor,
155
+ sampling_rate: int,
156
+ caption: str = "",
157
  ):
158
  num_items = samples.shape[0]
159
  samples = rearrange(samples, "b c t -> b t c")
160
  for idx in range(num_items):
161
+ logger.experiment.log(
162
  {
163
+ f"{id}_{idx}": wandb.Audio(
164
  samples[idx].cpu().numpy(),
165
  caption=caption,
166
  sample_rate=sampling_rate,
167
  )
168
  }
169
  )
170
+
171
+
172
+ def spectrogram(
173
+ x: torch.Tensor,
174
+ window: torch.Tensor,
175
+ n_fft: int,
176
+ hop_length: int,
177
+ alpha: float,
178
+ ) -> torch.Tensor:
179
+ bs, chs, samp = x.size()
180
+ x = x.view(bs * chs, -1) # move channels onto batch dim
181
+
182
+ X = torch.stft(
183
+ x,
184
+ n_fft=n_fft,
185
+ hop_length=hop_length,
186
+ window=window,
187
+ return_complex=True,
188
+ )
189
+
190
+ # move channels back
191
+ X = X.view(bs, chs, X.shape[-2], X.shape[-1])
192
+
193
+ return torch.pow(X.abs() + 1e-8, alpha)
train.py CHANGED
@@ -3,17 +3,18 @@ import pytorch_lightning as pl
3
  import torch
4
  from torch.utils.data import DataLoader
5
  from datasets import GuitarFXDataset
6
- from models import AudioDiffusionWrapper
 
7
 
8
  SAMPLE_RATE = 22050
9
  TRAIN_SPLIT = 0.8
10
 
11
 
12
  def main():
13
- # wandb_logger = WandbLogger(project="RemFX", save_dir="./")
14
- trainer = pl.Trainer() # logger=wandb_logger)
15
  guitfx = GuitarFXDataset(
16
- root="/Users/matthewrice/mir_datasets/egfxset",
17
  sample_rate=SAMPLE_RATE,
18
  effect_type=["Phaser"],
19
  )
@@ -24,7 +25,10 @@ def main():
24
  )
25
  train = DataLoader(train_dataset, batch_size=2)
26
  val = DataLoader(val_dataset, batch_size=2)
27
- model = AudioDiffusionWrapper()
 
 
 
28
  trainer.fit(model=model, train_dataloaders=train, val_dataloaders=val)
29
 
30
 
 
3
  import torch
4
  from torch.utils.data import DataLoader
5
  from datasets import GuitarFXDataset
6
+ from models import DiffusionGenerationModel, OpenUnmixModel
7
+
8
 
9
  SAMPLE_RATE = 22050
10
  TRAIN_SPLIT = 0.8
11
 
12
 
13
  def main():
14
+ wandb_logger = WandbLogger(project="RemFX", save_dir="./")
15
+ trainer = pl.Trainer(logger=wandb_logger, max_epochs=100)
16
  guitfx = GuitarFXDataset(
17
+ root="./data/egfx",
18
  sample_rate=SAMPLE_RATE,
19
  effect_type=["Phaser"],
20
  )
 
25
  )
26
  train = DataLoader(train_dataset, batch_size=2)
27
  val = DataLoader(val_dataset, batch_size=2)
28
+
29
+ # model = DiffusionGenerationModel()
30
+ model = OpenUnmixModel()
31
+
32
  trainer.fit(model=model, train_dataloaders=train, val_dataloaders=val)
33
 
34
 
umx ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 05fd4d8a0e3e50e308579052d762a342647c3408