|
import os |
|
import gradio as gr |
|
from kiwipiepy import Kiwi |
|
from typing import List, Tuple, Generator, Union |
|
|
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnablePassthrough, RunnableLambda |
|
from langchain_community.document_transformers import LongContextReorder |
|
|
|
from libs.config import STREAMING |
|
from libs.embeddings import get_embeddings |
|
from libs.retrievers import load_retrievers |
|
from libs.llm import get_llm |
|
from libs.prompt import get_prompt |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
def kiwi_tokenize(text): |
|
kiwi = Kiwi() |
|
return [token.form for token in kiwi.tokenize(text)] |
|
|
|
|
|
embeddings = get_embeddings() |
|
retriever = load_retrievers(embeddings) |
|
|
|
|
|
|
|
AVAILABLE_MODELS = { |
|
"gpt_3_5_turbo": "GPT-3.5 Turbo", |
|
"gpt_4o": "GPT-4o", |
|
"claude_3_5_sonnet": "Claude 3.5 Sonnet", |
|
"gemini_1_5_flash": "Gemini 1.5 Flash", |
|
"llama3_70b": "Llama3 70b", |
|
} |
|
|
|
|
|
def create_rag_chain(chat_history: List[Tuple[str, str]], model: str): |
|
llm = get_llm(streaming=STREAMING).with_config(configurable={"llm": model}) |
|
prompt = get_prompt(chat_history) |
|
|
|
return ( |
|
{ |
|
"context": retriever |
|
| RunnableLambda(LongContextReorder().transform_documents), |
|
"question": RunnablePassthrough(), |
|
} |
|
| prompt |
|
| llm |
|
| StrOutputParser() |
|
) |
|
|
|
|
|
def get_model_key(label): |
|
return next(key for key, value in AVAILABLE_MODELS.items() if value == label) |
|
|
|
|
|
def respond_stream( |
|
message: str, history: List[Tuple[str, str]], model: str |
|
) -> Generator[str, None, None]: |
|
rag_chain = create_rag_chain(history, model) |
|
for chunk in rag_chain.stream(message): |
|
yield chunk |
|
|
|
|
|
def respond(message: str, history: List[Tuple[str, str]], model: str) -> str: |
|
rag_chain = create_rag_chain(history, model) |
|
return rag_chain.invoke(message) |
|
|
|
|
|
def get_model_key(label: str) -> str: |
|
return next(key for key, value in AVAILABLE_MODELS.items() if value == label) |
|
|
|
|
|
def chat_function( |
|
message: str, history: List[Tuple[str, str]], model_label: str |
|
) -> Generator[str, None, None]: |
|
model_key = get_model_key(model_label) |
|
if STREAMING: |
|
response = "" |
|
for chunk in respond_stream(message, history, model_key): |
|
response += chunk |
|
yield response |
|
else: |
|
response = respond(message, history, model_key) |
|
yield response |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# λλ²μ νλ‘ μλ΄ λμ°λ―Έ") |
|
gr.Markdown( |
|
"μλ
νμΈμ! λλ²μ νλ‘μ κ΄ν μ§λ¬Έμ λ΅λ³ν΄λ리λ AI μλ΄ λμ°λ―Έμ
λλ€. νλ‘ κ²μ, ν΄μ, μ μ© λ±μ λν΄ κΆκΈνμ μ μ΄ μμΌλ©΄ μΈμ λ λ¬Όμ΄λ³΄μΈμ." |
|
) |
|
|
|
model_dropdown = gr.Dropdown( |
|
choices=list(AVAILABLE_MODELS.values()), |
|
label="λͺ¨λΈ μ ν", |
|
value=list(AVAILABLE_MODELS.values())[1], |
|
) |
|
|
|
chatbot = gr.ChatInterface( |
|
fn=chat_function, |
|
additional_inputs=[model_dropdown], |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|