File size: 3,569 Bytes
091b1e0
bd0a813
e2b0b28
 
3f8f152
 
091b1e0
3f8f152
091b1e0
 
 
3f8f152
 
091b1e0
 
3f8f152
 
bd0a813
091b1e0
 
 
bd0a813
 
091b1e0
 
 
 
 
e2b0b28
091b1e0
e2b0b28
091b1e0
3f8f152
 
bd0a813
091b1e0
3f8f152
9ff4511
091b1e0
 
 
 
e2b0b28
 
 
091b1e0
 
 
 
3f8f152
091b1e0
e2b0b28
 
 
 
 
 
 
 
 
 
 
 
 
091b1e0
bd0a813
e2b0b28
 
091b1e0
 
e2b0b28
 
 
 
 
091b1e0
e2b0b28
bd0a813
091b1e0
e2b0b28
 
 
 
 
091b1e0
b80b88c
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
from torch.nn.functional import pad
from utils import pad_cut_batch_audio
import torch.nn as nn


class Encoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels, cfg):
        super(Encoder, self).__init__()

        self.conv1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
                                     kernel_size=cfg['conv1']['kernel_size'],
                                     stride=cfg['conv1']['stride'])
        self.relu1 = torch.nn.ReLU()
        self.conv2 = torch.nn.Conv1d(in_channels=out_channels, out_channels=2 * out_channels,
                                     kernel_size=cfg['conv2']['kernel_size'],
                                     stride=cfg['conv2']['stride'])
        self.glu = torch.nn.GLU(dim=-2)

    def forward(self, x):
        x = self.relu1(self.conv1(x))
        if x.shape[-1] % 2 == 1:
            x = pad(x, (0, 1))
        x = self.glu(self.conv2(x))
        return x


class Decoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels, cfg, is_last=False):
        super(Decoder, self).__init__()
        self.is_last = is_last
        self.conv1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=2 * in_channels,
                                     kernel_size=cfg['conv1']['kernel_size'],
                                     stride=cfg['conv1']['stride'])
        self.glu = torch.nn.GLU(dim=-2)
        self.conv2 = torch.nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels,
                                              kernel_size=cfg['conv2']['kernel_size'],
                                              stride=cfg['conv2']['stride'])
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.glu(self.conv1(x))
        x = self.conv2(x)
        if not self.is_last:
            x = self.relu(x)
        return x


class Demucs(torch.nn.Module):
    def __init__(self, cfg):
        super(Demucs, self).__init__()
        self.L = cfg['L']

        encoders = [Encoder(in_channels=1, out_channels=cfg['H'], cfg=cfg['encoder'])]
        decoders = [Decoder(in_channels=cfg['H'], out_channels=1, cfg=cfg['decoder'], is_last=True)]
        for i in range(self.L - 1):
            encoders.append(Encoder(in_channels=(2 ** i) * cfg['H'],
                                    out_channels=(2 ** (i + 1)) * cfg['H'],
                                    cfg=cfg['encoder']))
            decoders.append(Decoder(in_channels=(2 ** (i + 1)) * cfg['H'],
                                    out_channels=(2 ** i) * cfg['H'],
                                    cfg=cfg['decoder']))
        self.encoders = nn.ModuleList(encoders)
        self.decoders = nn.ModuleList(decoders)

        self.lstm = torch.nn.LSTM(
            input_size=(2 ** (self.L - 1)) * cfg['H'],
            hidden_size=(2 ** (self.L - 1)) * cfg['H'], num_layers=2, batch_first=True)

    def forward(self, x):
        outs = [x]
        for i in range(self.L):
            out = self.encoders[i](outs[-1])
            outs.append(out)
        model_input = outs.pop(0)

        x, _ = self.lstm(outs[-1].permute(0, 2, 1))
        x = x.permute(0, 2, 1)

        for i in reversed(range(self.L)):
            decoder = self.decoders[i]
            x = pad_cut_batch_audio(x, outs[i].shape)
            x = decoder(x + outs[i])
        x = pad_cut_batch_audio(x, model_input.shape)
        return x

    def predict(self, wav):
        prediction = self.forward(torch.reshape(wav, (1, 1, -1)))
        return prediction.detach()[0]