import warnings
from typing import Dict, List, Optional, Tuple, Union

from lagent.llms.base_llm import AsyncLLMMixin, BaseLLM


class APITemplateParser:
    """Intermidate prompt template parser, specifically for API models.

    Args:
        meta_template (Dict): The meta template for the model.
    """

    def __init__(self, meta_template: Optional[Dict] = None):
        self.meta_template = meta_template
        # Check meta template
        if meta_template:
            assert isinstance(meta_template, list)
            self.roles: Dict[str, dict] = dict()  # maps role name to config
            for item in meta_template:
                assert isinstance(item, dict)
                assert item['role'] not in self.roles, \
                    'role in meta prompt must be unique!'
                self.roles[item['role']] = item.copy()

    def __call__(self, dialog: List[Union[str, List]]):
        """Parse the intermidate prompt template, and wrap it with meta
        template if applicable. When the meta template is set and the input is
        a list, the return value will be a list containing the full
        conversation history. Each item looks like:

        .. code-block:: python

            {'role': 'user', 'content': '...'}).

        Args:
            dialog (List[str or list]): An intermidate prompt
                template (potentially before being wrapped by meta template).

        Returns:
            List[str or list]: The finalized prompt or a conversation.
        """
        assert isinstance(dialog, (str, list))
        if isinstance(dialog, str):
            return dialog
        if self.meta_template:

            prompt = list()
            # Whether to keep generating the prompt
            generate = True
            for i, item in enumerate(dialog):
                if not generate:
                    break
                if isinstance(item, str):
                    if item.strip():
                        # TODO: logger
                        warnings.warn('Non-empty string in prompt template '
                                      'will be ignored in API models.')
                else:
                    api_prompts = self._prompt2api(item)
                    prompt.append(api_prompts)

            # merge the consecutive prompts assigned to the same role
            new_prompt = list([prompt[0]])
            last_role = prompt[0]['role']
            for item in prompt[1:]:
                if item['role'] == last_role:
                    new_prompt[-1]['content'] += '\n' + item['content']
                else:
                    last_role = item['role']
                    new_prompt.append(item)
            prompt = new_prompt

        else:
            # in case the model does not have any meta template
            prompt = ''
            last_sep = ''
            for item in dialog:
                if isinstance(item, str):
                    if item:
                        prompt += last_sep + item
                elif item.get('content', ''):
                    prompt += last_sep + item.get('content', '')
                last_sep = '\n'
        return prompt

    def _prompt2api(self, prompts: Union[List, str]) -> Tuple[str, bool]:
        """Convert the prompts to a API-style prompts, given an updated
        role_dict.

        Args:
            prompts (Union[List, str]): The prompts to be converted.
            role_dict (Dict[str, Dict]): The updated role dict.
            for_gen (bool): If True, the prompts will be converted for
                generation tasks. The conversion stops before the first
                role whose "generate" is set to True.

        Returns:
            Tuple[str, bool]: The converted string, and whether the follow-up
            conversion should be proceeded.
        """
        if isinstance(prompts, str):
            return prompts
        elif isinstance(prompts, dict):
            api_role = self._role2api_role(prompts)
            return api_role

        res = []
        for prompt in prompts:
            if isinstance(prompt, str):
                raise TypeError('Mixing str without explicit role is not '
                                'allowed in API models!')
            else:
                api_role = self._role2api_role(prompt)
                res.append(api_role)
        return res

    def _role2api_role(self, role_prompt: Dict) -> Tuple[str, bool]:
        merged_prompt = self.roles[role_prompt['role']]
        if merged_prompt.get('fallback_role'):
            merged_prompt = self.roles[self.roles[
                merged_prompt['fallback_role']]]
        res = role_prompt.copy()
        res['role'] = merged_prompt['api_role']
        res['content'] = merged_prompt.get('begin', '')
        res['content'] += role_prompt.get('content', '')
        res['content'] += merged_prompt.get('end', '')
        return res


class BaseAPILLM(BaseLLM):
    """Base class for API model wrapper.

    Args:
        model_type (str): The type of model.
        retry (int): Number of retires if the API call fails. Defaults to 2.
        meta_template (Dict, optional): The model's meta prompt
            template if needed, in case the requirement of injecting or
            wrapping of any meta instructions.
    """

    is_api: bool = True

    def __init__(self,
                 model_type: str,
                 retry: int = 2,
                 template_parser: 'APITemplateParser' = APITemplateParser,
                 meta_template: Optional[Dict] = None,
                 *,
                 max_new_tokens: int = 512,
                 top_p: float = 0.8,
                 top_k: int = 40,
                 temperature: float = 0.8,
                 repetition_penalty: float = 0.0,
                 stop_words: Union[List[str], str] = None):
        self.model_type = model_type
        self.meta_template = meta_template
        self.retry = retry
        if template_parser:
            self.template_parser = template_parser(meta_template)

        if isinstance(stop_words, str):
            stop_words = [stop_words]
        self.gen_params = dict(
            max_new_tokens=max_new_tokens,
            top_p=top_p,
            top_k=top_k,
            temperature=temperature,
            repetition_penalty=repetition_penalty,
            stop_words=stop_words,
            skip_special_tokens=False)


class AsyncBaseAPILLM(AsyncLLMMixin, BaseAPILLM):
    pass