ViT_Food101_Demo / model.py
kietnt0603's picture
first commit
c4a1f55
raw
history blame contribute delete
569 Bytes
import torch
import timm
from torchvision import transforms
def create_model(model_name, num_classes):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0])
])
# Load the pretrained model
model = timm.create_model('vit_base_patch16_224_miil_in21k', pretrained=True)
model.head = torch.nn.Linear(768, num_classes)
for param in model.parameters():
param.requires_grad = True
return model, transform