import asyncio
import copy
import logging
from dataclasses import asdict
from typing import List, Optional, Union

import aiohttp

from lagent.llms.base_llm import AsyncLLMMixin, BaseLLM
from lagent.schema import ModelStatusCode
from lagent.utils.util import filter_suffix


class TritonClient(BaseLLM):
    """TritonClient is a wrapper of TritonClient for LLM.

    Args:
        tritonserver_addr (str): the address in format "ip:port" of
            triton inference server
        model_name (str): the name of the model
        session_len (int): the context size
        max_tokens (int): the expected generated token numbers
    """

    def __init__(self,
                 tritonserver_addr: str,
                 model_name: str,
                 session_len: int = 32768,
                 log_level: str = 'WARNING',
                 **kwargs):
        super().__init__(path=None, **kwargs)
        try:
            from lmdeploy.serve.turbomind.chatbot import Chatbot, StatusCode
        except Exception as e:
            logging.error(f'{e}')
            raise RuntimeError('DO NOT use turbomind.chatbot since it has '
                               'been removed by lmdeploy since v0.5.2')
        self.state_map = {
            StatusCode.TRITON_STREAM_END: ModelStatusCode.END,
            StatusCode.TRITON_SERVER_ERR: ModelStatusCode.SERVER_ERR,
            StatusCode.TRITON_SESSION_CLOSED: ModelStatusCode.SESSION_CLOSED,
            StatusCode.TRITON_STREAM_ING: ModelStatusCode.STREAM_ING,
            StatusCode.TRITON_SESSION_OUT_OF_LIMIT:
            ModelStatusCode.SESSION_OUT_OF_LIMIT,
            StatusCode.TRITON_SESSION_INVALID_ARG:
            ModelStatusCode.SESSION_INVALID_ARG,
            StatusCode.TRITON_SESSION_READY: ModelStatusCode.SESSION_READY
        }
        self.chatbot = Chatbot(
            tritonserver_addr=tritonserver_addr,
            model_name=model_name,
            session_len=session_len,
            log_level=log_level,
            **kwargs)

    def generate(self,
                 inputs: Union[str, List[str]],
                 session_id: int = 2967,
                 request_id: str = '',
                 sequence_start: bool = True,
                 sequence_end: bool = True,
                 skip_special_tokens: bool = False,
                 **kwargs):
        """Start a new round conversation of a session. Return the chat
        completions in non-stream mode.

        Args:
            inputs (str, List[str]): user's prompt(s) in this round
            session_id (int): the identical id of a session
            request_id (str): the identical id of this round conversation
            sequence_start (bool): start flag of a session
            sequence_end (bool): end flag of a session
            skip_special_tokens (bool): Whether or not to remove special tokens
                in the decoding. Default to be False.
        Returns:
            (a list of/batched) text/chat completion
        """
        from lmdeploy.serve.turbomind.chatbot import Session, get_logger
        if isinstance(inputs, str):
            inputs = [inputs]
        prompt = inputs

        assert isinstance(session_id, int), \
            f'INT session id is required, but got {type(session_id)}'

        self.chatbot.cfg = self._update_gen_params(**kwargs)
        max_new_tokens = self.chatbot.cfg.max_new_tokens

        logger = get_logger('service.ft', log_level=self.chatbot.log_level)
        logger.info(f'session {session_id}, request_id {request_id}, '
                    f'max_out_len {max_new_tokens}')

        if self.chatbot._session is None:
            sequence_start = True
            self.chatbot._session = Session(session_id=session_id)
        elif self.chatbot._session.status == 0:
            logger.error(f'session {session_id} has been ended. Please set '
                         f'`sequence_start` be True if you want to restart it')
            return ''

        self.chatbot._session.status = 1
        self.chatbot._session.request_id = request_id
        self.chatbot._session.response = ''

        status, res, _ = None, '', 0
        for status, res, _ in self.chatbot._stream_infer(
                self.chatbot._session,
                prompt,
                max_new_tokens,
                sequence_start,
                sequence_end,
                skip_special_tokens=skip_special_tokens):
            status = self.state_map.get(status)
            if status < ModelStatusCode.END:
                return ''
            elif status == ModelStatusCode.END:
                self.chatbot._session.histories = (
                    self.chatbot._session.histories +
                    self.chatbot._session.prompt +
                    self.chatbot._session.response)
                # remove stop_words
                res = filter_suffix(res, self.gen_params.get('stop_words'))
                return res

    def stream_chat(self,
                    inputs: List[dict],
                    session_id: int = 2967,
                    request_id: str = '',
                    sequence_start: bool = True,
                    sequence_end: bool = True,
                    skip_special_tokens: bool = False,
                    **kwargs):
        """Start a new round conversation of a session. Return the chat
        completions in stream mode.

        Args:
            session_id (int): the identical id of a session
            inputs (List[dict]): user's inputs in this round conversation
            request_id (str): the identical id of this round conversation
            sequence_start (bool): start flag of a session
            sequence_end (bool): end flag of a session
            skip_special_tokens (bool): Whether or not to remove special tokens
                in the decoding. Default to be False.
        Returns:
            tuple(Status, str, int): status, text/chat completion,
            generated token number
        """
        from lmdeploy.serve.turbomind.chatbot import Session, get_logger
        assert isinstance(session_id, int), \
            f'INT session id is required, but got {type(session_id)}'

        self.chatbot.cfg = self._update_gen_params(**kwargs)
        max_new_tokens = self.chatbot.cfg.max_new_tokens

        logger = get_logger('service.ft', log_level=self.chatbot.log_level)
        logger.info(f'session {session_id}, request_id {request_id}, '
                    f'max_out_len {max_new_tokens}')

        if self.chatbot._session is None:
            sequence_start = True
            self.chatbot._session = Session(session_id=session_id)
        elif self.chatbot._session.status == 0:
            logger.error(f'session {session_id} has been ended. Please set '
                         f'`sequence_start` be True if you want to restart it')
            return ModelStatusCode.SESSION_CLOSED, '', 0

        self.chatbot._session.status = 1
        self.chatbot._session.request_id = request_id
        self.chatbot._session.response = ''

        prompt = self.template_parser(inputs)
        status, res, _ = None, '', 0
        for status, res, _ in self.chatbot._stream_infer(
                self.chatbot._session,
                prompt,
                max_new_tokens,
                sequence_start,
                sequence_end,
                skip_special_tokens=skip_special_tokens):
            status = self.state_map.get(status)
            # The stop symbol also appears in the output of the last STREAM_ING state.
            res = filter_suffix(res, self.gen_params.get('stop_words'))
            if status < ModelStatusCode.END:
                return status, res, _
            elif status == ModelStatusCode.END:  # remove stop_words
                self.chatbot._session.histories = (
                    self.chatbot._session.histories +
                    self.chatbot._session.prompt +
                    self.chatbot._session.response)
                yield status, res, _
                break
            else:
                yield status, res, _

    def _update_gen_params(self, **kwargs):
        import mmengine
        new_gen_params = self.update_gen_params(**kwargs)
        self.gen_params['stop_words'] = new_gen_params.pop('stop_words')
        stop_words = self.chatbot._stop_words(
            self.gen_params.get('stop_words'))
        cfg = mmengine.Config(
            dict(
                session_len=self.chatbot.model.session_len,
                stop_words=stop_words,
                bad_words=self.chatbot.cfg.bad_words,
                **new_gen_params))
        return cfg


class LMDeployPipeline(BaseLLM):
    """

    Args:
        path (str): The path to the model.
            It could be one of the following options:
                    - i) A local directory path of a turbomind model which is
                        converted by `lmdeploy convert` command or download
                        from ii) and iii).
                    - ii) The model_id of a lmdeploy-quantized model hosted
                        inside a model repo on huggingface.co, such as
                        "InternLM/internlm-chat-20b-4bit",
                        "lmdeploy/llama2-chat-70b-4bit", etc.
                    - iii) The model_id of a model hosted inside a model repo
                        on huggingface.co, such as "internlm/internlm-chat-7b",
                        "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
                        and so on.
        model_name (str): needed when model_path is a pytorch model on
            huggingface.co, such as "internlm-chat-7b",
            "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
        tp (int): tensor parallel
        pipeline_cfg (dict): config of pipeline
    """

    def __init__(self,
                 path: str,
                 model_name: Optional[str] = None,
                 tp: int = 1,
                 pipeline_cfg=dict(),
                 **kwargs):
        import lmdeploy
        from lmdeploy import ChatTemplateConfig, TurbomindEngineConfig, pipeline, version_info

        self.str_version = lmdeploy.__version__
        self.version = version_info
        self.do_sample = kwargs.pop('do_sample', None)
        if self.do_sample is not None and self.version < (0, 6, 0):
            raise RuntimeError(
                '`do_sample` parameter is not supported by lmdeploy until '
                f'v0.6.0, but currently using lmdeloy {self.str_version}')
        super().__init__(path=path, **kwargs)
        backend_config = copy.deepcopy(pipeline_cfg)
        backend_config.update(tp=tp)
        backend_config = {
            k: v
            for k, v in backend_config.items()
            if hasattr(TurbomindEngineConfig, k)
        }
        backend_config = TurbomindEngineConfig(**backend_config)
        chat_template_config = ChatTemplateConfig(
            model_name=model_name) if model_name else None
        self.model = pipeline(
            model_path=self.path,
            backend_config=backend_config,
            chat_template_config=chat_template_config,
            log_level='WARNING')

    def generate(self,
                 inputs: Union[str, List[str]],
                 do_preprocess: bool = None,
                 skip_special_tokens: bool = False,
                 return_dict: bool = False,
                 **kwargs):
        """Return the chat completions in non-stream mode.

        Args:
            inputs (Union[str, List[str]]): input texts to be completed.
            do_preprocess (bool): whether pre-process the messages. Default to
                True, which means chat_template will be applied.
            skip_special_tokens (bool): Whether or not to remove special tokens
                in the decoding. Default to be False.
        Returns:
            (a list of/batched) text/chat completion
        """
        from lmdeploy.messages import GenerationConfig
        batched = True
        if isinstance(inputs, str):
            inputs = [inputs]
            batched = False
        prompt = inputs
        do_sample = kwargs.pop('do_sample', None)
        gen_params = self.update_gen_params(**kwargs)

        if do_sample is None:
            do_sample = self.do_sample
        if do_sample is not None and self.version < (0, 6, 0):
            raise RuntimeError(
                '`do_sample` parameter is not supported by lmdeploy until '
                f'v0.6.0, but currently using lmdeloy {self.str_version}')
        if self.version >= (0, 6, 0):
            if do_sample is None:
                do_sample = gen_params['top_k'] > 1 or gen_params[
                    'temperature'] > 0
            gen_params.update(do_sample=do_sample)

        gen_config = GenerationConfig(
            skip_special_tokens=skip_special_tokens, **gen_params)
        response = self.model.batch_infer(
            prompt, gen_config=gen_config, do_preprocess=do_preprocess)
        texts = [resp.text for resp in response]
        # remove stop_words
        texts = filter_suffix(texts, self.gen_params.get('stop_words'))
        for resp, text in zip(response, texts):
            resp.text = text
        if batched:
            return [asdict(resp)
                    for resp in response] if return_dict else texts
        return asdict(response[0]) if return_dict else texts[0]


class LMDeployServer(BaseLLM):
    """

    Args:
        path (str): The path to the model.
            It could be one of the following options:
                - i) A local directory path of a turbomind model which is
                    converted by `lmdeploy convert` command or download from
                    ii) and iii).
                - ii) The model_id of a lmdeploy-quantized model hosted
                    inside a model repo on huggingface.co, such as
                    "InternLM/internlm-chat-20b-4bit",
                    "lmdeploy/llama2-chat-70b-4bit", etc.
                - iii) The model_id of a model hosted inside a model repo
                    on huggingface.co, such as "internlm/internlm-chat-7b",
                    "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
                    and so on.
        model_name (str): needed when model_path is a pytorch model on
            huggingface.co, such as "internlm-chat-7b",
            "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
        server_name (str): host ip for serving
        server_port (int): server port
        tp (int): tensor parallel
        log_level (str): set log level whose value among
            [CRITICAL, ERROR, WARNING, INFO, DEBUG]
    """

    def __init__(self,
                 path: str,
                 model_name: Optional[str] = None,
                 server_name: str = '0.0.0.0',
                 server_port: int = 23333,
                 tp: int = 1,
                 log_level: str = 'WARNING',
                 serve_cfg=dict(),
                 **kwargs):
        super().__init__(path=path, **kwargs)
        self.model_name = model_name
        # TODO get_logger issue in multi processing
        import lmdeploy
        self.client = lmdeploy.serve(
            model_path=self.path,
            model_name=model_name,
            server_name=server_name,
            server_port=server_port,
            tp=tp,
            log_level=log_level,
            **serve_cfg)

    def generate(self,
                 inputs: Union[str, List[str]],
                 session_id: int = 2967,
                 sequence_start: bool = True,
                 sequence_end: bool = True,
                 ignore_eos: bool = False,
                 skip_special_tokens: Optional[bool] = False,
                 timeout: int = 30,
                 **kwargs) -> List[str]:
        """Start a new round conversation of a session. Return the chat
        completions in non-stream mode.

        Args:
            inputs (str, List[str]): user's prompt(s) in this round
            session_id (int): the identical id of a session
            sequence_start (bool): start flag of a session
            sequence_end (bool): end flag of a session
            ignore_eos (bool): indicator for ignoring eos
            skip_special_tokens (bool): Whether or not to remove special tokens
                in the decoding. Default to be False.
            timeout (int): max time to wait for response
        Returns:
            (a list of/batched) text/chat completion
        """

        batched = True
        if isinstance(inputs, str):
            inputs = [inputs]
            batched = False

        gen_params = self.update_gen_params(**kwargs)
        max_new_tokens = gen_params.pop('max_new_tokens')
        gen_params.update(max_tokens=max_new_tokens)

        resp = [''] * len(inputs)
        for text in self.client.completions_v1(
                self.model_name,
                inputs,
                session_id=session_id,
                sequence_start=sequence_start,
                sequence_end=sequence_end,
                stream=False,
                ignore_eos=ignore_eos,
                skip_special_tokens=skip_special_tokens,
                timeout=timeout,
                **gen_params):
            resp = [
                resp[i] + item['text']
                for i, item in enumerate(text['choices'])
            ]
        # remove stop_words
        resp = filter_suffix(resp, self.gen_params.get('stop_words'))
        if not batched:
            return resp[0]
        return resp

    def stream_chat(self,
                    inputs: List[dict],
                    session_id=0,
                    sequence_start: bool = True,
                    sequence_end: bool = True,
                    stream: bool = True,
                    ignore_eos: bool = False,
                    skip_special_tokens: Optional[bool] = False,
                    timeout: int = 30,
                    **kwargs):
        """Start a new round conversation of a session. Return the chat
        completions in stream mode.

        Args:
            session_id (int): the identical id of a session
            inputs (List[dict]): user's inputs in this round conversation
            sequence_start (bool): start flag of a session
            sequence_end (bool): end flag of a session
            stream (bool): return in a streaming format if enabled
            ignore_eos (bool): indicator for ignoring eos
            skip_special_tokens (bool): Whether or not to remove special tokens
                in the decoding. Default to be False.
            timeout (int): max time to wait for response
        Returns:
            tuple(Status, str, int): status, text/chat completion,
            generated token number
        """
        gen_params = self.update_gen_params(**kwargs)
        max_new_tokens = gen_params.pop('max_new_tokens')
        gen_params.update(max_tokens=max_new_tokens)
        prompt = self.template_parser(inputs)

        resp = ''
        finished = False
        stop_words = self.gen_params.get('stop_words')
        for text in self.client.completions_v1(
                self.model_name,
                prompt,
                session_id=session_id,
                sequence_start=sequence_start,
                sequence_end=sequence_end,
                stream=stream,
                ignore_eos=ignore_eos,
                skip_special_tokens=skip_special_tokens,
                timeout=timeout,
                **gen_params):
            resp += text['choices'][0]['text']
            if not resp:
                continue
            # remove stop_words
            for sw in stop_words:
                if sw in resp:
                    resp = filter_suffix(resp, stop_words)
                    finished = True
                    break
            yield ModelStatusCode.STREAM_ING, resp, None
            if finished:
                break
        yield ModelStatusCode.END, resp, None


class LMDeployClient(LMDeployServer):
    """

    Args:
        url (str): communicating address 'http://<ip>:<port>' of
            api_server
        model_name (str): needed when model_path is a pytorch model on
            huggingface.co, such as "internlm-chat-7b",
            "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
    """

    def __init__(self, url: str, model_name: str, **kwargs):
        BaseLLM.__init__(self, path=url, **kwargs)
        from lmdeploy.serve.openai.api_client import APIClient
        self.client = APIClient(url)
        self.model_name = model_name


class AsyncLMDeployPipeline(AsyncLLMMixin, LMDeployPipeline):
    """

    Args:
        path (str): The path to the model.
            It could be one of the following options:
                    - i) A local directory path of a turbomind model which is
                        converted by `lmdeploy convert` command or download
                        from ii) and iii).
                    - ii) The model_id of a lmdeploy-quantized model hosted
                        inside a model repo on huggingface.co, such as
                        "InternLM/internlm-chat-20b-4bit",
                        "lmdeploy/llama2-chat-70b-4bit", etc.
                    - iii) The model_id of a model hosted inside a model repo
                        on huggingface.co, such as "internlm/internlm-chat-7b",
                        "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
                        and so on.
        model_name (str): needed when model_path is a pytorch model on
            huggingface.co, such as "internlm-chat-7b",
            "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
        tp (int): tensor parallel
        pipeline_cfg (dict): config of pipeline
    """

    async def generate(self,
                       inputs: Union[str, List[str]],
                       session_ids: Union[int, List[int]] = None,
                       do_preprocess: bool = None,
                       skip_special_tokens: bool = False,
                       return_dict: bool = False,
                       **kwargs):
        """Return the chat completions in non-stream mode.

        Args:
            inputs (Union[str, List[str]]): input texts to be completed.
            do_preprocess (bool): whether pre-process the messages. Default to
                True, which means chat_template will be applied.
            skip_special_tokens (bool): Whether or not to remove special tokens
                in the decoding. Default to be False.
        Returns:
            (a list of/batched) text/chat completion
        """
        from lmdeploy.messages import GenerationConfig, Response

        batched = True
        if isinstance(inputs, str):
            inputs = [inputs]
            batched = False
        if session_ids is None:
            session_ids = list(range(len(inputs)))
        elif isinstance(session_ids, (int, str)):
            session_ids = [session_ids]
        assert len(inputs) == len(session_ids)

        prompt = inputs
        gen_params = self.update_gen_params(**kwargs)
        gen_config = GenerationConfig(
            skip_special_tokens=skip_special_tokens, **gen_params)

        async def _inner_generate(uid, text):
            resp = Response('', 0, 0, uid)
            async for out in self.model.generate(
                    text,
                    uid,
                    gen_config,
                    stream_response=True,
                    sequence_start=True,
                    sequence_end=True,
                    do_preprocess=do_preprocess,
                    **kwargs):
                resp.text += out.response
                resp.generate_token_len = out.generate_token_len
                resp.input_token_len = out.input_token_len
                resp.finish_reason = out.finish_reason
                if out.token_ids:
                    resp.token_ids.extend(out.token_ids)
                if out.logprobs:
                    if resp.logprobs is None:
                        resp.logprobs = []
                    resp.logprobs.extend(out.logprobs)
            return resp

        response = await asyncio.gather(*[
            _inner_generate(sid, inp) for sid, inp in zip(session_ids, prompt)
        ])
        texts = [resp.text for resp in response]
        # remove stop_words
        texts = filter_suffix(texts, self.gen_params.get('stop_words'))
        for resp, text in zip(response, texts):
            resp.text = text
        if batched:
            return [asdict(resp)
                    for resp in response] if return_dict else texts
        return asdict(response[0]) if return_dict else texts[0]


class AsyncLMDeployServer(AsyncLLMMixin, LMDeployServer):
    """

    Args:
        path (str): The path to the model.
            It could be one of the following options:
                - i) A local directory path of a turbomind model which is
                    converted by `lmdeploy convert` command or download from
                    ii) and iii).
                - ii) The model_id of a lmdeploy-quantized model hosted
                    inside a model repo on huggingface.co, such as
                    "InternLM/internlm-chat-20b-4bit",
                    "lmdeploy/llama2-chat-70b-4bit", etc.
                - iii) The model_id of a model hosted inside a model repo
                    on huggingface.co, such as "internlm/internlm-chat-7b",
                    "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
                    and so on.
        model_name (str): needed when model_path is a pytorch model on
            huggingface.co, such as "internlm-chat-7b",
            "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
        server_name (str): host ip for serving
        server_port (int): server port
        tp (int): tensor parallel
        log_level (str): set log level whose value among
            [CRITICAL, ERROR, WARNING, INFO, DEBUG]
    """

    async def generate(
        self,
        inputs: Union[str, List[str]],
        session_ids: Union[int, List[int]] = None,
        sequence_start: bool = True,
        sequence_end: bool = True,
        ignore_eos: bool = False,
        skip_special_tokens: Optional[bool] = False,
        timeout: int = 30,
        **kwargs,
    ):
        """Start a new round conversation of a session. Return the chat
        completions in non-stream mode.

        Args:
            inputs (str, List[str]): user's prompt(s) in this round
            session_ids (int, List[int]): session id(s)
            sequence_start (bool): start flag of a session
            sequence_end (bool): end flag of a session
            ignore_eos (bool): indicator for ignoring eos
            skip_special_tokens (bool): Whether or not to remove special tokens
                in the decoding. Default to be False.
            timeout (int): max time to wait for response
        Returns:
            (a list of/batched) text/chat completion
        """
        from lmdeploy.serve.openai.api_client import json_loads

        batched = True
        if isinstance(inputs, str):
            inputs = [inputs]
            batched = False

        gen_params = self.update_gen_params(**kwargs)
        max_new_tokens = gen_params.pop('max_new_tokens')
        gen_params.update(max_tokens=max_new_tokens)

        responses = [''] * len(inputs)
        pload = dict(
            model=self.model_name,
            prompt=inputs,
            sequence_start=sequence_start,
            sequence_end=sequence_end,
            stream=False,
            ignore_eos=ignore_eos,
            skip_special_tokens=skip_special_tokens,
            timeout=timeout,
            **gen_params)
        async with aiohttp.ClientSession(
                timeout=aiohttp.ClientTimeout(3 * 3600)) as session:
            async with session.post(
                    self.client.completions_v1_url,
                    headers=self.client.headers,
                    json=pload) as resp:
                async for chunk in resp.content:
                    if chunk:
                        decoded = chunk.decode('utf-8')
                        output = json_loads(decoded)
                        responses = [
                            response + item['text'] for response, item in zip(
                                responses, output['choices'])
                        ]
        # remove stop_words
        responses = filter_suffix(responses, self.gen_params.get('stop_words'))
        if not batched:
            return responses[0]
        return responses

    async def stream_chat(
        self,
        inputs: List[dict],
        session_id: int = None,
        sequence_start: bool = True,
        sequence_end: bool = True,
        stream: bool = True,
        ignore_eos: bool = False,
        skip_special_tokens: Optional[bool] = False,
        timeout: int = 30,
        **kwargs,
    ):
        """Start a new round conversation of a session. Return the chat
        completions in stream mode.

        Args:
            inputs (List[dict]): user's inputs in this round conversation
            session_id (int): session id
            sequence_start (bool): start flag of a session
            sequence_end (bool): end flag of a session
            stream (bool): return in a streaming format if enabled
            ignore_eos (bool): indicator for ignoring eos
            skip_special_tokens (bool): Whether or not to remove special tokens
                in the decoding. Default to be False.
            timeout (int): max time to wait for response
        Returns:
            tuple(Status, str, int): status, text/chat completion,
            generated token number
        """
        from lmdeploy.serve.openai.api_client import json_loads

        gen_params = self.update_gen_params(**kwargs)
        max_new_tokens = gen_params.pop('max_new_tokens')
        gen_params.update(max_tokens=max_new_tokens)
        prompt = self.template_parser(inputs)

        response = ''
        finished = False
        stop_words = self.gen_params.get('stop_words')

        pload = dict(
            model=self.model_name,
            prompt=prompt,
            sequence_start=sequence_start,
            sequence_end=sequence_end,
            stream=stream,
            ignore_eos=ignore_eos,
            skip_special_tokens=skip_special_tokens,
            timeout=timeout,
            **gen_params)
        async with aiohttp.ClientSession(
                timeout=aiohttp.ClientTimeout(3 * 3600)) as session:
            async with session.post(
                    self.client.completions_v1_url,
                    headers=self.client.headers,
                    json=pload) as resp:
                async for chunk in resp.content:
                    if chunk:
                        decoded = chunk.decode('utf-8')
                        if not decoded.strip() or decoded.rstrip(
                        ) == 'data: [DONE]':
                            continue
                        if decoded[:6] == 'data: ':
                            decoded = decoded[6:]
                        output = json_loads(decoded)
                        response += output['choices'][0]['text']
                        if not response:
                            continue
                        # remove stop_words
                        for sw in stop_words:
                            if sw in response:
                                response = filter_suffix(response, stop_words)
                                finished = True
                                break
                        yield ModelStatusCode.STREAM_ING, response, None
                        if finished:
                            break
                yield ModelStatusCode.END, response, None


class AsyncLMDeployClient(AsyncLMDeployServer):
    """

    Args:
        url (str): communicating address 'http://<ip>:<port>' of
            api_server
        model_name (str): needed when model_path is a pytorch model on
            huggingface.co, such as "internlm-chat-7b",
            "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
    """

    def __init__(self, url: str, model_name: str, **kwargs):
        BaseLLM.__init__(self, path=url, **kwargs)
        from lmdeploy.serve.openai.api_client import APIClient
        self.client = APIClient(url)
        self.model_name = model_name