from colpali_engine.models import ColPali
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
import torch
from typing import List, cast

from tqdm import tqdm
from PIL import Image

# Ensure device is set to CPU for macOS
device = torch.device("cpu")
model_name = "vidore/colpali-v1.2"

# Load the ColPali model and processor for CPU
model = ColPali.from_pretrained(
    model_name,
    torch_dtype=torch.float32,  # Use float32 for CPU
    device_map=None,  # No device map needed for CPU
).to(device).eval()

processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))


class ColpaliManager:
    def __init__(self, device="cpu", model_name="vidore/colpali-v1.2"):
        print(f"Initializing ColpaliManager with device {device} and model {model_name}")

        self.device = torch.device(device)
        self.model = ColPali.from_pretrained(
            model_name,
            torch_dtype=torch.float32,
            device_map=None,
        ).to(self.device).eval()

        self.processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))

    def get_images(self, paths: List[str]) -> List[Image.Image]:
        return [Image.open(path) for path in paths]

    def process_images(self, image_paths: List[str], batch_size=5) -> List[float]:
        print(f"Processing {len(image_paths)} image_paths")
        images = self.get_images(image_paths)

        dataloader = DataLoader(
            dataset=ListDataset[Image.Image](images),
            batch_size=batch_size,
            shuffle=False,
            collate_fn=lambda x: self.processor.process_images(x),
        )

        embeddings = []
        for batch_doc in tqdm(dataloader):
            with torch.no_grad():
                batch_doc = {k: v.to(self.device) for k, v in batch_doc.items()}
                embeddings_batch = self.model(**batch_doc)
            embeddings.extend(list(torch.unbind(embeddings_batch.to(self.device))))

        return [embedding.float().cpu().numpy() for embedding in embeddings]

    def process_text(self, texts: List[str], batch_size=1) -> List[float]:
        print(f"Processing {len(texts)} texts")

        dataloader = DataLoader(
            dataset=ListDataset[str](texts),
            batch_size=batch_size,
            shuffle=False,
            collate_fn=lambda x: self.processor.process_queries(x),
        )

        embeddings = []
        for batch_query in dataloader:
            with torch.no_grad():
                batch_query = {k: v.to(self.device) for k, v in batch_query.items()}
                embeddings_batch = self.model(**batch_query)
            embeddings.extend(list(torch.unbind(embeddings_batch.to(self.device))))

        return [embedding.float().cpu().numpy() for embedding in embeddings]