Hem345's picture
Update app.py
2116640 verified
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
from keras.datasets import mnist
# Streamlit interface
st.title("Autoencoder Visualization")
st.subheader("Please train the autoencoder on MNIST train dataset")
st.subheader("Test and observe decoding below")
st.subheader("View the autoencoder architecture:")
st.markdown("[Image](https://tinyurl.com/7fxrkds8)")
st.subheader("Developed by Dr. Hemprasad Yashwant Patil, SENSE, VIT Vellore")
# Function to display image, latent representation, and reconstructed image
def display_reconstruction(index, autoencoder, encoder, x_test):
original = x_test[index]
latent_repr = encoder.predict(np.expand_dims(original, 0))[0]
reconstructed = autoencoder.predict(np.expand_dims(original, 0))[0]
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
# Display original image
axs[0].imshow(np.reshape(original, (28, 28)), cmap='gray')
axs[0].set_title('Original Image')
# Display latent representation as a bar chart
axs[1].bar(range(len(latent_repr)), latent_repr)
axs[1].set_title('Latent Representation')
# Display reconstructed image
axs[2].imshow(np.reshape(reconstructed, (28, 28)), cmap='gray')
axs[2].set_title('Reconstructed Image')
for ax in axs:
ax.axis('off')
st.pyplot(fig)
# Main Streamlit app
st.title("Autoencoder Training and Visualization")
# Initialize or reuse session state
if "autoencoder" not in st.session_state:
st.session_state.autoencoder = None
if "encoder" not in st.session_state:
st.session_state.encoder = None
if "x_test" not in st.session_state:
st.session_state.x_test = None
# Button to trigger training
if st.button("Train Autoencoder"):
# Load and preprocess data
(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
x_train = np.reshape(x_train, (-1, 784)) # Flatten to (None, 784)
x_test = np.reshape(x_test, (-1, 784))
# Save test set in session state
st.session_state.x_test = x_test
# Define autoencoder architecture
input_img = keras.Input(shape=(784,))
encoded = layers.Dense(128, activation='relu')(input_img)
encoded = layers.Dense(64, activation='relu')(encoded)
latent_vector = layers.Dense(32, activation='relu')(encoded)
decoded = layers.Dense(64, activation='relu')(latent_vector)
decoded = layers.Dense(128, activation='relu')(decoded)
decoded = layers.Dense(784, activation='sigmoid')(decoded)
autoencoder = keras.Model(input_img, decoded)
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')
# Train the autoencoder and display progress
with st.spinner("Training in progress..."):
autoencoder.fit(x_train, x_train, epochs=5, batch_size=128, validation_data=(x_test, x_test))
# Save trained models to session state
st.session_state.autoencoder = autoencoder
st.session_state.encoder = keras.Model(input_img, latent_vector)
# Input for image index to display
if st.session_state.autoencoder:
test_index = st.number_input("Enter an index (0-9999) to view an image from the test set:", min_value=0, max_value=9999)
# Button to display the reconstruction
if st.button("Display Reconstruction"):
display_reconstruction(test_index, st.session_state.autoencoder, st.session_state.encoder, st.session_state.x_test)