File size: 1,757 Bytes
5451fa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch.nn as nn


class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size, num_lstm_layers):
        super(Encoder, self).__init__()

        self.encoder_lstm = nn.LSTM(
            input_size, hidden_size, num_lstm_layers, batch_first=True
        )
        self.latent = nn.Linear(hidden_size, latent_size)

    def forward(self, x):
        lstm_out, (h_n, c_n) = self.encoder_lstm(x)
        h_last = lstm_out[:, -1, :]
        latent = self.latent(h_last)

        return latent

    def encode(self, x):
        lstm_out, _ = self.encoder_lstm(x)
        h_last = lstm_out[:, -1, :]
        latent = self.latent(h_last)
        return latent


class Decoder(nn.Module):
    def __init__(self, input_size, latent_size, sequence_length):
        super(Decoder, self).__init__()

        self.sequence_length = sequence_length

        self.decoder_mlp = nn.Sequential(
            nn.Linear(latent_size, 128),
            nn.ReLU(),
            nn.Linear(128, input_size * sequence_length),
        )

    def forward(self, x):

        decoded = self.decoder_mlp(x)
        return decoded


class Autoencoder(nn.Module):
    def __init__(
        self, input_size, hidden_size, latent_size, sequence_length, num_lstm_layers=1
    ):
        super(Autoencoder, self).__init__()

        self.sequence_length = sequence_length
        self.hidden_size = hidden_size

        self.encoder = Encoder(input_size, hidden_size, latent_size, num_lstm_layers)
        self.decoder = Decoder(input_size, latent_size, sequence_length)

    def forward(self, x):
        latent = self.encoder(x)
        decoded = self.decoder(latent)
        decoded = decoded.view(-1, self.sequence_length, x.size(2))

        return decoded