Spaces:
Running
Running
import streamlit as st | |
import tensorflow as tf | |
import os | |
import requests | |
import tempfile | |
import matplotlib.pyplot as plt | |
from tensorflow.keras.models import Sequential | |
from tensorflow.keras.layers import Flatten, Dense, Reshape | |
from tensorflow.keras.losses import SparseCategoricalCrossentropy | |
from io import StringIO | |
import datetime | |
import tensorboard | |
from tensorboard import program | |
try: | |
# Check if a GPU is available | |
gpu = len(tf.config.list_physical_devices('GPU')) > 0 | |
if gpu: | |
st.write("GPU is available!") # Inform the user | |
# Set TensorFlow to use the GPU if available (optional, usually automatic) | |
# You can specify which GPU if you have multiple: | |
# tf.config.set_visible_devices(tf.config.list_physical_devices('GPU')[0], 'GPU') # Use the first GPU | |
# or | |
# tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True) # Memory growth for the first GPU | |
# or | |
# strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) # Use multiple GPUs | |
else: | |
st.write("GPU is not available. Using CPU.") | |
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Force CPU usage (optional) | |
except RuntimeError as e: | |
st.write(f"Error checking GPU: {e}") | |
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Force CPU usage if there is a runtime error | |
def run_tensorboard(log_dir): | |
# Start TensorBoard | |
tb = program.TensorBoard() | |
tb.configure(argv=[None, '--logdir', log_dir]) | |
url = tb.launch() | |
return url | |
# Constants for dataset information | |
TRAIN_FILE = "train_images.tfrecords" | |
VAL_FILE = "val_images.tfrecords" | |
TRAIN_URL = "https://huggingface.co/datasets/louiecerv/cardiac_images/resolve/main/train_images.tfrecords" | |
VAL_URL = "https://huggingface.co/datasets/louiecerv/cardiac_images/resolve/main/val_images.tfrecords" | |
# Use a persistent temp directory | |
tmpdir = tempfile.gettempdir() | |
# Function to download a file with progress display | |
def download_file(url, local_filename, target_dir): | |
os.makedirs(target_dir, exist_ok=True) | |
filepath = os.path.join(target_dir, local_filename) | |
if os.path.exists(filepath): | |
st.write(f"File already exists: {filepath}") | |
return filepath | |
with requests.get(url, stream=True) as r: | |
r.raise_for_status() | |
total_size = int(r.headers.get('content-length', 0)) | |
progress_bar = st.empty() # Create a placeholder | |
with open(filepath, 'wb') as f: | |
downloaded_size = 0 | |
for chunk in r.iter_content(chunk_size=8192): | |
if chunk: | |
f.write(chunk) | |
downloaded_size += len(chunk) | |
progress_percent = int(downloaded_size / total_size * 100) | |
progress_bar.progress(progress_percent, text=f"Downloading {local_filename}...") | |
return filepath | |
# Download only if files are missing | |
train_file_path = download_file(TRAIN_URL, TRAIN_FILE, tmpdir) | |
val_file_path = download_file(VAL_URL, VAL_FILE, tmpdir) | |
# Dictionary describing the fields stored in TFRecord | |
image_feature_description = { | |
'height': tf.io.FixedLenFeature([], tf.int64), | |
'width': tf.io.FixedLenFeature([], tf.int64), | |
'depth': tf.io.FixedLenFeature([], tf.int64), | |
'name': tf.io.FixedLenFeature([], tf.string), | |
'image_raw': tf.io.FixedLenFeature([], tf.string), | |
'label_raw': tf.io.FixedLenFeature([], tf.string), | |
} | |
# Helper function to parse the image and label data from TFRecord | |
def _parse_image_function(example_proto): | |
return tf.io.parse_single_example(example_proto, image_feature_description) | |
# Function to read and decode an example from the dataset | |
def read_and_decode(example): | |
image_raw = tf.io.decode_raw(example['image_raw'], tf.int64) | |
image_raw.set_shape([65536]) | |
image = tf.reshape(image_raw, [256, 256, 1]) | |
image = tf.cast(image, tf.float32) * (1. / 1024) | |
label_raw = tf.io.decode_raw(example['label_raw'], tf.uint8) | |
label_raw.set_shape([65536]) | |
label = tf.reshape(label_raw, [256, 256, 1]) | |
return image, label | |
# Load and parse datasets | |
raw_training_dataset = tf.data.TFRecordDataset(train_file_path) | |
raw_val_dataset = tf.data.TFRecordDataset(val_file_path) | |
parsed_training_dataset = raw_training_dataset.map(_parse_image_function) | |
parsed_val_dataset = raw_val_dataset.map(_parse_image_function) | |
# Prepare datasets | |
tf_autotune = tf.data.experimental.AUTOTUNE | |
train = parsed_training_dataset.map(read_and_decode, num_parallel_calls=tf_autotune) | |
val = parsed_val_dataset.map(read_and_decode) | |
BUFFER_SIZE = 10 | |
BATCH_SIZE = 1 | |
train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat() | |
train_dataset = train_dataset.prefetch(buffer_size=tf_autotune) | |
test_dataset = val.batch(BATCH_SIZE) | |
st.write(train_dataset) | |
# function to take a prediction from the model and output an image for display | |
def create_mask(pred_mask): | |
pred_mask = tf.argmax(pred_mask, axis=-1) | |
pred_mask = pred_mask[..., tf.newaxis] | |
return pred_mask[0] | |
def display(display_list): | |
fig = plt.figure(figsize=(10, 10)) | |
title = ['Input Image', 'Label', 'Prediction'] # Updated title list | |
for i in range(len(display_list)): | |
ax = fig.add_subplot(1, len(display_list), i + 1) | |
display_resized = tf.reshape(display_list[i], [256, 256]) | |
ax.set_title(title[i]) # No longer out of range | |
ax.imshow(display_resized, cmap='gray') | |
ax.axis('off') | |
st.pyplot(fig) | |
# helper function to show the image, the label and the prediction | |
def show_predictions(dataset=None, num=1): | |
if dataset: | |
for image, label in dataset.take(num): | |
pred_mask = model.predict(image) | |
display([image[0], label[0], create_mask(pred_mask)]) | |
else: | |
prediction = create_mask(model.predict(sample_image[tf.newaxis, ...])) | |
display([sample_image, sample_label, prediction]) | |
# define a callback that shows image predictions on the test set | |
class DisplayCallback(tf.keras.callbacks.Callback): | |
def on_epoch_end(self, epoch, logs=None): | |
show_predictions() | |
st.write('\nSample Prediction after epoch {}\n'.format(epoch+1)) | |
# Streamlit app interface | |
st.title("Cardiac Images Dataset") | |
# Display sample images | |
for image, label in train.take(2): | |
sample_image, sample_label = image, label | |
display([sample_image, sample_label]) | |
tf.keras.backend.clear_session() | |
# set up the model architecture | |
model = tf.keras.models.Sequential([ | |
tf.keras.layers.Input(shape=(256, 256, 1)), # Define input shape | |
tf.keras.layers.Flatten(), | |
tf.keras.layers.Dense(64, activation='relu'), | |
tf.keras.layers.Dense(256*256*2, activation='softmax'), | |
tf.keras.layers.Reshape((256, 256, 2)) | |
]) | |
# specify how to train the model with algorithm, the loss function and metrics | |
model.compile( | |
optimizer='adam', | |
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), | |
metrics=['accuracy']) | |
# Capture the model summary | |
model_summary = StringIO() | |
model.summary(print_fn=lambda x: model_summary.write(x + '\n')) | |
# Display the model summary in Streamlit | |
st.markdown(model_summary.getvalue()) | |
try: | |
# Save the model plot | |
plot_filename = "model_plot.png" | |
tf.keras.utils.plot_model(model, to_file=plot_filename, show_shapes=True) | |
except Exception as e: | |
st.error(f"An error occurred: {e}") | |
# Streamlit App | |
st.title("Model Architecture") | |
# Display the model plot | |
st.image(plot_filename, caption="Neural Network Architecture", use_container_width=True) | |
# show a predection, as an example | |
show_predictions(test_dataset) | |
# setup a tensorboard callback | |
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) | |
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1) | |
if st.button("Train Model"): | |
# setup and run the model | |
EPOCHS = 20 | |
STEPS_PER_EPOCH = len(list(parsed_training_dataset)) | |
VALIDATION_STEPS = 26 | |
model_history = model.fit(train_dataset, epochs=EPOCHS, | |
steps_per_epoch=STEPS_PER_EPOCH, | |
validation_steps=VALIDATION_STEPS, | |
validation_data=test_dataset, | |
callbacks=[DisplayCallback(), tensorboard_callback]) | |
# output model statistics | |
loss = model_history.history['loss'] | |
val_loss = model_history.history['val_loss'] | |
accuracy = model_history.history['accuracy'] | |
val_accuracy = model_history.history['val_accuracy'] | |
epochs = range(EPOCHS) | |
st.title('Training and Validation Loss') # Optional title for the Streamlit app | |
fig, ax = plt.subplots() # Create a figure and an axes object | |
ax.plot(epochs, loss, 'r', label='Training loss') | |
ax.plot(epochs, val_loss, 'bo', label='Validation loss') | |
ax.set_title('Training and Validation Loss') #Set title for the axes | |
ax.set_xlabel('Epoch') | |
ax.set_ylabel('Loss Value') | |
ax.set_ylim([0, 1]) | |
ax.legend() | |
st.pyplot(fig) # Display the plot in Streamlit | |
if st.button("Evaluate Model"): | |
# Evaluate the model | |
evaluation_results = model.evaluate(test_dataset, verbose=0) # Set verbose=0 to suppress console output | |
# Assuming model.metrics_names provides labels for evaluation_results | |
results_dict = dict(zip(model.metrics_names, evaluation_results)) | |
st.subheader("Model Evaluation Results") | |
# Display each metric and its corresponding value | |
for metric, value in results_dict.items(): | |
st.write(f"**{metric.capitalize()}:** {value:.4f}") | |
if st.button("Show TensorBoard"): | |
# Create a log directory for TensorBoard | |
log_dir = "logs" | |
if not os.path.exists(log_dir): | |
os.makedirs(log_dir) | |
# Run TensorBoard | |
url = run_tensorboard(log_dir) | |
# Display TensorBoard in an iframe | |
st.markdown(f"<iframe src='{url}' width='100%' height='800'></iframe>", unsafe_allow_html=True) | |
if st.button("CNN"): | |
tf.keras.backend.clear_session() | |
inputs = tf.keras.Input(shape=(256, 256, 1), name="InputLayer") | |
x = tf.keras.layers.Conv2D(filters=100, kernel_size=5, strides=2, padding="same", | |
activation="relu", name="Conv1")(inputs) | |
x = tf.keras.layers.MaxPool2D(pool_size=2, strides=2, padding="same")(x) | |
x = tf.keras.layers.Conv2D(filters=200, kernel_size=5, strides=2, padding="same", | |
activation="relu", name="Conv2")(x) | |
x = tf.keras.layers.MaxPool2D(pool_size=2, strides=2, padding="same")(x) | |
x = tf.keras.layers.Conv2D(filters=300, kernel_size=3, strides=1, padding="same", | |
activation="relu", name="Conv3")(x) | |
x = tf.keras.layers.Conv2D(filters=300, kernel_size=3, strides=1, padding="same", | |
activation="relu", name="Conv4")(x) | |
x = tf.keras.layers.Conv2D(filters=2, kernel_size=1, strides=1, padding="same", | |
activation="relu", name="Conv5")(x) | |
outputs = tf.keras.layers.Conv2DTranspose(filters=2, kernel_size=31, strides=16, | |
padding="same", activation="softmax", | |
name="UpSampling")(x) | |
model = tf.keras.Model(inputs=inputs, outputs=outputs, name="CNN_Segmentation") | |
model.compile( | |
optimizer=tf.keras.optimizers.Adam(), | |
loss=tf.keras.losses.SparseCategoricalCrossentropy(), | |
metrics=['accuracy'] | |
) | |
# Capture the model summary | |
model_summary = StringIO() | |
model.summary(print_fn=lambda x: model_summary.write(x + '\n')) | |
# plot the model including the sizes of the model | |
tf.keras.utils.plot_model(model, show_shapes=True) | |
# show a predection, as an example | |
show_predictions(test_dataset) | |
# Initialize new directories for new task | |
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) | |
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1) | |
# setup and run the model | |
EPOCHS = 20 | |
STEPS_PER_EPOCH = len(list(parsed_training_dataset)) | |
VALIDATION_STEPS = 26 | |
model_history = model.fit(train_dataset, epochs=EPOCHS, | |
steps_per_epoch=STEPS_PER_EPOCH, | |
validation_steps=VALIDATION_STEPS, | |
validation_data=test_dataset, | |
callbacks=[DisplayCallback(), tensorboard_callback]) | |
# output model statistics | |
loss = model_history.history['loss'] | |
val_loss = model_history.history['val_loss'] | |
accuracy = model_history.history['accuracy'] | |
val_accuracy = model_history.history['val_accuracy'] | |
epochs = range(EPOCHS) | |
st.title('Training and Validation Loss') # Optional title for the Streamlit app | |
fig, ax = plt.subplots() # Create a figure and an axes object | |
ax.plot(epochs, loss, 'r', label='Training loss') | |
ax.plot(epochs, val_loss, 'bo', label='Validation loss') | |
ax.set_title('Training and Validation Loss') #Set title for the axes | |
ax.set_xlabel('Epoch') | |
ax.set_ylabel('Loss Value') | |
ax.set_ylim([0, 1]) | |
ax.legend() | |
st.pyplot(fig) # Display the plot in Streamlit | |