import numpy as np |
import cv2 |
import matplotlib.pyplot as plt |
import tensorflow as tf |
from tensorflow.keras.preprocessing import image |
from tensorflow.keras import Model |
from gcg.utils import logging |
def grad_cam(model, img, |
layer_name="block5_conv3", label_name=None, |
category_id=None): |
"""Get a heatmap by Grad-CAM. |
Args: |
model: A model object, build from tf.keras 2.X. |
img: An image ndarray. |
layer_name: A string, layer name in model. |
label_name: A list or None, |
show the label name by assign this argument, |
it should be a list of all label names. |
category_id: An integer, index of the class. |
Default is the category with the highest score in the prediction. |
Return: |
A heatmap ndarray(without color). |
""" |
img_tensor = np.expand_dims(img, axis=0) |
conv_layer = model.get_layer(layer_name) |
heatmap_model = Model([model.inputs], [conv_layer.output, model.output]) |
with tf.GradientTape() as gtape: |
conv_output, predictions = heatmap_model(img_tensor) |
if category_id is None: |
category_id = np.argmax(predictions[0]) |
if label_name is not None: |
print(label_name[category_id]) |
output = predictions[:, category_id] |
grads = gtape.gradient(output, conv_output) |
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) |
heatmap = tf.reduce_mean(tf.multiply(pooled_grads, conv_output), axis=-1) |
heatmap = np.maximum(heatmap, 0) |
max_heat = np.max(heatmap) |
if max_heat == 0: |
max_heat = 1e-10 |
heatmap /= max_heat |
return np.squeeze(heatmap) |
def grad_cam_plus(model, img, |
layer_name="block5_conv3", label_name=None, |
category_id=None): |
"""Get a heatmap by Grad-CAM++. |
Args: |
model: A model object, build from tf.keras 2.X. |
img: An image ndarray. |
layer_name: A string, layer name in model. |
label_name: A list or None, |
show the label name by assign this argument, |
it should be a list of all label names. |
category_id: An integer, index of the class. |
Default is the category with the highest score in the prediction. |
Return: |
A heatmap ndarray(without color). |
""" |
img_tensor = np.expand_dims(img, axis=0) |
conv_layer = model.get_layer(layer_name) |
heatmap_model = Model([model.inputs], [conv_layer.output, model.output]) |
with tf.GradientTape() as gtape1: |
with tf.GradientTape() as gtape2: |
with tf.GradientTape() as gtape3: |
conv_output, predictions = heatmap_model(img_tensor) |
if category_id is None: |
category_id = np.argmax(predictions[0]) |
if label_name is not None: |
print(label_name[category_id]) |
output = predictions[:, category_id] |
conv_first_grad = gtape3.gradient(output, conv_output) |
conv_second_grad = gtape2.gradient(conv_first_grad, conv_output) |
conv_third_grad = gtape1.gradient(conv_second_grad, conv_output) |
global_sum = np.sum(conv_output, axis=(0, 1, 2)) |
alpha_num = conv_second_grad[0] |
alpha_denom = conv_second_grad[0]*2.0 + conv_third_grad[0]*global_sum |
alpha_denom = np.where(alpha_denom != 0.0, alpha_denom, 1e-10) |
alphas = alpha_num/alpha_denom |
alpha_normalization_constant = np.sum(alphas, axis=(0,1)) |
alphas /= alpha_normalization_constant |
weights = np.maximum(conv_first_grad[0], 0.0) |
deep_linearization_weights = np.sum(weights*alphas, axis=(0,1)) |
grad_cam_map = np.sum(deep_linearization_weights*conv_output[0], axis=2) |
heatmap = np.maximum(grad_cam_map, 0) |
max_heat = np.max(heatmap) |
if max_heat == 0: |
max_heat = 1e-10 |
heatmap /= max_heat |
return heatmap |
def preprocess_image(img_path, image_size=(512, 512, 3)): |
"""Preprocess the image by reshape and normalization. |
Args: |
img_path: A string. |
target_size: A tuple, reshape to this size. |
Return: |
An image array. |
""" |
img = image.load_img(img_path, target_size=image_size) |
img = image.img_to_array(img) |
return img |
def show_GradCAM(img, heatmap, alpha=0.4, save_path=None, return_array=False): |
"""Show the image with heatmap. |
Args: |
img_path: string. |
heatmap: image array, get it by calling grad_cam(). |
alpha: float, transparency of heatmap. |
return_array: bool, return a superimposed image array or not. |
Return: |
None or image array. |
""" |
heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0])) |
heatmap = (heatmap * 255).astype("uint8") |
heatmap_colored = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) |
superimposed_img = heatmap_colored * alpha + img |
superimposed_img = np.clip(superimposed_img, 0, 255).astype("uint8") |
fig, axes = plt.subplots(1, 3, figsize=(12, 4)) |
fig.subplots_adjust(wspace=0, hspace=0) |
axes[0].imshow(img) |
axes[0].set_title('Original Image') |
axes[0].axis('off') |
axes[1].imshow(heatmap_colored) |
axes[1].set_title('Heatmap') |
axes[1].axis('off') |
axes[2].imshow(superimposed_img) |
axes[2].set_title('Superimposed Image') |
axes[2].axis('off') |
plt.tight_layout() |
if save_path: |
fig.savefig(save_path) |
logging.info(f"Saved combined visualization to {save_path}") |
if return_array: |
return superimposed_img |