Spaces:
Runtime error
Runtime error
"""General utils functions.""" | |
import random | |
import sys | |
import time | |
import traceback | |
import uuid | |
from contextlib import contextmanager | |
from dataclasses import dataclass | |
from typing import Any, Callable, Generator, List, Optional, Set, Type, cast | |
class GlobalsHelper: | |
"""Helper to retrieve globals. | |
Helpful for global caching of certain variables that can be expensive to load. | |
(e.g. tokenization) | |
""" | |
_tokenizer: Optional[Callable[[str], List]] = None | |
_stopwords: Optional[List[str]] = None | |
def tokenizer(self) -> Callable[[str], List]: | |
"""Get tokenizer.""" | |
if self._tokenizer is None: | |
# if python version >= 3.9, then use tiktoken | |
# else use GPT2TokenizerFast | |
if sys.version_info >= (3, 9): | |
tiktoken_import_err = ( | |
"`tiktoken` package not found, please run `pip install tiktoken`" | |
) | |
try: | |
import tiktoken | |
except ImportError: | |
raise ImportError(tiktoken_import_err) | |
enc = tiktoken.get_encoding("gpt2") | |
self._tokenizer = cast(Callable[[str], List], enc.encode) | |
else: | |
try: | |
import transformers | |
except ImportError: | |
raise ImportError( | |
"`transformers` package not found, " | |
"please run `pip install transformers`" | |
) | |
tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2") | |
def tokenizer_fn(text: str) -> List: | |
return tokenizer(text)["input_ids"] | |
self._tokenizer = tokenizer_fn | |
return self._tokenizer | |
def stopwords(self) -> List[str]: | |
"""Get stopwords.""" | |
if self._stopwords is None: | |
try: | |
import nltk | |
from nltk.corpus import stopwords | |
except ImportError: | |
raise ImportError( | |
"`nltk` package not found, please run `pip install nltk`" | |
) | |
nltk.download("stopwords") | |
self._stopwords = stopwords.words("english") | |
return self._stopwords | |
globals_helper = GlobalsHelper() | |
def get_new_id(d: Set) -> str: | |
"""Get a new ID.""" | |
while True: | |
new_id = str(uuid.uuid4()) | |
if new_id not in d: | |
break | |
return new_id | |
def get_new_int_id(d: Set) -> int: | |
"""Get a new integer ID.""" | |
while True: | |
new_id = random.randint(0, sys.maxsize) | |
if new_id not in d: | |
break | |
return new_id | |
def temp_set_attrs(obj: Any, **kwargs: Any) -> Generator: | |
"""Temporary setter. | |
Utility class for setting a temporary value for an attribute on a class. | |
Taken from: https://tinyurl.com/2p89xymh | |
""" | |
prev_values = {k: getattr(obj, k) for k in kwargs} | |
for k, v in kwargs.items(): | |
setattr(obj, k, v) | |
try: | |
yield | |
finally: | |
for k, v in prev_values.items(): | |
setattr(obj, k, v) | |
class ErrorToRetry: | |
"""Exception types that should be retried. | |
Args: | |
exception_cls (Type[Exception]): Class of exception. | |
check_fn (Optional[Callable[[Any]], bool]]): | |
A function that takes an exception instance as input and returns | |
whether to retry. | |
""" | |
exception_cls: Type[Exception] | |
check_fn: Optional[Callable[[Any], bool]] = None | |
def retry_on_exceptions_with_backoff( | |
lambda_fn: Callable, | |
errors_to_retry: List[ErrorToRetry], | |
max_tries: int = 10, | |
min_backoff_secs: float = 0.5, | |
max_backoff_secs: float = 60.0, | |
) -> Any: | |
"""Execute lambda function with retries and exponential backoff. | |
Args: | |
lambda_fn (Callable): Function to be called and output we want. | |
errors_to_retry (List[ErrorToRetry]): List of errors to retry. | |
At least one needs to be provided. | |
max_tries (int): Maximum number of tries, including the first. Defaults to 10. | |
min_backoff_secs (float): Minimum amount of backoff time between attempts. | |
Defaults to 0.5. | |
max_backoff_secs (float): Maximum amount of backoff time between attempts. | |
Defaults to 60. | |
""" | |
if not errors_to_retry: | |
raise ValueError("At least one error to retry needs to be provided") | |
error_checks = { | |
error_to_retry.exception_cls: error_to_retry.check_fn | |
for error_to_retry in errors_to_retry | |
} | |
exception_class_tuples = tuple(error_checks.keys()) | |
backoff_secs = min_backoff_secs | |
tries = 0 | |
while True: | |
try: | |
return lambda_fn() | |
except exception_class_tuples as e: | |
traceback.print_exc() | |
tries += 1 | |
if tries >= max_tries: | |
raise | |
check_fn = error_checks.get(e.__class__) | |
if check_fn and not check_fn(e): | |
raise | |
time.sleep(backoff_secs) | |
backoff_secs = min(backoff_secs * 2, max_backoff_secs) | |
def truncate_text(text: str, max_length: int) -> str: | |
"""Truncate text to a maximum length.""" | |
return text[: max_length - 3] + "..." | |