louiecerv commited on
Commit
7ee901b
·
1 Parent(s): 781fd86

sync with remote

Browse files
Files changed (4) hide show
  1. None +0 -0
  2. app.py +30 -4
  3. model.png +0 -0
  4. model_plot.png +0 -0
None ADDED
Binary file (41.2 kB). View file
 
app.py CHANGED
@@ -98,19 +98,35 @@ test_dataset = val.batch(BATCH_SIZE)
98
 
99
  st.write(train_dataset)
100
 
 
 
 
 
 
 
101
  def display(display_list):
102
  fig = plt.figure(figsize=(10, 10))
103
- title = ['Input Image', 'Label']
104
 
105
  for i in range(len(display_list)):
106
  ax = fig.add_subplot(1, len(display_list), i + 1)
107
  display_resized = tf.reshape(display_list[i], [256, 256])
108
- ax.set_title(title[i])
109
  ax.imshow(display_resized, cmap='gray')
110
  ax.axis('off')
111
 
112
  st.pyplot(fig)
113
 
 
 
 
 
 
 
 
 
 
 
114
  # Streamlit app interface
115
  st.title("Cardiac Images Dataset")
116
 
@@ -143,5 +159,15 @@ model.summary(print_fn=lambda x: model_summary.write(x + '\n'))
143
  # Display the model summary in Streamlit
144
  st.markdown(model_summary.getvalue())
145
 
146
- # plot the model including the sizes of the model
147
- tf.keras.utils.plot_model(model, show_shapes=True)
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  st.write(train_dataset)
100
 
101
+ # function to take a prediction from the model and output an image for display
102
+ def create_mask(pred_mask):
103
+ pred_mask = tf.argmax(pred_mask, axis=-1)
104
+ pred_mask = pred_mask[..., tf.newaxis]
105
+ return pred_mask[0]
106
+
107
  def display(display_list):
108
  fig = plt.figure(figsize=(10, 10))
109
+ title = ['Input Image', 'Label', 'Prediction'] # Updated title list
110
 
111
  for i in range(len(display_list)):
112
  ax = fig.add_subplot(1, len(display_list), i + 1)
113
  display_resized = tf.reshape(display_list[i], [256, 256])
114
+ ax.set_title(title[i]) # No longer out of range
115
  ax.imshow(display_resized, cmap='gray')
116
  ax.axis('off')
117
 
118
  st.pyplot(fig)
119
 
120
+ # helper function to show the image, the label and the prediction
121
+ def show_predictions(dataset=None, num=1):
122
+ if dataset:
123
+ for image, label in dataset.take(num):
124
+ pred_mask = model.predict(image)
125
+ display([image[0], label[0], create_mask(pred_mask)])
126
+ else:
127
+ prediction = create_mask(model.predict(sample_image[tf.newaxis, ...]))
128
+ display([sample_image, sample_label, prediction])
129
+
130
  # Streamlit app interface
131
  st.title("Cardiac Images Dataset")
132
 
 
159
  # Display the model summary in Streamlit
160
  st.markdown(model_summary.getvalue())
161
 
162
+ # Save the model plot
163
+ plot_filename = "model_plot.png"
164
+ tf.keras.utils.plot_model(model, to_file=plot_filename, show_shapes=True)
165
+
166
+ # Streamlit App
167
+ st.title("Model Architecture")
168
+
169
+ # Display the model plot
170
+ st.image(plot_filename, caption="Neural Network Architecture", use_container_width=True)
171
+
172
+ # show a predection, as an example
173
+ show_predictions(test_dataset)
model.png ADDED
model_plot.png ADDED