ajaye2 commited on
Commit
1f6a0a1
·
1 Parent(s): aff4488

Upload model

Browse files
Files changed (3) hide show
  1. ViTMAETimm.py +42 -0
  2. config.json +14 -0
  3. pytorch_model.bin +3 -0
ViTMAETimm.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from timm.models.resnet import BasicBlock, Bottleneck, ResNet
3
+
4
+ from transformers import PretrainedConfig
5
+ from typing import List
6
+ import torch
7
+ import timm
8
+
9
+
10
+ class ViTMAEConfig(PretrainedConfig):
11
+ model_type = "vit_mae_custom"
12
+
13
+ def __init__( self, model_name='timm/vit_base_patch16_224.mae', num_classes: int = 1000, **kwargs ):
14
+ self.model_name = model_name
15
+ self.num_classes = num_classes
16
+ super().__init__(**kwargs)
17
+
18
+ # 'timm/vit_huge_patch14_224.mae'
19
+ # class ViTMAEModel(PreTrainedModel):
20
+ # config_class = ViTMAEConfig
21
+
22
+ # def __init__(self, config):
23
+ # super().__init__(config)
24
+ # self.model = timm.create_model(config.model_name, num_classes=config.num_classes, pretrained=True)
25
+
26
+ # def forward(self, tensor):
27
+ # return self.model.forward_features(tensor)
28
+
29
+ class ViTMAEModelForImageClassification(PreTrainedModel):
30
+ config_class = ViTMAEConfig
31
+
32
+ def __init__(self, config):
33
+ super().__init__(config)
34
+
35
+ self.model = timm.create_model(config.model_name, num_classes=config.num_classes, pretrained=True)
36
+
37
+ def forward(self, tensor, labels=None):
38
+ logits = self.model(tensor)
39
+ if labels is not None:
40
+ loss = torch.nn.cross_entropy(logits, labels)
41
+ return {"loss": loss, "logits": logits}
42
+ return {"logits": logits}
config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ViTMAEModelForImageClassification"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "ViTMAETimm.ViTMAEConfig",
7
+ "AutoModelForImageClassification": "ViTMAETimm.ViTMAEModelForImageClassification"
8
+ },
9
+ "model_name": "timm/vit_large_patch16_224.mae",
10
+ "model_type": "vit_mae_custom",
11
+ "num_classes": 2,
12
+ "torch_dtype": "float32",
13
+ "transformers_version": "4.32.0.dev0"
14
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6d0e9bb111e2a9860bd20d52a8c5f959e5e8659c932ef6b3c835b3c33b36c18
3
+ size 1213306285