Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,7 @@ import requests
|
|
6 |
from io import BytesIO
|
7 |
import numpy as np
|
8 |
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
|
9 |
-
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
10 |
from pytorch_grad_cam.utils.image import show_cam_on_image
|
11 |
from timm.data import create_transform
|
12 |
|
@@ -51,7 +51,7 @@ def process_image(image_path, model):
|
|
51 |
return tensor
|
52 |
|
53 |
def get_cam_image(model, image, target_layer, cam_method):
|
54 |
-
cam = CAM_METHODS[cam_method](model=model, target_layers=[target_layer]
|
55 |
grayscale_cam = cam(input_tensor=image)
|
56 |
|
57 |
config = model.pretrained_cfg
|
@@ -69,14 +69,23 @@ def get_feature_info(model):
|
|
69 |
else:
|
70 |
return []
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
def explain_image(model_name, image_path, cam_method, feature_module):
|
73 |
model = load_model(model_name)
|
74 |
image = process_image(image_path, model)
|
75 |
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
else:
|
80 |
# Fallback to the last feature module or last convolutional layer
|
81 |
feature_info = get_feature_info(model)
|
82 |
if feature_info:
|
@@ -99,22 +108,29 @@ def explain_image(model_name, image_path, cam_method, feature_module):
|
|
99 |
def update_feature_modules(model_name):
|
100 |
model = load_model(model_name)
|
101 |
feature_modules = get_feature_info(model)
|
102 |
-
return gr.Dropdown
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
gr.
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
)
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
-
|
|
|
6 |
from io import BytesIO
|
7 |
import numpy as np
|
8 |
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
|
9 |
+
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
10 |
from pytorch_grad_cam.utils.image import show_cam_on_image
|
11 |
from timm.data import create_transform
|
12 |
|
|
|
51 |
return tensor
|
52 |
|
53 |
def get_cam_image(model, image, target_layer, cam_method):
|
54 |
+
cam = CAM_METHODS[cam_method](model=model, target_layers=[target_layer])
|
55 |
grayscale_cam = cam(input_tensor=image)
|
56 |
|
57 |
config = model.pretrained_cfg
|
|
|
69 |
else:
|
70 |
return []
|
71 |
|
72 |
+
def get_target_layer(model, target_layer_name):
|
73 |
+
if target_layer_name is None:
|
74 |
+
return None
|
75 |
+
|
76 |
+
try:
|
77 |
+
return model.get_submodule(target_layer_name)
|
78 |
+
except AttributeError:
|
79 |
+
print(f"WARNING: Layer '{target_layer_name}' not found in the model.")
|
80 |
+
return None
|
81 |
+
|
82 |
def explain_image(model_name, image_path, cam_method, feature_module):
|
83 |
model = load_model(model_name)
|
84 |
image = process_image(image_path, model)
|
85 |
|
86 |
+
target_layer = get_target_layer(model, feature_module)
|
87 |
+
|
88 |
+
if target_layer is None:
|
|
|
89 |
# Fallback to the last feature module or last convolutional layer
|
90 |
feature_info = get_feature_info(model)
|
91 |
if feature_info:
|
|
|
108 |
def update_feature_modules(model_name):
|
109 |
model = load_model(model_name)
|
110 |
feature_modules = get_feature_info(model)
|
111 |
+
return gr.Dropdown(choices=feature_modules, value=feature_modules[-1] if feature_modules else None)
|
112 |
|
113 |
+
with gr.Blocks() as demo:
|
114 |
+
gr.Markdown("# Explainable AI with timm models")
|
115 |
+
gr.Markdown("Upload an image, select a model, CAM method, and optionally a specific feature module to visualize the explanation.")
|
116 |
+
|
117 |
+
with gr.Row():
|
118 |
+
with gr.Column():
|
119 |
+
model_dropdown = gr.Dropdown(choices=MODELS, label="Select Model")
|
120 |
+
image_input = gr.Image(type="filepath", label="Upload Image")
|
121 |
+
cam_method_dropdown = gr.Dropdown(choices=list(CAM_METHODS.keys()), label="Select CAM Method")
|
122 |
+
feature_module_dropdown = gr.Dropdown(label="Select Feature Module (optional)")
|
123 |
+
explain_button = gr.Button("Explain Image")
|
124 |
+
|
125 |
+
with gr.Column():
|
126 |
+
output_image = gr.Image(type="pil", label="Explained Image")
|
127 |
+
|
128 |
+
model_dropdown.change(fn=update_feature_modules, inputs=[model_dropdown], outputs=[feature_module_dropdown])
|
129 |
+
|
130 |
+
explain_button.click(
|
131 |
+
fn=explain_image,
|
132 |
+
inputs=[model_dropdown, image_input, cam_method_dropdown, feature_module_dropdown],
|
133 |
+
outputs=[output_image]
|
134 |
+
)
|
135 |
|
136 |
+
demo.launch()
|