|
--- |
|
license: apache-2.0 |
|
language: |
|
- en |
|
metrics: |
|
- accuracy |
|
base_model: |
|
- microsoft/resnet-50 |
|
- timm/vgg19.tv_in1k |
|
- google/vit-base-patch16-224 |
|
- xai-org/grok-1 |
|
pipeline_tag: image-classification |
|
tags: |
|
- Ocular-Toxoplasmosis(FundusImages) |
|
- Retinal-images(Diabetics,Cataract,Gulocoma,Healthy) |
|
- Pytorch |
|
- Transformers |
|
- Image-Classification |
|
- Image_feature_extraction |
|
- Grad-CAM |
|
- XAI-Visualization |
|
--- |
|
|
|
# Model Card: ROYXAI [Vision Transformer + VGG19 + ResNet50 Ensemble with Grad-CAM] |
|
|
|
## Model Description |
|
This model is an ensemble of three deep learning architectures: **Vision Transformer (ViT), VGG19, and ResNet50**. The ensemble approach enhances classification performance on medical image datasets related to ocular diseases. The model also integrates **Grad-CAM** visualization to highlight regions of interest for better interpretability. |
|
|
|
## Model Details |
|
- **Model Name**: ROYXAI |
|
- **Developed by**: Avishek Roy Sparsho |
|
- **Framework**: PyTorch |
|
- **Ensemble Method**: Bagging |
|
- **Backbone Models**: Vision Transformer, VGG19, ResNet50 |
|
- **Target Task**: Medical Image Classification |
|
- **Supported Classes**: |
|
- OT |
|
- Healthy |
|
- SC_diabetes |
|
- SC_cataract |
|
- SC_glucoma |
|
|
|
## Dataset |
|
- **Dataset Name**: Custom Ocular Disease and its Secondary complications Dataset |
|
- **Dataset Source**: Private Dataset (Medical Images) |
|
- **Dataset Structure**: Images stored in folders based on class labels |
|
- **Preprocessing**: |
|
- Resized images to 224x224 pixels |
|
- Normalized using ImageNet mean and standard deviation |
|
|
|
## Model Performance |
|
- **Accuracy**: 98% on the test dataset |
|
- **Precision/Recall/F1-score**: Evaluated and optimized for medical diagnosis |
|
- **Overfitting Prevention**: Implemented **data augmentation, dropout, weight regularization** |
|
|
|
## Installation and Usage |
|
### Clone the Repository |
|
```bash |
|
git clone https://huggingface.co/Aviroy/ROYXAI |
|
cd ROYXAI |
|
``` |
|
|
|
### Install Dependencies |
|
```bash |
|
pip install -r requirements.txt |
|
``` |
|
|
|
### Training the Model |
|
To train the model from scratch, run: |
|
```bash |
|
python train.py --epochs 50 --batch_size 32 |
|
``` |
|
|
|
### Load Pretrained Model |
|
To directly use the trained model: |
|
```python |
|
import torch |
|
from PIL import Image |
|
import torchvision.transforms as transforms |
|
from model import ensemble_model # Load the trained ensemble model |
|
|
|
# Define image transformations |
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
# Load and preprocess an image |
|
image_path = "path/to/image.jpg" |
|
image = Image.open(image_path).convert('RGB') |
|
image = transform(image).unsqueeze(0).to('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
# Perform inference |
|
ensemble_model.eval() |
|
with torch.no_grad(): |
|
output = ensemble_model(image) |
|
predicted_class = torch.argmax(output, dim=1).item() |
|
|
|
# Print classification result |
|
print("Predicted Class:", predicted_class) |
|
``` |
|
|
|
## Grad-CAM Visualization |
|
### Visualizing Attention Maps for Interpretability |
|
#### Vision Transformer (ViT) |
|
```python |
|
from visualization import visualize_gradcam_vit # Function for ViT Grad-CAM |
|
|
|
# Generate Grad-CAM visualization |
|
overlay = visualize_gradcam_vit(ensemble_model.models[0], image, target_class=predicted_class) |
|
|
|
# Display the Grad-CAM output |
|
import matplotlib.pyplot as plt |
|
plt.imshow(overlay) |
|
plt.axis('off') |
|
plt.title("Grad-CAM for Vision Transformer") |
|
plt.show() |
|
``` |
|
|
|
#### ResNet50 |
|
```python |
|
from visualization import visualize_gradcam # General Grad-CAM function |
|
|
|
# Generate Grad-CAM visualization for ResNet50 |
|
overlay = visualize_gradcam(ensemble_model.models[2], image, target_class=predicted_class) |
|
|
|
# Display the Grad-CAM output |
|
import matplotlib.pyplot as plt |
|
plt.imshow(overlay) |
|
plt.axis('off') |
|
plt.title("Grad-CAM for ResNet50") |
|
plt.show() |
|
``` |
|
|
|
#### VGG19 |
|
```python |
|
from visualization import visualize_gradcam # General Grad-CAM function |
|
|
|
# Generate Grad-CAM visualization for VGG19 |
|
overlay = visualize_gradcam(ensemble_model.models[1], image, target_class=predicted_class) |
|
|
|
# Display the Grad-CAM output |
|
import matplotlib.pyplot as plt |
|
plt.imshow(overlay) |
|
plt.axis('off') |
|
plt.title("Grad-CAM for VGG19") |
|
plt.show() |
|
``` |
|
|
|
## Training Configuration |
|
- **Optimizer**: Adam with weight decay |
|
- **Learning Rate Scheduler**: Cosine Annealing LR |
|
- **Loss Function**: Cross-Entropy Loss |
|
- **Batch Size**: 32 |
|
- **Training Epochs**: 20 |
|
- **Hardware Used**: T4 GPU x2 ,M1chip ,GPU P100 |
|
|
|
## Limitations & Considerations |
|
- This model is trained on a specific dataset and may not generalize well to other medical image datasets without fine-tuning. |
|
- It is **not a substitute for professional medical diagnosis**. |
|
- The Vision Transformer model is computationally expensive compared to CNNs. |
|
|
|
## Citation |
|
If you use this model in your research, please cite: |
|
``` |
|
@article{Sparsho2025, |
|
author = {Avishek Roy Sparsho}, |
|
title = {ROYXAI Model For Proper Visualization of Classified Medical Image}, |
|
journal = {Medical AI Research}, |
|
year = {2025} |
|
} |
|
``` |
|
|
|
## Acknowledgments |
|
Special thanks to the open-source community and Kaggle for providing medical datasets for deep learning research. |
|
|
|
## License |
|
This model is released under the **Apache 2.0 License**. Use it responsibly. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|