File size: 5,248 Bytes
d5b9ea2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
---
license: apache-2.0
language:
- en
metrics:
- accuracy
base_model:
- microsoft/resnet-50
- timm/vgg19.tv_in1k
- google/vit-base-patch16-224
- xai-org/grok-1
pipeline_tag: image-classification
tags:
- Ocular-Toxoplasmosis(FundusImages)
- Retinal-images(Diabetics,Cataract,Gulocoma,Healthy)
- Pytorch
- Transformers
- Image-Classification
- Image_feature_extraction
- Grad-CAM
- XAI-Visualization
---
# Model Card: ROYXAI [Vision Transformer + VGG19 + ResNet50 Ensemble with Grad-CAM]
## Model Description
This model is an ensemble of three deep learning architectures: **Vision Transformer (ViT), VGG19, and ResNet50**. The ensemble approach enhances classification performance on medical image datasets related to ocular diseases. The model also integrates **Grad-CAM** visualization to highlight regions of interest for better interpretability.
## Model Details
- **Model Name**: ROYXAI
- **Developed by**: Avishek Roy Sparsho
- **Framework**: PyTorch
- **Ensemble Method**: Bagging
- **Backbone Models**: Vision Transformer, VGG19, ResNet50
- **Target Task**: Medical Image Classification
- **Supported Classes**:
- OT
- Healthy
- SC_diabetes
- SC_cataract
- SC_glucoma
## Dataset
- **Dataset Name**: Custom Ocular Disease and its Secondary complications Dataset
- **Dataset Source**: Private Dataset (Medical Images)
- **Dataset Structure**: Images stored in folders based on class labels
- **Preprocessing**:
- Resized images to 224x224 pixels
- Normalized using ImageNet mean and standard deviation
## Model Performance
- **Accuracy**: 98% on the test dataset
- **Precision/Recall/F1-score**: Evaluated and optimized for medical diagnosis
- **Overfitting Prevention**: Implemented **data augmentation, dropout, weight regularization**
## Installation and Usage
### Clone the Repository
```bash
git clone https://huggingface.co/Aviroy/ROYXAI
cd ROYXAI
```
### Install Dependencies
```bash
pip install -r requirements.txt
```
### Training the Model
To train the model from scratch, run:
```bash
python train.py --epochs 50 --batch_size 32
```
### Load Pretrained Model
To directly use the trained model:
```python
import torch
from PIL import Image
import torchvision.transforms as transforms
from model import ensemble_model # Load the trained ensemble model
# Define image transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Load and preprocess an image
image_path = "path/to/image.jpg"
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0).to('cuda' if torch.cuda.is_available() else 'cpu')
# Perform inference
ensemble_model.eval()
with torch.no_grad():
output = ensemble_model(image)
predicted_class = torch.argmax(output, dim=1).item()
# Print classification result
print("Predicted Class:", predicted_class)
```
## Grad-CAM Visualization
### Visualizing Attention Maps for Interpretability
#### Vision Transformer (ViT)
```python
from visualization import visualize_gradcam_vit # Function for ViT Grad-CAM
# Generate Grad-CAM visualization
overlay = visualize_gradcam_vit(ensemble_model.models[0], image, target_class=predicted_class)
# Display the Grad-CAM output
import matplotlib.pyplot as plt
plt.imshow(overlay)
plt.axis('off')
plt.title("Grad-CAM for Vision Transformer")
plt.show()
```
#### ResNet50
```python
from visualization import visualize_gradcam # General Grad-CAM function
# Generate Grad-CAM visualization for ResNet50
overlay = visualize_gradcam(ensemble_model.models[2], image, target_class=predicted_class)
# Display the Grad-CAM output
import matplotlib.pyplot as plt
plt.imshow(overlay)
plt.axis('off')
plt.title("Grad-CAM for ResNet50")
plt.show()
```
#### VGG19
```python
from visualization import visualize_gradcam # General Grad-CAM function
# Generate Grad-CAM visualization for VGG19
overlay = visualize_gradcam(ensemble_model.models[1], image, target_class=predicted_class)
# Display the Grad-CAM output
import matplotlib.pyplot as plt
plt.imshow(overlay)
plt.axis('off')
plt.title("Grad-CAM for VGG19")
plt.show()
```
## Training Configuration
- **Optimizer**: Adam with weight decay
- **Learning Rate Scheduler**: Cosine Annealing LR
- **Loss Function**: Cross-Entropy Loss
- **Batch Size**: 32
- **Training Epochs**: 20
- **Hardware Used**: T4 GPU x2 ,M1chip ,GPU P100
## Limitations & Considerations
- This model is trained on a specific dataset and may not generalize well to other medical image datasets without fine-tuning.
- It is **not a substitute for professional medical diagnosis**.
- The Vision Transformer model is computationally expensive compared to CNNs.
## Citation
If you use this model in your research, please cite:
```
@article{Sparsho2025,
author = {Avishek Roy Sparsho},
title = {ROYXAI Model For Proper Visualization of Classified Medical Image},
journal = {Medical AI Research},
year = {2025}
}
```
## Acknowledgments
Special thanks to the open-source community and Kaggle for providing medical datasets for deep learning research.
## License
This model is released under the **Apache 2.0 License**. Use it responsibly.
|