Spaces:
Configuration error
Configuration error
from colpali_engine.models import ColPali | |
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor | |
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor | |
from colpali_engine.utils.torch_utils import ListDataset | |
from torch.utils.data import DataLoader | |
import torch | |
from typing import List, cast | |
from tqdm import tqdm | |
from PIL import Image | |
# Ensure device is set to CPU for macOS | |
device = torch.device("cpu") | |
model_name = "vidore/colpali-v1.2" | |
# Load the ColPali model and processor for CPU | |
model = ColPali.from_pretrained( | |
model_name, | |
torch_dtype=torch.float32, # Use float32 for CPU | |
device_map=None, # No device map needed for CPU | |
).to(device).eval() | |
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name)) | |
class ColpaliManager: | |
def __init__(self, device="cpu", model_name="vidore/colpali-v1.2"): | |
print(f"Initializing ColpaliManager with device {device} and model {model_name}") | |
self.device = torch.device(device) | |
self.model = ColPali.from_pretrained( | |
model_name, | |
torch_dtype=torch.float32, | |
device_map=None, | |
).to(self.device).eval() | |
self.processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name)) | |
def get_images(self, paths: List[str]) -> List[Image.Image]: | |
return [Image.open(path) for path in paths] | |
def process_images(self, image_paths: List[str], batch_size=5) -> List[float]: | |
print(f"Processing {len(image_paths)} image_paths") | |
images = self.get_images(image_paths) | |
dataloader = DataLoader( | |
dataset=ListDataset[Image.Image](images), | |
batch_size=batch_size, | |
shuffle=False, | |
collate_fn=lambda x: self.processor.process_images(x), | |
) | |
embeddings = [] | |
for batch_doc in tqdm(dataloader): | |
with torch.no_grad(): | |
batch_doc = {k: v.to(self.device) for k, v in batch_doc.items()} | |
embeddings_batch = self.model(**batch_doc) | |
embeddings.extend(list(torch.unbind(embeddings_batch.to(self.device)))) | |
return [embedding.float().cpu().numpy() for embedding in embeddings] | |
def process_text(self, texts: List[str], batch_size=1) -> List[float]: | |
print(f"Processing {len(texts)} texts") | |
dataloader = DataLoader( | |
dataset=ListDataset[str](texts), | |
batch_size=batch_size, | |
shuffle=False, | |
collate_fn=lambda x: self.processor.process_queries(x), | |
) | |
embeddings = [] | |
for batch_query in dataloader: | |
with torch.no_grad(): | |
batch_query = {k: v.to(self.device) for k, v in batch_query.items()} | |
embeddings_batch = self.model(**batch_query) | |
embeddings.extend(list(torch.unbind(embeddings_batch.to(self.device)))) | |
return [embedding.float().cpu().numpy() for embedding in embeddings] | |