law-bot / app.py
anpigon's picture
refactor: Update app.py to improve chat interface functionality
0a85bef
raw
history blame
4.32 kB
import os
import gradio as gr
from kiwipiepy import Kiwi
from typing import List, Tuple, Generator, Union
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_community.document_transformers import LongContextReorder
from libs.config import STREAMING
from libs.embeddings import get_embeddings
from libs.retrievers import load_retrievers
from libs.llm import get_llm
from libs.prompt import get_prompt
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def kiwi_tokenize(text):
kiwi = Kiwi()
return [token.form for token in kiwi.tokenize(text)]
embeddings = get_embeddings()
retriever = load_retrievers(embeddings)
# μ‚¬μš© κ°€λŠ₯ν•œ λͺ¨λΈ λͺ©λ‘ (key: λͺ¨λΈ μ‹λ³„μž, value: μ‚¬μš©μžμ—κ²Œ ν‘œμ‹œν•  λ ˆμ΄λΈ”)
AVAILABLE_MODELS = {
# "gpt_3_5_turbo": "GPT-3.5 Turbo",
"gpt_4o": "GPT-4o",
"gemini_1_5_flash": "Gemini 1.5 Flash",
"claude_3_5_sonnet": "Claude 3.5 Sonnet",
# "llama3_70b": "Llama3 70b",
}
def create_rag_chain(chat_history: List[Tuple[str, str]], model: str):
langchain_messages = []
for human, ai in chat_history:
langchain_messages.append(HumanMessage(content=human))
langchain_messages.append(AIMessage(content=ai))
llm = get_llm(streaming=STREAMING).with_config(configurable={"llm": model})
prompt = get_prompt().partial(history=langchain_messages)
return (
{
"context": retriever,
# | RunnableLambda(LongContextReorder().transform_documents),
"question": RunnablePassthrough(),
}
| prompt
| llm
| StrOutputParser()
)
def get_model_key(label):
return next(key for key, value in AVAILABLE_MODELS.items() if value == label)
def respond_stream(
message: str, history: List[Tuple[str, str]], model: str
) -> Generator[str, None, None]:
rag_chain = create_rag_chain(history, model)
for chunk in rag_chain.stream(message):
yield chunk
def respond(message: str, history: List[Tuple[str, str]], model: str) -> str:
rag_chain = create_rag_chain(history, model)
return rag_chain.invoke(message)
def get_model_key(label: str) -> str:
return next(key for key, value in AVAILABLE_MODELS.items() if value == label)
def chat_function(
message: str, history: List[Tuple[str, str]], model_label: str
) -> Generator[str, None, None]:
model_key = get_model_key(model_label)
if STREAMING:
response = ""
for chunk in respond_stream(message, history, model_key):
response += chunk
yield response
else:
response = respond(message, history, model_key)
yield response
with gr.Blocks(
fill_height=True,
) as demo:
gr.Markdown("# λŒ€λ²•μ› νŒλ‘€ 상담 λ„μš°λ―Έ")
gr.Markdown(
"μ•ˆλ…•ν•˜μ„Έμš”! λŒ€λ²•μ› νŒλ‘€μ— κ΄€ν•œ μ§ˆλ¬Έμ— λ‹΅λ³€ν•΄λ“œλ¦¬λŠ” AI 상담 λ„μš°λ―Έμž…λ‹ˆλ‹€. νŒλ‘€ 검색, 해석, 적용 등에 λŒ€ν•΄ κΆκΈˆν•˜μ‹  점이 있으면 μ–Έμ œλ“  λ¬Όμ–΄λ³΄μ„Έμš”."
)
model_dropdown = gr.Dropdown(
choices=list(AVAILABLE_MODELS.values()),
label="λͺ¨λΈ 선택",
value=list(AVAILABLE_MODELS.values())[0],
)
chatbot = gr.ChatInterface(
fn=chat_function,
autofocus=True,
fill_height=True,
multimodal=False,
examples=[
[
"쀑고차 거래λ₯Ό ν–ˆμ–΄. 그런데 λΆˆλŸ‰μ„ λ°›μ•„ μ°¨μˆ˜λ¦¬μ— 500λ§Œμ›μ΄ λ“€μ—ˆμ–΄. νŒλ§€μžλŠ” 사기꾼인가? λΉ„μŠ·ν•œ 사둀λ₯Ό λͺ‡ 개 μ†Œκ°œν•΄μ€˜!!",
"GPT-4o",
"gpt_4o",
],
[
"논밭은 μ•½ 2μ²œν‰μ„ μƒ€λŠ”λ°, μ•Œκ³  λ³΄λ‹ˆ 집을 지을 수 μ—†λŠ” 땅이야. 이런 사기와 λΉ„μŠ·ν•œ κ±Έ μ•Œλ €μ€˜!",
"GPT-4o",
"gpt_4o",
],
[
"지인이 μž₯λ‚œν•˜λ‹€κ°€ λ“€μ˜€ μžˆλŠ” 칼을 νœ˜λ‘˜λŸ¬ λ‚΄ νŒ”μ΄ 20cmκ°€λŸ‰ μžμƒμ„ μž…μ—ˆμ–΄. μžκΈ°λŠ” μž₯λ‚œμ΄λΌλŠ”λ°, λΉ„μŠ·ν•œ 사둀λ₯Ό μ•Œλ €μ€˜!",
"GPT-4o",
"gpt_4o",
],
],
additional_inputs=[model_dropdown],
)
if __name__ == "__main__":
demo.launch()