A Robust Prototype-Based Network with Interpretable RBF Classifier Foundations
This repository contains the Deep Classification-by-Component (CBC) models for prototype-based learning interpretability benchmarks for classification as described in the paper "A Robust Prototype-Based Network with Interpretable RBF Classifier Foundations"
Model Description
The CBC approach learns components (or prototypes) to create interpretable learning insights. It uses positive and negative reasoning to reason about the class predictions i.e. the presence and absence of components creates evidence for a given class to be predicted as that class.
The deep_cbc
package provides trainer, evaluation
and visualization scripts for the CBC models in deep settings with CNN architecture as feature
extractor backbones. Further, CBC with positive reasoning is equivalent to having an RBF
classification head. Additionally, we provide compatibility support with the PIPNet
classification head as well.
Available and Supported Architectures
We provide two variants of CNNs for each of the CUB-200-2011, Stanford Cars and Oxford-IIIT dataset:
- ResNet50 w/ CBC Classification Head: Built on both partially trained and fully trained
backbone from the
model_zoo
module inpytorch
. - ConvNeXt w/ CBC Classification Head: Built on partially trained trained
convnext_tiny
backbone fromtorchvision
.
Further, training the above two architectures is possible with an RBF and PIPNet classification head as well.
Performance
All models were trained and evaluated on the CUB-200-2011 (CUB), Stanford Cars (CARS) and Oxford-IIIT Pets (PETS) datasets and below we report the top-1 classification accuracy results on these datasets.
Model Version | Backbone | CUB | CARS | PETS |
---|---|---|---|---|
CBC-C | convnext_tiny |
87.8 ± 0.1 % | 93.0 ± 0.1 % | 93.9 ± 0.1 % |
CBC-R | resnet50 |
83.3 ± 0.3 % | 92.7 ± 0.1 % | 90.1 ± 0.1 % |
CBC-R Full | resnet50 |
82.8 ± 0.3 % | 92.8 ± 0.1 % | 89.5 ± 0.2 % |
Model Features
- 🔍 Interpretable Decision Assistance: The model performs classification by using positive and negative reasoning based on learnt components (or prototypes) to provide interpretable decision-making insights for assistance.
- 🎯 SotA Accuracy: Achieves SotA performance on classification tasks for the interpretability benchmarks.
- 🚀 Multiple Feature Extractor CNN Backbones: Supports ConvNeXt and ResNet50 feature extractor architecture backbones with CBC heads for interpretable image classification tasks.
- 📊 Visualization and Analysis Tools: Equipped with visualization tools to plot learnt prototype patches and corresponding activation maps alongside the similarity score and detection probability metrics.
Requirements
- python = "^3.9"
- numpy = "1.26.4"
- matplotlib = "3.8.4"
- scikit-learn = "1.4.2"
- scipy = "1.13.0"
- pillow = "10.3.0"
- omegaconf = "2.3.0"
- hydra-core = "1.3.2"
- torch = "2.2.2"
- torchvision = "0.17.2"
- setuptools = "68.2.0"
The basic dependencies for using the models are stated above. Please, refer to the GitHub repository for detailed dependencies and project setup instructions to execute experiments with the above models.
Limitations and Bias
- ❗ Partial Interpretability Issue: The uninterpretable feature extractor CNN backbone introduces an uninterpretable component into the model. Although, we achieve SotA accuracy and demonstrate that the models provide quality positive and negative reasoning explanations. But, still we can only call these methods partially interpretable owing to the fact that all prototypes learnt are not human interpretable.
- ❗ Data Bias Issue: These models are trained on CUB-200-2011, Stanford Cars and Oxford-IIIT Pet datasets and the stated model performance would not generalize to other domains.
- ❗ Resolution Constraints Issue: The model backbones are pre-trained with a resolution of 224×224. Although models can flexibly input images of different resolutions with current data loaders. The performance will be suboptimal owing to fixed receptive fields learnt by networks for a given resolution. Possibly, a scope of improvement on Stanford Cars dataset can be to standardize image sizes as a pre-processing step to achieve better performance.
- ❗ Location Misalignment Issue: CNN based models are not perfectly immune to location misalignment under adversarial attack. Hence, with blackbox feature extractor the learnt prototype-based networks are also prone to such issues.
Citation
If you use this model in your research, please consider to cite:
@article{saralajew2024robust,
title={A Robust Prototype-Based Network with Interpretable RBF Classifier Foundations},
author={Saralajew, Sascha and Rana, Ashish and Villmann, Thomas and Shaker, Ammar},
journal={arXiv preprint arXiv:2412.15499},
year={2024}
}
Acknowledgements
This implementation builds upon the following excellent repositories:
And further these repositories can be referred to as additional documentation details specified in the above two repositories regarding the data pre-processing, data loaders, model architectures and visualizations.
License
This project is released under [MIT] license.
Contact
For any questions or feedback, please:
- Open an issue in the project GitHub repository
- Contact the Correspondence Author