Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -33,6 +33,14 @@ def display_reconstruction(index, autoencoder, encoder, x_test):
|
|
| 33 |
# Main Streamlit app
|
| 34 |
st.title("Autoencoder Training and Visualization")
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
# Button to trigger training
|
| 37 |
if st.button("Train Autoencoder"):
|
| 38 |
# Load and preprocess data
|
|
@@ -42,7 +50,10 @@ if st.button("Train Autoencoder"):
|
|
| 42 |
x_test = x_test.astype('float32') / 255.0
|
| 43 |
|
| 44 |
x_train = np.reshape(x_train, (-1, 784)) # Flatten to (None, 784)
|
| 45 |
-
x_test = np
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
# Define autoencoder architecture
|
| 48 |
input_img = keras.Input(shape=(784,))
|
|
@@ -62,12 +73,14 @@ if st.button("Train Autoencoder"):
|
|
| 62 |
with st.spinner("Training in progress..."):
|
| 63 |
autoencoder.fit(x_train, x_train, epochs=5, batch_size=128, validation_data=(x_test, x_test))
|
| 64 |
|
| 65 |
-
#
|
| 66 |
-
|
|
|
|
| 67 |
|
| 68 |
-
|
|
|
|
| 69 |
test_index = st.number_input("Enter an index (0-9999) to view an image from the test set:", min_value=0, max_value=9999)
|
| 70 |
|
| 71 |
# Button to display the reconstruction
|
| 72 |
if st.button("Display Reconstruction"):
|
| 73 |
-
display_reconstruction(test_index, autoencoder, encoder, x_test)
|
|
|
|
| 33 |
# Main Streamlit app
|
| 34 |
st.title("Autoencoder Training and Visualization")
|
| 35 |
|
| 36 |
+
# Initialize or reuse session state
|
| 37 |
+
if "autoencoder" not in st.session_state:
|
| 38 |
+
st.session_state.autoencoder = None
|
| 39 |
+
if "encoder" not in st.session_state:
|
| 40 |
+
st.session_state.encoder = None
|
| 41 |
+
if "x_test" not in st.session_state:
|
| 42 |
+
st.session_state.x_test = None
|
| 43 |
+
|
| 44 |
# Button to trigger training
|
| 45 |
if st.button("Train Autoencoder"):
|
| 46 |
# Load and preprocess data
|
|
|
|
| 50 |
x_test = x_test.astype('float32') / 255.0
|
| 51 |
|
| 52 |
x_train = np.reshape(x_train, (-1, 784)) # Flatten to (None, 784)
|
| 53 |
+
x_test = np reshape(x_test, (-1, 784))
|
| 54 |
+
|
| 55 |
+
# Save test set in session state
|
| 56 |
+
st.session_state.x_test = x_test
|
| 57 |
|
| 58 |
# Define autoencoder architecture
|
| 59 |
input_img = keras.Input(shape=(784,))
|
|
|
|
| 73 |
with st.spinner("Training in progress..."):
|
| 74 |
autoencoder.fit(x_train, x_train, epochs=5, batch_size=128, validation_data=(x_test, x_test))
|
| 75 |
|
| 76 |
+
# Save trained models to session state
|
| 77 |
+
st.session_state.autoencoder = autoencoder
|
| 78 |
+
st.session_state.encoder = keras.Model(input_img, latent_vector)
|
| 79 |
|
| 80 |
+
# Input for image index to display
|
| 81 |
+
if st.session_state.autoencoder:
|
| 82 |
test_index = st.number_input("Enter an index (0-9999) to view an image from the test set:", min_value=0, max_value=9999)
|
| 83 |
|
| 84 |
# Button to display the reconstruction
|
| 85 |
if st.button("Display Reconstruction"):
|
| 86 |
+
display_reconstruction(test_index, st.session_state.autoencoder, st.session_state.encoder, st.session_state.x_test)
|