rwightman HF staff commited on
Commit
19f8d94
·
verified ·
1 Parent(s): 265971f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -23
app.py CHANGED
@@ -6,7 +6,7 @@ import requests
6
  from io import BytesIO
7
  import numpy as np
8
  from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
9
- from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget, get_target_layer
10
  from pytorch_grad_cam.utils.image import show_cam_on_image
11
  from timm.data import create_transform
12
 
@@ -51,7 +51,7 @@ def process_image(image_path, model):
51
  return tensor
52
 
53
  def get_cam_image(model, image, target_layer, cam_method):
54
- cam = CAM_METHODS[cam_method](model=model, target_layers=[target_layer], use_cuda=torch.cuda.is_available())
55
  grayscale_cam = cam(input_tensor=image)
56
 
57
  config = model.pretrained_cfg
@@ -69,14 +69,23 @@ def get_feature_info(model):
69
  else:
70
  return []
71
 
 
 
 
 
 
 
 
 
 
 
72
  def explain_image(model_name, image_path, cam_method, feature_module):
73
  model = load_model(model_name)
74
  image = process_image(image_path, model)
75
 
76
- if feature_module:
77
- target_layer = get_target_layer(model, feature_module)
78
- print(f"Using feature module: {feature_module}")
79
- else:
80
  # Fallback to the last feature module or last convolutional layer
81
  feature_info = get_feature_info(model)
82
  if feature_info:
@@ -99,22 +108,29 @@ def explain_image(model_name, image_path, cam_method, feature_module):
99
  def update_feature_modules(model_name):
100
  model = load_model(model_name)
101
  feature_modules = get_feature_info(model)
102
- return gr.Dropdown.update(choices=feature_modules, value=feature_modules[-1] if feature_modules else None)
103
 
104
- iface = gr.Interface(
105
- fn=explain_image,
106
- inputs=[
107
- gr.Dropdown(choices=MODELS, label="Select Model"),
108
- gr.Image(type="filepath", label="Upload Image"),
109
- gr.Dropdown(choices=list(CAM_METHODS.keys()), label="Select CAM Method"),
110
- gr.Dropdown(label="Select Feature Module (optional)")
111
- ],
112
- outputs=gr.Image(type="pil", label="Explained Image"),
113
- title="Explainable AI with timm models",
114
- description="Upload an image, select a model, CAM method, and optionally a specific feature module to visualize the explanation.",
115
- allow_flagging="never"
116
- )
117
-
118
- iface.load(update_feature_modules, inputs=[iface.inputs[0]], outputs=[iface.inputs[3]])
 
 
 
 
 
 
 
119
 
120
- iface.launch()
 
6
  from io import BytesIO
7
  import numpy as np
8
  from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
9
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
10
  from pytorch_grad_cam.utils.image import show_cam_on_image
11
  from timm.data import create_transform
12
 
 
51
  return tensor
52
 
53
  def get_cam_image(model, image, target_layer, cam_method):
54
+ cam = CAM_METHODS[cam_method](model=model, target_layers=[target_layer])
55
  grayscale_cam = cam(input_tensor=image)
56
 
57
  config = model.pretrained_cfg
 
69
  else:
70
  return []
71
 
72
+ def get_target_layer(model, target_layer_name):
73
+ if target_layer_name is None:
74
+ return None
75
+
76
+ try:
77
+ return model.get_submodule(target_layer_name)
78
+ except AttributeError:
79
+ print(f"WARNING: Layer '{target_layer_name}' not found in the model.")
80
+ return None
81
+
82
  def explain_image(model_name, image_path, cam_method, feature_module):
83
  model = load_model(model_name)
84
  image = process_image(image_path, model)
85
 
86
+ target_layer = get_target_layer(model, feature_module)
87
+
88
+ if target_layer is None:
 
89
  # Fallback to the last feature module or last convolutional layer
90
  feature_info = get_feature_info(model)
91
  if feature_info:
 
108
  def update_feature_modules(model_name):
109
  model = load_model(model_name)
110
  feature_modules = get_feature_info(model)
111
+ return gr.Dropdown(choices=feature_modules, value=feature_modules[-1] if feature_modules else None)
112
 
113
+ with gr.Blocks() as demo:
114
+ gr.Markdown("# Explainable AI with timm models")
115
+ gr.Markdown("Upload an image, select a model, CAM method, and optionally a specific feature module to visualize the explanation.")
116
+
117
+ with gr.Row():
118
+ with gr.Column():
119
+ model_dropdown = gr.Dropdown(choices=MODELS, label="Select Model")
120
+ image_input = gr.Image(type="filepath", label="Upload Image")
121
+ cam_method_dropdown = gr.Dropdown(choices=list(CAM_METHODS.keys()), label="Select CAM Method")
122
+ feature_module_dropdown = gr.Dropdown(label="Select Feature Module (optional)")
123
+ explain_button = gr.Button("Explain Image")
124
+
125
+ with gr.Column():
126
+ output_image = gr.Image(type="pil", label="Explained Image")
127
+
128
+ model_dropdown.change(fn=update_feature_modules, inputs=[model_dropdown], outputs=[feature_module_dropdown])
129
+
130
+ explain_button.click(
131
+ fn=explain_image,
132
+ inputs=[model_dropdown, image_input, cam_method_dropdown, feature_module_dropdown],
133
+ outputs=[output_image]
134
+ )
135
 
136
+ demo.launch()