import asyncio
from typing import Callable, Optional, Union

import huggingface_hub
import semchunk
import tiktoken
import tokenizers
from datasets import Dataset, concatenate_datasets, load_dataset
from rich.progress import track
from transformers import PreTrainedTokenizer

TOKENIZER_OR_TOKEN_COUNTER = Union[
    str,
    tiktoken.Encoding,
    PreTrainedTokenizer,
    tokenizers.Tokenizer,
    Callable[[str], int],
]


class SemanticChunker:
    """
    SemanticChunker is a class that chunks documents into smaller segments and
    publishes them as datasets.

    This class uses the `semchunk` library to break down large documents into
    smaller, manageable chunks based on a specified tokenizer or token counter.
    This is particularly useful for processing large text datasets where
    smaller segments are needed for analysis or other operations.

    !!! example "Example Usage"
        ```python
        from medrag_multi_modal.semantic_chunking import SemanticChunker


        chunker = SemanticChunker(chunk_size=256)
        chunker.chunk(
            document_dataset="geekyrakshit/grays-anatomy-test",
            chunk_dataset_repo_id="geekyrakshit/grays-anatomy-chunks-test",
        )
        ```

    Args:
        tokenizer_or_token_counter (TOKENIZER_OR_TOKEN_COUNTER): The tokenizer or
            token counter to be used for chunking.
        chunk_size (Optional[int]): The size of each chunk. If not specified, the
            default chunk size from `semchunk` will be used.
        max_token_chars (Optional[int]): The maximum number of characters per token.
            If not specified, the default value from `semchunk` will be used.
        memoize (bool): Whether to memoize the chunking process for efficiency.
            Default is True.
    """

    def __init__(
        self,
        tokenizer_or_token_counter: TOKENIZER_OR_TOKEN_COUNTER = "o200k_base",
        chunk_size: Optional[int] = None,
        max_token_chars: Optional[int] = None,
        memoize: bool = True,
    ) -> None:
        self.chunker = semchunk.chunkerify(
            tokenizer_or_token_counter,
            chunk_size=chunk_size,
            max_token_chars=max_token_chars,
            memoize=memoize,
        )

    def chunk(
        self,
        document_dataset: Union[Dataset, str],
        chunk_dataset_repo_id: Optional[str] = None,
        overwrite_dataset: bool = False,
    ) -> Dataset:
        """
        Chunks a document dataset into smaller segments and publishes them as a new dataset.

        This function takes a document dataset, either as a HuggingFace Dataset object or a string
        representing the dataset repository ID, and chunks the documents into smaller segments using
        the specified chunker. The resulting chunks are then optionally published to a HuggingFace
        dataset repository.

        Args:
            document_dataset (Union[Dataset, str]): The document dataset to be chunked. It can be either
                a HuggingFace Dataset object or a string representing the dataset repository ID.
            chunk_dataset_repo_id (Optional[str]): The repository ID of the HuggingFace dataset to publish
                the chunks to, if provided. Defaults to None.
            overwrite_dataset (bool): Whether to overwrite the existing dataset if it exists. Defaults to False.

        Returns:
            Dataset: A HuggingFace Dataset object containing the chunks.
        """
        document_dataset = (
            load_dataset(document_dataset, split="corpus")
            if isinstance(document_dataset, str)
            else document_dataset
        ).to_list()

        chunks = []

        async def process_document(idx, document):
            document_chunks = self.chunker.chunk(str(document["text"]))
            for chunk in document_chunks:
                chunk_dict = {"document_idx": idx, "text": chunk}
                for key, value in document.items():
                    if key not in chunk_dict:
                        chunk_dict[key] = value
                chunks.append(chunk_dict)

        async def process_all_documents():
            tasks = []
            for idx, document in track(
                enumerate(document_dataset),
                total=len(document_dataset),
                description="Chunking documents",
            ):
                tasks.append(process_document(idx, document))
            await asyncio.gather(*tasks)

        asyncio.run(process_all_documents())

        chunks.sort(key=lambda x: x["document_idx"])

        dataset = Dataset.from_list(chunks)
        if chunk_dataset_repo_id:
            if huggingface_hub.repo_exists(chunk_dataset_repo_id, repo_type="dataset"):
                if not overwrite_dataset:
                    dataset = concatenate_datasets(
                        [
                            dataset,
                            load_dataset(chunk_dataset_repo_id, split="chunks"),
                        ]
                    )
            dataset.push_to_hub(repo_id=chunk_dataset_repo_id, split="chunks")

        return dataset