zhangfaen commited on
Commit
920b410
·
verified ·
1 Parent(s): 4cd9211

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. configuration_resnet.py +0 -12
  2. modeling_resnet.py +4 -2
configuration_resnet.py CHANGED
@@ -1,7 +1,5 @@
1
  from transformers import PretrainedConfig
2
  from typing import List
3
- from pprint import pprint
4
-
5
 
6
  class ResnetConfig(PretrainedConfig):
7
  model_type = "faen_resnet"
@@ -34,13 +32,3 @@ class ResnetConfig(PretrainedConfig):
34
  self.stem_type = stem_type
35
  self.avg_down = avg_down
36
  super().__init__(**kwargs)
37
-
38
- if __name__ == "__main__":
39
- resnet50d_config = ResnetConfig(block_type="bottleneck", stem_width=32, stem_type="deep", avg_down=True)
40
- print("init a ResnetConfig, it is:\n")
41
- pprint(resnet50d_config)
42
- resnet50d_config.save_pretrained("./")
43
- resnet50d_config = ResnetConfig.from_pretrained("./")
44
- print("\n")
45
- print("saved to file config.json and reload it from config.json and it is:\n")
46
- pprint(resnet50d_config)
 
1
  from transformers import PretrainedConfig
2
  from typing import List
 
 
3
 
4
  class ResnetConfig(PretrainedConfig):
5
  model_type = "faen_resnet"
 
32
  self.stem_type = stem_type
33
  self.avg_down = avg_down
34
  super().__init__(**kwargs)
 
 
 
 
 
 
 
 
 
 
modeling_resnet.py CHANGED
@@ -1,12 +1,13 @@
1
  from transformers import PreTrainedModel
2
  from timm.models.resnet import BasicBlock, Bottleneck, ResNet
3
- from configuration_resnet import ResnetConfig
4
  import torch
5
 
6
  BLOCK_MAPPING = {"basic": BasicBlock, "bottleneck": Bottleneck}
7
 
8
 
9
  class ResnetModel(PreTrainedModel):
 
10
 
11
  def __init__(self, config):
12
  super().__init__(config)
@@ -28,7 +29,8 @@ class ResnetModel(PreTrainedModel):
28
 
29
 
30
  class ResnetModelForImageClassification(PreTrainedModel):
31
-
 
32
  def __init__(self, config):
33
  super().__init__(config)
34
  block_layer = BLOCK_MAPPING[config.block_type]
 
1
  from transformers import PreTrainedModel
2
  from timm.models.resnet import BasicBlock, Bottleneck, ResNet
3
+ from .configuration_resnet import ResnetConfig
4
  import torch
5
 
6
  BLOCK_MAPPING = {"basic": BasicBlock, "bottleneck": Bottleneck}
7
 
8
 
9
  class ResnetModel(PreTrainedModel):
10
+ config_class = ResnetConfig
11
 
12
  def __init__(self, config):
13
  super().__init__(config)
 
29
 
30
 
31
  class ResnetModelForImageClassification(PreTrainedModel):
32
+ config_class = ResnetConfig
33
+
34
  def __init__(self, config):
35
  super().__init__(config)
36
  block_layer = BLOCK_MAPPING[config.block_type]