Hem345 commited on
Commit
7cc9dcf
·
verified ·
1 Parent(s): 8692d12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -5
app.py CHANGED
@@ -33,6 +33,14 @@ def display_reconstruction(index, autoencoder, encoder, x_test):
33
  # Main Streamlit app
34
  st.title("Autoencoder Training and Visualization")
35
 
 
 
 
 
 
 
 
 
36
  # Button to trigger training
37
  if st.button("Train Autoencoder"):
38
  # Load and preprocess data
@@ -42,7 +50,10 @@ if st.button("Train Autoencoder"):
42
  x_test = x_test.astype('float32') / 255.0
43
 
44
  x_train = np.reshape(x_train, (-1, 784)) # Flatten to (None, 784)
45
- x_test = np.reshape(x_test, (-1, 784))
 
 
 
46
 
47
  # Define autoencoder architecture
48
  input_img = keras.Input(shape=(784,))
@@ -62,12 +73,14 @@ if st.button("Train Autoencoder"):
62
  with st.spinner("Training in progress..."):
63
  autoencoder.fit(x_train, x_train, epochs=5, batch_size=128, validation_data=(x_test, x_test))
64
 
65
- # Create encoder model
66
- encoder = keras.Model(input_img, latent_vector)
 
67
 
68
- # Input for image index to display
 
69
  test_index = st.number_input("Enter an index (0-9999) to view an image from the test set:", min_value=0, max_value=9999)
70
 
71
  # Button to display the reconstruction
72
  if st.button("Display Reconstruction"):
73
- display_reconstruction(test_index, autoencoder, encoder, x_test)
 
33
  # Main Streamlit app
34
  st.title("Autoencoder Training and Visualization")
35
 
36
+ # Initialize or reuse session state
37
+ if "autoencoder" not in st.session_state:
38
+ st.session_state.autoencoder = None
39
+ if "encoder" not in st.session_state:
40
+ st.session_state.encoder = None
41
+ if "x_test" not in st.session_state:
42
+ st.session_state.x_test = None
43
+
44
  # Button to trigger training
45
  if st.button("Train Autoencoder"):
46
  # Load and preprocess data
 
50
  x_test = x_test.astype('float32') / 255.0
51
 
52
  x_train = np.reshape(x_train, (-1, 784)) # Flatten to (None, 784)
53
+ x_test = np reshape(x_test, (-1, 784))
54
+
55
+ # Save test set in session state
56
+ st.session_state.x_test = x_test
57
 
58
  # Define autoencoder architecture
59
  input_img = keras.Input(shape=(784,))
 
73
  with st.spinner("Training in progress..."):
74
  autoencoder.fit(x_train, x_train, epochs=5, batch_size=128, validation_data=(x_test, x_test))
75
 
76
+ # Save trained models to session state
77
+ st.session_state.autoencoder = autoencoder
78
+ st.session_state.encoder = keras.Model(input_img, latent_vector)
79
 
80
+ # Input for image index to display
81
+ if st.session_state.autoencoder:
82
  test_index = st.number_input("Enter an index (0-9999) to view an image from the test set:", min_value=0, max_value=9999)
83
 
84
  # Button to display the reconstruction
85
  if st.button("Display Reconstruction"):
86
+ display_reconstruction(test_index, st.session_state.autoencoder, st.session_state.encoder, st.session_state.x_test)