"""General utils functions.""" import random import sys import time import traceback import uuid from contextlib import contextmanager from dataclasses import dataclass from itertools import islice from typing import ( Any, Callable, Generator, List, Optional, Set, Type, cast, Union, Iterable, ) class GlobalsHelper: """Helper to retrieve globals. Helpful for global caching of certain variables that can be expensive to load. (e.g. tokenization) """ _tokenizer: Optional[Callable[[str], List]] = None _stopwords: Optional[List[str]] = None @property def tokenizer(self) -> Callable[[str], List]: """Get tokenizer.""" if self._tokenizer is None: # if python version >= 3.9, then use tiktoken # else use GPT2TokenizerFast if sys.version_info >= (3, 9): tiktoken_import_err = ( "`tiktoken` package not found, please run `pip install tiktoken`" ) try: import tiktoken except ImportError: raise ImportError(tiktoken_import_err) enc = tiktoken.get_encoding("gpt2") self._tokenizer = cast(Callable[[str], List], enc.encode) else: try: import transformers except ImportError: raise ImportError( "`transformers` package not found, " "please run `pip install transformers`" ) tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2") def tokenizer_fn(text: str) -> List: return tokenizer(text)["input_ids"] self._tokenizer = tokenizer_fn return self._tokenizer @property def stopwords(self) -> List[str]: """Get stopwords.""" if self._stopwords is None: try: import nltk from nltk.corpus import stopwords except ImportError: raise ImportError( "`nltk` package not found, please run `pip install nltk`" ) try: nltk.data.find("corpora/stopwords") except LookupError: nltk.download("stopwords") self._stopwords = stopwords.words("english") return self._stopwords globals_helper = GlobalsHelper() def get_new_id(d: Set) -> str: """Get a new ID.""" while True: new_id = str(uuid.uuid4()) if new_id not in d: break return new_id def get_new_int_id(d: Set) -> int: """Get a new integer ID.""" while True: new_id = random.randint(0, sys.maxsize) if new_id not in d: break return new_id @contextmanager def temp_set_attrs(obj: Any, **kwargs: Any) -> Generator: """Temporary setter. Utility class for setting a temporary value for an attribute on a class. Taken from: https://tinyurl.com/2p89xymh """ prev_values = {k: getattr(obj, k) for k in kwargs} for k, v in kwargs.items(): setattr(obj, k, v) try: yield finally: for k, v in prev_values.items(): setattr(obj, k, v) @dataclass class ErrorToRetry: """Exception types that should be retried. Args: exception_cls (Type[Exception]): Class of exception. check_fn (Optional[Callable[[Any]], bool]]): A function that takes an exception instance as input and returns whether to retry. """ exception_cls: Type[Exception] check_fn: Optional[Callable[[Any], bool]] = None def retry_on_exceptions_with_backoff( lambda_fn: Callable, errors_to_retry: List[ErrorToRetry], max_tries: int = 10, min_backoff_secs: float = 0.5, max_backoff_secs: float = 60.0, ) -> Any: """Execute lambda function with retries and exponential backoff. Args: lambda_fn (Callable): Function to be called and output we want. errors_to_retry (List[ErrorToRetry]): List of errors to retry. At least one needs to be provided. max_tries (int): Maximum number of tries, including the first. Defaults to 10. min_backoff_secs (float): Minimum amount of backoff time between attempts. Defaults to 0.5. max_backoff_secs (float): Maximum amount of backoff time between attempts. Defaults to 60. """ if not errors_to_retry: raise ValueError("At least one error to retry needs to be provided") error_checks = { error_to_retry.exception_cls: error_to_retry.check_fn for error_to_retry in errors_to_retry } exception_class_tuples = tuple(error_checks.keys()) backoff_secs = min_backoff_secs tries = 0 while True: try: return lambda_fn() except exception_class_tuples as e: traceback.print_exc() tries += 1 if tries >= max_tries: raise check_fn = error_checks.get(e.__class__) if check_fn and not check_fn(e): raise time.sleep(backoff_secs) backoff_secs = min(backoff_secs * 2, max_backoff_secs) def truncate_text(text: str, max_length: int) -> str: """Truncate text to a maximum length.""" return text[: max_length - 3] + "..." def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable: """Iterate over an iterable in batches. >>> list(iter_batch([1,2,3,4,5], 3)) [[1, 2, 3], [4, 5]] """ source_iter = iter(iterable) while source_iter: b = list(islice(source_iter, size)) if len(b) == 0: break yield b