awsaf49 commited on
Commit
37252de
·
2 Parent(s): 4a0cabe 094461a

pull from remote

Browse files
Files changed (1) hide show
  1. gcvit/utils/gradcam.py +71 -0
gcvit/utils/gradcam.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import tensorflow as tf
2
  import matplotlib.cm as cm
3
  import numpy as np
@@ -66,4 +67,74 @@ def get_gradcam_prediction(img, grad_model, process=True, decode=True, pred_inde
66
  overlay = array_to_img(overlay)
67
  # decode prediction
68
  preds_decode = tf.keras.applications.imagenet_utils.decode_predictions(preds.numpy())[0] if decode else preds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  return preds_decode, overlay
 
1
+ <<<<<<< HEAD
2
  import tensorflow as tf
3
  import matplotlib.cm as cm
4
  import numpy as np
 
67
  overlay = array_to_img(overlay)
68
  # decode prediction
69
  preds_decode = tf.keras.applications.imagenet_utils.decode_predictions(preds.numpy())[0] if decode else preds
70
+ =======
71
+ import tensorflow as tf
72
+ import matplotlib.cm as cm
73
+ import numpy as np
74
+ try:
75
+ from tensorflow.keras.utils import array_to_img, img_to_array
76
+ except:
77
+ from tensorflow.keras.preprocessing.image import array_to_img, img_to_array
78
+
79
+ def process_image(img, size=(224, 224)):
80
+ img_array = tf.keras.applications.imagenet_utils.preprocess_input(img, mode='torch')
81
+ img_array = tf.image.resize(img_array, size,)[None,]
82
+ return img_array
83
+
84
+ def get_gradcam_model(model):
85
+ inp = tf.keras.Input(shape=(224, 224, 3))
86
+ feats = model.forward_features(inp)
87
+ preds = model.forward_head(feats)
88
+ return tf.keras.models.Model(inp, [preds, feats])
89
+
90
+ def get_gradcam_prediction(img, grad_model, process=True, decode=True, pred_index=None, cmap='jet', alpha=0.6):
91
+ """Grad-CAM for a single image
92
+
93
+ Args:
94
+ img (np.ndarray): process or raw image without batch_shape e.g. (224, 224, 3)
95
+ grad_model (tf.keras.Model): model with feature map and prediction
96
+ process (bool, optional): imagenet pre-processing. Defaults to True.
97
+ pred_index (int, optional): for particular calss. Defaults to None.
98
+ cmap (str, optional): colormap. Defaults to 'jet'.
99
+ alpha (float, optional): opacity. Defaults to 0.4.
100
+
101
+ Returns:
102
+ preds_decode: top5 predictions
103
+ heatmap: gradcam heatmap
104
+ """
105
+ # process image for inference
106
+ if process:
107
+ img_array = process_image(img)
108
+ else:
109
+ img_array = tf.convert_to_tensor(img)[None,]
110
+ if img.min()!=img.max():
111
+ img = (img - img.min())/(img.max() - img.min())
112
+ img = np.uint8(img*255.0)
113
+ # get prediction
114
+ with tf.GradientTape(persistent=True) as tape:
115
+ preds, feats = grad_model(img_array)
116
+ if pred_index is None:
117
+ pred_index = tf.argmax(preds[0])
118
+ class_channel = preds[:, pred_index]
119
+ # compute heatmap
120
+ grads = tape.gradient(class_channel, feats)
121
+ pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
122
+ feats = feats[0]
123
+ heatmap = feats @ pooled_grads[..., tf.newaxis]
124
+ heatmap = tf.squeeze(heatmap)
125
+ heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
126
+ heatmap = heatmap.numpy()
127
+ heatmap = np.uint8(255 * heatmap)
128
+ # colorize heatmap
129
+ cmap = cm.get_cmap(cmap)
130
+ colors = cmap(np.arange(256))[:, :3]
131
+ heatmap = colors[heatmap]
132
+ heatmap = array_to_img(heatmap)
133
+ heatmap = heatmap.resize((img.shape[1], img.shape[0]))
134
+ heatmap = img_to_array(heatmap)
135
+ overlay = img + heatmap * alpha
136
+ overlay = array_to_img(overlay)
137
+ # decode prediction
138
+ preds_decode = tf.keras.applications.imagenet_utils.decode_predictions(preds.numpy())[0] if decode else preds
139
+ >>>>>>> 094461a8d383ad2565311ea9a0094b5856887867
140
  return preds_decode, overlay