MedMNIST Active Learning Model
This model is designed for image classification tasks within the medical imaging domain, specifically targeting the MedMNIST dataset. It employs a ResNet-50 architecture tailored for 28x28 pixel images and incorporates active learning strategies to enhance performance with limited labeled data.
Model Architecture
- Base Model: ResNet-50
- Modifications:
- Adjusted initial convolution layer to accommodate 28x28 input images.
- Removed max pooling layer to preserve spatial dimensions.
- Customized fully connected layer to output predictions for 9 classes.
Training Procedure
Training Hyperparameters
Hyperparameter | Value |
Batch Size | 53 |
Initial Labeled Size | 3559 |
Learning Rate | 0.01332344940133225 |
MC Dropout Passes | 6 |
Samples to Label | 4430 |
Weight Decay | 0.00021921795989143406 |
Optimizer Settings
The optimizer used during training was Stochastic Gradient Descent(SDG), with the following settings and a Learning Rate Scheduler of ReduceLROnPlateau:
learning_rate = 0.01332344940133225
momentum = 0.9
weight_decay = 0.00021921795989143406
The model was trained with float32 precision.
Data Augmentation
- Random resized cropping
- Horizontal flipping
- Random rotations
- Color jittering
- Gaussian blur
- RandAugment
Active Learning Strategy
The active learning process was based on a mixed sampling strategy:
- Uncertainty Sampling: Monte Carlo (MC) dropout was used to estimate uncertainty.
- Diversity Sampling: K-means clustering was employed to ensure diverse samples.
The model was evaluated on the validation set of PathMNIST. Key performance metrics include:
- Accuracy: 94.72%
- Loss: 0.2397
- AUC: 99.73%
The following plots illustrates the validation loss, validation accuracy, and validation auc over batches(number of iterations over the dataset) during the active learning process.
All code for this model can be accessed in the following GitHub Repository: Allen Cheung Determined_AI_Hackathon
To utilize this model:
Install Dependencies: Ensure the following Python packages are installed:
Install them using pip:
pip install torch torchvision medmnist scikit-learn determined
Load the Model:
import torch from model import ResNet50_28 model = ResNet50_28(num_classes=9) model.load_state_dict(torch.load('pytorch_model.bin')) model.eval()
from torchvision import transforms from PIL import Image transform = transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ]) image ='path_to_image.jpg') input_tensor = transform(image).unsqueeze(0) output = model(input_tensor) prediction = output.argmax(dim=1).item() print(f"Predicted class: {prediction}")
This project is licensed under the MIT License.
- MedMNIST Dataset
- Determined AI
- Survey on Deep Active Learning: Wang, H., Jin, Q., Li, S., Liu, S., Wang, M., & Song, Z. (2024). A comprehensive survey on deep active learning in medical image analysis. Medical Image Analysis, 95, 103201.