"""
    converse.py - this script has functions for handling the conversation between the user and the bot.

    https://huggingface.co/docs/transformers/v4.15.0/en/main_classes/model#transformers.generation_utils.GenerationMixin.generate.no_repeat_ngram_size
"""

import logging

logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
import pprint as pp
import time

from grammar_improve import remove_trailing_punctuation

from constrained_generation import constrained_generation


def discussion(
    prompt_text: str,
    speaker: str,
    responder: str,
    pipeline,
    timeout=45,
    min_length=8,
    max_length=64,
    top_p=0.95,
    top_k=50,
    temperature=0.7,
    full_text=False,
    length_penalty=0.8,
    no_repeat_ngram_size=2,
    num_return_sequences=1,
    device=-1,
    verbose=False,
    constrained_beam_search=False,
):
    """
    discussion - a function that takes in a prompt and generates a response. This function is meant to be used in a conversation loop, and is the main function for the bot.

    Parameters
    ----------
        prompt_text : str, the prompt to ask the bot, usually the user's question
        speaker : str, the name of the person who is speaking the prompt
        responder : str, the name of the person who is responding to the prompt
        pipeline : transformers.Pipeline, the pipeline to use for generating the response
        timeout : int, optional, the number of seconds to wait before timing out, by default 45
        max_length : int, optional, the maximum number of tokens to generate, defaults to 128
        top_p : float, optional, the top probability to use for sampling, defaults to 0.95
        top_k : int, optional, the top k to use for sampling, defaults to 50
        temperature : float, optional, the temperature to use for sampling, defaults to 0.7
        full_text : bool, optional, whether to return the full text or just the generated text, defaults to False
        num_return_sequences : int, optional, the number of sequences to return, defaults to 1
        device : int, optional, the device to use for generation, defaults to -1 (CPU)
        verbose : bool, optional, whether to print the generated text, defaults to False

    Returns
    -------
        str, the generated text
    """

    logging.debug(f"input args: {locals()}")

    p_list = []  # track conversation
    p_list.append(speaker.lower() + ":" + "\n")
    p_list.append(prompt_text.lower() + "\n")
    p_list.append("\n")
    p_list.append(responder.lower() + ":" + "\n")
    this_prompt = "".join(p_list)
    if verbose:
        print("overall prompt:\n")
        pp.pprint(this_prompt, indent=4)

    if constrained_beam_search:
        logging.info("generating using constrained beam search ...")
        response = constrained_generation(
            prompt=this_prompt,
            pipeline=pipeline,
            min_generated_tokens=min_length,
            max_generated_tokens=max_length,
            no_repeat_ngram_size=no_repeat_ngram_size,
            length_penalty=length_penalty,
            repetition_penalty=1.0,
            num_beams=4,
            timeout=timeout,
            verbose=False,
            full_text=full_text,
            speaker_name=speaker,
            responder_name=responder,
        )

        bot_dialogue = consolidate_texts(
            name_resp=responder,
            model_resp=response.split("\n"),
            name_spk=speaker,
            verbose=verbose,
            print_debug=True,
        )
    else:
        logging.info("generating using sampling ...")
        bot_dialogue = gen_response(
            this_prompt,
            pipeline,
            speaker,
            responder,
            timeout=timeout,
            min_length=min_length,
            max_length=max_length,
            top_p=top_p,
            top_k=top_k,
            temperature=temperature,
            full_text=full_text,
            no_repeat_ngram_size=no_repeat_ngram_size,
            length_penalty=length_penalty,
            num_return_sequences=num_return_sequences,
            device=device,
            verbose=verbose,
        )
    logging.debug(f"generation done. bot_dialogue: {bot_dialogue}")
    if isinstance(bot_dialogue, list) and len(bot_dialogue) > 1:
        bot_resp = ", ".join(bot_dialogue)
    elif isinstance(bot_dialogue, list) and len(bot_dialogue) == 1:
        bot_resp = bot_dialogue[0]
    else:
        bot_resp = bot_dialogue
    bot_resp = " ".join(bot_resp) if isinstance(bot_resp, list) else bot_resp
    bot_resp = bot_resp.strip()
    # remove the last ',' '.' chars
    bot_resp = remove_trailing_punctuation(bot_resp)
    if verbose:
        print("\nfinished!")
        print("\n... bot response:\n")
        pp.pprint(bot_resp)
    p_list.append(bot_resp + "\n")
    p_list.append("\n")

    logging.info(f"finished generating response:\n\t{bot_resp}")
    # return the bot response and the full conversation

    return {"out_text": bot_resp, "full_conv": p_list}


def gen_response(
    query: str,
    pipeline,
    speaker: str,
    responder: str,
    timeout=45,
    min_length=12,
    max_length=48,
    top_p=0.95,
    top_k=20,
    temperature=0.5,
    full_text=False,
    num_return_sequences=1,
    length_penalty: float = 0.8,
    repetition_penalty: float = 3.5,
    no_repeat_ngram_size=2,
    device=-1,
    verbose=False,
    **kwargs,
):
    """
    gen_response - a function that takes in a prompt and generates a response using the pipeline. This operates underneath the discussion function.

    Parameters
    ----------
        query : str, the prompt to ask the bot, usually the user's question
        speaker : str, the name of the person who is speaking the prompt
        responder : str, the name of the person who is responding to the prompt
        pipeline : transformers.Pipeline, the pipeline to use for generating the response
        timeout : int, optional, the number of seconds to wait before timing out, by default 45
        min_length : int, optional, the minimum number of tokens to generate, defaults to 4
        max_length : int, optional, the maximum number of tokens to generate, defaults to 64
        top_p : float, optional, the top probability to use for sampling, defaults to 0.95
        top_k : int, optional, the top k to use for sampling, defaults to 50
        temperature : float, optional, the temperature to use for sampling, defaults to 0.7
        full_text : bool, optional, whether to return the full text or just the generated text, defaults to False
        num_return_sequences : int, optional, the number of sequences to return, defaults to 1
        device : int, optional, the device to use for generation, defaults to -1 (CPU)
        verbose : bool, optional, whether to print the generated text, defaults to False

    Returns
    -------
        str, the generated text

    """
    logging.debug(f"input args - gen_response() : {locals()}")
    input_len = len(pipeline.tokenizer(query).input_ids)
    if max_length + input_len > 1024:
        max_length = max(1024 - input_len, 8)
        print(f"max_length too large, setting to {max_length}")
    st = time.perf_counter()

    response = pipeline(
        query,
        min_length=min_length + input_len,
        max_length=max_length + input_len,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        num_return_sequences=num_return_sequences,
        max_time=timeout,
        return_full_text=full_text,
        no_repeat_ngram_size=no_repeat_ngram_size,
        repetition_penalty=repetition_penalty,
        length_penalty=length_penalty,
        clean_up_tokenization_spaces=True,
        remove_invalid_values=True,
        **kwargs,
    )  # the likely better beam-less method
    rt = round(time.perf_counter() - st, 2)
    if verbose:
        print(f"took {rt} sec to respond")
    if verbose:
        print("\n[DEBUG] generated:\n")
        pp.pprint(response)  # for debugging
    # process the full result to get the ~bot response~ piece
    this_result = str(response[0]["generated_text"]).split(
        "\n"
    )  # TODO: adjust hardcoded value for index to dynamic (if n>1)

    bot_dialogue = consolidate_texts(
        name_resp=responder,
        model_resp=this_result,
        name_spk=speaker,
        verbose=verbose,
        print_debug=True,
    )
    if verbose:
        print(f"DEBUG: {bot_dialogue} was original response pre-SC")
    return bot_dialogue  #


def consolidate_texts(
    model_resp: list,
    name_resp: str = None,
    name_spk: str = None,
    verbose=False,
    print_debug=False,
):
    """
    consolidate_texts - given a list with speaker name followed by speaker text, returns all consecutive values of the first speaker name

    Parameters:
        name_resp (str): the name of the person who is responding
        model_resp (list): the list of strings to consolidate (usually from the model)
        name_spk (str): the name of the person who is speaking
        verbose (bool): whether to print the results
        print_debug (bool): whether to print the debug info during looping

    Returns:
        list, a list of all the consecutive messages of the first speaker name
    """
    assert len(model_resp) > 0, "model_resp is empty"
    if len(model_resp) == 1:
        return model_resp[0]
    name_resp = "person beta" if name_resp is None else name_resp
    name_spk = "person alpha" if name_spk is None else name_spk
    if verbose:
        print("====" * 10)
        print(
            f"\n[DEBUG] initial model_resp has {len(model_resp)} lines: \n\t{model_resp}"
        )
        print(
            f" the first element is \n\t{model_resp[0]} and it is {type(model_resp[0])}"
        )
    fn_resp = []

    name_counter = 0
    break_safe = False
    for resline in model_resp:
        if name_resp.lower() in resline:
            name_counter += 1
            break_safe = True  # know the line is from bot as this line starts with the name of the bot
            continue  # don't add this line to the list
        if name_spk.lower() in resline.lower():
            if print_debug:
                print(f"\nDEBUG: \n\t{resline}\ncaused the break")
            break  # the name of the speaker is in the line, so we're done
        if (
            any([": " in resline, ":\n" in resline])
            and name_resp.lower() not in resline.lower()
        ):
            if print_debug:
                print(f"\nDEBUG: \n\t{resline}\ncaused the break")
            break
        else:
            fn_resp.append(resline)
            break_safe = False
    if verbose:
        print("--" * 10)
        print("\nthe full response is:\n")
        print("\n".join(fn_resp))
        print("--" * 10)
    return fn_resp