Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import json | |
| import os | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, Callable | |
| from .gguf_writer import GGUFWriter | |
| class SpecialVocab: | |
| merges: list[str] | |
| add_special_token: dict[str, bool] | |
| special_token_ids: dict[str, int] | |
| chat_template: str | None | |
| def __init__( | |
| self, path: str | os.PathLike[str], load_merges: bool = False, | |
| special_token_types: tuple[str, ...] | None = None, | |
| n_vocab: int | None = None, | |
| ): | |
| self.special_token_ids = {} | |
| self.add_special_token = {} | |
| self.n_vocab = n_vocab | |
| self.load_merges = load_merges | |
| self.merges = [] | |
| self.chat_template = None | |
| if special_token_types is not None: | |
| self.special_token_types = special_token_types | |
| else: | |
| self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad', 'cls', 'mask') | |
| self._load(Path(path)) | |
| def __repr__(self) -> str: | |
| return '<SpecialVocab with {} merges, special tokens {}, add special tokens {}>'.format( | |
| len(self.merges), self.special_token_ids or "unset", self.add_special_token or "unset", | |
| ) | |
| def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None: | |
| if self.merges: | |
| if not quiet: | |
| print(f'gguf: Adding {len(self.merges)} merge(s).') | |
| gw.add_token_merges(self.merges) | |
| elif self.load_merges: | |
| print( | |
| 'gguf: WARNING: Adding merges requested but no merges found, output may be non-functional.', | |
| file = sys.stderr, | |
| ) | |
| for typ, tokid in self.special_token_ids.items(): | |
| id_handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None) | |
| if id_handler is None: | |
| print( | |
| f'gguf: WARNING: No handler for special token type {typ} with id {tokid} - skipping', | |
| file = sys.stderr, | |
| ) | |
| continue | |
| if not quiet: | |
| print(f'gguf: Setting special token type {typ} to {tokid}') | |
| id_handler(tokid) | |
| for typ, value in self.add_special_token.items(): | |
| add_handler: Callable[[bool], None] | None = getattr(gw, f'add_add_{typ}_token', None) | |
| if add_handler is None: | |
| print( | |
| f'gguf: WARNING: No handler for add_{typ}_token with value {value} - skipping', | |
| file = sys.stderr, | |
| ) | |
| continue | |
| if not quiet: | |
| print(f'gguf: Setting add_{typ}_token to {value}') | |
| add_handler(value) | |
| if self.chat_template is not None: | |
| if not quiet: | |
| print(f'gguf: Setting chat_template to {self.chat_template}') | |
| gw.add_chat_template(self.chat_template) | |
| def _load(self, path: Path) -> None: | |
| self._try_load_from_tokenizer_json(path) | |
| self._try_load_from_config_json(path) | |
| if self.load_merges and not self.merges: | |
| self._try_load_merges_txt(path) | |
| def _try_load_merges_txt(self, path: Path) -> bool: | |
| merges_file = path / 'merges.txt' | |
| if not merges_file.is_file(): | |
| return False | |
| with open(merges_file, 'r', encoding = 'utf-8') as fp: | |
| first_line = next(fp, '').strip() | |
| if not first_line.startswith('#'): | |
| fp.seek(0) | |
| line_num = 0 | |
| else: | |
| line_num = 1 | |
| merges = [] | |
| for line in fp: | |
| line_num += 1 | |
| line = line.strip() | |
| if not line: | |
| continue | |
| parts = line.split(None, 3) | |
| if len(parts) != 2: | |
| print( | |
| f'gguf: WARNING: {merges_file.name}: Line {line_num}: Entry malformed, ignoring', | |
| file = sys.stderr, | |
| ) | |
| continue | |
| merges.append(f'{parts[0]} {parts[1]}') | |
| self.merges = merges | |
| return True | |
| def _set_special_token(self, typ: str, tid: Any) -> None: | |
| if not isinstance(tid, int): | |
| return | |
| if tid < 0: | |
| raise ValueError(f'invalid value for special token type {typ}: {tid}') | |
| if self.n_vocab is None or tid < self.n_vocab: | |
| if typ in self.special_token_ids: | |
| return | |
| self.special_token_ids[typ] = tid | |
| return | |
| print( | |
| f'gguf: WARNING: Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping', | |
| file = sys.stderr, | |
| ) | |
| def _try_load_from_tokenizer_json(self, path: Path) -> bool: | |
| tokenizer_file = path / 'tokenizer.json' | |
| if tokenizer_file.is_file(): | |
| with open(tokenizer_file, encoding = 'utf-8') as f: | |
| tokenizer = json.load(f) | |
| if self.load_merges: | |
| merges = tokenizer.get('model', {}).get('merges') | |
| if isinstance(merges, list) and merges and isinstance(merges[0], str): | |
| self.merges = merges | |
| added_tokens = tokenizer.get('added_tokens', {}) | |
| else: | |
| added_tokens = {} | |
| tokenizer_config_file = path / 'tokenizer_config.json' | |
| if not tokenizer_config_file.is_file(): | |
| return True | |
| with open(tokenizer_config_file, encoding = 'utf-8') as f: | |
| tokenizer_config = json.load(f) | |
| chat_template = tokenizer_config.get('chat_template') | |
| if chat_template is None or isinstance(chat_template, (str, list)): | |
| self.chat_template = chat_template | |
| else: | |
| print( | |
| f'gguf: WARNING: Bad type for chat_template field in {tokenizer_config_file!r} - ignoring', | |
| file = sys.stderr | |
| ) | |
| for typ in self.special_token_types: | |
| add_entry = tokenizer_config.get(f'add_{typ}_token') | |
| if isinstance(add_entry, bool): | |
| self.add_special_token[typ] = add_entry | |
| entry = tokenizer_config.get(f'{typ}_token') | |
| if isinstance(entry, str): | |
| tc_content = entry | |
| elif isinstance(entry, dict): | |
| entry_content = entry.get('content') | |
| if not isinstance(entry_content, str): | |
| continue | |
| tc_content = entry_content | |
| else: | |
| continue | |
| # We only need the first match here. | |
| maybe_token_id = next( | |
| (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content), | |
| None, | |
| ) | |
| self._set_special_token(typ, maybe_token_id) | |
| return True | |
| def _try_load_from_config_json(self, path: Path) -> bool: | |
| config_file = path / 'config.json' | |
| if not config_file.is_file(): | |
| return False | |
| with open(config_file, encoding = 'utf-8') as f: | |
| config = json.load(f) | |
| for typ in self.special_token_types: | |
| self._set_special_token(typ, config.get(f'{typ}_token_id')) | |
| return True | |