ROYXAI / Update README.md
Aviroy's picture
Rename README.md to Update README.md
d5b9ea2 verified
---
license: apache-2.0
language:
- en
metrics:
- accuracy
base_model:
- microsoft/resnet-50
- timm/vgg19.tv_in1k
- google/vit-base-patch16-224
- xai-org/grok-1
pipeline_tag: image-classification
tags:
- Ocular-Toxoplasmosis(FundusImages)
- Retinal-images(Diabetics,Cataract,Gulocoma,Healthy)
- Pytorch
- Transformers
- Image-Classification
- Image_feature_extraction
- Grad-CAM
- XAI-Visualization
---
# Model Card: ROYXAI [Vision Transformer + VGG19 + ResNet50 Ensemble with Grad-CAM]
## Model Description
This model is an ensemble of three deep learning architectures: **Vision Transformer (ViT), VGG19, and ResNet50**. The ensemble approach enhances classification performance on medical image datasets related to ocular diseases. The model also integrates **Grad-CAM** visualization to highlight regions of interest for better interpretability.
## Model Details
- **Model Name**: ROYXAI
- **Developed by**: Avishek Roy Sparsho
- **Framework**: PyTorch
- **Ensemble Method**: Bagging
- **Backbone Models**: Vision Transformer, VGG19, ResNet50
- **Target Task**: Medical Image Classification
- **Supported Classes**:
- OT
- Healthy
- SC_diabetes
- SC_cataract
- SC_glucoma
## Dataset
- **Dataset Name**: Custom Ocular Disease and its Secondary complications Dataset
- **Dataset Source**: Private Dataset (Medical Images)
- **Dataset Structure**: Images stored in folders based on class labels
- **Preprocessing**:
- Resized images to 224x224 pixels
- Normalized using ImageNet mean and standard deviation
## Model Performance
- **Accuracy**: 98% on the test dataset
- **Precision/Recall/F1-score**: Evaluated and optimized for medical diagnosis
- **Overfitting Prevention**: Implemented **data augmentation, dropout, weight regularization**
## Installation and Usage
### Clone the Repository
```bash
git clone https://huggingface.co/Aviroy/ROYXAI
cd ROYXAI
```
### Install Dependencies
```bash
pip install -r requirements.txt
```
### Training the Model
To train the model from scratch, run:
```bash
python train.py --epochs 50 --batch_size 32
```
### Load Pretrained Model
To directly use the trained model:
```python
import torch
from PIL import Image
import torchvision.transforms as transforms
from model import ensemble_model # Load the trained ensemble model
# Define image transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Load and preprocess an image
image_path = "path/to/image.jpg"
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0).to('cuda' if torch.cuda.is_available() else 'cpu')
# Perform inference
ensemble_model.eval()
with torch.no_grad():
output = ensemble_model(image)
predicted_class = torch.argmax(output, dim=1).item()
# Print classification result
print("Predicted Class:", predicted_class)
```
## Grad-CAM Visualization
### Visualizing Attention Maps for Interpretability
#### Vision Transformer (ViT)
```python
from visualization import visualize_gradcam_vit # Function for ViT Grad-CAM
# Generate Grad-CAM visualization
overlay = visualize_gradcam_vit(ensemble_model.models[0], image, target_class=predicted_class)
# Display the Grad-CAM output
import matplotlib.pyplot as plt
plt.imshow(overlay)
plt.axis('off')
plt.title("Grad-CAM for Vision Transformer")
plt.show()
```
#### ResNet50
```python
from visualization import visualize_gradcam # General Grad-CAM function
# Generate Grad-CAM visualization for ResNet50
overlay = visualize_gradcam(ensemble_model.models[2], image, target_class=predicted_class)
# Display the Grad-CAM output
import matplotlib.pyplot as plt
plt.imshow(overlay)
plt.axis('off')
plt.title("Grad-CAM for ResNet50")
plt.show()
```
#### VGG19
```python
from visualization import visualize_gradcam # General Grad-CAM function
# Generate Grad-CAM visualization for VGG19
overlay = visualize_gradcam(ensemble_model.models[1], image, target_class=predicted_class)
# Display the Grad-CAM output
import matplotlib.pyplot as plt
plt.imshow(overlay)
plt.axis('off')
plt.title("Grad-CAM for VGG19")
plt.show()
```
## Training Configuration
- **Optimizer**: Adam with weight decay
- **Learning Rate Scheduler**: Cosine Annealing LR
- **Loss Function**: Cross-Entropy Loss
- **Batch Size**: 32
- **Training Epochs**: 20
- **Hardware Used**: T4 GPU x2 ,M1chip ,GPU P100
## Limitations & Considerations
- This model is trained on a specific dataset and may not generalize well to other medical image datasets without fine-tuning.
- It is **not a substitute for professional medical diagnosis**.
- The Vision Transformer model is computationally expensive compared to CNNs.
## Citation
If you use this model in your research, please cite:
```
@article{Sparsho2025,
author = {Avishek Roy Sparsho},
title = {ROYXAI Model For Proper Visualization of Classified Medical Image},
journal = {Medical AI Research},
year = {2025}
}
```
## Acknowledgments
Special thanks to the open-source community and Kaggle for providing medical datasets for deep learning research.
## License
This model is released under the **Apache 2.0 License**. Use it responsibly.