File size: 3,587 Bytes
0841207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch.nn as nn
import torch
import torch.nn.init as init

# Swish nonlinearity
class Swish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x * torch.sigmoid(x)

# Improved residual block with three convolutions, dropout, and normalization
class ResConv1DBlock(nn.Module):
    def __init__(self, n_in, n_state, dilation=1, activation='silu', norm=None, dropout=None):
        super().__init__()

        self.dropout = dropout
        self.norm = norm
        
        # Select normalization
        def get_norm(n):
            if norm == "LN":
                return nn.LayerNorm(n)
            elif norm == "GN":
                return nn.GroupNorm(32, n)
            elif norm == "BN":
                return nn.BatchNorm1d(n)
            else:
                return nn.Identity()
        
        self.norm1 = get_norm(n_in)
        self.norm2 = get_norm(n_state)
        self.norm3 = get_norm(n_in)

        # Select activation
        def get_activation(a):
            if a == "relu":
                return nn.ReLU()
            elif a == "silu":
                return Swish()
            elif a == "gelu":
                return nn.GELU()
            elif a == "leaky_relu":
                return nn.LeakyReLU(0.01)
            else:
                raise ValueError("Unsupported activation type")

        self.activation1 = get_activation(activation)
        self.activation2 = get_activation(activation)
        self.activation3 = get_activation(activation)

        # Convolution layers with dropout and normalization
        self.conv1 = nn.Conv1d(n_in, n_state, 3, padding=dilation, dilation=dilation)
        self.conv2 = nn.Conv1d(n_state, n_state, 3, padding=dilation, dilation=dilation)
        self.conv3 = nn.Conv1d(n_state, n_in, 1)  # Back to input dimensions

        if dropout:
            self.drop = nn.Dropout(dropout)

        # Initialize weights
        init.kaiming_normal_(self.conv1.weight, nonlinearity='relu')
        init.kaiming_normal_(self.conv2.weight, nonlinearity='relu')
        init.kaiming_normal_(self.conv3.weight, nonlinearity='relu')

    def forward(self, x):
        x_orig = x

        # Normalize and activate
        x = self.norm1(x)
        x = self.activation1(x)

        # First convolution
        x = self.conv1(x)
        
        # Apply dropout if specified
        if self.dropout:
            x = self.drop(x)

        # Normalize and activate again
        x = self.norm2(x)
        x = self.activation2(x)

        # Second convolution
        x = self.conv2(x)

        # Normalize, activate, and apply the final convolution
        x = self.norm3(x)
        x = self.activation3(x)
        x = self.conv3(x)

        # Apply skip connection
        x = x + x_orig

        return x


# ResNet1D with multiple residual blocks
class Resnet1D(nn.Module):
    def __init__(self, n_in, n_depth, dilation_growth_rate=1, reverse_dilation=True, activation='relu', norm=None, dropout=None):
        super().__init__()
        
        # Create residual blocks
        blocks = [ResConv1DBlock(n_in, n_in, dilation=dilation_growth_rate ** depth, activation=activation, norm=norm, dropout=dropout) for depth in range(n_depth)]
        
        if reverse_dilation:
            blocks = blocks[::-1]  # Reverse the order if needed
        
        self.model = nn.Sequential(*blocks)

    def forward(self, x):        
        return self.model(x)