Hem345's picture
Update app.py
0292a29 verified
raw
history blame
3.14 kB
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
from keras.datasets import mnist
# Function to display image, latent representation, and reconstructed image
def display_reconstruction(index, autoencoder, encoder, x_test):
original = x_test[index]
latent_repr = encoder.predict(np.expand_dims(original, 0))[0]
reconstructed = autoencoder.predict(np.expand_dims(original, 0))[0]
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
# Display original image
axs[0].imshow(np.reshape(original, (28, 28)), cmap='gray')
axs[0].set_title('Original Image')
# Display latent representation as a bar chart
axs[1].bar(range(len(latent_repr)), latent_repr)
axs[1].set_title('Latent Representation')
# Display reconstructed image
axs[2].imshow(np.reshape(reconstructed, (28, 28)), cmap='gray')
axs[2].set_title('Reconstructed Image')
for ax in axs:
ax.axis('off')
st.pyplot(fig)
# Main Streamlit app
st.title("Autoencoder Training and Visualization")
# Initialize or reuse session state
if "autoencoder" not in st.session_state:
st.session_state.autoencoder = None
if "encoder" not in st.session_state:
st.session_state.encoder = None
if "x_test" not in st.session_state:
st.session_state.x_test = None
# Button to trigger training
if st.button("Train Autoencoder"):
# Load and preprocess data
(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
x_train = np.reshape(x_train, (-1, 784)) # Flatten to (None, 784)
x_test = np.reshape(x_test, (-1, 784))
# Save test set in session state
st.session_state.x_test = x_test
# Define autoencoder architecture
input_img = keras.Input(shape=(784,))
encoded = layers.Dense(128, activation='relu')(input_img)
encoded = layers.Dense(64, activation='relu')(encoded)
latent_vector = layers.Dense(32, activation='relu')(encoded)
decoded = layers.Dense(64, activation='relu')(latent_vector)
decoded = layers.Dense(128, activation='relu')(decoded)
decoded = layers.Dense(784, activation='sigmoid')(decoded)
autoencoder = keras.Model(input_img, decoded)
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')
# Train the autoencoder and display progress
with st.spinner("Training in progress..."):
autoencoder.fit(x_train, x_train, epochs=5, batch_size=128, validation_data=(x_test, x_test))
# Save trained models to session state
st.session_state.autoencoder = autoencoder
st.session_state.encoder = keras.Model(input_img, latent_vector)
# Input for image index to display
if st.session_state.autoencoder:
test_index = st.number_input("Enter an index (0-9999) to view an image from the test set:", min_value=0, max_value=9999)
# Button to display the reconstruction
if st.button("Display Reconstruction"):
display_reconstruction(test_index, st.session_state.autoencoder, st.session_state.encoder, st.session_state.x_test)