Spaces:
Runtime error
Runtime error
| import tensorflow as tf | |
| from tensorflow.keras import layers | |
| def spatial_attention_block(input_tensor): | |
| average_pooling = tf.reduce_max(input_tensor, axis=-1) | |
| average_pooling = tf.expand_dims(average_pooling, axis=-1) | |
| max_pooling = tf.reduce_mean(input_tensor, axis=-1) | |
| max_pooling = tf.expand_dims(max_pooling, axis=-1) | |
| concatenated = layers.Concatenate(axis=-1)([average_pooling, max_pooling]) | |
| feature_map = layers.Conv2D(1, kernel_size=(1, 1))(concatenated) | |
| feature_map = tf.nn.sigmoid(feature_map) | |
| return input_tensor * feature_map | |
| def channel_attention_block(input_tensor): | |
| channels = list(input_tensor.shape)[-1] | |
| average_pooling = layers.GlobalAveragePooling2D()(input_tensor) | |
| feature_descriptor = tf.reshape(average_pooling, shape=(-1, 1, 1, channels)) | |
| feature_activations = layers.Conv2D( | |
| filters=channels // 8, kernel_size=(1, 1), activation="relu" | |
| )(feature_descriptor) | |
| feature_activations = layers.Conv2D( | |
| filters=channels, kernel_size=(1, 1), activation="sigmoid" | |
| )(feature_activations) | |
| return input_tensor * feature_activations | |
| def dual_attention_unit_block(input_tensor): | |
| channels = list(input_tensor.shape)[-1] | |
| feature_map = layers.Conv2D( | |
| channels, kernel_size=(3, 3), padding="same", activation="relu" | |
| )(input_tensor) | |
| feature_map = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")( | |
| feature_map | |
| ) | |
| channel_attention = channel_attention_block(feature_map) | |
| spatial_attention = spatial_attention_block(feature_map) | |
| concatenation = layers.Concatenate(axis=-1)([channel_attention, spatial_attention]) | |
| concatenation = layers.Conv2D(channels, kernel_size=(1, 1))(concatenation) | |
| return layers.Add()([input_tensor, concatenation]) | |