Multimodal_RAG / colpali_manager.py
amitsinghchandel's picture
Update space with new code
b69b215
from colpali_engine.models import ColPali
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
import torch
from typing import List, cast
from tqdm import tqdm
from PIL import Image
# Ensure device is set to CPU for macOS
device = torch.device("cpu")
model_name = "vidore/colpali-v1.2"
# Load the ColPali model and processor for CPU
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.float32, # Use float32 for CPU
device_map=None, # No device map needed for CPU
).to(device).eval()
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
class ColpaliManager:
def __init__(self, device="cpu", model_name="vidore/colpali-v1.2"):
print(f"Initializing ColpaliManager with device {device} and model {model_name}")
self.device = torch.device(device)
self.model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.float32,
device_map=None,
).to(self.device).eval()
self.processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
def get_images(self, paths: List[str]) -> List[Image.Image]:
return [Image.open(path) for path in paths]
def process_images(self, image_paths: List[str], batch_size=5) -> List[float]:
print(f"Processing {len(image_paths)} image_paths")
images = self.get_images(image_paths)
dataloader = DataLoader(
dataset=ListDataset[Image.Image](images),
batch_size=batch_size,
shuffle=False,
collate_fn=lambda x: self.processor.process_images(x),
)
embeddings = []
for batch_doc in tqdm(dataloader):
with torch.no_grad():
batch_doc = {k: v.to(self.device) for k, v in batch_doc.items()}
embeddings_batch = self.model(**batch_doc)
embeddings.extend(list(torch.unbind(embeddings_batch.to(self.device))))
return [embedding.float().cpu().numpy() for embedding in embeddings]
def process_text(self, texts: List[str], batch_size=1) -> List[float]:
print(f"Processing {len(texts)} texts")
dataloader = DataLoader(
dataset=ListDataset[str](texts),
batch_size=batch_size,
shuffle=False,
collate_fn=lambda x: self.processor.process_queries(x),
)
embeddings = []
for batch_query in dataloader:
with torch.no_grad():
batch_query = {k: v.to(self.device) for k, v in batch_query.items()}
embeddings_batch = self.model(**batch_query)
embeddings.extend(list(torch.unbind(embeddings_batch.to(self.device))))
return [embedding.float().cpu().numpy() for embedding in embeddings]