import time
import uuid
from functools import partial
from typing import (
    Dict,
    Any,
    AsyncIterator,
)

import anyio
from fastapi import APIRouter, Depends
from fastapi import HTTPException, Request
from loguru import logger
from openai.types.chat import (
    ChatCompletionMessage,
    ChatCompletion,
    ChatCompletionChunk,
)
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.completion_usage import CompletionUsage
from sse_starlette import EventSourceResponse
from text_generation.types import StreamResponse, Response

from api.core.tgi import TGIEngine
from api.models import GENERATE_ENGINE
from api.utils.compat import model_dump
from api.utils.protocol import Role, ChatCompletionCreateParams
from api.utils.request import (
    check_api_key,
    handle_request,
    get_event_publisher,
)

chat_router = APIRouter(prefix="/chat")


def get_engine():
    yield GENERATE_ENGINE


@chat_router.post("/completions", dependencies=[Depends(check_api_key)])
async def create_chat_completion(
    request: ChatCompletionCreateParams,
    raw_request: Request,
    engine: TGIEngine = Depends(get_engine),
):
    if (not request.messages) or request.messages[-1]["role"] == Role.ASSISTANT:
        raise HTTPException(status_code=400, detail="Invalid request")

    request = await handle_request(request, engine.prompt_adapter.stop)
    request.max_tokens = request.max_tokens or 512

    prompt = engine.apply_chat_template(request.messages)
    include = {
        "temperature",
        "best_of",
        "repetition_penalty",
        "typical_p",
        "watermark",
    }
    params = model_dump(request, include=include)
    params.update(
        dict(
            prompt=prompt,
            do_sample=request.temperature > 1e-5,
            max_new_tokens=request.max_tokens,
            stop_sequences=request.stop,
            top_p=request.top_p if request.top_p < 1.0 else 0.99,
        )
    )
    logger.debug(f"==== request ====\n{params}")

    request_id: str = f"chatcmpl-{str(uuid.uuid4())}"

    if request.stream:
        generator = engine.generate_stream(**params)
        iterator = create_chat_completion_stream(generator, params, request_id)
        send_chan, recv_chan = anyio.create_memory_object_stream(10)
        return EventSourceResponse(
            recv_chan,
            data_sender_callable=partial(
                get_event_publisher,
                request=raw_request,
                inner_send_chan=send_chan,
                iterator=iterator,
            ),
        )

    response: Response = await engine.generate(**params)
    finish_reason = response.details.finish_reason.value
    finish_reason = "length" if finish_reason == "length" else "stop"

    message = ChatCompletionMessage(role="assistant", content=response.generated_text)

    choice = Choice(
        index=0,
        message=message,
        finish_reason=finish_reason,
        logprobs=None,
    )

    num_prompt_tokens = len(response.details.prefill)
    num_generated_tokens = response.details.generated_tokens
    usage = CompletionUsage(
        prompt_tokens=num_prompt_tokens,
        completion_tokens=num_generated_tokens,
        total_tokens=num_prompt_tokens + num_generated_tokens,
    )
    return ChatCompletion(
        id=request_id,
        choices=[choice],
        created=int(time.time()),
        model=request.model,
        object="chat.completion",
        usage=usage,
    )


async def create_chat_completion_stream(
    generator: AsyncIterator[StreamResponse], params: Dict[str, Any], request_id: str
) -> AsyncIterator[ChatCompletionChunk]:
    # First chunk with role
    choice = ChunkChoice(
        index=0,
        delta=ChoiceDelta(role="assistant", content=""),
        finish_reason=None,
        logprobs=None,
    )
    yield ChatCompletionChunk(
        id=request_id,
        choices=[choice],
        created=int(time.time()),
        model=params.get("model", "llm"),
        object="chat.completion.chunk",
    )
    async for output in generator:
        output: StreamResponse
        if output.token.special:
            continue

        choice = ChunkChoice(
            index=0,
            delta=ChoiceDelta(content=output.token.text),
            finish_reason=None,
            logprobs=None,
        )
        yield ChatCompletionChunk(
            id=request_id,
            choices=[choice],
            created=int(time.time()),
            model=params.get("model", "llm"),
            object="chat.completion.chunk",
        )

    choice = ChunkChoice(
        index=0,
        delta=ChoiceDelta(),
        finish_reason="stop",
        logprobs=None,
    )
    yield ChatCompletionChunk(
        id=request_id,
        choices=[choice],
        created=int(time.time()),
        model=params.get("model", "llm"),
        object="chat.completion.chunk",
    )