Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from tensorflow import keras | |
| from tensorflow.keras import layers | |
| from keras.datasets import mnist | |
| # Function to display image, latent representation, and reconstructed image | |
| def display_reconstruction(index, autoencoder, encoder, x_test): | |
| original = x_test[index] | |
| latent_repr = encoder.predict(np.expand_dims(original, 0))[0] | |
| reconstructed = autoencoder.predict(np.expand_dims(original, 0))[0] | |
| fig, axs = plt.subplots(1, 3, figsize=(12, 4)) | |
| # Display original image | |
| axs[0].imshow(np.reshape(original, (28, 28)), cmap='gray') | |
| axs[0].set_title('Original Image') | |
| # Display latent representation as a bar chart | |
| axs[1].bar(range(len(latent_repr)), latent_repr) | |
| axs[1].set_title('Latent Representation') | |
| # Display reconstructed image | |
| axs[2].imshow(np.reshape(reconstructed, (28, 28)), cmap='gray') | |
| axs[2].set_title('Reconstructed Image') | |
| for ax in axs: | |
| ax.axis('off') | |
| st.pyplot(fig) | |
| # Main Streamlit app | |
| st.title("Autoencoder Training and Visualization") | |
| # Initialize or reuse session state | |
| if "autoencoder" not in st.session_state: | |
| st.session_state.autoencoder = None | |
| if "encoder" not in st.session_state: | |
| st.session_state.encoder = None | |
| if "x_test" not in st.session_state: | |
| st.session_state.x_test = None | |
| # Button to trigger training | |
| if st.button("Train Autoencoder"): | |
| # Load and preprocess data | |
| (x_train, _), (x_test, _) = mnist.load_data() | |
| x_train = x_train.astype('float32') / 255.0 | |
| x_test = x_test.astype('float32') / 255.0 | |
| x_train = np.reshape(x_train, (-1, 784)) # Flatten to (None, 784) | |
| x_test = np.reshape(x_test, (-1, 784)) | |
| # Save test set in session state | |
| st.session_state.x_test = x_test | |
| # Define autoencoder architecture | |
| input_img = keras.Input(shape=(784,)) | |
| encoded = layers.Dense(128, activation='relu')(input_img) | |
| encoded = layers.Dense(64, activation='relu')(encoded) | |
| latent_vector = layers.Dense(32, activation='relu')(encoded) | |
| decoded = layers.Dense(64, activation='relu')(latent_vector) | |
| decoded = layers.Dense(128, activation='relu')(decoded) | |
| decoded = layers.Dense(784, activation='sigmoid')(decoded) | |
| autoencoder = keras.Model(input_img, decoded) | |
| autoencoder.compile(optimizer='adam', loss='binary_crossentropy') | |
| # Train the autoencoder and display progress | |
| with st.spinner("Training in progress..."): | |
| autoencoder.fit(x_train, x_train, epochs=5, batch_size=128, validation_data=(x_test, x_test)) | |
| # Save trained models to session state | |
| st.session_state.autoencoder = autoencoder | |
| st.session_state.encoder = keras.Model(input_img, latent_vector) | |
| # Input for image index to display | |
| if st.session_state.autoencoder: | |
| test_index = st.number_input("Enter an index (0-9999) to view an image from the test set:", min_value=0, max_value=9999) | |
| # Button to display the reconstruction | |
| if st.button("Display Reconstruction"): | |
| display_reconstruction(test_index, st.session_state.autoencoder, st.session_state.encoder, st.session_state.x_test) | |
| # Display architecture image | |
| st.subheader("Autoencoder Architecture") | |
| architecture_image_path = 'image1.png' # Path to the uploaded image | |
| architecture_image = Image.open(architecture_image_path) | |
| st.image(architecture_image, caption="Autoencoder Architecture") |