Hem345 commited on
Commit
8692d12
·
verified ·
1 Parent(s): 63ae85b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py CHANGED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from tensorflow import keras
5
+ from tensorflow.keras import layers
6
+ from keras.datasets import mnist
7
+
8
+ # Function to display image, latent representation, and reconstructed image
9
+ def display_reconstruction(index, autoencoder, encoder, x_test):
10
+ original = x_test[index]
11
+ latent_repr = encoder.predict(np.expand_dims(original, 0))[0]
12
+ reconstructed = autoencoder.predict(np.expand_dims(original, 0))[0]
13
+
14
+ fig, axs = plt.subplots(1, 3, figsize=(12, 4))
15
+
16
+ # Display original image
17
+ axs[0].imshow(np.reshape(original, (28, 28)), cmap='gray')
18
+ axs[0].set_title('Original Image')
19
+
20
+ # Display latent representation as a bar chart
21
+ axs[1].bar(range(len(latent_repr)), latent_repr)
22
+ axs[1].set_title('Latent Representation')
23
+
24
+ # Display reconstructed image
25
+ axs[2].imshow(np.reshape(reconstructed, (28, 28)), cmap='gray')
26
+ axs[2].set_title('Reconstructed Image')
27
+
28
+ for ax in axs:
29
+ ax.axis('off')
30
+
31
+ st.pyplot(fig)
32
+
33
+ # Main Streamlit app
34
+ st.title("Autoencoder Training and Visualization")
35
+
36
+ # Button to trigger training
37
+ if st.button("Train Autoencoder"):
38
+ # Load and preprocess data
39
+ (x_train, _), (x_test, _) = mnist.load_data()
40
+
41
+ x_train = x_train.astype('float32') / 255.0
42
+ x_test = x_test.astype('float32') / 255.0
43
+
44
+ x_train = np.reshape(x_train, (-1, 784)) # Flatten to (None, 784)
45
+ x_test = np.reshape(x_test, (-1, 784))
46
+
47
+ # Define autoencoder architecture
48
+ input_img = keras.Input(shape=(784,))
49
+
50
+ encoded = layers.Dense(128, activation='relu')(input_img)
51
+ encoded = layers.Dense(64, activation='relu')(encoded)
52
+ latent_vector = layers.Dense(32, activation='relu')(encoded)
53
+
54
+ decoded = layers.Dense(64, activation='relu')(latent_vector)
55
+ decoded = layers.Dense(128, activation='relu')(decoded)
56
+ decoded = layers.Dense(784, activation='sigmoid')(decoded)
57
+
58
+ autoencoder = keras.Model(input_img, decoded)
59
+ autoencoder.compile(optimizer='adam', loss='binary_crossentropy')
60
+
61
+ # Train the autoencoder and display progress
62
+ with st.spinner("Training in progress..."):
63
+ autoencoder.fit(x_train, x_train, epochs=5, batch_size=128, validation_data=(x_test, x_test))
64
+
65
+ # Create encoder model
66
+ encoder = keras.Model(input_img, latent_vector)
67
+
68
+ # Input for image index to display
69
+ test_index = st.number_input("Enter an index (0-9999) to view an image from the test set:", min_value=0, max_value=9999)
70
+
71
+ # Button to display the reconstruction
72
+ if st.button("Display Reconstruction"):
73
+ display_reconstruction(test_index, autoencoder, encoder, x_test)