import os
import utils
import numpy as np
import gradio as gr
import tensorflow as tf
import matplotlib.pyplot as plt
from ttictoc import tic,toc
from urllib.request import urlretrieve
# '''--------------------------- Preprocesamiento ----------------------------'''
# tic()
# 3D U-Net\
if not os.path.exists("unet.h5"):
urlretrieve("https://dl.dropboxusercontent.com/s/ay5q8caqzlad7h5/unet.h5?dl=0", "unet.h5")
if not os.path.exists("resnet_50_23dataset.pth"):
urlretrieve("https://dl.dropboxusercontent.com/s/otxsgx3e31d5h9i/resnet_50_23dataset.pth?dl=0", "resnet_50_23dataset.pth")
path_3d_unet = 'unet.h5'
with tf.device("cpu:0"):
model_unet = utils.import_3d_unet(path_3d_unet)
# # Cargar imagen
# img = utils.load_img('F:/Downloads/ADNI_002_S_0295_MR_MP-RAGE__br_raw_20070525135721811_1_S32678_I55275.nii')
# # Extraer cerebro
# with tf.device("cpu:0"):
# brain = utils.brain_stripping(img, model_unet)
# print(toc())
# '''---------------------------- Procesamiento ------------------------------'''
# # Med net
# weight_path = 'resnet_50_23dataset.pth'
# device_ids = [0]
# mednet = utils.create_mednet(weight_path, device_ids)
# # Extraer características
# features = utils.get_features(brain, mednet)
def load_img(file):
sitk, array = utils.load_img(file.name)
# Redimención
mri_image = np.transpose(array)
mri_image = np.append(mri_image, np.zeros((192-mri_image.shape[0],256,256,)), axis=0)
# Rotación
mri_image = mri_image.astype(np.float32)
mri_image = np.rot90(mri_image, axes=(1,2))
return sitk, mri_image
def show_img(img, mri_slice, update):
fig = plt.figure()
plt.imshow(img[mri_slice,:,:], cmap='gray')
if update == True:
return fig, gr.update(visible=True), gr.update(visible=True)
else:
return fig
# def show_brain(brain, brain_slice):
# fig = plt.figure()
# plt.imshow(brain[brain_slice,:,:], cmap='gray')
# return fig, gr.update(visible=True)
def process_img(img, brain_slice, progress=gr.Progress()):
progress(880,desc="Processing...")
with tf.device("cpu:0"):
brain = utils.brain_stripping(img, model_unet)
fig, update_slider, _ = show_img(brain, brain_slice, update=True)
return brain, fig, update_slider
def clear():
return gr.File.update(value=None), gr.Plot.update(value=None), gr.update(visible=False)
# gr.Textbox.update(placeholder='Ingrese nombre del paciente'), gr.Number.update(value=0),
# demo = gr.Interface(fn=load_img,
# inputs=gr.File(file_count="single", file_type=[".nii"]),
# outputs=gr.Plot()
# # outputs='text'
# )
# theme = gr.themes.Base().load('css_new.json')
with gr.Blocks(theme=gr.themes.Base()) as demo:
with gr.Row():
# gr.HTML(r"""
""")
gr.HTML(r"""
""")
# gr.Markdown("""
# # SIMCI
# Interfaz de SIMCI
# """)
# Inputs
with gr.Row():
with gr.Column(variant="panel", scale=1):
gr.Markdown('Patient Information
')
with gr.Tab("Personal data"):
# Objeto para subir archivo nifti
input_name = gr.Textbox(placeholder='Enter the patient name', label='Patient name')
input_age = gr.Number(label='Age')
input_phone_num = gr.Number(label='Phone number')
input_emer_name = gr.Textbox(placeholder='Enter the emergency contact name', label='Emergency contact name')
input_emer_phone_num = gr.Number(label='Emergency contact name phone number')
input_sex = gr.Dropdown(["Male", "Female"], label="Sex")
with gr.Tab("Clinical data"):
input_MMSE = gr.Number(label='MMSE')
input_GDSCALE = gr.Number(label='GDSCALE')
input_CDR = gr.Number(label='Global CDR')
input_FAQ = gr.Number(label='FAQ Total Score')
input_NPI_Q = gr.Number(label='NPI-Q Total Score')
with gr.Tab("Vital Signs"):
input_Diastolic_blood_pressure = gr.Number(label='Diastolic Blood Pressure(mm Hg)')
input_Systolic_blood_pressure = gr.Number(label='Systolic Blood Pressure(mm Hg)')
input_Body_heigth = gr.Number(label='Body heigth (cm)')
input_Body_weight = gr.Number(label='Body weigth (kg)')
input_Heart_rate = gr.Number(label='Heart rate (bpm)')
input_Respiratory_rate = gr.Number(label='Respiratory rate (bpm)')
input_Body_temperature = gr.Number(label='Body temperature (°C)')
input_Pluse_oximetry = gr.Number(label='Pluse oximetry (%)')
with gr.Tab("Medications"):
input_medications = gr.Textbox(label='Medications', lines=5)
input_allergies = gr.Textbox(label='Allergies', lines=5)
input_file = gr.File(file_count="single", label="MRI Image File (.nii)")
with gr.Row():
# Botón para cargar imagen
load_img_button = gr.Button(value="Load")
# Botón para borrar
clear_button = gr.Button(value="Clear")
# Botón para procesar imagen
process_button = gr.Button(value="Procesar", visible=False)
# Outputs
with gr.Column(variant="panel", scale=1):
gr.Markdown('MRI visualization
')
# Plot para imágen original
plot_img_original = gr.Plot(label="Imagen MRI original")
# Slider para imágen original
mri_slider = gr.Slider(minimum=0,
maximum=192,
value=100,
step=1,
label="MRI Slice",
visible=False)
# Plot para imágen procesada
plot_brain = gr.Plot(label="Imagen MRI procesada", visible=True)
# Slider para imágen procesada
brain_slider = gr.Slider(minimum=0,
maximum=192,
value=100,
step=1,
label="MRI Slice",
visible=False)
# componentes =
# Variables
original_input_sitk = gr.State()
original_input_img = gr.State()
brain_img = gr.State()
update_true = gr.State(True)
update_false = gr.State(False)
# Cambios
# Cargar imagen nueva
input_file.change(load_img,
input_file,
[original_input_sitk, original_input_img])
# Mostrar imagen nueva
load_img_button.click(show_img,
[original_input_img, mri_slider, update_true],
[plot_img_original, mri_slider, process_button])
# Limpiar campos
clear_button.click(fn=clear,
outputs=[input_file, plot_img_original, mri_slider])
# Actualizar imagen original
mri_slider.change(show_img,
[original_input_img, mri_slider, update_false],
[plot_img_original])
# Procesar imagen
process_button.click(fn=process_img,
inputs=[original_input_sitk, brain_slider],
outputs=[brain_img,plot_brain,brain_slider])
# Actualizar imagen procesada
brain_slider.change(show_img,
[brain_img, brain_slider, update_false],
[plot_brain])
if __name__ == "__main__":
demo.queue(concurrency_count=20)
demo.launch()
# # Visualización resultados
# mri_slice = 100
# # Plot Comparación máscaras
# fig, axs = plt.subplots(1,2)
# fig.subplots_adjust(bottom=0.15)
# fig.suptitle('Comparación Máscaras Obtenidas')
# axs[0].set_title('MRI original')
# axs[0].imshow(img[mri_slice,:,:],cmap='gray')
# axs[1].set_title('Cerebro extraido con 3D U-Net')
# axs[1].imshow(brain[mri_slice,:,:],cmap='gray')
# # Slider para cambiar slice
# ax_slider = plt.axes([0.15, 0.05, 0.75, 0.03])
# mri_slice_slider = Slider(ax_slider, 'Slice', 0, 192, 100, valstep=1)
# def update(val):
# mri_slice = mri_slice_slider.val
# axs[0].imshow(img[:,:,mri_slice],cmap='gray')
# axs[1].imshow(brain[mri_slice,:,:],cmap='gray')
# # Actualizar plot comparación máscaras
# mri_slice_slider.on_changed(update)