mattricesound commited on
Commit
dcaaa71
·
2 Parent(s): d54e023 90cacdf

Merge pull request #4 from mhrice/hydra-init

Browse files
.gitignore CHANGED
@@ -6,4 +6,6 @@ data/
6
  .DS_Store
7
  __pycache__/
8
  lightning_logs/
9
- RemFX/
 
 
 
6
  .DS_Store
7
  __pycache__/
8
  lightning_logs/
9
+ outputs/
10
+ logs/
11
+ .vscode/
README.md CHANGED
@@ -1,7 +1,20 @@
1
 
2
- wget https://zenodo.org/record/7044411/files/Clean.zip?download=1 Clean.zip
 
 
 
 
3
 
4
- unzip Clean.zip
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- python3 -m venv env
7
- pip install -e .
 
1
 
2
+ ## Install Packages
3
+ 1. `python3 -m venv env`
4
+ 2. `source env/bin/activate`
5
+ 3. `pip install -e .`
6
+ 4. `pip install -e umx`
7
 
8
+ ## Download [GuitarFX Dataset] (https://zenodo.org/record/7044411/)
9
+ `./download_egfx.sh`
10
+
11
+ ## Train model
12
+ 1. Change Wandb variables in `shell_vars.sh`
13
+ 2. `python train.py exp=audio_diffusion`
14
+ or
15
+ 2. `python train.py exp=umx`
16
+
17
+ To add gpu, add `trainer.accelerator='gpu' trainer.devices=-1` to the command-line
18
+
19
+ Ex. `python train.py exp=umx trainer.accelerator='gpu' trainer.devices=-1`
20
 
 
 
config.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - exp: null
4
+ seed: 12345
5
+ train: True
6
+ length: 262144
7
+ sample_rate: 48000
8
+ logs_dir: "./logs"
9
+ log_every_n_steps: 1000
10
+
11
+ callbacks:
12
+ model_checkpoint:
13
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
14
+ monitor: "valid_loss" # name of the logged metric which determines when model is improving
15
+ save_top_k: 1 # save k best models (determined by above metric)
16
+ save_last: True # additionaly always save model from last epoch
17
+ mode: "min" # can be "max" or "min"
18
+ verbose: False
19
+ dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
20
+ filename: '{epoch:02d}-{valid_loss:.3f}'
21
+
22
+ datamodule:
23
+ _target_: datasets.Datamodule
24
+ dataset:
25
+ _target_: datasets.GuitarFXDataset
26
+ sample_rate: ${sample_rate}
27
+ root: ${oc.env:DATASET_ROOT}
28
+ length: ${length}
29
+ val_split: 0.2
30
+ batch_size: 16
31
+ num_workers: 8
32
+ pin_memory: True
33
+
34
+ logger:
35
+ _target_: pytorch_lightning.loggers.WandbLogger
36
+ project: ${oc.env:WANDB_PROJECT}
37
+ entity: ${oc.env:WANDB_ENTITY}
38
+ # offline: False # set True to store all logs only locally
39
+ job_type: "train"
40
+ group: ""
41
+ save_dir: "."
42
+
43
+ trainer:
44
+ _target_: pytorch_lightning.Trainer
45
+ precision: 32 # Precision used for tensors, default `32`
46
+ min_epochs: 0
47
+ max_epochs: -1
48
+ enable_model_summary: False
49
+ log_every_n_steps: 1 # Logs metrics every N batches
50
+ accumulate_grad_batches: 1
51
+ accelerator: null
52
+ devices: 1
exp/audio_diffusion.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.RemFXModel
4
+ lr: 1e-4
5
+ lr_beta1: 0.95
6
+ lr_beta2: 0.999
7
+ lr_eps: 1e-6
8
+ lr_weight_decay: 1e-3
9
+ sample_rate: ${sample_rate}
10
+ network:
11
+ _target_: remfx.models.DiffusionGenerationModel
12
+ n_channels: 1
13
+ datamodule:
14
+ dataset:
15
+ effect_types: ["Clean"]
16
+ batch_size: 2
exp/demucs.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ # @package _global_
exp/umx.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.RemFXModel
4
+ lr: 1e-4
5
+ lr_beta1: 0.95
6
+ lr_beta2: 0.999
7
+ lr_eps: 1e-6
8
+ lr_weight_decay: 1e-3
9
+ sample_rate: ${sample_rate}
10
+ network:
11
+ _target_: remfx.models.OpenUnmixModel
12
+ n_fft: 2048
13
+ hop_length: 512
14
+ n_channels: 1
15
+ alpha: 0.3
16
+ sample_rate: ${sample_rate}
17
+ datamodule:
18
+ dataset:
19
+ effect_types: ["RAT"]
main.py DELETED
@@ -1,19 +0,0 @@
1
- from audio_diffusion_pytorch import AudioDiffusionModel
2
- import torch
3
- from tqdm import tqdm
4
- import wandb
5
-
6
- model = AudioDiffusionModel(in_channels=1)
7
- wandb.init(project="RemFX", entity="mattricesound")
8
-
9
- x = torch.randn(2, 1, 2**18)
10
- for i in tqdm(range(100)):
11
- loss = model(x)
12
- loss.backward()
13
- if i % 10 == 0:
14
- print(loss)
15
- wandb.log({"loss": loss})
16
-
17
-
18
- noise = torch.randn(2, 1, 2**18)
19
- sampled = model.sample(noise=noise, num_steps=5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Experiments.ipynb → notebooks/Experiments.ipynb RENAMED
File without changes
diffusion_test.ipynb → notebooks/diffusion_test.ipynb RENAMED
File without changes
egfx.ipynb → notebooks/egfx.ipynb RENAMED
File without changes
guitar_generation_test.ipynb → notebooks/guitar_generation_test.ipynb RENAMED
File without changes
datasets.py → remfx/datasets.py RENAMED
@@ -1,10 +1,10 @@
1
- import torch
2
- from torch.utils.data import Dataset
3
  import torchaudio
4
  import torchaudio.transforms as T
5
  import torch.nn.functional as F
6
  from pathlib import Path
7
- from typing import List
 
8
 
9
  # https://zenodo.org/record/7044411/
10
 
@@ -18,18 +18,19 @@ class GuitarFXDataset(Dataset):
18
  root: str,
19
  sample_rate: int,
20
  length: int = LENGTH,
21
- effect_type: List[str] = None,
22
  ):
23
  self.length = length
24
  self.wet_files = []
25
  self.dry_files = []
26
  self.labels = []
27
  self.root = Path(root)
28
- if effect_type is None:
29
- effect_type = [
 
30
  d.name for d in self.root.iterdir() if d.is_dir() and d != "Clean"
31
  ]
32
- for i, effect in enumerate(effect_type):
33
  for pickup in Path(self.root / effect).iterdir():
34
  self.wet_files += sorted(list(pickup.glob("*.wav")))
35
  self.dry_files += sorted(
@@ -61,3 +62,50 @@ class GuitarFXDataset(Dataset):
61
  elif resampled_y.shape[-1] > self.length:
62
  resampled_y = resampled_y[:, : self.length]
63
  return (resampled_x, resampled_y, effect_label)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset, DataLoader, random_split
 
2
  import torchaudio
3
  import torchaudio.transforms as T
4
  import torch.nn.functional as F
5
  from pathlib import Path
6
+ import pytorch_lightning as pl
7
+ from typing import Any, List
8
 
9
  # https://zenodo.org/record/7044411/
10
 
 
18
  root: str,
19
  sample_rate: int,
20
  length: int = LENGTH,
21
+ effect_types: List[str] = None,
22
  ):
23
  self.length = length
24
  self.wet_files = []
25
  self.dry_files = []
26
  self.labels = []
27
  self.root = Path(root)
28
+
29
+ if effect_types is None:
30
+ effect_types = [
31
  d.name for d in self.root.iterdir() if d.is_dir() and d != "Clean"
32
  ]
33
+ for i, effect in enumerate(effect_types):
34
  for pickup in Path(self.root / effect).iterdir():
35
  self.wet_files += sorted(list(pickup.glob("*.wav")))
36
  self.dry_files += sorted(
 
62
  elif resampled_y.shape[-1] > self.length:
63
  resampled_y = resampled_y[:, : self.length]
64
  return (resampled_x, resampled_y, effect_label)
65
+
66
+
67
+ class Datamodule(pl.LightningDataModule):
68
+ def __init__(
69
+ self,
70
+ dataset,
71
+ *,
72
+ val_split: float,
73
+ batch_size: int,
74
+ num_workers: int,
75
+ pin_memory: bool = False,
76
+ **kwargs: int,
77
+ ) -> None:
78
+ super().__init__()
79
+ self.dataset = dataset
80
+ self.val_split = val_split
81
+ self.batch_size = batch_size
82
+ self.num_workers = num_workers
83
+ self.pin_memory = pin_memory
84
+ self.data_train: Any = None
85
+ self.data_val: Any = None
86
+
87
+ def setup(self, stage: Any = None) -> None:
88
+ split = [1.0 - self.val_split, self.val_split]
89
+ train_size = int(split[0] * len(self.dataset))
90
+ val_size = int(split[1] * len(self.dataset))
91
+ self.data_train, self.data_val = random_split(
92
+ self.dataset, [train_size, val_size]
93
+ )
94
+
95
+ def train_dataloader(self) -> DataLoader:
96
+ return DataLoader(
97
+ dataset=self.data_train,
98
+ batch_size=self.batch_size,
99
+ num_workers=self.num_workers,
100
+ pin_memory=self.pin_memory,
101
+ shuffle=True,
102
+ )
103
+
104
+ def val_dataloader(self) -> DataLoader:
105
+ return DataLoader(
106
+ dataset=self.data_val,
107
+ batch_size=self.batch_size,
108
+ num_workers=self.num_workers,
109
+ pin_memory=self.pin_memory,
110
+ shuffle=False,
111
+ )
models.py → remfx/models.py RENAMED
@@ -1,63 +1,59 @@
1
  import torch
2
- from torch import Tensor
3
  import pytorch_lightning as pl
4
  from einops import rearrange
5
  import wandb
6
- from audio_diffusion_pytorch import AudioDiffusionModel
 
7
 
8
- import sys
9
-
10
- sys.path.append("./umx")
11
  from umx.openunmix.model import OpenUnmix, Separator
12
 
13
 
14
- SAMPLE_RATE = 22050 # From audio-diffusion-pytorch
15
-
16
-
17
- class OpenUnmixModel(pl.LightningModule):
18
  def __init__(
19
  self,
20
- n_fft: int = 2048,
21
- hop_length: int = 512,
22
- alpha: float = 0.3,
 
 
 
 
23
  ):
24
  super().__init__()
25
- self.model = OpenUnmix(
26
- nb_channels=1,
27
- nb_bins=n_fft // 2 + 1,
28
- )
29
- self.n_fft = n_fft
30
- self.hop_length = hop_length
31
- self.alpha = alpha
32
- window = torch.hann_window(n_fft)
33
- self.register_buffer("window", window)
 
 
34
 
35
- def forward(self, x: torch.Tensor):
36
- return self.model(x)
 
 
 
 
 
 
 
37
 
38
  def training_step(self, batch, batch_idx):
39
- loss, _ = self.common_step(batch, batch_idx, mode="train")
40
  return loss
41
 
42
  def validation_step(self, batch, batch_idx):
43
- loss, Y = self.common_step(batch, batch_idx, mode="val")
44
- return loss, Y
45
 
46
  def common_step(self, batch, batch_idx, mode: str = "train"):
47
- x, target, label = batch
48
- X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
49
- Y = self(X)
50
- Y_hat = spectrogram(
51
- target, self.window, self.n_fft, self.hop_length, self.alpha
52
- )
53
- loss = torch.nn.functional.mse_loss(Y, Y_hat)
54
- self.log(f"{mode}_loss", loss, on_step=True, on_epoch=True)
55
- return loss, Y
56
-
57
- def configure_optimizers(self):
58
- return torch.optim.Adam(
59
- self.parameters(), lr=1e-4, betas=(0.95, 0.999), eps=1e-6, weight_decay=1e-3
60
- )
61
 
62
  def on_validation_epoch_start(self):
63
  self.log_next = True
@@ -65,87 +61,90 @@ class OpenUnmixModel(pl.LightningModule):
65
  def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
66
  if self.log_next:
67
  x, target, label = batch
68
- s = Separator(
69
- target_models={"other": self.model},
70
- nb_channels=1,
71
- sample_rate=SAMPLE_RATE,
72
- n_fft=self.n_fft,
73
- n_hop=self.hop_length,
74
- ).to(self.device)
75
- outputs = s(x).squeeze(1)
76
  log_wandb_audio_batch(
77
  logger=self.logger,
78
  id="sample",
79
  samples=x.cpu(),
80
- sampling_rate=SAMPLE_RATE,
81
  caption=f"Epoch {self.current_epoch}",
82
  )
83
  log_wandb_audio_batch(
84
  logger=self.logger,
85
  id="prediction",
86
- samples=outputs.cpu(),
87
- sampling_rate=SAMPLE_RATE,
88
  caption=f"Epoch {self.current_epoch}",
89
  )
90
  log_wandb_audio_batch(
91
- logger=self.loggger,
92
  id="target",
93
  samples=target.cpu(),
94
- sampling_rate=SAMPLE_RATE,
95
  caption=f"Epoch {self.current_epoch}",
96
  )
97
  self.log_next = False
98
 
99
 
100
- class DiffusionGenerationModel(pl.LightningModule):
101
- def __init__(self, model: torch.nn.Module):
 
 
 
 
 
 
 
102
  super().__init__()
103
- self.model = model
 
 
 
 
 
104
 
105
- def forward(self, x: torch.Tensor):
106
- return self.model(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- def sample(self, *args, **kwargs) -> Tensor:
109
- return self.model.sample(*args, **kwargs)
 
 
 
 
110
 
111
- def training_step(self, batch, batch_idx):
112
- loss = self.common_step(batch, batch_idx, mode="train")
113
  return loss
114
 
115
- def validation_step(self, batch, batch_idx):
116
- loss = self.common_step(batch, batch_idx, mode="val")
117
-
118
- def common_step(self, batch, batch_idx, mode: str = "train"):
119
- x, target, label = batch
120
- loss = self(x)
121
- self.log(f"{mode}_loss", loss, on_step=True, on_epoch=True)
122
- return loss
123
 
124
- def configure_optimizers(self):
125
- return torch.optim.Adam(
126
- self.parameters(), lr=1e-4, betas=(0.95, 0.999), eps=1e-6, weight_decay=1e-3
127
- )
128
 
129
- def on_validation_epoch_start(self):
130
- self.log_next = True
 
 
131
 
132
- def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
133
  x, target, label = batch
134
- if self.log_next:
135
- self.log_sample(x)
136
- self.log_next = False
137
 
138
- @torch.no_grad()
139
- def log_sample(self, batch, num_steps=10):
140
- # Get start diffusion noise
141
- noise = torch.randn(batch.shape, device=self.device)
142
- sampled = self.sample(noise=noise, num_steps=num_steps) # Suggested range: 2-50
143
- log_wandb_audio_batch(
144
- id="sample",
145
- samples=sampled,
146
- sampling_rate=SAMPLE_RATE,
147
- caption=f"Sampled in {num_steps} steps",
148
- )
149
 
150
 
151
  def log_wandb_audio_batch(
 
1
  import torch
2
+ from torch import Tensor, nn
3
  import pytorch_lightning as pl
4
  from einops import rearrange
5
  import wandb
6
+ from audio_diffusion_pytorch import DiffusionModel
7
+ import auraloss
8
 
 
 
 
9
  from umx.openunmix.model import OpenUnmix, Separator
10
 
11
 
12
+ class RemFXModel(pl.LightningModule):
 
 
 
13
  def __init__(
14
  self,
15
+ lr: float,
16
+ lr_beta1: float,
17
+ lr_beta2: float,
18
+ lr_eps: float,
19
+ lr_weight_decay: float,
20
+ sample_rate: float,
21
+ network: nn.Module,
22
  ):
23
  super().__init__()
24
+ self.lr = lr
25
+ self.lr_beta1 = lr_beta1
26
+ self.lr_beta2 = lr_beta2
27
+ self.lr_eps = lr_eps
28
+ self.lr_weight_decay = lr_weight_decay
29
+ self.sample_rate = sample_rate
30
+ self.model = network
31
+
32
+ @property
33
+ def device(self):
34
+ return next(self.model.parameters()).device
35
 
36
+ def configure_optimizers(self):
37
+ optimizer = torch.optim.AdamW(
38
+ list(self.model.parameters()),
39
+ lr=self.lr,
40
+ betas=(self.lr_beta1, self.lr_beta2),
41
+ eps=self.lr_eps,
42
+ weight_decay=self.lr_weight_decay,
43
+ )
44
+ return optimizer
45
 
46
  def training_step(self, batch, batch_idx):
47
+ loss = self.common_step(batch, batch_idx, mode="train")
48
  return loss
49
 
50
  def validation_step(self, batch, batch_idx):
51
+ loss = self.common_step(batch, batch_idx, mode="valid")
 
52
 
53
  def common_step(self, batch, batch_idx, mode: str = "train"):
54
+ loss = self.model(batch)
55
+ self.log(f"{mode}_loss", loss)
56
+ return loss
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  def on_validation_epoch_start(self):
59
  self.log_next = True
 
61
  def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
62
  if self.log_next:
63
  x, target, label = batch
64
+ y = self.model.sample(x)
 
 
 
 
 
 
 
65
  log_wandb_audio_batch(
66
  logger=self.logger,
67
  id="sample",
68
  samples=x.cpu(),
69
+ sampling_rate=self.sample_rate,
70
  caption=f"Epoch {self.current_epoch}",
71
  )
72
  log_wandb_audio_batch(
73
  logger=self.logger,
74
  id="prediction",
75
+ samples=y.cpu(),
76
+ sampling_rate=self.sample_rate,
77
  caption=f"Epoch {self.current_epoch}",
78
  )
79
  log_wandb_audio_batch(
80
+ logger=self.logger,
81
  id="target",
82
  samples=target.cpu(),
83
+ sampling_rate=self.sample_rate,
84
  caption=f"Epoch {self.current_epoch}",
85
  )
86
  self.log_next = False
87
 
88
 
89
+ class OpenUnmixModel(torch.nn.Module):
90
+ def __init__(
91
+ self,
92
+ n_fft: int = 2048,
93
+ hop_length: int = 512,
94
+ n_channels: int = 1,
95
+ alpha: float = 0.3,
96
+ sample_rate: int = 22050,
97
+ ):
98
  super().__init__()
99
+ self.n_channels = n_channels
100
+ self.n_fft = n_fft
101
+ self.hop_length = hop_length
102
+ self.alpha = alpha
103
+ window = torch.hann_window(n_fft)
104
+ self.register_buffer("window", window)
105
 
106
+ self.num_bins = self.n_fft // 2 + 1
107
+ self.sample_rate = sample_rate
108
+ self.model = OpenUnmix(
109
+ nb_channels=self.n_channels,
110
+ nb_bins=self.num_bins,
111
+ )
112
+ self.separator = Separator(
113
+ target_models={"other": self.model},
114
+ nb_channels=self.n_channels,
115
+ sample_rate=self.sample_rate,
116
+ n_fft=self.n_fft,
117
+ n_hop=self.hop_length,
118
+ )
119
+ self.loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
120
+ n_bins=self.num_bins, sample_rate=self.sample_rate
121
+ )
122
 
123
+ def forward(self, batch):
124
+ x, target, label = batch
125
+ X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
126
+ Y = self.model(X)
127
+ sep_out = self.separator(x).squeeze(1)
128
+ loss = self.loss_fn(sep_out, target)
129
 
 
 
130
  return loss
131
 
132
+ def sample(self, x: Tensor) -> Tensor:
133
+ return self.separator(x).squeeze(1)
 
 
 
 
 
 
134
 
 
 
 
 
135
 
136
+ class DiffusionGenerationModel(nn.Module):
137
+ def __init__(self, n_channels: int = 1):
138
+ super().__init__()
139
+ self.model = DiffusionModel(in_channels=n_channels)
140
 
141
+ def forward(self, batch):
142
  x, target, label = batch
143
+ return self.model(x)
 
 
144
 
145
+ def sample(self, x: Tensor, num_steps: int = 10) -> Tensor:
146
+ noise = torch.randn(x.shape).to(x)
147
+ return self.model.sample(noise, num_steps=num_steps)
 
 
 
 
 
 
 
 
148
 
149
 
150
  def log_wandb_audio_batch(
remfx/utils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List
3
+ import pytorch_lightning as pl
4
+ from omegaconf import DictConfig
5
+ from pytorch_lightning.utilities import rank_zero_only
6
+
7
+
8
+ def get_logger(name=__name__) -> logging.Logger:
9
+ """Initializes multi-GPU-friendly python command line logger."""
10
+
11
+ logger = logging.getLogger(name)
12
+
13
+ # this ensures all logging levels get marked with the rank zero decorator
14
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
15
+ for level in (
16
+ "debug",
17
+ "info",
18
+ "warning",
19
+ "error",
20
+ "exception",
21
+ "fatal",
22
+ "critical",
23
+ ):
24
+ setattr(logger, level, rank_zero_only(getattr(logger, level)))
25
+
26
+ return logger
27
+
28
+
29
+ log = get_logger(__name__)
30
+
31
+
32
+ @rank_zero_only
33
+ def log_hyperparameters(
34
+ config: DictConfig,
35
+ model: pl.LightningModule,
36
+ datamodule: pl.LightningDataModule,
37
+ trainer: pl.Trainer,
38
+ callbacks: List[pl.Callback],
39
+ logger: pl.loggers.logger.Logger,
40
+ ) -> None:
41
+ """Controls which config parts are saved by Lightning loggers.
42
+ Additionaly saves:
43
+ - number of model parameters
44
+ """
45
+
46
+ if not trainer.logger:
47
+ return
48
+
49
+ hparams = {}
50
+
51
+ # choose which parts of hydra config will be saved to loggers
52
+ hparams["model"] = config["model"]
53
+
54
+ # save number of model parameters
55
+ hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
56
+ hparams["model/params/trainable"] = sum(
57
+ p.numel() for p in model.parameters() if p.requires_grad
58
+ )
59
+ hparams["model/params/non_trainable"] = sum(
60
+ p.numel() for p in model.parameters() if not p.requires_grad
61
+ )
62
+
63
+ hparams["datamodule"] = config["datamodule"]
64
+ hparams["trainer"] = config["trainer"]
65
+
66
+ if "seed" in config:
67
+ hparams["seed"] = config["seed"]
68
+ if "callbacks" in config:
69
+ hparams["callbacks"] = config["callbacks"]
70
+
71
+ logger.experiment.config.update(hparams)
download_egfx.sh → scripts/download_egfx.sh RENAMED
File without changes
scripts/train.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import hydra
3
+ from omegaconf import DictConfig
4
+ import remfx.utils as utils
5
+
6
+ log = utils.get_logger(__name__)
7
+
8
+
9
+ @hydra.main(version_base=None, config_path=".", config_name="config.yaml")
10
+ def main(cfg: DictConfig):
11
+ # Apply seed for reproducibility
12
+ print(cfg)
13
+ pl.seed_everything(cfg.seed)
14
+
15
+ log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
16
+ datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
17
+
18
+ log.info(f"Instantiating model <{cfg.model._target_}>.")
19
+ model = hydra.utils.instantiate(cfg.model, _convert_="partial")
20
+
21
+ # Init all callbacks
22
+ callbacks = []
23
+ if "callbacks" in cfg:
24
+ for _, cb_conf in cfg["callbacks"].items():
25
+ if "_target_" in cb_conf:
26
+ log.info(f"Instantiating callback <{cb_conf._target_}>.")
27
+ callbacks.append(hydra.utils.instantiate(cb_conf, _convert_="partial"))
28
+
29
+ logger = hydra.utils.instantiate(cfg.logger, _convert_="partial")
30
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>.")
31
+ trainer = hydra.utils.instantiate(
32
+ cfg.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
33
+ )
34
+ log.info("Logging hyperparameters!")
35
+ utils.log_hyperparameters(
36
+ config=cfg,
37
+ model=model,
38
+ datamodule=datamodule,
39
+ trainer=trainer,
40
+ callbacks=callbacks,
41
+ logger=logger,
42
+ )
43
+ trainer.fit(model=model, datamodule=datamodule)
44
+
45
+
46
+ if __name__ == "__main__":
47
+ main()
setup.py CHANGED
@@ -42,6 +42,8 @@ setup(
42
  "ema_pytorch",
43
  "einops",
44
  "librosa",
 
 
45
  ],
46
  include_package_data=True,
47
  license="Apache License 2.0",
 
42
  "ema_pytorch",
43
  "einops",
44
  "librosa",
45
+ "hydra-core",
46
+ "auraloss",
47
  ],
48
  include_package_data=True,
49
  license="Apache License 2.0",
shell_vars.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ export DATASET_ROOT="./data/egfx"
2
+ export WANDB_PROJECT="RemFX"
3
+ export WANDB_ENTITY="mattricesound"
train.py DELETED
@@ -1,36 +0,0 @@
1
- from pytorch_lightning.loggers import WandbLogger
2
- import pytorch_lightning as pl
3
- import torch
4
- from torch.utils.data import DataLoader
5
- from datasets import GuitarFXDataset
6
- from models import DiffusionGenerationModel, OpenUnmixModel
7
-
8
-
9
- SAMPLE_RATE = 22050
10
- TRAIN_SPLIT = 0.8
11
-
12
-
13
- def main():
14
- wandb_logger = WandbLogger(project="RemFX", save_dir="./")
15
- trainer = pl.Trainer(logger=wandb_logger, max_epochs=100)
16
- guitfx = GuitarFXDataset(
17
- root="./data/egfx",
18
- sample_rate=SAMPLE_RATE,
19
- effect_type=["Phaser"],
20
- )
21
- train_size = int(TRAIN_SPLIT * len(guitfx))
22
- val_size = len(guitfx) - train_size
23
- train_dataset, val_dataset = torch.utils.data.random_split(
24
- guitfx, [train_size, val_size]
25
- )
26
- train = DataLoader(train_dataset, batch_size=2)
27
- val = DataLoader(val_dataset, batch_size=2)
28
-
29
- # model = DiffusionGenerationModel()
30
- model = OpenUnmixModel()
31
-
32
- trainer.fit(model=model, train_dataloaders=train, val_dataloaders=val)
33
-
34
-
35
- if __name__ == "__main__":
36
- main()