File size: 4,546 Bytes
56487d0 21c61a3 de850e8 ceaa913 de850e8 56487d0 685b1d3 56487d0 de850e8 273182d de850e8 19ceb64 de850e8 56487d0 ceaa913 8a45bc6 ceaa913 03d55cf 8a45bc6 ceaa913 273182d ceaa913 685b1d3 8a45bc6 273182d 685b1d3 56487d0 c80fa93 56487d0 273182d 56487d0 ceaa913 273182d 56487d0 273182d 21c61a3 ceaa913 273182d 56487d0 21c61a3 ceaa913 273182d 1c4aaba ceaa913 1c4aaba 273182d 0a85bef 273182d 03d55cf 273182d 0a85bef 1c4aaba 0a85bef 1c4aaba 0a85bef 1c4aaba 0a85bef 273182d 21c61a3 8a45bc6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import os
import gradio as gr
from kiwipiepy import Kiwi
from typing import List, Tuple, Generator, Union
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_community.document_transformers import LongContextReorder
from libs.config import STREAMING
from libs.embeddings import get_embeddings
from libs.retrievers import load_retrievers
from libs.llm import get_llm
from libs.prompt import get_prompt
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def kiwi_tokenize(text):
kiwi = Kiwi()
return [token.form for token in kiwi.tokenize(text)]
embeddings = get_embeddings()
retriever = load_retrievers(embeddings)
# μ¬μ© κ°λ₯ν λͺ¨λΈ λͺ©λ‘ (key: λͺ¨λΈ μλ³μ, value: μ¬μ©μμκ² νμν λ μ΄λΈ)
AVAILABLE_MODELS = {
# "gpt_3_5_turbo": "GPT-3.5 Turbo",
"gpt_4o": "GPT-4o",
"gemini_1_5_flash": "Gemini 1.5 Flash",
"claude_3_5_sonnet": "Claude 3.5 Sonnet",
# "llama3_70b": "Llama3 70b",
}
def create_rag_chain(chat_history: List[Tuple[str, str]], model: str):
langchain_messages = []
for human, ai in chat_history:
langchain_messages.append(HumanMessage(content=human))
langchain_messages.append(AIMessage(content=ai))
llm = get_llm(streaming=STREAMING).with_config(configurable={"llm": model})
prompt = get_prompt().partial(history=langchain_messages)
return (
{
"context": retriever
| RunnableLambda(LongContextReorder().transform_documents),
"question": RunnablePassthrough(),
}
| prompt
| llm
| StrOutputParser()
)
def get_model_key(label):
return next(key for key, value in AVAILABLE_MODELS.items() if value == label)
def respond_stream(
message: str, history: List[Tuple[str, str]], model: str
) -> Generator[str, None, None]:
rag_chain = create_rag_chain(history, model)
for chunk in rag_chain.stream(message):
yield chunk
def respond(message: str, history: List[Tuple[str, str]], model: str) -> str:
rag_chain = create_rag_chain(history, model)
return rag_chain.invoke(message)
def get_model_key(label: str) -> str:
return next(key for key, value in AVAILABLE_MODELS.items() if value == label)
def validate_input(message: str) -> bool:
"""μ
λ ₯λ λ©μμ§κ° μ ν¨νμ§ κ²μ¬ν©λλ€."""
return bool(message.strip())
def chat_function(
message: str, history: List[Tuple[str, str]], model_label: str
) -> Generator[str, None, None]:
if not validate_input(message):
yield "λ©μμ§λ₯Ό μ
λ ₯ν΄μ£ΌμΈμ."
return
model_key = get_model_key(model_label)
if STREAMING:
response = ""
for chunk in respond_stream(message, history, model_key):
response += chunk
yield response
else:
response = respond(message, history, model_key)
yield response
with gr.Blocks(
fill_height=True,
) as demo:
gr.Markdown("# λλ²μ νλ‘ μλ΄ λμ°λ―Έ")
gr.Markdown(
"μλ
νμΈμ! λλ²μ νλ‘μ κ΄ν μ§λ¬Έμ λ΅λ³ν΄λ리λ AI μλ΄ λμ°λ―Έμ
λλ€. νλ‘ κ²μ, ν΄μ, μ μ© λ±μ λν΄ κΆκΈνμ μ μ΄ μμΌλ©΄ μΈμ λ λ¬Όμ΄λ³΄μΈμ."
)
model_dropdown = gr.Dropdown(
choices=list(AVAILABLE_MODELS.values()),
label="λͺ¨λΈ μ ν",
value=list(AVAILABLE_MODELS.values())[0],
)
chatbot = gr.ChatInterface(
fn=chat_function,
autofocus=True,
fill_height=True,
multimodal=False,
examples=[
[
"μ€κ³ μ°¨ κ±°λλ₯Ό νλλ° λΆλμΌλ‘ μ°¨ μ리μ 500λ§μμ΄ λ€μμ΅λλ€. ν맀μμκ² λ²μ μ±
μμ λ¬Όμ μ μλμ? λΉμ·ν νλ‘λ₯Ό μκ°ν΄μ£ΌμΈμ.",
"GPT-4o",
],
[
"μ½ 2μ² νμ λμ§λ₯Ό ꡬ맀νλλ°, μκ³ λ³΄λ μ£Όνμ μ§μ μ μλ λ
μ΄μμ΅λλ€. μ΄μ μ μ¬ν λΆλμ° μ¬κΈ° κ΄λ ¨ νλ‘λ₯Ό μλ €μ£ΌμΈμ.",
"GPT-4o",
],
[
"μ§μΈμ΄ μ₯λμΌλ‘ νλλ₯Έ μΉΌμ νμ΄ 20cm κ°λ μ°λ Έμ΅λλ€. μ₯λμ΄λΌκ³ μ£Όμ₯νλλ°, μ΄μ μ μ¬ν μν΄ κ΄λ ¨ νλ‘λ₯Ό μλ €μ£ΌμΈμ.",
"GPT-4o",
],
],
additional_inputs=[model_dropdown],
)
if __name__ == "__main__":
demo.launch()
|