File size: 7,261 Bytes
147fbb0
 
 
 
 
 
 
 
19f8d94
147fbb0
 
500c37b
147fbb0
 
500c37b
147fbb0
 
 
 
 
 
 
 
 
 
 
 
 
500c37b
 
 
 
 
 
 
 
 
 
147fbb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500c37b
 
 
 
 
 
19f8d94
500c37b
147fbb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19f8d94
 
 
 
 
 
 
 
 
 
500c37b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147fbb0
 
 
19f8d94
 
 
147fbb0
 
 
 
 
 
 
 
 
 
 
 
 
 
500c37b
 
 
 
 
 
 
 
 
 
 
 
 
 
147fbb0
 
 
 
19f8d94
147fbb0
500c37b
 
 
 
 
 
19f8d94
500c37b
 
19f8d94
 
 
 
 
 
 
500c37b
19f8d94
 
 
 
500c37b
19f8d94
 
500c37b
19f8d94
 
 
500c37b
 
19f8d94
147fbb0
19f8d94
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import gradio as gr
import timm
import torch
from PIL import Image
import requests
from io import BytesIO
import numpy as np
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from timm.data import create_transform
from timm.data import infer_imagenet_subset, ImageNetInfo

# List of available timm models
MODELS = timm.list_pretrained()

# List of available GradCAM methods
CAM_METHODS = {
    "GradCAM": GradCAM,
    "HiResCAM": HiResCAM,
    "ScoreCAM": ScoreCAM,
    "GradCAM++": GradCAMPlusPlus,
    "AblationCAM": AblationCAM,
    "XGradCAM": XGradCAM,
    "EigenCAM": EigenCAM,
    "FullGrad": FullGrad
}

class CustomDatasetInfo:
    def __init__(self, label_names, label_descriptions=None):
        self.label_names = label_names
        self.label_descriptions = label_descriptions or label_names

    def index_to_description(self, index, detailed=False):
        if detailed and self.label_descriptions:
            return self.label_descriptions[index]
        return self.label_names[index]

def load_model(model_name):
    model = timm.create_model(model_name, pretrained=True)
    model.eval()
    return model

def process_image(image_path, model):
    if image_path.startswith('http'):
        response = requests.get(image_path)
        image = Image.open(BytesIO(response.content))
    else:
        image = Image.open(image_path)
    
    config = model.pretrained_cfg
    transform = create_transform(
        input_size=config['input_size'],
        crop_pct=config['crop_pct'],
        mean=config['mean'],
        std=config['std'],
        interpolation=config['interpolation'],
        is_training=False
    )
    
    tensor = transform(image).unsqueeze(0)
    return tensor

def get_cam_image(model, image, target_layer, cam_method, target_class):
    if target_class is not None and target_class != "highest scoring":
        target = ClassifierOutputTarget(target_class)
    else:
        target = None
    
    cam = CAM_METHODS[cam_method](model=model, target_layers=[target_layer])
    grayscale_cam = cam(input_tensor=image, targets=[target] if target else None)
    
    config = model.pretrained_cfg
    mean = torch.tensor(config['mean']).view(3, 1, 1)
    std = torch.tensor(config['std']).view(3, 1, 1)
    rgb_img = (image.squeeze(0) * std + mean).permute(1, 2, 0).cpu().numpy()
    rgb_img = np.clip(rgb_img, 0, 1)
    
    cam_image = show_cam_on_image(rgb_img, grayscale_cam[0, :], use_rgb=True)
    return Image.fromarray(cam_image)

def get_feature_info(model):
    if hasattr(model, 'feature_info'):
        return [f['module'] for f in model.feature_info]
    else:
        return []

def get_target_layer(model, target_layer_name):
    if target_layer_name is None:
        return None
    
    try:
        return model.get_submodule(target_layer_name)
    except AttributeError:
        print(f"WARNING: Layer '{target_layer_name}' not found in the model.")
        return None

def get_class_names(model):
    dataset_info = None
    label_names = model.pretrained_cfg.get("label_names", None)
    label_descriptions = model.pretrained_cfg.get("label_descriptions", None)
    if label_names is None:
        imagenet_subset = infer_imagenet_subset(model)
        if imagenet_subset:
            dataset_info = ImageNetInfo(imagenet_subset)
        else:
            label_names = [f"LABEL_{i}" for i in range(model.num_classes)]
    if dataset_info is None:
        dataset_info = CustomDatasetInfo(
            label_names=label_names,
            label_descriptions=label_descriptions,
        )
    return dataset_info

def explain_image(model_name, image_path, cam_method, feature_module, target_class):
    model = load_model(model_name)
    image = process_image(image_path, model)
    
    target_layer = get_target_layer(model, feature_module)
    
    if target_layer is None:
        feature_info = get_feature_info(model)
        if feature_info:
            target_layer = get_target_layer(model, feature_info[-1])
            print(f"Using last feature module: {feature_info[-1]}")
        else:
            for name, module in reversed(list(model.named_modules())):
                if isinstance(module, torch.nn.Conv2d):
                    target_layer = module
                    print(f"Fallback: Using last convolutional layer: {name}")
                    break
    
    if target_layer is None:
        raise ValueError("Could not find a suitable target layer.")
    
    target_class_index = None if target_class == "highest scoring" else int(target_class.split(':')[0])
    cam_image = get_cam_image(model, image, target_layer, cam_method, target_class_index)
    
    with torch.no_grad():
        out = model(image)
    probabilities = out.squeeze(0).softmax(dim=0)
    values, indices = torch.topk(probabilities, 5)  # Top 5 predictions
    dataset_info = get_class_names(model)
    labels = [
        f"{i}: {dataset_info.index_to_description(i.item(), detailed=True)} ({v.item():.2%})"
        for i, v in zip(indices, values)
    ]
    
    return cam_image, "\n".join(labels)

def update_feature_modules(model_name):
    model = load_model(model_name)
    feature_modules = get_feature_info(model)
    return gr.Dropdown(choices=feature_modules, value=feature_modules[-1] if feature_modules else None)

def update_class_dropdown(model_name):
    model = load_model(model_name)
    dataset_info = get_class_names(model)
    class_names = ["highest scoring"] + [f"{i}: {dataset_info.index_to_description(i, detailed=True)}" for i in range(model.num_classes)]
    return gr.Dropdown(choices=class_names, value="highest scoring")

with gr.Blocks() as demo:
    gr.Markdown("# Explainable AI with timm models. NOTE: This is a WIP but some models are functioning.")
    gr.Markdown("Upload an image, select a model, CAM method, and optionally a specific feature module and target class to visualize the explanation.")
    
    with gr.Row():
        with gr.Column():
            model_dropdown = gr.Dropdown(choices=MODELS, label="Select Model")
            image_input = gr.Image(type="filepath", label="Upload Image")
            cam_method_dropdown = gr.Dropdown(choices=list(CAM_METHODS.keys()), label="Select CAM Method")
            feature_module_dropdown = gr.Dropdown(label="Select Feature Module (optional)")
            class_dropdown = gr.Dropdown(label="Select Target Class (optional)")
            explain_button = gr.Button("Explain Image")
        
        with gr.Column():
            output_image = gr.Image(type="pil", label="Explained Image")
            prediction_text = gr.Textbox(label="Top 5 Predictions")
    
    model_dropdown.change(fn=update_feature_modules, inputs=[model_dropdown], outputs=[feature_module_dropdown])
    model_dropdown.change(fn=update_class_dropdown, inputs=[model_dropdown], outputs=[class_dropdown])
    
    explain_button.click(
        fn=explain_image,
        inputs=[model_dropdown, image_input, cam_method_dropdown, feature_module_dropdown, class_dropdown],
        outputs=[output_image, prediction_text]
    )

demo.launch()