from colpali_engine.models import ColPali from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor from colpali_engine.utils.torch_utils import ListDataset from torch.utils.data import DataLoader import torch from typing import List, cast from tqdm import tqdm from PIL import Image # Ensure device is set to CPU for macOS device = torch.device("cpu") model_name = "vidore/colpali-v1.2" # Load the ColPali model and processor for CPU model = ColPali.from_pretrained( model_name, torch_dtype=torch.float32, # Use float32 for CPU device_map=None, # No device map needed for CPU ).to(device).eval() processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name)) class ColpaliManager: def __init__(self, device="cpu", model_name="vidore/colpali-v1.2"): print(f"Initializing ColpaliManager with device {device} and model {model_name}") self.device = torch.device(device) self.model = ColPali.from_pretrained( model_name, torch_dtype=torch.float32, device_map=None, ).to(self.device).eval() self.processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name)) def get_images(self, paths: List[str]) -> List[Image.Image]: return [Image.open(path) for path in paths] def process_images(self, image_paths: List[str], batch_size=5) -> List[float]: print(f"Processing {len(image_paths)} image_paths") images = self.get_images(image_paths) dataloader = DataLoader( dataset=ListDataset[Image.Image](images), batch_size=batch_size, shuffle=False, collate_fn=lambda x: self.processor.process_images(x), ) embeddings = [] for batch_doc in tqdm(dataloader): with torch.no_grad(): batch_doc = {k: v.to(self.device) for k, v in batch_doc.items()} embeddings_batch = self.model(**batch_doc) embeddings.extend(list(torch.unbind(embeddings_batch.to(self.device)))) return [embedding.float().cpu().numpy() for embedding in embeddings] def process_text(self, texts: List[str], batch_size=1) -> List[float]: print(f"Processing {len(texts)} texts") dataloader = DataLoader( dataset=ListDataset[str](texts), batch_size=batch_size, shuffle=False, collate_fn=lambda x: self.processor.process_queries(x), ) embeddings = [] for batch_query in dataloader: with torch.no_grad(): batch_query = {k: v.to(self.device) for k, v in batch_query.items()} embeddings_batch = self.model(**batch_query) embeddings.extend(list(torch.unbind(embeddings_batch.to(self.device)))) return [embedding.float().cpu().numpy() for embedding in embeddings]