mattricesound commited on
Commit
e8eaf47
·
1 Parent(s): 943f213

Add resume from checkpoint for training

Browse files
Files changed (2) hide show
  1. remfx/models.py +3 -3
  2. scripts/train.py +4 -0
remfx/models.py CHANGED
@@ -43,7 +43,7 @@ class RemFXChainInference(pl.LightningModule):
43
  effects_order = order
44
  else:
45
  effects_order = self.effect_order
46
- effects = [
47
  [ALL_EFFECTS[i] for i, effect in enumerate(effect_label) if effect == 1.0]
48
  for effect_label in rem_fx_labels
49
  ]
@@ -56,7 +56,7 @@ class RemFXChainInference(pl.LightningModule):
56
  id="input_effected_audio",
57
  samples=input_samples.cpu(),
58
  sampling_rate=self.sample_rate,
59
- caption=effects,
60
  )
61
  log_wandb_audio_batch(
62
  logger=self.logger,
@@ -66,7 +66,7 @@ class RemFXChainInference(pl.LightningModule):
66
  caption="Target Data",
67
  )
68
  with torch.no_grad():
69
- for i, (elem, effects_list) in enumerate(zip(x, effects)):
70
  elem = elem.unsqueeze(0) # Add batch dim
71
  # Get the correct effect by search for names in effects_order
72
  effect_list_names = [effect.__name__ for effect in effects_list]
 
43
  effects_order = order
44
  else:
45
  effects_order = self.effect_order
46
+ effects_present = [
47
  [ALL_EFFECTS[i] for i, effect in enumerate(effect_label) if effect == 1.0]
48
  for effect_label in rem_fx_labels
49
  ]
 
56
  id="input_effected_audio",
57
  samples=input_samples.cpu(),
58
  sampling_rate=self.sample_rate,
59
+ caption="Input Data",
60
  )
61
  log_wandb_audio_batch(
62
  logger=self.logger,
 
66
  caption="Target Data",
67
  )
68
  with torch.no_grad():
69
+ for i, (elem, effects_list) in enumerate(zip(x, effects_present)):
70
  elem = elem.unsqueeze(0) # Add batch dim
71
  # Get the correct effect by search for names in effects_order
72
  effect_list_names = [effect.__name__ for effect in effects_list]
scripts/train.py CHANGED
@@ -16,6 +16,10 @@ def main(cfg: DictConfig):
16
  log.info(f"Instantiating model <{cfg.model._target_}>.")
17
  model = hydra.utils.instantiate(cfg.model, _convert_="partial")
18
 
 
 
 
 
19
  # Init all callbacks
20
  callbacks = []
21
  if "callbacks" in cfg:
 
16
  log.info(f"Instantiating model <{cfg.model._target_}>.")
17
  model = hydra.utils.instantiate(cfg.model, _convert_="partial")
18
 
19
+ if "ckpt_path" in cfg:
20
+ log.info(f"Loading checkpoint from <{cfg.ckpt_path}>.")
21
+ model = model.load_from_checkpoint(cfg.ckpt_path)
22
+
23
  # Init all callbacks
24
  callbacks = []
25
  if "callbacks" in cfg: