File size: 6,038 Bytes
1a49a0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
---
license: mit
base_model:
- torchvision/convnext_tiny
- pytorch/resnet50
metrics:
- accuracy
tags:
- Interpretability
- Explainable AI
- XAI
- Classification
- CNN
- Convolutional Neural Networks
---
# A Robust Prototype-Based Network with Interpretable RBF Classifier Foundations
This repository contains the Deep Classification-by-Component (CBC) models for
prototype-based learning interpretability benchmarks for classification as described in the paper
"A Robust Prototype-Based Network with Interpretable RBF Classifier Foundations"
## Model Description
The CBC approach learns components (or prototypes) to create interpretable learning insights.
It uses positive and negative reasoning to reason about the class predictions
i.e. the presence and absence of components creates evidence for a given class to be predicted
as that class.
The [`deep_cbc`](https://github.com/si-cim/cbc-aaai-2025) package provides trainer, evaluation
and visualization scripts for the CBC models in deep settings with CNN architecture as feature
extractor backbones. Further, CBC with positive reasoning is equivalent to having an RBF
classification head. Additionally, we provide compatibility support with the PIPNet
classification head as well.
### Available and Supported Architectures
We provide two variants of CNNs for each of the CUB-200-2011, Stanford Cars and
Oxford-IIIT dataset:
- **ResNet50 w/ CBC Classification Head**: Built on both partially trained and fully trained
backbone from the `model_zoo` module in `pytorch`.
- **ConvNeXt w/ CBC Classification Head**: Built on partially trained trained `convnext_tiny`
backbone from `torchvision`.
Further, training the above two architectures is possible with an RBF and PIPNet classification
head as well.
## Performance
All models were trained and evaluated on the CUB-200-2011 (CUB), Stanford Cars (CARS) and
Oxford-IIIT Pets (PETS) datasets and below we report the top-1 classification accuracy
results on these datasets.
| Model Version | Backbone | CUB | CARS | PETS |
|---------------|-----------------|--------------|--------------|--------------|
| CBC-C | `convnext_tiny` | 87.8 ± 0.1 % | 93.0 ± 0.1 % | 93.9 ± 0.1 % |
| CBC-R | `resnet50` | 83.3 ± 0.3 % | 92.7 ± 0.1 % | 90.1 ± 0.1 % |
| CBC-R Full | `resnet50` | 82.8 ± 0.3 % | 92.8 ± 0.1 % | 89.5 ± 0.2 % |
## Model Features
- 🔍 **Interpretable Decision Assistance:** The model performs classification by
using positive and negative reasoning based on learnt components (or prototypes) to provide
interpretable decision-making insights for assistance.
- 🎯 **SotA Accuracy:** Achieves SotA performance on classification tasks for the interpretability benchmarks.
- 🚀 **Multiple Feature Extractor CNN Backbones:** Supports ConvNeXt and ResNet50 feature extractor
architecture backbones with CBC heads for interpretable image classification tasks.
- 📊 **Visualization and Analysis Tools:** Equipped with visualization tools to plot learnt prototype patches and
corresponding activation maps alongside the similarity score and detection probability metrics.
## Requirements
- python = "^3.9"
- numpy = "1.26.4"
- matplotlib = "3.8.4"
- scikit-learn = "1.4.2"
- scipy = "1.13.0"
- pillow = "10.3.0"
- omegaconf = "2.3.0"
- hydra-core = "1.3.2"
- torch = "2.2.2"
- torchvision = "0.17.2"
- setuptools = "68.2.0"
The basic dependencies for using the models are stated above. Please, refer to the
[GitHub repository](https://github.com/si-cim/cbc-aaai-2025) for detailed dependencies
and project setup instructions to execute experiments with the above models.
## Limitations and Bias
- ❗ **Partial Interpretability Issue:** The uninterpretable feature extractor CNN backbone introduces
an uninterpretable component into the model. Although, we achieve SotA accuracy and demonstrate
that the models provide quality positive and negative reasoning explanations. But, still we
can only call these methods partially interpretable owing to the fact that all prototypes learnt
are not human interpretable.
- ❗ **Data Bias Issue:** These models are trained on CUB-200-2011, Stanford Cars and Oxford-IIIT Pet datasets
and the stated model performance would not generalize to other domains.
- ❗ **Resolution Constraints Issue:** The model backbones are pre-trained with a resolution of 224×224.
Although models can flexibly input images of different resolutions with current data loaders.
The performance will be suboptimal owing to fixed receptive fields learnt by networks for a given resolution.
Possibly, a scope of improvement on Stanford Cars dataset can be to standardize image sizes as
a pre-processing step to achieve better performance.
- ❗ **Location Misalignment Issue:** CNN based models are not perfectly immune to
location misalignment under adversarial attack. Hence, with blackbox feature extractor the
learnt prototype-based networks are also prone to such issues.
## Citation
If you use this model in your research, please consider to cite:
```bibtex
@article{saralajew2024robust,
title={A Robust Prototype-Based Network with Interpretable RBF Classifier Foundations},
author={Saralajew, Sascha and Rana, Ashish and Villmann, Thomas and Shaker, Ammar},
journal={arXiv preprint arXiv:2412.15499},
year={2024}
}
```
## Acknowledgements
This implementation builds upon the following excellent repositories:
- [PIPNet](https://github.com/M-Nauta/PIPNet)
- [ProtoPNet](https://github.com/cfchen-duke/ProtoPNet)
And further these repositories can be referred to as additional documentation details specified
in the above two repositories regarding the data pre-processing, data loaders,
model architectures and visualizations.
## License
This project is released under [MIT] license.
## Contact
For any questions or feedback, please:
1. Open an issue in the project [GitHub repository](https://github.com/si-cim/cbc-aaai-2025)
2. Contact the Correspondence Author
|