mattricesound commited on
Commit
90cacdf
·
1 Parent(s): e4c0874

Fix folder structure

Browse files
.gitignore CHANGED
@@ -6,7 +6,6 @@ data/
6
  .DS_Store
7
  __pycache__/
8
  lightning_logs/
9
- RemFX/
10
  outputs/
11
  logs/
12
  .vscode/
 
6
  .DS_Store
7
  __pycache__/
8
  lightning_logs/
 
9
  outputs/
10
  logs/
11
  .vscode/
datasets.py → remfx/datasets.py RENAMED
File without changes
remfx/models.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor, nn
3
+ import pytorch_lightning as pl
4
+ from einops import rearrange
5
+ import wandb
6
+ from audio_diffusion_pytorch import DiffusionModel
7
+ import auraloss
8
+
9
+ from umx.openunmix.model import OpenUnmix, Separator
10
+
11
+
12
+ class RemFXModel(pl.LightningModule):
13
+ def __init__(
14
+ self,
15
+ lr: float,
16
+ lr_beta1: float,
17
+ lr_beta2: float,
18
+ lr_eps: float,
19
+ lr_weight_decay: float,
20
+ sample_rate: float,
21
+ network: nn.Module,
22
+ ):
23
+ super().__init__()
24
+ self.lr = lr
25
+ self.lr_beta1 = lr_beta1
26
+ self.lr_beta2 = lr_beta2
27
+ self.lr_eps = lr_eps
28
+ self.lr_weight_decay = lr_weight_decay
29
+ self.sample_rate = sample_rate
30
+ self.model = network
31
+
32
+ @property
33
+ def device(self):
34
+ return next(self.model.parameters()).device
35
+
36
+ def configure_optimizers(self):
37
+ optimizer = torch.optim.AdamW(
38
+ list(self.model.parameters()),
39
+ lr=self.lr,
40
+ betas=(self.lr_beta1, self.lr_beta2),
41
+ eps=self.lr_eps,
42
+ weight_decay=self.lr_weight_decay,
43
+ )
44
+ return optimizer
45
+
46
+ def training_step(self, batch, batch_idx):
47
+ loss = self.common_step(batch, batch_idx, mode="train")
48
+ return loss
49
+
50
+ def validation_step(self, batch, batch_idx):
51
+ loss = self.common_step(batch, batch_idx, mode="valid")
52
+
53
+ def common_step(self, batch, batch_idx, mode: str = "train"):
54
+ loss = self.model(batch)
55
+ self.log(f"{mode}_loss", loss)
56
+ return loss
57
+
58
+ def on_validation_epoch_start(self):
59
+ self.log_next = True
60
+
61
+ def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
62
+ if self.log_next:
63
+ x, target, label = batch
64
+ y = self.model.sample(x)
65
+ log_wandb_audio_batch(
66
+ logger=self.logger,
67
+ id="sample",
68
+ samples=x.cpu(),
69
+ sampling_rate=self.sample_rate,
70
+ caption=f"Epoch {self.current_epoch}",
71
+ )
72
+ log_wandb_audio_batch(
73
+ logger=self.logger,
74
+ id="prediction",
75
+ samples=y.cpu(),
76
+ sampling_rate=self.sample_rate,
77
+ caption=f"Epoch {self.current_epoch}",
78
+ )
79
+ log_wandb_audio_batch(
80
+ logger=self.logger,
81
+ id="target",
82
+ samples=target.cpu(),
83
+ sampling_rate=self.sample_rate,
84
+ caption=f"Epoch {self.current_epoch}",
85
+ )
86
+ self.log_next = False
87
+
88
+
89
+ class OpenUnmixModel(torch.nn.Module):
90
+ def __init__(
91
+ self,
92
+ n_fft: int = 2048,
93
+ hop_length: int = 512,
94
+ n_channels: int = 1,
95
+ alpha: float = 0.3,
96
+ sample_rate: int = 22050,
97
+ ):
98
+ super().__init__()
99
+ self.n_channels = n_channels
100
+ self.n_fft = n_fft
101
+ self.hop_length = hop_length
102
+ self.alpha = alpha
103
+ window = torch.hann_window(n_fft)
104
+ self.register_buffer("window", window)
105
+
106
+ self.num_bins = self.n_fft // 2 + 1
107
+ self.sample_rate = sample_rate
108
+ self.model = OpenUnmix(
109
+ nb_channels=self.n_channels,
110
+ nb_bins=self.num_bins,
111
+ )
112
+ self.separator = Separator(
113
+ target_models={"other": self.model},
114
+ nb_channels=self.n_channels,
115
+ sample_rate=self.sample_rate,
116
+ n_fft=self.n_fft,
117
+ n_hop=self.hop_length,
118
+ )
119
+ self.loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
120
+ n_bins=self.num_bins, sample_rate=self.sample_rate
121
+ )
122
+
123
+ def forward(self, batch):
124
+ x, target, label = batch
125
+ X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
126
+ Y = self.model(X)
127
+ sep_out = self.separator(x).squeeze(1)
128
+ loss = self.loss_fn(sep_out, target)
129
+
130
+ return loss
131
+
132
+ def sample(self, x: Tensor) -> Tensor:
133
+ return self.separator(x).squeeze(1)
134
+
135
+
136
+ class DiffusionGenerationModel(nn.Module):
137
+ def __init__(self, n_channels: int = 1):
138
+ super().__init__()
139
+ self.model = DiffusionModel(in_channels=n_channels)
140
+
141
+ def forward(self, batch):
142
+ x, target, label = batch
143
+ return self.model(x)
144
+
145
+ def sample(self, x: Tensor, num_steps: int = 10) -> Tensor:
146
+ noise = torch.randn(x.shape).to(x)
147
+ return self.model.sample(noise, num_steps=num_steps)
148
+
149
+
150
+ def log_wandb_audio_batch(
151
+ logger: pl.loggers.WandbLogger,
152
+ id: str,
153
+ samples: Tensor,
154
+ sampling_rate: int,
155
+ caption: str = "",
156
+ ):
157
+ num_items = samples.shape[0]
158
+ samples = rearrange(samples, "b c t -> b t c")
159
+ for idx in range(num_items):
160
+ logger.experiment.log(
161
+ {
162
+ f"{id}_{idx}": wandb.Audio(
163
+ samples[idx].cpu().numpy(),
164
+ caption=caption,
165
+ sample_rate=sampling_rate,
166
+ )
167
+ }
168
+ )
169
+
170
+
171
+ def spectrogram(
172
+ x: torch.Tensor,
173
+ window: torch.Tensor,
174
+ n_fft: int,
175
+ hop_length: int,
176
+ alpha: float,
177
+ ) -> torch.Tensor:
178
+ bs, chs, samp = x.size()
179
+ x = x.view(bs * chs, -1) # move channels onto batch dim
180
+
181
+ X = torch.stft(
182
+ x,
183
+ n_fft=n_fft,
184
+ hop_length=hop_length,
185
+ window=window,
186
+ return_complex=True,
187
+ )
188
+
189
+ # move channels back
190
+ X = X.view(bs, chs, X.shape[-2], X.shape[-1])
191
+
192
+ return torch.pow(X.abs() + 1e-8, alpha)
remfx/utils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List
3
+ import pytorch_lightning as pl
4
+ from omegaconf import DictConfig
5
+ from pytorch_lightning.utilities import rank_zero_only
6
+
7
+
8
+ def get_logger(name=__name__) -> logging.Logger:
9
+ """Initializes multi-GPU-friendly python command line logger."""
10
+
11
+ logger = logging.getLogger(name)
12
+
13
+ # this ensures all logging levels get marked with the rank zero decorator
14
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
15
+ for level in (
16
+ "debug",
17
+ "info",
18
+ "warning",
19
+ "error",
20
+ "exception",
21
+ "fatal",
22
+ "critical",
23
+ ):
24
+ setattr(logger, level, rank_zero_only(getattr(logger, level)))
25
+
26
+ return logger
27
+
28
+
29
+ log = get_logger(__name__)
30
+
31
+
32
+ @rank_zero_only
33
+ def log_hyperparameters(
34
+ config: DictConfig,
35
+ model: pl.LightningModule,
36
+ datamodule: pl.LightningDataModule,
37
+ trainer: pl.Trainer,
38
+ callbacks: List[pl.Callback],
39
+ logger: pl.loggers.logger.Logger,
40
+ ) -> None:
41
+ """Controls which config parts are saved by Lightning loggers.
42
+ Additionaly saves:
43
+ - number of model parameters
44
+ """
45
+
46
+ if not trainer.logger:
47
+ return
48
+
49
+ hparams = {}
50
+
51
+ # choose which parts of hydra config will be saved to loggers
52
+ hparams["model"] = config["model"]
53
+
54
+ # save number of model parameters
55
+ hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
56
+ hparams["model/params/trainable"] = sum(
57
+ p.numel() for p in model.parameters() if p.requires_grad
58
+ )
59
+ hparams["model/params/non_trainable"] = sum(
60
+ p.numel() for p in model.parameters() if not p.requires_grad
61
+ )
62
+
63
+ hparams["datamodule"] = config["datamodule"]
64
+ hparams["trainer"] = config["trainer"]
65
+
66
+ if "seed" in config:
67
+ hparams["seed"] = config["seed"]
68
+ if "callbacks" in config:
69
+ hparams["callbacks"] = config["callbacks"]
70
+
71
+ logger.experiment.config.update(hparams)
download_egfx.sh → scripts/download_egfx.sh RENAMED
File without changes
train.py → scripts/train.py RENAMED
File without changes