MRI-Image / vit.py
TharunSiva's picture
util files
c96e8c8 verified
import os
import torch
import torchvision
from torch import nn
from torchvision import transforms
from PIL import Image
class CFG:
DEVICE = 'cpu'
NUM_DEVICES = torch.cuda.device_count()
NUM_WORKERS = os.cpu_count()
NUM_CLASSES = 4
EPOCHS = 16
BATCH_SIZE = 32
LR = 0.001
APPLY_SHUFFLE = True
SEED = 768
HEIGHT = 224
WIDTH = 224
CHANNELS = 3
IMAGE_SIZE = (224, 224, 3)
class VisionTransformerModel(nn.Module):
def __init__(self, backbone_model, name='vision-transformer',
num_classes=CFG.NUM_CLASSES, device=CFG.DEVICE):
super(VisionTransformerModel, self).__init__()
self.backbone_model = backbone_model
self.device = device
self.num_classes = num_classes
self.name = name
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Dropout(p=0.2, inplace=True),
nn.Linear(in_features=1000, out_features=256, bias=True),
nn.GELU(),
nn.Dropout(p=0.2, inplace=True),
nn.Linear(in_features=256, out_features=num_classes, bias=False)
).to(device)
def forward(self, image):
vit_output = self.backbone_model(image)
return self.classifier(vit_output)
def get_vit_b32_model(
device: torch.device=CFG.NUM_CLASSES) -> nn.Module:
# Set the manual seeds
torch.manual_seed(CFG.SEED)
torch.cuda.manual_seed(CFG.SEED)
# Get model weights
model_weights = (
torchvision
.models
.ViT_L_32_Weights
.DEFAULT
)
# Get model and push to device
model = (
torchvision.models.vit_l_32(
weights=model_weights
)
).to(device)
# Freeze Model Parameters
for param in model.parameters():
param.requires_grad = False
return model
# Get ViT model
vit_backbone = get_vit_b32_model(CFG.DEVICE)
vit_params = {
'backbone_model' : vit_backbone,
'name' : 'ViT-L-B32',
'device' : CFG.DEVICE
}
# Generate Model
vit_model = VisionTransformerModel(**vit_params)
vit_model.load_state_dict(
torch.load('vit_model.pth', map_location=torch.device('cpu'))
)
# Define the image transformation
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
def predict(image_path):
image = Image.open(image_path)
input_tensor = transform(image)
input_batch = input_tensor.unsqueeze(0).to(CFG.DEVICE) # Add batch dimension
# Perform inference
with torch.no_grad():
output = vit_model(input_batch).to(CFG.DEVICE)
# You can now use the 'output' tensor as needed (e.g., get predictions)
return output