import asyncio
from typing import List, Union

from lagent.llms.base_llm import AsyncBaseLLM, BaseLLM
from lagent.utils.util import filter_suffix


def asdict_completion(output):
    return {
        key: getattr(output, key)
        for key in [
            'text', 'token_ids', 'cumulative_logprob', 'logprobs',
            'finish_reason', 'stop_reason'
        ]
    }


class VllmModel(BaseLLM):
    """
    A wrapper of vLLM model.

    Args:
        path (str): The path to the model.
            It could be one of the following options:
                    - i) A local directory path of a huggingface model.
                    - ii) The model_id of a model hosted inside a model repo
                        on huggingface.co, such as "internlm/internlm-chat-7b",
                        "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
                        and so on.
        tp (int): tensor parallel
        vllm_cfg (dict): Other kwargs for vllm model initialization.
    """

    def __init__(self, path: str, tp: int = 1, vllm_cfg=dict(), **kwargs):

        super().__init__(path=path, **kwargs)
        from vllm import LLM
        self.model = LLM(
            model=self.path,
            trust_remote_code=True,
            tensor_parallel_size=tp,
            **vllm_cfg)

    def generate(self,
                 inputs: Union[str, List[str]],
                 do_preprocess: bool = None,
                 skip_special_tokens: bool = False,
                 return_dict: bool = False,
                 **kwargs):
        """Return the chat completions in non-stream mode.

        Args:
            inputs (Union[str, List[str]]): input texts to be completed.
            do_preprocess (bool): whether pre-process the messages. Default to
                True, which means chat_template will be applied.
            skip_special_tokens (bool): Whether or not to remove special tokens
                in the decoding. Default to be False.
        Returns:
            (a list of/batched) text/chat completion
        """
        from vllm import SamplingParams

        batched = True
        if isinstance(inputs, str):
            inputs = [inputs]
            batched = False
        prompt = inputs
        gen_params = self.update_gen_params(**kwargs)
        max_new_tokens = gen_params.pop('max_new_tokens')
        stop_words = gen_params.pop('stop_words')

        sampling_config = SamplingParams(
            skip_special_tokens=skip_special_tokens,
            max_tokens=max_new_tokens,
            stop=stop_words,
            **gen_params)
        response = self.model.generate(prompt, sampling_params=sampling_config)
        texts = [resp.outputs[0].text for resp in response]
        # remove stop_words
        texts = filter_suffix(texts, self.gen_params.get('stop_words'))
        for resp, text in zip(response, texts):
            resp.outputs[0].text = text
        if batched:
            return [asdict_completion(resp.outputs[0])
                    for resp in response] if return_dict else texts
        return asdict_completion(
            response[0].outputs[0]) if return_dict else texts[0]


class AsyncVllmModel(AsyncBaseLLM):
    """
    A asynchronous wrapper of vLLM model.

    Args:
        path (str): The path to the model.
            It could be one of the following options:
                    - i) A local directory path of a huggingface model.
                    - ii) The model_id of a model hosted inside a model repo
                        on huggingface.co, such as "internlm/internlm-chat-7b",
                        "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
                        and so on.
        tp (int): tensor parallel
        vllm_cfg (dict): Other kwargs for vllm model initialization.
    """

    def __init__(self, path: str, tp: int = 1, vllm_cfg=dict(), **kwargs):
        super().__init__(path=path, **kwargs)
        from vllm import AsyncEngineArgs, AsyncLLMEngine

        engine_args = AsyncEngineArgs(
            model=self.path,
            trust_remote_code=True,
            tensor_parallel_size=tp,
            **vllm_cfg)
        self.model = AsyncLLMEngine.from_engine_args(engine_args)

    async def generate(self,
                       inputs: Union[str, List[str]],
                       session_ids: Union[int, List[int]] = None,
                       do_preprocess: bool = None,
                       skip_special_tokens: bool = False,
                       return_dict: bool = False,
                       **kwargs):
        """Return the chat completions in non-stream mode.

        Args:
            inputs (Union[str, List[str]]): input texts to be completed.
            do_preprocess (bool): whether pre-process the messages. Default to
                True, which means chat_template will be applied.
            skip_special_tokens (bool): Whether or not to remove special tokens
                in the decoding. Default to be False.
        Returns:
            (a list of/batched) text/chat completion
        """
        from vllm import SamplingParams

        batched = True
        if isinstance(inputs, str):
            inputs = [inputs]
            batched = False
        if session_ids is None:
            session_ids = list(range(len(inputs)))
        elif isinstance(session_ids, (int, str)):
            session_ids = [session_ids]
        assert len(inputs) == len(session_ids)

        prompt = inputs
        gen_params = self.update_gen_params(**kwargs)
        max_new_tokens = gen_params.pop('max_new_tokens')
        stop_words = gen_params.pop('stop_words')

        sampling_config = SamplingParams(
            skip_special_tokens=skip_special_tokens,
            max_tokens=max_new_tokens,
            stop=stop_words,
            **gen_params)

        async def _inner_generate(uid, text):
            resp, generator = '', self.model.generate(
                text, sampling_params=sampling_config, request_id=uid)
            async for out in generator:
                resp = out.outputs[0]
            return resp

        response = await asyncio.gather(*[
            _inner_generate(sid, inp) for sid, inp in zip(session_ids, prompt)
        ])
        texts = [resp.text for resp in response]
        # remove stop_words
        texts = filter_suffix(texts, self.gen_params.get('stop_words'))
        for resp, text in zip(response, texts):
            resp.text = text
        if batched:
            return [asdict_completion(resp)
                    for resp in response] if return_dict else texts
        return asdict_completion(response[0]) if return_dict else texts[0]