|
from abc import ABC, abstractmethod |
|
from enum import Enum |
|
from typing import Tuple, Optional |
|
|
|
import tensorflow as tf |
|
from tensorflow.keras.layers import * |
|
from tensorflow.keras.models import * |
|
|
|
class BaseUNet(ABC): |
|
""" |
|
Base Interface for UNet |
|
""" |
|
|
|
def __init__(self, model: Model): |
|
self.model: Model = model |
|
|
|
def get_model(self): |
|
return self.model |
|
|
|
@staticmethod |
|
@abstractmethod |
|
def build_model(input_size: Tuple[int, int, int], filters: Tuple, kernels: Tuple): |
|
pass |
|
|
|
|
|
class UNet(Enum): |
|
""" |
|
Enum class defining different architecture types available |
|
""" |
|
DEFAULT = 0 |
|
DEFAULT_IMAGENET_EMBEDDING = 1 |
|
RESNET = 3 |
|
RESIDUAL_ATTENTION_UNET_SEPARABLE_CONV = 4 |
|
|
|
def build_model(self, input_size: Tuple[int, int, int], filters: Optional[Tuple] = None, |
|
kernels: Optional[Tuple] = None) -> BaseUNet: |
|
|
|
|
|
if filters is None: |
|
filters = (16, 32, 64, 128, 256) |
|
|
|
|
|
if kernels is None: |
|
kernels = list(3 for _ in range(len(filters))) |
|
|
|
|
|
if len(filters) != len(kernels): |
|
raise Exception('Kernels and filter count has to match.') |
|
|
|
if self == UNet.DEFAULT_IMAGENET_EMBEDDING: |
|
print('Using default UNet model with imagenet embedding') |
|
return UNetDefault.build_model(input_size, filters, kernels, use_embedding=True) |
|
elif self == UNet.RESNET: |
|
print('Using UNet Resnet model') |
|
return UNet_resnet.build_model(input_size, filters, kernels) |
|
elif self == UNet.RESIDUAL_ATTENTION_UNET_SEPARABLE_CONV: |
|
print('Using UNet Resnet model with attention mechanism and separable convolutions') |
|
return UNet_ResNet_Attention_SeparableConv.build_model(input_size, filters, kernels) |
|
|
|
print('Using default UNet model') |
|
return UNetDefault.build_model(input_size, filters, kernels, use_embedding=False) |
|
|
|
|
|
class Attention(Layer): |
|
def __init__(self, **kwargs): |
|
super(Attention, self).__init__(**kwargs) |
|
|
|
def build(self, input_shape): |
|
|
|
self.kernel = self.add_weight(name='kernel', |
|
shape=(input_shape[-1], 1), |
|
initializer='glorot_normal', |
|
trainable=True) |
|
self.bias = self.add_weight(name='bias', |
|
shape=(1,), |
|
initializer='zeros', |
|
trainable=True) |
|
super(Attention, self).build(input_shape) |
|
|
|
def call(self, x): |
|
attention = tf.nn.softmax(tf.matmul(x, self.kernel) + self.bias, axis=-1) |
|
return tf.multiply(x, attention) |
|
|
|
def compute_output_shape(self, input_shape): |
|
return input_shape |
|
|
|
|
|
class UNet_ResNet_Attention_SeparableConv(BaseUNet): |
|
""" |
|
UNet architecture with resnet blocks, attention mechanism and separable convolutions |
|
""" |
|
@staticmethod |
|
def build_model(input_size: Tuple[int, int, int], filters: Tuple, kernels: Tuple): |
|
|
|
p0 = Input(shape=input_size) |
|
conv_outputs = [] |
|
first_layer = SeparableConv2D(filters[0], kernels[0], padding='same')(p0) |
|
int_layer = first_layer |
|
for i, f in enumerate(filters): |
|
int_layer, skip = UNet_ResNet_Attention_SeparableConv.down_block(int_layer, f, kernels[i]) |
|
conv_outputs.append(skip) |
|
|
|
int_layer = UNet_ResNet_Attention_SeparableConv.bottleneck(int_layer, filters[-1], kernels[-1]) |
|
|
|
conv_outputs = list(reversed(conv_outputs)) |
|
reversed_filter = list(reversed(filters)) |
|
reversed_kernels = list(reversed(kernels)) |
|
for i, f in enumerate(reversed_filter): |
|
if i + 1 < len(reversed_filter): |
|
num_filters_next = reversed_filter[i + 1] |
|
num_kernels_next = reversed_kernels[i + 1] |
|
else: |
|
num_filters_next = f |
|
num_kernels_next = reversed_kernels[i] |
|
int_layer = UNet_ResNet_Attention_SeparableConv.up_block(int_layer, conv_outputs[i], f, num_filters_next, num_kernels_next) |
|
int_layer = Attention()(int_layer) |
|
|
|
|
|
int_layer = Concatenate()([first_layer, int_layer]) |
|
int_layer = SeparableConv2D(filters[0], kernels[0], padding="same", activation="relu")(int_layer) |
|
outputs = SeparableConv2D(3, (1, 1), padding="same", activation="sigmoid")(int_layer) |
|
model = Model(p0, outputs) |
|
return UNet_ResNet_Attention_SeparableConv(model) |
|
|
|
@staticmethod |
|
def down_block(x, num_filters: int = 64, kernel: int = 3): |
|
|
|
x = SeparableConv2D(num_filters, kernel, padding='same', strides=2, dilation_rate = 2)(x) |
|
|
|
|
|
out = SeparableConv2D(num_filters, kernel, padding='same')(x) |
|
|
|
out = Activation('relu')(out) |
|
out = SeparableConv2D(num_filters, kernel, padding='same')(out) |
|
|
|
|
|
out = Add()([out, x]) |
|
|
|
return Activation('relu')(out), x |
|
|
|
@staticmethod |
|
def up_block(x, skip, num_filters: int = 64, num_filters_next: int = 64, kernel: int = 3): |
|
|
|
concat = Concatenate()([x, skip]) |
|
|
|
|
|
out = SeparableConv2D(num_filters, kernel, padding='same', dilation_rate = 2)(concat) |
|
|
|
out = Activation('relu')(out) |
|
out = SeparableConv2D(num_filters, kernel, padding='same')(out) |
|
|
|
|
|
out = Add()([out, x]) |
|
|
|
out = Activation('relu')(out) |
|
|
|
|
|
out = UpSampling2D((2, 2))(out) |
|
out = SeparableConv2D(num_filters_next, kernel, padding='same')(out) |
|
|
|
return Activation('relu')(out) |
|
|
|
@staticmethod |
|
def bottleneck(x, num_filters: int = 64, kernel: int = 3): |
|
|
|
out = SeparableConv2D(num_filters, kernel, padding='same', dilation_rate = 2)(x) |
|
|
|
out = Activation('relu')(out) |
|
out = SeparableConv2D(num_filters, kernel, padding='same')(out) |
|
out = Add()([out, x]) |
|
|
|
return Activation('relu')(out) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UNet_resnet(BaseUNet): |
|
""" |
|
UNet architecture with resnet blocks |
|
""" |
|
|
|
@staticmethod |
|
def build_model(input_size: Tuple[int, int, int], filters: Tuple, kernels: Tuple): |
|
|
|
p0 = Input(shape=input_size) |
|
conv_outputs = [] |
|
first_layer = Conv2D(filters[0], kernels[0], padding='same')(p0) |
|
int_layer = first_layer |
|
for i, f in enumerate(filters): |
|
int_layer, skip = UNet_resnet.down_block(int_layer, f, kernels[i]) |
|
conv_outputs.append(skip) |
|
|
|
int_layer = UNet_resnet.bottleneck(int_layer, filters[-1], kernels[-1]) |
|
|
|
conv_outputs = list(reversed(conv_outputs)) |
|
reversed_filter = list(reversed(filters)) |
|
reversed_kernels = list(reversed(kernels)) |
|
for i, f in enumerate(reversed_filter): |
|
if i + 1 < len(reversed_filter): |
|
num_filters_next = reversed_filter[i + 1] |
|
num_kernels_next = reversed_kernels[i + 1] |
|
else: |
|
num_filters_next = f |
|
num_kernels_next = reversed_kernels[i] |
|
int_layer = UNet_resnet.up_block(int_layer, conv_outputs[i], f, num_filters_next, num_kernels_next) |
|
|
|
|
|
int_layer = Concatenate()([first_layer, int_layer]) |
|
int_layer = Conv2D(filters[0], kernels[0], padding="same", activation="relu")(int_layer) |
|
outputs = Conv2D(3, (1, 1), padding="same", activation="sigmoid")(int_layer) |
|
model = Model(p0, outputs) |
|
return UNet_resnet(model) |
|
|
|
@staticmethod |
|
def down_block(x, num_filters: int = 64, kernel: int = 3): |
|
|
|
x = Conv2D(num_filters, kernel, padding='same', strides=2)(x) |
|
|
|
|
|
out = Conv2D(num_filters, kernel, padding='same')(x) |
|
|
|
out = Activation('relu')(out) |
|
out = Conv2D(num_filters, kernel, padding='same')(out) |
|
|
|
|
|
out = Add()([out, x]) |
|
|
|
return Activation('relu')(out), x |
|
|
|
@staticmethod |
|
def up_block(x, skip, num_filters: int = 64, num_filters_next: int = 64, kernel: int = 3): |
|
|
|
|
|
concat = Concatenate()([x, skip]) |
|
|
|
|
|
out = Conv2D(num_filters, kernel, padding='same')(concat) |
|
|
|
out = Activation('relu')(out) |
|
out = Conv2D(num_filters, kernel, padding='same')(out) |
|
|
|
|
|
out = Add()([out, x]) |
|
|
|
out = Activation('relu')(out) |
|
|
|
|
|
concat = Concatenate()([out, skip]) |
|
|
|
|
|
|
|
out = Conv2DTranspose(num_filters_next, kernel, padding='same', strides=2)(concat) |
|
out = Conv2D(num_filters_next, kernel, padding='same')(out) |
|
|
|
return Activation('relu')(out) |
|
|
|
@staticmethod |
|
def bottleneck(x, filters, kernel: int = 3): |
|
x = Conv2D(filters, kernel, padding='same', name='bottleneck')(x) |
|
|
|
return Activation('relu')(x) |
|
|
|
|
|
class UNetDefault(BaseUNet): |
|
""" |
|
UNet architecture from following github notebook for image segmentation: |
|
https://github.com/nikhilroxtomar/UNet-Segmentation-in-Keras-TensorFlow/blob/master/unet-segmentation.ipynb |
|
https://github.com/nikhilroxtomar/Polyp-Segmentation-using-UNET-in-TensorFlow-2.0 |
|
""" |
|
|
|
@staticmethod |
|
def build_model(input_size: Tuple[int, int, int], filters: Tuple, kernels: Tuple, use_embedding: bool = True): |
|
|
|
p0 = Input(input_size) |
|
|
|
if use_embedding: |
|
mobilenet_model = tf.keras.applications.MobileNetV2( |
|
input_shape=input_size, include_top=False, weights='imagenet' |
|
) |
|
mobilenet_model.trainable = False |
|
mn1 = mobilenet_model(p0) |
|
mn1 = Reshape((16, 16, 320))(mn1) |
|
|
|
conv_outputs = [] |
|
int_layer = p0 |
|
|
|
for f in filters: |
|
conv_output, int_layer = UNetDefault.down_block(int_layer, f) |
|
conv_outputs.append(conv_output) |
|
|
|
int_layer = UNetDefault.bottleneck(int_layer, filters[-1]) |
|
|
|
if use_embedding: |
|
int_layer = Concatenate()([int_layer, mn1]) |
|
|
|
conv_outputs = list(reversed(conv_outputs)) |
|
for i, f in enumerate(reversed(filters)): |
|
int_layer = UNetDefault.up_block(int_layer, conv_outputs[i], f) |
|
|
|
int_layer = Conv2D(filters[0] // 2, 3, padding="same", activation="relu")(int_layer) |
|
outputs = Conv2D(3, (1, 1), padding="same", activation="sigmoid")(int_layer) |
|
model = Model(p0, outputs) |
|
return UNetDefault(model) |
|
|
|
@staticmethod |
|
def down_block(x, filters, kernel_size=(3, 3), padding="same", strides=1): |
|
c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x) |
|
|
|
p = MaxPool2D((2, 2), (2, 2))(c) |
|
return c, p |
|
|
|
@staticmethod |
|
def up_block(x, skip, filters, kernel_size=(3, 3), padding="same", strides=1): |
|
us = UpSampling2D((2, 2))(x) |
|
c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(us) |
|
|
|
concat = Concatenate()([c, skip]) |
|
c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(concat) |
|
|
|
return c |
|
|
|
@staticmethod |
|
def bottleneck(x, filters, kernel_size=(3, 3), padding="same", strides=1): |
|
c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x) |
|
|
|
return c |
|
|
|
|
|
if __name__ == "__main__": |
|
filters = (64, 128, 128, 256, 256, 512) |
|
kernels = (7, 7, 7, 3, 3, 3) |
|
input_image_size = (256, 256, 3) |
|
|
|
|
|
|
|
|
|
model = UNetDefault.build_model(input_size=input_image_size, filters=filters, kernels=kernels) |
|
print(model.summary()) |
|
|