Spaces:
Sleeping
Sleeping
Commit
·
be8f3b0
1
Parent(s):
eae60a9
WIP: datapoint writing
Browse files- remfx/models.py +9 -2
remfx/models.py
CHANGED
@@ -36,6 +36,7 @@ class RemFXChainInference(pl.LightningModule):
|
|
36 |
self.sample_rate = sample_rate
|
37 |
self.effect_order = effect_order
|
38 |
self.classifier = classifier
|
|
|
39 |
|
40 |
def forward(self, batch, batch_idx, order=None):
|
41 |
x, y, _, rem_fx_labels = batch
|
@@ -48,8 +49,9 @@ class RemFXChainInference(pl.LightningModule):
|
|
48 |
# Use classifier labels
|
49 |
if self.classifier:
|
50 |
threshold = 0.5
|
51 |
-
|
52 |
-
|
|
|
53 |
|
54 |
effects_present = [
|
55 |
[ALL_EFFECTS[i] for i, effect in enumerate(effect_label) if effect == 1.0]
|
@@ -155,6 +157,11 @@ class RemFXChainInference(pl.LightningModule):
|
|
155 |
prog_bar=True,
|
156 |
sync_dist=True,
|
157 |
)
|
|
|
|
|
|
|
|
|
|
|
158 |
return loss
|
159 |
|
160 |
def sample(self, batch):
|
|
|
36 |
self.sample_rate = sample_rate
|
37 |
self.effect_order = effect_order
|
38 |
self.classifier = classifier
|
39 |
+
# self.output_str = "IN_SISDR,OUT_SISDR,IN_STFT,OUT_STFT\n"
|
40 |
|
41 |
def forward(self, batch, batch_idx, order=None):
|
42 |
x, y, _, rem_fx_labels = batch
|
|
|
49 |
# Use classifier labels
|
50 |
if self.classifier:
|
51 |
threshold = 0.5
|
52 |
+
with torch.no_grad():
|
53 |
+
labels = torch.sigmoid(self.classifier(x))
|
54 |
+
rem_fx_labels = torch.where(labels > threshold, 1.0, 0.0)
|
55 |
|
56 |
effects_present = [
|
57 |
[ALL_EFFECTS[i] for i, effect in enumerate(effect_label) if effect == 1.0]
|
|
|
157 |
prog_bar=True,
|
158 |
sync_dist=True,
|
159 |
)
|
160 |
+
# self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
|
161 |
+
# self.output_str += "\n"
|
162 |
+
# if batch_idx == 4:
|
163 |
+
# with open("output.csv", "w") as f:
|
164 |
+
# f.write(self.output_str)
|
165 |
return loss
|
166 |
|
167 |
def sample(self, batch):
|