SIMCI / utils.py
SebastianBravo's picture
Added new features
ded489b
# import os
import torch
import resnet
import numpy as np
import tensorflow as tf
# import nibabel as nib
import SimpleITK as sitk
import segmentation_models_3D as sm
from torch import nn
# from ttictoc import tic,toc
from skimage import morphology
from keras import backend as K
from scipy import ndimage as ndi
from keras.models import load_model
from patchify import patchify, unpatchify
# from matplotlib import pyplot as plt
# from matplotlib.widgets import Slider
# Funci贸n que retorna modelo 3D U-Net para extracci贸n de cerebro
def import_3d_unet(path_3d_unet):
# M茅tricas de desempe帽o
def dice_coefficient(y_true, y_pred):
smoothing_factor = 1
flat_y_true = K.flatten(y_true)
flat_y_pred = K.flatten(y_pred)
return (2. * K.sum(flat_y_true * flat_y_pred) + smoothing_factor) / (K.sum(flat_y_true) + K.sum(flat_y_pred) + smoothing_factor)
# Cargar modelo preentrenado
model = load_model(path_3d_unet, custom_objects={'dice_coefficient':dice_coefficient, 'iou_score':sm.metrics.IOUScore(threshold=0.5)})
return model
# Funci贸n que caraga imagen en formato nifti, aplica filtro N4 y normaliza imagen
def load_img(path):
# Lectura de MRI T1 formato nifti
inputImage = sitk.ReadImage(path, sitk.sitkFloat32)
return inputImage, sitk.GetArrayFromImage(inputImage).astype(np.float32)
# Funci贸n que remueve
def brain_stripping(inputImage, model_unet):
"""----------------------Preprocesamiento imagen MRI-----------------------"""
image = inputImage
# N4 Bias Field Correction
maskImage = sitk.OtsuThreshold(inputImage, 0, 1, 200)
corrector = sitk.N4BiasFieldCorrectionImageFilter()
corrected_image = corrector.Execute(image, maskImage)
log_bias_field = corrector.GetLogBiasFieldAsImage(inputImage)
corrected_image_full_resolution = inputImage / sitk.Exp(log_bias_field)
#Normalizaci贸n
image_normalized = sitk.GetArrayFromImage(corrected_image_full_resolution)
image_normalized = (image_normalized-np.min(image_normalized))/(np.max(image_normalized)-np.min(image_normalized))
image_normalized = image_normalized.astype(np.float32)
# Redimenci贸n
mri_image = np.transpose(image_normalized)
mri_image = np.append(mri_image, np.zeros((192-mri_image.shape[0],256,256,)), axis=0)
# Rotaci贸n
mri_image = mri_image.astype(np.float32)
mri_image = np.rot90(mri_image, axes=(1,2))
# Volume sampling
mri_patches = patchify(mri_image, (64, 64, 64), step=64)
"""--------------------Predicci贸n de m谩scara de cerebro--------------------"""
# M谩scara de cerebro para cada vol煤men
mask_patches = []
for i in range(mri_patches.shape[0]):
for j in range(mri_patches.shape[1]):
for k in range(mri_patches.shape[2]):
single_patch = np.expand_dims(mri_patches[i,j,k,:,:,:], axis=0)
single_patch_prediction = model_unet.predict(single_patch)
single_patch_prediction_th = (single_patch_prediction[0,:,:,:,0] > 0.5).astype(np.uint8)
mask_patches.append(single_patch_prediction_th)
# Conversi贸n a numpy array
predicted_patches = np.array(mask_patches)
# Reshape para proceso de reconstrucci贸n
predicted_patches_reshaped = np.reshape(predicted_patches,
(mri_patches.shape[0], mri_patches.shape[1], mri_patches.shape[2],
mri_patches.shape[3], mri_patches.shape[4], mri_patches.shape[5]) )
# Reconstrucci贸n m谩scara
reconstructed_mask = unpatchify(predicted_patches_reshaped, mri_image.shape)
# Suavizado m谩scara
corrected_mask = ndi.binary_closing(reconstructed_mask, structure=morphology.ball(2)).astype(np.uint8)
# Eliminaci贸n de volumenes ruido
no_noise_mask = corrected_mask.copy()
mask_labeled = morphology.label(corrected_mask, background=0, connectivity=3)
label_count = np.unique(mask_labeled, return_counts=True)
brain_label = np.argmax(label_count[1][1:]) + 1
no_noise_mask[np.where(mask_labeled != brain_label)] = 0
# Elimicaci贸n huecos y hendiduras
filled_mask = ndi.binary_closing(no_noise_mask, structure=morphology.ball(12)).astype(np.uint8)
"""-------------------------Extracci贸n de cerebro--------------------------"""
# Aplicar m谩scara a imagen mri
mri_brain = np.multiply(mri_image,filled_mask)
return mri_brain
# Funci贸n que retorna modelo MedNet
def create_mednet(weight_path, device_ids):
# Clase para agregar capa totalmente conectada
class simci_net(nn.Module):
def __init__(self):
super(simci_net, self).__init__()
self.pretrained_model = resnet.resnet50(sample_input_D=192, sample_input_H=256, sample_input_W=256, num_seg_classes=2, no_cuda = False)
self.pretrained_model.conv_seg = nn.Sequential(nn.AdaptiveMaxPool3d(output_size=(1, 1, 1)),
nn.Flatten(start_dim=1))
def forward(self, x):
x = self.pretrained_model(x)
return x
# Path con pesos preentrenados
weight_path = weight_path
# Lista de GPUs para utilizar
device_ids = device_ids
# Generar red
simci_model = simci_net()
# Distribuir en varias GPUs
simci_model = torch.nn.DataParallel(simci_model, device_ids = device_ids)
simci_model.to(f'cuda:{simci_model.device_ids[0]}')
# Diccionario state
net_dict = simci_model.state_dict()
# Cargar pesos
weight = torch.load(weight_path, map_location=torch.device(f'cuda:{simci_model.device_ids[0]}'))
# Transferencia de aprendizaje
pretrain_dict = {}
for k, v in weight['state_dict'].items():
if k.replace("module.", "module.pretrained_model.") in net_dict.keys():
pretrain_dict[k.replace("module.", "module.pretrained_model.")] = v
# pretrain_dict = {k.replace("module.", ""): v for k, v in weight['state_dict'].items() if k.replace("module.", "") in net_dict.keys()}
net_dict.update(pretrain_dict)
simci_model.load_state_dict(net_dict)
# Bloqueo de parametros mednet
for param in simci_model.module.pretrained_model.parameters():
param.requires_grad = False
simci_model.eval() # Modelo en modo evaluaci贸n
return simci_model
# Funci贸n que extrae caracter铆sticas de cerebro
def get_features(brain, mednet_model):
with torch.no_grad():
# Convertir a tensor
data = torch.from_numpy(np.expand_dims(np.expand_dims(brain,axis=0), axis=0))
# Enviar imagen a GPU
data = data.to(f'cuda:{mednet_model.device_ids[0]}')
# Extraer Caracter铆sticas
features = mednet_model(data) # Forward
features = features.cpu().numpy()
torch.cuda.empty_cache()
return features
# Classify image
def get_prediction(features, scores, svm_model, dl_model):
prediction = svm_model.predict(features)
# x = np.concatenate((scores, prediction))
# prob = dl_model.predict(x)
return prediction