import os
import gradio as gr
from kiwipiepy import Kiwi
from typing import List, Tuple, Generator, Union

from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_community.document_transformers import LongContextReorder

from libs.config import STREAMING
from libs.embeddings import get_embeddings
from libs.retrievers import load_retrievers
from libs.llm import get_llm
from libs.prompt import get_prompt

os.environ["TOKENIZERS_PARALLELISM"] = "false"


def kiwi_tokenize(text):
    kiwi = Kiwi()
    return [token.form for token in kiwi.tokenize(text)]


embeddings = get_embeddings()
retriever = load_retrievers(embeddings)


# 사용 가능한 모델 목록 (key: 모델 식별자, value: 사용자에게 표시할 레이블)
AVAILABLE_MODELS = {
    # "gpt_3_5_turbo": "GPT-3.5 Turbo",
    "gpt_4o": "GPT-4o",
    "gemini_1_5_flash": "Gemini 1.5 Flash",
    "claude_3_5_sonnet": "Claude 3.5 Sonnet",
    # "llama3_70b": "Llama3 70b",
}


def create_rag_chain(chat_history: List[Tuple[str, str]], model: str):
    langchain_messages = []
    for human, ai in chat_history:
        langchain_messages.append(HumanMessage(content=human))
        langchain_messages.append(AIMessage(content=ai))

    llm = get_llm(streaming=STREAMING).with_config(configurable={"llm": model})
    prompt = get_prompt().partial(history=langchain_messages)

    return (
        {
            "context": retriever
            | RunnableLambda(LongContextReorder().transform_documents),
            "question": RunnablePassthrough(),
        }
        | prompt
        | llm
        | StrOutputParser()
    )


def get_model_key(label):
    return next(key for key, value in AVAILABLE_MODELS.items() if value == label)


def respond_stream(
    message: str, history: List[Tuple[str, str]], model: str
) -> Generator[str, None, None]:
    rag_chain = create_rag_chain(history, model)
    for chunk in rag_chain.stream(message):
        yield chunk


def respond(message: str, history: List[Tuple[str, str]], model: str) -> str:
    rag_chain = create_rag_chain(history, model)
    return rag_chain.invoke(message)


def get_model_key(label: str) -> str:
    return next(key for key, value in AVAILABLE_MODELS.items() if value == label)


def validate_input(message: str) -> bool:
    """입력된 메시지가 유효한지 검사합니다."""
    return bool(message.strip())


def chat_function(
    message: str, history: List[Tuple[str, str]], model_label: str
) -> Generator[str, None, None]:
    if not validate_input(message):
        yield "메시지를 입력해주세요."
        return

    model_key = get_model_key(model_label)
    if STREAMING:
        response = ""
        for chunk in respond_stream(message, history, model_key):
            response += chunk
            yield response
    else:
        response = respond(message, history, model_key)
        yield response


with gr.Blocks(
    fill_height=True,
) as demo:
    gr.Markdown("# 대법원 판례 상담 도우미")
    gr.Markdown(
        "안녕하세요! 대법원 판례에 관한 질문에 답변해드리는 AI 상담 도우미입니다. 판례 검색, 해석, 적용 등에 대해 궁금하신 점이 있으면 언제든 물어보세요."
    )

    model_dropdown = gr.Dropdown(
        choices=list(AVAILABLE_MODELS.values()),
        label="모델 선택",
        value=list(AVAILABLE_MODELS.values())[0],
    )

    chatbot = gr.ChatInterface(
        fn=chat_function,
        autofocus=True,
        fill_height=True,
        multimodal=False,
        examples=[
            [
                "중고차 거래를 했는데 불량으로 차 수리에 500만원이 들었습니다. 판매자에게 법적 책임을 물을 수 있나요? 비슷한 판례를 소개해주세요.",
                "GPT-4o",
            ],
            [
                "약 2천 평의 농지를 구매했는데, 알고 보니 주택을 지을 수 없는 땅이었습니다. 이와 유사한 부동산 사기 관련 판례를 알려주세요.",
                "GPT-4o",
            ],
            [
                "지인이 장난으로 휘두른 칼에 팔이 20cm 가량 찔렸습니다. 장난이라고 주장하는데, 이와 유사한 상해 관련 판례를 알려주세요.",
                "GPT-4o",
            ],
        ],
        additional_inputs=[model_dropdown],
    )

if __name__ == "__main__":
    demo.launch()