mattricesound commited on
Commit
f65f2ca
·
2 Parent(s): 9f1e632 4cb9c24

Merge pull request #40 from mhrice/classifier-inference

Browse files
README.md CHANGED
@@ -47,6 +47,9 @@ see `cfg/exp/default.yaml` for an example.
47
  - `reverb`
48
  - `delay`
49
 
 
 
 
50
  ## Run inference on directory
51
  Assumes directory is structured as
52
  - root
@@ -64,6 +67,7 @@ Change root path in `shell_vars.sh` and `source shell_vars.sh`
64
  `python scripts/chain_inference.py +exp=chain_inference_custom`
65
 
66
 
 
67
  ## Misc.
68
  By default, files are rendered to `input_dir / processed / {string_of_effects} / {train|val|test}`.
69
 
 
47
  - `reverb`
48
  - `delay`
49
 
50
+ ## Chain Inference
51
+ `python scripts/chain_inference.py +exp=chain_inference`
52
+
53
  ## Run inference on directory
54
  Assumes directory is structured as
55
  - root
 
67
  `python scripts/chain_inference.py +exp=chain_inference_custom`
68
 
69
 
70
+
71
  ## Misc.
72
  By default, files are rendered to `input_dir / processed / {string_of_effects} / {train|val|test}`.
73
 
cfg/exp/chain_inference.yaml CHANGED
@@ -63,4 +63,6 @@ inference_effects_ordering:
63
  - "RandomPedalboardReverb"
64
  - "RandomPedalboardChorus"
65
  - "RandomPedalboardDelay"
66
- num_bins: 1025
 
 
 
63
  - "RandomPedalboardReverb"
64
  - "RandomPedalboardChorus"
65
  - "RandomPedalboardDelay"
66
+ num_bins: 1025
67
+ inference_effects_shuffle: False
68
+ inference_use_all_effect_models: False
cfg/exp/chain_inference_aug.yaml CHANGED
@@ -63,4 +63,6 @@ inference_effects_ordering:
63
  - "RandomPedalboardReverb"
64
  - "RandomPedalboardChorus"
65
  - "RandomPedalboardDelay"
66
- num_bins: 1025
 
 
 
63
  - "RandomPedalboardReverb"
64
  - "RandomPedalboardChorus"
65
  - "RandomPedalboardDelay"
66
+ num_bins: 1025
67
+ inference_effects_shuffle: False
68
+ inference_use_all_effect_models: False
cfg/exp/chain_inference_aug_classifier.yaml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: demucs
4
+ - override /effects: all
5
+ seed: 12345
6
+ sample_rate: 48000
7
+ chunk_size: 262144 # 5.5s
8
+ logs_dir: "./logs"
9
+ render_root: "/scratch/EffectSet"
10
+ accelerator: "gpu"
11
+ log_audio: True
12
+ # Effects
13
+ num_kept_effects: [0,0] # [min, max]
14
+ num_removed_effects: [0,5] # [min, max]
15
+ shuffle_kept_effects: True
16
+ shuffle_removed_effects: True
17
+ num_classes: 5
18
+ effects_to_keep:
19
+ effects_to_remove:
20
+ - distortion
21
+ - compressor
22
+ - reverb
23
+ - chorus
24
+ - delay
25
+ datamodule:
26
+ batch_size: 16
27
+ num_workers: 8
28
+
29
+ dcunet:
30
+ _target_: remfx.models.RemFX
31
+ lr: 1e-4
32
+ lr_beta1: 0.95
33
+ lr_beta2: 0.999
34
+ lr_eps: 1e-6
35
+ lr_weight_decay: 1e-3
36
+ sample_rate: ${sample_rate}
37
+ network:
38
+ _target_: remfx.models.DCUNetModel
39
+ architecture: "Large-DCUNet-20"
40
+ stft_kernel_size: 512
41
+ fix_length_mode: "pad"
42
+ sample_rate: ${sample_rate}
43
+ num_bins: 1025
44
+
45
+ classifier:
46
+ _target_: remfx.models.FXClassifier
47
+ lr: 3e-4
48
+ lr_weight_decay: 1e-3
49
+ sample_rate: ${sample_rate}
50
+ mixup: False
51
+ network:
52
+ _target_: remfx.classifier.Cnn14
53
+ num_classes: ${num_classes}
54
+ n_fft: 2048
55
+ hop_length: 512
56
+ n_mels: 128
57
+ sample_rate: ${sample_rate}
58
+ model_sample_rate: ${sample_rate}
59
+ specaugment: False
60
+ classifier_ckpt: "ckpts/classifier.ckpt"
61
+
62
+ ckpts:
63
+ RandomPedalboardDistortion:
64
+ model: ${model}
65
+ ckpt_path: "ckpts/demucs_distortion_aug.ckpt"
66
+ RandomPedalboardCompressor:
67
+ model: ${model}
68
+ ckpt_path: "ckpts/demucs_compressor_aug.ckpt"
69
+ RandomPedalboardReverb:
70
+ model: ${dcunet}
71
+ ckpt_path: "ckpts/dcunet_reverb_aug.ckpt"
72
+ RandomPedalboardChorus:
73
+ model: ${dcunet}
74
+ ckpt_path: "ckpts/dcunet_chorus_aug.ckpt"
75
+ RandomPedalboardDelay:
76
+ model: ${dcunet}
77
+ ckpt_path: "ckpts/dcunet_delay_aug.ckpt"
78
+
79
+ inference_effects_ordering:
80
+ - "RandomPedalboardDistortion"
81
+ - "RandomPedalboardCompressor"
82
+ - "RandomPedalboardReverb"
83
+ - "RandomPedalboardChorus"
84
+ - "RandomPedalboardDelay"
85
+ num_bins: 1025
86
+ inference_effects_shuffle: False
87
+ inference_use_all_effect_models: False
cfg/exp/chain_inference_custom.yaml CHANGED
@@ -68,4 +68,6 @@ inference_effects_ordering:
68
  - "RandomPedalboardReverb"
69
  - "RandomPedalboardChorus"
70
  - "RandomPedalboardDelay"
71
- num_bins: 1025
 
 
 
68
  - "RandomPedalboardReverb"
69
  - "RandomPedalboardChorus"
70
  - "RandomPedalboardDelay"
71
+ num_bins: 1025
72
+ inference_effects_shuffle: False
73
+ inference_use_all_effect_models: False
remfx/callbacks.py CHANGED
@@ -64,9 +64,6 @@ class AudioCallback(Callback):
64
  ]
65
  for i, label in enumerate(effects_present_name):
66
  self.log(f"{'_'.join(label)}", 0.0)
67
- # self.log(f"{effects}_{i}", label)
68
- # trainer.logger.experiment.log(
69
- # {f"effects_{i}": f"{'_'.join(label)}"}
70
  else:
71
  y = pl_module.model.sample(x)
72
  # Concat samples together for easier viewing in dashboard
 
64
  ]
65
  for i, label in enumerate(effects_present_name):
66
  self.log(f"{'_'.join(label)}", 0.0)
 
 
 
67
  else:
68
  y = pl_module.model.sample(x)
69
  # Concat samples together for easier viewing in dashboard
remfx/classifier.py CHANGED
@@ -1,9 +1,11 @@
1
  import torch
2
  import torchaudio
3
  import torch.nn as nn
4
- import hearbaseline
5
- import hearbaseline.vggish
6
- import hearbaseline.wav2vec2
 
 
7
 
8
  import wav2clip_hear
9
  import panns_hear
 
1
  import torch
2
  import torchaudio
3
  import torch.nn as nn
4
+
5
+ # import hearbaseline
6
+
7
+ # import hearbaseline.vggish
8
+ # import hearbaseline.wav2vec2
9
 
10
  import wav2clip_hear
11
  import panns_hear
remfx/datasets.py CHANGED
@@ -13,10 +13,10 @@ from typing import Any, List, Dict
13
  from torch.utils.data import Dataset, DataLoader
14
  from remfx.utils import select_random_chunk
15
  import multiprocessing
 
16
 
17
 
18
- # https://zenodo.org/record/1193957 -> VocalSet
19
-
20
  ALL_EFFECTS = effect_lib.Pedalboard_Effects
21
  # print(ALL_EFFECTS)
22
 
@@ -404,6 +404,7 @@ class EffectDataset(Dataset):
404
  self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
405
  self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
406
  self.normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=-20)
 
407
  self.effects = effect_modules
408
  self.shuffle_kept_effects = shuffle_kept_effects
409
  self.shuffle_removed_effects = shuffle_removed_effects
@@ -471,7 +472,6 @@ class EffectDataset(Dataset):
471
  chunk = select_random_chunk(
472
  random_file_choice, self.chunk_size, self.sample_rate
473
  )
474
-
475
  # Sum to mono
476
  if chunk.shape[0] > 1:
477
  chunk = chunk.sum(0, keepdim=True)
@@ -568,46 +568,52 @@ class EffectDataset(Dataset):
568
  # Index in effect settings
569
  effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
570
  effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
571
- # Apply
572
- dry_labels = []
573
- for effect in effects_to_apply:
574
- # Normalize in-between effects
575
- dry = self.normalize(effect(dry))
576
- dry_labels.append(ALL_EFFECTS.index(type(effect)))
577
-
578
- # Apply effects_to_remove
579
- # Shuffle effects if specified
580
- if self.shuffle_removed_effects:
581
- effect_indices = torch.randperm(len(self.effects_to_remove))
582
- else:
583
- effect_indices = torch.arange(len(self.effects_to_remove))
584
- wet = torch.clone(dry)
585
- r1 = self.num_removed_effects[0]
586
- r2 = self.num_removed_effects[1]
587
- num_removed_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
588
- effect_indices = effect_indices[:num_removed_effects]
589
- # Index in effect settings
590
- effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices]
591
- effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
592
- # Apply
593
- wet_labels = []
594
- for effect in effects_to_apply:
595
- # Normalize in-between effects
596
- wet = self.normalize(effect(wet))
597
- wet_labels.append(ALL_EFFECTS.index(type(effect)))
598
-
599
- wet_labels_tensor = torch.zeros(len(ALL_EFFECTS))
600
- dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
601
-
602
- for label_idx in wet_labels:
603
- wet_labels_tensor[label_idx] = 1.0
604
-
605
- for label_idx in dry_labels:
606
- dry_labels_tensor[label_idx] = 1.0
607
-
608
- # Normalize
609
- normalized_dry = self.normalize(dry)
610
- normalized_wet = self.normalize(wet)
 
 
 
 
 
 
611
  return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
612
 
613
 
@@ -692,7 +698,7 @@ class EffectDatamodule(pl.LightningDataModule):
692
  def test_dataloader(self) -> DataLoader:
693
  return DataLoader(
694
  dataset=self.test_dataset,
695
- batch_size=self.test_batch_size, # Use small, consistent batch size for testing
696
  num_workers=self.num_workers,
697
  pin_memory=self.pin_memory,
698
  shuffle=False,
 
13
  from torch.utils.data import Dataset, DataLoader
14
  from remfx.utils import select_random_chunk
15
  import multiprocessing
16
+ from auraloss.freq import MultiResolutionSTFTLoss
17
 
18
 
19
+ STFT_THRESH = 1e-3
 
20
  ALL_EFFECTS = effect_lib.Pedalboard_Effects
21
  # print(ALL_EFFECTS)
22
 
 
404
  self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
405
  self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
406
  self.normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=-20)
407
+ self.mrstft = MultiResolutionSTFTLoss(sample_rate=sample_rate)
408
  self.effects = effect_modules
409
  self.shuffle_kept_effects = shuffle_kept_effects
410
  self.shuffle_removed_effects = shuffle_removed_effects
 
472
  chunk = select_random_chunk(
473
  random_file_choice, self.chunk_size, self.sample_rate
474
  )
 
475
  # Sum to mono
476
  if chunk.shape[0] > 1:
477
  chunk = chunk.sum(0, keepdim=True)
 
568
  # Index in effect settings
569
  effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
570
  effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
571
+ # stft comparison
572
+ stft = 0
573
+ while stft < STFT_THRESH:
574
+ # Apply
575
+ dry_labels = []
576
+ for effect in effects_to_apply:
577
+ # Normalize in-between effects
578
+ dry = self.normalize(effect(dry))
579
+ dry_labels.append(ALL_EFFECTS.index(type(effect)))
580
+
581
+ # Apply effects_to_remove
582
+ # Shuffle effects if specified
583
+ if self.shuffle_removed_effects:
584
+ effect_indices = torch.randperm(len(self.effects_to_remove))
585
+ else:
586
+ effect_indices = torch.arange(len(self.effects_to_remove))
587
+ wet = torch.clone(dry)
588
+ r1 = self.num_removed_effects[0]
589
+ r2 = self.num_removed_effects[1]
590
+ num_removed_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
591
+ effect_indices = effect_indices[:num_removed_effects]
592
+ # Index in effect settings
593
+ effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices]
594
+ effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
595
+ # Apply
596
+ wet_labels = []
597
+ for effect in effects_to_apply:
598
+ # Normalize in-between effects
599
+ wet = self.normalize(effect(wet))
600
+ wet_labels.append(ALL_EFFECTS.index(type(effect)))
601
+
602
+ wet_labels_tensor = torch.zeros(len(ALL_EFFECTS))
603
+ dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
604
+
605
+ for label_idx in wet_labels:
606
+ wet_labels_tensor[label_idx] = 1.0
607
+
608
+ for label_idx in dry_labels:
609
+ dry_labels_tensor[label_idx] = 1.0
610
+
611
+ # Normalize
612
+ normalized_dry = self.normalize(dry)
613
+ normalized_wet = self.normalize(wet)
614
+
615
+ # Check STFT, pick different effects if necessary
616
+ stft = self.mrstft(normalized_wet, normalized_dry)
617
  return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
618
 
619
 
 
698
  def test_dataloader(self) -> DataLoader:
699
  return DataLoader(
700
  dataset=self.test_dataset,
701
+ batch_size=1, # Use small, consistent batch size for testing
702
  num_workers=self.num_workers,
703
  pin_memory=self.pin_memory,
704
  shuffle=False,
remfx/models.py CHANGED
@@ -16,12 +16,22 @@ from remfx.callbacks import log_wandb_audio_batch
16
  from einops import rearrange
17
  from remfx import effects
18
  import asteroid
 
19
 
20
  ALL_EFFECTS = effects.Pedalboard_Effects
21
 
22
 
23
  class RemFXChainInference(pl.LightningModule):
24
- def __init__(self, models, sample_rate, num_bins, effect_order):
 
 
 
 
 
 
 
 
 
25
  super().__init__()
26
  self.model = models
27
  self.mrstftloss = MultiResolutionSTFTLoss(
@@ -36,6 +46,10 @@ class RemFXChainInference(pl.LightningModule):
36
  )
37
  self.sample_rate = sample_rate
38
  self.effect_order = effect_order
 
 
 
 
39
 
40
  def forward(self, batch, batch_idx, order=None):
41
  x, y, _, rem_fx_labels = batch
@@ -44,28 +58,46 @@ class RemFXChainInference(pl.LightningModule):
44
  effects_order = order
45
  else:
46
  effects_order = self.effect_order
47
- effects_present = [
48
- [ALL_EFFECTS[i] for i, effect in enumerate(effect_label) if effect == 1.0]
49
- for effect_label in rem_fx_labels
50
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  output = []
52
- input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
53
- target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
54
-
55
- log_wandb_audio_batch(
56
- logger=self.logger,
57
- id="input_effected_audio",
58
- samples=input_samples.cpu(),
59
- sampling_rate=self.sample_rate,
60
- caption="Input Data",
61
- )
62
- log_wandb_audio_batch(
63
- logger=self.logger,
64
- id="target_audio",
65
- samples=target_samples.cpu(),
66
- sampling_rate=self.sample_rate,
67
- caption="Target Data",
68
- )
69
  with torch.no_grad():
70
  for i, (elem, effects_list) in enumerate(zip(x, effects_present)):
71
  elem = elem.unsqueeze(0) # Add batch dim
@@ -101,22 +133,22 @@ class RemFXChainInference(pl.LightningModule):
101
  # )
102
  output.append(elem.squeeze(0))
103
  output = torch.stack(output)
104
- output_samples = rearrange(output, "b c t -> c (b t)").unsqueeze(0)
105
-
106
- log_wandb_audio_batch(
107
- logger=self.logger,
108
- id="output_audio",
109
- samples=output_samples.cpu(),
110
- sampling_rate=self.sample_rate,
111
- caption="Output Data",
112
- )
113
  loss = self.mrstftloss(output, y) + self.l1loss(output, y) * 100
114
  return loss, output
115
 
116
  def test_step(self, batch, batch_idx):
117
  x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
118
- # Random order
119
- # random.shuffle(self.effect_order)
 
120
  loss, output = self.forward(batch, batch_idx, order=self.effect_order)
121
  # Crop target to match output
122
  if output.shape[-1] < y.shape[-1]:
@@ -148,8 +180,16 @@ class RemFXChainInference(pl.LightningModule):
148
  prog_bar=True,
149
  sync_dist=True,
150
  )
 
 
 
 
151
  return loss
152
 
 
 
 
 
153
  def sample(self, batch):
154
  return self.forward(batch, 0)[1]
155
 
@@ -181,6 +221,7 @@ class RemFX(pl.LightningModule):
181
  )
182
  # Log first batch metrics input vs output only once
183
  self.log_train_audio = True
 
184
 
185
  @property
186
  def device(self):
@@ -257,9 +298,16 @@ class RemFX(pl.LightningModule):
257
  prog_bar=True,
258
  sync_dist=True,
259
  )
260
-
 
 
 
261
  return loss
262
 
 
 
 
 
263
 
264
  class OpenUnmixModel(nn.Module):
265
  def __init__(
 
16
  from einops import rearrange
17
  from remfx import effects
18
  import asteroid
19
+ import random
20
 
21
  ALL_EFFECTS = effects.Pedalboard_Effects
22
 
23
 
24
  class RemFXChainInference(pl.LightningModule):
25
+ def __init__(
26
+ self,
27
+ models,
28
+ sample_rate,
29
+ num_bins,
30
+ effect_order,
31
+ classifier=None,
32
+ shuffle_effect_order=False,
33
+ use_all_effect_models=False,
34
+ ):
35
  super().__init__()
36
  self.model = models
37
  self.mrstftloss = MultiResolutionSTFTLoss(
 
46
  )
47
  self.sample_rate = sample_rate
48
  self.effect_order = effect_order
49
+ self.classifier = classifier
50
+ self.shuffle_effect_order = shuffle_effect_order
51
+ self.output_str = "IN_SISDR,OUT_SISDR,IN_STFT,OUT_STFT\n"
52
+ self.use_all_effect_models = use_all_effect_models
53
 
54
  def forward(self, batch, batch_idx, order=None):
55
  x, y, _, rem_fx_labels = batch
 
58
  effects_order = order
59
  else:
60
  effects_order = self.effect_order
61
+
62
+ # Use classifier labels
63
+ if self.classifier:
64
+ threshold = 0.5
65
+ with torch.no_grad():
66
+ labels = torch.sigmoid(self.classifier(x))
67
+ rem_fx_labels = torch.where(labels > threshold, 1.0, 0.0)
68
+ if self.use_all_effect_models:
69
+ effects_present = [
70
+ [ALL_EFFECTS[i] for i, effect in enumerate(effect_label)]
71
+ for effect_label in rem_fx_labels
72
+ ]
73
+ else:
74
+ effects_present = [
75
+ [
76
+ ALL_EFFECTS[i]
77
+ for i, effect in enumerate(effect_label)
78
+ if effect == 1.0
79
+ ]
80
+ for effect_label in rem_fx_labels
81
+ ]
82
+
83
  output = []
84
+ # input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
85
+ # target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
86
+
87
+ # log_wandb_audio_batch(
88
+ # logger=self.logger,
89
+ # id="input_effected_audio",
90
+ # samples=input_samples.cpu(),
91
+ # sampling_rate=self.sample_rate,
92
+ # caption="Input Data",
93
+ # )
94
+ # log_wandb_audio_batch(
95
+ # logger=self.logger,
96
+ # id="target_audio",
97
+ # samples=target_samples.cpu(),
98
+ # sampling_rate=self.sample_rate,
99
+ # caption="Target Data",
100
+ # )
101
  with torch.no_grad():
102
  for i, (elem, effects_list) in enumerate(zip(x, effects_present)):
103
  elem = elem.unsqueeze(0) # Add batch dim
 
133
  # )
134
  output.append(elem.squeeze(0))
135
  output = torch.stack(output)
136
+
137
+ # log_wandb_audio_batch(
138
+ # logger=self.logger,
139
+ # id="output_audio",
140
+ # samples=output_samples.cpu(),
141
+ # sampling_rate=self.sample_rate,
142
+ # caption="Output Data",
143
+ # )
 
144
  loss = self.mrstftloss(output, y) + self.l1loss(output, y) * 100
145
  return loss, output
146
 
147
  def test_step(self, batch, batch_idx):
148
  x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
149
+ if self.shuffle_effect_order:
150
+ # Random order
151
+ random.shuffle(self.effect_order)
152
  loss, output = self.forward(batch, batch_idx, order=self.effect_order)
153
  # Crop target to match output
154
  if output.shape[-1] < y.shape[-1]:
 
180
  prog_bar=True,
181
  sync_dist=True,
182
  )
183
+ # print(f"Input_{metric}", negate * self.metrics[metric](x, y))
184
+ # print(f"test_{metric}", negate * self.metrics[metric](output, y))
185
+ self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
186
+ self.output_str += "\n"
187
  return loss
188
 
189
+ def on_test_end(self) -> None:
190
+ with open("output.csv", "w") as f:
191
+ f.write(self.output_str)
192
+
193
  def sample(self, batch):
194
  return self.forward(batch, 0)[1]
195
 
 
221
  )
222
  # Log first batch metrics input vs output only once
223
  self.log_train_audio = True
224
+ self.output_str = "IN_SISDR,OUT_SISDR,IN_STFT,OUT_STFT\n"
225
 
226
  @property
227
  def device(self):
 
298
  prog_bar=True,
299
  sync_dist=True,
300
  )
301
+ # print(f"Input_{metric}", negate * self.metrics[metric](x, y))
302
+ # print(f"test_{metric}", negate * self.metrics[metric](output, y))
303
+ self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
304
+ self.output_str += "\n"
305
  return loss
306
 
307
+ def on_test_end(self) -> None:
308
+ with open("output.csv", "w") as f:
309
+ f.write(self.output_str)
310
+
311
 
312
  class OpenUnmixModel(nn.Module):
313
  def __init__(
remfx/utils.py CHANGED
@@ -159,7 +159,7 @@ def select_random_chunk(
159
  random_start = torch.randint(0, max_len, (1,)).item()
160
  chunk = audio[:, random_start : random_start + new_chunk_size]
161
  # Skip if energy too low
162
- if torch.mean(torch.abs(chunk)) < 1e-6:
163
  return None
164
  resampled_chunk = torchaudio.functional.resample(chunk, sr, sample_rate)
165
  return resampled_chunk
 
159
  random_start = torch.randint(0, max_len, (1,)).item()
160
  chunk = audio[:, random_start : random_start + new_chunk_size]
161
  # Skip if energy too low
162
+ if torch.mean(torch.abs(chunk)) < 1e-4:
163
  return None
164
  resampled_chunk = torchaudio.functional.resample(chunk, sr, sample_rate)
165
  return resampled_chunk
scripts/chain_inference.py CHANGED
@@ -15,7 +15,7 @@ def main(cfg: DictConfig):
15
  pl.seed_everything(cfg.seed)
16
  log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
17
  datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
18
- log.info(f"Instantiating model <{cfg.model._target_}>.")
19
  models = {}
20
  for effect in cfg.ckpts:
21
  model = hydra.utils.instantiate(cfg.ckpts[effect].model, _convert_="partial")
@@ -26,6 +26,16 @@ def main(cfg: DictConfig):
26
  model.to(device)
27
  models[effect] = model
28
 
 
 
 
 
 
 
 
 
 
 
29
  callbacks = []
30
  if "callbacks" in cfg:
31
  for _, cb_conf in cfg["callbacks"].items():
@@ -54,6 +64,9 @@ def main(cfg: DictConfig):
54
  sample_rate=cfg.sample_rate,
55
  num_bins=cfg.num_bins,
56
  effect_order=cfg.inference_effects_ordering,
 
 
 
57
  )
58
  trainer.test(model=inference_model, datamodule=datamodule)
59
 
 
15
  pl.seed_everything(cfg.seed)
16
  log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
17
  datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
18
+ log.info("Instantiating Chain Inference Models")
19
  models = {}
20
  for effect in cfg.ckpts:
21
  model = hydra.utils.instantiate(cfg.ckpts[effect].model, _convert_="partial")
 
26
  model.to(device)
27
  models[effect] = model
28
 
29
+ classifier = None
30
+ if "classifier" in cfg:
31
+ log.info(f"Instantiating classifier <{cfg.classifier._target_}>.")
32
+ classifier = hydra.utils.instantiate(cfg.classifier, _convert_="partial")
33
+ ckpt_path = cfg.classifier_ckpt
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ state_dict = torch.load(ckpt_path, map_location=device)["state_dict"]
36
+ classifier.load_state_dict(state_dict)
37
+ classifier.to(device)
38
+
39
  callbacks = []
40
  if "callbacks" in cfg:
41
  for _, cb_conf in cfg["callbacks"].items():
 
64
  sample_rate=cfg.sample_rate,
65
  num_bins=cfg.num_bins,
66
  effect_order=cfg.inference_effects_ordering,
67
+ classifier=classifier,
68
+ shuffle_effect_order=cfg.inference_effects_shuffle,
69
+ use_all_effect_models=cfg.inference_use_all_effect_models,
70
  )
71
  trainer.test(model=inference_model, datamodule=datamodule)
72