Spaces:
Sleeping
Sleeping
import torch | |
import timm | |
from torchvision import transforms | |
def create_model(model_name, num_classes): | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]) | |
]) | |
# Load the pretrained model | |
model = timm.create_model('vit_base_patch16_224_miil_in21k', pretrained=True) | |
model.head = torch.nn.Linear(768, num_classes) | |
for param in model.parameters(): | |
param.requires_grad = True | |
return model, transform | |