law-bot / app.py
anpigon's picture
refactor: Update app.py and libs/llm.py to improve model selection and configuration
ceaa913
raw
history blame
3.16 kB
import os
import gradio as gr
from kiwipiepy import Kiwi
from typing import List, Tuple, Generator, Union
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_community.document_transformers import LongContextReorder
from libs.config import STREAMING
from libs.embeddings import get_embeddings
from libs.retrievers import load_retrievers
from libs.llm import get_llm
from libs.prompt import get_prompt
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def kiwi_tokenize(text):
kiwi = Kiwi()
return [token.form for token in kiwi.tokenize(text)]
embeddings = get_embeddings()
retriever = load_retrievers(embeddings)
# μ‚¬μš© κ°€λŠ₯ν•œ λͺ¨λΈ λͺ©λ‘ (key: λͺ¨λΈ μ‹λ³„μž, value: μ‚¬μš©μžμ—κ²Œ ν‘œμ‹œν•  λ ˆμ΄λΈ”)
AVAILABLE_MODELS = {
"gpt_3_5_turbo": "GPT-3.5 Turbo",
"gpt_4o": "GPT-4o",
"claude_3_5_sonnet": "Claude 3.5 Sonnet",
"gemini_1_5_flash": "Gemini 1.5 Flash",
"llama3_70b": "Llama3 70b",
}
def create_rag_chain(chat_history: List[Tuple[str, str]], model: str):
llm = get_llm(streaming=STREAMING).with_config(configurable={"llm": model})
prompt = get_prompt(chat_history)
return (
{
"context": retriever
| RunnableLambda(LongContextReorder().transform_documents),
"question": RunnablePassthrough(),
}
| prompt
| llm
| StrOutputParser()
)
def get_model_key(label):
return next(key for key, value in AVAILABLE_MODELS.items() if value == label)
def respond_stream(
message: str, history: List[Tuple[str, str]], model: str
) -> Generator[str, None, None]:
rag_chain = create_rag_chain(history, model)
for chunk in rag_chain.stream(message):
yield chunk
def respond(message: str, history: List[Tuple[str, str]], model: str) -> str:
rag_chain = create_rag_chain(history, model)
return rag_chain.invoke(message)
def get_model_key(label: str) -> str:
return next(key for key, value in AVAILABLE_MODELS.items() if value == label)
def chat_function(
message: str, history: List[Tuple[str, str]], model_label: str
) -> Generator[str, None, None]:
model_key = get_model_key(model_label)
if STREAMING:
response = ""
for chunk in respond_stream(message, history, model_key):
response += chunk
yield response
else:
response = respond(message, history, model_key)
yield response
with gr.Blocks() as demo:
gr.Markdown("# λŒ€λ²•μ› νŒλ‘€ 상담 λ„μš°λ―Έ")
gr.Markdown(
"μ•ˆλ…•ν•˜μ„Έμš”! λŒ€λ²•μ› νŒλ‘€μ— κ΄€ν•œ μ§ˆλ¬Έμ— λ‹΅λ³€ν•΄λ“œλ¦¬λŠ” AI 상담 λ„μš°λ―Έμž…λ‹ˆλ‹€. νŒλ‘€ 검색, 해석, 적용 등에 λŒ€ν•΄ κΆκΈˆν•˜μ‹  점이 있으면 μ–Έμ œλ“  λ¬Όμ–΄λ³΄μ„Έμš”."
)
model_dropdown = gr.Dropdown(
choices=list(AVAILABLE_MODELS.values()),
label="λͺ¨λΈ 선택",
value=list(AVAILABLE_MODELS.values())[1],
)
chatbot = gr.ChatInterface(
fn=chat_function,
additional_inputs=[model_dropdown],
)
if __name__ == "__main__":
demo.launch()