import logging
from argparse import ArgumentParser
from typing import List

from meerkat import DataPanel, SpacyColumn
from meerkat.logging.utils import set_logging_level
from spacy import load

from align import BertscoreAligner, NGramAligner, StaticEmbeddingAligner, Aligner
from utils import clean_text

set_logging_level('critical')
logger = logging.getLogger(__name__)
logger.setLevel(logging.CRITICAL)


def _run_aligners(
    dataset: DataPanel,
    aligners: List[Aligner],
    doc_column: str,
    reference_column: str,
    summary_columns: List[str] = None,
):
    if not summary_columns:
        summary_columns = []

    to_columns = []
    if reference_column is not None:
        to_columns.append(reference_column)
    to_columns.extend(summary_columns)

    for aligner in aligners:

        # Run the aligner on (document, summary) pairs
        dataset = dataset.update(
            lambda x: {
                f'{type(aligner).__name__}:{doc_column}:{to_columns}':
                    aligner.align(
                        x[doc_column],
                        [x[col] for col in to_columns],
                    ),
            },
        )

        if reference_column is not None and len(summary_columns):
            # Run the aligner on (reference, summary) pairs
            dataset = dataset.update(
                lambda x: {
                    f'{type(aligner).__name__}:{reference_column}:{summary_columns}': aligner.align(
                        x[reference_column],
                        [x[col] for col in summary_columns],
                    ),
                },
            )

        if len(to_columns) > 1:
            # Instead of having one column for (document, summary) comparisons, split
            # off into (1 + |summary_columns|) total columns, one for each comparison

            # Retrieve the (document, summary) column
            doc_summary_column = dataset[f'{type(aligner).__name__}:{doc_column}:{to_columns}']

            for i, col in enumerate(to_columns):
                # Add as a new column after encoding with the aligner's `encode` method
                dataset.add_column(
                    f'{type(aligner).__name__}:{doc_column}:{col}',
                    [row[i] for row in doc_summary_column],
                )

            # Remove the (document, summary) column
            dataset.remove_column(f'{type(aligner).__name__}:{doc_column}:{to_columns}')

        if reference_column is not None and len(summary_columns) > 1:
            # Instead of having one column for (reference, summary) comparisons, split
            # off into (|summary_columns|) total columns, one for each comparison

            # Retrieve the (reference, summary) column
            reference_summary_column = dataset[f'{type(aligner).__name__}:{reference_column}:{summary_columns}']

            for i, col in enumerate(summary_columns):
                # Add as a new column
                dataset.add_column(
                    f'{type(aligner).__name__}:{reference_column}:{col}',
                    [row[i] for row in reference_summary_column],
                )

            # Remove the (reference, summary) column
            dataset.remove_column(f'{type(aligner).__name__}:{reference_column}:{summary_columns}')

    return dataset


def load_nlp():
    try:
        return load('en_core_web_lg')
    except OSError:
        raise OSError("'en_core_web_lg model' is required unless loading from cached file."
                      "To install: 'python -m spacy download en_core_web_lg'")


def run_workflow(
    jsonl_path: str,
    doc_column: str = None,
    reference_column: str = None,
    summary_columns: List[str] = None,
    bert_aligner_threshold: float = 0.5,
    bert_aligner_top_k: int = 3,
    embedding_aligner_threshold: float = 0.5,
    embedding_aligner_top_k: int = 3,
    processed_dataset_path: str = None,
    n_samples: int = None
):
    if not jsonl_path:
        raise ValueError("'jsonl_path' is required")

    if not processed_dataset_path:
        raise ValueError("Please specify a path to save the dataset.")

    # Load the dataset
    dataset = DataPanel.from_jsonl(jsonl_path)

    if doc_column is None:
        # Assume `doc_column` is called "document"
        doc_column = 'document'
        assert doc_column in dataset.columns, \
            f"`doc_column={doc_column}` is not a column in datapanel."
        print("Assuming `doc_column` is called 'document'.")

    if reference_column is None:
        # Assume `reference_column` is called "summary:reference"
        reference_column = 'summary:reference'
        print("Assuming `reference_column` is called 'summary:reference'.")
        if reference_column not in dataset.columns:
            print("No reference summary loaded")
            reference_column = None

    if summary_columns is None or len(summary_columns) == 0:
        # Assume `summary_columns` are prefixed by "summary:"
        summary_columns = []
        for col in dataset.columns:
            if col.startswith("summary:") and col != "summary:reference":
                summary_columns.append(col)
        print(f"Reading summary columns from datapanel. Found {summary_columns}.")

    if len(summary_columns) == 0 and reference_column is None:
        raise ValueError("At least one summary is required")

    # Restrict to the first `n_samples`
    if n_samples:
        print(f"Restricting to {n_samples} samples.")
        dataset = dataset.head(n_samples)

    print("size of dataset:", len(dataset))

    # Combine the text columns into one list
    text_columns = [doc_column] + ([reference_column] if reference_column else []) + summary_columns

    # Preprocessing all the text columns
    print("Preprocessing text columns")
    dataset = dataset.update(
        lambda x: {
            f'preprocessed_{k}': x[k] if args.no_clean else clean_text(x[k])
            for k in text_columns
        }
    )

    # Run the Spacy pipeline on all preprocessed text columns
    nlp = load_nlp()

    nlp.add_pipe('sentencizer', before="parser")

    print("Running spacy processing")
    for col in text_columns:
        dataset.add_column(f'spacy:{col}', SpacyColumn.from_docs(nlp.pipe(dataset[f'preprocessed_{col}'])))

    # Run the 3 align pipelines
    bert_aligner = BertscoreAligner(
        threshold=bert_aligner_threshold,
        top_k=bert_aligner_top_k,
    )

    embedding_aligner = StaticEmbeddingAligner(
        threshold=embedding_aligner_threshold,
        top_k=embedding_aligner_top_k,
    )

    ngram_aligner = NGramAligner()

    dataset = _run_aligners(
        dataset=dataset,
        aligners=[bert_aligner, embedding_aligner, ngram_aligner],
        doc_column=f'spacy:{doc_column}',
        reference_column=f'spacy:{reference_column}' if reference_column else None,
        summary_columns=[f'spacy:{col}' for col in summary_columns],
    )

    # Save the dataset
    dataset.write(processed_dataset_path)

    return dataset


def standardize_dataset(
    dataset_name: str,
    dataset_version: str,
    dataset_split: str,
    save_jsonl_path: str,
    doc_column: str = None,
    reference_column: str = None,
    n_samples: int = None

):
    """Load a dataset from Huggingface and dump it to disk."""

    if args.dataset is None or \
        args.split is None or \
        args.save_jsonl_path is None:
        raise ValueError('Missing command line argument')

    # Load the dataset from Huggingface
    dataset = get_dataset(
        dataset_name=dataset_name,
        dataset_version=dataset_version,
        dataset_split=dataset_split
    )
    if n_samples:
        dataset = dataset[:n_samples]

    if doc_column is None:
        if reference_column is not None:
            raise ValueError("You must specify `doc_column` if you specify `reference_column`")
        try:
            doc_column, reference_column = {
                'cnn_dailymail': ('article', 'highlights'),
                'xsum': ('document', 'summary')
            }[dataset_name]
        except:
            raise NotImplementedError(
                "Please specify `doc_column`."
            )

    # Rename the columns
    if doc_column != 'document':
        dataset.add_column('document', dataset[doc_column])
        dataset.remove_column(doc_column)
    dataset.add_column('summary:reference', dataset[reference_column])
    dataset.remove_column(reference_column)

    # Save the dataset back to disk
    dataset.to_jsonl(save_jsonl_path)
    return dataset


def get_dataset(
    dataset_name: str = None,
    dataset_version: str = None,
    dataset_split: str = 'test',
    dataset_jsonl: str = None,
):
    """Load a dataset."""
    assert (dataset_name is not None) != (dataset_jsonl is not None), \
        "Specify one of `dataset_name` or `dataset_jsonl`."

    # Load the dataset
    if dataset_name is not None:
        return get_hf_dataset(dataset_name, dataset_version, dataset_split)

    return DataPanel.from_jsonl(json_path=dataset_jsonl)


def get_hf_dataset(name: str, version: str = None, split: str = 'test'):
    """Get dataset from Huggingface."""
    if version:
        return DataPanel.from_huggingface(name, version, split=split)
    return DataPanel.from_huggingface(name, split=split)


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--dataset', type=str, choices=['cnn_dailymail', 'xsum'],
                        help="Huggingface dataset name.")
    parser.add_argument('--version', type=str,
                        help="Huggingface dataset version.")
    parser.add_argument('--split', type=str, default='test',
                        help="Huggingface dataset split.")
    parser.add_argument('--dataset_jsonl', type=str,
                        help="Path to a jsonl file for the dataset.")
    parser.add_argument('--save_jsonl_path', type=str,
                        help="Path to save the processed jsonl dataset.")
    parser.add_argument('--doc_column', type=str,
                        help="Name of the document column in the dataset.")
    parser.add_argument('--reference_column', type=str,
                        help="Name of the reference summary column in the dataset.")
    parser.add_argument('--summary_columns', nargs='+', default=[],
                        help="Name of other summary columns in/added to the dataset.")

    parser.add_argument('--bert_aligner_threshold', type=float, default=0.1,
                        help="Minimum threshold for BERT alignment.")
    parser.add_argument('--bert_aligner_top_k', type=int, default=10,
                        help="Top-k for BERT alignment.")
    parser.add_argument('--embedding_aligner_threshold', type=float, default=0.1,
                        help="Minimum threshold for embedding alignment.")
    parser.add_argument('--embedding_aligner_top_k', type=int, default=10,
                        help="Top-k for embedding alignment.")
    parser.add_argument('--processed_dataset_path', type=str,
                        help="Path to store the final processed dataset.")
    parser.add_argument('--n_samples', type=int,
                        help="Number of dataset samples to process.")

    parser.add_argument('--workflow', action='store_true', default=False,
                        help="Whether to run the preprocessing workflow.")
    parser.add_argument('--standardize', action='store_true', default=False,
                        help="Whether to standardize the dataset and save to jsonl.")
    parser.add_argument('--no_clean', action='store_true', default=False,
                        help="Do not clean text (remove extraneous spaces, newlines).")
    args = parser.parse_args()

    if args.standardize:
        # Dump a Huggingface dataset to standardized jsonl format
        standardize_dataset(
            dataset_name=args.dataset,
            dataset_version=args.version,
            dataset_split=args.split,
            save_jsonl_path=args.save_jsonl_path,
            doc_column=args.doc_column,
            reference_column=args.reference_column,
            n_samples=args.n_samples
        )

    if args.workflow:
        # Run the processing workflow
        run_workflow(
            jsonl_path=args.dataset_jsonl,
            doc_column=args.doc_column,
            reference_column=args.reference_column,
            summary_columns=args.summary_columns,
            bert_aligner_threshold=args.bert_aligner_threshold,
            bert_aligner_top_k=args.bert_aligner_top_k,
            embedding_aligner_threshold=args.embedding_aligner_threshold,
            embedding_aligner_top_k=args.embedding_aligner_top_k,
            processed_dataset_path=args.processed_dataset_path,
            n_samples=args.n_samples
        )