from math import ceil
from config import settings
from mongoengine import connect
from llm_agent import ReactAgent
from fastapi import FastAPI, Response
from chat import (
    ChatSession,
    FeedbackRequest,
    LikeDislikeRequest,
    UpdateSessionNameRequest,
    Message,
    ChatSessionSchema,
    Role,
)


app = FastAPI()
react_agent = ReactAgent()


@app.on_event("startup")
async def load_scheduler_and_DB():
    connect(settings.DB_NAME, host=settings.DB_URI, alias="default")
    print("Database connection established!!")


@app.post("/query", tags=["Chat-Model"])
def handle_query(chat_schema: ChatSessionSchema):
    if chat_schema.session_id:
        chat_session = ChatSession.objects(id=chat_schema.session_id)
        chat_session = chat_session.first()

        chat_history = chat_session.get_last_messages()
        response = react_agent.handle_query(
            session_id=chat_schema.session_id,
            query=chat_schema.query,
            chat_history=chat_history,
        )

        chat_session.add_message_with_metadata(
            role=Role.USER.value, content=chat_schema.query
        )
        chat_session.add_message_with_metadata(role=Role.MODEL.value, content=response)

    return {"response": response}


@app.post(
    "/temp_session",
    tags=["Chat-Session"],
)
def temp_session() -> dict:
    session = ChatSession()
    session.save()
    return {"message": "Session created", "session_id": session.get_id()}


@app.post("/feedback", tags=["Chat-Features"])
def feedback(feedback_request: FeedbackRequest):
    chat_session = ChatSession.objects.get(id=feedback_request.session_id)
    if chat_session:
        chat_session.feedback_message(
            feedback_request.message_id, feedback_request.feedback
        )
        return {"message": "Feedback saved successfully"}
    else:
        return Response(
            content="Chat session not found", status_code=404, media_type="text/plain"
        )


@app.post("/like", tags=["Chat-Features"])
def like_message(like_request: LikeDislikeRequest):
    chat_session = ChatSession.objects.get(id=like_request.session_id)
    if chat_session:
        chat_session.like_message(like_request.message_id)
        return {"message": "Message liked successfully"}
    else:
        return Response(
            content="Chat session not found", status_code=404, media_type="text/plain"
        )


@app.post("/dislike", tags=["Chat-Features"])
def dislike_message(dislike_request: LikeDislikeRequest):
    chat_session = ChatSession.objects.get(id=dislike_request.session_id)
    if chat_session:
        chat_session.dislike_message(dislike_request.message_id)
        return {"message": "Message disliked successfully"}
    else:
        return Response(
            content="Chat session not found", status_code=404, media_type="text/plain"
        )


@app.post(
    "/session_name_change",
    tags=["Chat-Session"],
)
def update_session_name(request: UpdateSessionNameRequest):
    try:
        chat_session = ChatSession.objects(id=request.session_id).first()
        if not chat_session:
            return Response(
                content="Session not found", status_code=404, media_type="text/plain"
            )

        chat_session.session_name = request.new_session_name
        chat_session.save()

        return {"message": "Session name updated successfully"}
    except Exception:
        return Response(
            content="Session not found", status_code=404, media_type="text/plain"
        )


@app.get(
    "/chat_session/<session_id>",
    tags=["Chat-Session"],
)
def get_chat_session(
    session_id: str,
    page: int = 1,
    size: int = 20,
):
    try:
        chat_session = ChatSession.objects.get(id=session_id)
    except BaseException:
        return Response(
            content="Chat session not found", status_code=404, media_type="text/plain"
        )
    skip = (page - 1) * size
    message_ids = [
        message.id for message in chat_session.messages[skip : skip + size]  # noqa
    ]
    messages = Message.objects(id__in=message_ids)
    serialized_messages = [
        {
            **message.to_mongo().to_dict(),
            "_id": str(message.id),
            "chat_session": (
                str(message.chat_session.id) if message.chat_session else None
            ),
        }
        for message in messages
    ]

    total_count = ChatSession.objects.get(id=session_id).count()
    total_pages = ceil(total_count / size)
    has_next_page = page < total_pages
    next_page = page + 1 if has_next_page else None

    return {
        "total_count": total_count,
        "total_pages": total_pages,
        "has_next_page": has_next_page,
        "next_page": next_page,
        "messages": serialized_messages,
    }


@app.delete(
    "/delete_session",
    tags=["Chat-Session"],
)
def delete_session(session_id: str):
    try:
        chat_session = ChatSession.objects(id=session_id).first()
        if not chat_session:
            return Response(
                content="Chat session not found",
                status_code=404,
                media_type="text/plain",
            )

        chat_session.delete()
        return {"message": "Session deleted successfully"}
    except Exception:
        raise Response(
            content="Chat session not found", status_code=404, media_type="text/plain"
        )