Upload tokenization_dart.py
Browse files- tokenization_dart.py +10 -23
tokenization_dart.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
| 1 |
import logging
|
| 2 |
-
import
|
| 3 |
-
from typing import Dict, List
|
| 4 |
-
from pydantic.dataclasses import dataclass
|
| 5 |
|
| 6 |
from transformers import PreTrainedTokenizerFast
|
| 7 |
from tokenizers.decoders import Decoder
|
|
@@ -39,35 +37,24 @@ PROMPT_TEMPLATE = (
|
|
| 39 |
"{{ '</character>' }}"
|
| 40 |
|
| 41 |
"{{ '<general>' }}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
"{% if 'general' not in messages or messages['general'] is none %}"
|
| 43 |
"{{ '' }}"
|
| 44 |
"{% else %}"
|
| 45 |
"{{ messages['general'] }}"
|
| 46 |
"{% endif %}"
|
|
|
|
| 47 |
).strip()
|
| 48 |
# fmt: on
|
| 49 |
|
| 50 |
|
| 51 |
-
@dataclass
|
| 52 |
-
class Category:
|
| 53 |
-
name: str
|
| 54 |
-
bos_token_id: int
|
| 55 |
-
eos_token_id: int
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
@dataclass
|
| 59 |
-
class TagCategoryConfig:
|
| 60 |
-
categories: Dict[str, Category]
|
| 61 |
-
category_to_token_ids: Dict[str, List[int]]
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
def load_tag_category_config(config_json: str):
|
| 65 |
-
with open(config_json, "rb") as file:
|
| 66 |
-
config: TagCategoryConfig = TagCategoryConfig(**json.loads(file.read()))
|
| 67 |
-
|
| 68 |
-
return config
|
| 69 |
-
|
| 70 |
-
|
| 71 |
class DartDecoder:
|
| 72 |
def __init__(self, special_tokens: List[str]):
|
| 73 |
self.special_tokens = list(special_tokens)
|
|
|
|
| 1 |
import logging
|
| 2 |
+
from typing import List
|
|
|
|
|
|
|
| 3 |
|
| 4 |
from transformers import PreTrainedTokenizerFast
|
| 5 |
from tokenizers.decoders import Decoder
|
|
|
|
| 37 |
"{{ '</character>' }}"
|
| 38 |
|
| 39 |
"{{ '<general>' }}"
|
| 40 |
+
# length token
|
| 41 |
+
"{% if 'length' not in messages or messages['length'] is none %}"
|
| 42 |
+
"{{ '<|long|>' }}"
|
| 43 |
+
"{% else %}"
|
| 44 |
+
"{{ messages['length'] }}"
|
| 45 |
+
"{% endif %}"
|
| 46 |
+
|
| 47 |
+
# general token
|
| 48 |
"{% if 'general' not in messages or messages['general'] is none %}"
|
| 49 |
"{{ '' }}"
|
| 50 |
"{% else %}"
|
| 51 |
"{{ messages['general'] }}"
|
| 52 |
"{% endif %}"
|
| 53 |
+
"{{ '<|input_end|>' }}"
|
| 54 |
).strip()
|
| 55 |
# fmt: on
|
| 56 |
|
| 57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
class DartDecoder:
|
| 59 |
def __init__(self, special_tokens: List[str]):
|
| 60 |
self.special_tokens = list(special_tokens)
|