A Robust Prototype-Based Network with Interpretable RBF Classifier Foundations

This repository contains the Deep Classification-by-Component (CBC) models for prototype-based learning interpretability benchmarks for classification as described in the paper "A Robust Prototype-Based Network with Interpretable RBF Classifier Foundations"

Model Description

The CBC approach learns components (or prototypes) to create interpretable learning insights. It uses positive and negative reasoning to reason about the class predictions i.e. the presence and absence of components creates evidence for a given class to be predicted as that class.

The deep_cbc package provides trainer, evaluation and visualization scripts for the CBC models in deep settings with CNN architecture as feature extractor backbones. Further, CBC with positive reasoning is equivalent to having an RBF classification head. Additionally, we provide compatibility support with the PIPNet classification head as well.

Available and Supported Architectures

We provide two variants of CNNs for each of the CUB-200-2011, Stanford Cars and Oxford-IIIT dataset:

  • ResNet50 w/ CBC Classification Head: Built on both partially trained and fully trained backbone from the model_zoo module in pytorch.
  • ConvNeXt w/ CBC Classification Head: Built on partially trained trained convnext_tiny backbone from torchvision.

Further, training the above two architectures is possible with an RBF and PIPNet classification head as well.

Performance

All models were trained and evaluated on the CUB-200-2011 (CUB), Stanford Cars (CARS) and Oxford-IIIT Pets (PETS) datasets and below we report the top-1 classification accuracy results on these datasets.

Model Version Backbone CUB CARS PETS
CBC-C convnext_tiny 87.8 ± 0.1 % 93.0 ± 0.1 % 93.9 ± 0.1 %
CBC-R resnet50 83.3 ± 0.3 % 92.7 ± 0.1 % 90.1 ± 0.1 %
CBC-R Full resnet50 82.8 ± 0.3 % 92.8 ± 0.1 % 89.5 ± 0.2 %

Model Features

  • 🔍 Interpretable Decision Assistance: The model performs classification by using positive and negative reasoning based on learnt components (or prototypes) to provide interpretable decision-making insights for assistance.
  • 🎯 SotA Accuracy: Achieves SotA performance on classification tasks for the interpretability benchmarks.
  • 🚀 Multiple Feature Extractor CNN Backbones: Supports ConvNeXt and ResNet50 feature extractor architecture backbones with CBC heads for interpretable image classification tasks.
  • 📊 Visualization and Analysis Tools: Equipped with visualization tools to plot learnt prototype patches and corresponding activation maps alongside the similarity score and detection probability metrics.

Requirements

  • python = "^3.9"
  • numpy = "1.26.4"
  • matplotlib = "3.8.4"
  • scikit-learn = "1.4.2"
  • scipy = "1.13.0"
  • pillow = "10.3.0"
  • omegaconf = "2.3.0"
  • hydra-core = "1.3.2"
  • torch = "2.2.2"
  • torchvision = "0.17.2"
  • setuptools = "68.2.0"

The basic dependencies for using the models are stated above. Please, refer to the GitHub repository for detailed dependencies and project setup instructions to execute experiments with the above models.

Limitations and Bias

  • Partial Interpretability Issue: The uninterpretable feature extractor CNN backbone introduces an uninterpretable component into the model. Although, we achieve SotA accuracy and demonstrate that the models provide quality positive and negative reasoning explanations. But, still we can only call these methods partially interpretable owing to the fact that all prototypes learnt are not human interpretable.
  • Data Bias Issue: These models are trained on CUB-200-2011, Stanford Cars and Oxford-IIIT Pet datasets and the stated model performance would not generalize to other domains.
  • Resolution Constraints Issue: The model backbones are pre-trained with a resolution of 224×224. Although models can flexibly input images of different resolutions with current data loaders. The performance will be suboptimal owing to fixed receptive fields learnt by networks for a given resolution. Possibly, a scope of improvement on Stanford Cars dataset can be to standardize image sizes as a pre-processing step to achieve better performance.
  • Location Misalignment Issue: CNN based models are not perfectly immune to location misalignment under adversarial attack. Hence, with blackbox feature extractor the learnt prototype-based networks are also prone to such issues.

Citation

If you use this model in your research, please consider to cite:

@article{saralajew2024robust,
  title={A Robust Prototype-Based Network with Interpretable RBF Classifier Foundations},
  author={Saralajew, Sascha and Rana, Ashish and Villmann, Thomas and Shaker, Ammar},
  journal={arXiv preprint arXiv:2412.15499},
  year={2024}
}

Acknowledgements

This implementation builds upon the following excellent repositories:

And further these repositories can be referred to as additional documentation details specified in the above two repositories regarding the data pre-processing, data loaders, model architectures and visualizations.

License

This project is released under [MIT] license.

Contact

For any questions or feedback, please:

  1. Open an issue in the project GitHub repository
  2. Contact the Correspondence Author
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.