import torch import timm from torchvision import transforms def create_model(model_name, num_classes): transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]) ]) # Load the pretrained model model = timm.create_model('vit_base_patch16_224_miil_in21k', pretrained=True) model.head = torch.nn.Linear(768, num_classes) for param in model.parameters(): param.requires_grad = True return model, transform