import streamlit as st
from tensorflow import keras
import os
import matplotlib.pyplot as plt
from io import BytesIO
from NNVisualiser import NNVisualiser
import glob
import inspect
from tensorflow.keras.models import save_model
import tempfile
import re
import zipfile
import io
# Function to create a ZIP file of all PNG files
def create_zip_of_png_files():
# Get current working directory
cwd = os.getcwd()
png_files = [f for f in os.listdir(cwd) if f.endswith('.png')]
# Create a BytesIO object to hold the ZIP file in memory
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w') as zip_file:
for png_file in png_files:
zip_file.write(os.path.join(cwd, png_file), arcname=png_file) # Seek to the beginning of the BytesIO buffer
return zip_buffer
def generate_title_from_method_name(method_name):
# Remove the "plot" prefix if it exists
if method_name.startswith("plot"):
method_name = method_name[4:] # Remove the first 4 characters ("plot")
# Split the string at camel case boundaries
words = re.findall(r'[A-Z][a-z]*', method_name)
# Join the words with spaces and format the final string
title = "Plotting " + " ".join(words[:]) + " Plot "
return title
def downloadKerasModel():
with tempfile.NamedTemporaryFile(delete=False, suffix=".keras") as tmp_file:
model_data =
return model_data
# Function to build folder hierarchy up to the 6th level (excluding files and hidden folders)
# @st.cache_data
def generate_folder_hierarchy(root_folder, max_depth=7):
folder_dict = {}
# Traverse through the directory tree
for dirpath, dirnames, filenames in os.walk(root_folder):
# Get the relative path from the root folder
rel_path = os.path.relpath(dirpath, root_folder)
depth = rel_path.count(os.sep) + 1 # Calculate the depth level
# Only include directories up to the max_depth (7th level)
if depth > max_depth:
# Filter out directories that start with a dot (e.g., .git)
dirnames[:] = [d for d in dirnames if not d.startswith('.') and d != '1']
sub_dict = folder_dict
# Split the relative path into parts to create a nested structure
for part in rel_path.split(os.sep):
if part == '.' or part.startswith('.'):
if part not in sub_dict:
sub_dict[part] = {}
sub_dict = sub_dict[part]
return folder_dict
def getPlotMethods():
return [name for name, func in inspect.getmembers(NNVisualiser, inspect.isfunction) if name.startswith('plot')]
# Example usage
root_folder = os.getcwd(); # Replace with your folder path
folder_hierarchy = generate_folder_hierarchy(root_folder)
# Streamlit app
st.title("Repository : Simple ANN Models with UAT Architecture")
st.write(f"A Collection of ANN Models with a 1-xReLU-1 Architecture for Basic 1D Functions on Bounded Intervals")
# col1, col2, col3 = st.columns([4, 3, 3])
# with col1:
# # Level 1: Initialisation dropdown
# initialisation = st.selectbox("Select Initialisation", list(folder_hierarchy.keys()))
# with col2:
# # Level 2: Sample size dropdown, based on selected initialisation
# sampleSize = st.selectbox("Select Sample Size", list(folder_hierarchy[initialisation].keys()))
# with col3:
# # Level 3: Batch size dropdown, based on selected sample size
# batchSize = st.selectbox("Select Batch Size", list(folder_hierarchy[initialisation][sampleSize].keys()))
# col4, col5, col6 = st.columns([3, 4, 3])
# with col4:
# # Level 4: Epochs count dropdown, based on selected batch size
# epochs = st.selectbox("Select Epochs Count", list(folder_hierarchy[initialisation][sampleSize][batchSize].keys()))
# with col5:
# # Level 5: Functions list dropdown, based on selected epochs count
# functions = st.selectbox("Select Neurons Count", list(folder_hierarchy[initialisation][sampleSize][batchSize][epochs].keys()))
# with col6:
# # Level 6: Neurons count dropdown, based on selected function
# neurons = st.selectbox("Select Neurons Count", list(folder_hierarchy[initialisation][sampleSize][batchSize][epochs][functions].keys()))
repo = st.sidebar.selectbox("Select Model Repository",list(folder_hierarchy.keys()))
initialisation = st.sidebar.selectbox("Select Initialisation", list(folder_hierarchy[repo].keys()))
sampleSize = st.sidebar.selectbox("Select Sample Size", list(folder_hierarchy[repo][initialisation].keys()))
batchSize = st.sidebar.selectbox("Select Batch Size", list(folder_hierarchy[repo][initialisation][sampleSize].keys()))
epochs = st.sidebar.selectbox("Select Epochs Count", list(folder_hierarchy[repo][initialisation][sampleSize][batchSize].keys()))
functions = st.sidebar.selectbox("Select Function", list(folder_hierarchy[repo][initialisation][sampleSize][batchSize][epochs].keys()))
neurons = st.sidebar.selectbox("Select Neurons Count", list(folder_hierarchy[repo][initialisation][sampleSize][batchSize][epochs][functions].keys()))
# Display the selected values
st.write(f"You selected: {repo} : {initialisation} : {sampleSize} : {batchSize} : {epochs} : {functions} : {neurons}")
modelPath = os.path.join(os.getcwd(), repo, initialisation, sampleSize, batchSize, epochs, functions, neurons);
model = keras.models.load_model(modelPath);
visualiser = NNVisualiser(model);
# Function to get layer and neuron information
def get_layer_info(model):
layer_info = []
for layer in model.layers:
'index': len(layer_info),
'type': layer.__class__.__name__,
'units': getattr(layer, 'units', None), # Number of neurons
return layer_info
layer_info = get_layer_info(model)
# Extract layer indices and neuron counts
layer_indices = [layer['index'] for layer in layer_info]
neuron_counts = [layer['units'] for layer in layer_info]
# Dropdown for selecting layer index
#selected_layer_index = st.sidebar.selectbox("Select Layer Index", layer_indices)
# Find the number of neurons for the selected layer
#selected_layer_units = neuron_counts[selected_layer_index]
# Dropdown for selecting neuron index in the selected layer
#neuron_indices = list(range(selected_layer_units))
#selected_neuron_index = st.sidebar.selectbox("Select Neuron Index", neuron_indices)
# Dropdown for selecting plots from NNVisualiser
plotMethods = getPlotMethods()
selectedPlotMethod = st.sidebar.selectbox("Select Plot", plotMethods)
#Removing earlier plots
image_files = glob.glob("*.png")
for file in image_files:
except Exception as e:
st.write("Error in removing previous plots")
st.session_state.title_text = generate_title_from_method_name(selectedPlotMethod)
# Call your package's plot method (which directly plots without returning a figure)
method = getattr(visualiser, selectedPlotMethod, None)
if method is not None:
if 'Neuron' in selectedPlotMethod:
selected_layer_index = st.sidebar.selectbox("Select Layer Index", layer_indices)
# Find the number of neurons for the selected layer
selected_layer_units = neuron_counts[selected_layer_index]
# Dropdown for selecting neuron index in the selected layer
neuron_indices = list(range(selected_layer_units))
selected_neuron_index = st.sidebar.selectbox("Select Neuron Index", neuron_indices)
params = (selected_layer_index, selected_neuron_index)
elif 'Layer' in selectedPlotMethod:
selected_layer_index = st.sidebar.selectbox("Select Layer Index", layer_indices)
params = (selected_layer_index,)
st.session_state.kerasModelToDownload = downloadKerasModel()
st.session_state.plotsToDownload = create_zip_of_png_files()
def downloads():
label="Download Model",
data = downloadKerasModel(),
label="Download Plots",
# column = st.columns (2)
# column[0].download_button(
# label="Download Model",
# data = downloadKerasModel(),
# file_name="model.keras",
# mime="application/octet-stream"
# );
# column[1].download_button(
# label="Download Plots",
# data=create_zip_of_png_files(),
# file_name="",
# mime="application/zip"
# );
with st.sidebar:
# visualiser.plotFlowForNetwork();
image_files = glob.glob("*.png")
# Use Streamlit to display the image from the buffer
# if st.sidebar.button("Download Keras model"):
# downloadKerasModel()
# if st.sidebar.download_button(
# label="Download Keras Model",
# data = downloadKerasModel(),
# file_name="model.keras",
# mime="application/octet-stream"
# ):
# st.sidebar.success(f"Model Downloaded Successfully")
# # Button to create and download the ZIP file
# if st.sidebar.download_button(
# label="Download Plots",
# data=create_zip_of_png_files(),
# file_name="",
# mime="application/zip"
# ):
# st.sidebar.success(f"Plots Downloaded Successfully")