Spaces:
Runtime error
Runtime error
fix: reformat code for imagenette
Browse files- .gitattributes +3 -0
- app.py +34 -0
- model/att_pool/keras_metadata.pb +3 -0
- model/att_pool/saved_model.pb +3 -0
- model/att_pool/variables/variables.data-00000-of-00001 +0 -0
- model/att_pool/variables/variables.index +0 -0
- model/stem/keras_metadata.pb +3 -0
- model/stem/saved_model.pb +3 -0
- model/stem/variables/variables.data-00000-of-00001 +0 -0
- model/stem/variables/variables.index +0 -0
- model/trunk/keras_metadata.pb +3 -0
- model/trunk/saved_model.pb +3 -0
- model/trunk/variables/variables.data-00000-of-00001 +0 -0
- model/trunk/variables/variables.index +0 -0
- utilities/config.py +8 -0
- utilities/model.py +30 -0
- utilities/visualization.py +45 -0
.gitattributes
CHANGED
|
@@ -25,3 +25,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
| 27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
| 27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
|
| 29 |
+
# for macOS
|
| 30 |
+
.DS_Store
|
app.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import the necessary packages
|
| 2 |
+
from utilities import config
|
| 3 |
+
from utilities import model
|
| 4 |
+
from utilities import visualization
|
| 5 |
+
from tensorflow import keras
|
| 6 |
+
import gradio as gr
|
| 7 |
+
|
| 8 |
+
# load the models from disk
|
| 9 |
+
conv_stem = keras.models.load_model(
|
| 10 |
+
config.IMAGENETTE_STEM_PATH,
|
| 11 |
+
compile=False
|
| 12 |
+
)
|
| 13 |
+
conv_trunk = keras.models.load_model(
|
| 14 |
+
config.IMAGENETTE_TRUNK_PATH,
|
| 15 |
+
compile=False
|
| 16 |
+
)
|
| 17 |
+
conv_attn = keras.models.load_model(
|
| 18 |
+
config.IMAGENETTE_ATTN_PATH,
|
| 19 |
+
compile=False
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# create the patch conv net
|
| 23 |
+
patch_conv_net = model.PatchConvNet(
|
| 24 |
+
stem=conv_stem,
|
| 25 |
+
trunk=conv_trunk,
|
| 26 |
+
attention_pooling=conv_attn,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# get the plot attention function
|
| 30 |
+
plot_attention = visualization.PlotAttention(model=patch_conv_net)
|
| 31 |
+
iface = gr.Interface(
|
| 32 |
+
fn=plot_attention,
|
| 33 |
+
inputs=[gr.inputs.Image(label="Input Image")],
|
| 34 |
+
outputs="image").launch()
|
model/att_pool/keras_metadata.pb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fefb606f0aeb214dcfa9cf9786955f6b7ecb7bdd116e007c44264af75936bfca
|
| 3 |
+
size 15848
|
model/att_pool/saved_model.pb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0057c7816b0d297c8cdb02958c1a0044220eea870f923e38ba2c1a7fb01c9448
|
| 3 |
+
size 324550
|
model/att_pool/variables/variables.data-00000-of-00001
ADDED
|
Binary file (1.61 MB). View file
|
|
|
model/att_pool/variables/variables.index
ADDED
|
Binary file (1.38 kB). View file
|
|
|
model/stem/keras_metadata.pb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e41ea234e1a8f95c77c1308242971d818b50ec3212d3b1f1f34d9042d77f1270
|
| 3 |
+
size 11998
|
model/stem/saved_model.pb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ee90a781360388d5b7e2035be35b349f4209a6e58881bbc02e2c01cea8130877
|
| 3 |
+
size 96815
|
model/stem/variables/variables.data-00000-of-00001
ADDED
|
Binary file (1.56 MB). View file
|
|
|
model/stem/variables/variables.index
ADDED
|
Binary file (667 Bytes). View file
|
|
|
model/trunk/keras_metadata.pb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c090252e37268009bf4e4fc8866b2984bfdfbe88341ddb54dea464c9eb365bc4
|
| 3 |
+
size 23883
|
model/trunk/saved_model.pb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:255bf760cd4df017b1d4ff670f889ac9c5985d9b41d1d2d1d401c3030797164b
|
| 3 |
+
size 359160
|
model/trunk/variables/variables.data-00000-of-00001
ADDED
|
Binary file (2.96 MB). View file
|
|
|
model/trunk/variables/variables.index
ADDED
|
Binary file (733 Bytes). View file
|
|
|
utilities/config.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import the necessary packages
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
# define the path to the model
|
| 5 |
+
MODEL_PATH = "model"
|
| 6 |
+
IMAGENETTE_ATTN_PATH = os.path.join(MODEL_PATH, "att_pool")
|
| 7 |
+
IMAGENETTE_STEM_PATH = os.path.join(MODEL_PATH, "stem")
|
| 8 |
+
IMAGENETTE_TRUNK_PATH = os.path.join(MODEL_PATH, "trunk")
|
utilities/model.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import the necessary packages
|
| 2 |
+
from tensorflow import keras
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
|
| 5 |
+
# Patch conv
|
| 6 |
+
class PatchConvNet(keras.Model):
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
stem,
|
| 10 |
+
trunk,
|
| 11 |
+
attention_pooling,
|
| 12 |
+
**kwargs,
|
| 13 |
+
):
|
| 14 |
+
super().__init__(**kwargs)
|
| 15 |
+
self.stem = stem
|
| 16 |
+
self.trunk = trunk
|
| 17 |
+
self.attention_pooling = attention_pooling
|
| 18 |
+
|
| 19 |
+
@tf.function(
|
| 20 |
+
input_signature=[
|
| 21 |
+
tf.TensorSpec(shape=[None, None, None, 3], dtype=tf.uint8)
|
| 22 |
+
])
|
| 23 |
+
def call(self, images):
|
| 24 |
+
# pass through the stem
|
| 25 |
+
x = self.stem(images)
|
| 26 |
+
# pass through the trunk
|
| 27 |
+
x = self.trunk(x)
|
| 28 |
+
# pass through the attention pooling block
|
| 29 |
+
predictions, viz_weights = self.attention_pooling(x)
|
| 30 |
+
return predictions, viz_weights
|
utilities/visualization.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import the necessary packages
|
| 2 |
+
from tensorflow.keras import layers
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
class PlotAttention:
|
| 8 |
+
def __init__(self, model):
|
| 9 |
+
self.model = model
|
| 10 |
+
|
| 11 |
+
def __call__(self, image):
|
| 12 |
+
# resize the image to a 224, 224 dim
|
| 13 |
+
image = tf.image.convert_image_dtype(image, tf.float32)
|
| 14 |
+
image = tf.image.resize(image, (224, 224))
|
| 15 |
+
image = image[tf.newaxis, ...]
|
| 16 |
+
|
| 17 |
+
# pass through the stem
|
| 18 |
+
test_x = self.model.stem(image)
|
| 19 |
+
# pass through the trunk
|
| 20 |
+
test_x = self.model.trunk(test_x)
|
| 21 |
+
# pass through the attention pooling block
|
| 22 |
+
_, test_viz_weights = self.model.attention_pooling(test_x)
|
| 23 |
+
test_viz_weights = test_viz_weights[tf.newaxis, ...]
|
| 24 |
+
|
| 25 |
+
# reshape the vizualization weights
|
| 26 |
+
num_patches = tf.shape(test_viz_weights)[-1]
|
| 27 |
+
height = width = int(math.sqrt(num_patches))
|
| 28 |
+
test_viz_weights = layers.Reshape((height, width))(test_viz_weights)
|
| 29 |
+
|
| 30 |
+
index = 0
|
| 31 |
+
selected_image = image[index]
|
| 32 |
+
selected_weight = test_viz_weights[index]
|
| 33 |
+
|
| 34 |
+
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
|
| 35 |
+
ax[0].imshow(selected_image)
|
| 36 |
+
ax[0].set_title(f"Original")
|
| 37 |
+
ax[0].axis("off")
|
| 38 |
+
|
| 39 |
+
img = ax[1].imshow(selected_image)
|
| 40 |
+
ax[1].imshow(selected_weight, cmap='inferno', alpha=0.6, extent=img.get_extent())
|
| 41 |
+
ax[1].set_title(f"Attended")
|
| 42 |
+
ax[1].axis("off")
|
| 43 |
+
|
| 44 |
+
plt.axis("off")
|
| 45 |
+
return plt
|