File size: 4,546 Bytes
56487d0
21c61a3
de850e8
ceaa913
de850e8
56487d0
685b1d3
56487d0
 
de850e8
273182d
de850e8
 
 
 
 
19ceb64
 
de850e8
 
 
 
 
 
 
 
56487d0
 
ceaa913
 
8a45bc6
ceaa913
 
03d55cf
8a45bc6
ceaa913
273182d
 
ceaa913
685b1d3
 
 
 
8a45bc6
273182d
685b1d3
56487d0
 
 
c80fa93
 
56487d0
 
 
273182d
56487d0
 
 
 
ceaa913
 
 
 
 
 
 
273182d
56487d0
273182d
21c61a3
 
ceaa913
273182d
56487d0
 
21c61a3
ceaa913
273182d
 
 
1c4aaba
 
 
 
 
ceaa913
 
 
1c4aaba
 
 
 
273182d
 
 
 
 
 
 
 
 
 
 
0a85bef
 
 
273182d
 
 
 
 
 
 
 
03d55cf
273182d
 
0a85bef
 
 
 
 
 
 
1c4aaba
0a85bef
 
 
1c4aaba
0a85bef
 
 
1c4aaba
0a85bef
 
 
 
273182d
21c61a3
 
8a45bc6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import gradio as gr
from kiwipiepy import Kiwi
from typing import List, Tuple, Generator, Union

from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_community.document_transformers import LongContextReorder

from libs.config import STREAMING
from libs.embeddings import get_embeddings
from libs.retrievers import load_retrievers
from libs.llm import get_llm
from libs.prompt import get_prompt

os.environ["TOKENIZERS_PARALLELISM"] = "false"


def kiwi_tokenize(text):
    kiwi = Kiwi()
    return [token.form for token in kiwi.tokenize(text)]


embeddings = get_embeddings()
retriever = load_retrievers(embeddings)


# μ‚¬μš© κ°€λŠ₯ν•œ λͺ¨λΈ λͺ©λ‘ (key: λͺ¨λΈ μ‹λ³„μž, value: μ‚¬μš©μžμ—κ²Œ ν‘œμ‹œν•  λ ˆμ΄λΈ”)
AVAILABLE_MODELS = {
    # "gpt_3_5_turbo": "GPT-3.5 Turbo",
    "gpt_4o": "GPT-4o",
    "gemini_1_5_flash": "Gemini 1.5 Flash",
    "claude_3_5_sonnet": "Claude 3.5 Sonnet",
    # "llama3_70b": "Llama3 70b",
}


def create_rag_chain(chat_history: List[Tuple[str, str]], model: str):
    langchain_messages = []
    for human, ai in chat_history:
        langchain_messages.append(HumanMessage(content=human))
        langchain_messages.append(AIMessage(content=ai))

    llm = get_llm(streaming=STREAMING).with_config(configurable={"llm": model})
    prompt = get_prompt().partial(history=langchain_messages)

    return (
        {
            "context": retriever
            | RunnableLambda(LongContextReorder().transform_documents),
            "question": RunnablePassthrough(),
        }
        | prompt
        | llm
        | StrOutputParser()
    )


def get_model_key(label):
    return next(key for key, value in AVAILABLE_MODELS.items() if value == label)


def respond_stream(
    message: str, history: List[Tuple[str, str]], model: str
) -> Generator[str, None, None]:
    rag_chain = create_rag_chain(history, model)
    for chunk in rag_chain.stream(message):
        yield chunk


def respond(message: str, history: List[Tuple[str, str]], model: str) -> str:
    rag_chain = create_rag_chain(history, model)
    return rag_chain.invoke(message)


def get_model_key(label: str) -> str:
    return next(key for key, value in AVAILABLE_MODELS.items() if value == label)


def validate_input(message: str) -> bool:
    """μž…λ ₯된 λ©”μ‹œμ§€κ°€ μœ νš¨ν•œμ§€ κ²€μ‚¬ν•©λ‹ˆλ‹€."""
    return bool(message.strip())


def chat_function(
    message: str, history: List[Tuple[str, str]], model_label: str
) -> Generator[str, None, None]:
    if not validate_input(message):
        yield "λ©”μ‹œμ§€λ₯Ό μž…λ ₯ν•΄μ£Όμ„Έμš”."
        return

    model_key = get_model_key(model_label)
    if STREAMING:
        response = ""
        for chunk in respond_stream(message, history, model_key):
            response += chunk
            yield response
    else:
        response = respond(message, history, model_key)
        yield response


with gr.Blocks(
    fill_height=True,
) as demo:
    gr.Markdown("# λŒ€λ²•μ› νŒλ‘€ 상담 λ„μš°λ―Έ")
    gr.Markdown(
        "μ•ˆλ…•ν•˜μ„Έμš”! λŒ€λ²•μ› νŒλ‘€μ— κ΄€ν•œ μ§ˆλ¬Έμ— λ‹΅λ³€ν•΄λ“œλ¦¬λŠ” AI 상담 λ„μš°λ―Έμž…λ‹ˆλ‹€. νŒλ‘€ 검색, 해석, 적용 등에 λŒ€ν•΄ κΆκΈˆν•˜μ‹  점이 있으면 μ–Έμ œλ“  λ¬Όμ–΄λ³΄μ„Έμš”."
    )

    model_dropdown = gr.Dropdown(
        choices=list(AVAILABLE_MODELS.values()),
        label="λͺ¨λΈ 선택",
        value=list(AVAILABLE_MODELS.values())[0],
    )

    chatbot = gr.ChatInterface(
        fn=chat_function,
        autofocus=True,
        fill_height=True,
        multimodal=False,
        examples=[
            [
                "쀑고차 거래λ₯Ό ν–ˆλŠ”λ° λΆˆλŸ‰μœΌλ‘œ μ°¨ μˆ˜λ¦¬μ— 500λ§Œμ›μ΄ λ“€μ—ˆμŠ΅λ‹ˆλ‹€. νŒλ§€μžμ—κ²Œ 법적 μ±…μž„μ„ 물을 수 μžˆλ‚˜μš”? λΉ„μŠ·ν•œ νŒλ‘€λ₯Ό μ†Œκ°œν•΄μ£Όμ„Έμš”.",
                "GPT-4o",
            ],
            [
                "μ•½ 2천 ν‰μ˜ 농지λ₯Ό κ΅¬λ§€ν–ˆλŠ”λ°, μ•Œκ³  λ³΄λ‹ˆ 주택을 지을 수 μ—†λŠ” λ•…μ΄μ—ˆμŠ΅λ‹ˆλ‹€. 이와 μœ μ‚¬ν•œ 뢀동산 사기 κ΄€λ ¨ νŒλ‘€λ₯Ό μ•Œλ €μ£Όμ„Έμš”.",
                "GPT-4o",
            ],
            [
                "지인이 μž₯λ‚œμœΌλ‘œ νœ˜λ‘λ₯Έ 칼에 νŒ”μ΄ 20cm κ°€λŸ‰ μ°”λ ΈμŠ΅λ‹ˆλ‹€. μž₯λ‚œμ΄λΌκ³  μ£Όμž₯ν•˜λŠ”λ°, 이와 μœ μ‚¬ν•œ 상해 κ΄€λ ¨ νŒλ‘€λ₯Ό μ•Œλ €μ£Όμ„Έμš”.",
                "GPT-4o",
            ],
        ],
        additional_inputs=[model_dropdown],
    )

if __name__ == "__main__":
    demo.launch()