import copy
import random

import torch
import torch.nn as nn
from transformers import CLIPTokenizer
from typing import Any, List, Optional, Union

class TokenizerWrapper:
    """Tokenizer wrapper for CLIPTokenizer. Only support CLIPTokenizer
    currently. This wrapper is modified from https://github.com/huggingface/dif
    fusers/blob/e51f19aee82c8dd874b715a09dbc521d88835d68/src/diffusers/loaders.
    py#L358  # noqa.

    Args:
        from_pretrained (Union[str, os.PathLike], optional): The *model id*
            of a pretrained model or a path to a *directory* containing
            model weights and config. Defaults to None.
        from_config (Union[str, os.PathLike], optional): The *model id*
            of a pretrained model or a path to a *directory* containing
            model weights and config. Defaults to None.

        *args, **kwargs: If `from_pretrained` is passed, *args and **kwargs
            will be passed to `from_pretrained` function. Otherwise, *args
            and **kwargs will be used to initialize the model by
            `self._module_cls(*args, **kwargs)`.
    """

    def __init__(self, tokenizer: CLIPTokenizer):
        self.wrapped = tokenizer
        self.token_map = {}

    def __getattr__(self, name: str) -> Any:
        if name in self.__dict__:
            return getattr(self, name)
        #if name == "wrapped":
        #    return getattr(self, 'wrapped')#super().__getattr__("wrapped")

        try:
            return getattr(self.wrapped, name)
        except AttributeError:
            raise AttributeError(
                "'name' cannot be found in both "
                f"'{self.__class__.__name__}' and "
                f"'{self.__class__.__name__}.tokenizer'."
            )

    def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs):
        """Attempt to add tokens to the tokenizer.

        Args:
            tokens (Union[str, List[str]]): The tokens to be added.
        """
        num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs)
        assert num_added_tokens != 0, (
            f"The tokenizer already contains the token {tokens}. Please pass "
            "a different `placeholder_token` that is not already in the "
            "tokenizer."
        )

    def get_token_info(self, token: str) -> dict:
        """Get the information of a token, including its start and end index in
        the current tokenizer.

        Args:
            token (str): The token to be queried.

        Returns:
            dict: The information of the token, including its start and end
                index in current tokenizer.
        """
        token_ids = self.__call__(token).input_ids
        start, end = token_ids[1], token_ids[-2] + 1
        return {"name": token, "start": start, "end": end}

    def add_placeholder_token(self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs):
        """Add placeholder tokens to the tokenizer.

        Args:
            placeholder_token (str): The placeholder token to be added.
            num_vec_per_token (int, optional): The number of vectors of
                the added placeholder token.
            *args, **kwargs: The arguments for `self.wrapped.add_tokens`.
        """
        output = []
        if num_vec_per_token == 1:
            self.try_adding_tokens(placeholder_token, *args, **kwargs)
            output.append(placeholder_token)
        else:
            output = []
            for i in range(num_vec_per_token):
                ith_token = placeholder_token + f"_{i}"
                self.try_adding_tokens(ith_token, *args, **kwargs)
                output.append(ith_token)

        for token in self.token_map:
            if token in placeholder_token:
                raise ValueError(
                    f"The tokenizer already has placeholder token {token} "
                    f"that can get confused with {placeholder_token} "
                    "keep placeholder tokens independent"
                )
        self.token_map[placeholder_token] = output

    def replace_placeholder_tokens_in_text(
        self, text: Union[str, List[str]], vector_shuffle: bool = False, prop_tokens_to_load: float = 1.0
    ) -> Union[str, List[str]]:
        """Replace the keywords in text with placeholder tokens. This function
        will be called in `self.__call__` and `self.encode`.

        Args:
            text (Union[str, List[str]]): The text to be processed.
            vector_shuffle (bool, optional): Whether to shuffle the vectors.
                Defaults to False.
            prop_tokens_to_load (float, optional): The proportion of tokens to
                be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0.

        Returns:
            Union[str, List[str]]: The processed text.
        """
        if isinstance(text, list):
            output = []
            for i in range(len(text)):
                output.append(self.replace_placeholder_tokens_in_text(text[i], vector_shuffle=vector_shuffle))
            return output

        for placeholder_token in self.token_map:
            if placeholder_token in text:
                tokens = self.token_map[placeholder_token]
                tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
                if vector_shuffle:
                    tokens = copy.copy(tokens)
                    random.shuffle(tokens)
                text = text.replace(placeholder_token, " ".join(tokens))
        return text

    def replace_text_with_placeholder_tokens(self, text: Union[str, List[str]]) -> Union[str, List[str]]:
        """Replace the placeholder tokens in text with the original keywords.
        This function will be called in `self.decode`.

        Args:
            text (Union[str, List[str]]): The text to be processed.

        Returns:
            Union[str, List[str]]: The processed text.
        """
        if isinstance(text, list):
            output = []
            for i in range(len(text)):
                output.append(self.replace_text_with_placeholder_tokens(text[i]))
            return output

        for placeholder_token, tokens in self.token_map.items():
            merged_tokens = " ".join(tokens)
            if merged_tokens in text:
                text = text.replace(merged_tokens, placeholder_token)
        return text

    def __call__(
        self,
        text: Union[str, List[str]],
        *args,
        vector_shuffle: bool = False,
        prop_tokens_to_load: float = 1.0,
        **kwargs,
    ):
        """The call function of the wrapper.

        Args:
            text (Union[str, List[str]]): The text to be tokenized.
            vector_shuffle (bool, optional): Whether to shuffle the vectors.
                Defaults to False.
            prop_tokens_to_load (float, optional): The proportion of tokens to
                be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0
            *args, **kwargs: The arguments for `self.wrapped.__call__`.
        """
        replaced_text = self.replace_placeholder_tokens_in_text(
            text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
        )

        return self.wrapped.__call__(replaced_text, *args, **kwargs)

    def encode(self, text: Union[str, List[str]], *args, **kwargs):
        """Encode the passed text to token index.

        Args:
            text (Union[str, List[str]]): The text to be encode.
            *args, **kwargs: The arguments for `self.wrapped.__call__`.
        """
        replaced_text = self.replace_placeholder_tokens_in_text(text)
        return self.wrapped(replaced_text, *args, **kwargs)

    def decode(self, token_ids, return_raw: bool = False, *args, **kwargs) -> Union[str, List[str]]:
        """Decode the token index to text.

        Args:
            token_ids: The token index to be decoded.
            return_raw: Whether keep the placeholder token in the text.
                Defaults to False.
            *args, **kwargs: The arguments for `self.wrapped.decode`.

        Returns:
            Union[str, List[str]]: The decoded text.
        """
        text = self.wrapped.decode(token_ids, *args, **kwargs)
        if return_raw:
            return text
        replaced_text = self.replace_text_with_placeholder_tokens(text)
        return replaced_text

    def __repr__(self):
        """The representation of the wrapper."""
        s = super().__repr__()
        prefix = f"Wrapped Module Class: {self._module_cls}\n"
        prefix += f"Wrapped Module Name: {self._module_name}\n"
        if self._from_pretrained:
            prefix += f"From Pretrained: {self._from_pretrained}\n"
        s = prefix + s
        return s
    

class EmbeddingLayerWithFixes(nn.Module):
    """The revised embedding layer to support external embeddings. This design
    of this class is inspired by https://github.com/AUTOMATIC1111/stable-
    diffusion-webui/blob/22bcc7be428c94e9408f589966c2040187245d81/modules/sd_hi
    jack.py#L224  # noqa.

    Args:
        wrapped (nn.Emebdding): The embedding layer to be wrapped.
        external_embeddings (Union[dict, List[dict]], optional): The external
            embeddings added to this layer. Defaults to None.
    """

    def __init__(self, wrapped: nn.Embedding, external_embeddings: Optional[Union[dict, List[dict]]] = None):
        super().__init__()
        self.wrapped = wrapped
        self.num_embeddings = wrapped.weight.shape[0]

        self.external_embeddings = []
        if external_embeddings:
            self.add_embeddings(external_embeddings)

        self.trainable_embeddings = nn.ParameterDict()

    @property
    def weight(self):
        """Get the weight of wrapped embedding layer."""
        return self.wrapped.weight

    def check_duplicate_names(self, embeddings: List[dict]):
        """Check whether duplicate names exist in list of 'external
        embeddings'.

        Args:
            embeddings (List[dict]): A list of embedding to be check.
        """
        names = [emb["name"] for emb in embeddings]
        assert len(names) == len(set(names)), (
            "Found duplicated names in 'external_embeddings'. Name list: " f"'{names}'"
        )

    def check_ids_overlap(self, embeddings):
        """Check whether overlap exist in token ids of 'external_embeddings'.

        Args:
            embeddings (List[dict]): A list of embedding to be check.
        """
        ids_range = [[emb["start"], emb["end"], emb["name"]] for emb in embeddings]
        ids_range.sort()  # sort by 'start'
        # check if 'end' has overlapping
        for idx in range(len(ids_range) - 1):
            name1, name2 = ids_range[idx][-1], ids_range[idx + 1][-1]
            assert ids_range[idx][1] <= ids_range[idx + 1][0], (
                f"Found ids overlapping between embeddings '{name1}' " f"and '{name2}'."
            )

    def add_embeddings(self, embeddings: Optional[Union[dict, List[dict]]]):
        """Add external embeddings to this layer.

        Use case:

        >>> 1. Add token to tokenizer and get the token id.
        >>> tokenizer = TokenizerWrapper('openai/clip-vit-base-patch32')
        >>> # 'how much' in kiswahili
        >>> tokenizer.add_placeholder_tokens('ngapi', num_vec_per_token=4)
        >>>
        >>> 2. Add external embeddings to the model.
        >>> new_embedding = {
        >>>     'name': 'ngapi',  # 'how much' in kiswahili
        >>>     'embedding': torch.ones(1, 15) * 4,
        >>>     'start': tokenizer.get_token_info('kwaheri')['start'],
        >>>     'end': tokenizer.get_token_info('kwaheri')['end'],
        >>>     'trainable': False  # if True, will registry as a parameter
        >>> }
        >>> embedding_layer = nn.Embedding(10, 15)
        >>> embedding_layer_wrapper = EmbeddingLayerWithFixes(embedding_layer)
        >>> embedding_layer_wrapper.add_embeddings(new_embedding)
        >>>
        >>> 3. Forward tokenizer and embedding layer!
        >>> input_text = ['hello, ngapi!', 'hello my friend, ngapi?']
        >>> input_ids = tokenizer(
        >>>     input_text, padding='max_length', truncation=True,
        >>>     return_tensors='pt')['input_ids']
        >>> out_feat = embedding_layer_wrapper(input_ids)
        >>>
        >>> 4. Let's validate the result!
        >>> assert (out_feat[0, 3: 7] == 2.3).all()
        >>> assert (out_feat[2, 5: 9] == 2.3).all()

        Args:
            embeddings (Union[dict, list[dict]]): The external embeddings to
                be added. Each dict must contain the following 4 fields: 'name'
                (the name of this embedding), 'embedding' (the embedding
                tensor), 'start' (the start token id of this embedding), 'end'
                (the end token id of this embedding). For example:
                `{name: NAME, start: START, end: END, embedding: torch.Tensor}`
        """
        if isinstance(embeddings, dict):
            embeddings = [embeddings]

        self.external_embeddings += embeddings
        self.check_duplicate_names(self.external_embeddings)
        self.check_ids_overlap(self.external_embeddings)

        # set for trainable
        added_trainable_emb_info = []
        for embedding in embeddings:
            trainable = embedding.get("trainable", False)
            if trainable:
                name = embedding["name"]
                embedding["embedding"] = torch.nn.Parameter(embedding["embedding"])
                self.trainable_embeddings[name] = embedding["embedding"]
                added_trainable_emb_info.append(name)

        added_emb_info = [emb["name"] for emb in embeddings]
        added_emb_info = ", ".join(added_emb_info)
        print(f"Successfully add external embeddings: {added_emb_info}.", "current")

        if added_trainable_emb_info:
            added_trainable_emb_info = ", ".join(added_trainable_emb_info)
            print("Successfully add trainable external embeddings: " f"{added_trainable_emb_info}", "current")

    def replace_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        """Replace external input ids to 0.

        Args:
            input_ids (torch.Tensor): The input ids to be replaced.

        Returns:
            torch.Tensor: The replaced input ids.
        """
        input_ids_fwd = input_ids.clone()
        input_ids_fwd[input_ids_fwd >= self.num_embeddings] = 0
        return input_ids_fwd

    def replace_embeddings(
        self, input_ids: torch.Tensor, embedding: torch.Tensor, external_embedding: dict
    ) -> torch.Tensor:
        """Replace external embedding to the embedding layer. Noted that, in
        this function we use `torch.cat` to avoid inplace modification.

        Args:
            input_ids (torch.Tensor): The original token ids. Shape like
                [LENGTH, ].
            embedding (torch.Tensor): The embedding of token ids after
                `replace_input_ids` function.
            external_embedding (dict): The external embedding to be replaced.

        Returns:
            torch.Tensor: The replaced embedding.
        """
        new_embedding = []

        name = external_embedding["name"]
        start = external_embedding["start"]
        end = external_embedding["end"]
        target_ids_to_replace = [i for i in range(start, end)]
        ext_emb = external_embedding["embedding"]

        # do not need to replace
        if not (input_ids == start).any():
            return embedding

        # start replace
        s_idx, e_idx = 0, 0
        while e_idx < len(input_ids):
            if input_ids[e_idx] == start:
                if e_idx != 0:
                    # add embedding do not need to replace
                    new_embedding.append(embedding[s_idx:e_idx])

                # check if the next embedding need to replace is valid
                actually_ids_to_replace = [int(i) for i in input_ids[e_idx : e_idx + end - start]]
                assert actually_ids_to_replace == target_ids_to_replace, (
                    f"Invalid 'input_ids' in position: {s_idx} to {e_idx}. "
                    f"Expect '{target_ids_to_replace}' for embedding "
                    f"'{name}' but found '{actually_ids_to_replace}'."
                )

                new_embedding.append(ext_emb)

                s_idx = e_idx + end - start
                e_idx = s_idx + 1
            else:
                e_idx += 1

        if e_idx == len(input_ids):
            new_embedding.append(embedding[s_idx:e_idx])

        return torch.cat(new_embedding, dim=0)

    def forward(self, input_ids: torch.Tensor, external_embeddings: Optional[List[dict]] = None):
        """The forward function.

        Args:
            input_ids (torch.Tensor): The token ids shape like [bz, LENGTH] or
                [LENGTH, ].
            external_embeddings (Optional[List[dict]]): The external
                embeddings. If not passed, only `self.external_embeddings`
                will be used.  Defaults to None.

        input_ids: shape like [bz, LENGTH] or [LENGTH].
        """
        assert input_ids.ndim in [1, 2]
        if input_ids.ndim == 1:
            input_ids = input_ids.unsqueeze(0)

        if external_embeddings is None and not self.external_embeddings:
            return self.wrapped(input_ids)

        input_ids_fwd = self.replace_input_ids(input_ids)
        inputs_embeds = self.wrapped(input_ids_fwd)

        vecs = []

        if external_embeddings is None:
            external_embeddings = []
        elif isinstance(external_embeddings, dict):
            external_embeddings = [external_embeddings]
        embeddings = self.external_embeddings + external_embeddings

        for input_id, embedding in zip(input_ids, inputs_embeds):
            new_embedding = embedding
            for external_embedding in embeddings:
                new_embedding = self.replace_embeddings(input_id, new_embedding, external_embedding)
            vecs.append(new_embedding)

        return torch.stack(vecs)



def add_tokens(
    tokenizer, text_encoder, placeholder_tokens: list, initialize_tokens: list = None, num_vectors_per_token: int = 1
):
    """Add token for training.

    # TODO: support add tokens as dict, then we can load pretrained tokens.
    """
    if initialize_tokens is not None:
        assert len(initialize_tokens) == len(
            placeholder_tokens
        ), "placeholder_token should be the same length as initialize_token"
    for ii in range(len(placeholder_tokens)):
        tokenizer.add_placeholder_token(placeholder_tokens[ii], num_vec_per_token=num_vectors_per_token)

    # text_encoder.set_embedding_layer()
    embedding_layer = text_encoder.text_model.embeddings.token_embedding
    text_encoder.text_model.embeddings.token_embedding = EmbeddingLayerWithFixes(embedding_layer)
    embedding_layer = text_encoder.text_model.embeddings.token_embedding

    assert embedding_layer is not None, (
        "Do not support get embedding layer for current text encoder. " "Please check your configuration."
    )
    initialize_embedding = []
    if initialize_tokens is not None:
        for ii in range(len(placeholder_tokens)):
            init_id = tokenizer(initialize_tokens[ii]).input_ids[1]
            temp_embedding = embedding_layer.weight[init_id]
            initialize_embedding.append(temp_embedding[None, ...].repeat(num_vectors_per_token, 1))
    else:
        for ii in range(len(placeholder_tokens)):
            init_id = tokenizer("a").input_ids[1]
            temp_embedding = embedding_layer.weight[init_id]
            len_emb = temp_embedding.shape[0]
            init_weight = (torch.rand(num_vectors_per_token, len_emb) - 0.5) / 2.0
            initialize_embedding.append(init_weight)

    # initialize_embedding  = torch.cat(initialize_embedding,dim=0)

    token_info_all = []
    for ii in range(len(placeholder_tokens)):
        token_info = tokenizer.get_token_info(placeholder_tokens[ii])
        token_info["embedding"] = initialize_embedding[ii]
        token_info["trainable"] = True
        token_info_all.append(token_info)
    embedding_layer.add_embeddings(token_info_all)