mattricesound commited on
Commit
f3350b1
·
1 Parent(s): 15b101a
Files changed (4) hide show
  1. cfg/model/tcn.yaml +27 -0
  2. remfx/models.py +31 -0
  3. remfx/tcn.py +145 -0
  4. remfx/utils.py +12 -0
cfg/model/tcn.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.RemFX
4
+ lr: 1e-4
5
+ lr_beta1: 0.95
6
+ lr_beta2: 0.999
7
+ lr_eps: 1e-6
8
+ lr_weight_decay: 1e-3
9
+ sample_rate: ${sample_rate}
10
+ network:
11
+ _target_: remfx.models.TCNModel
12
+ ninputs: 1
13
+ noutputs: 1
14
+ nblocks: 4
15
+ channel_growth: 0
16
+ channel_width: 32
17
+ kernel_size: 13
18
+ stack_size: 10
19
+ dilation_growth: 10
20
+ condition: False
21
+ latent_dim: 2
22
+ norm_type: "identity"
23
+ causal: False
24
+ estimate_loudness: False
25
+ sample_rate: ${sample_rate}
26
+ num_bins: 1025
27
+
remfx/models.py CHANGED
@@ -12,6 +12,7 @@ from umx.openunmix.model import OpenUnmix, Separator
12
  from remfx.utils import FADLoss, spectrogram
13
  from remfx.dptnet import DPTNet_base
14
  from remfx.dcunet import RefineSpectrogramUnet
 
15
 
16
 
17
  class RemFX(pl.LightningModule):
@@ -240,6 +241,36 @@ class DCUNetModel(nn.Module):
240
  return output
241
 
242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  class FXClassifier(pl.LightningModule):
244
  def __init__(
245
  self,
 
12
  from remfx.utils import FADLoss, spectrogram
13
  from remfx.dptnet import DPTNet_base
14
  from remfx.dcunet import RefineSpectrogramUnet
15
+ from remfx.tcn import TCN
16
 
17
 
18
  class RemFX(pl.LightningModule):
 
241
  return output
242
 
243
 
244
+ class TCNModel(nn.Module):
245
+ def __init__(self, sample_rate, num_bins, **kwargs):
246
+ super().__init__()
247
+ self.model = TCN(**kwargs)
248
+ self.mrstftloss = MultiResolutionSTFTLoss(
249
+ n_bins=num_bins, sample_rate=sample_rate
250
+ )
251
+ self.l1loss = nn.L1Loss()
252
+
253
+ def forward(self, batch):
254
+ x, target = batch
255
+ output = self.model(x) # B x 1 x T
256
+ # Pad or crop to match target
257
+ if output.shape[-1] > x.shape[-1]:
258
+ output = output[:, : x.shape[-1]]
259
+ elif output.shape[-1] < x.shape[-1]:
260
+ output = F.pad(output, (0, x.shape[-1] - output.shape[-1]))
261
+ loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
262
+ return loss, output
263
+
264
+ def sample(self, x: Tensor) -> Tensor:
265
+ output = self.model(x) # B x 1 x T
266
+ # Pad or crop to match target
267
+ if output.shape[-1] > x.shape[-1]:
268
+ output = output[:, : x.shape[-1]]
269
+ elif output.shape[-1] < x.shape[-1]:
270
+ output = F.pad(output, (0, x.shape[-1] - output.shape[-1]))
271
+ return output
272
+
273
+
274
  class FXClassifier(pl.LightningModule):
275
  def __init__(
276
  self,
remfx/tcn.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is based on the following repository written by Christian J. Steinmetz
2
+ # https://github.com/csteinmetz1/micro-tcn
3
+ from typing import Callable
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import Tensor
7
+
8
+ from remfx.utils import causal_crop, center_crop
9
+
10
+
11
+ class TCNBlock(nn.Module):
12
+ def __init__(
13
+ self,
14
+ in_ch: int,
15
+ out_ch: int,
16
+ kernel_size: int = 3,
17
+ dilation: int = 1,
18
+ stride: int = 1,
19
+ crop_fn: Callable = causal_crop,
20
+ ) -> None:
21
+ super().__init__()
22
+ self.in_ch = in_ch
23
+ self.out_ch = out_ch
24
+ self.kernel_size = kernel_size
25
+ self.stride = stride
26
+
27
+ self.crop_fn = crop_fn
28
+ # Assumes stride of 1
29
+ padding = (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2
30
+ self.conv1 = nn.Conv1d(
31
+ in_ch,
32
+ out_ch,
33
+ kernel_size,
34
+ stride=stride,
35
+ padding=0,
36
+ dilation=dilation,
37
+ bias=True,
38
+ )
39
+ # residual connection
40
+ self.res = nn.Conv1d(
41
+ in_ch,
42
+ out_ch,
43
+ kernel_size=1,
44
+ groups=1,
45
+ stride=stride,
46
+ bias=False,
47
+ )
48
+ self.relu = nn.PReLU(out_ch)
49
+
50
+ def forward(self, x: Tensor) -> Tensor:
51
+ x_in = x
52
+ x = self.conv1(x)
53
+ x = self.relu(x)
54
+
55
+ # residual
56
+ x_res = self.res(x_in)
57
+
58
+ # causal crop
59
+ x = x + self.crop_fn(x_res, x.shape[-1])
60
+
61
+ return x
62
+
63
+
64
+ class TCN(nn.Module):
65
+ def __init__(
66
+ self,
67
+ ninputs: int = 1,
68
+ noutputs: int = 1,
69
+ nblocks: int = 4,
70
+ channel_growth: int = 0,
71
+ channel_width: int = 32,
72
+ kernel_size: int = 13,
73
+ stack_size: int = 10,
74
+ dilation_growth: int = 10,
75
+ condition: bool = False,
76
+ latent_dim: int = 2,
77
+ norm_type: str = "identity",
78
+ causal: bool = False,
79
+ estimate_loudness: bool = False,
80
+ ) -> None:
81
+ super().__init__()
82
+ self.ninputs = ninputs
83
+ self.noutputs = noutputs
84
+ self.nblocks = nblocks
85
+ self.channel_growth = channel_growth
86
+ self.channel_width = channel_width
87
+ self.kernel_size = kernel_size
88
+ self.stack_size = stack_size
89
+ self.dilation_growth = dilation_growth
90
+ self.condition = condition
91
+ self.latent_dim = latent_dim
92
+ self.norm_type = norm_type
93
+ self.causal = causal
94
+ self.estimate_loudness = estimate_loudness
95
+
96
+ print(f"Causal: {self.causal}")
97
+ if self.causal:
98
+ self.crop_fn = causal_crop
99
+ else:
100
+ self.crop_fn = center_crop
101
+
102
+ if estimate_loudness:
103
+ self.loudness = torch.nn.Linear(latent_dim, 1)
104
+
105
+ # audio model
106
+ self.process_blocks = torch.nn.ModuleList()
107
+ out_ch = -1
108
+ for n in range(nblocks):
109
+ in_ch = out_ch if n > 0 else ninputs
110
+ out_ch = in_ch * channel_growth if channel_growth > 1 else channel_width
111
+ dilation = dilation_growth ** (n % stack_size)
112
+ self.process_blocks.append(
113
+ TCNBlock(
114
+ in_ch,
115
+ out_ch,
116
+ kernel_size,
117
+ dilation,
118
+ stride=1,
119
+ crop_fn=self.crop_fn,
120
+ )
121
+ )
122
+ self.output = nn.Conv1d(out_ch, noutputs, kernel_size=1)
123
+
124
+ # model configuration
125
+ self.receptive_field = self.compute_receptive_field()
126
+ self.block_size = 2048
127
+ self.buffer = torch.zeros(2, self.receptive_field + self.block_size - 1)
128
+
129
+ def forward(self, x: Tensor) -> Tensor:
130
+ x_in = x
131
+ for _, block in enumerate(self.process_blocks):
132
+ x = block(x)
133
+ # y_hat = torch.tanh(self.output(x))
134
+ x_in = causal_crop(x_in, x.shape[-1])
135
+ gain_ln = self.output(x)
136
+ y_hat = torch.tanh(gain_ln * x_in)
137
+ return y_hat
138
+
139
+ def compute_receptive_field(self):
140
+ """Compute the receptive field in samples."""
141
+ rf = self.kernel_size
142
+ for n in range(1, self.nblocks):
143
+ dilation = self.dilation_growth ** (n % self.stack_size)
144
+ rf = rf + ((self.kernel_size - 1) * dilation)
145
+ return rf
remfx/utils.py CHANGED
@@ -204,3 +204,15 @@ def concat_complex(a: torch.tensor, b: torch.tensor, dim: int = 1) -> torch.tens
204
  a_real, a_img = a.chunk(2, dim)
205
  b_real, b_img = b.chunk(2, dim)
206
  return torch.cat([a_real, b_real, a_img, b_img], dim=dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  a_real, a_img = a.chunk(2, dim)
205
  b_real, b_img = b.chunk(2, dim)
206
  return torch.cat([a_real, b_real, a_img, b_img], dim=dim)
207
+
208
+
209
+ def center_crop(x, length: int):
210
+ start = (x.shape[-1] - length) // 2
211
+ stop = start + length
212
+ return x[..., start:stop]
213
+
214
+
215
+ def causal_crop(x, length: int):
216
+ stop = x.shape[-1] - 1
217
+ start = stop - length
218
+ return x[..., start:stop]