AbeerTrial's picture
Duplicate from AbeerTrial/SOAPAssist
35b22df
"""General utils functions."""
import random
import sys
import time
import traceback
import uuid
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Callable, Generator, List, Optional, Set, Type, cast
class GlobalsHelper:
"""Helper to retrieve globals.
Helpful for global caching of certain variables that can be expensive to load.
(e.g. tokenization)
"""
_tokenizer: Optional[Callable[[str], List]] = None
_stopwords: Optional[List[str]] = None
@property
def tokenizer(self) -> Callable[[str], List]:
"""Get tokenizer."""
if self._tokenizer is None:
# if python version >= 3.9, then use tiktoken
# else use GPT2TokenizerFast
if sys.version_info >= (3, 9):
tiktoken_import_err = (
"`tiktoken` package not found, please run `pip install tiktoken`"
)
try:
import tiktoken
except ImportError:
raise ImportError(tiktoken_import_err)
enc = tiktoken.get_encoding("gpt2")
self._tokenizer = cast(Callable[[str], List], enc.encode)
else:
try:
import transformers
except ImportError:
raise ImportError(
"`transformers` package not found, "
"please run `pip install transformers`"
)
tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
def tokenizer_fn(text: str) -> List:
return tokenizer(text)["input_ids"]
self._tokenizer = tokenizer_fn
return self._tokenizer
@property
def stopwords(self) -> List[str]:
"""Get stopwords."""
if self._stopwords is None:
try:
import nltk
from nltk.corpus import stopwords
except ImportError:
raise ImportError(
"`nltk` package not found, please run `pip install nltk`"
)
nltk.download("stopwords")
self._stopwords = stopwords.words("english")
return self._stopwords
globals_helper = GlobalsHelper()
def get_new_id(d: Set) -> str:
"""Get a new ID."""
while True:
new_id = str(uuid.uuid4())
if new_id not in d:
break
return new_id
def get_new_int_id(d: Set) -> int:
"""Get a new integer ID."""
while True:
new_id = random.randint(0, sys.maxsize)
if new_id not in d:
break
return new_id
@contextmanager
def temp_set_attrs(obj: Any, **kwargs: Any) -> Generator:
"""Temporary setter.
Utility class for setting a temporary value for an attribute on a class.
Taken from: https://tinyurl.com/2p89xymh
"""
prev_values = {k: getattr(obj, k) for k in kwargs}
for k, v in kwargs.items():
setattr(obj, k, v)
try:
yield
finally:
for k, v in prev_values.items():
setattr(obj, k, v)
@dataclass
class ErrorToRetry:
"""Exception types that should be retried.
Args:
exception_cls (Type[Exception]): Class of exception.
check_fn (Optional[Callable[[Any]], bool]]):
A function that takes an exception instance as input and returns
whether to retry.
"""
exception_cls: Type[Exception]
check_fn: Optional[Callable[[Any], bool]] = None
def retry_on_exceptions_with_backoff(
lambda_fn: Callable,
errors_to_retry: List[ErrorToRetry],
max_tries: int = 10,
min_backoff_secs: float = 0.5,
max_backoff_secs: float = 60.0,
) -> Any:
"""Execute lambda function with retries and exponential backoff.
Args:
lambda_fn (Callable): Function to be called and output we want.
errors_to_retry (List[ErrorToRetry]): List of errors to retry.
At least one needs to be provided.
max_tries (int): Maximum number of tries, including the first. Defaults to 10.
min_backoff_secs (float): Minimum amount of backoff time between attempts.
Defaults to 0.5.
max_backoff_secs (float): Maximum amount of backoff time between attempts.
Defaults to 60.
"""
if not errors_to_retry:
raise ValueError("At least one error to retry needs to be provided")
error_checks = {
error_to_retry.exception_cls: error_to_retry.check_fn
for error_to_retry in errors_to_retry
}
exception_class_tuples = tuple(error_checks.keys())
backoff_secs = min_backoff_secs
tries = 0
while True:
try:
return lambda_fn()
except exception_class_tuples as e:
traceback.print_exc()
tries += 1
if tries >= max_tries:
raise
check_fn = error_checks.get(e.__class__)
if check_fn and not check_fn(e):
raise
time.sleep(backoff_secs)
backoff_secs = min(backoff_secs * 2, max_backoff_secs)
def truncate_text(text: str, max_length: int) -> str:
"""Truncate text to a maximum length."""
return text[: max_length - 3] + "..."