File size: 12,207 Bytes
f0c1a1a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 |
from keras.layers import Conv1D, Conv2D, Conv3D
from keras.layers import Conv2D, LayerNormalization, Layer
from tensorflow.keras.layers import Reshape, Activation, Softmax, Permute, Add, Dot
from keras.optimizers import RMSprop
from tensorflow.keras import layers, models, regularizers
from tensorflow.keras.applications import EfficientNetV2B0
from tensorflow.keras.callbacks import ModelCheckpoint
from keras import ops
import tensorflow as tf
import numpy as np
from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score,cohen_kappa_score,roc_auc_score,classification_report
from gcg import config
from gcg.utils import CustomException, logging
import sys
class GlobalContextAttention(tf.keras.layers.Layer):
def __init__(self, reduction_ratio=8, transform_activation='linear', **kwargs):
"""
Initializes the GlobalContextAttention layer.
Args:
reduction_ratio (int): Reduces the input filters by this factor for the
bottleneck block of the transform submodule.
transform_activation (str): Activation function to apply to the output
of the transform block.
**kwargs: Additional keyword arguments for the Layer class.
"""
super(GlobalContextAttention, self).__init__(**kwargs)
self.reduction_ratio = reduction_ratio
self.transform_activation = transform_activation
def build(self, input_shape):
"""
Builds the layer by initializing weights and sub-layers.
Args:
input_shape: Shape of the input tensor.
"""
self.channel_dim = 1 if tf.keras.backend.image_data_format() == 'channels_first' else -1
self.rank = len(input_shape)
# Validate input rank
if self.rank not in [3, 4, 5]:
raise ValueError('Input dimension has to be either 3 (temporal), 4 (spatial), or 5 (spatio-temporal)')
# Calculate the number of channels
self.channels = input_shape[self.channel_dim]
# Initialize sub-layers
self.conv_context = self._convND_layer(1) # Context modelling block
self.conv_transform_bottleneck = self._convND_layer(self.channels // self.reduction_ratio) # Transform bottleneck
self.conv_transform_output = self._convND_layer(self.channels) # Transform output block
# Softmax and Dot layers
self.softmax = Softmax(axis=self._get_flat_spatial_dim())
self.dot = Dot(axes=(1, 1)) # Dot product over the flattened spatial dimensions
# Activation layers
self.activation_relu = Activation('relu')
self.activation_transform = Activation(self.transform_activation)
# Add layer for final output
self.add = Add()
super(GlobalContextAttention, self).build(input_shape)
def call(self, inputs):
"""
Performs the forward pass of the layer.
Args:
inputs: Input tensor.
Returns:
Output tensor with global context attention applied.
"""
# Context Modelling Block
input_flat = self._spatial_flattenND(inputs) # [B, spatial_dims, C]
context = self.conv_context(inputs) # [B, spatial_dims, 1]
context = self._spatial_flattenND(context) # [B, spatial_dims, 1]
context = self.softmax(context) # Apply softmax over spatial_dims
context = self.dot([input_flat, context]) # [B, C, 1]
context = self._spatial_expandND(context) # [B, C, 1, 1, ...]
# Transform Block
transform = self.conv_transform_bottleneck(context) # [B, C // R, 1, 1, ...]
transform = self.activation_relu(transform)
transform = self.conv_transform_output(transform) # [B, C, 1, 1, ...]
transform = self.activation_transform(transform)
# Apply context transform
out = self.add([inputs, transform]) # [B, spatial_dims, C]
return out
def _convND_layer(self, filters):
"""
Creates a Conv1D, Conv2D, or Conv3D layer based on the input rank.
Args:
filters (int): Number of filters for the convolutional layer.
Returns:
A Conv1D, Conv2D, or Conv3D layer.
"""
if self.rank == 3:
return Conv1D(filters, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal')
elif self.rank == 4:
return Conv2D(filters, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal')
elif self.rank == 5:
return Conv3D(filters, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal')
def _spatial_flattenND(self, ip):
"""
Flattens the spatial dimensions of the input tensor.
Args:
ip: Input tensor.
Returns:
Flattened tensor.
"""
if self.rank == 3:
return ip # Identity op for rank 3
else:
shape = (ip.shape[self.channel_dim], -1) if self.channel_dim == 1 else (-1, ip.shape[-1])
return Reshape(shape)(ip)
def _spatial_expandND(self, ip):
"""
Expands the spatial dimensions of the input tensor.
Args:
ip: Input tensor.
Returns:
Expanded tensor.
"""
if self.rank == 3:
return Permute((2, 1))(ip) # Identity op for rank 3
else:
shape = (-1, *(1 for _ in range(self.rank - 2))) if self.channel_dim == 1 else (*(1 for _ in range(self.rank - 2)), -1)
return Reshape(shape)(ip)
def _get_flat_spatial_dim(self):
"""
Returns the axis for flattening spatial dimensions.
Returns:
Axis for flattening.
"""
return 1 if self.channel_dim == 1 else -1
def get_config(self):
"""
Returns the configuration of the layer for serialization.
Returns:
A dictionary containing the layer configuration.
"""
config = super(GlobalContextAttention, self).get_config()
config.update({
'reduction_ratio': self.reduction_ratio,
'transform_activation': self.transform_activation,
})
return config
class AttentionGate(Layer):
def __init__(self, filters, **kwargs):
self.filters = filters
super(AttentionGate, self).__init__(**kwargs)
def build(self, input_shape):
# Create trainable parameters for attention gate
self.conv_xl = Conv2D(self.filters, kernel_size=(1, 1), strides=(1, 1), padding='same', activation='relu')
self.conv_g = Conv2D(self.filters, kernel_size=(1, 1), strides=(1, 1), padding='same', activation='relu')
self.psi = Conv2D(1, kernel_size=(1, 1), strides=(1, 1), padding='same', activation='linear')
self.layer_norm = LayerNormalization(axis=-1)
# Build the child layers
self.conv_xl.build(input_shape[0]) # Build conv_xl with the shape of xl
self.conv_g.build(input_shape[1]) # Build conv_g with the shape of g
self.psi.build(input_shape[0]) # Build psi with the shape of xl
self.layer_norm.build(input_shape[0]) # Build layer_norm with the shape of xl
# Add trainable weights
self.bxg = self.add_weight(name='bxg',
shape=(self.filters,),
initializer='zeros',
trainable=True)
self.bpsi = self.add_weight(name='bpsi',
shape=(1,),
initializer='zeros',
trainable=True)
super(AttentionGate, self).build(input_shape)
def call(self, inputs):
xl, g = inputs
# Apply convolutional operations
xl_conv = self.conv_xl(xl)
g_conv = self.conv_g(g)
# Compute additive attention
att = tf.keras.backend.relu(xl_conv + g_conv + self.bxg)
att = self.layer_norm(att) # Add LayerNormalization
att = self.psi(att) + self.bpsi
att = tf.keras.backend.sigmoid(att)
# Apply attention gate
x_hat = att * xl
return x_hat
def compute_output_shape(self, input_shape):
return input_shape[0]
def get_config(self):
config = super(AttentionGate, self).get_config()
config.update({'filters': self.filters})
return config
class GCRMSprop(RMSprop):
def get_gradients(self, loss, params):
# We here just provide a modified get_gradients() function since we are trying to just compute the centralized gradients.
grads = []
gradients = super().get_gradients()
for grad in gradients:
grad_len = len(grad.shape)
if grad_len > 1:
axis = list(range(grad_len - 1))
grad -= ops.mean(grad, axis=axis, keep_dims=True)
grads.append(grad)
return grads
# Build the model
def build_model(input_shape, num_classes):
try:
logging.info("Loading weights of EfficientNetV2B0...")
base_model = EfficientNetV2B0(weights='imagenet', include_top=False, input_shape=input_shape)
fmaps = base_model.output
logging.info("Initializing Global Context Attention...")
context_fmaps = GlobalContextAttention()(fmaps)
logging.info("Initializing AttentionGate...")
att_fmaps = AttentionGate(fmaps.shape[-1])([fmaps, context_fmaps])
x = layers.GlobalAveragePooling2D()(att_fmaps)
# First Dense Layer
x = layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.005))(x)
x = layers.BatchNormalization()(x)
x = layers.Dropout(0.3)(x)
# Second Dense Layer
x = layers.Dense(256, activation='relu', kernel_regularizer=regularizers.l2(0.005))(x)
x = layers.BatchNormalization()(x)
x = layers.Dropout(0.2)(x)
output = layers.Dense(num_classes, activation='softmax')(x)
model = models.Model(inputs=base_model.input, outputs=output)
model.compile(optimizer=GCRMSprop(learning_rate=1e-4), loss='categorical_crossentropy', metrics=['accuracy'])
logging.info("Model Built Successfully!")
return model
except Exception as e:
raise CustomException(e, sys)
# Train the model
def train_model(model, X_train, X_test, y_train, y_test):
try:
# Define the necessary callbacks
checkpoint = ModelCheckpoint(config.model_path, monitor='val_accuracy', verbose=1, save_best_only=True, mode='max')
callbacks = [checkpoint]
logging.info(f"Training network for {config.EPOCHS} epochs...")
hist = model.fit(X_train, y_train, batch_size=config.batch_size,
validation_data=(X_test, y_test),
epochs=config.EPOCHS, callbacks=callbacks)
return hist
except Exception as e:
raise CustomException(e, sys)
# Evaluate the model
def evaluate_model(model, X_test, y_test):
try:
y_score = model.predict(X_test)
y_pred = np.argmax(y_score, axis=-1)
Y_test = np.argmax(y_test, axis=-1)
acc = accuracy_score(Y_test,y_pred)
mpre = precision_score(Y_test,y_pred,average='macro')
mrecall = recall_score(Y_test,y_pred,average='macro')
mf1 = f1_score(Y_test,y_pred,average='macro')
kappa = cohen_kappa_score(Y_test,y_pred,weights='quadratic')
auc = roc_auc_score(Y_test, y_score, average='macro', multi_class='ovr')
logging.info(f"Accuracy: {round(acc*100,2)}")
logging.info(f"Macro Precision: {round(mpre*100,2)}")
logging.info(f"Macro Recall: {round(mrecall*100,2)}")
logging.info(f"Macro F1-Score: {round(mf1*100,2)}")
logging.info(f"Quadratic Kappa Score: {round(kappa*100,2)}")
logging.info(f"ROC AUC Score: {round(auc*100,2)}")
logging.info(classification_report(Y_test, y_pred, digits=4))
except Exception as e:
raise CustomException(e, sys)
|