Spaces:
Sleeping
Sleeping
Commit
·
f8fea2a
1
Parent(s):
1b72821
Fix DCUNet
Browse files- remfx/models.py +23 -0
remfx/models.py
CHANGED
@@ -326,6 +326,29 @@ class DPTNetModel(nn.Module):
|
|
326 |
return self.model(x.squeeze(1))
|
327 |
|
328 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
329 |
class TCNModel(nn.Module):
|
330 |
def __init__(self, sample_rate, num_bins, **kwargs):
|
331 |
super().__init__()
|
|
|
326 |
return self.model(x.squeeze(1))
|
327 |
|
328 |
|
329 |
+
class DCUNetModel(nn.Module):
|
330 |
+
def __init__(self, sample_rate, num_bins, **kwargs):
|
331 |
+
super().__init__()
|
332 |
+
self.model = asteroid.models.DCUNet(**kwargs)
|
333 |
+
self.mrstftloss = MultiResolutionSTFTLoss(
|
334 |
+
n_bins=num_bins, sample_rate=sample_rate
|
335 |
+
)
|
336 |
+
self.l1loss = nn.L1Loss()
|
337 |
+
|
338 |
+
def forward(self, batch):
|
339 |
+
x, target = batch
|
340 |
+
output = self.model(x.squeeze(1)) # B x T
|
341 |
+
# Crop target to match output
|
342 |
+
if output.shape[-1] < target.shape[-1]:
|
343 |
+
target = causal_crop(target, output.shape[-1])
|
344 |
+
loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
|
345 |
+
return loss, output
|
346 |
+
|
347 |
+
def sample(self, x: Tensor) -> Tensor:
|
348 |
+
output = self.model(x.squeeze(1)) # B x T
|
349 |
+
return output
|
350 |
+
|
351 |
+
|
352 |
class TCNModel(nn.Module):
|
353 |
def __init__(self, sample_rate, num_bins, **kwargs):
|
354 |
super().__init__()
|