Spaces:
Runtime error
Runtime error
File size: 3,569 Bytes
091b1e0 bd0a813 e2b0b28 3f8f152 091b1e0 3f8f152 091b1e0 3f8f152 091b1e0 3f8f152 bd0a813 091b1e0 bd0a813 091b1e0 e2b0b28 091b1e0 e2b0b28 091b1e0 3f8f152 bd0a813 091b1e0 3f8f152 9ff4511 091b1e0 e2b0b28 091b1e0 3f8f152 091b1e0 e2b0b28 091b1e0 bd0a813 e2b0b28 091b1e0 e2b0b28 091b1e0 e2b0b28 bd0a813 091b1e0 e2b0b28 091b1e0 b80b88c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import torch
from torch.nn.functional import pad
from utils import pad_cut_batch_audio
import torch.nn as nn
class Encoder(torch.nn.Module):
def __init__(self, in_channels, out_channels, cfg):
super(Encoder, self).__init__()
self.conv1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
kernel_size=cfg['conv1']['kernel_size'],
stride=cfg['conv1']['stride'])
self.relu1 = torch.nn.ReLU()
self.conv2 = torch.nn.Conv1d(in_channels=out_channels, out_channels=2 * out_channels,
kernel_size=cfg['conv2']['kernel_size'],
stride=cfg['conv2']['stride'])
self.glu = torch.nn.GLU(dim=-2)
def forward(self, x):
x = self.relu1(self.conv1(x))
if x.shape[-1] % 2 == 1:
x = pad(x, (0, 1))
x = self.glu(self.conv2(x))
return x
class Decoder(torch.nn.Module):
def __init__(self, in_channels, out_channels, cfg, is_last=False):
super(Decoder, self).__init__()
self.is_last = is_last
self.conv1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=2 * in_channels,
kernel_size=cfg['conv1']['kernel_size'],
stride=cfg['conv1']['stride'])
self.glu = torch.nn.GLU(dim=-2)
self.conv2 = torch.nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels,
kernel_size=cfg['conv2']['kernel_size'],
stride=cfg['conv2']['stride'])
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.glu(self.conv1(x))
x = self.conv2(x)
if not self.is_last:
x = self.relu(x)
return x
class Demucs(torch.nn.Module):
def __init__(self, cfg):
super(Demucs, self).__init__()
self.L = cfg['L']
encoders = [Encoder(in_channels=1, out_channels=cfg['H'], cfg=cfg['encoder'])]
decoders = [Decoder(in_channels=cfg['H'], out_channels=1, cfg=cfg['decoder'], is_last=True)]
for i in range(self.L - 1):
encoders.append(Encoder(in_channels=(2 ** i) * cfg['H'],
out_channels=(2 ** (i + 1)) * cfg['H'],
cfg=cfg['encoder']))
decoders.append(Decoder(in_channels=(2 ** (i + 1)) * cfg['H'],
out_channels=(2 ** i) * cfg['H'],
cfg=cfg['decoder']))
self.encoders = nn.ModuleList(encoders)
self.decoders = nn.ModuleList(decoders)
self.lstm = torch.nn.LSTM(
input_size=(2 ** (self.L - 1)) * cfg['H'],
hidden_size=(2 ** (self.L - 1)) * cfg['H'], num_layers=2, batch_first=True)
def forward(self, x):
outs = [x]
for i in range(self.L):
out = self.encoders[i](outs[-1])
outs.append(out)
model_input = outs.pop(0)
x, _ = self.lstm(outs[-1].permute(0, 2, 1))
x = x.permute(0, 2, 1)
for i in reversed(range(self.L)):
decoder = self.decoders[i]
x = pad_cut_batch_audio(x, outs[i].shape)
x = decoder(x + outs[i])
x = pad_cut_batch_audio(x, model_input.shape)
return x
def predict(self, wav):
prediction = self.forward(torch.reshape(wav, (1, 1, -1)))
return prediction.detach()[0]
|