vittoriopippi
commited on
Commit
·
160658a
1
Parent(s):
af10767
Change imports
Browse files- modeling_vatrpp.py +5 -0
- models/model.py +1 -1
modeling_vatrpp.py
CHANGED
@@ -10,6 +10,7 @@ from .data.dataset import FolderDataset
|
|
10 |
from .models.model import VATr
|
11 |
from .models.util.vision import detect_text_bounds
|
12 |
from torchvision.transforms.functional import to_pil_image
|
|
|
13 |
|
14 |
|
15 |
def get_long_tail_chars():
|
@@ -26,6 +27,10 @@ class VATrPP(PreTrainedModel):
|
|
26 |
|
27 |
def __init__(self, config: VATrPPConfig) -> None:
|
28 |
super().__init__(config)
|
|
|
|
|
|
|
|
|
29 |
self.model = VATr(config)
|
30 |
self.model.eval()
|
31 |
|
|
|
10 |
from .models.model import VATr
|
11 |
from .models.util.vision import detect_text_bounds
|
12 |
from torchvision.transforms.functional import to_pil_image
|
13 |
+
from huggingface_hub import hf_hub_download
|
14 |
|
15 |
|
16 |
def get_long_tail_chars():
|
|
|
27 |
|
28 |
def __init__(self, config: VATrPPConfig) -> None:
|
29 |
super().__init__(config)
|
30 |
+
|
31 |
+
config.english_words_path = hf_hub_download(repo_id="blowing-up-groundhogs/vatrpp", filename=config.english_words_path)
|
32 |
+
config.mytext_path = hf_hub_download(repo_id="blowing-up-groundhogs/vatrpp", filename='mytext.txt')
|
33 |
+
|
34 |
self.model = VATr(config)
|
35 |
self.model.eval()
|
36 |
|
models/model.py
CHANGED
@@ -260,7 +260,7 @@ class VATr(nn.Module):
|
|
260 |
|
261 |
self.epoch = 0
|
262 |
|
263 |
-
with open(
|
264 |
self.text = f.read()
|
265 |
self.text = self.text.replace('\n', ' ')
|
266 |
self.text = self.text.replace('\n', ' ')
|
|
|
260 |
|
261 |
self.epoch = 0
|
262 |
|
263 |
+
with open(args.mytext_path, 'r', encoding='utf-8') as f:
|
264 |
self.text = f.read()
|
265 |
self.text = self.text.replace('\n', ' ')
|
266 |
self.text = self.text.replace('\n', ' ')
|