import importlib import multiprocessing from typing import Optional, Sequence, List import numpy as np from chromadb.api.types import URI, DataLoader, Image from concurrent.futures import ThreadPoolExecutor class ImageLoader(DataLoader[List[Optional[Image]]]): def __init__(self, max_workers: int = multiprocessing.cpu_count()) -> None: try: self._PILImage = importlib.import_module("PIL.Image") self._max_workers = max_workers except ImportError: raise ValueError( "The PIL python package is not installed. Please install it with `pip install pillow`" ) def _load_image(self, uri: Optional[URI]) -> Optional[Image]: return np.array(self._PILImage.open(uri)) if uri is not None else None def __call__(self, uris: Sequence[Optional[URI]]) -> List[Optional[Image]]: with ThreadPoolExecutor(max_workers=self._max_workers) as executor: return list(executor.map(self._load_image, uris))