| import torch.nn as nn | |
| class Encoder(nn.Module): | |
| def __init__(self, input_size, hidden_size, latent_size, num_lstm_layers): | |
| super(Encoder, self).__init__() | |
| self.encoder_lstm = nn.LSTM( | |
| input_size, hidden_size, num_lstm_layers, batch_first=True | |
| ) | |
| self.latent = nn.Linear(hidden_size, latent_size) | |
| def forward(self, x): | |
| lstm_out, (h_n, c_n) = self.encoder_lstm(x) | |
| h_last = lstm_out[:, -1, :] | |
| latent = self.latent(h_last) | |
| return latent | |
| def encode(self, x): | |
| lstm_out, _ = self.encoder_lstm(x) | |
| h_last = lstm_out[:, -1, :] | |
| latent = self.latent(h_last) | |
| return latent | |
| class Decoder(nn.Module): | |
| def __init__(self, input_size, latent_size, sequence_length): | |
| super(Decoder, self).__init__() | |
| self.sequence_length = sequence_length | |
| self.decoder_mlp = nn.Sequential( | |
| nn.Linear(latent_size, 128), | |
| nn.ReLU(), | |
| nn.Linear(128, input_size * sequence_length), | |
| ) | |
| def forward(self, x): | |
| decoded = self.decoder_mlp(x) | |
| return decoded | |
| class Autoencoder(nn.Module): | |
| def __init__( | |
| self, input_size, hidden_size, latent_size, sequence_length, num_lstm_layers=1 | |
| ): | |
| super(Autoencoder, self).__init__() | |
| self.sequence_length = sequence_length | |
| self.hidden_size = hidden_size | |
| self.encoder = Encoder(input_size, hidden_size, latent_size, num_lstm_layers) | |
| self.decoder = Decoder(input_size, latent_size, sequence_length) | |
| def forward(self, x): | |
| latent = self.encoder(x) | |
| decoded = self.decoder(latent) | |
| decoded = decoded.view(-1, self.sequence_length, x.size(2)) | |
| return decoded | |