Spaces:
Runtime error
Runtime error
Commit
·
7902946
1
Parent(s):
133e1dc
Fix remfx all effects selection
Browse files- remfx/models.py +6 -5
remfx/models.py
CHANGED
@@ -67,7 +67,7 @@ class RemFXChainInference(pl.LightningModule):
|
|
67 |
rem_fx_labels = torch.where(labels > threshold, 1.0, 0.0)
|
68 |
if self.use_all_effect_models:
|
69 |
effects_present = [
|
70 |
-
[ALL_EFFECTS[i] for i, effect in enumerate(effect_label)
|
71 |
for effect_label in rem_fx_labels
|
72 |
]
|
73 |
else:
|
@@ -79,6 +79,7 @@ class RemFXChainInference(pl.LightningModule):
|
|
79 |
]
|
80 |
for effect_label in rem_fx_labels
|
81 |
]
|
|
|
82 |
output = []
|
83 |
# input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
|
84 |
# target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
|
@@ -179,8 +180,8 @@ class RemFXChainInference(pl.LightningModule):
|
|
179 |
prog_bar=True,
|
180 |
sync_dist=True,
|
181 |
)
|
182 |
-
print(f"Input_{metric}", negate * self.metrics[metric](x, y))
|
183 |
-
print(f"test_{metric}", negate * self.metrics[metric](output, y))
|
184 |
self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
|
185 |
self.output_str += "\n"
|
186 |
return loss
|
@@ -297,8 +298,8 @@ class RemFX(pl.LightningModule):
|
|
297 |
prog_bar=True,
|
298 |
sync_dist=True,
|
299 |
)
|
300 |
-
print(f"Input_{metric}", negate * self.metrics[metric](x, y))
|
301 |
-
print(f"test_{metric}", negate * self.metrics[metric](output, y))
|
302 |
self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
|
303 |
self.output_str += "\n"
|
304 |
return loss
|
|
|
67 |
rem_fx_labels = torch.where(labels > threshold, 1.0, 0.0)
|
68 |
if self.use_all_effect_models:
|
69 |
effects_present = [
|
70 |
+
[ALL_EFFECTS[i] for i, effect in enumerate(effect_label)]
|
71 |
for effect_label in rem_fx_labels
|
72 |
]
|
73 |
else:
|
|
|
79 |
]
|
80 |
for effect_label in rem_fx_labels
|
81 |
]
|
82 |
+
|
83 |
output = []
|
84 |
# input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
|
85 |
# target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
|
|
|
180 |
prog_bar=True,
|
181 |
sync_dist=True,
|
182 |
)
|
183 |
+
# print(f"Input_{metric}", negate * self.metrics[metric](x, y))
|
184 |
+
# print(f"test_{metric}", negate * self.metrics[metric](output, y))
|
185 |
self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
|
186 |
self.output_str += "\n"
|
187 |
return loss
|
|
|
298 |
prog_bar=True,
|
299 |
sync_dist=True,
|
300 |
)
|
301 |
+
# print(f"Input_{metric}", negate * self.metrics[metric](x, y))
|
302 |
+
# print(f"test_{metric}", negate * self.metrics[metric](output, y))
|
303 |
self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
|
304 |
self.output_str += "\n"
|
305 |
return loss
|