""".. _goal_function:

GoalFunction Class
===========================================================
"""


from abc import ABC, abstractmethod

import lru
import numpy as np
import torch

from textattack.goal_function_results.goal_function_result import (
    GoalFunctionResultStatus,
)
from textattack.shared import validators
from textattack.shared.utils import ReprMixin


class GoalFunction(ReprMixin, ABC):
    """Evaluates how well a perturbed attacked_text object is achieving a
    specified goal.

    Args:
        model_wrapper (:class:`~textattack.models.wrappers.ModelWrapper`):
            The victim model to attack.
        maximizable(:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether the goal function is maximizable, as opposed to a boolean result of success or failure.
        query_budget (:obj:`float`, `optional`, defaults to :obj:`float("in")`):
            The maximum number of model queries allowed.
        model_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**20`):
            The maximum number of items to keep in the model results cache at once.
    """

    def __init__(
        self,
        model_wrapper,
        maximizable=False,
        use_cache=True,
        query_budget=float("inf"),
        model_batch_size=32,
        model_cache_size=2**20,
    ):
        validators.validate_model_goal_function_compatibility(
            self.__class__, model_wrapper.model.__class__
        )
        self.model = model_wrapper
        self.maximizable = maximizable
        self.use_cache = use_cache
        self.query_budget = query_budget
        self.batch_size = model_batch_size
        if self.use_cache:
            self._call_model_cache = lru.LRU(model_cache_size)
        else:
            self._call_model_cache = None

    def clear_cache(self):
        if self.use_cache:
            self._call_model_cache.clear()

    def init_attack_example(self, attacked_text, ground_truth_output):
        """Called before attacking ``attacked_text`` to 'reset' the goal
        function and set properties for this example."""
        self.initial_attacked_text = attacked_text
        self.ground_truth_output = ground_truth_output
        self.num_queries = 0
        result, _ = self.get_result(attacked_text, check_skip=True)
        return result, _

    def get_output(self, attacked_text):
        """Returns output for display based on the result of calling the
        model."""
        return self._get_displayed_output(self._call_model([attacked_text])[0])

    def get_result(self, attacked_text, **kwargs):
        """A helper method that queries ``self.get_results`` with a single
        ``AttackedText`` object."""
        results, search_over = self.get_results([attacked_text], **kwargs)
        result = results[0] if len(results) else None
        return result, search_over

    def get_results(self, attacked_text_list, check_skip=False):
        """For each attacked_text object in attacked_text_list, returns a
        result consisting of whether or not the goal has been achieved, the
        output for display purposes, and a score.

        Additionally returns whether the search is over due to the query
        budget.
        """
        results = []
        if self.query_budget < float("inf"):
            queries_left = self.query_budget - self.num_queries
            attacked_text_list = attacked_text_list[:queries_left]
        self.num_queries += len(attacked_text_list)
        model_outputs = self._call_model(attacked_text_list)
        for attacked_text, raw_output in zip(attacked_text_list, model_outputs):
            displayed_output = self._get_displayed_output(raw_output)
            goal_status = self._get_goal_status(
                raw_output, attacked_text, check_skip=check_skip
            )
            goal_function_score = self._get_score(raw_output, attacked_text)
            results.append(
                self._goal_function_result_type()(
                    attacked_text,
                    raw_output,
                    displayed_output,
                    goal_status,
                    goal_function_score,
                    self.num_queries,
                    self.ground_truth_output,
                )
            )
        return results, self.num_queries == self.query_budget

    def _get_goal_status(self, model_output, attacked_text, check_skip=False):
        should_skip = check_skip and self._should_skip(model_output, attacked_text)
        if should_skip:
            return GoalFunctionResultStatus.SKIPPED
        if self.maximizable:
            return GoalFunctionResultStatus.MAXIMIZING
        if self._is_goal_complete(model_output, attacked_text):
            return GoalFunctionResultStatus.SUCCEEDED
        return GoalFunctionResultStatus.SEARCHING

    @abstractmethod
    def _is_goal_complete(self, model_output, attacked_text):
        raise NotImplementedError()

    def _should_skip(self, model_output, attacked_text):
        return self._is_goal_complete(model_output, attacked_text)

    @abstractmethod
    def _get_score(self, model_output, attacked_text):
        raise NotImplementedError()

    def _get_displayed_output(self, raw_output):
        return raw_output

    @abstractmethod
    def _goal_function_result_type(self):
        """Returns the class of this goal function's results."""
        raise NotImplementedError()

    @abstractmethod
    def _process_model_outputs(self, inputs, outputs):
        """Processes and validates a list of model outputs.

        This is a task-dependent operation. For example, classification
        outputs need to make sure they have a softmax applied.
        """
        raise NotImplementedError()

    def _call_model_uncached(self, attacked_text_list):
        """Queries model and returns outputs for a list of AttackedText
        objects."""
        if not len(attacked_text_list):
            return []

        inputs = [at.tokenizer_input for at in attacked_text_list]
        outputs = []
        i = 0
        while i < len(inputs):
            batch = inputs[i : i + self.batch_size]
            batch_preds = self.model(batch)

            # Some seq-to-seq models will return a single string as a prediction
            # for a single-string list. Wrap these in a list.
            if isinstance(batch_preds, str):
                batch_preds = [batch_preds]

            # Get PyTorch tensors off of other devices.
            if isinstance(batch_preds, torch.Tensor):
                batch_preds = batch_preds.cpu()

            if isinstance(batch_preds, list):
                outputs.extend(batch_preds)
            elif isinstance(batch_preds, np.ndarray):
                outputs.append(torch.tensor(batch_preds))
            else:
                outputs.append(batch_preds)
            i += self.batch_size

        if isinstance(outputs[0], torch.Tensor):
            outputs = torch.cat(outputs, dim=0)

        assert len(inputs) == len(
            outputs
        ), f"Got {len(outputs)} outputs for {len(inputs)} inputs"

        return self._process_model_outputs(attacked_text_list, outputs)

    def _call_model(self, attacked_text_list):
        """Gets predictions for a list of ``AttackedText`` objects.

        Gets prediction from cache if possible. If prediction is not in
        the cache, queries model and stores prediction in cache.
        """
        if not self.use_cache:
            return self._call_model_uncached(attacked_text_list)
        else:
            uncached_list = []
            for text in attacked_text_list:
                if text in self._call_model_cache:
                    # Re-write value in cache. This moves the key to the top of the
                    # LRU cache and prevents the unlikely event that the text
                    # is overwritten when we store the inputs from `uncached_list`.
                    self._call_model_cache[text] = self._call_model_cache[text]
                else:
                    uncached_list.append(text)
            uncached_list = [
                text
                for text in attacked_text_list
                if text not in self._call_model_cache
            ]
            outputs = self._call_model_uncached(uncached_list)
            for text, output in zip(uncached_list, outputs):
                self._call_model_cache[text] = output
            all_outputs = [self._call_model_cache[text] for text in attacked_text_list]
            return all_outputs

    def extra_repr_keys(self):
        attrs = []
        if self.query_budget < float("inf"):
            attrs.append("query_budget")
        if self.maximizable:
            attrs.append("maximizable")
        return attrs

    def __getstate__(self):
        state = self.__dict__.copy()
        if self.use_cache:
            state["_call_model_cache"] = self._call_model_cache.get_size()
        return state

    def __setstate__(self, state):
        self.__dict__ = state
        if self.use_cache:
            self._call_model_cache = lru.LRU(state["_call_model_cache"])