MRI-Image / ResUNet.py
TharunSiva's picture
util files
c96e8c8 verified
import os
import numpy as np
import pandas as pd
import cv2
import tensorflow as tf
from tensorflow.python.keras import Sequential
from tensorflow.keras import layers, optimizers
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras.initializers import glorot_uniform
from tensorflow.keras.utils import plot_model
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint, LearningRateScheduler
import tensorflow.keras.backend as K
# lets create model now
def resblock(X, f):
'''
function for creating res block
'''
X_copy = X #copy of input
# main path
X = Conv2D(f, kernel_size=(1,1), kernel_initializer='he_normal')(X)
X = BatchNormalization()(X)
X = Activation('relu')(X)
X = Conv2D(f, kernel_size=(3,3), padding='same', kernel_initializer='he_normal')(X)
X = BatchNormalization()(X)
# shortcut path
X_copy = Conv2D(f, kernel_size=(1,1), kernel_initializer='he_normal')(X_copy)
X_copy = BatchNormalization()(X_copy)
# Adding the output from main path and short path together
X = Add()([X, X_copy])
X = Activation('relu')(X)
return X
def upsample_concat(x, skip):
'''
funtion for upsampling image
'''
X = UpSampling2D((2,2))(x)
merge = Concatenate()([X, skip])
return merge
def load_model():
input_shape = (256,256,3)
X_input = Input(input_shape) #iniating tensor of input shape
# Stage 1
conv_1 = Conv2D(16, 3, activation='relu', padding='same', kernel_initializer='he_normal')(X_input)
conv_1 = BatchNormalization()(conv_1)
conv_1 = Conv2D(16, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv_1)
conv_1 = BatchNormalization()(conv_1)
pool_1 = MaxPool2D((2,2))(conv_1)
# stage 2
conv_2 = resblock(pool_1, 32)
pool_2 = MaxPool2D((2,2))(conv_2)
# Stage 3
conv_3 = resblock(pool_2, 64)
pool_3 = MaxPool2D((2,2))(conv_3)
# Stage 4
conv_4 = resblock(pool_3, 128)
pool_4 = MaxPool2D((2,2))(conv_4)
# Stage 5 (bottle neck)
conv_5 = resblock(pool_4, 256)
# Upsample Stage 1
up_1 = upsample_concat(conv_5, conv_4)
up_1 = resblock(up_1, 128)
# Upsample Stage 2
up_2 = upsample_concat(up_1, conv_3)
up_2 = resblock(up_2, 64)
# Upsample Stage 3
up_3 = upsample_concat(up_2, conv_2)
up_3 = resblock(up_3, 32)
# Upsample Stage 4
up_4 = upsample_concat(up_3, conv_1)
up_4 = resblock(up_4, 16)
# final output
out = Conv2D(1, (1,1), kernel_initializer='he_normal', padding='same', activation='sigmoid')(up_4)
seg_model = Model(X_input, out)
return seg_model