Christian J. Steinmetz commited on
Commit
c756b1d
·
1 Parent(s): d254115

changing linear layers to MLP

Browse files
README.md CHANGED
@@ -77,4 +77,30 @@ python scripts/download.py vocalset guitarset idmt-smt-guitar idmt-smt-bass idmt
77
  To run audio effects classifiction:
78
  ```
79
  python scripts/train.py model=classifier "effects_to_use=[compressor, distortion, reverb, chorus, delay]" "effects_to_remove=[]" max_kept_effects=5 max_removed_effects=0 shuffle_kept_effects=True shuffle_removed_effects=True accelerator='gpu' render_root=/scratch/RemFX render_files=True
80
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  To run audio effects classifiction:
78
  ```
79
  python scripts/train.py model=classifier "effects_to_use=[compressor, distortion, reverb, chorus, delay]" "effects_to_remove=[]" max_kept_effects=5 max_removed_effects=0 shuffle_kept_effects=True shuffle_removed_effects=True accelerator='gpu' render_root=/scratch/RemFX render_files=True
80
+ ```
81
+
82
+ ```
83
+ srun --comment harmonai --partition=g40 --gpus=1 --cpus-per-gpu=12 --job-name=harmonai --pty bash -i
84
+ source env/bin/activate
85
+ rsync -aP /fsx/home-csteinmetz1/data/EffectSet_cjs.tar /scratch
86
+ tar -xvf EffectSet_cjs.tar
87
+ mv scratch/EffectSet_cjs ./EffectSet_cjs
88
+
89
+ export DATASET_ROOT="/admin/home-csteinmetz1/data/remfx-data"
90
+ export WANDB_PROJECT="RemFX"
91
+ export WANDB_ENTITY="cjstein"
92
+
93
+ python scripts/train.py +exp=5-5.yaml model=cls_vggish render_files=False logs_dir=/scratch/cjs-log datamodule.batch_size=64
94
+ python scripts/train.py +exp=5-5.yaml model=cls_panns_pt render_files=False logs_dir=/scratch/cjs-log datamodule.batch_size=64
95
+ python scripts/train.py +exp=5-5.yaml model=cls_wav2vec2 render_files=False logs_dir=/scratch/cjs-log datamodule.batch_size=64
96
+ python scripts/train.py +exp=5-5.yaml model=cls_wav2clip render_files=False logs_dir=/scratch/cjs-log datamodule.batch_size=64
97
+ ```
98
+
99
+ ### Installing HEAR models
100
+
101
+ wav2clip
102
+ ```
103
+ pip install hearbaseline
104
+ pip install git+https://github.com/hohsiangwu/wav2clip-hear.git
105
+ pip install git+https://github.com/qiuqiangkong/HEAR2021_Challenge_PANNs
106
+ wget https://zenodo.org/record/6332525/files/hear2021-panns_hear.pth
cfg/exp/5-5_cls.yaml CHANGED
@@ -56,4 +56,4 @@ trainer:
56
  accelerator: ${accelerator}
57
  devices: 1
58
  gradient_clip_val: 10.0
59
- max_steps: 150000
 
56
  accelerator: ${accelerator}
57
  devices: 1
58
  gradient_clip_val: 10.0
59
+ max_steps: 80000
cfg/model/cls_panns_44k_noaug.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ network:
8
+ _target_: remfx.classifier.Cnn14
9
+ num_classes: ${num_classes}
10
+ n_fft: 1024
11
+ hop_length: 256
12
+ n_mels: 128
13
+ sample_rate: 44100
14
+ model_sample_rate: 44100
15
+ specaugment: False
remfx/models.py CHANGED
@@ -422,14 +422,32 @@ class FXClassifier(pl.LightningModule):
422
  self.lr_weight_decay = lr_weight_decay
423
  self.sample_rate = sample_rate
424
  self.network = network
 
425
 
426
- def forward(self, x: torch.Tensor):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
  return self.network(x)
428
 
429
  def common_step(self, batch, batch_idx, mode: str = "train"):
 
430
  x, y, dry_label, wet_label = batch
431
- pred_label = self.network(x)
432
- loss = nn.functional.cross_entropy(pred_label, dry_label)
433
  self.log(
434
  f"{mode}_loss",
435
  loss,
@@ -440,11 +458,12 @@ class FXClassifier(pl.LightningModule):
440
  sync_dist=True,
441
  )
442
 
 
 
 
443
  self.log(
444
- f"{mode}_mAP",
445
- torchmetrics.functional.retrieval_average_precision(
446
- pred_label, dry_label.long()
447
- ),
448
  on_step=True,
449
  on_epoch=True,
450
  prog_bar=True,
@@ -452,6 +471,17 @@ class FXClassifier(pl.LightningModule):
452
  sync_dist=True,
453
  )
454
 
 
 
 
 
 
 
 
 
 
 
 
455
  return loss
456
 
457
  def training_step(self, batch, batch_idx):
 
422
  self.lr_weight_decay = lr_weight_decay
423
  self.sample_rate = sample_rate
424
  self.network = network
425
+ self.effects = ["distortion", "compressor", "reverb", "chorus", "delay"]
426
 
427
+ self.train_f1 = torchmetrics.classification.MultilabelF1Score(
428
+ 5, average="none", multidim_average="global"
429
+ )
430
+ self.val_f1 = torchmetrics.classification.MultilabelF1Score(
431
+ 5, average="none", multidim_average="global"
432
+ )
433
+ self.test_f1 = torchmetrics.classification.MultilabelF1Score(
434
+ 5, average="none", multidim_average="global"
435
+ )
436
+
437
+ self.metrics = {
438
+ "train": self.train_f1,
439
+ "valid": self.val_f1,
440
+ "test": self.test_f1,
441
+ }
442
+
443
+ def forward(self, x: torch.Tensor, train: bool = False):
444
  return self.network(x)
445
 
446
  def common_step(self, batch, batch_idx, mode: str = "train"):
447
+ train = True if mode == "train" else False
448
  x, y, dry_label, wet_label = batch
449
+ pred_label = self(x, train)
450
+ loss = nn.functional.cross_entropy(pred_label, wet_label)
451
  self.log(
452
  f"{mode}_loss",
453
  loss,
 
458
  sync_dist=True,
459
  )
460
 
461
+ metrics = self.metrics[mode](pred_label, wet_label.long())
462
+ avg_metrics = torch.mean(metrics)
463
+
464
  self.log(
465
+ f"{mode}_f1_avg",
466
+ avg_metrics,
 
 
467
  on_step=True,
468
  on_epoch=True,
469
  prog_bar=True,
 
471
  sync_dist=True,
472
  )
473
 
474
+ for idx, effect_name in enumerate(self.effects):
475
+ self.log(
476
+ f"{mode}_f1_{effect_name}",
477
+ metrics[idx],
478
+ on_step=True,
479
+ on_epoch=True,
480
+ prog_bar=True,
481
+ logger=True,
482
+ sync_dist=True,
483
+ )
484
+
485
  return loss
486
 
487
  def training_step(self, batch, batch_idx):