Spaces:
Sleeping
Sleeping
import streamlit as st | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from tensorflow import keras | |
from tensorflow.keras import layers | |
from keras.datasets import mnist | |
# Function to display image, latent representation, and reconstructed image | |
def display_reconstruction(index, autoencoder, encoder, x_test): | |
original = x_test[index] | |
latent_repr = encoder.predict(np.expand_dims(original, 0))[0] | |
reconstructed = autoencoder.predict(np.expand_dims(original, 0))[0] | |
fig, axs = plt.subplots(1, 3, figsize=(12, 4)) | |
# Display original image | |
axs[0].imshow(np.reshape(original, (28, 28)), cmap='gray') | |
axs[0].set_title('Original Image') | |
# Display latent representation as a bar chart | |
axs[1].bar(range(len(latent_repr)), latent_repr) | |
axs[1].set_title('Latent Representation') | |
# Display reconstructed image | |
axs[2].imshow(np.reshape(reconstructed, (28, 28)), cmap='gray') | |
axs[2].set_title('Reconstructed Image') | |
for ax in axs: | |
ax.axis('off') | |
st.pyplot(fig) | |
# Main Streamlit app | |
st.title("Autoencoder Training and Visualization") | |
# Initialize or reuse session state | |
if "autoencoder" not in st.session_state: | |
st.session_state.autoencoder = None | |
if "encoder" not in st.session_state: | |
st.session_state.encoder = None | |
if "x_test" not in st.session_state: | |
st.session_state.x_test = None | |
# Button to trigger training | |
if st.button("Train Autoencoder"): | |
# Load and preprocess data | |
(x_train, _), (x_test, _) = mnist.load_data() | |
x_train = x_train.astype('float32') / 255.0 | |
x_test = x_test.astype('float32') / 255.0 | |
x_train = np.reshape(x_train, (-1, 784)) # Flatten to (None, 784) | |
x_test = np.reshape(x_test, (-1, 784)) | |
# Save test set in session state | |
st.session_state.x_test = x_test | |
# Define autoencoder architecture | |
input_img = keras.Input(shape=(784,)) | |
encoded = layers.Dense(128, activation='relu')(input_img) | |
encoded = layers.Dense(64, activation='relu')(encoded) | |
latent_vector = layers.Dense(32, activation='relu')(encoded) | |
decoded = layers.Dense(64, activation='relu')(latent_vector) | |
decoded = layers.Dense(128, activation='relu')(decoded) | |
decoded = layers.Dense(784, activation='sigmoid')(decoded) | |
autoencoder = keras.Model(input_img, decoded) | |
autoencoder.compile(optimizer='adam', loss='binary_crossentropy') | |
# Train the autoencoder and display progress | |
with st.spinner("Training in progress..."): | |
autoencoder.fit(x_train, x_train, epochs=5, batch_size=128, validation_data=(x_test, x_test)) | |
# Save trained models to session state | |
st.session_state.autoencoder = autoencoder | |
st.session_state.encoder = keras.Model(input_img, latent_vector) | |
# Input for image index to display | |
if st.session_state.autoencoder: | |
test_index = st.number_input("Enter an index (0-9999) to view an image from the test set:", min_value=0, max_value=9999) | |
# Button to display the reconstruction | |
if st.button("Display Reconstruction"): | |
display_reconstruction(test_index, st.session_state.autoencoder, st.session_state.encoder, st.session_state.x_test) | |