File size: 569 Bytes
c4a1f55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
import timm
from torchvision import transforms

def create_model(model_name, num_classes):
  transform = transforms.Compose([
      transforms.Resize((224, 224)),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0])
  ])
  
  # Load the pretrained model
  model = timm.create_model('vit_base_patch16_224_miil_in21k', pretrained=True)
  model.head = torch.nn.Linear(768, num_classes)

  for param in model.parameters():
    param.requires_grad = True
  return model, transform