chroma / chromadb /utils /data_loaders.py
badalsahani's picture
feat: chroma initial deploy
287a0bc
raw
history blame contribute delete
1.01 kB
import importlib
import multiprocessing
from typing import Optional, Sequence, List
import numpy as np
from chromadb.api.types import URI, DataLoader, Image
from concurrent.futures import ThreadPoolExecutor
class ImageLoader(DataLoader[List[Optional[Image]]]):
def __init__(self, max_workers: int = multiprocessing.cpu_count()) -> None:
try:
self._PILImage = importlib.import_module("PIL.Image")
self._max_workers = max_workers
except ImportError:
raise ValueError(
"The PIL python package is not installed. Please install it with `pip install pillow`"
)
def _load_image(self, uri: Optional[URI]) -> Optional[Image]:
return np.array(self._PILImage.open(uri)) if uri is not None else None
def __call__(self, uris: Sequence[Optional[URI]]) -> List[Optional[Image]]:
with ThreadPoolExecutor(max_workers=self._max_workers) as executor:
return list(executor.map(self._load_image, uris))