SebastianBravo commited on
Commit
ca41ad4
1 Parent(s): f9cd822

initial commit

Browse files
Files changed (3) hide show
  1. app.py +185 -0
  2. resnet.py +263 -0
  3. utils.py +182 -0
app.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import utils
2
+ import gradio as gr
3
+ import tensorflow as tf
4
+ import matplotlib.pyplot as plt
5
+ from ttictoc import tic,toc
6
+
7
+ # '''--------------------------- Preprocesamiento ----------------------------'''
8
+ # tic()
9
+ # 3D U-Net
10
+ path_3d_unet = 'F:/Desktop/Universidad/Semestres/NovenoSemestre/Proyecto_de_Grado/Codigo/3D_U-Net/outputs/checkpoints/model.49-0.97.h5'
11
+
12
+ with tf.device("cpu:0"):
13
+ model_unet = utils.import_3d_unet(path_3d_unet)
14
+
15
+ # # Cargar imagen
16
+ # img = utils.load_img('F:/Downloads/ADNI_002_S_0295_MR_MP-RAGE__br_raw_20070525135721811_1_S32678_I55275.nii')
17
+
18
+ # # Extraer cerebro
19
+ # with tf.device("cpu:0"):
20
+ # brain = utils.brain_stripping(img, model_unet)
21
+ # print(toc())
22
+
23
+ # '''---------------------------- Procesamiento ------------------------------'''
24
+ # # Med net
25
+ # weight_path = 'F:/Desktop/Universidad/Semestres/NovenoSemestre/Proyecto_de_Grado/Codigo/Procesamiento/mednet_weights/pretrain/resnet_50_23dataset.pth'
26
+ # device_ids = [0]
27
+ # mednet = utils.create_mednet(weight_path, device_ids)
28
+
29
+ # # Extraer caracter铆sticas
30
+ # features = utils.get_features(brain, mednet)
31
+
32
+ def load_img(file):
33
+ sitk, array = utils.load_img(file.name)
34
+ return sitk, array
35
+
36
+ def show_img(img, mri_slice):
37
+ fig = plt.figure()
38
+ plt.imshow(img[:,:,mri_slice], cmap='gray')
39
+
40
+ return fig, gr.update(visible=True)
41
+
42
+ def show_brain(brain, brain_slice):
43
+ fig = plt.figure()
44
+ plt.imshow(brain[brain_slice,:,:], cmap='gray')
45
+
46
+ return fig, gr.update(visible=True)
47
+
48
+ def process_img(img, brain_slice):
49
+ with tf.device("cpu:0"):
50
+ brain = utils.brain_stripping(img, model_unet)
51
+
52
+ fig, update = show_brain(brain, brain_slice)
53
+
54
+ return brain, fig, update
55
+
56
+ def clear():
57
+ return gr.File.update(value=None), gr.Plot.update(value=None), gr.update(visible=False)
58
+
59
+ # gr.Textbox.update(placeholder='Ingrese nombre del paciente'), gr.Number.update(value=0),
60
+
61
+ # demo = gr.Interface(fn=load_img,
62
+ # inputs=gr.File(file_count="single", file_type=[".nii"]),
63
+ # outputs=gr.Plot()
64
+ # # outputs='text'
65
+ # )
66
+
67
+ with gr.Blocks() as demo:
68
+ with gr.Row():
69
+ gr.Markdown("""
70
+ # SIMCI
71
+ Interfaz de SIMCI
72
+ """)
73
+
74
+ # Inputs
75
+ with gr.Row():
76
+ with gr.Column(scale=1):
77
+ # Objeto para subir archivo nifti
78
+ input_name = gr.Textbox(placeholder='Ingrese nombre del paciente', label='Nombre')
79
+ input_age = gr.Number(label='Edad')
80
+
81
+
82
+ input_file = gr.File(file_count="single", file_type=[".nii"], label="Archivo Imagen MRI")
83
+
84
+ with gr.Row():
85
+ # Bot贸n para cargar imagen
86
+ load_img_button = gr.Button(value="Load")
87
+
88
+ # Bot贸n para borrar
89
+ clear_button = gr.Button(value="Clear")
90
+
91
+ # Bot贸n para procesar imagen
92
+ process_button = gr.Button(value="Procesar")
93
+
94
+ # Outputs
95
+ with gr.Column(scale=1):
96
+ # Plot para im谩gen original
97
+ plot_img_original = gr.Plot(label="Imagen MRI original")
98
+
99
+ # Slider para im谩gen original
100
+ mri_slider = gr.Slider(minimum=0,
101
+ maximum=166,
102
+ value=100,
103
+ step=1,
104
+ label="MRI Slice",
105
+ visible=False)
106
+
107
+ # Plot para im谩gen procesada
108
+ plot_brain = gr.Plot(label="Imagen MRI procesada")
109
+
110
+ # Slider para im谩gen procesada
111
+ brain_slider = gr.Slider(minimum=0,
112
+ maximum=192,
113
+ value=100,
114
+ step=1,
115
+ label="MRI Slice",
116
+ visible=False)
117
+
118
+ # componentes =
119
+
120
+ # Variables
121
+ original_input_sitk = gr.State()
122
+ original_input_img = gr.State()
123
+ brain_img = gr.State()
124
+
125
+ # Cambios
126
+ # Cargar imagen nueva
127
+ input_file.change(load_img,
128
+ input_file,
129
+ [original_input_sitk, original_input_img])
130
+
131
+ # Mostrar imagen nueva
132
+ load_img_button.click(show_img,
133
+ [original_input_img, mri_slider],
134
+ [plot_img_original, mri_slider])
135
+
136
+ # Limpiar campos
137
+ clear_button.click(fn=clear,
138
+ outputs=[input_file, plot_img_original, mri_slider])
139
+
140
+ # Actualizar imagen original
141
+ mri_slider.change(show_img,
142
+ [original_input_img, mri_slider],
143
+ [plot_img_original,mri_slider])
144
+
145
+ # Procesar imagen
146
+ process_button.click(fn=process_img,
147
+ inputs=[original_input_sitk, brain_slider],
148
+ outputs=[brain_img,plot_brain,brain_slider])
149
+
150
+ # Actualizar imagen procesada
151
+ brain_slider.change(show_brain,
152
+ [brain_img, brain_slider],
153
+ [plot_brain,brain_slider])
154
+
155
+
156
+ if __name__ == "__main__":
157
+ demo.launch()
158
+
159
+ # # Visualizaci贸n resultados
160
+ # mri_slice = 100
161
+
162
+ # # Plot Comparaci贸n m谩scaras
163
+ # fig, axs = plt.subplots(1,2)
164
+ # fig.subplots_adjust(bottom=0.15)
165
+ # fig.suptitle('Comparaci贸n M谩scaras Obtenidas')
166
+
167
+ # axs[0].set_title('MRI original')
168
+ # axs[0].imshow(img[mri_slice,:,:],cmap='gray')
169
+
170
+ # axs[1].set_title('Cerebro extraido con 3D U-Net')
171
+ # axs[1].imshow(brain[mri_slice,:,:],cmap='gray')
172
+
173
+
174
+ # # Slider para cambiar slice
175
+ # ax_slider = plt.axes([0.15, 0.05, 0.75, 0.03])
176
+ # mri_slice_slider = Slider(ax_slider, 'Slice', 0, 192, 100, valstep=1)
177
+
178
+ # def update(val):
179
+ # mri_slice = mri_slice_slider.val
180
+
181
+ # axs[0].imshow(img[:,:,mri_slice],cmap='gray')
182
+ # axs[1].imshow(brain[mri_slice,:,:],cmap='gray')
183
+
184
+ # # Actualizar plot comparaci贸n m谩scaras
185
+ # mri_slice_slider.on_changed(update)
resnet.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Variable
5
+ import math
6
+ from functools import partial
7
+
8
+ __all__ = [
9
+ 'ResNet', 'resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
10
+ 'resnet152', 'resnet200'
11
+ ]
12
+
13
+
14
+ def conv3x3x3(in_planes, out_planes, stride=1, dilation=1):
15
+ # 3x3x3 convolution with padding
16
+ return nn.Conv3d(
17
+ in_planes,
18
+ out_planes,
19
+ kernel_size=3,
20
+ dilation=dilation,
21
+ stride=stride,
22
+ padding=dilation,
23
+ bias=False)
24
+
25
+
26
+ def downsample_basic_block(x, planes, stride, no_cuda=False):
27
+ out = F.avg_pool3d(x, kernel_size=1, stride=stride)
28
+ zero_pads = torch.Tensor(
29
+ out.size(0), planes - out.size(1), out.size(2), out.size(3),
30
+ out.size(4)).zero_()
31
+ if not no_cuda:
32
+ if isinstance(out.data, torch.cuda.FloatTensor):
33
+ zero_pads = zero_pads.cuda()
34
+
35
+ out = Variable(torch.cat([out.data, zero_pads], dim=1))
36
+
37
+ return out
38
+
39
+
40
+ class BasicBlock(nn.Module):
41
+ expansion = 1
42
+
43
+ def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
44
+ super(BasicBlock, self).__init__()
45
+ self.conv1 = conv3x3x3(inplanes, planes, stride=stride, dilation=dilation)
46
+ self.bn1 = nn.BatchNorm3d(planes)
47
+ self.relu = nn.ReLU(inplace=True)
48
+ self.conv2 = conv3x3x3(planes, planes, dilation=dilation)
49
+ self.bn2 = nn.BatchNorm3d(planes)
50
+ self.downsample = downsample
51
+ self.stride = stride
52
+ self.dilation = dilation
53
+
54
+ def forward(self, x):
55
+ residual = x
56
+
57
+ out = self.conv1(x)
58
+ out = self.bn1(out)
59
+ out = self.relu(out)
60
+ out = self.conv2(out)
61
+ out = self.bn2(out)
62
+
63
+ if self.downsample is not None:
64
+ residual = self.downsample(x)
65
+
66
+ out += residual
67
+ out = self.relu(out)
68
+
69
+ return out
70
+
71
+
72
+ class Bottleneck(nn.Module):
73
+ expansion = 4
74
+
75
+ def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
76
+ super(Bottleneck, self).__init__()
77
+ self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
78
+ self.bn1 = nn.BatchNorm3d(planes)
79
+ self.conv2 = nn.Conv3d(
80
+ planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False)
81
+ self.bn2 = nn.BatchNorm3d(planes)
82
+ self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
83
+ self.bn3 = nn.BatchNorm3d(planes * 4)
84
+ self.relu = nn.ReLU(inplace=True)
85
+ self.downsample = downsample
86
+ self.stride = stride
87
+ self.dilation = dilation
88
+
89
+ def forward(self, x):
90
+ residual = x
91
+
92
+ out = self.conv1(x)
93
+ out = self.bn1(out)
94
+ out = self.relu(out)
95
+
96
+ out = self.conv2(out)
97
+ out = self.bn2(out)
98
+ out = self.relu(out)
99
+
100
+ out = self.conv3(out)
101
+ out = self.bn3(out)
102
+
103
+ if self.downsample is not None:
104
+ residual = self.downsample(x)
105
+
106
+ out += residual
107
+ out = self.relu(out)
108
+
109
+ return out
110
+
111
+
112
+ class ResNet(nn.Module):
113
+
114
+ def __init__(self,
115
+ block,
116
+ layers,
117
+ sample_input_D,
118
+ sample_input_H,
119
+ sample_input_W,
120
+ num_seg_classes,
121
+ shortcut_type='B',
122
+ no_cuda = False):
123
+ self.inplanes = 64
124
+ self.no_cuda = no_cuda
125
+ super(ResNet, self).__init__()
126
+ self.conv1 = nn.Conv3d(
127
+ 1,
128
+ 64,
129
+ kernel_size=7,
130
+ stride=(2, 2, 2),
131
+ padding=(3, 3, 3),
132
+ bias=False)
133
+
134
+ self.bn1 = nn.BatchNorm3d(64)
135
+ self.relu = nn.ReLU(inplace=True)
136
+ self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
137
+ self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type)
138
+ self.layer2 = self._make_layer(
139
+ block, 128, layers[1], shortcut_type, stride=2)
140
+ self.layer3 = self._make_layer(
141
+ block, 256, layers[2], shortcut_type, stride=1, dilation=2)
142
+ self.layer4 = self._make_layer(
143
+ block, 512, layers[3], shortcut_type, stride=1, dilation=4)
144
+
145
+ self.conv_seg = nn.Sequential(
146
+ nn.ConvTranspose3d(
147
+ 512 * block.expansion,
148
+ 32,
149
+ 2,
150
+ stride=2
151
+ ),
152
+ nn.BatchNorm3d(32),
153
+ nn.ReLU(inplace=True),
154
+ nn.Conv3d(
155
+ 32,
156
+ 32,
157
+ kernel_size=3,
158
+ stride=(1, 1, 1),
159
+ padding=(1, 1, 1),
160
+ bias=False),
161
+ nn.BatchNorm3d(32),
162
+ nn.ReLU(inplace=True),
163
+ nn.Conv3d(
164
+ 32,
165
+ num_seg_classes,
166
+ kernel_size=1,
167
+ stride=(1, 1, 1),
168
+ bias=False)
169
+ )
170
+
171
+ for m in self.modules():
172
+ if isinstance(m, nn.Conv3d):
173
+ m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
174
+ elif isinstance(m, nn.BatchNorm3d):
175
+ m.weight.data.fill_(1)
176
+ m.bias.data.zero_()
177
+
178
+ def _make_layer(self, block, planes, blocks, shortcut_type, stride=1, dilation=1):
179
+ downsample = None
180
+ if stride != 1 or self.inplanes != planes * block.expansion:
181
+ if shortcut_type == 'A':
182
+ downsample = partial(
183
+ downsample_basic_block,
184
+ planes=planes * block.expansion,
185
+ stride=stride,
186
+ no_cuda=self.no_cuda)
187
+ else:
188
+ downsample = nn.Sequential(
189
+ nn.Conv3d(
190
+ self.inplanes,
191
+ planes * block.expansion,
192
+ kernel_size=1,
193
+ stride=stride,
194
+ bias=False), nn.BatchNorm3d(planes * block.expansion))
195
+
196
+ layers = []
197
+ layers.append(block(self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample))
198
+ self.inplanes = planes * block.expansion
199
+ for i in range(1, blocks):
200
+ layers.append(block(self.inplanes, planes, dilation=dilation))
201
+
202
+ return nn.Sequential(*layers)
203
+
204
+ def forward(self, x):
205
+ x = self.conv1(x)
206
+ x = self.bn1(x)
207
+ x = self.relu(x)
208
+ x = self.maxpool(x)
209
+ x = self.layer1(x)
210
+ x = self.layer2(x)
211
+ x = self.layer3(x)
212
+ x = self.layer4(x)
213
+ x = self.conv_seg(x)
214
+
215
+ return x
216
+
217
+ def resnet10(**kwargs):
218
+ """Constructs a ResNet-18 model.
219
+ """
220
+ model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs)
221
+ return model
222
+
223
+
224
+ def resnet18(**kwargs):
225
+ """Constructs a ResNet-18 model.
226
+ """
227
+ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
228
+ return model
229
+
230
+
231
+ def resnet34(**kwargs):
232
+ """Constructs a ResNet-34 model.
233
+ """
234
+ model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
235
+ return model
236
+
237
+
238
+ def resnet50(**kwargs):
239
+ """Constructs a ResNet-50 model.
240
+ """
241
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
242
+ return model
243
+
244
+
245
+ def resnet101(**kwargs):
246
+ """Constructs a ResNet-101 model.
247
+ """
248
+ model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
249
+ return model
250
+
251
+
252
+ def resnet152(**kwargs):
253
+ """Constructs a ResNet-101 model.
254
+ """
255
+ model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
256
+ return model
257
+
258
+
259
+ def resnet200(**kwargs):
260
+ """Constructs a ResNet-101 model.
261
+ """
262
+ model = ResNet(Bottleneck, [3, 24, 36, 3], **kwargs)
263
+ return model
utils.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import os
2
+ import torch
3
+ import resnet
4
+ import numpy as np
5
+ import tensorflow as tf
6
+ # import nibabel as nib
7
+ import SimpleITK as sitk
8
+ import segmentation_models_3D as sm
9
+ from torch import nn
10
+ # from ttictoc import tic,toc
11
+ from skimage import morphology
12
+ from keras import backend as K
13
+ from scipy import ndimage as ndi
14
+ from keras.models import load_model
15
+ from patchify import patchify, unpatchify
16
+
17
+ # from matplotlib import pyplot as plt
18
+ # from matplotlib.widgets import Slider
19
+
20
+ # Funci贸n que retorna modelo 3D U-Net para extracci贸n de cerebro
21
+ def import_3d_unet(path_3d_unet):
22
+ # M茅tricas de desempe帽o
23
+ def dice_coefficient(y_true, y_pred):
24
+ smoothing_factor = 1
25
+ flat_y_true = K.flatten(y_true)
26
+ flat_y_pred = K.flatten(y_pred)
27
+ return (2. * K.sum(flat_y_true * flat_y_pred) + smoothing_factor) / (K.sum(flat_y_true) + K.sum(flat_y_pred) + smoothing_factor)
28
+
29
+ # Cargar modelo preentrenado
30
+ # with tf.device('/cpu:0'):
31
+ model = load_model(path_3d_unet, custom_objects={'dice_coefficient':dice_coefficient, 'iou_score':sm.metrics.IOUScore(threshold=0.5)})
32
+ return model
33
+
34
+
35
+ # Funci贸n que caraga imagen en formato nifti, aplica filtro N4 y normaliza imagen
36
+ def load_img(path):
37
+ # Lectura de MRI T1 formato nifti
38
+ inputImage = sitk.ReadImage(path, sitk.sitkFloat32)
39
+
40
+ return inputImage, sitk.GetArrayFromImage(inputImage).astype(np.float32)
41
+
42
+ # Funci贸n que remueve
43
+ def brain_stripping(inputImage, model_unet):
44
+ """----------------------Preprocesamiento imagen MRI-----------------------"""
45
+ image = inputImage
46
+
47
+ # N4 Bias Field Correction
48
+ maskImage = sitk.OtsuThreshold(inputImage, 0, 1, 200)
49
+ corrector = sitk.N4BiasFieldCorrectionImageFilter()
50
+ corrected_image = corrector.Execute(image, maskImage)
51
+ log_bias_field = corrector.GetLogBiasFieldAsImage(inputImage)
52
+ corrected_image_full_resolution = inputImage / sitk.Exp(log_bias_field)
53
+
54
+ #Normalizaci贸n
55
+ image_normalized = sitk.GetArrayFromImage(corrected_image_full_resolution)
56
+ image_normalized = (image_normalized-np.min(image_normalized))/(np.max(image_normalized)-np.min(image_normalized))
57
+ image_normalized = image_normalized.astype(np.float32)
58
+
59
+ # Redimenci贸n
60
+ mri_image = np.transpose(image_normalized)
61
+ mri_image = np.append(mri_image, np.zeros((192-mri_image.shape[0],256,256,)), axis=0)
62
+
63
+ # Rotaci贸n
64
+ mri_image = mri_image.astype(np.float32)
65
+ mri_image = np.rot90(mri_image, axes=(1,2))
66
+
67
+ # Volume sampling
68
+ mri_patches = patchify(mri_image, (64, 64, 64), step=64)
69
+
70
+ """--------------------Predicci贸n de m谩scara de cerebro--------------------"""
71
+ # M谩scara de cerebro para cada vol煤men
72
+ mask_patches = []
73
+
74
+ for i in range(mri_patches.shape[0]):
75
+ for j in range(mri_patches.shape[1]):
76
+ for k in range(mri_patches.shape[2]):
77
+ single_patch = np.expand_dims(mri_patches[i,j,k,:,:,:], axis=0)
78
+ single_patch_prediction = model_unet.predict(single_patch)
79
+ single_patch_prediction_th = (single_patch_prediction[0,:,:,:,0] > 0.5).astype(np.uint8)
80
+ mask_patches.append(single_patch_prediction_th)
81
+
82
+ # Conversi贸n a numpy array
83
+ predicted_patches = np.array(mask_patches)
84
+
85
+ # Reshape para proceso de reconstrucci贸n
86
+ predicted_patches_reshaped = np.reshape(predicted_patches,
87
+ (mri_patches.shape[0], mri_patches.shape[1], mri_patches.shape[2],
88
+ mri_patches.shape[3], mri_patches.shape[4], mri_patches.shape[5]) )
89
+
90
+ # Reconstrucci贸n m谩scara
91
+ reconstructed_mask = unpatchify(predicted_patches_reshaped, mri_image.shape)
92
+
93
+ # Suavizado m谩scara
94
+ corrected_mask = ndi.binary_closing(reconstructed_mask, structure=morphology.ball(2)).astype(np.uint8)
95
+
96
+ # Eliminaci贸n de volumenes ruido
97
+ no_noise_mask = corrected_mask.copy()
98
+ mask_labeled = morphology.label(corrected_mask, background=0, connectivity=3)
99
+ label_count = np.unique(mask_labeled, return_counts=True)
100
+ brain_label = np.argmax(label_count[1][1:]) + 1
101
+
102
+ no_noise_mask[np.where(mask_labeled != brain_label)] = 0
103
+
104
+ # Elimicaci贸n huecos y hendiduras
105
+ filled_mask = ndi.binary_closing(no_noise_mask, structure=morphology.ball(12)).astype(np.uint8)
106
+
107
+ """-------------------------Extracci贸n de cerebro--------------------------"""
108
+ # Aplicar m谩scara a imagen mri
109
+ mri_brain = np.multiply(mri_image,filled_mask)
110
+
111
+ return mri_brain
112
+
113
+ # Funci贸n que retorna modelo MedNet
114
+ def create_mednet(weight_path, device_ids):
115
+ # Clase para agregar capa totalmente conectada
116
+ class simci_net(nn.Module):
117
+ def __init__(self):
118
+ super(simci_net, self).__init__()
119
+
120
+ self.pretrained_model = resnet.resnet50(sample_input_D=192, sample_input_H=256, sample_input_W=256, num_seg_classes=2, no_cuda = False)
121
+ self.pretrained_model.conv_seg = nn.Sequential(nn.AdaptiveMaxPool3d(output_size=(1, 1, 1)),
122
+ nn.Flatten(start_dim=1))
123
+
124
+
125
+ def forward(self, x):
126
+ x = self.pretrained_model(x)
127
+
128
+ return x
129
+
130
+ # Path con pesos preentrenados
131
+ weight_path = weight_path
132
+
133
+ # Lista de GPUs para utilizar
134
+ device_ids = device_ids
135
+
136
+ # Generar red
137
+ simci_model = simci_net()
138
+
139
+ # Distribuir en varias GPUs
140
+ simci_model = torch.nn.DataParallel(simci_model, device_ids = device_ids)
141
+ simci_model.to(f'cuda:{simci_model.device_ids[0]}')
142
+
143
+ # Diccionario state
144
+ net_dict = simci_model.state_dict()
145
+
146
+ # Cargar pesos
147
+ weight = torch.load(weight_path, map_location=torch.device(f'cuda:{simci_model.device_ids[0]}'))
148
+
149
+ # Transferencia de aprendizaje
150
+ pretrain_dict = {}
151
+
152
+ for k, v in weight['state_dict'].items():
153
+ if k.replace("module.", "module.pretrained_model.") in net_dict.keys():
154
+ pretrain_dict[k.replace("module.", "module.pretrained_model.")] = v
155
+
156
+ # pretrain_dict = {k.replace("module.", ""): v for k, v in weight['state_dict'].items() if k.replace("module.", "") in net_dict.keys()}
157
+ net_dict.update(pretrain_dict)
158
+ simci_model.load_state_dict(net_dict)
159
+
160
+ # Bloqueo de parametros mednet
161
+ for param in simci_model.module.pretrained_model.parameters():
162
+ param.requires_grad = False
163
+
164
+ simci_model.eval() # Modelo en modo evaluaci贸n
165
+
166
+ return simci_model
167
+
168
+ # Funci贸n que extrae caracter铆sticas de cerebro
169
+ def get_features(brain, mednet_model):
170
+ with torch.no_grad():
171
+ # Convertir a tensor
172
+ data = torch.from_numpy(np.expand_dims(np.expand_dims(brain,axis=0), axis=0))
173
+
174
+ # Enviar imagen a GPU
175
+ data = data.to(f'cuda:{mednet_model.device_ids[0]}')
176
+
177
+ # Extraer Caracter铆sticas
178
+ features = mednet_model(data) # Forward
179
+ features = features.cpu().numpy()
180
+
181
+ torch.cuda.empty_cache()
182
+ return features