Spaces:
Sleeping
Sleeping
# This code is based on the following repository written by Christian J. Steinmetz | |
# https://github.com/csteinmetz1/micro-tcn | |
from typing import Callable | |
import torch | |
import torch.nn as nn | |
from torch import Tensor | |
from remfx.utils import causal_crop, center_crop | |
class TCNBlock(nn.Module): | |
def __init__( | |
self, | |
in_ch: int, | |
out_ch: int, | |
kernel_size: int = 3, | |
dilation: int = 1, | |
stride: int = 1, | |
crop_fn: Callable = causal_crop, | |
) -> None: | |
super().__init__() | |
self.in_ch = in_ch | |
self.out_ch = out_ch | |
self.kernel_size = kernel_size | |
self.stride = stride | |
self.crop_fn = crop_fn | |
self.conv1 = nn.Conv1d( | |
in_ch, | |
out_ch, | |
kernel_size, | |
stride=stride, | |
padding=0, | |
dilation=dilation, | |
bias=True, | |
) | |
# residual connection | |
self.res = nn.Conv1d( | |
in_ch, | |
out_ch, | |
kernel_size=1, | |
groups=1, | |
stride=stride, | |
bias=False, | |
) | |
self.relu = nn.PReLU(out_ch) | |
def forward(self, x: Tensor) -> Tensor: | |
x_in = x | |
x = self.conv1(x) | |
x = self.relu(x) | |
# residual | |
x_res = self.res(x_in) | |
# causal crop | |
x = x + self.crop_fn(x_res, x.shape[-1]) | |
return x | |
class TCN(nn.Module): | |
def __init__( | |
self, | |
ninputs: int = 1, | |
noutputs: int = 1, | |
nblocks: int = 4, | |
channel_growth: int = 0, | |
channel_width: int = 32, | |
kernel_size: int = 13, | |
stack_size: int = 10, | |
dilation_growth: int = 10, | |
condition: bool = False, | |
latent_dim: int = 2, | |
norm_type: str = "identity", | |
causal: bool = False, | |
estimate_loudness: bool = False, | |
) -> None: | |
super().__init__() | |
self.ninputs = ninputs | |
self.noutputs = noutputs | |
self.nblocks = nblocks | |
self.channel_growth = channel_growth | |
self.channel_width = channel_width | |
self.kernel_size = kernel_size | |
self.stack_size = stack_size | |
self.dilation_growth = dilation_growth | |
self.condition = condition | |
self.latent_dim = latent_dim | |
self.norm_type = norm_type | |
self.causal = causal | |
self.estimate_loudness = estimate_loudness | |
print(f"Causal: {self.causal}") | |
if self.causal: | |
self.crop_fn = causal_crop | |
else: | |
self.crop_fn = center_crop | |
if estimate_loudness: | |
self.loudness = torch.nn.Linear(latent_dim, 1) | |
# audio model | |
self.process_blocks = torch.nn.ModuleList() | |
out_ch = -1 | |
for n in range(nblocks): | |
in_ch = out_ch if n > 0 else ninputs | |
out_ch = in_ch * channel_growth if channel_growth > 1 else channel_width | |
dilation = dilation_growth ** (n % stack_size) | |
self.process_blocks.append( | |
TCNBlock( | |
in_ch, | |
out_ch, | |
kernel_size, | |
dilation, | |
stride=1, | |
crop_fn=self.crop_fn, | |
) | |
) | |
self.output = nn.Conv1d(out_ch, noutputs, kernel_size=1) | |
# model configuration | |
self.receptive_field = self.compute_receptive_field() | |
self.block_size = 2048 | |
self.buffer = torch.zeros(2, self.receptive_field + self.block_size - 1) | |
def forward(self, x: Tensor) -> Tensor: | |
x_in = x | |
for _, block in enumerate(self.process_blocks): | |
x = block(x) | |
y_hat = torch.tanh(self.output(x)) | |
return y_hat | |
def compute_receptive_field(self): | |
"""Compute the receptive field in samples.""" | |
rf = self.kernel_size | |
for n in range(1, self.nblocks): | |
dilation = self.dilation_growth ** (n % self.stack_size) | |
rf = rf + ((self.kernel_size - 1) * dilation) | |
return rf | |