Update gptx_tokenizer.py
Browse files- gptx_tokenizer.py +9 -24
gptx_tokenizer.py
CHANGED
|
@@ -7,7 +7,7 @@ from pathlib import Path
|
|
| 7 |
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
|
| 8 |
|
| 9 |
import sentencepiece as spm
|
| 10 |
-
from huggingface_hub import hf_hub_download, list_repo_files
|
| 11 |
from transformers.tokenization_utils import PreTrainedTokenizer
|
| 12 |
from transformers.tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
| 13 |
|
|
@@ -64,29 +64,14 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
|
|
| 64 |
f"<placeholder_tok_{i}>" for i in range(256)
|
| 65 |
]
|
| 66 |
|
| 67 |
-
def find_tokenizer_config(self, config_path: Path, repo_id: str = None) -> Path:
|
| 68 |
-
if
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
# Find the tokenizer config file
|
| 76 |
-
tokenizer_files = [f for f in repo_files if f.endswith('tokenizer_config.json')]
|
| 77 |
-
if not tokenizer_files:
|
| 78 |
-
raise FileNotFoundError(f"No tokenizer_config.json file found in repository {repo_id}")
|
| 79 |
-
|
| 80 |
-
# Use the first tokenizer_config.json file found
|
| 81 |
-
tokenizer_config_file = tokenizer_files[0]
|
| 82 |
-
print(f"Found tokenizer config file: {tokenizer_config_file}")
|
| 83 |
-
|
| 84 |
-
# Download the file
|
| 85 |
-
tokenizer_config_file_or_name = hf_hub_download(repo_id=repo_id, filename=tokenizer_config_file)
|
| 86 |
-
print(f"Downloaded tokenizer config file to: {tokenizer_config_file_or_name}")
|
| 87 |
-
return tokenizer_config_file_or_name
|
| 88 |
-
except Exception as e:
|
| 89 |
-
raise OSError(f"Failed to download tokenizer model: {str(e)}")
|
| 90 |
|
| 91 |
def instantiate_from_file_or_name(self, model_file_or_name: str, repo_id: str = None):
|
| 92 |
"""
|
|
|
|
| 7 |
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
|
| 8 |
|
| 9 |
import sentencepiece as spm
|
| 10 |
+
from huggingface_hub import hf_hub_download, list_repo_files, try_to_load_from_cache
|
| 11 |
from transformers.tokenization_utils import PreTrainedTokenizer
|
| 12 |
from transformers.tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
| 13 |
|
|
|
|
| 64 |
f"<placeholder_tok_{i}>" for i in range(256)
|
| 65 |
]
|
| 66 |
|
| 67 |
+
def find_tokenizer_config(self, config_path: Path, repo_id: str = None) -> Optional[Path]:
|
| 68 |
+
if not os.path.isfile(config_path):
|
| 69 |
+
config_path = try_to_load_from_cache(repo_id=repo_id, filename=Path(config_path).name)
|
| 70 |
+
if not config_path:
|
| 71 |
+
config_path = self._download_config_from_hub(repo_id=repo_id)
|
| 72 |
+
|
| 73 |
+
return config_path
|
| 74 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
def instantiate_from_file_or_name(self, model_file_or_name: str, repo_id: str = None):
|
| 77 |
"""
|