awsaf49 commited on
Commit
21fa3d7
·
1 Parent(s): 37252de

fix conflict

Browse files
Files changed (2) hide show
  1. gcvit/utils/gradcam.py +0 -71
  2. setup.py +0 -50
gcvit/utils/gradcam.py CHANGED
@@ -1,73 +1,3 @@
1
- <<<<<<< HEAD
2
- import tensorflow as tf
3
- import matplotlib.cm as cm
4
- import numpy as np
5
- try:
6
- from tensorflow.keras.utils import array_to_img, img_to_array
7
- except:
8
- from tensorflow.keras.preprocessing.image import array_to_img, img_to_array
9
-
10
- def process_image(img, size=(224, 224)):
11
- img_array = tf.keras.applications.imagenet_utils.preprocess_input(img, mode='torch')
12
- img_array = tf.image.resize(img_array, size,)[None,]
13
- return img_array
14
-
15
- def get_gradcam_model(model):
16
- inp = tf.keras.Input(shape=(224, 224, 3))
17
- feats = model.forward_features(inp)
18
- preds = model.forward_head(feats)
19
- return tf.keras.models.Model(inp, [preds, feats])
20
-
21
- def get_gradcam_prediction(img, grad_model, process=True, decode=True, pred_index=None, cmap='jet', alpha=0.4):
22
- """Grad-CAM for a single image
23
-
24
- Args:
25
- img (np.ndarray): process or raw image without batch_shape e.g. (224, 224, 3)
26
- grad_model (tf.keras.Model): model with feature map and prediction
27
- process (bool, optional): imagenet pre-processing. Defaults to True.
28
- pred_index (int, optional): for particular calss. Defaults to None.
29
- cmap (str, optional): colormap. Defaults to 'jet'.
30
- alpha (float, optional): opacity. Defaults to 0.4.
31
-
32
- Returns:
33
- preds_decode: top5 predictions
34
- heatmap: gradcam heatmap
35
- """
36
- # process image for inference
37
- if process:
38
- img_array = process_image(img)
39
- else:
40
- img_array = tf.convert_to_tensor(img)[None,]
41
- if img.min()!=img.max():
42
- img = (img - img.min())/(img.max() - img.min())
43
- img = np.uint8(img*255.0)
44
- # get prediction
45
- with tf.GradientTape(persistent=True) as tape:
46
- preds, feats = grad_model(img_array)
47
- if pred_index is None:
48
- pred_index = tf.argmax(preds[0])
49
- class_channel = preds[:, pred_index]
50
- # compute heatmap
51
- grads = tape.gradient(class_channel, feats)
52
- pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
53
- feats = feats[0]
54
- heatmap = feats @ pooled_grads[..., tf.newaxis]
55
- heatmap = tf.squeeze(heatmap)
56
- heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
57
- heatmap = heatmap.numpy()
58
- heatmap = np.uint8(255 * heatmap)
59
- # colorize heatmap
60
- cmap = cm.get_cmap(cmap)
61
- colors = cmap(np.arange(256))[:, :3]
62
- heatmap = colors[heatmap]
63
- heatmap = array_to_img(heatmap)
64
- heatmap = heatmap.resize((img.shape[1], img.shape[0]))
65
- heatmap = img_to_array(heatmap)
66
- overlay = img + heatmap * alpha
67
- overlay = array_to_img(overlay)
68
- # decode prediction
69
- preds_decode = tf.keras.applications.imagenet_utils.decode_predictions(preds.numpy())[0] if decode else preds
70
- =======
71
  import tensorflow as tf
72
  import matplotlib.cm as cm
73
  import numpy as np
@@ -136,5 +66,4 @@ def get_gradcam_prediction(img, grad_model, process=True, decode=True, pred_inde
136
  overlay = array_to_img(overlay)
137
  # decode prediction
138
  preds_decode = tf.keras.applications.imagenet_utils.decode_predictions(preds.numpy())[0] if decode else preds
139
- >>>>>>> 094461a8d383ad2565311ea9a0094b5856887867
140
  return preds_decode, overlay
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import tensorflow as tf
2
  import matplotlib.cm as cm
3
  import numpy as np
 
66
  overlay = array_to_img(overlay)
67
  # decode prediction
68
  preds_decode = tf.keras.applications.imagenet_utils.decode_predictions(preds.numpy())[0] if decode else preds
 
69
  return preds_decode, overlay
setup.py DELETED
@@ -1,50 +0,0 @@
1
- from setuptools import setup, find_packages
2
- from codecs import open
3
- from os import path
4
-
5
- here = path.abspath(path.dirname(__file__))
6
-
7
- # Get the long description from the README file
8
- with open(path.join(here, "README.md"), encoding="utf-8") as f:
9
- long_description = f.read()
10
-
11
- with open(path.join(here, 'requirements.txt')) as f:
12
- install_requires = [x for x in f.read().splitlines() if len(x)]
13
-
14
- exec(open("gcvit/version.py").read())
15
-
16
- setup(
17
- name="gcvit",
18
- version=__version__,
19
- description="Tensorflow 2.0 Implementation of GCViT: Global Context Vision Transformer. https://github.com/awsaf49/gcvit-tf",
20
- long_description=long_description,
21
- long_description_content_type="text/markdown",
22
- url="https://github.com/awsaf49/gcvit-tf",
23
- author="Awsaf",
24
- author_email="[email protected]",
25
- classifiers=[
26
- # How mature is this project? Common values are
27
- # 3 - Alpha
28
- # 4 - Beta
29
- # 5 - Production/Stable
30
- "Development Status :: 3 - Alpha",
31
- "Intended Audience :: Developers",
32
- "Intended Audience :: Science/Research",
33
- "License :: OSI Approved :: Apache Software License",
34
- "Programming Language :: Python :: 3.6",
35
- "Programming Language :: Python :: 3.7",
36
- "Programming Language :: Python :: 3.8",
37
- "Topic :: Scientific/Engineering",
38
- "Topic :: Scientific/Engineering :: Artificial Intelligence",
39
- "Topic :: Software Development",
40
- "Topic :: Software Development :: Libraries",
41
- "Topic :: Software Development :: Libraries :: Python Modules",
42
- ],
43
- # Note that this is a string of words separated by whitespace, not a list.
44
- keywords="tensorflow computer_vision image classification transformer",
45
- packages=find_packages(exclude=["tests"]),
46
- include_package_data=True,
47
- install_requires=install_requires,
48
- python_requires=">=3.6",
49
- license="MIT",
50
- )