LFUNet / utils /architectures.py
amish1729's picture
Initial commit
232568e
raw
history blame
13.1 kB
from abc import ABC, abstractmethod
from enum import Enum
from typing import Tuple, Optional
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
class BaseUNet(ABC):
"""
Base Interface for UNet
"""
def __init__(self, model: Model):
self.model: Model = model
def get_model(self):
return self.model
@staticmethod
@abstractmethod
def build_model(input_size: Tuple[int, int, int], filters: Tuple, kernels: Tuple):
pass
class UNet(Enum):
"""
Enum class defining different architecture types available
"""
DEFAULT = 0
DEFAULT_IMAGENET_EMBEDDING = 1
RESNET = 3
RESIDUAL_ATTENTION_UNET_SEPARABLE_CONV = 4
def build_model(self, input_size: Tuple[int, int, int], filters: Optional[Tuple] = None,
kernels: Optional[Tuple] = None) -> BaseUNet:
# set default filters
if filters is None:
filters = (16, 32, 64, 128, 256)
# set default kernels
if kernels is None:
kernels = list(3 for _ in range(len(filters)))
# check kernels and filters
if len(filters) != len(kernels):
raise Exception('Kernels and filter count has to match.')
if self == UNet.DEFAULT_IMAGENET_EMBEDDING:
print('Using default UNet model with imagenet embedding')
return UNetDefault.build_model(input_size, filters, kernels, use_embedding=True)
elif self == UNet.RESNET:
print('Using UNet Resnet model')
return UNet_resnet.build_model(input_size, filters, kernels)
elif self == UNet.RESIDUAL_ATTENTION_UNET_SEPARABLE_CONV:
print('Using UNet Resnet model with attention mechanism and separable convolutions')
return UNet_ResNet_Attention_SeparableConv.build_model(input_size, filters, kernels)
print('Using default UNet model')
return UNetDefault.build_model(input_size, filters, kernels, use_embedding=False)
class Attention(Layer):
def __init__(self, **kwargs):
super(Attention, self).__init__(**kwargs)
def build(self, input_shape):
# Create a trainable weight variable for this layer.
self.kernel = self.add_weight(name='kernel',
shape=(input_shape[-1], 1),
initializer='glorot_normal',
trainable=True)
self.bias = self.add_weight(name='bias',
shape=(1,),
initializer='zeros',
trainable=True)
super(Attention, self).build(input_shape) # Be sure to call this at the end
def call(self, x):
attention = tf.nn.softmax(tf.matmul(x, self.kernel) + self.bias, axis=-1)
return tf.multiply(x, attention)
def compute_output_shape(self, input_shape):
return input_shape
class UNet_ResNet_Attention_SeparableConv(BaseUNet):
"""
UNet architecture with resnet blocks, attention mechanism and separable convolutions
"""
@staticmethod
def build_model(input_size: Tuple[int, int, int], filters: Tuple, kernels: Tuple):
p0 = Input(shape=input_size)
conv_outputs = []
first_layer = SeparableConv2D(filters[0], kernels[0], padding='same')(p0)
int_layer = first_layer
for i, f in enumerate(filters):
int_layer, skip = UNet_ResNet_Attention_SeparableConv.down_block(int_layer, f, kernels[i])
conv_outputs.append(skip)
int_layer = UNet_ResNet_Attention_SeparableConv.bottleneck(int_layer, filters[-1], kernels[-1])
conv_outputs = list(reversed(conv_outputs))
reversed_filter = list(reversed(filters))
reversed_kernels = list(reversed(kernels))
for i, f in enumerate(reversed_filter):
if i + 1 < len(reversed_filter):
num_filters_next = reversed_filter[i + 1]
num_kernels_next = reversed_kernels[i + 1]
else:
num_filters_next = f
num_kernels_next = reversed_kernels[i]
int_layer = UNet_ResNet_Attention_SeparableConv.up_block(int_layer, conv_outputs[i], f, num_filters_next, num_kernels_next)
int_layer = Attention()(int_layer)
# concat. with the first layer
int_layer = Concatenate()([first_layer, int_layer])
int_layer = SeparableConv2D(filters[0], kernels[0], padding="same", activation="relu")(int_layer)
outputs = SeparableConv2D(3, (1, 1), padding="same", activation="sigmoid")(int_layer)
model = Model(p0, outputs)
return UNet_ResNet_Attention_SeparableConv(model)
@staticmethod
def down_block(x, num_filters: int = 64, kernel: int = 3):
# down-sample inputs
x = SeparableConv2D(num_filters, kernel, padding='same', strides=2, dilation_rate = 2)(x)
# inner block
out = SeparableConv2D(num_filters, kernel, padding='same')(x)
# out = BatchNormalization()(out)
out = Activation('relu')(out)
out = SeparableConv2D(num_filters, kernel, padding='same')(out)
# merge with the skip connection
out = Add()([out, x])
# out = BatchNormalization()(out)
return Activation('relu')(out), x
@staticmethod
def up_block(x, skip, num_filters: int = 64, num_filters_next: int = 64, kernel: int = 3):
# add U-Net skip connection - before up-sampling
concat = Concatenate()([x, skip])
# inner block
out = SeparableConv2D(num_filters, kernel, padding='same', dilation_rate = 2)(concat)
# out = BatchNormalization()(out)
out = Activation('relu')(out)
out = SeparableConv2D(num_filters, kernel, padding='same')(out)
# merge with the skip connection
out = Add()([out, x])
# out = BatchNormalization()(out)
out = Activation('relu')(out)
# up-sample
out = UpSampling2D((2, 2))(out)
out = SeparableConv2D(num_filters_next, kernel, padding='same')(out)
# out = BatchNormalization()(out)
return Activation('relu')(out)
@staticmethod
def bottleneck(x, num_filters: int = 64, kernel: int = 3):
# inner block
out = SeparableConv2D(num_filters, kernel, padding='same', dilation_rate = 2)(x)
# out = BatchNormalization()(out)
out = Activation('relu')(out)
out = SeparableConv2D(num_filters, kernel, padding='same')(out)
out = Add()([out, x])
# out = BatchNormalization()(out)
return Activation('relu')(out)
# Class for UNet with Resnet blocks
class UNet_resnet(BaseUNet):
"""
UNet architecture with resnet blocks
"""
@staticmethod
def build_model(input_size: Tuple[int, int, int], filters: Tuple, kernels: Tuple):
p0 = Input(shape=input_size)
conv_outputs = []
first_layer = Conv2D(filters[0], kernels[0], padding='same')(p0)
int_layer = first_layer
for i, f in enumerate(filters):
int_layer, skip = UNet_resnet.down_block(int_layer, f, kernels[i])
conv_outputs.append(skip)
int_layer = UNet_resnet.bottleneck(int_layer, filters[-1], kernels[-1])
conv_outputs = list(reversed(conv_outputs))
reversed_filter = list(reversed(filters))
reversed_kernels = list(reversed(kernels))
for i, f in enumerate(reversed_filter):
if i + 1 < len(reversed_filter):
num_filters_next = reversed_filter[i + 1]
num_kernels_next = reversed_kernels[i + 1]
else:
num_filters_next = f
num_kernels_next = reversed_kernels[i]
int_layer = UNet_resnet.up_block(int_layer, conv_outputs[i], f, num_filters_next, num_kernels_next)
# concat. with the first layer
int_layer = Concatenate()([first_layer, int_layer])
int_layer = Conv2D(filters[0], kernels[0], padding="same", activation="relu")(int_layer)
outputs = Conv2D(3, (1, 1), padding="same", activation="sigmoid")(int_layer)
model = Model(p0, outputs)
return UNet_resnet(model)
@staticmethod
def down_block(x, num_filters: int = 64, kernel: int = 3):
# down-sample inputs
x = Conv2D(num_filters, kernel, padding='same', strides=2)(x)
# inner block
out = Conv2D(num_filters, kernel, padding='same')(x)
# out = BatchNormalization()(out)
out = Activation('relu')(out)
out = Conv2D(num_filters, kernel, padding='same')(out)
# merge with the skip connection
out = Add()([out, x])
# out = BatchNormalization()(out)
return Activation('relu')(out), x
@staticmethod
def up_block(x, skip, num_filters: int = 64, num_filters_next: int = 64, kernel: int = 3):
# add U-Net skip connection - before up-sampling
concat = Concatenate()([x, skip])
# inner block
out = Conv2D(num_filters, kernel, padding='same')(concat)
# out = BatchNormalization()(out)
out = Activation('relu')(out)
out = Conv2D(num_filters, kernel, padding='same')(out)
# merge with the skip connection
out = Add()([out, x])
# out = BatchNormalization()(out)
out = Activation('relu')(out)
# add U-Net skip connection - before up-sampling
concat = Concatenate()([out, skip])
# up-sample
# out = UpSampling2D((2, 2))(concat)
out = Conv2DTranspose(num_filters_next, kernel, padding='same', strides=2)(concat)
out = Conv2D(num_filters_next, kernel, padding='same')(out)
# out = BatchNormalization()(out)
return Activation('relu')(out)
@staticmethod
def bottleneck(x, filters, kernel: int = 3):
x = Conv2D(filters, kernel, padding='same', name='bottleneck')(x)
# x = BatchNormalization()(x)
return Activation('relu')(x)
class UNetDefault(BaseUNet):
"""
UNet architecture from following github notebook for image segmentation:
https://github.com/nikhilroxtomar/UNet-Segmentation-in-Keras-TensorFlow/blob/master/unet-segmentation.ipynb
https://github.com/nikhilroxtomar/Polyp-Segmentation-using-UNET-in-TensorFlow-2.0
"""
@staticmethod
def build_model(input_size: Tuple[int, int, int], filters: Tuple, kernels: Tuple, use_embedding: bool = True):
p0 = Input(input_size)
if use_embedding:
mobilenet_model = tf.keras.applications.MobileNetV2(
input_shape=input_size, include_top=False, weights='imagenet'
)
mobilenet_model.trainable = False
mn1 = mobilenet_model(p0)
mn1 = Reshape((16, 16, 320))(mn1)
conv_outputs = []
int_layer = p0
for f in filters:
conv_output, int_layer = UNetDefault.down_block(int_layer, f)
conv_outputs.append(conv_output)
int_layer = UNetDefault.bottleneck(int_layer, filters[-1])
if use_embedding:
int_layer = Concatenate()([int_layer, mn1])
conv_outputs = list(reversed(conv_outputs))
for i, f in enumerate(reversed(filters)):
int_layer = UNetDefault.up_block(int_layer, conv_outputs[i], f)
int_layer = Conv2D(filters[0] // 2, 3, padding="same", activation="relu")(int_layer)
outputs = Conv2D(3, (1, 1), padding="same", activation="sigmoid")(int_layer)
model = Model(p0, outputs)
return UNetDefault(model)
@staticmethod
def down_block(x, filters, kernel_size=(3, 3), padding="same", strides=1):
c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x)
# c = BatchNormalization()(c)
p = MaxPool2D((2, 2), (2, 2))(c)
return c, p
@staticmethod
def up_block(x, skip, filters, kernel_size=(3, 3), padding="same", strides=1):
us = UpSampling2D((2, 2))(x)
c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(us)
# c = BatchNormalization()(c)
concat = Concatenate()([c, skip])
c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(concat)
# c = BatchNormalization()(c)
return c
@staticmethod
def bottleneck(x, filters, kernel_size=(3, 3), padding="same", strides=1):
c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x)
# c = BatchNormalization()(c)
return c
if __name__ == "__main__":
filters = (64, 128, 128, 256, 256, 512)
kernels = (7, 7, 7, 3, 3, 3)
input_image_size = (256, 256, 3)
# model = UNet_resnet()
# model = model.build_model(input_size=input_image_size,filters=filters,kernels=kernels)
# print(model.summary())
# __init__() missing 1 required positional argument: 'model'
model = UNetDefault.build_model(input_size=input_image_size, filters=filters, kernels=kernels)
print(model.summary())