Spaces:
Runtime error
Runtime error
import torch | |
from torch.nn.functional import pad | |
from utils import pad_cut_batch_audio | |
import torch.nn as nn | |
class Encoder(torch.nn.Module): | |
def __init__(self, in_channels, out_channels, cfg): | |
super(Encoder, self).__init__() | |
self.conv1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels, | |
kernel_size=cfg['conv1']['kernel_size'], | |
stride=cfg['conv1']['stride']) | |
self.relu1 = torch.nn.ReLU() | |
self.conv2 = torch.nn.Conv1d(in_channels=out_channels, out_channels=2 * out_channels, | |
kernel_size=cfg['conv2']['kernel_size'], | |
stride=cfg['conv2']['stride']) | |
self.glu = torch.nn.GLU(dim=-2) | |
def forward(self, x): | |
x = self.relu1(self.conv1(x)) | |
if x.shape[-1] % 2 == 1: | |
x = pad(x, (0, 1)) | |
x = self.glu(self.conv2(x)) | |
return x | |
class Decoder(torch.nn.Module): | |
def __init__(self, in_channels, out_channels, cfg, is_last=False): | |
super(Decoder, self).__init__() | |
self.is_last = is_last | |
self.conv1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=2 * in_channels, | |
kernel_size=cfg['conv1']['kernel_size'], | |
stride=cfg['conv1']['stride']) | |
self.glu = torch.nn.GLU(dim=-2) | |
self.conv2 = torch.nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels, | |
kernel_size=cfg['conv2']['kernel_size'], | |
stride=cfg['conv2']['stride']) | |
self.relu = torch.nn.ReLU() | |
def forward(self, x): | |
x = self.glu(self.conv1(x)) | |
x = self.conv2(x) | |
if not self.is_last: | |
x = self.relu(x) | |
return x | |
class Demucs(torch.nn.Module): | |
def __init__(self, cfg): | |
super(Demucs, self).__init__() | |
self.L = cfg['L'] | |
encoders = [Encoder(in_channels=1, out_channels=cfg['H'], cfg=cfg['encoder'])] | |
decoders = [Decoder(in_channels=cfg['H'], out_channels=1, cfg=cfg['decoder'], is_last=True)] | |
for i in range(self.L - 1): | |
encoders.append(Encoder(in_channels=(2 ** i) * cfg['H'], | |
out_channels=(2 ** (i + 1)) * cfg['H'], | |
cfg=cfg['encoder'])) | |
decoders.append(Decoder(in_channels=(2 ** (i + 1)) * cfg['H'], | |
out_channels=(2 ** i) * cfg['H'], | |
cfg=cfg['decoder'])) | |
self.encoders = nn.ModuleList(encoders) | |
self.decoders = nn.ModuleList(decoders) | |
self.lstm = torch.nn.LSTM( | |
input_size=(2 ** (self.L - 1)) * cfg['H'], | |
hidden_size=(2 ** (self.L - 1)) * cfg['H'], num_layers=2, batch_first=True) | |
def forward(self, x): | |
outs = [x] | |
for i in range(self.L): | |
out = self.encoders[i](outs[-1]) | |
outs.append(out) | |
model_input = outs.pop(0) | |
x, _ = self.lstm(outs[-1].permute(0, 2, 1)) | |
x = x.permute(0, 2, 1) | |
for i in reversed(range(self.L)): | |
decoder = self.decoders[i] | |
x = pad_cut_batch_audio(x, outs[i].shape) | |
x = decoder(x + outs[i]) | |
x = pad_cut_batch_audio(x, model_input.shape) | |
return x | |
def predict(self, wav): | |
with torch.no_grad(): | |
wav_reshaped = wav.reshape((1,1,-1)) | |
prediction = self.forward(wav_reshaped) | |
return prediction[0] | |