File size: 7,261 Bytes
147fbb0 19f8d94 147fbb0 500c37b 147fbb0 500c37b 147fbb0 500c37b 147fbb0 500c37b 19f8d94 500c37b 147fbb0 19f8d94 500c37b 147fbb0 19f8d94 147fbb0 500c37b 147fbb0 19f8d94 147fbb0 500c37b 19f8d94 500c37b 19f8d94 500c37b 19f8d94 500c37b 19f8d94 500c37b 19f8d94 500c37b 19f8d94 147fbb0 19f8d94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import gradio as gr
import timm
import torch
from PIL import Image
import requests
from io import BytesIO
import numpy as np
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from timm.data import create_transform
from timm.data import infer_imagenet_subset, ImageNetInfo
# List of available timm models
MODELS = timm.list_pretrained()
# List of available GradCAM methods
CAM_METHODS = {
"GradCAM": GradCAM,
"HiResCAM": HiResCAM,
"ScoreCAM": ScoreCAM,
"GradCAM++": GradCAMPlusPlus,
"AblationCAM": AblationCAM,
"XGradCAM": XGradCAM,
"EigenCAM": EigenCAM,
"FullGrad": FullGrad
}
class CustomDatasetInfo:
def __init__(self, label_names, label_descriptions=None):
self.label_names = label_names
self.label_descriptions = label_descriptions or label_names
def index_to_description(self, index, detailed=False):
if detailed and self.label_descriptions:
return self.label_descriptions[index]
return self.label_names[index]
def load_model(model_name):
model = timm.create_model(model_name, pretrained=True)
model.eval()
return model
def process_image(image_path, model):
if image_path.startswith('http'):
response = requests.get(image_path)
image = Image.open(BytesIO(response.content))
else:
image = Image.open(image_path)
config = model.pretrained_cfg
transform = create_transform(
input_size=config['input_size'],
crop_pct=config['crop_pct'],
mean=config['mean'],
std=config['std'],
interpolation=config['interpolation'],
is_training=False
)
tensor = transform(image).unsqueeze(0)
return tensor
def get_cam_image(model, image, target_layer, cam_method, target_class):
if target_class is not None and target_class != "highest scoring":
target = ClassifierOutputTarget(target_class)
else:
target = None
cam = CAM_METHODS[cam_method](model=model, target_layers=[target_layer])
grayscale_cam = cam(input_tensor=image, targets=[target] if target else None)
config = model.pretrained_cfg
mean = torch.tensor(config['mean']).view(3, 1, 1)
std = torch.tensor(config['std']).view(3, 1, 1)
rgb_img = (image.squeeze(0) * std + mean).permute(1, 2, 0).cpu().numpy()
rgb_img = np.clip(rgb_img, 0, 1)
cam_image = show_cam_on_image(rgb_img, grayscale_cam[0, :], use_rgb=True)
return Image.fromarray(cam_image)
def get_feature_info(model):
if hasattr(model, 'feature_info'):
return [f['module'] for f in model.feature_info]
else:
return []
def get_target_layer(model, target_layer_name):
if target_layer_name is None:
return None
try:
return model.get_submodule(target_layer_name)
except AttributeError:
print(f"WARNING: Layer '{target_layer_name}' not found in the model.")
return None
def get_class_names(model):
dataset_info = None
label_names = model.pretrained_cfg.get("label_names", None)
label_descriptions = model.pretrained_cfg.get("label_descriptions", None)
if label_names is None:
imagenet_subset = infer_imagenet_subset(model)
if imagenet_subset:
dataset_info = ImageNetInfo(imagenet_subset)
else:
label_names = [f"LABEL_{i}" for i in range(model.num_classes)]
if dataset_info is None:
dataset_info = CustomDatasetInfo(
label_names=label_names,
label_descriptions=label_descriptions,
)
return dataset_info
def explain_image(model_name, image_path, cam_method, feature_module, target_class):
model = load_model(model_name)
image = process_image(image_path, model)
target_layer = get_target_layer(model, feature_module)
if target_layer is None:
feature_info = get_feature_info(model)
if feature_info:
target_layer = get_target_layer(model, feature_info[-1])
print(f"Using last feature module: {feature_info[-1]}")
else:
for name, module in reversed(list(model.named_modules())):
if isinstance(module, torch.nn.Conv2d):
target_layer = module
print(f"Fallback: Using last convolutional layer: {name}")
break
if target_layer is None:
raise ValueError("Could not find a suitable target layer.")
target_class_index = None if target_class == "highest scoring" else int(target_class.split(':')[0])
cam_image = get_cam_image(model, image, target_layer, cam_method, target_class_index)
with torch.no_grad():
out = model(image)
probabilities = out.squeeze(0).softmax(dim=0)
values, indices = torch.topk(probabilities, 5) # Top 5 predictions
dataset_info = get_class_names(model)
labels = [
f"{i}: {dataset_info.index_to_description(i.item(), detailed=True)} ({v.item():.2%})"
for i, v in zip(indices, values)
]
return cam_image, "\n".join(labels)
def update_feature_modules(model_name):
model = load_model(model_name)
feature_modules = get_feature_info(model)
return gr.Dropdown(choices=feature_modules, value=feature_modules[-1] if feature_modules else None)
def update_class_dropdown(model_name):
model = load_model(model_name)
dataset_info = get_class_names(model)
class_names = ["highest scoring"] + [f"{i}: {dataset_info.index_to_description(i, detailed=True)}" for i in range(model.num_classes)]
return gr.Dropdown(choices=class_names, value="highest scoring")
with gr.Blocks() as demo:
gr.Markdown("# Explainable AI with timm models. NOTE: This is a WIP but some models are functioning.")
gr.Markdown("Upload an image, select a model, CAM method, and optionally a specific feature module and target class to visualize the explanation.")
with gr.Row():
with gr.Column():
model_dropdown = gr.Dropdown(choices=MODELS, label="Select Model")
image_input = gr.Image(type="filepath", label="Upload Image")
cam_method_dropdown = gr.Dropdown(choices=list(CAM_METHODS.keys()), label="Select CAM Method")
feature_module_dropdown = gr.Dropdown(label="Select Feature Module (optional)")
class_dropdown = gr.Dropdown(label="Select Target Class (optional)")
explain_button = gr.Button("Explain Image")
with gr.Column():
output_image = gr.Image(type="pil", label="Explained Image")
prediction_text = gr.Textbox(label="Top 5 Predictions")
model_dropdown.change(fn=update_feature_modules, inputs=[model_dropdown], outputs=[feature_module_dropdown])
model_dropdown.change(fn=update_class_dropdown, inputs=[model_dropdown], outputs=[class_dropdown])
explain_button.click(
fn=explain_image,
inputs=[model_dropdown, image_input, cam_method_dropdown, feature_module_dropdown, class_dropdown],
outputs=[output_image, prediction_text]
)
demo.launch() |