|
import torch.nn as nn |
|
|
|
|
|
class Encoder(nn.Module): |
|
def __init__(self, input_size, hidden_size, latent_size, num_lstm_layers): |
|
super(Encoder, self).__init__() |
|
|
|
self.encoder_lstm = nn.LSTM( |
|
input_size, hidden_size, num_lstm_layers, batch_first=True |
|
) |
|
self.latent = nn.Linear(hidden_size, latent_size) |
|
|
|
def forward(self, x): |
|
lstm_out, (h_n, c_n) = self.encoder_lstm(x) |
|
h_last = lstm_out[:, -1, :] |
|
latent = self.latent(h_last) |
|
|
|
return latent |
|
|
|
def encode(self, x): |
|
lstm_out, _ = self.encoder_lstm(x) |
|
h_last = lstm_out[:, -1, :] |
|
latent = self.latent(h_last) |
|
return latent |
|
|
|
|
|
class Decoder(nn.Module): |
|
def __init__(self, input_size, latent_size, sequence_length): |
|
super(Decoder, self).__init__() |
|
|
|
self.sequence_length = sequence_length |
|
|
|
self.decoder_mlp = nn.Sequential( |
|
nn.Linear(latent_size, 128), |
|
nn.ReLU(), |
|
nn.Linear(128, input_size * sequence_length), |
|
) |
|
|
|
def forward(self, x): |
|
|
|
decoded = self.decoder_mlp(x) |
|
return decoded |
|
|
|
|
|
class Autoencoder(nn.Module): |
|
def __init__( |
|
self, input_size, hidden_size, latent_size, sequence_length, num_lstm_layers=1 |
|
): |
|
super(Autoencoder, self).__init__() |
|
|
|
self.sequence_length = sequence_length |
|
self.hidden_size = hidden_size |
|
|
|
self.encoder = Encoder(input_size, hidden_size, latent_size, num_lstm_layers) |
|
self.decoder = Decoder(input_size, latent_size, sequence_length) |
|
|
|
def forward(self, x): |
|
latent = self.encoder(x) |
|
decoded = self.decoder(latent) |
|
decoded = decoded.view(-1, self.sequence_length, x.size(2)) |
|
|
|
return decoded |
|
|