"""Module containing PromptTokenizingStrategy and Prompter classes"""

import abc
import copy
import functools
import logging
from typing import Dict, List, Tuple, Union

from fastchat.conversation import Conversation
from transformers import BatchEncoding, PreTrainedTokenizer

from axolotl.monkeypatch.fastchat_conversation_turns import (
    add_get_turns_to_conversation,
)
from axolotl.prompters import IGNORE_TOKEN_ID

LOG = logging.getLogger("axolotl")

IGNORE_INDEX = -100
LLAMA_DEFAULT_PAD_TOKEN = "<pad>"  # nosec
LLAMA_DEFAULT_EOS_TOKEN = "</s>"  # nosec
LLAMA_DEFAULT_BOS_TOKEN = "<s>"  # nosec
LLAMA_DEFAULT_UNK_TOKEN = "<unk>"  # nosec

add_get_turns_to_conversation()


class InvalidDataException(Exception):
    """
    Exception raised when the data is invalid
    """


class PromptTokenizingStrategy(abc.ABC):
    """
    Abstract class for tokenizing strategies
    """

    def __init__(
        self,
        prompter,
        tokenizer,
        train_on_inputs: bool = False,
        sequence_len: int = 2048,
    ):
        self.prompter = prompter
        self.tokenizer: PreTrainedTokenizer = tokenizer
        self.train_on_inputs = train_on_inputs
        self.sequence_len = sequence_len
        self.max_length = sequence_len

    @abc.abstractmethod
    def tokenize_prompt(self, prompt):
        pass

    @property
    def supports_batched(self):
        return False

    @functools.lru_cache(maxsize=128)
    def _get_user_token(self):
        try:
            id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
            if isinstance(id_or_ids, (int,)):
                return id_or_ids
        except KeyError:
            pass
        return False

    @functools.lru_cache(maxsize=128)
    def _get_assistant_token(self):
        try:
            id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
            if isinstance(id_or_ids, (int,)):
                return id_or_ids
        except KeyError:
            pass
        return False

    def _tokenize(
        self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
    ) -> BatchEncoding:
        result: BatchEncoding
        if not prompt:
            LOG.warning("Empty text requested for tokenization.")
            result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
        else:
            result = self.tokenizer(
                prompt,
                truncation=True,
                max_length=self.max_length,
                padding=False,
                return_tensors=None,
            )
        if len(result["input_ids"]) == 0:
            LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
        if (
            len(result["input_ids"]) > 0
            and result["input_ids"][-1] != self.tokenizer.eos_token_id
            and len(result["input_ids"]) < self.max_length
            and add_eos_token
        ):
            result["input_ids"].append(self.tokenizer.eos_token_id)
            result["attention_mask"].append(1)

        if (
            len(result["input_ids"]) > 0
            and result["input_ids"][0] == self.tokenizer.bos_token_id
            and strip_bos_token
        ):
            result["input_ids"] = result["input_ids"][1:]
            result["attention_mask"] = result["attention_mask"][1:]

        result["labels"] = result["input_ids"].copy()
        return result


class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
    """
    Tokenizing strategy for instruction-based prompts.
    """

    def parse_instruction_fields(
        self, prompt
    ) -> Union[Tuple[str, str, str], Tuple[str, str, str, str]]:
        raise NotImplementedError

    def tokenize_prompt(self, prompt):
        (
            instruction,
            input,  # pylint: disable=redefined-builtin
            response,
        ) = self.parse_instruction_fields(prompt)
        user_prompt = next(
            iter(
                self.prompter.build_prompt(
                    instruction,
                    input,
                )
            )
        )
        tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
        if not self.train_on_inputs:
            user_prompt_len = len(tokenized_prompt["input_ids"])
            # TODO this could be sped up using numpy array slicing
            tokenized_prompt["labels"] = [-100] * user_prompt_len
        tokenized_res_prompt = self._tokenize(
            response, strip_bos_token=True, add_eos_token=True
        )
        tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
        tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
        tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]

        return tokenized_prompt

    def _build_full_prompt(
        self, instruction, input, response  # pylint: disable=redefined-builtin
    ):
        return next(
            iter(
                self.prompter.build_prompt(
                    instruction,
                    input,
                    response,
                )
            )
        )


class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
    """
    Tokenizing strategy for Alpaca prompts.
    """

    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
        return (
            prompt["instruction"],
            prompt["input"] if "input" in prompt else "",
            prompt["output"],
        )


class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
    """
    Tokenizing strategy for Alpaca Multiple Choice prompts.
    """

    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
        return (
            prompt["question"],
            "\n".join(f'- "{choice}"' for choice in prompt["choices"]),
            prompt["solution"] if "solution" in prompt else prompt["explanation"],
        )


class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
    """
    Tokenizing strategy for Jeopardy prompts.
    """

    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
        return (
            prompt["question"],
            prompt["category"],
            "what is " + prompt["answer"],
        )


class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
    """
    Tokenizing strategy for OpenAssistant prompts.
    """

    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
        return (
            prompt["INSTRUCTION"],
            "",
            prompt["RESPONSE"],
        )


class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
    """
    Tokenizing strategy for SummarizeTLDR prompts.
    """

    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
        return (
            prompt["article"],
            "",
            prompt["summary"],
        )


class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
    """
    Tokenizing strategy for GPTeacher prompts.
    """

    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
        return (
            prompt["instruction"],
            prompt["input"] if "input" in prompt else "",
            prompt["response"],
        )


class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
    """
    Tokenizing strategy for NomicGPT4All prompts.
    """

    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
        return (
            prompt["prompt"],
            "",
            prompt["response"],
        )


class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
    """
    Tokenizing strategy for Reflection prompts.
    """

    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]:
        raise NotImplementedError

    def tokenize_prompt(self, prompt):
        (
            instruction,
            input,  # pylint: disable=redefined-builtin
            output,
            reflection,
            corrected,
        ) = self.parse_instruction_fields(prompt)
        full_prompt = self._build_full_prompt(
            instruction, input, output, reflection, corrected
        )
        tokenized_full_prompt = self._tokenize(full_prompt)
        if not self.train_on_inputs:
            user_prompt = next(
                iter(
                    self.prompter.build_prompt(
                        instruction,
                        input,
                    )
                )
            )
            tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
            user_prompt_len = len(tokenized_user_prompt["input_ids"])
            # TODO this could be sped up using numpy array slicing
            tokenized_full_prompt["labels"] = [
                -100
            ] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]

        return tokenized_full_prompt

    def _build_full_prompt(
        self, instruction, input, output, reflection, corrected
    ):  # pylint: disable=redefined-builtin
        return next(
            iter(
                self.prompter.build_prompt(
                    instruction,
                    input,
                    output,
                    reflection,
                    corrected,
                )
            )
        )

    def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
        result = self.tokenizer(
            prompt,
            truncation=True,
            max_length=self.sequence_len,
            padding=False,
            return_tensors=None,
        )
        if (
            result["input_ids"][-1] != self.tokenizer.eos_token_id
            and len(result["input_ids"]) < self.sequence_len
            and add_eos_token
        ):
            result["input_ids"].append(self.tokenizer.eos_token_id)
            result["attention_mask"].append(1)

        result["labels"] = result["input_ids"].copy()
        return result


class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
    """
    Tokenizing strategy for Alpaca Reflection prompts.
    """

    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]:
        return (
            prompt["instruction"],
            prompt["input"] if "input" in prompt else "",
            prompt["output"],
            prompt["reflection"],
            prompt["corrected"],
        )


class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
    """
    Tokenizing strategy for ShareGPT prompts.
    """

    def get_conversation_thread(self, prompt):
        return prompt["conversations"]

    def tokenize_prompt(self, prompt):
        result, current_len = tokenize_prompt_default()
        user_token = self._get_user_token()
        assistant_token = self._get_assistant_token()
        conversation: Conversation = (
            self.prompter._conversation  # pylint: disable=protected-access
        )
        try:
            for _, part in enumerate(
                self.prompter.build_prompt(self.get_conversation_thread(prompt))
            ):
                if isinstance(part, tuple):
                    if conversation.roles[0] in part[0]:
                        turn = part[0] + part[1] if not user_token else part[1]
                        # this is still the user query, we should
                        if not part[1].strip():
                            LOG.warning(f"user turn has empty text: {prompt}")
                        res = self._tokenize(
                            turn,
                            add_eos_token=False,
                            strip_bos_token=True,
                        )
                        if user_token:
                            res["input_ids"] = [user_token, *res["input_ids"]]
                        # everything from this is masked out from the labels
                        labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
                    elif conversation.roles[1] in part[0]:
                        # TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
                        turn = part[0] + part[1] if not assistant_token else part[1]
                        # this should be the assistant response, should end with an eos token
                        if not part[1].strip():
                            LOG.warning(f"assistant turn has empty text: {prompt}")
                        res = self._tokenize(
                            turn,
                            add_eos_token=True,
                            strip_bos_token=True,
                        )
                        if assistant_token:
                            res["input_ids"] = [
                                assistant_token,
                                *res["input_ids"],
                            ]
                        # not masked out from labels
                        labels = copy.deepcopy(res["input_ids"])
                    elif part[0] == "":
                        turn = part[1]
                        # this is only ever the first part, should include the bos token and the user query
                        res = self._tokenize(
                            turn, add_eos_token=False, strip_bos_token=False
                        )
                        # everything from this is masked out from the labels
                        labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
                    else:
                        LOG.warning(f"unhandled role: {part[0]}")
                        continue

                # pylint: disable=duplicate-code
                result, current_len = parse_tokenized_to_result(
                    result,
                    current_len,
                    res,
                    labels,
                    pad_token_id=self.tokenizer.pad_token_id,
                )
            return result
        except (KeyError, AssertionError, IndexError) as err:
            raise InvalidDataException(str(err)) from err

    def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
        if not prompt.strip():
            LOG.warning("Empty text requested for tokenization.")
            result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
        else:
            result = self.tokenizer(
                prompt,
                truncation=True,
                max_length=self.sequence_len,
                padding=False,
                return_tensors=None,
            )
        if (
            len(result["input_ids"]) > 0
            and result["input_ids"][-1] != self.tokenizer.eos_token_id
            and len(result["input_ids"]) < self.sequence_len
            and add_eos_token
        ):
            result["input_ids"].append(self.tokenizer.eos_token_id)
            result["attention_mask"].append(1)

        if (
            len(result["input_ids"]) > 0
            and result["input_ids"][0] == self.tokenizer.bos_token_id
            and strip_bos_token
        ):
            result["input_ids"] = result["input_ids"][1:]
            result["attention_mask"] = result["attention_mask"][1:]

        result["labels"] = result["input_ids"].copy()
        return result


def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
    """
    Returns the default values for the tokenize prompt function
    """

    result: Dict[str, List[int]] = {
        "input_ids": [],
        "attention_mask": [],
        "labels": [],
    }
    current_len = 0
    return result, current_len


def parse_tokenized_to_result(
    result: Dict[str, List[int]],
    current_len: int,
    res: Dict[str, List[int]],
    labels: List[int],
    pad_token_id: Union[int, None] = None,
) -> Tuple[Dict[str, List[int]], int]:
    """
    Parses the tokenized prompt and append the tokenized input_ids, attention_mask and labels to the result
    """

    input_ids = res["input_ids"]
    input_len = len(input_ids)
    result["input_ids"][current_len : current_len + input_len] = input_ids
    result["attention_mask"][current_len : current_len + input_len] = [
        1 if x != pad_token_id else 0 for x in input_ids
    ]
    result["labels"][current_len : current_len + input_len] = labels
    current_len += input_len

    return result, current_len