File size: 3,985 Bytes
f3350b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ce90f2
f3350b1
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# This code is based on the following repository written by Christian J. Steinmetz
# https://github.com/csteinmetz1/micro-tcn
from typing import Callable
import torch
import torch.nn as nn
from torch import Tensor

from remfx.utils import causal_crop, center_crop


class TCNBlock(nn.Module):
    def __init__(
        self,
        in_ch: int,
        out_ch: int,
        kernel_size: int = 3,
        dilation: int = 1,
        stride: int = 1,
        crop_fn: Callable = causal_crop,
    ) -> None:
        super().__init__()
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.kernel_size = kernel_size
        self.stride = stride

        self.crop_fn = crop_fn
        self.conv1 = nn.Conv1d(
            in_ch,
            out_ch,
            kernel_size,
            stride=stride,
            padding=0,
            dilation=dilation,
            bias=True,
        )
        # residual connection
        self.res = nn.Conv1d(
            in_ch,
            out_ch,
            kernel_size=1,
            groups=1,
            stride=stride,
            bias=False,
        )
        self.relu = nn.PReLU(out_ch)

    def forward(self, x: Tensor) -> Tensor:
        x_in = x
        x = self.conv1(x)
        x = self.relu(x)

        # residual
        x_res = self.res(x_in)

        # causal crop
        x = x + self.crop_fn(x_res, x.shape[-1])

        return x


class TCN(nn.Module):
    def __init__(
        self,
        ninputs: int = 1,
        noutputs: int = 1,
        nblocks: int = 4,
        channel_growth: int = 0,
        channel_width: int = 32,
        kernel_size: int = 13,
        stack_size: int = 10,
        dilation_growth: int = 10,
        condition: bool = False,
        latent_dim: int = 2,
        norm_type: str = "identity",
        causal: bool = False,
        estimate_loudness: bool = False,
    ) -> None:
        super().__init__()
        self.ninputs = ninputs
        self.noutputs = noutputs
        self.nblocks = nblocks
        self.channel_growth = channel_growth
        self.channel_width = channel_width
        self.kernel_size = kernel_size
        self.stack_size = stack_size
        self.dilation_growth = dilation_growth
        self.condition = condition
        self.latent_dim = latent_dim
        self.norm_type = norm_type
        self.causal = causal
        self.estimate_loudness = estimate_loudness

        if self.causal:
            self.crop_fn = causal_crop
        else:
            self.crop_fn = center_crop

        if estimate_loudness:
            self.loudness = torch.nn.Linear(latent_dim, 1)

        # audio model
        self.process_blocks = torch.nn.ModuleList()
        out_ch = -1
        for n in range(nblocks):
            in_ch = out_ch if n > 0 else ninputs
            out_ch = in_ch * channel_growth if channel_growth > 1 else channel_width
            dilation = dilation_growth ** (n % stack_size)
            self.process_blocks.append(
                TCNBlock(
                    in_ch,
                    out_ch,
                    kernel_size,
                    dilation,
                    stride=1,
                    crop_fn=self.crop_fn,
                )
            )
        self.output = nn.Conv1d(out_ch, noutputs, kernel_size=1)

        # model configuration
        self.receptive_field = self.compute_receptive_field()
        self.block_size = 2048
        self.buffer = torch.zeros(2, self.receptive_field + self.block_size - 1)

    def forward(self, x: Tensor) -> Tensor:
        for _, block in enumerate(self.process_blocks):
            x = block(x)
        y_hat = torch.tanh(self.output(x))
        return y_hat

    def compute_receptive_field(self):
        """Compute the receptive field in samples."""
        rf = self.kernel_size
        for n in range(1, self.nblocks):
            dilation = self.dilation_growth ** (n % self.stack_size)
            rf = rf + ((self.kernel_size - 1) * dilation)
        return rf