SebastianBravo commited on
Commit
ded489b
1 Parent(s): d6b54da

Added new features

Browse files
Files changed (2) hide show
  1. app.py +150 -88
  2. utils.py +11 -2
app.py CHANGED
@@ -1,43 +1,63 @@
1
  import os
2
  import utils
 
3
  import numpy as np
4
  import gradio as gr
5
  import tensorflow as tf
6
  import matplotlib.pyplot as plt
7
  from ttictoc import tic,toc
 
8
  from urllib.request import urlretrieve
9
 
10
- # '''--------------------------- Preprocesamiento ----------------------------'''
11
- # tic()
12
- # 3D U-Net\
13
  if not os.path.exists("unet.h5"):
14
  urlretrieve("https://dl.dropboxusercontent.com/s/ay5q8caqzlad7h5/unet.h5?dl=0", "unet.h5")
15
-
 
16
  if not os.path.exists("resnet_50_23dataset.pth"):
17
  urlretrieve("https://dl.dropboxusercontent.com/s/otxsgx3e31d5h9i/resnet_50_23dataset.pth?dl=0", "resnet_50_23dataset.pth")
18
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  path_3d_unet = 'unet.h5'
 
 
 
 
20
 
 
 
 
21
  with tf.device("cpu:0"):
22
  model_unet = utils.import_3d_unet(path_3d_unet)
23
 
24
- # # Cargar imagen
25
- # img = utils.load_img('F:/Downloads/ADNI_002_S_0295_MR_MP-RAGE__br_raw_20070525135721811_1_S32678_I55275.nii')
 
26
 
27
- # # Extraer cerebro
28
- # with tf.device("cpu:0"):
29
- # brain = utils.brain_stripping(img, model_unet)
30
- # print(toc())
31
 
32
- # '''---------------------------- Procesamiento ------------------------------'''
33
- # # Med net
34
- # weight_path = 'resnet_50_23dataset.pth'
35
- # device_ids = [0]
36
- # mednet = utils.create_mednet(weight_path, device_ids)
37
 
38
- # # Extraer caracter铆sticas
39
- # features = utils.get_features(brain, mednet)
40
 
 
41
  def load_img(file):
42
  sitk, array = utils.load_img(file.name)
43
 
@@ -66,28 +86,52 @@ def show_img(img, mri_slice, update):
66
 
67
  # return fig, gr.update(visible=True)
68
 
69
- def process_img(img, brain_slice, progress=gr.Progress()):
70
- progress(880,desc="Processing...")
71
 
72
  with tf.device("cpu:0"):
73
  brain = utils.brain_stripping(img, model_unet)
74
 
75
  fig, update_slider, _ = show_img(brain, brain_slice, update=True)
76
 
77
- return brain, fig, update_slider
78
 
79
- def clear():
80
- return gr.File.update(value=None), gr.Plot.update(value=None), gr.update(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- # gr.Textbox.update(placeholder='Ingrese nombre del paciente'), gr.Number.update(value=0),
 
83
 
84
- # demo = gr.Interface(fn=load_img,
85
- # inputs=gr.File(file_count="single", file_type=[".nii"]),
86
- # outputs=gr.Plot()
87
- # # outputs='text'
88
- # )
89
 
90
- # theme = gr.themes.Base().load('css_new.json')
91
 
92
  with gr.Blocks(theme=gr.themes.Base()) as demo:
93
  with gr.Row():
@@ -112,11 +156,36 @@ with gr.Blocks(theme=gr.themes.Base()) as demo:
112
  input_sex = gr.Dropdown(["Male", "Female"], label="Sex")
113
 
114
  with gr.Tab("Clinical data"):
115
- input_MMSE = gr.Number(label='MMSE')
116
- input_GDSCALE = gr.Number(label='GDSCALE')
117
- input_CDR = gr.Number(label='Global CDR')
118
- input_FAQ = gr.Number(label='FAQ Total Score')
119
- input_NPI_Q = gr.Number(label='NPI-Q Total Score')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  with gr.Tab("Vital Signs"):
122
  input_Diastolic_blood_pressure = gr.Number(label='Diastolic Blood Pressure(mm Hg)')
@@ -142,33 +211,47 @@ with gr.Blocks(theme=gr.themes.Base()) as demo:
142
  clear_button = gr.Button(value="Clear")
143
 
144
  # Bot贸n para procesar imagen
145
- process_button = gr.Button(value="Procesar", visible=False)
 
 
 
146
 
147
  # Outputs
148
  with gr.Column(variant="panel", scale=1):
149
  gr.Markdown('<h2 style="text-align: center; color:#235784;">MRI visualization</h2>')
150
 
151
- # Plot para im谩gen original
152
- plot_img_original = gr.Plot(label="Imagen MRI original")
 
 
 
 
 
 
 
 
 
 
153
 
154
- # Slider para im谩gen original
155
- mri_slider = gr.Slider(minimum=0,
156
- maximum=192,
157
- value=100,
158
- step=1,
159
- label="MRI Slice",
160
- visible=False)
161
-
162
- # Plot para im谩gen procesada
163
- plot_brain = gr.Plot(label="Imagen MRI procesada", visible=True)
164
 
165
- # Slider para im谩gen procesada
166
- brain_slider = gr.Slider(minimum=0,
167
- maximum=192,
168
- value=100,
169
- step=1,
170
- label="MRI Slice",
171
- visible=False)
 
 
 
 
 
 
172
 
173
  # componentes =
174
 
@@ -177,6 +260,7 @@ with gr.Blocks(theme=gr.themes.Base()) as demo:
177
  original_input_img = gr.State()
178
  brain_img = gr.State()
179
 
 
180
  update_true = gr.State(True)
181
  update_false = gr.State(False)
182
 
@@ -191,10 +275,6 @@ with gr.Blocks(theme=gr.themes.Base()) as demo:
191
  [original_input_img, mri_slider, update_true],
192
  [plot_img_original, mri_slider, process_button])
193
 
194
- # Limpiar campos
195
- clear_button.click(fn=clear,
196
- outputs=[input_file, plot_img_original, mri_slider])
197
-
198
  # Actualizar imagen original
199
  mri_slider.change(show_img,
200
  [original_input_img, mri_slider, update_false],
@@ -203,42 +283,24 @@ with gr.Blocks(theme=gr.themes.Base()) as demo:
203
  # Procesar imagen
204
  process_button.click(fn=process_img,
205
  inputs=[original_input_sitk, brain_slider],
206
- outputs=[brain_img,plot_brain,brain_slider])
207
 
208
  # Actualizar imagen procesada
209
  brain_slider.change(show_img,
210
  [brain_img, brain_slider, update_false],
211
  [plot_brain])
 
 
 
 
 
 
 
 
 
 
212
 
213
 
214
  if __name__ == "__main__":
215
- demo.queue(concurrency_count=20)
216
  demo.launch()
217
-
218
- # # Visualizaci贸n resultados
219
- # mri_slice = 100
220
-
221
- # # Plot Comparaci贸n m谩scaras
222
- # fig, axs = plt.subplots(1,2)
223
- # fig.subplots_adjust(bottom=0.15)
224
- # fig.suptitle('Comparaci贸n M谩scaras Obtenidas')
225
-
226
- # axs[0].set_title('MRI original')
227
- # axs[0].imshow(img[mri_slice,:,:],cmap='gray')
228
-
229
- # axs[1].set_title('Cerebro extraido con 3D U-Net')
230
- # axs[1].imshow(brain[mri_slice,:,:],cmap='gray')
231
-
232
-
233
- # # Slider para cambiar slice
234
- # ax_slider = plt.axes([0.15, 0.05, 0.75, 0.03])
235
- # mri_slice_slider = Slider(ax_slider, 'Slice', 0, 192, 100, valstep=1)
236
-
237
- # def update(val):
238
- # mri_slice = mri_slice_slider.val
239
-
240
- # axs[0].imshow(img[:,:,mri_slice],cmap='gray')
241
- # axs[1].imshow(brain[mri_slice,:,:],cmap='gray')
242
-
243
- # # Actualizar plot comparaci贸n m谩scaras
244
- # mri_slice_slider.on_changed(update)
 
1
  import os
2
  import utils
3
+ import pickle
4
  import numpy as np
5
  import gradio as gr
6
  import tensorflow as tf
7
  import matplotlib.pyplot as plt
8
  from ttictoc import tic,toc
9
+ from keras.models import load_model
10
  from urllib.request import urlretrieve
11
 
12
+ '''--------------------------- Descarga de modelos ----------------------------'''
13
+
14
+ # 3D U-Net
15
  if not os.path.exists("unet.h5"):
16
  urlretrieve("https://dl.dropboxusercontent.com/s/ay5q8caqzlad7h5/unet.h5?dl=0", "unet.h5")
17
+
18
+ # Med3D
19
  if not os.path.exists("resnet_50_23dataset.pth"):
20
  urlretrieve("https://dl.dropboxusercontent.com/s/otxsgx3e31d5h9i/resnet_50_23dataset.pth?dl=0", "resnet_50_23dataset.pth")
21
 
22
+ # Clasificador de im谩gen SVM
23
+ if not os.path.exists("svm_model.pickle"):
24
+ urlretrieve("https://dl.dropboxusercontent.com/s/n3tb3r6oyf06xfx/svm_model.pickle?dl=0", "svm_model.pickle")
25
+
26
+ # Nivel de riesgo
27
+ if not os.path.exists("mlp_probabilidad.h5"):
28
+ urlretrieve("https://dl.dropboxusercontent.com/s/78fjlg374mvjygd/mlp_probabilidad.h5?dl=0", "mlp_probabilidad.h5")
29
+
30
+ # Scaler para scores
31
+ if not os.path.exists("scaler.pickle"):
32
+ urlretrieve("https://dl.dropboxusercontent.com/s/ow6pe4k45r3xkbl/scaler.pickle?dl=0", "scaler.pickle")
33
+
34
  path_3d_unet = 'unet.h5'
35
+ weight_path = 'resnet_50_23dataset.pth'
36
+ svm_path = "svm_model.pickle"
37
+ prob_model_path = "mlp_probabilidad.h5"
38
+ scaler_path = "scaler.pickle"
39
 
40
+
41
+ '''---------------------------- Carga de modelos ------------------------------'''
42
+ # 3D U-Net
43
  with tf.device("cpu:0"):
44
  model_unet = utils.import_3d_unet(path_3d_unet)
45
 
46
+ # MedNet
47
+ device_ids = [7]
48
+ mednet_model = utils.create_mednet(weight_path, device_ids)
49
 
50
+ # SVM model
51
+ svm_model = pickle.load(open(svm_path, 'rb'))
 
 
52
 
53
+ # Nivel de riesgo
54
+ with tf.device("cpu:0"):
55
+ prob_model = load_model(prob_model_path)
 
 
56
 
57
+ # Scaler
58
+ scaler = pickle.load(open(scaler_path, 'rb'))
59
 
60
+ '''-------------------------------- Funciones ---------------------------------'''
61
  def load_img(file):
62
  sitk, array = utils.load_img(file.name)
63
 
 
86
 
87
  # return fig, gr.update(visible=True)
88
 
89
+ def process_img(img, brain_slice):
90
+ # progress(None,desc="Processing...")
91
 
92
  with tf.device("cpu:0"):
93
  brain = utils.brain_stripping(img, model_unet)
94
 
95
  fig, update_slider, _ = show_img(brain, brain_slice, update=True)
96
 
97
+ return brain, fig, update_slider, gr.update(visible=True)
98
 
99
+ def get_diagnosis(brain_img, age, MMSE, GDSCALE, CDR, FAQ, NPI, sex):
100
+ # Extracci贸n de caracter铆sticas de imagen
101
+ features = utils.get_features(brain_img, mednet_model)
102
+
103
+ # Clasificaci贸n de imagen
104
+ label_img = np.array([svm_model.predict(features)])
105
+
106
+ if sex == "Male":
107
+ sex_dum = 1
108
+ else:
109
+ sex_dum = 0
110
+
111
+ scores = np.array([age, MMSE, GDSCALE, CDR, FAQ, NPI, sex_dum, label_img])
112
+
113
+ print(scores)
114
+
115
+ # Normalizaci贸n de scores
116
+ scores_norm = scaler.transform(scores.reshape(1,-1))
117
+
118
+ print(scores_norm)
119
+
120
+ with tf.device("cpu:0"):
121
+ # Probabilidad de tener MCI
122
+ prob = prob_model.predict(scores_norm)[0,0]
123
+
124
+ # Probabilidad de tener MCI
125
+ print(prob)
126
+ diagnosis = f"The patient has a probability of {(100*prob):.2f}% of having MCI"
127
+
128
+ return gr.update(value=diagnosis)
129
 
130
+ def clear():
131
+ return gr.File.update(value=None), gr.Plot.update(value=None), gr.update(visible=False), gr.Plot.update(value=None), gr.update(visible=False), gr.update(value="The diagnosis will show here..."), gr.update(visible=False), gr.update(visible=False)
132
 
 
 
 
 
 
133
 
134
+ '''--------------------------------- Interfaz ---------------------------------'''
135
 
136
  with gr.Blocks(theme=gr.themes.Base()) as demo:
137
  with gr.Row():
 
156
  input_sex = gr.Dropdown(["Male", "Female"], label="Sex")
157
 
158
  with gr.Tab("Clinical data"):
159
+ input_MMSE = gr.Slider(minimum=0,
160
+ maximum=30,
161
+ value=0,
162
+ step=1,
163
+ label="MMSE total score")
164
+
165
+ input_GDSCALE = gr.Slider(minimum=0,
166
+ maximum=12,
167
+ value=0,
168
+ step=1,
169
+ label="GDSCALE total score")
170
+
171
+ input_CDR = gr.Slider(minimum=0,
172
+ maximum=3,
173
+ value=0,
174
+ step=0.5,
175
+ label="Global CDR")
176
+
177
+ input_FAQ = gr.Slider(minimum=0,
178
+ maximum=30,
179
+ value=0,
180
+ step=1,
181
+ label="FAQ total score")
182
+
183
+ input_NPI_Q = gr.Slider(minimum=0,
184
+ maximum=30,
185
+ value=0,
186
+ step=1,
187
+ label="NPI-Q total score")
188
+
189
 
190
  with gr.Tab("Vital Signs"):
191
  input_Diastolic_blood_pressure = gr.Number(label='Diastolic Blood Pressure(mm Hg)')
 
211
  clear_button = gr.Button(value="Clear")
212
 
213
  # Bot贸n para procesar imagen
214
+ process_button = gr.Button(value="Process MRI", visible=False, variant="primary")
215
+
216
+ # Bot贸n para obtener diagnostico
217
+ diagnostic_button = gr.Button(value="Get diagnosis", visible=False, variant="primary")
218
 
219
  # Outputs
220
  with gr.Column(variant="panel", scale=1):
221
  gr.Markdown('<h2 style="text-align: center; color:#235784;">MRI visualization</h2>')
222
 
223
+ with gr.Box():
224
+ gr.Markdown('<h4 style="color:#235784;">Loaded MRI</h4>')
225
+ # Plot para im谩gen original
226
+ plot_img_original = gr.Plot(show_label=False)
227
+
228
+ # Slider para im谩gen original
229
+ mri_slider = gr.Slider(minimum=0,
230
+ maximum=192,
231
+ value=100,
232
+ step=1,
233
+ label="MRI Slice",
234
+ visible=False)
235
 
236
+ with gr.Box():
237
+ gr.Markdown('<h4 style="color:#235784;">Proccessed MRI</h4>')
238
+
239
+ # Plot para im谩gen procesada
240
+ plot_brain = gr.Plot(show_label=False, visible=True)
 
 
 
 
 
241
 
242
+ # Slider para im谩gen procesada
243
+ brain_slider = gr.Slider(minimum=0,
244
+ maximum=192,
245
+ value=100,
246
+ step=1,
247
+ label="MRI Slice",
248
+ visible=False)
249
+
250
+ with gr.Box():
251
+ gr.Markdown('<h2 style="text-align: center; color:#235784;">Diagnosis</h2>')
252
+
253
+ # Texto del diagnostico
254
+ diagnosis_text = gr.Textbox(label="Diagnosis",interactive=False, placeholder="The diagnosis will show here...")
255
 
256
  # componentes =
257
 
 
260
  original_input_img = gr.State()
261
  brain_img = gr.State()
262
 
263
+
264
  update_true = gr.State(True)
265
  update_false = gr.State(False)
266
 
 
275
  [original_input_img, mri_slider, update_true],
276
  [plot_img_original, mri_slider, process_button])
277
 
 
 
 
 
278
  # Actualizar imagen original
279
  mri_slider.change(show_img,
280
  [original_input_img, mri_slider, update_false],
 
283
  # Procesar imagen
284
  process_button.click(fn=process_img,
285
  inputs=[original_input_sitk, brain_slider],
286
+ outputs=[brain_img,plot_brain,brain_slider, diagnostic_button])
287
 
288
  # Actualizar imagen procesada
289
  brain_slider.change(show_img,
290
  [brain_img, brain_slider, update_false],
291
  [plot_brain])
292
+
293
+ # Actualizar diagnostico
294
+ diagnostic_button.click(fn=get_diagnosis,
295
+ inputs=[brain_img, input_age, input_MMSE, input_GDSCALE, input_CDR, input_FAQ, input_NPI_Q, input_sex],
296
+ outputs=[diagnosis_text])
297
+
298
+ # Limpiar campos
299
+ clear_button.click(fn=clear,
300
+ outputs=[input_file, plot_img_original, mri_slider, plot_brain, brain_slider, diagnosis_text, process_button, diagnostic_button])
301
+
302
 
303
 
304
  if __name__ == "__main__":
305
+ # demo.queue(concurrency_count=20)
306
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py CHANGED
@@ -27,7 +27,6 @@ def import_3d_unet(path_3d_unet):
27
  return (2. * K.sum(flat_y_true * flat_y_pred) + smoothing_factor) / (K.sum(flat_y_true) + K.sum(flat_y_pred) + smoothing_factor)
28
 
29
  # Cargar modelo preentrenado
30
- # with tf.device('/cpu:0'):
31
  model = load_model(path_3d_unet, custom_objects={'dice_coefficient':dice_coefficient, 'iou_score':sm.metrics.IOUScore(threshold=0.5)})
32
  return model
33
 
@@ -179,4 +178,14 @@ def get_features(brain, mednet_model):
179
  features = features.cpu().numpy()
180
 
181
  torch.cuda.empty_cache()
182
- return features
 
 
 
 
 
 
 
 
 
 
 
27
  return (2. * K.sum(flat_y_true * flat_y_pred) + smoothing_factor) / (K.sum(flat_y_true) + K.sum(flat_y_pred) + smoothing_factor)
28
 
29
  # Cargar modelo preentrenado
 
30
  model = load_model(path_3d_unet, custom_objects={'dice_coefficient':dice_coefficient, 'iou_score':sm.metrics.IOUScore(threshold=0.5)})
31
  return model
32
 
 
178
  features = features.cpu().numpy()
179
 
180
  torch.cuda.empty_cache()
181
+ return features
182
+
183
+ # Classify image
184
+ def get_prediction(features, scores, svm_model, dl_model):
185
+ prediction = svm_model.predict(features)
186
+
187
+ # x = np.concatenate((scores, prediction))
188
+
189
+ # prob = dl_model.predict(x)
190
+
191
+ return prediction