|
from tclogger import logger
|
|
from transformers import AutoTokenizer
|
|
|
|
from constants.models import MODEL_MAP, TOKEN_LIMIT_MAP, TOKEN_RESERVED
|
|
|
|
|
|
class TokenChecker:
|
|
def __init__(self, input_str: str, model: str):
|
|
self.input_str = input_str
|
|
|
|
if model in MODEL_MAP.keys():
|
|
self.model = model
|
|
else:
|
|
self.model = "nous-mixtral-8x7b"
|
|
|
|
self.model_fullname = MODEL_MAP[self.model]
|
|
|
|
|
|
GATED_MODEL_MAP = {
|
|
"llama3-70b": "NousResearch/Meta-Llama-3-70B",
|
|
"gemma-7b": "unsloth/gemma-7b",
|
|
"mistral-7b": "dfurman/Mistral-7B-Instruct-v0.2",
|
|
"mixtral-8x7b": "dfurman/Mixtral-8x7B-Instruct-v0.1",
|
|
}
|
|
if self.model in GATED_MODEL_MAP.keys():
|
|
self.tokenizer = AutoTokenizer.from_pretrained(GATED_MODEL_MAP[self.model])
|
|
else:
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_fullname)
|
|
|
|
def count_tokens(self):
|
|
token_count = len(self.tokenizer.encode(self.input_str))
|
|
logger.note(f"Prompt Token Count: {token_count}")
|
|
return token_count
|
|
|
|
def get_token_limit(self):
|
|
return TOKEN_LIMIT_MAP[self.model]
|
|
|
|
def get_token_redundancy(self):
|
|
return int(self.get_token_limit() - TOKEN_RESERVED - self.count_tokens())
|
|
|
|
def check_token_limit(self):
|
|
if self.get_token_redundancy() <= 0:
|
|
raise ValueError(
|
|
f"Prompt exceeded token limit: {self.count_tokens()} > {self.get_token_limit()}"
|
|
)
|
|
return True
|
|
|