import os

import numpy as np
import pypdfium2 as pdfium
import torch
import tqdm
from model import encode_images, encode_queries
from PIL import Image
from sqlitedict import SqliteDict
from voyager import Index, Space


def iter_batch(
    X: list[str], batch_size: int, tqdm_bar: bool = True, desc: str = ""
) -> list:
    """Iterate over a list of elements by batch."""
    batchs = [X[pos : pos + batch_size] for pos in range(0, len(X), batch_size)]

    if tqdm_bar:
        for batch in tqdm.tqdm(
            iterable=batchs,
            position=0,
            total=1 + len(X) // batch_size,
            desc=desc,
        ):
            yield batch
    else:
        yield from batchs


class Voyager:
    """Voyager index. The Voyager index is a fast and efficient index for approximate nearest neighbor search.

    Parameters
    ----------
    name
        The name of the collection.
    override
        Whether to override the collection if it already exists.
    embedding_size
        The number of dimensions of the embeddings.
    M
        The number of subquantizers.
    ef_construction
        The number of candidates to evaluate during the construction of the index.
    ef_search
        The number of candidates to evaluate during the search.
    """

    def __init__(
        self,
        index_folder: str = "indexes",
        index_name: str = "base_collection",
        override: bool = False,
        embedding_size: int = 128,
        M: int = 64,
        ef_construction: int = 200,
        ef_search: int = 200,
    ) -> None:
        self.ef_search = ef_search

        if not os.path.exists(path=index_folder):
            os.makedirs(name=index_folder)

        self.index_path = os.path.join(index_folder, f"{index_name}.voyager")
        self.page_ids_to_data_path = os.path.join(
            index_folder, f"{index_name}_page_ids_to_data.sqlite"
        )

        self.index = self._create_collection(
            index_path=self.index_path,
            embedding_size=embedding_size,
            M=M,
            ef_constructions=ef_construction,
            override=override,
        )

    def _load_page_ids_to_data(self) -> SqliteDict:
        """Load the SQLite database that maps document IDs to images."""
        return SqliteDict(self.page_ids_to_data_path, outer_stack=False)

    def _create_collection(
        self,
        index_path: str,
        embedding_size: int,
        M: int,
        ef_constructions: int,
        override: bool,
    ) -> None:
        """Create a new Voyager collection.

        Parameters
        ----------
        index_path
            The path to the index.
        embedding_size
            The size of the embeddings.
        M
            The number of subquantizers.
        ef_constructions
            The number of candidates to evaluate during the construction of the index.
        override
            Whether to override the collection if it already exists.

        """
        if os.path.exists(path=index_path) and not override:
            return Index.load(index_path)

        if os.path.exists(path=index_path):
            os.remove(index_path)

        # Create the Voyager index
        index = Index(
            Space.Cosine,
            num_dimensions=embedding_size,
            M=M,
            ef_construction=ef_constructions,
        )

        index.save(index_path)

        if override and os.path.exists(path=self.page_ids_to_data_path):
            os.remove(path=self.page_ids_to_data_path)

        # Create the SQLite databases
        page_ids_to_data = self._load_page_ids_to_data()
        page_ids_to_data.close()
        return index

    def add_documents(
        self,
        paths: str | list[str],
        batch_size: int = 1,
    ) -> None:
        """Add documents to the index. Note that batch_size means the number of pages to encode at once, not documents."""
        if isinstance(paths, str):
            paths = [paths]

        page_ids_to_data = self._load_page_ids_to_data()

        images = []
        num_pages = []

        for path in paths:
            if path.lower().endswith(".pdf"):
                pdf = pdfium.PdfDocument(path)
                n_pages = len(pdf)
                num_pages.append(n_pages)
                for page_number in range(n_pages):
                    page = pdf.get_page(page_number)
                    pil_image = page.render(
                        scale=1,
                        rotation=0,
                    )
                    pil_image = pil_image.to_pil()
                    images.append(pil_image)
                pdf.close()
            else:
                pil_image = Image.open(path)
                images.append(pil_image)
                num_pages.append(1)

        embeddings = []
        for batch in iter_batch(
            X=images, batch_size=batch_size, desc=f"Encoding pages (bs={batch_size})"
        ):
            embeddings.extend(encode_images(batch))

        embeddings_ids = self.index.add_items(embeddings)
        current_index = 0

        for i, path in enumerate(paths):
            for page_number in range(num_pages[i]):
                page_ids_to_data[embeddings_ids[current_index]] = {
                    "path": path,
                    "image": images[current_index],
                    "page_number": page_number,
                }
                current_index += 1

        page_ids_to_data.commit()
        self.index.save(self.index_path)

        return self

    def __call__(
        self,
        queries: np.ndarray | torch.Tensor,
        k: int = 10,
    ) -> dict:
        """Query the index for the nearest neighbors of the queries embeddings.

        Parameters
        ----------
        queries_embeddings
            The queries embeddings.
        k
            The number of nearest neighbors to return.

        """

        queries_embeddings = encode_queries(queries)
        page_ids_to_data = self._load_page_ids_to_data()
        k = min(k, len(page_ids_to_data))

        n_queries = len(queries_embeddings)
        indices, distances = self.index.query(
            queries_embeddings, k, query_ef=self.ef_search
        )

        if len(indices) == 0:
            raise ValueError("Index is empty, add documents before querying.")
        documents = [
            [page_ids_to_data[str(indice)] for indice in query_indices]
            for query_indices in indices
        ]
        page_ids_to_data.close()
        return {
            "documents": documents,
            "distances": distances.reshape(n_queries, -1, k),
        }