Spaces:
Paused
Paused
# import os | |
import torch | |
import resnet | |
import numpy as np | |
import tensorflow as tf | |
# import nibabel as nib | |
import SimpleITK as sitk | |
import segmentation_models_3D as sm | |
from torch import nn | |
# from ttictoc import tic,toc | |
from skimage import morphology | |
from keras import backend as K | |
from scipy import ndimage as ndi | |
from keras.models import load_model | |
from patchify import patchify, unpatchify | |
# from matplotlib import pyplot as plt | |
# from matplotlib.widgets import Slider | |
# Funci贸n que retorna modelo 3D U-Net para extracci贸n de cerebro | |
def import_3d_unet(path_3d_unet): | |
# M茅tricas de desempe帽o | |
def dice_coefficient(y_true, y_pred): | |
smoothing_factor = 1 | |
flat_y_true = K.flatten(y_true) | |
flat_y_pred = K.flatten(y_pred) | |
return (2. * K.sum(flat_y_true * flat_y_pred) + smoothing_factor) / (K.sum(flat_y_true) + K.sum(flat_y_pred) + smoothing_factor) | |
# Cargar modelo preentrenado | |
model = load_model(path_3d_unet, custom_objects={'dice_coefficient':dice_coefficient, 'iou_score':sm.metrics.IOUScore(threshold=0.5)}) | |
return model | |
# Funci贸n que caraga imagen en formato nifti, aplica filtro N4 y normaliza imagen | |
def load_img(path): | |
# Lectura de MRI T1 formato nifti | |
inputImage = sitk.ReadImage(path, sitk.sitkFloat32) | |
return inputImage, sitk.GetArrayFromImage(inputImage).astype(np.float32) | |
# Funci贸n que remueve | |
def brain_stripping(inputImage, model_unet): | |
"""----------------------Preprocesamiento imagen MRI-----------------------""" | |
image = inputImage | |
# N4 Bias Field Correction | |
maskImage = sitk.OtsuThreshold(inputImage, 0, 1, 200) | |
corrector = sitk.N4BiasFieldCorrectionImageFilter() | |
corrected_image = corrector.Execute(image, maskImage) | |
log_bias_field = corrector.GetLogBiasFieldAsImage(inputImage) | |
corrected_image_full_resolution = inputImage / sitk.Exp(log_bias_field) | |
#Normalizaci贸n | |
image_normalized = sitk.GetArrayFromImage(corrected_image_full_resolution) | |
image_normalized = (image_normalized-np.min(image_normalized))/(np.max(image_normalized)-np.min(image_normalized)) | |
image_normalized = image_normalized.astype(np.float32) | |
# Redimenci贸n | |
mri_image = np.transpose(image_normalized) | |
mri_image = np.append(mri_image, np.zeros((192-mri_image.shape[0],256,256,)), axis=0) | |
# Rotaci贸n | |
mri_image = mri_image.astype(np.float32) | |
mri_image = np.rot90(mri_image, axes=(1,2)) | |
# Volume sampling | |
mri_patches = patchify(mri_image, (64, 64, 64), step=64) | |
"""--------------------Predicci贸n de m谩scara de cerebro--------------------""" | |
# M谩scara de cerebro para cada vol煤men | |
mask_patches = [] | |
for i in range(mri_patches.shape[0]): | |
for j in range(mri_patches.shape[1]): | |
for k in range(mri_patches.shape[2]): | |
single_patch = np.expand_dims(mri_patches[i,j,k,:,:,:], axis=0) | |
single_patch_prediction = model_unet.predict(single_patch) | |
single_patch_prediction_th = (single_patch_prediction[0,:,:,:,0] > 0.5).astype(np.uint8) | |
mask_patches.append(single_patch_prediction_th) | |
# Conversi贸n a numpy array | |
predicted_patches = np.array(mask_patches) | |
# Reshape para proceso de reconstrucci贸n | |
predicted_patches_reshaped = np.reshape(predicted_patches, | |
(mri_patches.shape[0], mri_patches.shape[1], mri_patches.shape[2], | |
mri_patches.shape[3], mri_patches.shape[4], mri_patches.shape[5]) ) | |
# Reconstrucci贸n m谩scara | |
reconstructed_mask = unpatchify(predicted_patches_reshaped, mri_image.shape) | |
# Suavizado m谩scara | |
corrected_mask = ndi.binary_closing(reconstructed_mask, structure=morphology.ball(2)).astype(np.uint8) | |
# Eliminaci贸n de volumenes ruido | |
no_noise_mask = corrected_mask.copy() | |
mask_labeled = morphology.label(corrected_mask, background=0, connectivity=3) | |
label_count = np.unique(mask_labeled, return_counts=True) | |
brain_label = np.argmax(label_count[1][1:]) + 1 | |
no_noise_mask[np.where(mask_labeled != brain_label)] = 0 | |
# Elimicaci贸n huecos y hendiduras | |
filled_mask = ndi.binary_closing(no_noise_mask, structure=morphology.ball(12)).astype(np.uint8) | |
"""-------------------------Extracci贸n de cerebro--------------------------""" | |
# Aplicar m谩scara a imagen mri | |
mri_brain = np.multiply(mri_image,filled_mask) | |
return mri_brain | |
# Funci贸n que retorna modelo MedNet | |
def create_mednet(weight_path, device_ids): | |
# Clase para agregar capa totalmente conectada | |
class simci_net(nn.Module): | |
def __init__(self): | |
super(simci_net, self).__init__() | |
self.pretrained_model = resnet.resnet50(sample_input_D=192, sample_input_H=256, sample_input_W=256, num_seg_classes=2, no_cuda = False) | |
self.pretrained_model.conv_seg = nn.Sequential(nn.AdaptiveMaxPool3d(output_size=(1, 1, 1)), | |
nn.Flatten(start_dim=1)) | |
def forward(self, x): | |
x = self.pretrained_model(x) | |
return x | |
# Path con pesos preentrenados | |
weight_path = weight_path | |
# Lista de GPUs para utilizar | |
device_ids = device_ids | |
# Generar red | |
simci_model = simci_net() | |
# Distribuir en varias GPUs | |
simci_model = torch.nn.DataParallel(simci_model, device_ids = device_ids) | |
simci_model.to(f'cuda:{simci_model.device_ids[0]}') | |
# Diccionario state | |
net_dict = simci_model.state_dict() | |
# Cargar pesos | |
weight = torch.load(weight_path, map_location=torch.device(f'cuda:{simci_model.device_ids[0]}')) | |
# Transferencia de aprendizaje | |
pretrain_dict = {} | |
for k, v in weight['state_dict'].items(): | |
if k.replace("module.", "module.pretrained_model.") in net_dict.keys(): | |
pretrain_dict[k.replace("module.", "module.pretrained_model.")] = v | |
# pretrain_dict = {k.replace("module.", ""): v for k, v in weight['state_dict'].items() if k.replace("module.", "") in net_dict.keys()} | |
net_dict.update(pretrain_dict) | |
simci_model.load_state_dict(net_dict) | |
# Bloqueo de parametros mednet | |
for param in simci_model.module.pretrained_model.parameters(): | |
param.requires_grad = False | |
simci_model.eval() # Modelo en modo evaluaci贸n | |
return simci_model | |
# Funci贸n que extrae caracter铆sticas de cerebro | |
def get_features(brain, mednet_model): | |
with torch.no_grad(): | |
# Convertir a tensor | |
data = torch.from_numpy(np.expand_dims(np.expand_dims(brain,axis=0), axis=0)) | |
# Enviar imagen a GPU | |
data = data.to(f'cuda:{mednet_model.device_ids[0]}') | |
# Extraer Caracter铆sticas | |
features = mednet_model(data) # Forward | |
features = features.cpu().numpy() | |
torch.cuda.empty_cache() | |
return features | |
# Classify image | |
def get_prediction(features, scores, svm_model, dl_model): | |
prediction = svm_model.predict(features) | |
# x = np.concatenate((scores, prediction)) | |
# prob = dl_model.predict(x) | |
return prediction | |