TejaCherukuri
Add the required files
f0c1a1a
raw
history blame
12.2 kB
from keras.layers import Conv1D, Conv2D, Conv3D
from keras.layers import Conv2D, LayerNormalization, Layer
from tensorflow.keras.layers import Reshape, Activation, Softmax, Permute, Add, Dot
from keras.optimizers import RMSprop
from tensorflow.keras import layers, models, regularizers
from tensorflow.keras.applications import EfficientNetV2B0
from tensorflow.keras.callbacks import ModelCheckpoint
from keras import ops
import tensorflow as tf
import numpy as np
from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score,cohen_kappa_score,roc_auc_score,classification_report
from gcg import config
from gcg.utils import CustomException, logging
import sys
class GlobalContextAttention(tf.keras.layers.Layer):
def __init__(self, reduction_ratio=8, transform_activation='linear', **kwargs):
"""
Initializes the GlobalContextAttention layer.
Args:
reduction_ratio (int): Reduces the input filters by this factor for the
bottleneck block of the transform submodule.
transform_activation (str): Activation function to apply to the output
of the transform block.
**kwargs: Additional keyword arguments for the Layer class.
"""
super(GlobalContextAttention, self).__init__(**kwargs)
self.reduction_ratio = reduction_ratio
self.transform_activation = transform_activation
def build(self, input_shape):
"""
Builds the layer by initializing weights and sub-layers.
Args:
input_shape: Shape of the input tensor.
"""
self.channel_dim = 1 if tf.keras.backend.image_data_format() == 'channels_first' else -1
self.rank = len(input_shape)
# Validate input rank
if self.rank not in [3, 4, 5]:
raise ValueError('Input dimension has to be either 3 (temporal), 4 (spatial), or 5 (spatio-temporal)')
# Calculate the number of channels
self.channels = input_shape[self.channel_dim]
# Initialize sub-layers
self.conv_context = self._convND_layer(1) # Context modelling block
self.conv_transform_bottleneck = self._convND_layer(self.channels // self.reduction_ratio) # Transform bottleneck
self.conv_transform_output = self._convND_layer(self.channels) # Transform output block
# Softmax and Dot layers
self.softmax = Softmax(axis=self._get_flat_spatial_dim())
self.dot = Dot(axes=(1, 1)) # Dot product over the flattened spatial dimensions
# Activation layers
self.activation_relu = Activation('relu')
self.activation_transform = Activation(self.transform_activation)
# Add layer for final output
self.add = Add()
super(GlobalContextAttention, self).build(input_shape)
def call(self, inputs):
"""
Performs the forward pass of the layer.
Args:
inputs: Input tensor.
Returns:
Output tensor with global context attention applied.
"""
# Context Modelling Block
input_flat = self._spatial_flattenND(inputs) # [B, spatial_dims, C]
context = self.conv_context(inputs) # [B, spatial_dims, 1]
context = self._spatial_flattenND(context) # [B, spatial_dims, 1]
context = self.softmax(context) # Apply softmax over spatial_dims
context = self.dot([input_flat, context]) # [B, C, 1]
context = self._spatial_expandND(context) # [B, C, 1, 1, ...]
# Transform Block
transform = self.conv_transform_bottleneck(context) # [B, C // R, 1, 1, ...]
transform = self.activation_relu(transform)
transform = self.conv_transform_output(transform) # [B, C, 1, 1, ...]
transform = self.activation_transform(transform)
# Apply context transform
out = self.add([inputs, transform]) # [B, spatial_dims, C]
return out
def _convND_layer(self, filters):
"""
Creates a Conv1D, Conv2D, or Conv3D layer based on the input rank.
Args:
filters (int): Number of filters for the convolutional layer.
Returns:
A Conv1D, Conv2D, or Conv3D layer.
"""
if self.rank == 3:
return Conv1D(filters, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal')
elif self.rank == 4:
return Conv2D(filters, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal')
elif self.rank == 5:
return Conv3D(filters, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal')
def _spatial_flattenND(self, ip):
"""
Flattens the spatial dimensions of the input tensor.
Args:
ip: Input tensor.
Returns:
Flattened tensor.
"""
if self.rank == 3:
return ip # Identity op for rank 3
else:
shape = (ip.shape[self.channel_dim], -1) if self.channel_dim == 1 else (-1, ip.shape[-1])
return Reshape(shape)(ip)
def _spatial_expandND(self, ip):
"""
Expands the spatial dimensions of the input tensor.
Args:
ip: Input tensor.
Returns:
Expanded tensor.
"""
if self.rank == 3:
return Permute((2, 1))(ip) # Identity op for rank 3
else:
shape = (-1, *(1 for _ in range(self.rank - 2))) if self.channel_dim == 1 else (*(1 for _ in range(self.rank - 2)), -1)
return Reshape(shape)(ip)
def _get_flat_spatial_dim(self):
"""
Returns the axis for flattening spatial dimensions.
Returns:
Axis for flattening.
"""
return 1 if self.channel_dim == 1 else -1
def get_config(self):
"""
Returns the configuration of the layer for serialization.
Returns:
A dictionary containing the layer configuration.
"""
config = super(GlobalContextAttention, self).get_config()
config.update({
'reduction_ratio': self.reduction_ratio,
'transform_activation': self.transform_activation,
})
return config
class AttentionGate(Layer):
def __init__(self, filters, **kwargs):
self.filters = filters
super(AttentionGate, self).__init__(**kwargs)
def build(self, input_shape):
# Create trainable parameters for attention gate
self.conv_xl = Conv2D(self.filters, kernel_size=(1, 1), strides=(1, 1), padding='same', activation='relu')
self.conv_g = Conv2D(self.filters, kernel_size=(1, 1), strides=(1, 1), padding='same', activation='relu')
self.psi = Conv2D(1, kernel_size=(1, 1), strides=(1, 1), padding='same', activation='linear')
self.layer_norm = LayerNormalization(axis=-1)
# Build the child layers
self.conv_xl.build(input_shape[0]) # Build conv_xl with the shape of xl
self.conv_g.build(input_shape[1]) # Build conv_g with the shape of g
self.psi.build(input_shape[0]) # Build psi with the shape of xl
self.layer_norm.build(input_shape[0]) # Build layer_norm with the shape of xl
# Add trainable weights
self.bxg = self.add_weight(name='bxg',
shape=(self.filters,),
initializer='zeros',
trainable=True)
self.bpsi = self.add_weight(name='bpsi',
shape=(1,),
initializer='zeros',
trainable=True)
super(AttentionGate, self).build(input_shape)
def call(self, inputs):
xl, g = inputs
# Apply convolutional operations
xl_conv = self.conv_xl(xl)
g_conv = self.conv_g(g)
# Compute additive attention
att = tf.keras.backend.relu(xl_conv + g_conv + self.bxg)
att = self.layer_norm(att) # Add LayerNormalization
att = self.psi(att) + self.bpsi
att = tf.keras.backend.sigmoid(att)
# Apply attention gate
x_hat = att * xl
return x_hat
def compute_output_shape(self, input_shape):
return input_shape[0]
def get_config(self):
config = super(AttentionGate, self).get_config()
config.update({'filters': self.filters})
return config
class GCRMSprop(RMSprop):
def get_gradients(self, loss, params):
# We here just provide a modified get_gradients() function since we are trying to just compute the centralized gradients.
grads = []
gradients = super().get_gradients()
for grad in gradients:
grad_len = len(grad.shape)
if grad_len > 1:
axis = list(range(grad_len - 1))
grad -= ops.mean(grad, axis=axis, keep_dims=True)
grads.append(grad)
return grads
# Build the model
def build_model(input_shape, num_classes):
try:
logging.info("Loading weights of EfficientNetV2B0...")
base_model = EfficientNetV2B0(weights='imagenet', include_top=False, input_shape=input_shape)
fmaps = base_model.output
logging.info("Initializing Global Context Attention...")
context_fmaps = GlobalContextAttention()(fmaps)
logging.info("Initializing AttentionGate...")
att_fmaps = AttentionGate(fmaps.shape[-1])([fmaps, context_fmaps])
x = layers.GlobalAveragePooling2D()(att_fmaps)
# First Dense Layer
x = layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.005))(x)
x = layers.BatchNormalization()(x)
x = layers.Dropout(0.3)(x)
# Second Dense Layer
x = layers.Dense(256, activation='relu', kernel_regularizer=regularizers.l2(0.005))(x)
x = layers.BatchNormalization()(x)
x = layers.Dropout(0.2)(x)
output = layers.Dense(num_classes, activation='softmax')(x)
model = models.Model(inputs=base_model.input, outputs=output)
model.compile(optimizer=GCRMSprop(learning_rate=1e-4), loss='categorical_crossentropy', metrics=['accuracy'])
logging.info("Model Built Successfully!")
return model
except Exception as e:
raise CustomException(e, sys)
# Train the model
def train_model(model, X_train, X_test, y_train, y_test):
try:
# Define the necessary callbacks
checkpoint = ModelCheckpoint(config.model_path, monitor='val_accuracy', verbose=1, save_best_only=True, mode='max')
callbacks = [checkpoint]
logging.info(f"Training network for {config.EPOCHS} epochs...")
hist = model.fit(X_train, y_train, batch_size=config.batch_size,
validation_data=(X_test, y_test),
epochs=config.EPOCHS, callbacks=callbacks)
return hist
except Exception as e:
raise CustomException(e, sys)
# Evaluate the model
def evaluate_model(model, X_test, y_test):
try:
y_score = model.predict(X_test)
y_pred = np.argmax(y_score, axis=-1)
Y_test = np.argmax(y_test, axis=-1)
acc = accuracy_score(Y_test,y_pred)
mpre = precision_score(Y_test,y_pred,average='macro')
mrecall = recall_score(Y_test,y_pred,average='macro')
mf1 = f1_score(Y_test,y_pred,average='macro')
kappa = cohen_kappa_score(Y_test,y_pred,weights='quadratic')
auc = roc_auc_score(Y_test, y_score, average='macro', multi_class='ovr')
logging.info(f"Accuracy: {round(acc*100,2)}")
logging.info(f"Macro Precision: {round(mpre*100,2)}")
logging.info(f"Macro Recall: {round(mrecall*100,2)}")
logging.info(f"Macro F1-Score: {round(mf1*100,2)}")
logging.info(f"Quadratic Kappa Score: {round(kappa*100,2)}")
logging.info(f"ROC AUC Score: {round(auc*100,2)}")
logging.info(classification_report(Y_test, y_pred, digits=4))
except Exception as e:
raise CustomException(e, sys)