Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import torchvision | |
from torch import nn | |
from torchvision import transforms | |
from PIL import Image | |
class CFG: | |
DEVICE = 'cpu' | |
NUM_DEVICES = torch.cuda.device_count() | |
NUM_WORKERS = os.cpu_count() | |
NUM_CLASSES = 4 | |
EPOCHS = 16 | |
BATCH_SIZE = 32 | |
LR = 0.001 | |
APPLY_SHUFFLE = True | |
SEED = 768 | |
HEIGHT = 224 | |
WIDTH = 224 | |
CHANNELS = 3 | |
IMAGE_SIZE = (224, 224, 3) | |
class VisionTransformerModel(nn.Module): | |
def __init__(self, backbone_model, name='vision-transformer', | |
num_classes=CFG.NUM_CLASSES, device=CFG.DEVICE): | |
super(VisionTransformerModel, self).__init__() | |
self.backbone_model = backbone_model | |
self.device = device | |
self.num_classes = num_classes | |
self.name = name | |
self.classifier = nn.Sequential( | |
nn.Flatten(), | |
nn.Dropout(p=0.2, inplace=True), | |
nn.Linear(in_features=1000, out_features=256, bias=True), | |
nn.GELU(), | |
nn.Dropout(p=0.2, inplace=True), | |
nn.Linear(in_features=256, out_features=num_classes, bias=False) | |
).to(device) | |
def forward(self, image): | |
vit_output = self.backbone_model(image) | |
return self.classifier(vit_output) | |
def get_vit_b32_model( | |
device: torch.device=CFG.NUM_CLASSES) -> nn.Module: | |
# Set the manual seeds | |
torch.manual_seed(CFG.SEED) | |
torch.cuda.manual_seed(CFG.SEED) | |
# Get model weights | |
model_weights = ( | |
torchvision | |
.models | |
.ViT_L_32_Weights | |
.DEFAULT | |
) | |
# Get model and push to device | |
model = ( | |
torchvision.models.vit_l_32( | |
weights=model_weights | |
) | |
).to(device) | |
# Freeze Model Parameters | |
for param in model.parameters(): | |
param.requires_grad = False | |
return model | |
# Get ViT model | |
vit_backbone = get_vit_b32_model(CFG.DEVICE) | |
vit_params = { | |
'backbone_model' : vit_backbone, | |
'name' : 'ViT-L-B32', | |
'device' : CFG.DEVICE | |
} | |
# Generate Model | |
vit_model = VisionTransformerModel(**vit_params) | |
vit_model.load_state_dict( | |
torch.load('vit_model.pth', map_location=torch.device('cpu')) | |
) | |
# Define the image transformation | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor() | |
]) | |
def predict(image_path): | |
image = Image.open(image_path) | |
input_tensor = transform(image) | |
input_batch = input_tensor.unsqueeze(0).to(CFG.DEVICE) # Add batch dimension | |
# Perform inference | |
with torch.no_grad(): | |
output = vit_model(input_batch).to(CFG.DEVICE) | |
# You can now use the 'output' tensor as needed (e.g., get predictions) | |
return output |