team-7 / model.py
micpst's picture
add source files
5451fa1
import torch.nn as nn
class Encoder(nn.Module):
def __init__(self, input_size, hidden_size, latent_size, num_lstm_layers):
super(Encoder, self).__init__()
self.encoder_lstm = nn.LSTM(
input_size, hidden_size, num_lstm_layers, batch_first=True
)
self.latent = nn.Linear(hidden_size, latent_size)
def forward(self, x):
lstm_out, (h_n, c_n) = self.encoder_lstm(x)
h_last = lstm_out[:, -1, :]
latent = self.latent(h_last)
return latent
def encode(self, x):
lstm_out, _ = self.encoder_lstm(x)
h_last = lstm_out[:, -1, :]
latent = self.latent(h_last)
return latent
class Decoder(nn.Module):
def __init__(self, input_size, latent_size, sequence_length):
super(Decoder, self).__init__()
self.sequence_length = sequence_length
self.decoder_mlp = nn.Sequential(
nn.Linear(latent_size, 128),
nn.ReLU(),
nn.Linear(128, input_size * sequence_length),
)
def forward(self, x):
decoded = self.decoder_mlp(x)
return decoded
class Autoencoder(nn.Module):
def __init__(
self, input_size, hidden_size, latent_size, sequence_length, num_lstm_layers=1
):
super(Autoencoder, self).__init__()
self.sequence_length = sequence_length
self.hidden_size = hidden_size
self.encoder = Encoder(input_size, hidden_size, latent_size, num_lstm_layers)
self.decoder = Decoder(input_size, latent_size, sequence_length)
def forward(self, x):
latent = self.encoder(x)
decoded = self.decoder(latent)
decoded = decoded.view(-1, self.sequence_length, x.size(2))
return decoded