vittoriopippi commited on
Commit
f253188
·
1 Parent(s): b622c9d

Error with model_type

Browse files
Files changed (2) hide show
  1. configuration_vatrpp.py +2 -1
  2. modeling_vatrpp.py +2 -6
configuration_vatrpp.py CHANGED
@@ -1,5 +1,6 @@
1
- from transformers import PretrainedConfig
2
 
 
3
  class VATrPPConfig(PretrainedConfig):
4
  model_type = "vatrpp"
5
 
 
1
+ from transformers import PretrainedConfig, AutoConfig
2
 
3
+ @AutoConfig.register("vatrpp")
4
  class VATrPPConfig(PretrainedConfig):
5
  model_type = "vatrpp"
6
 
modeling_vatrpp.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import PreTrainedModel
2
  from .configuration_vatrpp import VATrPPConfig
3
  import json
4
  import os
@@ -20,7 +20,6 @@ from util.misc import FakeArgs
20
  from util.text import TextGenerator
21
  from util.vision import detect_text_bounds
22
  from torchvision.transforms.functional import to_pil_image
23
- from transformers import CONFIG_MAPPING, MODEL_MAPPING
24
 
25
 
26
  def get_long_tail_chars():
@@ -31,6 +30,7 @@ def get_long_tail_chars():
31
 
32
  return chars
33
 
 
34
  class VATrPP(PreTrainedModel):
35
  config_class = VATrPPConfig
36
 
@@ -130,7 +130,3 @@ class VATrPP(PreTrainedModel):
130
  x_pos += word.shape[1] + gap_width
131
 
132
  return result
133
-
134
-
135
- CONFIG_MAPPING.register("vatrpp", VATrPPConfig)
136
- MODEL_MAPPING.register(VATrPPConfig, VATrPP)
 
1
+ from transformers import PreTrainedModel, AutoModel
2
  from .configuration_vatrpp import VATrPPConfig
3
  import json
4
  import os
 
20
  from util.text import TextGenerator
21
  from util.vision import detect_text_bounds
22
  from torchvision.transforms.functional import to_pil_image
 
23
 
24
 
25
  def get_long_tail_chars():
 
30
 
31
  return chars
32
 
33
+ @AutoModel.register(VATrPPConfig)
34
  class VATrPP(PreTrainedModel):
35
  config_class = VATrPPConfig
36
 
 
130
  x_pos += word.shape[1] + gap_width
131
 
132
  return result