File size: 1,390 Bytes
5c75b25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
from transformers import PreTrainedModel
from timm.models.resnet import BasicBlock, Bottleneck, ResNet
from transformers import PretrainedConfig
from typing import List
import torch
import timm
class ViTMAEConfig(PretrainedConfig):
model_type = "vit_mae_custom"
def __init__( self, model_name='timm/vit_base_patch16_224.mae', num_classes: int = 1000, **kwargs ):
self.model_name = model_name
self.num_classes = num_classes
super().__init__(**kwargs)
# 'timm/vit_huge_patch14_224.mae'
# class ViTMAEModel(PreTrainedModel):
# config_class = ViTMAEConfig
# def __init__(self, config):
# super().__init__(config)
# self.model = timm.create_model(config.model_name, num_classes=config.num_classes, pretrained=True)
# def forward(self, tensor):
# return self.model.forward_features(tensor)
class ViTMAEModelForImageClassification(PreTrainedModel):
config_class = ViTMAEConfig
def __init__(self, config):
super().__init__(config)
self.model = timm.create_model(config.model_name, num_classes=config.num_classes, pretrained=True)
def forward(self, tensor, labels=None):
logits = self.model(tensor)
if labels is not None:
loss = torch.nn.cross_entropy(logits, labels)
return {"loss": loss, "logits": logits}
return {"logits": logits} |