fix conflict
Browse files- gcvit/utils/gradcam.py +0 -71
- setup.py +0 -50
gcvit/utils/gradcam.py
CHANGED
@@ -1,73 +1,3 @@
|
|
1 |
-
<<<<<<< HEAD
|
2 |
-
import tensorflow as tf
|
3 |
-
import matplotlib.cm as cm
|
4 |
-
import numpy as np
|
5 |
-
try:
|
6 |
-
from tensorflow.keras.utils import array_to_img, img_to_array
|
7 |
-
except:
|
8 |
-
from tensorflow.keras.preprocessing.image import array_to_img, img_to_array
|
9 |
-
|
10 |
-
def process_image(img, size=(224, 224)):
|
11 |
-
img_array = tf.keras.applications.imagenet_utils.preprocess_input(img, mode='torch')
|
12 |
-
img_array = tf.image.resize(img_array, size,)[None,]
|
13 |
-
return img_array
|
14 |
-
|
15 |
-
def get_gradcam_model(model):
|
16 |
-
inp = tf.keras.Input(shape=(224, 224, 3))
|
17 |
-
feats = model.forward_features(inp)
|
18 |
-
preds = model.forward_head(feats)
|
19 |
-
return tf.keras.models.Model(inp, [preds, feats])
|
20 |
-
|
21 |
-
def get_gradcam_prediction(img, grad_model, process=True, decode=True, pred_index=None, cmap='jet', alpha=0.4):
|
22 |
-
"""Grad-CAM for a single image
|
23 |
-
|
24 |
-
Args:
|
25 |
-
img (np.ndarray): process or raw image without batch_shape e.g. (224, 224, 3)
|
26 |
-
grad_model (tf.keras.Model): model with feature map and prediction
|
27 |
-
process (bool, optional): imagenet pre-processing. Defaults to True.
|
28 |
-
pred_index (int, optional): for particular calss. Defaults to None.
|
29 |
-
cmap (str, optional): colormap. Defaults to 'jet'.
|
30 |
-
alpha (float, optional): opacity. Defaults to 0.4.
|
31 |
-
|
32 |
-
Returns:
|
33 |
-
preds_decode: top5 predictions
|
34 |
-
heatmap: gradcam heatmap
|
35 |
-
"""
|
36 |
-
# process image for inference
|
37 |
-
if process:
|
38 |
-
img_array = process_image(img)
|
39 |
-
else:
|
40 |
-
img_array = tf.convert_to_tensor(img)[None,]
|
41 |
-
if img.min()!=img.max():
|
42 |
-
img = (img - img.min())/(img.max() - img.min())
|
43 |
-
img = np.uint8(img*255.0)
|
44 |
-
# get prediction
|
45 |
-
with tf.GradientTape(persistent=True) as tape:
|
46 |
-
preds, feats = grad_model(img_array)
|
47 |
-
if pred_index is None:
|
48 |
-
pred_index = tf.argmax(preds[0])
|
49 |
-
class_channel = preds[:, pred_index]
|
50 |
-
# compute heatmap
|
51 |
-
grads = tape.gradient(class_channel, feats)
|
52 |
-
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
|
53 |
-
feats = feats[0]
|
54 |
-
heatmap = feats @ pooled_grads[..., tf.newaxis]
|
55 |
-
heatmap = tf.squeeze(heatmap)
|
56 |
-
heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
|
57 |
-
heatmap = heatmap.numpy()
|
58 |
-
heatmap = np.uint8(255 * heatmap)
|
59 |
-
# colorize heatmap
|
60 |
-
cmap = cm.get_cmap(cmap)
|
61 |
-
colors = cmap(np.arange(256))[:, :3]
|
62 |
-
heatmap = colors[heatmap]
|
63 |
-
heatmap = array_to_img(heatmap)
|
64 |
-
heatmap = heatmap.resize((img.shape[1], img.shape[0]))
|
65 |
-
heatmap = img_to_array(heatmap)
|
66 |
-
overlay = img + heatmap * alpha
|
67 |
-
overlay = array_to_img(overlay)
|
68 |
-
# decode prediction
|
69 |
-
preds_decode = tf.keras.applications.imagenet_utils.decode_predictions(preds.numpy())[0] if decode else preds
|
70 |
-
=======
|
71 |
import tensorflow as tf
|
72 |
import matplotlib.cm as cm
|
73 |
import numpy as np
|
@@ -136,5 +66,4 @@ def get_gradcam_prediction(img, grad_model, process=True, decode=True, pred_inde
|
|
136 |
overlay = array_to_img(overlay)
|
137 |
# decode prediction
|
138 |
preds_decode = tf.keras.applications.imagenet_utils.decode_predictions(preds.numpy())[0] if decode else preds
|
139 |
-
>>>>>>> 094461a8d383ad2565311ea9a0094b5856887867
|
140 |
return preds_decode, overlay
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import tensorflow as tf
|
2 |
import matplotlib.cm as cm
|
3 |
import numpy as np
|
|
|
66 |
overlay = array_to_img(overlay)
|
67 |
# decode prediction
|
68 |
preds_decode = tf.keras.applications.imagenet_utils.decode_predictions(preds.numpy())[0] if decode else preds
|
|
|
69 |
return preds_decode, overlay
|
setup.py
DELETED
@@ -1,50 +0,0 @@
|
|
1 |
-
from setuptools import setup, find_packages
|
2 |
-
from codecs import open
|
3 |
-
from os import path
|
4 |
-
|
5 |
-
here = path.abspath(path.dirname(__file__))
|
6 |
-
|
7 |
-
# Get the long description from the README file
|
8 |
-
with open(path.join(here, "README.md"), encoding="utf-8") as f:
|
9 |
-
long_description = f.read()
|
10 |
-
|
11 |
-
with open(path.join(here, 'requirements.txt')) as f:
|
12 |
-
install_requires = [x for x in f.read().splitlines() if len(x)]
|
13 |
-
|
14 |
-
exec(open("gcvit/version.py").read())
|
15 |
-
|
16 |
-
setup(
|
17 |
-
name="gcvit",
|
18 |
-
version=__version__,
|
19 |
-
description="Tensorflow 2.0 Implementation of GCViT: Global Context Vision Transformer. https://github.com/awsaf49/gcvit-tf",
|
20 |
-
long_description=long_description,
|
21 |
-
long_description_content_type="text/markdown",
|
22 |
-
url="https://github.com/awsaf49/gcvit-tf",
|
23 |
-
author="Awsaf",
|
24 |
-
author_email="[email protected]",
|
25 |
-
classifiers=[
|
26 |
-
# How mature is this project? Common values are
|
27 |
-
# 3 - Alpha
|
28 |
-
# 4 - Beta
|
29 |
-
# 5 - Production/Stable
|
30 |
-
"Development Status :: 3 - Alpha",
|
31 |
-
"Intended Audience :: Developers",
|
32 |
-
"Intended Audience :: Science/Research",
|
33 |
-
"License :: OSI Approved :: Apache Software License",
|
34 |
-
"Programming Language :: Python :: 3.6",
|
35 |
-
"Programming Language :: Python :: 3.7",
|
36 |
-
"Programming Language :: Python :: 3.8",
|
37 |
-
"Topic :: Scientific/Engineering",
|
38 |
-
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
39 |
-
"Topic :: Software Development",
|
40 |
-
"Topic :: Software Development :: Libraries",
|
41 |
-
"Topic :: Software Development :: Libraries :: Python Modules",
|
42 |
-
],
|
43 |
-
# Note that this is a string of words separated by whitespace, not a list.
|
44 |
-
keywords="tensorflow computer_vision image classification transformer",
|
45 |
-
packages=find_packages(exclude=["tests"]),
|
46 |
-
include_package_data=True,
|
47 |
-
install_requires=install_requires,
|
48 |
-
python_requires=">=3.6",
|
49 |
-
license="MIT",
|
50 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|