Spaces:
Runtime error
Runtime error
Commit
·
7fc4de1
1
Parent(s):
b8427f9
Update datagen silence threshold to 1e-4
Browse files- cfg/exp/chain_inference_aug_classifier.yaml +5 -4
- remfx/datasets.py +2 -3
- remfx/models.py +26 -15
- remfx/utils.py +1 -1
cfg/exp/chain_inference_aug_classifier.yaml
CHANGED
@@ -47,14 +47,15 @@ classifier:
|
|
47 |
lr: 3e-4
|
48 |
lr_weight_decay: 1e-3
|
49 |
sample_rate: ${sample_rate}
|
|
|
50 |
network:
|
51 |
_target_: remfx.classifier.Cnn14
|
52 |
num_classes: ${num_classes}
|
53 |
-
n_fft:
|
54 |
-
hop_length:
|
55 |
n_mels: 128
|
56 |
-
sample_rate:
|
57 |
-
model_sample_rate:
|
58 |
specaugment: False
|
59 |
classifier_ckpt: "ckpts/classifier.ckpt"
|
60 |
|
|
|
47 |
lr: 3e-4
|
48 |
lr_weight_decay: 1e-3
|
49 |
sample_rate: ${sample_rate}
|
50 |
+
mixup: False
|
51 |
network:
|
52 |
_target_: remfx.classifier.Cnn14
|
53 |
num_classes: ${num_classes}
|
54 |
+
n_fft: 2048
|
55 |
+
hop_length: 512
|
56 |
n_mels: 128
|
57 |
+
sample_rate: ${sample_rate}
|
58 |
+
model_sample_rate: ${sample_rate}
|
59 |
specaugment: False
|
60 |
classifier_ckpt: "ckpts/classifier.ckpt"
|
61 |
|
remfx/datasets.py
CHANGED
@@ -259,7 +259,7 @@ class EffectDataset(Dataset):
|
|
259 |
render_files: bool = True,
|
260 |
render_root: str = None,
|
261 |
mode: str = "train",
|
262 |
-
parallel: bool =
|
263 |
):
|
264 |
super().__init__()
|
265 |
self.chunks = []
|
@@ -342,7 +342,6 @@ class EffectDataset(Dataset):
|
|
342 |
chunk = select_random_chunk(
|
343 |
random_file_choice, self.chunk_size, self.sample_rate
|
344 |
)
|
345 |
-
|
346 |
# Sum to mono
|
347 |
if chunk.shape[0] > 1:
|
348 |
chunk = chunk.sum(0, keepdim=True)
|
@@ -561,7 +560,7 @@ class EffectDatamodule(pl.LightningDataModule):
|
|
561 |
def test_dataloader(self) -> DataLoader:
|
562 |
return DataLoader(
|
563 |
dataset=self.test_dataset,
|
564 |
-
batch_size=
|
565 |
num_workers=self.num_workers,
|
566 |
pin_memory=self.pin_memory,
|
567 |
shuffle=False,
|
|
|
259 |
render_files: bool = True,
|
260 |
render_root: str = None,
|
261 |
mode: str = "train",
|
262 |
+
parallel: bool = False,
|
263 |
):
|
264 |
super().__init__()
|
265 |
self.chunks = []
|
|
|
342 |
chunk = select_random_chunk(
|
343 |
random_file_choice, self.chunk_size, self.sample_rate
|
344 |
)
|
|
|
345 |
# Sum to mono
|
346 |
if chunk.shape[0] > 1:
|
347 |
chunk = chunk.sum(0, keepdim=True)
|
|
|
560 |
def test_dataloader(self) -> DataLoader:
|
561 |
return DataLoader(
|
562 |
dataset=self.test_dataset,
|
563 |
+
batch_size=1, # Use small, consistent batch size for testing
|
564 |
num_workers=self.num_workers,
|
565 |
pin_memory=self.pin_memory,
|
566 |
shuffle=False,
|
remfx/models.py
CHANGED
@@ -37,7 +37,7 @@ class RemFXChainInference(pl.LightningModule):
|
|
37 |
self.sample_rate = sample_rate
|
38 |
self.effect_order = effect_order
|
39 |
self.classifier = classifier
|
40 |
-
|
41 |
|
42 |
def forward(self, batch, batch_idx, order=None):
|
43 |
x, y, _, rem_fx_labels = batch
|
@@ -46,7 +46,7 @@ class RemFXChainInference(pl.LightningModule):
|
|
46 |
effects_order = order
|
47 |
else:
|
48 |
effects_order = self.effect_order
|
49 |
-
|
50 |
# Use classifier labels
|
51 |
if self.classifier:
|
52 |
threshold = 0.5
|
@@ -113,13 +113,13 @@ class RemFXChainInference(pl.LightningModule):
|
|
113 |
output = torch.stack(output)
|
114 |
output_samples = rearrange(output, "b c t -> c (b t)").unsqueeze(0)
|
115 |
|
116 |
-
log_wandb_audio_batch(
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
)
|
123 |
loss = self.mrstftloss(output, y) + self.l1loss(output, y) * 100
|
124 |
return loss, output
|
125 |
|
@@ -158,13 +158,16 @@ class RemFXChainInference(pl.LightningModule):
|
|
158 |
prog_bar=True,
|
159 |
sync_dist=True,
|
160 |
)
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
# f.write(self.output_str)
|
166 |
return loss
|
167 |
|
|
|
|
|
|
|
|
|
168 |
def sample(self, batch):
|
169 |
return self.forward(batch, 0)[1]
|
170 |
|
@@ -196,6 +199,7 @@ class RemFX(pl.LightningModule):
|
|
196 |
)
|
197 |
# Log first batch metrics input vs output only once
|
198 |
self.log_train_audio = True
|
|
|
199 |
|
200 |
@property
|
201 |
def device(self):
|
@@ -272,9 +276,16 @@ class RemFX(pl.LightningModule):
|
|
272 |
prog_bar=True,
|
273 |
sync_dist=True,
|
274 |
)
|
275 |
-
|
|
|
|
|
|
|
276 |
return loss
|
277 |
|
|
|
|
|
|
|
|
|
278 |
|
279 |
class OpenUnmixModel(nn.Module):
|
280 |
def __init__(
|
|
|
37 |
self.sample_rate = sample_rate
|
38 |
self.effect_order = effect_order
|
39 |
self.classifier = classifier
|
40 |
+
self.output_str = "IN_SISDR,OUT_SISDR,IN_STFT,OUT_STFT\n"
|
41 |
|
42 |
def forward(self, batch, batch_idx, order=None):
|
43 |
x, y, _, rem_fx_labels = batch
|
|
|
46 |
effects_order = order
|
47 |
else:
|
48 |
effects_order = self.effect_order
|
49 |
+
old_labels = rem_fx_labels
|
50 |
# Use classifier labels
|
51 |
if self.classifier:
|
52 |
threshold = 0.5
|
|
|
113 |
output = torch.stack(output)
|
114 |
output_samples = rearrange(output, "b c t -> c (b t)").unsqueeze(0)
|
115 |
|
116 |
+
# log_wandb_audio_batch(
|
117 |
+
# logger=self.logger,
|
118 |
+
# id="output_audio",
|
119 |
+
# samples=output_samples.cpu(),
|
120 |
+
# sampling_rate=self.sample_rate,
|
121 |
+
# caption="Output Data",
|
122 |
+
# )
|
123 |
loss = self.mrstftloss(output, y) + self.l1loss(output, y) * 100
|
124 |
return loss, output
|
125 |
|
|
|
158 |
prog_bar=True,
|
159 |
sync_dist=True,
|
160 |
)
|
161 |
+
print(f"Input_{metric}", negate * self.metrics[metric](x, y))
|
162 |
+
print(f"test_{metric}", negate * self.metrics[metric](output, y))
|
163 |
+
self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
|
164 |
+
self.output_str += "\n"
|
|
|
165 |
return loss
|
166 |
|
167 |
+
def on_test_end(self) -> None:
|
168 |
+
with open("output.csv", "w") as f:
|
169 |
+
f.write(self.output_str)
|
170 |
+
|
171 |
def sample(self, batch):
|
172 |
return self.forward(batch, 0)[1]
|
173 |
|
|
|
199 |
)
|
200 |
# Log first batch metrics input vs output only once
|
201 |
self.log_train_audio = True
|
202 |
+
self.output_str = "IN_SISDR,OUT_SISDR,IN_STFT,OUT_STFT\n"
|
203 |
|
204 |
@property
|
205 |
def device(self):
|
|
|
276 |
prog_bar=True,
|
277 |
sync_dist=True,
|
278 |
)
|
279 |
+
print(f"Input_{metric}", negate * self.metrics[metric](x, y))
|
280 |
+
print(f"test_{metric}", negate * self.metrics[metric](output, y))
|
281 |
+
self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
|
282 |
+
self.output_str += "\n"
|
283 |
return loss
|
284 |
|
285 |
+
def on_test_end(self) -> None:
|
286 |
+
with open("output.csv", "w") as f:
|
287 |
+
f.write(self.output_str)
|
288 |
+
|
289 |
|
290 |
class OpenUnmixModel(nn.Module):
|
291 |
def __init__(
|
remfx/utils.py
CHANGED
@@ -159,7 +159,7 @@ def select_random_chunk(
|
|
159 |
random_start = torch.randint(0, max_len, (1,)).item()
|
160 |
chunk = audio[:, random_start : random_start + new_chunk_size]
|
161 |
# Skip if energy too low
|
162 |
-
if torch.mean(torch.abs(chunk)) < 1e-
|
163 |
return None
|
164 |
resampled_chunk = torchaudio.functional.resample(chunk, sr, sample_rate)
|
165 |
return resampled_chunk
|
|
|
159 |
random_start = torch.randint(0, max_len, (1,)).item()
|
160 |
chunk = audio[:, random_start : random_start + new_chunk_size]
|
161 |
# Skip if energy too low
|
162 |
+
if torch.mean(torch.abs(chunk)) < 1e-4:
|
163 |
return None
|
164 |
resampled_chunk = torchaudio.functional.resample(chunk, sr, sample_rate)
|
165 |
return resampled_chunk
|