File size: 9,815 Bytes
0852299 5d88ea6 61cc1ca 0852299 55d4567 c382069 61cc1ca 0852299 6131316 0852299 55d4567 0852299 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 |
import streamlit as st
from tensorflow import keras
import os
import matplotlib.pyplot as plt
from io import BytesIO
from NNVisualiser import NNVisualiser
import glob
import inspect
from tensorflow.keras.models import save_model
import tempfile
import re
import zipfile
import io
# Function to create a ZIP file of all PNG files
def create_zip_of_png_files():
# Get current working directory
cwd = os.getcwd()
png_files = [f for f in os.listdir(cwd) if f.endswith('.png')]
# Create a BytesIO object to hold the ZIP file in memory
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w') as zip_file:
for png_file in png_files:
zip_file.write(os.path.join(cwd, png_file), arcname=png_file)
zip_buffer.seek(0) # Seek to the beginning of the BytesIO buffer
return zip_buffer
def generate_title_from_method_name(method_name):
# Remove the "plot" prefix if it exists
if method_name.startswith("plot"):
method_name = method_name[4:] # Remove the first 4 characters ("plot")
# Split the string at camel case boundaries
words = re.findall(r'[A-Z][a-z]*', method_name)
# Join the words with spaces and format the final string
title = "Plotting " + " ".join(words[:]) + " Plot "
return title
def downloadKerasModel():
with tempfile.NamedTemporaryFile(delete=False, suffix=".keras") as tmp_file:
save_model(model, tmp_file.name)
tmp_file.seek(0)
model_data = tmp_file.read()
return model_data
# Function to build folder hierarchy up to the 6th level (excluding files and hidden folders)
# @st.cache_data
def generate_folder_hierarchy(root_folder, max_depth=7):
folder_dict = {}
# Traverse through the directory tree
for dirpath, dirnames, filenames in os.walk(root_folder):
# Get the relative path from the root folder
rel_path = os.path.relpath(dirpath, root_folder)
depth = rel_path.count(os.sep) + 1 # Calculate the depth level
# Only include directories up to the max_depth (7th level)
if depth > max_depth:
continue
# Filter out directories that start with a dot (e.g., .git)
dirnames[:] = [d for d in dirnames if not d.startswith('.') and d != '1']
sub_dict = folder_dict
# Split the relative path into parts to create a nested structure
for part in rel_path.split(os.sep):
if part == '.' or part.startswith('.'):
continue
if part not in sub_dict:
sub_dict[part] = {}
sub_dict = sub_dict[part]
return folder_dict
@st.cache_data
def getPlotMethods():
return [name for name, func in inspect.getmembers(NNVisualiser, inspect.isfunction) if name.startswith('plot')]
# Example usage
root_folder = os.getcwd(); # Replace with your folder path
folder_hierarchy = generate_folder_hierarchy(root_folder)
# Streamlit app
st.title("Repository : Simple ANN Models with UAT Architecture")
st.write(f"A Collection of ANN Models with a 1-xReLU-1 Architecture for Basic 1D Functions on Bounded Intervals")
#Commented
# col1, col2, col3 = st.columns([4, 3, 3])
# with col1:
# # Level 1: Initialisation dropdown
# initialisation = st.selectbox("Select Initialisation", list(folder_hierarchy.keys()))
# with col2:
# # Level 2: Sample size dropdown, based on selected initialisation
# sampleSize = st.selectbox("Select Sample Size", list(folder_hierarchy[initialisation].keys()))
# with col3:
# # Level 3: Batch size dropdown, based on selected sample size
# batchSize = st.selectbox("Select Batch Size", list(folder_hierarchy[initialisation][sampleSize].keys()))
# col4, col5, col6 = st.columns([3, 4, 3])
# with col4:
# # Level 4: Epochs count dropdown, based on selected batch size
# epochs = st.selectbox("Select Epochs Count", list(folder_hierarchy[initialisation][sampleSize][batchSize].keys()))
# with col5:
# # Level 5: Functions list dropdown, based on selected epochs count
# functions = st.selectbox("Select Neurons Count", list(folder_hierarchy[initialisation][sampleSize][batchSize][epochs].keys()))
# with col6:
# # Level 6: Neurons count dropdown, based on selected function
# neurons = st.selectbox("Select Neurons Count", list(folder_hierarchy[initialisation][sampleSize][batchSize][epochs][functions].keys()))
repo = st.sidebar.selectbox("Select Model Repository",list(folder_hierarchy.keys()))
initialisation = st.sidebar.selectbox("Select Initialisation", list(folder_hierarchy[repo].keys()))
sampleSize = st.sidebar.selectbox("Select Sample Size", list(folder_hierarchy[repo][initialisation].keys()))
batchSize = st.sidebar.selectbox("Select Batch Size", list(folder_hierarchy[repo][initialisation][sampleSize].keys()))
epochs = st.sidebar.selectbox("Select Epochs Count", list(folder_hierarchy[repo][initialisation][sampleSize][batchSize].keys()))
functions = st.sidebar.selectbox("Select Function", list(folder_hierarchy[repo][initialisation][sampleSize][batchSize][epochs].keys()))
neurons = st.sidebar.selectbox("Select Neurons Count", list(folder_hierarchy[repo][initialisation][sampleSize][batchSize][epochs][functions].keys()))
# Display the selected values
st.write(f"You selected: {repo} : {initialisation} : {sampleSize} : {batchSize} : {epochs} : {functions} : {neurons}")
modelPath = os.path.join(os.getcwd(), repo, initialisation, sampleSize, batchSize, epochs, functions, neurons);
model = keras.models.load_model(modelPath);
visualiser = NNVisualiser(model);
visualiser.setSavePlots(True);
# Function to get layer and neuron information
def get_layer_info(model):
layer_info = []
for layer in model.layers:
layer_info.append({
'index': len(layer_info),
'type': layer.__class__.__name__,
'units': getattr(layer, 'units', None), # Number of neurons
})
return layer_info
layer_info = get_layer_info(model)
# Extract layer indices and neuron counts
layer_indices = [layer['index'] for layer in layer_info]
neuron_counts = [layer['units'] for layer in layer_info]
# Dropdown for selecting layer index
#selected_layer_index = st.sidebar.selectbox("Select Layer Index", layer_indices)
# Find the number of neurons for the selected layer
#selected_layer_units = neuron_counts[selected_layer_index]
# Dropdown for selecting neuron index in the selected layer
#neuron_indices = list(range(selected_layer_units))
#selected_neuron_index = st.sidebar.selectbox("Select Neuron Index", neuron_indices)
# Dropdown for selecting plots from NNVisualiser
plotMethods = getPlotMethods()
selectedPlotMethod = st.sidebar.selectbox("Select Plot", plotMethods)
#Removing earlier plots
image_files = glob.glob("*.png")
for file in image_files:
try:
os.remove(file)
except Exception as e:
st.write("Error in removing previous plots")
st.session_state.title_text = generate_title_from_method_name(selectedPlotMethod)
st.title(st.session_state.title_text)
# Call your package's plot method (which directly plots without returning a figure)
visualiser.setSavePlots(True);
method = getattr(visualiser, selectedPlotMethod, None)
if method is not None:
if 'Neuron' in selectedPlotMethod:
selected_layer_index = st.sidebar.selectbox("Select Layer Index", layer_indices)
# Find the number of neurons for the selected layer
selected_layer_units = neuron_counts[selected_layer_index]
# Dropdown for selecting neuron index in the selected layer
neuron_indices = list(range(selected_layer_units))
selected_neuron_index = st.sidebar.selectbox("Select Neuron Index", neuron_indices)
params = (selected_layer_index, selected_neuron_index)
method(*params)
elif 'Layer' in selectedPlotMethod:
selected_layer_index = st.sidebar.selectbox("Select Layer Index", layer_indices)
params = (selected_layer_index,)
method(*params)
else:
method()
st.session_state.kerasModelToDownload = downloadKerasModel()
st.session_state.plotsToDownload = create_zip_of_png_files()
@st.fragment()
def downloads():
st.download_button(
label="Download Model",
data = downloadKerasModel(),
file_name="model.keras",
mime="application/octet-stream"
);
st.download_button(
label="Download Plots",
data=create_zip_of_png_files(),
file_name="images.zip",
mime="application/zip"
);
# column = st.columns (2)
# column[0].download_button(
# label="Download Model",
# data = downloadKerasModel(),
# file_name="model.keras",
# mime="application/octet-stream"
# );
# column[1].download_button(
# label="Download Plots",
# data=create_zip_of_png_files(),
# file_name="images.zip",
# mime="application/zip"
# );
with st.sidebar:
downloads()
# visualiser.plotFlowForNetwork();
image_files = glob.glob("*.png")
# Use Streamlit to display the image from the buffer
st.image(image_files)
# if st.sidebar.button("Download Keras model"):
# downloadKerasModel()
# if st.sidebar.download_button(
# label="Download Keras Model",
# data = downloadKerasModel(),
# file_name="model.keras",
# mime="application/octet-stream"
# ):
# st.sidebar.success(f"Model Downloaded Successfully")
# # Button to create and download the ZIP file
# if st.sidebar.download_button(
# label="Download Plots",
# data=create_zip_of_png_files(),
# file_name="images.zip",
# mime="application/zip"
# ):
# st.sidebar.success(f"Plots Downloaded Successfully")
|