Spaces:
Sleeping
Sleeping
Commit
·
e8a69d9
1
Parent(s):
61b9249
Log FAD only during test. Use rendered files during test.
Browse files- README.md +6 -5
- cfg/config.yaml +1 -1
- remfx/models.py +3 -0
- scripts/test.py +55 -0
README.md
CHANGED
@@ -35,10 +35,11 @@ Ex. `python scripts/train.py +exp=umx_distortion trainer.accelerator='gpu' train
|
|
35 |
- `reverb`
|
36 |
- `all` (choose random effect to apply to each file)
|
37 |
|
|
|
|
|
|
|
|
|
38 |
## Misc.
|
39 |
By default, files are rendered to `input_dir / processed / train/val/test`.
|
40 |
-
To skip rendering files (use previously rendered), add `render_files=False` to the command-line
|
41 |
-
To change the rendered location, add `render_root={path/to/dir}` to the command-line
|
42 |
-
Test
|
43 |
-
Experiment dictates data, ckpt dictates model
|
44 |
-
`python scripts/test.py +exp=umx_distortion.yaml +ckpt_path=test_ckpts/umx_dist.ckpt`
|
|
|
35 |
- `reverb`
|
36 |
- `all` (choose random effect to apply to each file)
|
37 |
|
38 |
+
### Testing
|
39 |
+
Experiment dictates data, ckpt dictates model
|
40 |
+
`python scripts/test.py +exp=umx_distortion.yaml +ckpt_path=test_ckpts/umx_dist.ckpt`
|
41 |
+
|
42 |
## Misc.
|
43 |
By default, files are rendered to `input_dir / processed / train/val/test`.
|
44 |
+
To skip rendering files (use previously rendered), add `render_files=False` to the command-line (added to test by default).
|
45 |
+
To change the rendered location, add `render_root={path/to/dir}` to the command-line (use this for train and test)
|
|
|
|
|
|
cfg/config.yaml
CHANGED
@@ -19,7 +19,7 @@ callbacks:
|
|
19 |
save_last: True # additionaly always save model from last epoch
|
20 |
mode: "min" # can be "max" or "min"
|
21 |
verbose: False
|
22 |
-
dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
|
23 |
filename: '{epoch:02d}-{valid_loss:.3f}'
|
24 |
|
25 |
datamodule:
|
|
|
19 |
save_last: True # additionaly always save model from last epoch
|
20 |
mode: "min" # can be "max" or "min"
|
21 |
verbose: False
|
22 |
+
dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}-${exp}
|
23 |
filename: '{epoch:02d}-{valid_loss:.3f}'
|
24 |
|
25 |
datamodule:
|
remfx/models.py
CHANGED
@@ -79,6 +79,9 @@ class RemFXModel(pl.LightningModule):
|
|
79 |
negate = -1
|
80 |
else:
|
81 |
negate = 1
|
|
|
|
|
|
|
82 |
self.log(
|
83 |
f"{mode}_{metric}",
|
84 |
negate * self.metrics[metric](output, y),
|
|
|
79 |
negate = -1
|
80 |
else:
|
81 |
negate = 1
|
82 |
+
# Only Log FAD on test set
|
83 |
+
if metric == "FAD" and mode != "test":
|
84 |
+
continue
|
85 |
self.log(
|
86 |
f"{mode}_{metric}",
|
87 |
negate * self.metrics[metric](output, y),
|
scripts/test.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import hydra
|
3 |
+
from omegaconf import DictConfig
|
4 |
+
import remfx.utils as utils
|
5 |
+
from pytorch_lightning.utilities.model_summary import ModelSummary
|
6 |
+
from remfx.models import RemFXModel
|
7 |
+
import torch
|
8 |
+
|
9 |
+
log = utils.get_logger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
@hydra.main(version_base=None, config_path="../cfg", config_name="config.yaml")
|
13 |
+
def main(cfg: DictConfig):
|
14 |
+
# Apply seed for reproducibility
|
15 |
+
if cfg.seed:
|
16 |
+
pl.seed_everything(cfg.seed)
|
17 |
+
cfg.render_files = False
|
18 |
+
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
|
19 |
+
datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
|
20 |
+
log.info(f"Instantiating model <{cfg.model._target_}>.")
|
21 |
+
model = hydra.utils.instantiate(cfg.model, _convert_="partial")
|
22 |
+
state_dict = torch.load(cfg.ckpt_path, map_location=torch.device("cpu"))[
|
23 |
+
"state_dict"
|
24 |
+
]
|
25 |
+
model.load_state_dict(state_dict)
|
26 |
+
|
27 |
+
# Init all callbacks
|
28 |
+
callbacks = []
|
29 |
+
if "callbacks" in cfg:
|
30 |
+
for _, cb_conf in cfg["callbacks"].items():
|
31 |
+
if "_target_" in cb_conf:
|
32 |
+
log.info(f"Instantiating callback <{cb_conf._target_}>.")
|
33 |
+
callbacks.append(hydra.utils.instantiate(cb_conf, _convert_="partial"))
|
34 |
+
|
35 |
+
logger = hydra.utils.instantiate(cfg.logger, _convert_="partial")
|
36 |
+
log.info(f"Instantiating trainer <{cfg.trainer._target_}>.")
|
37 |
+
trainer = hydra.utils.instantiate(
|
38 |
+
cfg.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
|
39 |
+
)
|
40 |
+
log.info("Logging hyperparameters!")
|
41 |
+
utils.log_hyperparameters(
|
42 |
+
config=cfg,
|
43 |
+
model=model,
|
44 |
+
datamodule=datamodule,
|
45 |
+
trainer=trainer,
|
46 |
+
callbacks=callbacks,
|
47 |
+
logger=logger,
|
48 |
+
)
|
49 |
+
summary = ModelSummary(model)
|
50 |
+
print(summary)
|
51 |
+
trainer.test(model=model, datamodule=datamodule)
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
main()
|