mattricesound commited on
Commit
943f213
·
1 Parent(s): d92486b

Fix loss issue in chain_inference

Browse files
Files changed (1) hide show
  1. remfx/models.py +58 -22
remfx/models.py CHANGED
@@ -12,6 +12,7 @@ from remfx.utils import FADLoss, spectrogram
12
  from remfx.tcn import TCN
13
  from remfx.utils import causal_crop
14
  from remfx.callbacks import log_wandb_audio_batch
 
15
  from remfx import effects
16
  import asteroid
17
 
@@ -47,6 +48,23 @@ class RemFXChainInference(pl.LightningModule):
47
  for effect_label in rem_fx_labels
48
  ]
49
  output = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  with torch.no_grad():
51
  for i, (elem, effects_list) in enumerate(zip(x, effects)):
52
  elem = elem.unsqueeze(0) # Add batch dim
@@ -56,33 +74,41 @@ class RemFXChainInference(pl.LightningModule):
56
  effect for effect in effects_order if effect in effect_list_names
57
  ]
58
 
59
- log_wandb_audio_batch(
60
- logger=self.logger,
61
- id=f"{i}_Before",
62
- samples=elem.cpu(),
63
- sampling_rate=self.sample_rate,
64
- caption=effects,
65
- )
66
  for effect in effects:
67
  # Sample the model
68
  elem = self.model[effect].model.sample(elem)
69
- log_wandb_audio_batch(
70
- logger=self.logger,
71
- id=f"{i}_{effect}",
72
- samples=elem.cpu(),
73
- sampling_rate=self.sample_rate,
74
- caption=effects,
75
- )
76
- log_wandb_audio_batch(
77
- logger=self.logger,
78
- id=f"{i}_After",
79
- samples=elem.cpu(),
80
- sampling_rate=self.sample_rate,
81
- caption=effects,
82
- )
83
  output.append(elem.squeeze(0))
84
  output = torch.stack(output)
85
-
 
 
 
 
 
 
 
 
86
  loss = self.mrstftloss(output, y) + self.l1loss(output, y) * 100
87
  return loss, output
88
 
@@ -112,6 +138,16 @@ class RemFXChainInference(pl.LightningModule):
112
  prog_bar=True,
113
  sync_dist=True,
114
  )
 
 
 
 
 
 
 
 
 
 
115
 
116
  def sample(self, batch):
117
  return self.forward(batch, 0)[1]
 
12
  from remfx.tcn import TCN
13
  from remfx.utils import causal_crop
14
  from remfx.callbacks import log_wandb_audio_batch
15
+ from einops import rearrange
16
  from remfx import effects
17
  import asteroid
18
 
 
48
  for effect_label in rem_fx_labels
49
  ]
50
  output = []
51
+ input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
52
+ target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
53
+
54
+ log_wandb_audio_batch(
55
+ logger=self.logger,
56
+ id="input_effected_audio",
57
+ samples=input_samples.cpu(),
58
+ sampling_rate=self.sample_rate,
59
+ caption=effects,
60
+ )
61
+ log_wandb_audio_batch(
62
+ logger=self.logger,
63
+ id="target_audio",
64
+ samples=target_samples.cpu(),
65
+ sampling_rate=self.sample_rate,
66
+ caption="Target Data",
67
+ )
68
  with torch.no_grad():
69
  for i, (elem, effects_list) in enumerate(zip(x, effects)):
70
  elem = elem.unsqueeze(0) # Add batch dim
 
74
  effect for effect in effects_order if effect in effect_list_names
75
  ]
76
 
77
+ # log_wandb_audio_batch(
78
+ # logger=self.logger,
79
+ # id=f"{i}_Before",
80
+ # samples=elem.cpu(),
81
+ # sampling_rate=self.sample_rate,
82
+ # caption=effects,
83
+ # )
84
  for effect in effects:
85
  # Sample the model
86
  elem = self.model[effect].model.sample(elem)
87
+ # log_wandb_audio_batch(
88
+ # logger=self.logger,
89
+ # id=f"{i}_{effect}",
90
+ # samples=elem.cpu(),
91
+ # sampling_rate=self.sample_rate,
92
+ # caption=effects,
93
+ # )
94
+ # log_wandb_audio_batch(
95
+ # logger=self.logger,
96
+ # id=f"{i}_After",
97
+ # samples=elem.cpu(),
98
+ # sampling_rate=self.sample_rate,
99
+ # caption=effects,
100
+ # )
101
  output.append(elem.squeeze(0))
102
  output = torch.stack(output)
103
+ output_samples = rearrange(output, "b c t -> c (b t)").unsqueeze(0)
104
+
105
+ log_wandb_audio_batch(
106
+ logger=self.logger,
107
+ id="output_audio",
108
+ samples=output_samples.cpu(),
109
+ sampling_rate=self.sample_rate,
110
+ caption="Output Data",
111
+ )
112
  loss = self.mrstftloss(output, y) + self.l1loss(output, y) * 100
113
  return loss, output
114
 
 
138
  prog_bar=True,
139
  sync_dist=True,
140
  )
141
+ self.log(
142
+ f"Input_{metric}",
143
+ negate * self.metrics[metric](x, y),
144
+ on_step=False,
145
+ on_epoch=True,
146
+ logger=True,
147
+ prog_bar=True,
148
+ sync_dist=True,
149
+ )
150
+ return loss
151
 
152
  def sample(self, batch):
153
  return self.forward(batch, 0)[1]