import os
import re
import multiprocessing
from pathlib import Path
from typing import Dict, List

from datasets import load_dataset, Dataset
from transformers import AutoTokenizer


os.environ["TOKENIZERS_PARALLELISM"] = "false"


DATASET_NAME_PATTERN = re.compile(r"[^a-zA-Z0-9]")


def download_dataset(
    ds_name: str,
    ds_config: str = None,
    ds_split: str = "train",
):
    """
    Download a dataset from the HuggingFace Hub. Will only save the

    Args:
        ds_name (`str`):
            The name of the dataset to load.
        ds_config (`str`, *optional*, Defaults to `None`):
            The configuration of the dataset to load.
        ds_split (`str`, *optional*, Defaults to `"train"`):
            The split of the dataset to load.

    Returns:
        len(ds) (`int`):
            The number of rows in the dataset.
    """
    if ds_name == "wikipedia":
        ds = load_wikipedia(ds_name, ds_config)
    else:
        if ds_config == "":
            ds_config = None
        ds = load_dataset(ds_name, ds_config, split=ds_split)

    chunk_and_save_dataset(
        ds, ds_name=ds_name, ds_config=ds_config, suffix=f"_{ds_split}_raw"
    )

    return len(ds)


def load_wikipedia(ds_name, ds_config):
    """
    Stream the wikipedia dataset from the HuggingFace Hub.

    Args:
        ds_name (`str`):
            The name of the dataset to load. Must be `"wikipedia"`.
        ds_config (`str`, *optional*, Defaults to `None`):
            The configuration of the dataset to load.

    Returns:
        ds (`datasets.Dataset`):
    """
    ds = load_dataset(ds_name, ds_config, streaming=True, split="train")

    def gen():
        for example in ds:
            yield {"text": example["text"]}

    return Dataset.from_generator(gen)


def chunk_and_save_dataset(
    ds: Dataset,
    chunk_size: int = 20_000,
    ds_name: str = None,
    ds_config: str = None,
    suffix: str = "",
):
    """
    Chunk a dataset into smaller datasets of size `chunk_size`.
    The name of the dataset will be used to create a folder in `/data`.

    Args:
        ds (`Dataset`):
            The dataset to chunk.
        chunk_size (`int`, *optional*, Defaults to `20_000`):
            The size of each chunk. Defaults to `20_000`.
        ds_name (`str`, *optional*, Defaults to `None`):
            The name of the dataset to load.
        ds_config (`str`, *optional*, Defaults to `None`):
            The configuration of the dataset to load.
        suffix (`str`, *optional*, Defaults to `""`):
            The suffix to add to the dataset name.


    Returns:
        chunks (`List[Dataset]`):
            The list of chunks.
    """

    if ds_config is None:
        ds_config = ""

    folder = Path("/data") / DATASET_NAME_PATTERN.sub("", ds_name + ds_config)
    folder.mkdir(exist_ok=True, parents=True)

    for chunk_num, start_idx in enumerate(range(0, len(ds), chunk_size)):
        end_idx = min(start_idx + chunk_size, len(ds))

        temp = ds.select(range(start_idx, end_idx))

        temp.to_parquet(str(folder / f"chunk_{chunk_num}{suffix}"))


def tokenize_dataset(
    ds_name: str,
    ds_config: str = None,
    ds_split: str = "train",
    model_name: str = None,
    opt_level: str = None,
    column_name: str = "text",
    num2skip: int = 0,
    num2embed: int = -1,
):
    """
    Tokenize the examples using the tokenizer. Sort by length

    Args:
        ds_name (`str`):
            The name of the dataset to load.

        ds_config (`str`, *optional*, Defaults to `None`):
            The configuration of the dataset to load.

        model_name (`str`, *optional*, Defaults to `None`):
            The name of the model to use for tokenization.

        opt_level (`str`, *optional*, Defaults to `None`):
            The optimization level to use for tokenization.

        column_name (`str`, *optional*, defaults to `text`):
            column name to use for tokenization. Defaults to `text`

        num2skip (`int`, *optional*, defaults to `0`):
            number of rows to skip. Defaults to `0`

        num2embed (`int`, *optional*, defaults to `-1`):
            number of rows to embed. Defaults to `-1`, which means all rows.

    Returns:
        ds (`Dataset`):
    """

    # TODO: option for controlling length for models that can go shorter/longer than 512

    folder = Path("/data") / DATASET_NAME_PATTERN.sub("", ds_name + ds_config)
    files = list(map(str, folder.glob(f"chunk_*_{ds_split}_raw")))

    ds = load_dataset("parquet", data_files=files, split="train")

    if num2embed == -1:
        num2embed = len(ds)
    ds = ds.select(range(num2skip, num2skip + num2embed))

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    padding = "max_length" if opt_level == "O4" else False
    max_length = 512

    def tokenize(
        examples: Dict[str, List[str]],
    ):
        tokenized = tokenizer(
            examples[column_name],
            truncation=True,
            padding=padding,
            max_length=max_length,
        )
        tokenized["length"] = [len(x) for x in tokenized["input_ids"]]

        return tokenized

    tds = ds.map(
        tokenize,
        batched=True,
        batch_size=1000,
        remove_columns=set(ds.column_names) - {column_name},
        num_proc=multiprocessing.cpu_count(),
        desc="Tokenizing",
    )

    # sort to minimize padding
    if padding != "max_length":
        tds = tds.sort("length")

    chunk_and_save_dataset(
        tds, ds_name=ds_name, ds_config=ds_config, suffix=f"_{ds_split}_tokenized"
    )


def load_tokenized_dataset(
    ds_name: str,
    ds_config: str = None,
    ds_split: str = "train",
):
    """
    Load a tokenized dataset from disk.

    Args:
        ds_name (`str`):
            The name of the dataset to load.

        ds_config (`str`, *optional*, Defaults to `None`):
            The configuration of the dataset to load.

        ds_split (`str`, *optional*, Defaults to `"train"`):
            The split of the dataset to load.

    Returns:
        ds (`Dataset`):
    """

    folder = Path("/data") / DATASET_NAME_PATTERN.sub("", ds_name + ds_config)
    files = list(map(str, folder.glob(f"chunk_*_{ds_split}_tokenized")))

    return load_dataset("parquet", data_files=files, split="train")