from keras.layers import Conv1D, Conv2D, Conv3D |
from keras.layers import Conv2D, LayerNormalization, Layer |
from tensorflow.keras.layers import Reshape, Activation, Softmax, Permute, Add, Dot |
from keras.optimizers import RMSprop |
from tensorflow.keras import layers, models, regularizers |
from tensorflow.keras.applications import EfficientNetV2B0 |
from tensorflow.keras.callbacks import ModelCheckpoint |
from keras import ops |
import tensorflow as tf |
import numpy as np |
from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score,cohen_kappa_score,roc_auc_score,classification_report |
from gcg import config |
from gcg.utils import CustomException, logging |
import sys |
class GlobalContextAttention(tf.keras.layers.Layer): |
def __init__(self, reduction_ratio=8, transform_activation='linear', **kwargs): |
""" |
Initializes the GlobalContextAttention layer. |
Args: |
reduction_ratio (int): Reduces the input filters by this factor for the |
bottleneck block of the transform submodule. |
transform_activation (str): Activation function to apply to the output |
of the transform block. |
**kwargs: Additional keyword arguments for the Layer class. |
""" |
super(GlobalContextAttention, self).__init__(**kwargs) |
self.reduction_ratio = reduction_ratio |
self.transform_activation = transform_activation |
def build(self, input_shape): |
""" |
Builds the layer by initializing weights and sub-layers. |
Args: |
input_shape: Shape of the input tensor. |
""" |
self.channel_dim = 1 if tf.keras.backend.image_data_format() == 'channels_first' else -1 |
self.rank = len(input_shape) |
if self.rank not in [3, 4, 5]: |
raise ValueError('Input dimension has to be either 3 (temporal), 4 (spatial), or 5 (spatio-temporal)') |
self.channels = input_shape[self.channel_dim] |
self.conv_context = self._convND_layer(1) |
self.conv_transform_bottleneck = self._convND_layer(self.channels // self.reduction_ratio) |
self.conv_transform_output = self._convND_layer(self.channels) |
self.softmax = Softmax(axis=self._get_flat_spatial_dim()) |
self.dot = Dot(axes=(1, 1)) |
self.activation_relu = Activation('relu') |
self.activation_transform = Activation(self.transform_activation) |
self.add = Add() |
super(GlobalContextAttention, self).build(input_shape) |
def call(self, inputs): |
""" |
Performs the forward pass of the layer. |
Args: |
inputs: Input tensor. |
Returns: |
Output tensor with global context attention applied. |
""" |
input_flat = self._spatial_flattenND(inputs) |
context = self.conv_context(inputs) |
context = self._spatial_flattenND(context) |
context = self.softmax(context) |
context = self.dot([input_flat, context]) |
context = self._spatial_expandND(context) |
transform = self.conv_transform_bottleneck(context) |
transform = self.activation_relu(transform) |
transform = self.conv_transform_output(transform) |
transform = self.activation_transform(transform) |
out = self.add([inputs, transform]) |
return out |
def _convND_layer(self, filters): |
""" |
Creates a Conv1D, Conv2D, or Conv3D layer based on the input rank. |
Args: |
filters (int): Number of filters for the convolutional layer. |
Returns: |
A Conv1D, Conv2D, or Conv3D layer. |
""" |
if self.rank == 3: |
return Conv1D(filters, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal') |
elif self.rank == 4: |
return Conv2D(filters, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal') |
elif self.rank == 5: |
return Conv3D(filters, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal') |
def _spatial_flattenND(self, ip): |
""" |
Flattens the spatial dimensions of the input tensor. |
Args: |
ip: Input tensor. |
Returns: |
Flattened tensor. |
""" |
if self.rank == 3: |
return ip |
else: |
shape = (ip.shape[self.channel_dim], -1) if self.channel_dim == 1 else (-1, ip.shape[-1]) |
return Reshape(shape)(ip) |
def _spatial_expandND(self, ip): |
""" |
Expands the spatial dimensions of the input tensor. |
Args: |
ip: Input tensor. |
Returns: |
Expanded tensor. |
""" |
if self.rank == 3: |
return Permute((2, 1))(ip) |
else: |
shape = (-1, *(1 for _ in range(self.rank - 2))) if self.channel_dim == 1 else (*(1 for _ in range(self.rank - 2)), -1) |
return Reshape(shape)(ip) |
def _get_flat_spatial_dim(self): |
""" |
Returns the axis for flattening spatial dimensions. |
Returns: |
Axis for flattening. |
""" |
return 1 if self.channel_dim == 1 else -1 |
def get_config(self): |
""" |
Returns the configuration of the layer for serialization. |
Returns: |
A dictionary containing the layer configuration. |
""" |
config = super(GlobalContextAttention, self).get_config() |
config.update({ |
'reduction_ratio': self.reduction_ratio, |
'transform_activation': self.transform_activation, |
}) |
return config |
class AttentionGate(Layer): |
def __init__(self, filters, **kwargs): |
self.filters = filters |
super(AttentionGate, self).__init__(**kwargs) |
def build(self, input_shape): |
self.conv_xl = Conv2D(self.filters, kernel_size=(1, 1), strides=(1, 1), padding='same', activation='relu') |
self.conv_g = Conv2D(self.filters, kernel_size=(1, 1), strides=(1, 1), padding='same', activation='relu') |
self.psi = Conv2D(1, kernel_size=(1, 1), strides=(1, 1), padding='same', activation='linear') |
self.layer_norm = LayerNormalization(axis=-1) |
self.conv_xl.build(input_shape[0]) |
self.conv_g.build(input_shape[1]) |
self.psi.build(input_shape[0]) |
self.layer_norm.build(input_shape[0]) |
self.bxg = self.add_weight(name='bxg', |
shape=(self.filters,), |
initializer='zeros', |
trainable=True) |
self.bpsi = self.add_weight(name='bpsi', |
shape=(1,), |
initializer='zeros', |
trainable=True) |
super(AttentionGate, self).build(input_shape) |
def call(self, inputs): |
xl, g = inputs |
xl_conv = self.conv_xl(xl) |
g_conv = self.conv_g(g) |
att = tf.keras.backend.relu(xl_conv + g_conv + self.bxg) |
att = self.layer_norm(att) |
att = self.psi(att) + self.bpsi |
att = tf.keras.backend.sigmoid(att) |
x_hat = att * xl |
return x_hat |
def compute_output_shape(self, input_shape): |
return input_shape[0] |
def get_config(self): |
config = super(AttentionGate, self).get_config() |
config.update({'filters': self.filters}) |
return config |
class GCRMSprop(RMSprop): |
def get_gradients(self, loss, params): |
grads = [] |
gradients = super().get_gradients() |
for grad in gradients: |
grad_len = len(grad.shape) |
if grad_len > 1: |
axis = list(range(grad_len - 1)) |
grad -= ops.mean(grad, axis=axis, keep_dims=True) |
grads.append(grad) |
return grads |
def build_model(input_shape, num_classes): |
try: |
logging.info("Loading weights of EfficientNetV2B0...") |
base_model = EfficientNetV2B0(weights='imagenet', include_top=False, input_shape=input_shape) |
fmaps = base_model.output |
logging.info("Initializing Global Context Attention...") |
context_fmaps = GlobalContextAttention()(fmaps) |
logging.info("Initializing AttentionGate...") |
att_fmaps = AttentionGate(fmaps.shape[-1])([fmaps, context_fmaps]) |
x = layers.GlobalAveragePooling2D()(att_fmaps) |
x = layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.005))(x) |
x = layers.BatchNormalization()(x) |
x = layers.Dropout(0.3)(x) |
x = layers.Dense(256, activation='relu', kernel_regularizer=regularizers.l2(0.005))(x) |
x = layers.BatchNormalization()(x) |
x = layers.Dropout(0.2)(x) |
output = layers.Dense(num_classes, activation='softmax')(x) |
model = models.Model(inputs=base_model.input, outputs=output) |
model.compile(optimizer=GCRMSprop(learning_rate=1e-4), loss='categorical_crossentropy', metrics=['accuracy']) |
logging.info("Model Built Successfully!") |
return model |
except Exception as e: |
raise CustomException(e, sys) |
def train_model(model, X_train, X_test, y_train, y_test): |
try: |
checkpoint = ModelCheckpoint(config.model_path, monitor='val_accuracy', verbose=1, save_best_only=True, mode='max') |
callbacks = [checkpoint] |
logging.info(f"Training network for {config.EPOCHS} epochs...") |
hist = model.fit(X_train, y_train, batch_size=config.batch_size, |
validation_data=(X_test, y_test), |
epochs=config.EPOCHS, callbacks=callbacks) |
return hist |
except Exception as e: |
raise CustomException(e, sys) |
def evaluate_model(model, X_test, y_test): |
try: |
y_score = model.predict(X_test) |
y_pred = np.argmax(y_score, axis=-1) |
Y_test = np.argmax(y_test, axis=-1) |
acc = accuracy_score(Y_test,y_pred) |
mpre = precision_score(Y_test,y_pred,average='macro') |
mrecall = recall_score(Y_test,y_pred,average='macro') |
mf1 = f1_score(Y_test,y_pred,average='macro') |
kappa = cohen_kappa_score(Y_test,y_pred,weights='quadratic') |
auc = roc_auc_score(Y_test, y_score, average='macro', multi_class='ovr') |
logging.info(f"Accuracy: {round(acc*100,2)}") |
logging.info(f"Macro Precision: {round(mpre*100,2)}") |
logging.info(f"Macro Recall: {round(mrecall*100,2)}") |
logging.info(f"Macro F1-Score: {round(mf1*100,2)}") |
logging.info(f"Quadratic Kappa Score: {round(kappa*100,2)}") |
logging.info(f"ROC AUC Score: {round(auc*100,2)}") |
logging.info(classification_report(Y_test, y_pred, digits=4)) |
except Exception as e: |
raise CustomException(e, sys) |