refactor: Update llm.py and prompt.py files
Browse files- app.py +60 -16
- libs/llm.py +7 -7
app.py
CHANGED
@@ -6,7 +6,7 @@ from langchain_core.output_parsers import StrOutputParser
|
|
6 |
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
|
7 |
from langchain_community.document_transformers import LongContextReorder
|
8 |
|
9 |
-
from libs.config import
|
10 |
from libs.embeddings import get_embeddings
|
11 |
from libs.retrievers import load_retrievers
|
12 |
from libs.llm import get_llm
|
@@ -22,8 +22,17 @@ embeddings = get_embeddings()
|
|
22 |
retriever = load_retrievers(embeddings)
|
23 |
|
24 |
|
25 |
-
def
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
prompt = get_prompt(chat_history)
|
28 |
|
29 |
return (
|
@@ -33,29 +42,64 @@ def create_rag_chain(chat_history):
|
|
33 |
"question": RunnablePassthrough(),
|
34 |
}
|
35 |
| prompt
|
36 |
-
| llm
|
37 |
| StrOutputParser()
|
38 |
)
|
39 |
|
40 |
|
41 |
-
def respond_stream(message, history):
|
42 |
-
rag_chain = create_rag_chain(history)
|
43 |
-
response = ""
|
44 |
for chunk in rag_chain.stream(message):
|
45 |
-
|
46 |
-
yield response
|
47 |
|
48 |
|
49 |
-
def respond(message, history):
|
50 |
-
rag_chain = create_rag_chain(history)
|
51 |
return rag_chain.invoke(message)
|
52 |
|
53 |
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
if __name__ == "__main__":
|
61 |
demo.launch()
|
|
|
6 |
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
|
7 |
from langchain_community.document_transformers import LongContextReorder
|
8 |
|
9 |
+
from libs.config import STREAMING
|
10 |
from libs.embeddings import get_embeddings
|
11 |
from libs.retrievers import load_retrievers
|
12 |
from libs.llm import get_llm
|
|
|
22 |
retriever = load_retrievers(embeddings)
|
23 |
|
24 |
|
25 |
+
def kiwi_tokenize(text):
|
26 |
+
kiwi = Kiwi()
|
27 |
+
return [token.form for token in kiwi.tokenize(text)]
|
28 |
+
|
29 |
+
|
30 |
+
embeddings = get_embeddings()
|
31 |
+
retriever = load_retrievers(embeddings)
|
32 |
+
|
33 |
+
|
34 |
+
def create_rag_chain(chat_history, model):
|
35 |
+
llm = get_llm(streaming=STREAMING).with_config(configurable={"llm": model})
|
36 |
prompt = get_prompt(chat_history)
|
37 |
|
38 |
return (
|
|
|
42 |
"question": RunnablePassthrough(),
|
43 |
}
|
44 |
| prompt
|
45 |
+
| llm
|
46 |
| StrOutputParser()
|
47 |
)
|
48 |
|
49 |
|
50 |
+
def respond_stream(message, history, model):
|
51 |
+
rag_chain = create_rag_chain(history, model)
|
|
|
52 |
for chunk in rag_chain.stream(message):
|
53 |
+
yield chunk
|
|
|
54 |
|
55 |
|
56 |
+
def respond(message, history, model):
|
57 |
+
rag_chain = create_rag_chain(history, model)
|
58 |
return rag_chain.invoke(message)
|
59 |
|
60 |
|
61 |
+
# μ¬μ© κ°λ₯ν λͺ¨λΈ λͺ©λ‘ (key: λͺ¨λΈ μλ³μ, value: μ¬μ©μμκ² νμν λ μ΄λΈ)
|
62 |
+
AVAILABLE_MODELS = {
|
63 |
+
"gpt_3_5_turbo": "GPT-3.5 Turbo",
|
64 |
+
"gpt_4o": "GPT-4o",
|
65 |
+
"claude_3_5_sonnet": "Claude 3.5 Sonnet",
|
66 |
+
"gemini_1_5_flash": "Gemini 1.5 Flash",
|
67 |
+
"llama3_70b": "Llama3 70b",
|
68 |
+
}
|
69 |
+
|
70 |
+
|
71 |
+
def get_model_key(label):
|
72 |
+
return next(key for key, value in AVAILABLE_MODELS.items() if value == label)
|
73 |
+
|
74 |
+
|
75 |
+
def chat_function(message, history, model_label):
|
76 |
+
model_key = get_model_key(model_label)
|
77 |
+
if STREAMING:
|
78 |
+
response = ""
|
79 |
+
for chunk in respond_stream(message, history, model_key):
|
80 |
+
response += chunk
|
81 |
+
yield response
|
82 |
+
else:
|
83 |
+
response = respond(message, history, model_key)
|
84 |
+
yield response
|
85 |
+
|
86 |
+
|
87 |
+
with gr.Blocks() as demo:
|
88 |
+
gr.Markdown("# λλ²μ νλ‘ μλ΄ λμ°λ―Έ")
|
89 |
+
gr.Markdown(
|
90 |
+
"μλ
νμΈμ! λλ²μ νλ‘μ κ΄ν μ§λ¬Έμ λ΅λ³ν΄λ리λ AI μλ΄ λμ°λ―Έμ
λλ€. νλ‘ κ²μ, ν΄μ, μ μ© λ±μ λν΄ κΆκΈνμ μ μ΄ μμΌλ©΄ μΈμ λ λ¬Όμ΄λ³΄μΈμ."
|
91 |
+
)
|
92 |
+
|
93 |
+
model_dropdown = gr.Dropdown(
|
94 |
+
choices=list(AVAILABLE_MODELS.values()),
|
95 |
+
label="λͺ¨λΈ μ ν",
|
96 |
+
value=list(AVAILABLE_MODELS.values())[1],
|
97 |
+
)
|
98 |
+
|
99 |
+
chatbot = gr.ChatInterface(
|
100 |
+
fn=chat_function,
|
101 |
+
additional_inputs=[model_dropdown],
|
102 |
+
)
|
103 |
|
104 |
if __name__ == "__main__":
|
105 |
demo.launch()
|
libs/llm.py
CHANGED
@@ -16,32 +16,32 @@ class StreamCallback(BaseCallbackHandler):
|
|
16 |
|
17 |
def get_llm(streaming=True):
|
18 |
return ChatOpenAI(
|
19 |
-
model="gpt-
|
20 |
temperature=0,
|
21 |
streaming=streaming,
|
22 |
callbacks=[StreamCallback()],
|
23 |
).configurable_alternatives(
|
24 |
ConfigurableField(id="llm"),
|
25 |
-
default_key="
|
26 |
-
|
27 |
-
model="claude-3-
|
28 |
temperature=0,
|
29 |
streaming=streaming,
|
30 |
callbacks=[StreamCallback()],
|
31 |
),
|
32 |
-
|
33 |
model="gpt-3.5-turbo",
|
34 |
temperature=0,
|
35 |
streaming=streaming,
|
36 |
callbacks=[StreamCallback()],
|
37 |
),
|
38 |
-
|
39 |
model="gemini-1.5-flash",
|
40 |
temperature=0,
|
41 |
streaming=streaming,
|
42 |
callbacks=[StreamCallback()],
|
43 |
),
|
44 |
-
|
45 |
model_name="llama3-70b-8192",
|
46 |
temperature=0,
|
47 |
streaming=streaming,
|
|
|
16 |
|
17 |
def get_llm(streaming=True):
|
18 |
return ChatOpenAI(
|
19 |
+
model="gpt-4o",
|
20 |
temperature=0,
|
21 |
streaming=streaming,
|
22 |
callbacks=[StreamCallback()],
|
23 |
).configurable_alternatives(
|
24 |
ConfigurableField(id="llm"),
|
25 |
+
default_key="gpt_4o",
|
26 |
+
claude_3_5_sonnet=ChatAnthropic(
|
27 |
+
model="claude-3-5-sonnet-20240620",
|
28 |
temperature=0,
|
29 |
streaming=streaming,
|
30 |
callbacks=[StreamCallback()],
|
31 |
),
|
32 |
+
gpt_3_5_turbo=ChatOpenAI(
|
33 |
model="gpt-3.5-turbo",
|
34 |
temperature=0,
|
35 |
streaming=streaming,
|
36 |
callbacks=[StreamCallback()],
|
37 |
),
|
38 |
+
gemini_1_5_flash=GoogleGenerativeAI(
|
39 |
model="gemini-1.5-flash",
|
40 |
temperature=0,
|
41 |
streaming=streaming,
|
42 |
callbacks=[StreamCallback()],
|
43 |
),
|
44 |
+
llama3_70b=ChatGroq(
|
45 |
model_name="llama3-70b-8192",
|
46 |
temperature=0,
|
47 |
streaming=streaming,
|