import streamlit as st import numpy as np import matplotlib.pyplot as plt from tensorflow import keras from tensorflow.keras import layers from keras.datasets import mnist # Streamlit interface st.title("Autoencoder Visualization") st.subheader("Please train the autoencoder on MNIST train dataset") st.subheader("Test and observe decoding below") st.subheader("View the autoencoder architecture:") st.markdown("[Image](https://tinyurl.com/7fxrkds8)") st.subheader("Developed by Dr. Hemprasad Yashwant Patil, SENSE, VIT Vellore") # Function to display image, latent representation, and reconstructed image def display_reconstruction(index, autoencoder, encoder, x_test): original = x_test[index] latent_repr = encoder.predict(np.expand_dims(original, 0))[0] reconstructed = autoencoder.predict(np.expand_dims(original, 0))[0] fig, axs = plt.subplots(1, 3, figsize=(12, 4)) # Display original image axs[0].imshow(np.reshape(original, (28, 28)), cmap='gray') axs[0].set_title('Original Image') # Display latent representation as a bar chart axs[1].bar(range(len(latent_repr)), latent_repr) axs[1].set_title('Latent Representation') # Display reconstructed image axs[2].imshow(np.reshape(reconstructed, (28, 28)), cmap='gray') axs[2].set_title('Reconstructed Image') for ax in axs: ax.axis('off') st.pyplot(fig) # Main Streamlit app st.title("Autoencoder Training and Visualization") # Initialize or reuse session state if "autoencoder" not in st.session_state: st.session_state.autoencoder = None if "encoder" not in st.session_state: st.session_state.encoder = None if "x_test" not in st.session_state: st.session_state.x_test = None # Button to trigger training if st.button("Train Autoencoder"): # Load and preprocess data (x_train, _), (x_test, _) = mnist.load_data() x_train = x_train.astype('float32') / 255.0 x_test = x_test.astype('float32') / 255.0 x_train = np.reshape(x_train, (-1, 784)) # Flatten to (None, 784) x_test = np.reshape(x_test, (-1, 784)) # Save test set in session state st.session_state.x_test = x_test # Define autoencoder architecture input_img = keras.Input(shape=(784,)) encoded = layers.Dense(128, activation='relu')(input_img) encoded = layers.Dense(64, activation='relu')(encoded) latent_vector = layers.Dense(32, activation='relu')(encoded) decoded = layers.Dense(64, activation='relu')(latent_vector) decoded = layers.Dense(128, activation='relu')(decoded) decoded = layers.Dense(784, activation='sigmoid')(decoded) autoencoder = keras.Model(input_img, decoded) autoencoder.compile(optimizer='adam', loss='binary_crossentropy') # Train the autoencoder and display progress with st.spinner("Training in progress..."): autoencoder.fit(x_train, x_train, epochs=5, batch_size=128, validation_data=(x_test, x_test)) # Save trained models to session state st.session_state.autoencoder = autoencoder st.session_state.encoder = keras.Model(input_img, latent_vector) # Input for image index to display if st.session_state.autoencoder: test_index = st.number_input("Enter an index (0-9999) to view an image from the test set:", min_value=0, max_value=9999) # Button to display the reconstruction if st.button("Display Reconstruction"): display_reconstruction(test_index, st.session_state.autoencoder, st.session_state.encoder, st.session_state.x_test)