import copy
import os
import tempfile
import types
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union

import torch
from huggingface_hub.utils import SoftTemporaryDirectory

from setfit.utils import set_docstring

from .. import logging
from ..modeling import SetFitModel
from .aspect_extractor import AspectExtractor


if TYPE_CHECKING:
    from spacy.tokens import Doc

logger = logging.get_logger(__name__)


@dataclass
class SpanSetFitModel(SetFitModel):
    spacy_model: str = "en_core_web_lg"
    span_context: int = 0

    attributes_to_save: Set[str] = field(
        init=False,
        repr=False,
        default_factory=lambda: {"normalize_embeddings", "labels", "span_context", "spacy_model"},
    )

    def prepend_aspects(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[str]:
        for doc, aspects in zip(docs, aspects_list):
            for aspect_slice in aspects:
                aspect = doc[max(aspect_slice.start - self.span_context, 0) : aspect_slice.stop + self.span_context]
                # TODO: Investigate performance difference of different formats
                yield aspect.text + ":" + doc.text

    def __call__(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[bool]:
        inputs_list = list(self.prepend_aspects(docs, aspects_list))
        preds = self.predict(inputs_list, as_numpy=True)
        iter_preds = iter(preds)
        return [[next(iter_preds) for _ in aspects] for aspects in aspects_list]

    def create_model_card(self, path: str, model_name: Optional[str] = None) -> None:
        """Creates and saves a model card for a SetFit model.

        Args:
            path (str): The path to save the model card to.
            model_name (str, *optional*): The name of the model. Defaults to `SetFit Model`.
        """
        if not os.path.exists(path):
            os.makedirs(path)

        # If the model_path is a folder that exists locally, i.e. when create_model_card is called
        # via push_to_hub, and the path is in a temporary folder, then we only take the last two
        # directories
        model_path = Path(model_name)
        if model_path.exists() and Path(tempfile.gettempdir()) in model_path.resolve().parents:
            model_name = "/".join(model_path.parts[-2:])

        is_aspect = isinstance(self, AspectModel)
        aspect_model = "setfit-absa-aspect"
        polarity_model = "setfit-absa-polarity"
        if model_name is not None:
            if is_aspect:
                aspect_model = model_name
                if model_name.endswith("-aspect"):
                    polarity_model = model_name[: -len("-aspect")] + "-polarity"
            else:
                polarity_model = model_name
                if model_name.endswith("-polarity"):
                    aspect_model = model_name[: -len("-polarity")] + "-aspect"

        # Only once:
        if self.model_card_data.absa is None and self.model_card_data.model_name:
            from spacy import __version__ as spacy_version

            self.model_card_data.model_name = self.model_card_data.model_name.replace(
                "SetFit", "SetFit Aspect Model" if is_aspect else "SetFit Polarity Model", 1
            )
            self.model_card_data.tags.insert(1, "absa")
            self.model_card_data.version["spacy"] = spacy_version
        self.model_card_data.absa = {
            "is_absa": True,
            "is_aspect": is_aspect,
            "spacy_model": self.spacy_model,
            "aspect_model": aspect_model,
            "polarity_model": polarity_model,
        }
        if self.model_card_data.task_name is None:
            self.model_card_data.task_name = "Aspect Based Sentiment Analysis (ABSA)"
        self.model_card_data.inference = False
        with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f:
            f.write(self.generate_model_card())


docstring = SpanSetFitModel.from_pretrained.__doc__
cut_index = docstring.find("multi_target_strategy")
if cut_index != -1:
    docstring = (
        docstring[:cut_index]
        + """model_card_data (`SetFitModelCardData`, *optional*):
                A `SetFitModelCardData` instance storing data such as model language, license, dataset name,
                    etc. to be used in the automatically generated model cards.
            use_differentiable_head (`bool`, *optional*):
                Whether to load SetFit using a differentiable (i.e., Torch) head instead of Logistic Regression.
            normalize_embeddings (`bool`, *optional*):
                Whether to apply normalization on the embeddings produced by the Sentence Transformer body.
            span_context (`int`, defaults to `0`):
                The number of words before and after the span candidate that should be prepended to the full sentence.
                By default, 0 for Aspect models and 3 for Polarity models.
            device (`Union[torch.device, str]`, *optional*):
                The device on which to load the SetFit model, e.g. `"cuda:0"`, `"mps"` or `torch.device("cuda")`."""
    )
    SpanSetFitModel.from_pretrained = set_docstring(SpanSetFitModel.from_pretrained, docstring, cls=SpanSetFitModel)


class AspectModel(SpanSetFitModel):
    def __call__(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[bool]:
        sentence_preds = super().__call__(docs, aspects_list)
        return [
            [aspect for aspect, pred in zip(aspects, preds) if pred == "aspect"]
            for aspects, preds in zip(aspects_list, sentence_preds)
        ]


# The set_docstring magic has as a consequences that subclasses need to update the cls in the from_pretrained
# classmethod, otherwise the wrong instance will be instantiated.
AspectModel.from_pretrained = types.MethodType(AspectModel.from_pretrained.__func__, AspectModel)


@dataclass
class PolarityModel(SpanSetFitModel):
    span_context: int = 3


PolarityModel.from_pretrained = types.MethodType(PolarityModel.from_pretrained.__func__, PolarityModel)


@dataclass
class AbsaModel:
    aspect_extractor: AspectExtractor
    aspect_model: AspectModel
    polarity_model: PolarityModel

    def predict(self, inputs: Union[str, List[str]]) -> List[Dict[str, Any]]:
        is_str = isinstance(inputs, str)
        inputs_list = [inputs] if is_str else inputs
        docs, aspects_list = self.aspect_extractor(inputs_list)
        if sum(aspects_list, []) == []:
            return aspects_list

        aspects_list = self.aspect_model(docs, aspects_list)
        if sum(aspects_list, []) == []:
            return aspects_list

        polarity_list = self.polarity_model(docs, aspects_list)
        outputs = []
        for docs, aspects, polarities in zip(docs, aspects_list, polarity_list):
            outputs.append(
                [
                    {"span": docs[aspect_slice].text, "polarity": polarity}
                    for aspect_slice, polarity in zip(aspects, polarities)
                ]
            )
        return outputs if not is_str else outputs[0]

    @property
    def device(self) -> torch.device:
        return self.aspect_model.device

    def to(self, device: Union[str, torch.device]) -> "AbsaModel":
        self.aspect_model.to(device)
        self.polarity_model.to(device)

    def __call__(self, inputs: Union[str, List[str]]) -> List[Dict[str, Any]]:
        return self.predict(inputs)

    def save_pretrained(
        self,
        save_directory: Union[str, Path],
        polarity_save_directory: Optional[Union[str, Path]] = None,
        push_to_hub: bool = False,
        **kwargs,
    ) -> None:
        if polarity_save_directory is None:
            base_save_directory = Path(save_directory)
            save_directory = base_save_directory.parent / (base_save_directory.name + "-aspect")
            polarity_save_directory = base_save_directory.parent / (base_save_directory.name + "-polarity")
        self.aspect_model.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
        self.polarity_model.save_pretrained(save_directory=polarity_save_directory, push_to_hub=push_to_hub, **kwargs)

    @classmethod
    def from_pretrained(
        cls,
        model_id: str,
        polarity_model_id: Optional[str] = None,
        spacy_model: Optional[str] = None,
        span_contexts: Tuple[Optional[int], Optional[int]] = (None, None),
        force_download: bool = None,
        resume_download: bool = None,
        proxies: Optional[Dict] = None,
        token: Optional[Union[str, bool]] = None,
        cache_dir: Optional[str] = None,
        local_files_only: bool = None,
        use_differentiable_head: bool = None,
        normalize_embeddings: bool = None,
        **model_kwargs,
    ) -> "AbsaModel":
        revision = None
        if len(model_id.split("@")) == 2:
            model_id, revision = model_id.split("@")
        if spacy_model:
            model_kwargs["spacy_model"] = spacy_model
        aspect_model = AspectModel.from_pretrained(
            model_id,
            span_context=span_contexts[0],
            revision=revision,
            force_download=force_download,
            resume_download=resume_download,
            proxies=proxies,
            token=token,
            cache_dir=cache_dir,
            local_files_only=local_files_only,
            use_differentiable_head=use_differentiable_head,
            normalize_embeddings=normalize_embeddings,
            labels=["no aspect", "aspect"],
            **model_kwargs,
        )
        if polarity_model_id:
            model_id = polarity_model_id
            revision = None
            if len(model_id.split("@")) == 2:
                model_id, revision = model_id.split("@")
        # If model_card_data was provided, "separate" the instance between the Aspect
        # and Polarity models.
        model_card_data = model_kwargs.pop("model_card_data", None)
        if model_card_data:
            model_kwargs["model_card_data"] = copy.deepcopy(model_card_data)
        polarity_model = PolarityModel.from_pretrained(
            model_id,
            span_context=span_contexts[1],
            revision=revision,
            force_download=force_download,
            resume_download=resume_download,
            proxies=proxies,
            token=token,
            cache_dir=cache_dir,
            local_files_only=local_files_only,
            use_differentiable_head=use_differentiable_head,
            normalize_embeddings=normalize_embeddings,
            **model_kwargs,
        )
        if aspect_model.spacy_model != polarity_model.spacy_model:
            logger.warning(
                "The Aspect and Polarity models are configured to use different spaCy models:\n"
                f"* {repr(aspect_model.spacy_model)} for the aspect model, and\n"
                f"* {repr(polarity_model.spacy_model)} for the polarity model.\n"
                f"This model will use {repr(aspect_model.spacy_model)}."
            )

        aspect_extractor = AspectExtractor(spacy_model=aspect_model.spacy_model)

        return cls(aspect_extractor, aspect_model, polarity_model)

    def push_to_hub(self, repo_id: str, polarity_repo_id: Optional[str] = None, **kwargs) -> None:
        if "/" not in repo_id:
            raise ValueError(
                '`repo_id` must be a full repository ID, including organisation, e.g. "tomaarsen/setfit-absa-restaurant".'
            )
        if polarity_repo_id is not None and "/" not in polarity_repo_id:
            raise ValueError(
                '`polarity_repo_id` must be a full repository ID, including organisation, e.g. "tomaarsen/setfit-absa-restaurant".'
            )
        commit_message = kwargs.pop("commit_message", "Add SetFit ABSA model")

        # Push the files to the repo in a single commit
        with SoftTemporaryDirectory() as tmp_dir:
            save_directory = Path(tmp_dir) / repo_id
            polarity_save_directory = None if polarity_repo_id is None else Path(tmp_dir) / polarity_repo_id
            self.save_pretrained(
                save_directory=save_directory,
                polarity_save_directory=polarity_save_directory,
                push_to_hub=True,
                commit_message=commit_message,
                **kwargs,
            )