Upload folder using huggingface_hub
Browse files- configuration_resnet.py +0 -12
- modeling_resnet.py +4 -2
configuration_resnet.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
from transformers import PretrainedConfig
|
2 |
from typing import List
|
3 |
-
from pprint import pprint
|
4 |
-
|
5 |
|
6 |
class ResnetConfig(PretrainedConfig):
|
7 |
model_type = "faen_resnet"
|
@@ -34,13 +32,3 @@ class ResnetConfig(PretrainedConfig):
|
|
34 |
self.stem_type = stem_type
|
35 |
self.avg_down = avg_down
|
36 |
super().__init__(**kwargs)
|
37 |
-
|
38 |
-
if __name__ == "__main__":
|
39 |
-
resnet50d_config = ResnetConfig(block_type="bottleneck", stem_width=32, stem_type="deep", avg_down=True)
|
40 |
-
print("init a ResnetConfig, it is:\n")
|
41 |
-
pprint(resnet50d_config)
|
42 |
-
resnet50d_config.save_pretrained("./")
|
43 |
-
resnet50d_config = ResnetConfig.from_pretrained("./")
|
44 |
-
print("\n")
|
45 |
-
print("saved to file config.json and reload it from config.json and it is:\n")
|
46 |
-
pprint(resnet50d_config)
|
|
|
1 |
from transformers import PretrainedConfig
|
2 |
from typing import List
|
|
|
|
|
3 |
|
4 |
class ResnetConfig(PretrainedConfig):
|
5 |
model_type = "faen_resnet"
|
|
|
32 |
self.stem_type = stem_type
|
33 |
self.avg_down = avg_down
|
34 |
super().__init__(**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modeling_resnet.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
from transformers import PreTrainedModel
|
2 |
from timm.models.resnet import BasicBlock, Bottleneck, ResNet
|
3 |
-
from configuration_resnet import ResnetConfig
|
4 |
import torch
|
5 |
|
6 |
BLOCK_MAPPING = {"basic": BasicBlock, "bottleneck": Bottleneck}
|
7 |
|
8 |
|
9 |
class ResnetModel(PreTrainedModel):
|
|
|
10 |
|
11 |
def __init__(self, config):
|
12 |
super().__init__(config)
|
@@ -28,7 +29,8 @@ class ResnetModel(PreTrainedModel):
|
|
28 |
|
29 |
|
30 |
class ResnetModelForImageClassification(PreTrainedModel):
|
31 |
-
|
|
|
32 |
def __init__(self, config):
|
33 |
super().__init__(config)
|
34 |
block_layer = BLOCK_MAPPING[config.block_type]
|
|
|
1 |
from transformers import PreTrainedModel
|
2 |
from timm.models.resnet import BasicBlock, Bottleneck, ResNet
|
3 |
+
from .configuration_resnet import ResnetConfig
|
4 |
import torch
|
5 |
|
6 |
BLOCK_MAPPING = {"basic": BasicBlock, "bottleneck": Bottleneck}
|
7 |
|
8 |
|
9 |
class ResnetModel(PreTrainedModel):
|
10 |
+
config_class = ResnetConfig
|
11 |
|
12 |
def __init__(self, config):
|
13 |
super().__init__(config)
|
|
|
29 |
|
30 |
|
31 |
class ResnetModelForImageClassification(PreTrainedModel):
|
32 |
+
config_class = ResnetConfig
|
33 |
+
|
34 |
def __init__(self, config):
|
35 |
super().__init__(config)
|
36 |
block_layer = BLOCK_MAPPING[config.block_type]
|