File size: 1,757 Bytes
5451fa1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import torch.nn as nn
class Encoder(nn.Module):
def __init__(self, input_size, hidden_size, latent_size, num_lstm_layers):
super(Encoder, self).__init__()
self.encoder_lstm = nn.LSTM(
input_size, hidden_size, num_lstm_layers, batch_first=True
)
self.latent = nn.Linear(hidden_size, latent_size)
def forward(self, x):
lstm_out, (h_n, c_n) = self.encoder_lstm(x)
h_last = lstm_out[:, -1, :]
latent = self.latent(h_last)
return latent
def encode(self, x):
lstm_out, _ = self.encoder_lstm(x)
h_last = lstm_out[:, -1, :]
latent = self.latent(h_last)
return latent
class Decoder(nn.Module):
def __init__(self, input_size, latent_size, sequence_length):
super(Decoder, self).__init__()
self.sequence_length = sequence_length
self.decoder_mlp = nn.Sequential(
nn.Linear(latent_size, 128),
nn.ReLU(),
nn.Linear(128, input_size * sequence_length),
)
def forward(self, x):
decoded = self.decoder_mlp(x)
return decoded
class Autoencoder(nn.Module):
def __init__(
self, input_size, hidden_size, latent_size, sequence_length, num_lstm_layers=1
):
super(Autoencoder, self).__init__()
self.sequence_length = sequence_length
self.hidden_size = hidden_size
self.encoder = Encoder(input_size, hidden_size, latent_size, num_lstm_layers)
self.decoder = Decoder(input_size, latent_size, sequence_length)
def forward(self, x):
latent = self.encoder(x)
decoded = self.decoder(latent)
decoded = decoded.view(-1, self.sequence_length, x.size(2))
return decoded
|