import torch.nn as nn


class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size, num_lstm_layers):
        super(Encoder, self).__init__()

        self.encoder_lstm = nn.LSTM(
            input_size, hidden_size, num_lstm_layers, batch_first=True
        )
        self.latent = nn.Linear(hidden_size, latent_size)

    def forward(self, x):
        lstm_out, (h_n, c_n) = self.encoder_lstm(x)
        h_last = lstm_out[:, -1, :]
        latent = self.latent(h_last)

        return latent

    def encode(self, x):
        lstm_out, _ = self.encoder_lstm(x)
        h_last = lstm_out[:, -1, :]
        latent = self.latent(h_last)
        return latent


class Decoder(nn.Module):
    def __init__(self, input_size, latent_size, sequence_length):
        super(Decoder, self).__init__()

        self.sequence_length = sequence_length

        self.decoder_mlp = nn.Sequential(
            nn.Linear(latent_size, 128),
            nn.ReLU(),
            nn.Linear(128, input_size * sequence_length),
        )

    def forward(self, x):

        decoded = self.decoder_mlp(x)
        return decoded


class Autoencoder(nn.Module):
    def __init__(
        self, input_size, hidden_size, latent_size, sequence_length, num_lstm_layers=1
    ):
        super(Autoencoder, self).__init__()

        self.sequence_length = sequence_length
        self.hidden_size = hidden_size

        self.encoder = Encoder(input_size, hidden_size, latent_size, num_lstm_layers)
        self.decoder = Decoder(input_size, latent_size, sequence_length)

    def forward(self, x):
        latent = self.encoder(x)
        decoded = self.decoder(latent)
        decoded = decoded.view(-1, self.sequence_length, x.size(2))

        return decoded