"""
grammar_improve.py - this .py script contains functions to improve the grammar of a user's input or the models output.

"""

import logging
logging.basicConfig(level=logging.INFO)
import math
import pprint as pp
import re
import time

import neuspell
import transformers
from cleantext import clean
from neuspell import BertChecker, SclstmChecker
from symspellpy.symspellpy import SymSpell

from utils import suppress_stdout


def detect_propers(text: str):
    """
    detect_propers - detect if a string contains proper nouns

    Args:
        text (str): [string to be checked]

    Returns:
        [bool]: [True if string contains proper nouns]
    """
    pat = re.compile(r"(?:\w+['’])?\w+(?:-(?:\w+['’])?\w+)*")
    return bool(pat.search(text))


def fix_punct_spaces(string):
    """
    fix_punct_spaces - replace spaces around punctuation with punctuation. For example, "hello , there" -> "hello, there"

    Parameters
    ----------
    string : str, required, input string to be corrected

    Returns
    -------
    str, corrected string
    """

    fix_spaces = re.compile(r"\s*([?!.,]+(?:\s+[?!.,]+)*)\s*")
    string = fix_spaces.sub(lambda x: "{} ".format(x.group(1).replace(" ", "")), string)
    return string.strip()


def split_sentences(text: str):
    """
    split_sentences - split a string into a list of sentences that keep their ending punctuation. powered by regex witchcraft

    Args:
        text (str): [string to be split]

    Returns:
        [list]: [list of strings]
    """
    return re.split(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s", text)


def remove_repeated_words(bot_response):
    """
    remove_repeated_words - remove repeated words from a string, returning only the first instance of each word

    Parameters
    ----------
    bot_response : str
        string to remove repeated words from

    Returns
    -------
    str
        string containing the first instance of each word
    """
    words = bot_response.split()
    unique_words = []
    for word in words:
        if word not in unique_words:
            unique_words.append(word)
    return " ".join(unique_words)


def remove_trailing_punctuation(text: str, fuLL_strip=False):
    """
    remove_trailing_punctuation - remove trailing punctuation from a string. Purpose is to seem more natural to end users

    Args:
        text (str): [string to be cleaned]

    Returns:
        [str]: [cleaned string]
    """
    if fuLL_strip:
        return text.strip("?!.,;:")
    else:
        return text.strip(".,;:")


def fix_punct_spacing(text: str):
    """fix_punct_spacing - fix spacing around punctuation"""
    fix_spaces = re.compile(r"\s*([?!.,]+(?:\s+[?!.,]+)*)\s*")
    spc_text = fix_spaces.sub(lambda x: "{} ".format(x.group(1).replace(" ", "")), text)
    cln_text = re.sub(r"(\W)(?=\1)", "", spc_text)

    return cln_text


def synthesize_grammar(
    corrector: transformers.pipeline,
    message: str,
    num_beams=4,
    length_penalty=0.9,
    repetition_penalty=1.5,
    no_repeat_ngram_size=4,
    verbose=False,
):
    """
    synthesize_grammar - use a SyntaxSynthesizer model to generate a string from a message

    Parameters
    ----------
    corrector : transformers.pipeline, required, which is the SyntaxSynthesizer model already loaded
    message : str, required, which is the message to be corrected
    num_beams : int, optional, by default 4, which is the number of beams to use for the model
    length_penalty : float, optional, by default 0.9, which is the length penalty to use for the model
    repetition_penalty : float, optional, by default 1.5, which is the repetition penalty to use for the model
    no_repeat_ngram_size : int, optional, by default 4, which is the n-gram size to use for the model
    verbose : bool, optional, by default False, which is whether to print the runtime of the model

    Returns
    -------
    """
    st = time.perf_counter()
    input_text = clean(message, lower=False)
    input_len = len(corrector.tokenizer(input_text).input_ids)
    results = corrector(
        input_text,
        max_length=int(1.1 * input_len),
        min_length=2 if input_len < 64 else int(0.2 * input_len),
        num_beams=num_beams,
        repetition_penalty=repetition_penalty,
        length_penalty=length_penalty,
        no_repeat_ngram_size=no_repeat_ngram_size,
        early_stopping=True,
        do_sample=False,
        clean_up_tokenization_spaces=True,
    )
    corrected_text = results[0]["generated_text"]
    if verbose:
        rt = round(time.perf_counter() - st, 2)
        print(f"synthesizing took {rt} seconds")
    return corrected_text.strip()


"""
start of SymSpell code
"""


def symspeller(
    my_string: str,
    sym_checker=None,
    max_dist: int = 2,
    prefix_length: int = 7,
    ignore_non_words=True,
    dictionary_path: str = None,
    bigram_path: str = None,
    verbose=False,
):
    """
    symspeller - a wrapper for the SymSpell class from symspellpy

    Parameters
    ----------
        my_string : str, required, default=None, the string to be checked
        sym_checker : SymSpell, optional, default=None, the SymSpell object to use
        max_dist : int, optional, default=3, the maximum distance to look for replacements
        prefix_length : int, optional, default=7, the length of the prefixes to use
        ignore_non_words : bool, optional, default=True, whether to ignore non-words
        dictionary_path : str, optional, default=None, the path to the dictionary file
        bigram_path : str, optional, default=None, the path to the bigram dictionary file
        verbose : bool, optional, default=False, whether to print the results

    Returns
    -------
        list,

    """

    assert len(my_string) > 0, "entered string for correction is empty"

    if sym_checker is None:
        # need to create a new class object. user can specify their own dictionary and bigram files
        if verbose:
            print("creating new SymSpell object")
        sym_checker = build_symspell_obj(
            edit_dist=max_dist,
            prefix_length=prefix_length,
            dictionary_path=dictionary_path,
            bigram_path=bigram_path,
        )
    else:
        if verbose:
            print("using existing SymSpell object")
    # max edit distance per lookup (per single word, not per whole input string)
    suggestions = sym_checker.lookup_compound(
        my_string,
        max_edit_distance=max_dist,
        ignore_non_words=ignore_non_words,
        ignore_term_with_digits=True,
        transfer_casing=True,
    )

    if verbose:
        print(f"{len(suggestions)} suggestions found")
        print(f"the original string is:\n\t{my_string}")
        sug_list = [sug.term for sug in suggestions]
        print(f"suggestions:\n\t{sug_list}\n")

    if len(suggestions) < 1:
        return clean(my_string)  # no correction because no suggestions
    else:
        first_result = suggestions[0]  # first result is the most likely
        return first_result._term


def build_symspell_obj(
    edit_dist=2,
    prefix_length=7,
    dictionary_path=None,
    bigram_path=None,
):
    """
    build_symspell_obj [build a SymSpell object]

    Args:
        verbose (bool, optional): Defaults to False.

    Returns:
        SymSpell: a SymSpell object
    """
    dictionary_path = (
        r"symspell_rsc/frequency_dictionary_en_82_765.txt"
        if dictionary_path is None
        else dictionary_path
    )
    bigram_path = (
        r"symspell_rsc/frequency_bigramdictionary_en_243_342.txt"
        if bigram_path is None
        else bigram_path
    )
    sym_checker = SymSpell(
        max_dictionary_edit_distance=edit_dist + 2, prefix_length=prefix_length
    )
    # term_index is the column of the term and count_index is the
    # column of the term frequency
    sym_checker.load_dictionary(dictionary_path, term_index=0, count_index=1)
    sym_checker.load_bigram_dictionary(bigram_path, term_index=0, count_index=2)

    return sym_checker


"""
# if using t5b_correction to check for spelling errors, use this code to initialize the objects

import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration

model_name = 'deep-learning-analytics/GrammarCorrector'
# torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch_device = 'cpu'
gc_tokenizer = T5Tokenizer.from_pretrained(model_name)
gc_model = T5ForConditionalGeneration.from_pretrained(model_name).to(torch_device)

"""


def t5b_correction(prompt: str, korrektor, verbose=False, beams=4):
    """
    t5b_correction - correct a string using a text2textgen pipeline model from transformers

    Parameters
    ----------
    prompt : str, required, input prompt to be corrected
    korrektor : transformers.pipeline, required, pipeline object
    verbose : bool, optional, whether to print the corrected prompt. Defaults to False.
    beams : int, optional, number of beams to use for the correction. Defaults to 4.

    Returns
    -------
    str, corrected prompt
    """

    p_min_len = int(math.ceil(0.9 * len(prompt)))
    p_max_len = int(math.ceil(1.1 * len(prompt)))
    if verbose:
        print(f"setting min to {p_min_len} and max to {p_max_len}\n")
    gcorr_result = korrektor(
        f"grammar: {prompt}",
        return_text=True,
        clean_up_tokenization_spaces=True,
        num_beams=beams,
        max_length=p_max_len,
        repetition_penalty=1.3,
        length_penalty=0.2,
        no_repeat_ngram_size=2,
    )
    if verbose:
        print(f"grammar correction result: \n\t{gcorr_result}\n")
    return gcorr_result


def all_neuspell_chkrs():
    """
    disp_neuspell_chkrs - display the neuspell checkers available

    Parameters
    ----------
    None

    Returns
    -------
    checker_opts - list of checkers available
    """

    checker_opts = dir(neuspell)
    print(f"\navailable checkers:")

    pp.pprint(checker_opts, indent=4, compact=True)

    return checker_opts


def load_ns_checker(customckr=None, fast=False):
    """
    load_ns_checker - helper function, load / "set up" a neuspell checker from huggingface transformers

    Args:
        customckr (neuspell.NeuSpell): [neuspell checker object], optional, if not provided, will load the default checker

    Returns:
        [neuspell.NeuSpell]: [neuspell checker object]
    """
    st = time.perf_counter()
    # stop all printing to the console
    with suppress_stdout():
        if customckr is None and not fast:

            checker = BertChecker(
                pretrained=True
            )  # load the default checker, has the best balance
        elif customckr is None and fast:
            checker = SclstmChecker(
                pretrained=True
            )  # this one is faster but not as accurate
        else:
            checker = customckr(pretrained=True)
    rt_min = (time.perf_counter() - st) / 60
    # return to standard logging level
    print(f"\n\nloaded checker in {rt_min} minutes")

    return checker


def neuspell_correct(input_text: str, checker=None, verbose=False):
    """
    neuspell_correct - correct a string using neuspell.
                        note that modificaitons to the checker are needed if doing list-based corrections

    Parameters
    ----------
    input_text : str, required, input string to be corrected
    checker : neuspell.NeuSpell, optional, neuspell checker object. Defaults to None.
    verbose : bool, optional, whether to print the corrected string. Defaults to False.

    Returns
    -------
    str, corrected string
    """
    if isinstance(input_text, str) and len(input_text) < 4:
        print(f"input text of {input_text} is too short to be corrected")
        return input_text

    if checker is None:
        print("NOTE - no checker provided, loading default checker")
        checker = SclstmChecker(pretrained=True)

    corrected = checker.correct(input_text)
    cleaned_txt = fix_punct_spaces(corrected)

    if verbose:
        print(f"neuspell correction result: \n\t{cleaned_txt}\n")
    return cleaned_txt


def grammarpipe(corrector, qphrase: str):
    """
    gramformer_correct - THE ORIGINAL ONE USED IN PROJECT AND NEEDS TO BE CHANGED.
                            Idea is to correct a string using a text2textgen pipeline model from transformers
    Args:
        corrector (transformers.pipeline): [transformers pipeline object, already created w/ relevant model]
        qphrase (str): [text to be corrected]
    Returns:
        [str]: [corrected text]
    """
    if isinstance(qphrase, str) and len(qphrase) < 4:
        print(f"input text of {qphrase} is too short to be corrected")
        return qphrase
    try:
        corrected = corrector(
            clean(qphrase), return_text=True, clean_up_tokenization_spaces=True
        )
        return corrected[0]["generated_text"]
    except Exception as e:
        print(f"NOTE - failed to correct with grammarpipe:\n {e}")
        return clean(qphrase)


def DLA_correct(qphrase: str):
    """
    DLA_correct - an "overhead" function to call correct_grammar() on a string, allowing for each newline to be corrected individually

    Args:
        qphrase (str): [string to be corrected]

    Returns:
        str, the list of the corrected strings joined under " "
    """
    if isinstance(qphrase, str) and len(qphrase) < 4:
        print(f"input text of {qphrase} is too short to be corrected")
        return qphrase

    sentences = split_sentences(qphrase)
    if len(sentences) == 1:
        corrected = correct_grammar(sentences[0])
        return corrected
    else:
        full_cor = []
        for sen in sentences:
            corr_sen = correct_grammar(clean(sen))
            full_cor.append(corr_sen)
        return " ".join(full_cor)


def correct_grammar(
    input_text: str,
    tokenizer,
    model,
    n_results: int = 1,
    beams: int = 8,
    temp=1,
    no_repeat_ngram_size=4,
    rep_penalty=2.5,
    device="cpu",
):
    """
    correct_grammar - correct a string using a text2textgen pipeline model from transformers.
                        This function is an alternative to the t5b_correction function.

    Parameters
    ----------
    input_text : str, required, input string to be corrected
    tokenizer : transformers.T5Tokenizer, required, tokenizer object, already created w/ relevant model
    model : transformers.T5ForConditionalGeneration, required, model object, already created w/ relevant model
    n_results : int, optional, number of results to return. Defaults to 1.
    beams : int, optional, number of beams to use for the correction. Defaults to 8.
    temp : int, optional, temperature to use for the correction. Defaults to 1.
    uniq_ngrams : int, optional, number of ngrams to use for the correction. Defaults to 2.
    rep_penalty : float, optional, penalty to use for the correction. Defaults to 1.5.
    device : str, optional, device to use for the correction. Defaults to 'cpu'.

    Returns
    -------
    str, corrected string (or list of strings if n_results > 1)
    """
    st = time.perf_counter()

    if len(tokenizer(input_text).input_ids) < 4:
        logging.info(f"input text of {input_text} is too short to be corrected")
        return input_text
    max_length = min(int(math.ceil(len(input_text) * 1.2)), 128)
    batch = tokenizer(
        [input_text],
        truncation=True,
        padding="max_length",
        max_length=max_length,
        return_tensors="pt",
    ).to(device)
    translated = model.generate(
        **batch,
        max_length=max_length,
        min_length=min(10, len(input_text)),
        no_repeat_ngram_size=no_repeat_ngram_size,
        repetition_penalty=rep_penalty,
        num_beams=beams,
        num_return_sequences=n_results,
        temperature=temp,
    )

    tgt_text = tokenizer.batch_decode(translated)
    rt_min = (time.perf_counter() - st) / 60
    print(f"\n\ncorrected in {rt_min} minutes")

    if isinstance(tgt_text, list):
        return tgt_text[0]
    else:
        return tgt_text