Spaces:
Runtime error
Runtime error
Commit
·
e8eaf47
1
Parent(s):
943f213
Add resume from checkpoint for training
Browse files- remfx/models.py +3 -3
- scripts/train.py +4 -0
remfx/models.py
CHANGED
@@ -43,7 +43,7 @@ class RemFXChainInference(pl.LightningModule):
|
|
43 |
effects_order = order
|
44 |
else:
|
45 |
effects_order = self.effect_order
|
46 |
-
|
47 |
[ALL_EFFECTS[i] for i, effect in enumerate(effect_label) if effect == 1.0]
|
48 |
for effect_label in rem_fx_labels
|
49 |
]
|
@@ -56,7 +56,7 @@ class RemFXChainInference(pl.LightningModule):
|
|
56 |
id="input_effected_audio",
|
57 |
samples=input_samples.cpu(),
|
58 |
sampling_rate=self.sample_rate,
|
59 |
-
caption=
|
60 |
)
|
61 |
log_wandb_audio_batch(
|
62 |
logger=self.logger,
|
@@ -66,7 +66,7 @@ class RemFXChainInference(pl.LightningModule):
|
|
66 |
caption="Target Data",
|
67 |
)
|
68 |
with torch.no_grad():
|
69 |
-
for i, (elem, effects_list) in enumerate(zip(x,
|
70 |
elem = elem.unsqueeze(0) # Add batch dim
|
71 |
# Get the correct effect by search for names in effects_order
|
72 |
effect_list_names = [effect.__name__ for effect in effects_list]
|
|
|
43 |
effects_order = order
|
44 |
else:
|
45 |
effects_order = self.effect_order
|
46 |
+
effects_present = [
|
47 |
[ALL_EFFECTS[i] for i, effect in enumerate(effect_label) if effect == 1.0]
|
48 |
for effect_label in rem_fx_labels
|
49 |
]
|
|
|
56 |
id="input_effected_audio",
|
57 |
samples=input_samples.cpu(),
|
58 |
sampling_rate=self.sample_rate,
|
59 |
+
caption="Input Data",
|
60 |
)
|
61 |
log_wandb_audio_batch(
|
62 |
logger=self.logger,
|
|
|
66 |
caption="Target Data",
|
67 |
)
|
68 |
with torch.no_grad():
|
69 |
+
for i, (elem, effects_list) in enumerate(zip(x, effects_present)):
|
70 |
elem = elem.unsqueeze(0) # Add batch dim
|
71 |
# Get the correct effect by search for names in effects_order
|
72 |
effect_list_names = [effect.__name__ for effect in effects_list]
|
scripts/train.py
CHANGED
@@ -16,6 +16,10 @@ def main(cfg: DictConfig):
|
|
16 |
log.info(f"Instantiating model <{cfg.model._target_}>.")
|
17 |
model = hydra.utils.instantiate(cfg.model, _convert_="partial")
|
18 |
|
|
|
|
|
|
|
|
|
19 |
# Init all callbacks
|
20 |
callbacks = []
|
21 |
if "callbacks" in cfg:
|
|
|
16 |
log.info(f"Instantiating model <{cfg.model._target_}>.")
|
17 |
model = hydra.utils.instantiate(cfg.model, _convert_="partial")
|
18 |
|
19 |
+
if "ckpt_path" in cfg:
|
20 |
+
log.info(f"Loading checkpoint from <{cfg.ckpt_path}>.")
|
21 |
+
model = model.load_from_checkpoint(cfg.ckpt_path)
|
22 |
+
|
23 |
# Init all callbacks
|
24 |
callbacks = []
|
25 |
if "callbacks" in cfg:
|