import os
import pickle
import random
import pandas as pd

SPEAKERA_ROLE_MAP = {"Agent": 0, "Visitor": 1}

LABEL_MAP = {
    "Curiosity": 0,
    "Obscene": 1,
    "Informative": 2,
    "Openness": 3,
    "Acceptance": 4,
    "Interest": 5,
    "Greeting": 6,
    "Disapproval": 7,
    "Denial": 8,
    "Anxious": 9,
    "Uninterested": 10,
    "Remorse": 11,
    "Confused": 12,
    "Accusatory": 13,
    "Annoyed": 14,
}


def process_user_input(input: str):
    """Parse the user input and return a list of row where each row is a list with
    format `[<conversation_id>, <speaker>, <message>]`.

    Args:
        input (str): the input of the user with each line has the format of
        `<speaker>:<message>`. Only one message per line.

    Returns:
        dict: a dictionary containing whether the input was successfully processed and
        if so, the processed data of the input.
    """
    if input == None or input == "":
        return {"success": False, "message": "Input must not be an empty string!"}

    data = []
    for line in input.split("\n"):
        if line == "":
            continue
        try:
            speaker, message = line.split(":", 1)

            if speaker != "Agent" and speaker != "Visitor":
                return {"success": False, "message": f"Invalid speaker {speaker}"}

            # Assuming there's only one input conversation
            # Give it a dummy conversation id of epik_0
            data.append(["epik_0", speaker, message])
        except:
            return {"success": False, "message": "Invalid Input"}

    return {
        "success": True,
        "message": "Success",
        "data": data,
    }


def encode_speaker_role(role):
    return SPEAKERA_ROLE_MAP.get(role, 1)


def decode_speaker_role(role_numeric):
    for role, numeric_val in SPEAKERA_ROLE_MAP.items():
        if role_numeric == numeric_val:
            return role

    return "Unknow Speaker"


def encode_sentiment_label(label):
    return LABEL_MAP.get(label, -1)


def decode_numeric_label(label_numeric):
    for label, numeric_val in LABEL_MAP.items():
        if label_numeric == numeric_val:
            return label

    return "Unknow Label"


def preapre_csv(data: list[list], output_path: str, with_label: bool = False):
    """
    Process and group the speakers, messages, and labels (if any) by conversation
    ids. This function is useful to prepare the neccesary csv file before converting it into
    pickle file.


    Args:
        data (list[list]): A list contains the rows of a dataframe. Each row contains
        values representing the coversation id, speaker role, message (, and label if any) in this order.
        output_path (str): path to write the csv file.
        with_label (bool, optional): Whether the input data contains labels (ie, for
        training) or not (ie, for making predictions on a new sample). Defaults to False.
    """
    columns = ["ConversationId", "ParticipantRole", "Text"]

    if with_label:
        columns += ["Label"]

    df = pd.DataFrame(data=data, columns=columns)

    # encode the participant role
    df["ParticipantRoleEncoded"] = df["ParticipantRole"].apply(
        lambda role: encode_speaker_role(role)
    )

    # encode the labels
    if with_label:
        df["LabelNumeric"] = df["Label"].apply(
            lambda label: encode_sentiment_label(label)
        )
    else:
        # Give the new input dummy labels to match the model input shape
        df["LabelNumeric"] = df["ParticipantRole"].apply(lambda _: -1)

    # group the data into list based on conversation id
    agg_params = {"Label": list} if with_label else {}
    agg_params.update(
        {
            "ParticipantRole": list,
            "ParticipantRoleEncoded": list,
            "Text": list,
            "LabelNumeric": list,
        }
    )
    grouped_df = df.groupby("ConversationId").agg(agg_params).reset_index()

    grouped_df.to_csv(output_path, index=False, encoding="ascii")

    return grouped_df


def convert_to_pickle(
    source: str,
    dest: str,
    index_col: str = None,
    list_type_columns: list = [],
    order=[],
    exclude=[],
    single_tuple=False,
):
    """Convert a csv file into a pickle file with format
    col1, col2, ..., coln

    Args:
        source (str): path to csv file
        dest (str): the location where the pickle file will be stored
        index_col (str): the column with unique ids that serves as index. Default to
        None
        order (list, optional): specify the order for one or many columns from left to
        right, followed by columns not in order.
        exclude (list, optional): columns to be excluded from the result. Defaults to
        [].
        single_tuple (bool): whether or not to output as tuple if there is only one
        single column. Default to False.
    """
    df = pd.read_csv(source)
    df = df.drop(columns=exclude)

    # convert column from string representation of a list to list
    for col in list_type_columns:
        if col in df.columns:
            df[col] = df[col].fillna("[]").apply(lambda x: eval(x))

    if index_col != None:
        df = df.set_index(index_col)

    # reorder the columns
    if order != []:
        left = df[order]
        right = df[[col for col in df.columns if col not in order]]
        df = pd.concat([left, right], axis=1)

    output = ()
    for col in df.columns:
        output += (df[col].to_dict(),)

    if not single_tuple and len(output) == 1:
        output = output[0]

    with open(dest, "wb") as f:
        pickle.dump(output, f)
        f.close()

    return


def split_and_save_ids(
    ids, train_ratio=0.8, test_ratio=0.1, valid_ratio=0.1, dir=".", seed=None
):
    """
    Randomly split a list of IDs into training, testing, and validation sets and save them to text files.

    Args:
        ids (list): List of IDs to be split.
        train_ratio (float): Ratio of IDs for the training set (default is 0.8).
        test_ratio (float): Ratio of IDs for the testing set (default is 0.1).
        valid_ratio (float): Ratio of IDs for the validation set (default is 0.1).
        dir (str): the path to the directory to save the files for ids
        seed (int): Seed for randomization (default is None).

    Returns:
        train_set (list): List of IDs in the training set.
        test_set (list): List of IDs in the testing set.
        valid_set (list): List of IDs in the validation set.
    """

    # Check if the ratios add up to 1.0
    assert train_ratio + test_ratio + valid_ratio == 1.0, "Ratios should add up to 1.0"

    # Set random seed for reproducibility
    if seed is not None:
        random.seed(seed)

    # Shuffle the list of IDs
    random.shuffle(ids)

    # Calculate the split points
    train_split = int(len(ids) * train_ratio)
    test_split = train_split + int(len(ids) * test_ratio)

    # Split the IDs
    train_set = ids[:train_split]
    test_set = ids[train_split:test_split]
    valid_set = ids[test_split:]

    # Save the sets to text files
    def save_to_txt(file_path, id_set):
        with open(file_path, "w") as file:
            id_strings = [str(conv_id) for conv_id in id_set]
            file.write("\n".join(id_strings))

    save_to_txt(os.path.join(dir, "train_set.txt"), train_set)
    save_to_txt(os.path.join(dir, "test_set.txt"), test_set)
    save_to_txt(os.path.join(dir, "validation_set.txt"), valid_set)

    return train_set, test_set, valid_set


def merge_pkl_with_ids(pickle_src: str, ids_files: list, dir: str = "."):
    """Merge an existing pickle file with id files, resulting in a pickle file with 3
    more fields of train_ids, test_ids, and valid_ids.

    Args:
        pickle_src (str): the path to the pickle file
        ids_files (list): list of files that contain ids. Example:
        ["train_set.txt", "test_set.txt", "validation_set.txt"]. Each file should
        contain one single unique id on each line.
        dir (str, optional): the directory for ids_files. Defaults to ''.
    """
    ids_set = ()
    for filename in ids_files:
        ids = []
        path = os.path.join(dir, filename)
        with open(path, "r") as file:
            for line in file:
                ids.append(line.strip())

        ids_set += (ids,)

    with open(pickle_src, "rb") as file:
        data = pickle.load(file)
        data += ids_set
        file.close()

    with open(pickle_src, "wb") as file:
        pickle.dump(data, file)
        file.close()