import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
from keras.datasets import mnist


# Streamlit interface
st.title("Autoencoder Visualization")
st.subheader("Please train the autoencoder on MNIST train dataset")
st.subheader("Test and observe decoding below")
st.subheader("View the autoencoder architecture:") 
st.markdown("[Image](https://tinyurl.com/7fxrkds8)")
st.subheader("Developed by Dr. Hemprasad Yashwant Patil, SENSE, VIT Vellore")


# Function to display image, latent representation, and reconstructed image
def display_reconstruction(index, autoencoder, encoder, x_test):
    original = x_test[index]
    latent_repr = encoder.predict(np.expand_dims(original, 0))[0]
    reconstructed = autoencoder.predict(np.expand_dims(original, 0))[0]

    fig, axs = plt.subplots(1, 3, figsize=(12, 4))

    # Display original image
    axs[0].imshow(np.reshape(original, (28, 28)), cmap='gray')
    axs[0].set_title('Original Image')

    # Display latent representation as a bar chart
    axs[1].bar(range(len(latent_repr)), latent_repr)
    axs[1].set_title('Latent Representation')

    # Display reconstructed image
    axs[2].imshow(np.reshape(reconstructed, (28, 28)), cmap='gray')
    axs[2].set_title('Reconstructed Image')

    for ax in axs:
        ax.axis('off')

    st.pyplot(fig)

# Main Streamlit app
st.title("Autoencoder Training and Visualization")

# Initialize or reuse session state
if "autoencoder" not in st.session_state:
    st.session_state.autoencoder = None
if "encoder" not in st.session_state:
    st.session_state.encoder = None
if "x_test" not in st.session_state:
    st.session_state.x_test = None

# Button to trigger training
if st.button("Train Autoencoder"):
    # Load and preprocess data
    (x_train, _), (x_test, _) = mnist.load_data()

    x_train = x_train.astype('float32') / 255.0
    x_test = x_test.astype('float32') / 255.0

    x_train = np.reshape(x_train, (-1, 784))  # Flatten to (None, 784)
    x_test = np.reshape(x_test, (-1, 784))

    # Save test set in session state
    st.session_state.x_test = x_test

    # Define autoencoder architecture
    input_img = keras.Input(shape=(784,))

    encoded = layers.Dense(128, activation='relu')(input_img)
    encoded = layers.Dense(64, activation='relu')(encoded)
    latent_vector = layers.Dense(32, activation='relu')(encoded)

    decoded = layers.Dense(64, activation='relu')(latent_vector)
    decoded = layers.Dense(128, activation='relu')(decoded)
    decoded = layers.Dense(784, activation='sigmoid')(decoded)

    autoencoder = keras.Model(input_img, decoded)
    autoencoder.compile(optimizer='adam', loss='binary_crossentropy')

    # Train the autoencoder and display progress
    with st.spinner("Training in progress..."):
        autoencoder.fit(x_train, x_train, epochs=5, batch_size=128, validation_data=(x_test, x_test))

    # Save trained models to session state
    st.session_state.autoencoder = autoencoder
    st.session_state.encoder = keras.Model(input_img, latent_vector)

# Input for image index to display
if st.session_state.autoencoder:
    test_index = st.number_input("Enter an index (0-9999) to view an image from the test set:", min_value=0, max_value=9999)

    # Button to display the reconstruction
    if st.button("Display Reconstruction"):
        display_reconstruction(test_index, st.session_state.autoencoder, st.session_state.encoder, st.session_state.x_test)