mattricesound commited on
Commit
be8f3b0
·
1 Parent(s): eae60a9

WIP: datapoint writing

Browse files
Files changed (1) hide show
  1. remfx/models.py +9 -2
remfx/models.py CHANGED
@@ -36,6 +36,7 @@ class RemFXChainInference(pl.LightningModule):
36
  self.sample_rate = sample_rate
37
  self.effect_order = effect_order
38
  self.classifier = classifier
 
39
 
40
  def forward(self, batch, batch_idx, order=None):
41
  x, y, _, rem_fx_labels = batch
@@ -48,8 +49,9 @@ class RemFXChainInference(pl.LightningModule):
48
  # Use classifier labels
49
  if self.classifier:
50
  threshold = 0.5
51
- labels = self.classifier(x)
52
- rem_fx_labels = torch.where(labels > threshold, 1.0, 0.0)
 
53
 
54
  effects_present = [
55
  [ALL_EFFECTS[i] for i, effect in enumerate(effect_label) if effect == 1.0]
@@ -155,6 +157,11 @@ class RemFXChainInference(pl.LightningModule):
155
  prog_bar=True,
156
  sync_dist=True,
157
  )
 
 
 
 
 
158
  return loss
159
 
160
  def sample(self, batch):
 
36
  self.sample_rate = sample_rate
37
  self.effect_order = effect_order
38
  self.classifier = classifier
39
+ # self.output_str = "IN_SISDR,OUT_SISDR,IN_STFT,OUT_STFT\n"
40
 
41
  def forward(self, batch, batch_idx, order=None):
42
  x, y, _, rem_fx_labels = batch
 
49
  # Use classifier labels
50
  if self.classifier:
51
  threshold = 0.5
52
+ with torch.no_grad():
53
+ labels = torch.sigmoid(self.classifier(x))
54
+ rem_fx_labels = torch.where(labels > threshold, 1.0, 0.0)
55
 
56
  effects_present = [
57
  [ALL_EFFECTS[i] for i, effect in enumerate(effect_label) if effect == 1.0]
 
157
  prog_bar=True,
158
  sync_dist=True,
159
  )
160
+ # self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
161
+ # self.output_str += "\n"
162
+ # if batch_idx == 4:
163
+ # with open("output.csv", "w") as f:
164
+ # f.write(self.output_str)
165
  return loss
166
 
167
  def sample(self, batch):