Spaces:
Paused
Paused
""" | |
Based upon ImageCaptionLoader in LangChain version: langchain/document_loaders/image_captions.py | |
But accepts preloaded model to avoid slowness in use and CUDA forking issues | |
Loader that uses Pix2Struct models to image caption | |
""" | |
from typing import List, Union, Any, Tuple | |
from langchain.docstore.document import Document | |
from langchain.document_loaders import ImageCaptionLoader | |
from utils import get_device, clear_torch_cache | |
from PIL import Image | |
class H2OPix2StructLoader(ImageCaptionLoader): | |
"""Loader that extracts text from images""" | |
def __init__(self, path_images: Union[str, List[str]] = None, model_type="google/pix2struct-textcaps-base", | |
max_new_tokens=50): | |
super().__init__(path_images) | |
self._pix2struct_model = None | |
self._model_type = model_type | |
self._max_new_tokens = max_new_tokens | |
def set_context(self): | |
if get_device() == 'cuda': | |
import torch | |
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0 | |
if n_gpus > 0: | |
self.context_class = torch.device | |
self.device = 'cuda' | |
else: | |
self.device = 'cpu' | |
else: | |
self.device = 'cpu' | |
def load_model(self): | |
try: | |
from transformers import AutoProcessor, Pix2StructForConditionalGeneration | |
except ImportError: | |
raise ValueError( | |
"`transformers` package not found, please install with " | |
"`pip install transformers`." | |
) | |
if self._pix2struct_model: | |
self._pix2struct_model = self._pix2struct_model.to(self.device) | |
return self | |
self.set_context() | |
self._pix2struct_processor = AutoProcessor.from_pretrained(self._model_type) | |
self._pix2struct_model = Pix2StructForConditionalGeneration.from_pretrained(self._model_type).to(self.device) | |
return self | |
def unload_model(self): | |
if hasattr(self._pix2struct_model, 'cpu'): | |
self._pix2struct_model.cpu() | |
clear_torch_cache() | |
def set_image_paths(self, path_images: Union[str, List[str]]): | |
""" | |
Load from a list of image files | |
""" | |
if isinstance(path_images, str): | |
self.image_paths = [path_images] | |
else: | |
self.image_paths = path_images | |
def load(self, prompt=None) -> List[Document]: | |
if self._pix2struct_model is None: | |
self.load_model() | |
results = [] | |
for path_image in self.image_paths: | |
caption, metadata = self._get_captions_and_metadata( | |
processor=self._pix2struct_processor, model=self._pix2struct_model, path_image=path_image | |
) | |
doc = Document(page_content=caption, metadata=metadata) | |
results.append(doc) | |
return results | |
def _get_captions_and_metadata( | |
self, processor: Any, model: Any, path_image: str) -> Tuple[str, dict]: | |
""" | |
Helper function for getting the captions and metadata of an image | |
""" | |
try: | |
image = Image.open(path_image) | |
except Exception: | |
raise ValueError(f"Could not get image data for {path_image}") | |
inputs = self._pix2struct_processor(images=image, return_tensors="pt") | |
inputs = inputs.to(self.device) | |
generated_ids = self._pix2struct_model.generate(**inputs, max_new_tokens=self._max_new_tokens) | |
generated_text = self._pix2struct_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
metadata: dict = {"image_path": path_image} | |
return generated_text, metadata | |