Spaces:
Runtime error
Runtime error
Commit
Β·
d8f7979
1
Parent(s):
12e94ae
Restore classifier, move shell scripts to scripts
Browse files- README.md +12 -12
- remfx/classifier.py +19 -21
- remfx/models.py +1 -46
- scripts/chain_inference.py +2 -0
- download_ckpts.sh β scripts/download_ckpts.sh +0 -0
- download_eval_datasets.sh β scripts/download_eval_datasets.sh +0 -0
- eval.sh β scripts/eval.sh +3 -3
- remfx_detect.sh β scripts/remfx_detect.sh +1 -1
README.md
CHANGED
@@ -16,12 +16,12 @@ This repo can be used for many different tasks. Here are some examples.
|
|
16 |
## Run RemFX Detect on a single file
|
17 |
First, need to download the checkpoints from [zenodo](https://zenodo.org/record/8179396)
|
18 |
```
|
19 |
-
|
20 |
-
|
21 |
```
|
22 |
## Download the [General Purpose Audio Effect Removal evaluation datasets](https://zenodo.org/record/8187288)
|
23 |
```
|
24 |
-
|
25 |
```
|
26 |
|
27 |
## Download the starter datasets
|
@@ -73,28 +73,28 @@ Also note that the training assumes you have a GPU. To train on CPU, set `accele
|
|
73 |
First download the General Purpose Audio Effect Removal evaluation datasets (see above).
|
74 |
To use the pretrained RemFX model, download the checkpoints
|
75 |
```
|
76 |
-
|
77 |
```
|
78 |
Then run the evaluation script, select the RemFX configuration, between `remfx_oracle`, `remfx_detect`, and `remfx_all`. Then select N, the number of effects to remove.
|
79 |
```
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
|
87 |
```
|
88 |
To eval a custom monolithic model, first train a model (see Training)
|
89 |
Then run the evaluation script, with the config used and checkpoint_path.
|
90 |
```
|
91 |
-
|
92 |
```
|
93 |
|
94 |
To eval a custom effect-specific model as part of the inference chain, first train a model (see Training), then edit `cfg/exp/remfx_{desired_configuration}.yaml -> ckpts -> {effect}`.
|
95 |
Then run the evaluation script.
|
96 |
```
|
97 |
-
|
98 |
```
|
99 |
|
100 |
The script assumes that RemFX_eval_datasets is in the top-level directory.
|
|
|
16 |
## Run RemFX Detect on a single file
|
17 |
First, need to download the checkpoints from [zenodo](https://zenodo.org/record/8179396)
|
18 |
```
|
19 |
+
scripts/download_checkpoints.sh
|
20 |
+
scripts/remfx_detect.sh wet.wav -o dry.wav
|
21 |
```
|
22 |
## Download the [General Purpose Audio Effect Removal evaluation datasets](https://zenodo.org/record/8187288)
|
23 |
```
|
24 |
+
scripts/download_eval_datasets.sh
|
25 |
```
|
26 |
|
27 |
## Download the starter datasets
|
|
|
73 |
First download the General Purpose Audio Effect Removal evaluation datasets (see above).
|
74 |
To use the pretrained RemFX model, download the checkpoints
|
75 |
```
|
76 |
+
scripts/download_checkpoints.sh
|
77 |
```
|
78 |
Then run the evaluation script, select the RemFX configuration, between `remfx_oracle`, `remfx_detect`, and `remfx_all`. Then select N, the number of effects to remove.
|
79 |
```
|
80 |
+
scripts/eval.sh remfx_detect 0-0
|
81 |
+
scripts/eval.sh remfx_detect 1-1
|
82 |
+
scripts/eval.sh remfx_detect 2-2
|
83 |
+
scripts/eval.sh remfx_detect 3-3
|
84 |
+
scripts/eval.sh remfx_detect 4-4
|
85 |
+
scripts/eval.sh remfx_detect 5-5
|
86 |
|
87 |
```
|
88 |
To eval a custom monolithic model, first train a model (see Training)
|
89 |
Then run the evaluation script, with the config used and checkpoint_path.
|
90 |
```
|
91 |
+
scripts/eval.sh distortion_aug 0-0 -ckpt "logs/ckpts/2023-07-26-10-10-27/epoch\=05-valid_loss\=8.623.ckpt"
|
92 |
```
|
93 |
|
94 |
To eval a custom effect-specific model as part of the inference chain, first train a model (see Training), then edit `cfg/exp/remfx_{desired_configuration}.yaml -> ckpts -> {effect}`.
|
95 |
Then run the evaluation script.
|
96 |
```
|
97 |
+
scripts/eval.sh remfx_detect 0-0
|
98 |
```
|
99 |
|
100 |
The script assumes that RemFX_eval_datasets is in the top-level directory.
|
remfx/classifier.py
CHANGED
@@ -1,11 +1,9 @@
|
|
1 |
import torch
|
2 |
import torchaudio
|
3 |
import torch.nn as nn
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
# import hearbaseline.vggish
|
8 |
-
# import hearbaseline.wav2vec2
|
9 |
|
10 |
import wav2clip_hear
|
11 |
import panns_hear
|
@@ -173,10 +171,10 @@ class Cnn14(nn.Module):
|
|
173 |
|
174 |
self.fc1 = nn.Linear(2048, 2048, bias=True)
|
175 |
|
176 |
-
self.fc_audioset = nn.Linear(2048, num_classes, bias=True)
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
|
181 |
self.init_weight()
|
182 |
|
@@ -192,7 +190,7 @@ class Cnn14(nn.Module):
|
|
192 |
def init_weight(self):
|
193 |
init_bn(self.bn0)
|
194 |
init_layer(self.fc1)
|
195 |
-
init_layer(self.fc_audioset)
|
196 |
|
197 |
def forward(self, x: torch.Tensor, train: bool = False):
|
198 |
"""
|
@@ -212,12 +210,12 @@ class Cnn14(nn.Module):
|
|
212 |
# axs[1].imshow(x[0, :, :, :].detach().squeeze().cpu().numpy())
|
213 |
# plt.savefig("spec_augment.png", dpi=300)
|
214 |
|
215 |
-
x = x.permute(0, 2, 1, 3)
|
216 |
-
x = self.bn0(x)
|
217 |
-
x = x.permute(0, 2, 1, 3)
|
218 |
|
219 |
# apply standardization
|
220 |
-
|
221 |
|
222 |
x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
|
223 |
x = F.dropout(x, p=0.2, training=train)
|
@@ -239,13 +237,13 @@ class Cnn14(nn.Module):
|
|
239 |
x = F.dropout(x, p=0.5, training=train)
|
240 |
x = F.relu_(self.fc1(x))
|
241 |
|
242 |
-
|
243 |
-
|
244 |
-
|
|
|
|
|
245 |
|
246 |
-
|
247 |
-
return clipwise_output
|
248 |
-
# return outputs
|
249 |
|
250 |
|
251 |
class ConvBlock(nn.Module):
|
@@ -296,4 +294,4 @@ class ConvBlock(nn.Module):
|
|
296 |
else:
|
297 |
raise Exception("Incorrect argument!")
|
298 |
|
299 |
-
return x
|
|
|
1 |
import torch
|
2 |
import torchaudio
|
3 |
import torch.nn as nn
|
4 |
+
import hearbaseline
|
5 |
+
import hearbaseline.vggish
|
6 |
+
import hearbaseline.wav2vec2
|
|
|
|
|
7 |
|
8 |
import wav2clip_hear
|
9 |
import panns_hear
|
|
|
171 |
|
172 |
self.fc1 = nn.Linear(2048, 2048, bias=True)
|
173 |
|
174 |
+
# self.fc_audioset = nn.Linear(2048, num_classes, bias=True)
|
175 |
+
self.heads = torch.nn.ModuleList()
|
176 |
+
for _ in range(num_classes):
|
177 |
+
self.heads.append(nn.Linear(2048, 1, bias=True))
|
178 |
|
179 |
self.init_weight()
|
180 |
|
|
|
190 |
def init_weight(self):
|
191 |
init_bn(self.bn0)
|
192 |
init_layer(self.fc1)
|
193 |
+
# init_layer(self.fc_audioset)
|
194 |
|
195 |
def forward(self, x: torch.Tensor, train: bool = False):
|
196 |
"""
|
|
|
210 |
# axs[1].imshow(x[0, :, :, :].detach().squeeze().cpu().numpy())
|
211 |
# plt.savefig("spec_augment.png", dpi=300)
|
212 |
|
213 |
+
# x = x.permute(0, 2, 1, 3)
|
214 |
+
# x = self.bn0(x)
|
215 |
+
# x = x.permute(0, 2, 1, 3)
|
216 |
|
217 |
# apply standardization
|
218 |
+
x = (x - x.mean(dim=0, keepdim=True)) / x.std(dim=0, keepdim=True)
|
219 |
|
220 |
x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
|
221 |
x = F.dropout(x, p=0.2, training=train)
|
|
|
237 |
x = F.dropout(x, p=0.5, training=train)
|
238 |
x = F.relu_(self.fc1(x))
|
239 |
|
240 |
+
outputs = []
|
241 |
+
for head in self.heads:
|
242 |
+
outputs.append(torch.sigmoid(head(x)))
|
243 |
+
|
244 |
+
# clipwise_output = self.fc_audioset(x)
|
245 |
|
246 |
+
return outputs
|
|
|
|
|
247 |
|
248 |
|
249 |
class ConvBlock(nn.Module):
|
|
|
294 |
else:
|
295 |
raise Exception("Incorrect argument!")
|
296 |
|
297 |
+
return x
|
remfx/models.py
CHANGED
@@ -143,17 +143,8 @@ class RemFXChainInference(pl.LightningModule):
|
|
143 |
prog_bar=True,
|
144 |
sync_dist=True,
|
145 |
)
|
146 |
-
# print(f"Input_{metric}", negate * self.metrics[metric](x, y))
|
147 |
-
# print(f"test_{metric}", negate * self.metrics[metric](output, y))
|
148 |
-
# self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
|
149 |
-
# self.output_str += "\n"
|
150 |
return loss
|
151 |
|
152 |
-
def on_test_end(self) -> None:
|
153 |
-
pass
|
154 |
-
# with open("output.csv", "w") as f:
|
155 |
-
# f.write(self.output_str)
|
156 |
-
|
157 |
def sample(self, batch):
|
158 |
return self.forward(batch, 0)[1]
|
159 |
|
@@ -438,7 +429,6 @@ def mixup(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0):
|
|
438 |
|
439 |
return mixed_x, mixed_y, lam
|
440 |
|
441 |
-
|
442 |
class FXClassifier(pl.LightningModule):
|
443 |
def __init__(
|
444 |
self,
|
@@ -458,42 +448,7 @@ class FXClassifier(pl.LightningModule):
|
|
458 |
self.mixup = mixup
|
459 |
self.label_smoothing = label_smoothing
|
460 |
|
461 |
-
self.loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
462 |
self.loss_fn = torch.nn.BCELoss()
|
463 |
-
|
464 |
-
if False:
|
465 |
-
self.train_f1 = torchmetrics.classification.MultilabelF1Score(
|
466 |
-
5, average="none", multidim_average="global"
|
467 |
-
)
|
468 |
-
self.val_f1 = torchmetrics.classification.MultilabelF1Score(
|
469 |
-
5, average="none", multidim_average="global"
|
470 |
-
)
|
471 |
-
self.test_f1 = torchmetrics.classification.MultilabelF1Score(
|
472 |
-
5, average="none", multidim_average="global"
|
473 |
-
)
|
474 |
-
|
475 |
-
self.train_f1_avg = torchmetrics.classification.MultilabelF1Score(
|
476 |
-
5, threshold=0.5, average="macro", multidim_average="global"
|
477 |
-
)
|
478 |
-
self.val_f1_avg = torchmetrics.classification.MultilabelF1Score(
|
479 |
-
5, threshold=0.5, average="macro", multidim_average="global"
|
480 |
-
)
|
481 |
-
self.test_f1_avg = torchmetrics.classification.MultilabelF1Score(
|
482 |
-
5, threshold=0.5, average="macro", multidim_average="global"
|
483 |
-
)
|
484 |
-
|
485 |
-
self.metrics = {
|
486 |
-
"train": self.train_acc,
|
487 |
-
"valid": self.val_acc,
|
488 |
-
"test": self.test_acc,
|
489 |
-
}
|
490 |
-
|
491 |
-
self.avg_metrics = {
|
492 |
-
"train": self.train_f1_avg,
|
493 |
-
"valid": self.val_f1_avg,
|
494 |
-
"test": self.test_f1_avg,
|
495 |
-
}
|
496 |
-
|
497 |
self.metrics = torch.nn.ModuleDict()
|
498 |
for effect in self.effects:
|
499 |
self.metrics[f"train_{effect}_acc"] = torchmetrics.classification.Accuracy(
|
@@ -578,4 +533,4 @@ class FXClassifier(pl.LightningModule):
|
|
578 |
lr=self.lr,
|
579 |
weight_decay=self.lr_weight_decay,
|
580 |
)
|
581 |
-
return optimizer
|
|
|
143 |
prog_bar=True,
|
144 |
sync_dist=True,
|
145 |
)
|
|
|
|
|
|
|
|
|
146 |
return loss
|
147 |
|
|
|
|
|
|
|
|
|
|
|
148 |
def sample(self, batch):
|
149 |
return self.forward(batch, 0)[1]
|
150 |
|
|
|
429 |
|
430 |
return mixed_x, mixed_y, lam
|
431 |
|
|
|
432 |
class FXClassifier(pl.LightningModule):
|
433 |
def __init__(
|
434 |
self,
|
|
|
448 |
self.mixup = mixup
|
449 |
self.label_smoothing = label_smoothing
|
450 |
|
|
|
451 |
self.loss_fn = torch.nn.BCELoss()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
452 |
self.metrics = torch.nn.ModuleDict()
|
453 |
for effect in self.effects:
|
454 |
self.metrics[f"train_{effect}_acc"] = torchmetrics.classification.Accuracy(
|
|
|
533 |
lr=self.lr,
|
534 |
weight_decay=self.lr_weight_decay,
|
535 |
)
|
536 |
+
return optimizer
|
scripts/chain_inference.py
CHANGED
@@ -45,6 +45,7 @@ def main(cfg: DictConfig):
|
|
45 |
|
46 |
logger = hydra.utils.instantiate(cfg.logger, _convert_="partial")
|
47 |
log.info(f"Instantiating trainer <{cfg.trainer._target_}>.")
|
|
|
48 |
trainer = hydra.utils.instantiate(
|
49 |
cfg.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
|
50 |
)
|
@@ -68,6 +69,7 @@ def main(cfg: DictConfig):
|
|
68 |
shuffle_effect_order=cfg.inference_effects_shuffle,
|
69 |
use_all_effect_models=cfg.inference_use_all_effect_models,
|
70 |
)
|
|
|
71 |
trainer.test(model=inference_model, datamodule=datamodule)
|
72 |
|
73 |
|
|
|
45 |
|
46 |
logger = hydra.utils.instantiate(cfg.logger, _convert_="partial")
|
47 |
log.info(f"Instantiating trainer <{cfg.trainer._target_}>.")
|
48 |
+
cfg.trainer.accelerator = "gpu" if torch.cuda.is_available() else "cpu"
|
49 |
trainer = hydra.utils.instantiate(
|
50 |
cfg.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
|
51 |
)
|
|
|
69 |
shuffle_effect_order=cfg.inference_effects_shuffle,
|
70 |
use_all_effect_models=cfg.inference_use_all_effect_models,
|
71 |
)
|
72 |
+
|
73 |
trainer.test(model=inference_model, datamodule=datamodule)
|
74 |
|
75 |
|
download_ckpts.sh β scripts/download_ckpts.sh
RENAMED
File without changes
|
download_eval_datasets.sh β scripts/download_eval_datasets.sh
RENAMED
File without changes
|
eval.sh β scripts/eval.sh
RENAMED
@@ -1,13 +1,13 @@
|
|
1 |
#! /bin/bash
|
2 |
|
3 |
# Example usage:
|
4 |
-
#
|
5 |
-
#
|
6 |
# First 2 arguments are required, third argument is optional
|
7 |
|
8 |
# Default value for the optional parameter
|
9 |
ckpt_path=""
|
10 |
-
|
11 |
# Function to display script usage
|
12 |
function display_usage {
|
13 |
echo "Usage: $0 <experiment> <dataset> [-ckpt {ckpt_path}]"
|
|
|
1 |
#! /bin/bash
|
2 |
|
3 |
# Example usage:
|
4 |
+
# scripts/eval.sh remfx_detect 0-0
|
5 |
+
# scripts/eval.sh distortion_aug 0-0 -ckpt logs/ckpts/2023-01-21-12-21-44
|
6 |
# First 2 arguments are required, third argument is optional
|
7 |
|
8 |
# Default value for the optional parameter
|
9 |
ckpt_path=""
|
10 |
+
export DATASET_ROOT=RemFX_eval_datasets
|
11 |
# Function to display script usage
|
12 |
function display_usage {
|
13 |
echo "Usage: $0 <experiment> <dataset> [-ckpt {ckpt_path}]"
|
remfx_detect.sh β scripts/remfx_detect.sh
RENAMED
@@ -1,7 +1,7 @@
|
|
1 |
#! /bin/bash
|
2 |
|
3 |
# Example usage:
|
4 |
-
#
|
5 |
# first argument is required, second argument is optional
|
6 |
|
7 |
# Check if first argument is empty
|
|
|
1 |
#! /bin/bash
|
2 |
|
3 |
# Example usage:
|
4 |
+
# scripts/remfx_detect.sh wet.wav -o examples/output.wav
|
5 |
# first argument is required, second argument is optional
|
6 |
|
7 |
# Check if first argument is empty
|