import asyncio
from typing import (
    Optional,
    List,
    Dict,
    Any,
    AsyncIterator,
    Union,
)

from fastapi import HTTPException
from loguru import logger
from openai.types.chat import ChatCompletionMessageParam
from transformers import PreTrainedTokenizer
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams

from api.adapter import get_prompt_adapter
from api.generation import build_qwen_chat_input


class VllmEngine:
    def __init__(
        self,
        model: AsyncLLMEngine,
        tokenizer: PreTrainedTokenizer,
        model_name: str,
        prompt_name: Optional[str] = None,
        context_len: Optional[int] = -1,
    ):
        """
        Initializes the VLLMEngine object.

        Args:
            model: The AsyncLLMEngine object.
            tokenizer: The PreTrainedTokenizer object.
            model_name: The name of the model.
            prompt_name: The name of the prompt (optional).
            context_len: The length of the context (optional, default=-1).
        """
        self.model = model
        self.model_name = model_name.lower()
        self.tokenizer = tokenizer
        self.prompt_name = prompt_name.lower() if prompt_name is not None else None
        self.prompt_adapter = get_prompt_adapter(self.model_name, prompt_name=self.prompt_name)

        model_config = asyncio.run(self.model.get_model_config())
        if "qwen" in self.model_name:
            self.max_model_len = context_len if context_len > 0 else 8192
        else:
            self.max_model_len = model_config.max_model_len

    def apply_chat_template(
        self,
        messages: List[ChatCompletionMessageParam],
        max_tokens: Optional[int] = 256,
        functions: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
        tools: Optional[List[Dict[str, Any]]] = None,
    ) -> Union[str, List[int]]:
        """
        Applies a chat template to the given messages and returns the processed output.

        Args:
            messages: A list of ChatCompletionMessageParam objects representing the chat messages.
            max_tokens: The maximum number of tokens in the output (optional, default=256).
            functions: A dictionary or list of dictionaries representing the functions to be applied (optional).
            tools: A list of dictionaries representing the tools to be used (optional).

        Returns:
            Union[str, List[int]]: The processed output as a string or a list of integers.
        """
        if self.prompt_adapter.function_call_available:
            messages = self.prompt_adapter.postprocess_messages(
                messages, functions, tools,
            )
            if functions or tools:
                logger.debug(f"==== Messages with tools ====\n{messages}")

        if "chatglm3" in self.model_name:
            query, role = messages[-1]["content"], messages[-1]["role"]
            return self.tokenizer.build_chat_input(
                query, history=messages[:-1], role=role
            )["input_ids"][0].tolist()
        elif "qwen" in self.model_name:
            return build_qwen_chat_input(
                self.tokenizer,
                messages,
                self.max_model_len,
                max_tokens,
                functions,
                tools,
            )
        else:
            return self.prompt_adapter.apply_chat_template(messages)

    def convert_to_inputs(
        self,
        prompt: Optional[str] = None,
        token_ids: Optional[List[int]] = None,
        max_tokens: Optional[int] = 256,
    ) -> List[int]:
        max_input_tokens = self.max_model_len - max_tokens
        input_ids = token_ids or self.tokenizer(prompt).input_ids
        return input_ids[-max_input_tokens:]

    def generate(self, params: Dict[str, Any], request_id: str) -> AsyncIterator:
        """
        Generates text based on the given parameters and request ID.

        Args:
            params (Dict[str, Any]): A dictionary of parameters for text generation.
            request_id (str): The ID of the request.

        Yields:
            Any: The generated text.
        """
        max_tokens = params.get("max_tokens", 256)
        prompt_or_messages = params.get("prompt_or_messages")
        if isinstance(prompt_or_messages, list):
            prompt_or_messages = self.apply_chat_template(
                prompt_or_messages,
                max_tokens,
                functions=params.get("functions"),
                tools=params.get("tools"),
            )

        if isinstance(prompt_or_messages, list):
            prompt, token_ids = None, prompt_or_messages
        else:
            prompt, token_ids = prompt_or_messages, None

        token_ids = self.convert_to_inputs(prompt, token_ids, max_tokens=max_tokens)
        try:
            sampling_params = SamplingParams(
                n=params.get("n", 1),
                presence_penalty=params.get("presence_penalty", 0.),
                frequency_penalty=params.get("frequency_penalty", 0.),
                temperature=params.get("temperature", 0.9),
                top_p=params.get("top_p", 0.8),
                stop=params.get("stop", []),
                stop_token_ids=params.get("stop_token_ids", []),
                max_tokens=params.get("max_tokens", 256),
                repetition_penalty=params.get("repetition_penalty", 1.03),
                min_p=params.get("min_p", 0.0),
                best_of=params.get("best_of", 1),
                ignore_eos=params.get("ignore_eos", False),
                use_beam_search=params.get("use_beam_search", False),
                skip_special_tokens=params.get("skip_special_tokens", True),
                spaces_between_special_tokens=params.get("spaces_between_special_tokens", True),
            )
            result_generator = self.model.generate(
                prompt_or_messages if isinstance(prompt_or_messages, str) else None,
                sampling_params,
                request_id,
                token_ids,
            )
        except ValueError as e:
            raise HTTPException(status_code=400, detail=str(e)) from e

        return result_generator

    @property
    def stop(self):
        """
        Gets the stop property of the prompt adapter.

        Returns:
            The stop property of the prompt adapter, or None if it does not exist.
        """
        return self.prompt_adapter.stop if hasattr(self.prompt_adapter, "stop") else None