import copy
import logging
from typing import Dict, List, Optional, Union

from lagent.schema import ModelStatusCode
from .base_api import APITemplateParser
from .base_llm import BaseLLM

logger = logging.getLogger(__name__)


class HFTransformer(BaseLLM):
    """Model wrapper around HuggingFace general models.

    Adapted from Internlm (https://github.com/InternLM/InternLM/blob/main/
        chat/web_demo.py)

    Args:
        path (str): The name or path to HuggingFace's model.
        tokenizer_path (str): The path to the tokenizer. Defaults to None.
        tokenizer_kwargs (dict): Keyword arguments for the tokenizer.
            Defaults to {}.
        tokenizer_only (bool): If True, only the tokenizer will be initialized.
            Defaults to False.
        model_kwargs (dict): Keyword arguments for the model, used in loader.
            Defaults to dict(device_map='auto').
        meta_template (Dict, optional): The model's meta prompt
            template if needed, in case the requirement of injecting or
            wrapping of any meta instructions.
    """

    def __init__(self,
                 path: str,
                 tokenizer_path: Optional[str] = None,
                 tokenizer_kwargs: dict = dict(),
                 tokenizer_only: bool = False,
                 model_kwargs: dict = dict(device_map='auto'),
                 meta_template: Optional[Dict] = None,
                 stop_words_id: Union[List[int], int] = None,
                 **kwargs):
        super().__init__(
            path=path,
            tokenizer_only=tokenizer_only,
            meta_template=meta_template,
            **kwargs)
        if isinstance(stop_words_id, int):
            stop_words_id = [stop_words_id]
        self.gen_params.update(stop_words_id=stop_words_id)
        if self.gen_params['stop_words'] is not None and \
                self.gen_params['stop_words_id'] is not None:
            logger.warning('Both stop_words and stop_words_id are specified,'
                           'only stop_words_id will be used.')

        self._load_tokenizer(
            path=path,
            tokenizer_path=tokenizer_path,
            tokenizer_kwargs=tokenizer_kwargs)
        if not tokenizer_only:
            self._load_model(path=path, model_kwargs=model_kwargs)

        from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList  # noqa: E501
        self.logits_processor = LogitsProcessorList()
        self.stopping_criteria = StoppingCriteriaList()
        self.prefix_allowed_tokens_fn = None

        stop_words_id = []
        if self.gen_params.get('stop_words_id'):
            stop_words_id = self.gen_params.get('stop_words_id')
        elif self.gen_params.get('stop_words'):
            for sw in self.gen_params.get('stop_words'):
                stop_words_id.append(self.tokenizer(sw)['input_ids'][-1])
        self.additional_eos_token_id = stop_words_id

    def _load_tokenizer(self, path: str, tokenizer_path: Optional[str],
                        tokenizer_kwargs: dict):
        from transformers import AutoTokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_path if tokenizer_path else path,
            trust_remote_code=True,
            **tokenizer_kwargs)

        if self.tokenizer.pad_token_id is None:
            if self.tokenizer.eos_token is not None:
                logger.warning(
                    f'Using eos_token_id {self.tokenizer.eos_token} '
                    'as pad_token_id.')
                self.tokenizer.pad_token = self.tokenizer.eos_token
            else:
                from transformers.generation import GenerationConfig
                self.gcfg = GenerationConfig.from_pretrained(path)

                if self.gcfg.pad_token_id is not None:
                    logger.warning(
                        f'Using pad_token_id {self.gcfg.pad_token_id} '
                        'as pad_token_id.')
                    self.tokenizer.pad_token_id = self.gcfg.pad_token_id
                else:
                    raise ValueError(
                        'pad_token_id is not set for this tokenizer. Try to '
                        'set pad_token_id via passing '
                        '`pad_token_id={PAD_TOKEN_ID}` in model_cfg.')

    def _load_model(self, path: str, model_kwargs: dict):
        import torch
        from transformers import AutoModel
        model_kwargs.setdefault('torch_dtype', torch.float16)
        self.model = AutoModel.from_pretrained(
            path, trust_remote_code=True, **model_kwargs)
        self.model.eval()

    def tokenize(self, inputs: str):
        assert isinstance(inputs, str)
        inputs = self.tokenizer(
            inputs, return_tensors='pt', return_length=True)
        return inputs['input_ids'].tolist()

    def generate(
        self,
        inputs: Union[str, List[str]],
        do_sample: bool = True,
        **kwargs,
    ):
        """Return the chat completions in non-stream mode.

        Args:
            inputs (Union[str, List[str]]): input texts to be completed.
            do_sample (bool): do sampling if enabled
        Returns:
            (a list of/batched) text/chat completion
        """
        for status, chunk, _ in self.stream_generate(inputs, do_sample,
                                                     **kwargs):
            response = chunk
        return response

    def stream_generate(
        self,
        inputs: List[str],
        do_sample: bool = True,
        **kwargs,
    ):
        """Return the chat completions in stream mode.

        Args:
            inputs (Union[str, List[str]]): input texts to be completed.
            do_sample (bool): do sampling if enabled
        Returns:
            tuple(Status, str, int): status, text/chat completion,
            generated token number
        """
        import torch
        from torch import nn
        with torch.no_grad():
            batched = True
            if isinstance(inputs, str):
                inputs = [inputs]
                batched = False
            inputs = self.tokenizer(
                inputs, padding=True, return_tensors='pt', return_length=True)
            input_length = inputs['length']
            for k, v in inputs.items():
                inputs[k] = v.cuda()
            input_ids = inputs['input_ids']
            attention_mask = inputs['attention_mask']
            batch_size = input_ids.shape[0]
            input_ids_seq_length = input_ids.shape[-1]
            generation_config = self.model.generation_config
            generation_config = copy.deepcopy(generation_config)
            new_gen_params = self.update_gen_params(**kwargs)
            generation_config.update(**new_gen_params)
            generation_config.update(**kwargs)
            model_kwargs = generation_config.to_dict()
            model_kwargs['attention_mask'] = attention_mask
            _, eos_token_id = (  # noqa: F841  # pylint: disable=W0612
                generation_config.bos_token_id,
                generation_config.eos_token_id,
            )
            if eos_token_id is None:
                if self.gcfg.eos_token_id is not None:
                    eos_token_id = self.gcfg.eos_token_id
                else:
                    eos_token_id = []
            if isinstance(eos_token_id, int):
                eos_token_id = [eos_token_id]
            if self.additional_eos_token_id is not None:
                eos_token_id.extend(self.additional_eos_token_id)
            eos_token_id_tensor = torch.tensor(eos_token_id).to(
                input_ids.device) if eos_token_id is not None else None
            generation_config.max_length = (
                generation_config.max_new_tokens + input_ids_seq_length)
            # Set generation parameters if not already defined
            logits_processor = self.logits_processor
            stopping_criteria = self.stopping_criteria

            logits_processor = self.model._get_logits_processor(
                generation_config=generation_config,
                input_ids_seq_length=input_ids_seq_length,
                encoder_input_ids=input_ids,
                prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn,
                logits_processor=logits_processor,
            )

            stopping_criteria = self.model._get_stopping_criteria(
                generation_config=generation_config,
                stopping_criteria=stopping_criteria)
            logits_warper = self.model._get_logits_warper(generation_config)

            unfinished_sequences = input_ids.new(batch_size).fill_(1)
            scores = None
            while True:
                model_inputs = self.model.prepare_inputs_for_generation(
                    input_ids, **model_kwargs)
                # forward pass to get next token
                outputs = self.model(
                    **model_inputs,
                    return_dict=True,
                    output_attentions=False,
                    output_hidden_states=False,
                )

                next_token_logits = outputs.logits[:, -1, :]

                # pre-process distribution
                next_token_scores = logits_processor(input_ids,
                                                     next_token_logits)
                next_token_scores = logits_warper(input_ids, next_token_scores)

                # sample
                probs = nn.functional.softmax(next_token_scores, dim=-1)
                if do_sample:
                    next_tokens = torch.multinomial(
                        probs, num_samples=1).squeeze(1)
                else:
                    next_tokens = torch.argmax(probs, dim=-1)

                # update generated ids, model inputs,
                # and length for next step
                input_ids = torch.cat([input_ids, next_tokens[:, None]],
                                      dim=-1)
                model_kwargs = self.model._update_model_kwargs_for_generation(  # noqa: E501
                    outputs,
                    model_kwargs,
                    is_encoder_decoder=False)
                unfinished_sequences = unfinished_sequences.mul(
                    next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(
                        eos_token_id_tensor.unsqueeze(1)).prod(dim=0))
                output_token_ids = input_ids.cpu().tolist()
                for i in range(len(output_token_ids)):
                    output_token_ids[i] = output_token_ids[i][:][
                        input_length[i]:]
                    # Find the first occurrence of
                    # an EOS token in the sequence
                    first_eos_idx = next(
                        (idx
                         for idx, token_id in enumerate(output_token_ids[i])
                         if token_id in eos_token_id), None)
                    # If an EOS token is found, only the previous
                    # part of it is retained
                    if first_eos_idx is not None:
                        output_token_ids[i] = output_token_ids[
                            i][:first_eos_idx]

                response = self.tokenizer.batch_decode(output_token_ids)
                # print(response)
                if not batched:
                    response = response[0]
                yield ModelStatusCode.STREAM_ING, response, None
                # stop when each sentence is finished,
                # or if we exceed the maximum length
                if (unfinished_sequences.max() == 0
                        or stopping_criteria(input_ids, scores)):
                    break
            yield ModelStatusCode.END, response, None

    def stream_chat(
        self,
        inputs: List[dict],
        do_sample: bool = True,
        **kwargs,
    ):
        """Return the chat completions in stream mode.

        Args:
            inputs (List[dict]): input messages to be completed.
            do_sample (bool): do sampling if enabled
        Returns:
            the text/chat completion
        """
        prompt = self.template_parser(inputs)
        yield from self.stream_generate(prompt, do_sample, **kwargs)


class HFTransformerCasualLM(HFTransformer):

    def _load_model(self, path: str, model_kwargs: dict):
        import torch
        from transformers import AutoModelForCausalLM
        model_kwargs.setdefault('torch_dtype', torch.float16)
        self.model = AutoModelForCausalLM.from_pretrained(
            path, trust_remote_code=True, **model_kwargs)
        self.model.eval()


class HFTransformerChat(HFTransformerCasualLM):

    def __init__(self, template_parser=APITemplateParser, **kwargs):
        super().__init__(template_parser=template_parser, **kwargs)

    def chat(self,
             inputs: Union[List[dict], List[List[dict]]],
             do_sample: bool = True,
             **kwargs):
        """Return the chat completions in stream mode.

        Args:
            inputs (Union[List[dict], List[List[dict]]]): input messages to be completed.
            do_sample (bool): do sampling if enabled
        Returns:
            the text/chat completion
        """
        # handle batch inference with vanilla for loop
        if isinstance(inputs[0], list):
            resps = []
            for input in inputs:
                resps.append(self.chat(input, do_sample, **kwargs))
            return resps
        prompt = self.template_parser(inputs)
        query = prompt[-1]['content']
        history = prompt[:-1]
        try:
            response, history = self.model.chat(
                self.tokenizer, query, history=history)
        except Exception as e:
            # handle over-length input error
            logger.warning(str(e))
            response = ''
        return response