Spaces:
Runtime error
Runtime error
Commit
·
8949a8c
1
Parent(s):
14ae0ea
Initial ptl model and training script for umx
Browse files- .gitignore +3 -1
- .gitmodules +3 -0
- models.py +97 -16
- train.py +9 -5
- umx +1 -0
.gitignore
CHANGED
@@ -4,4 +4,6 @@ wandb/
|
|
4 |
*.egg-info/
|
5 |
data/
|
6 |
.DS_Store
|
7 |
-
__pycache__/
|
|
|
|
|
|
4 |
*.egg-info/
|
5 |
data/
|
6 |
.DS_Store
|
7 |
+
__pycache__/
|
8 |
+
lightning_logs/
|
9 |
+
RemFX/
|
.gitmodules
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "umx"]
|
2 |
+
path = umx
|
3 |
+
url = https://github.com/sigsep/open-unmix-pytorch
|
models.py
CHANGED
@@ -1,44 +1,103 @@
|
|
1 |
-
from audio_diffusion_pytorch import AudioDiffusionModel
|
2 |
import torch
|
3 |
from torch import Tensor
|
4 |
import pytorch_lightning as pl
|
5 |
from einops import rearrange
|
6 |
import wandb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
SAMPLE_RATE = 22050 # From audio-diffusion-pytorch
|
9 |
|
10 |
|
11 |
-
class
|
12 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
13 |
super().__init__()
|
14 |
-
self.model =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
def forward(self, x: torch.Tensor):
|
17 |
return self.model(x)
|
18 |
|
19 |
def training_step(self, batch, batch_idx):
|
20 |
-
loss = self.common_step(batch, batch_idx, mode="train")
|
21 |
return loss
|
22 |
|
23 |
def validation_step(self, batch, batch_idx):
|
24 |
-
loss = self.common_step(batch, batch_idx, mode="val")
|
|
|
25 |
|
26 |
def common_step(self, batch, batch_idx, mode: str = "train"):
|
27 |
x, target, label = batch
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
29 |
self.log(f"{mode}_loss", loss, on_step=True, on_epoch=True)
|
30 |
-
return loss
|
31 |
|
32 |
def configure_optimizers(self):
|
33 |
return torch.optim.Adam(
|
34 |
self.parameters(), lr=1e-4, betas=(0.95, 0.999), eps=1e-6, weight_decay=1e-3
|
35 |
)
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
-
class
|
39 |
-
def __init__(self):
|
40 |
super().__init__()
|
41 |
-
self.model =
|
42 |
|
43 |
def forward(self, x: torch.Tensor):
|
44 |
return self.model(x)
|
@@ -77,10 +136,8 @@ class AudioDiffusionWrapper(pl.LightningModule):
|
|
77 |
def log_sample(self, batch, num_steps=10):
|
78 |
# Get start diffusion noise
|
79 |
noise = torch.randn(batch.shape, device=self.device)
|
80 |
-
sampled = self.
|
81 |
-
|
82 |
-
)
|
83 |
-
self.log_wandb_audio_batch(
|
84 |
id="sample",
|
85 |
samples=sampled,
|
86 |
sampling_rate=SAMPLE_RATE,
|
@@ -96,10 +153,34 @@ def log_wandb_audio_batch(
|
|
96 |
for idx in range(num_items):
|
97 |
wandb.log(
|
98 |
{
|
99 |
-
f"
|
100 |
samples[idx].cpu().numpy(),
|
101 |
caption=caption,
|
102 |
sample_rate=sampling_rate,
|
103 |
)
|
104 |
}
|
105 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
from torch import Tensor
|
3 |
import pytorch_lightning as pl
|
4 |
from einops import rearrange
|
5 |
import wandb
|
6 |
+
from audio_diffusion_pytorch import AudioDiffusionModel
|
7 |
+
|
8 |
+
import sys
|
9 |
+
|
10 |
+
sys.path.append("/Users/matthewrice/Developer/remfx/umx/")
|
11 |
+
from umx.openunmix.model import OpenUnmix, Separator
|
12 |
+
|
13 |
|
14 |
SAMPLE_RATE = 22050 # From audio-diffusion-pytorch
|
15 |
|
16 |
|
17 |
+
class OpenUnmixModel(pl.LightningModule):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
n_fft: int = 2048,
|
21 |
+
hop_length: int = 512,
|
22 |
+
alpha: float = 0.3,
|
23 |
+
):
|
24 |
super().__init__()
|
25 |
+
self.model = OpenUnmix(
|
26 |
+
nb_channels=1,
|
27 |
+
nb_bins=n_fft // 2 + 1,
|
28 |
+
)
|
29 |
+
self.n_fft = n_fft
|
30 |
+
self.hop_length = hop_length
|
31 |
+
self.alpha = alpha
|
32 |
+
window = torch.hann_window(n_fft)
|
33 |
+
self.register_buffer("window", window)
|
34 |
|
35 |
def forward(self, x: torch.Tensor):
|
36 |
return self.model(x)
|
37 |
|
38 |
def training_step(self, batch, batch_idx):
|
39 |
+
loss, _ = self.common_step(batch, batch_idx, mode="train")
|
40 |
return loss
|
41 |
|
42 |
def validation_step(self, batch, batch_idx):
|
43 |
+
loss, Y = self.common_step(batch, batch_idx, mode="val")
|
44 |
+
return loss, Y
|
45 |
|
46 |
def common_step(self, batch, batch_idx, mode: str = "train"):
|
47 |
x, target, label = batch
|
48 |
+
X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
|
49 |
+
Y = self(X)
|
50 |
+
Y_hat = spectrogram(
|
51 |
+
target, self.window, self.n_fft, self.hop_length, self.alpha
|
52 |
+
)
|
53 |
+
loss = torch.nn.functional.mse_loss(Y, Y_hat)
|
54 |
self.log(f"{mode}_loss", loss, on_step=True, on_epoch=True)
|
55 |
+
return loss, Y
|
56 |
|
57 |
def configure_optimizers(self):
|
58 |
return torch.optim.Adam(
|
59 |
self.parameters(), lr=1e-4, betas=(0.95, 0.999), eps=1e-6, weight_decay=1e-3
|
60 |
)
|
61 |
|
62 |
+
def on_validation_epoch_start(self):
|
63 |
+
self.log_next = True
|
64 |
+
|
65 |
+
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
|
66 |
+
if self.log_next:
|
67 |
+
x, target, label = batch
|
68 |
+
s = Separator(
|
69 |
+
target_models={"other": self.model},
|
70 |
+
nb_channels=1,
|
71 |
+
sample_rate=SAMPLE_RATE,
|
72 |
+
n_fft=self.n_fft,
|
73 |
+
n_hop=self.hop_length,
|
74 |
+
)
|
75 |
+
outputs = s(x).squeeze(1)
|
76 |
+
log_wandb_audio_batch(
|
77 |
+
id="sample",
|
78 |
+
samples=x,
|
79 |
+
sampling_rate=SAMPLE_RATE,
|
80 |
+
caption=f"Epoch {self.current_epoch}",
|
81 |
+
)
|
82 |
+
log_wandb_audio_batch(
|
83 |
+
id="prediction",
|
84 |
+
samples=outputs,
|
85 |
+
sampling_rate=SAMPLE_RATE,
|
86 |
+
caption=f"Epoch {self.current_epoch}",
|
87 |
+
)
|
88 |
+
log_wandb_audio_batch(
|
89 |
+
id="target",
|
90 |
+
samples=target,
|
91 |
+
sampling_rate=SAMPLE_RATE,
|
92 |
+
caption=f"Epoch {self.current_epoch}",
|
93 |
+
)
|
94 |
+
self.log_next = False
|
95 |
+
|
96 |
|
97 |
+
class DiffusionGenerationModel(pl.LightningModule):
|
98 |
+
def __init__(self, model: torch.nn.Module):
|
99 |
super().__init__()
|
100 |
+
self.model = model
|
101 |
|
102 |
def forward(self, x: torch.Tensor):
|
103 |
return self.model(x)
|
|
|
136 |
def log_sample(self, batch, num_steps=10):
|
137 |
# Get start diffusion noise
|
138 |
noise = torch.randn(batch.shape, device=self.device)
|
139 |
+
sampled = self.sample(noise=noise, num_steps=num_steps) # Suggested range: 2-50
|
140 |
+
log_wandb_audio_batch(
|
|
|
|
|
141 |
id="sample",
|
142 |
samples=sampled,
|
143 |
sampling_rate=SAMPLE_RATE,
|
|
|
153 |
for idx in range(num_items):
|
154 |
wandb.log(
|
155 |
{
|
156 |
+
f"{id}_{idx}": wandb.Audio(
|
157 |
samples[idx].cpu().numpy(),
|
158 |
caption=caption,
|
159 |
sample_rate=sampling_rate,
|
160 |
)
|
161 |
}
|
162 |
)
|
163 |
+
|
164 |
+
|
165 |
+
def spectrogram(
|
166 |
+
x: torch.Tensor,
|
167 |
+
window: torch.Tensor,
|
168 |
+
n_fft: int,
|
169 |
+
hop_length: int,
|
170 |
+
alpha: float,
|
171 |
+
) -> torch.Tensor:
|
172 |
+
bs, chs, samp = x.size()
|
173 |
+
x = x.view(bs * chs, -1) # move channels onto batch dim
|
174 |
+
|
175 |
+
X = torch.stft(
|
176 |
+
x,
|
177 |
+
n_fft=n_fft,
|
178 |
+
hop_length=hop_length,
|
179 |
+
window=window,
|
180 |
+
return_complex=True,
|
181 |
+
)
|
182 |
+
|
183 |
+
# move channels back
|
184 |
+
X = X.view(bs, chs, X.shape[-2], X.shape[-1])
|
185 |
+
|
186 |
+
return torch.pow(X.abs() + 1e-8, alpha)
|
train.py
CHANGED
@@ -3,17 +3,18 @@ import pytorch_lightning as pl
|
|
3 |
import torch
|
4 |
from torch.utils.data import DataLoader
|
5 |
from datasets import GuitarFXDataset
|
6 |
-
from models import
|
|
|
7 |
|
8 |
SAMPLE_RATE = 22050
|
9 |
TRAIN_SPLIT = 0.8
|
10 |
|
11 |
|
12 |
def main():
|
13 |
-
|
14 |
-
trainer = pl.Trainer(
|
15 |
guitfx = GuitarFXDataset(
|
16 |
-
root="/Users/matthewrice/
|
17 |
sample_rate=SAMPLE_RATE,
|
18 |
effect_type=["Phaser"],
|
19 |
)
|
@@ -24,7 +25,10 @@ def main():
|
|
24 |
)
|
25 |
train = DataLoader(train_dataset, batch_size=2)
|
26 |
val = DataLoader(val_dataset, batch_size=2)
|
27 |
-
|
|
|
|
|
|
|
28 |
trainer.fit(model=model, train_dataloaders=train, val_dataloaders=val)
|
29 |
|
30 |
|
|
|
3 |
import torch
|
4 |
from torch.utils.data import DataLoader
|
5 |
from datasets import GuitarFXDataset
|
6 |
+
from models import DiffusionGenerationModel, OpenUnmixModel
|
7 |
+
|
8 |
|
9 |
SAMPLE_RATE = 22050
|
10 |
TRAIN_SPLIT = 0.8
|
11 |
|
12 |
|
13 |
def main():
|
14 |
+
wandb_logger = WandbLogger(project="RemFX", save_dir="./")
|
15 |
+
trainer = pl.Trainer(logger=wandb_logger, max_epochs=10)
|
16 |
guitfx = GuitarFXDataset(
|
17 |
+
root="/Users/matthewrice/Developer/remfx/data/egfx",
|
18 |
sample_rate=SAMPLE_RATE,
|
19 |
effect_type=["Phaser"],
|
20 |
)
|
|
|
25 |
)
|
26 |
train = DataLoader(train_dataset, batch_size=2)
|
27 |
val = DataLoader(val_dataset, batch_size=2)
|
28 |
+
|
29 |
+
# model = DiffusionGenerationModel()
|
30 |
+
model = OpenUnmixModel()
|
31 |
+
|
32 |
trainer.fit(model=model, train_dataloaders=train, val_dataloaders=val)
|
33 |
|
34 |
|
umx
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 05fd4d8a0e3e50e308579052d762a342647c3408
|