anpigon commited on
Commit
273182d
Β·
1 Parent(s): 363462e

refactor: Update llm.py and prompt.py files

Browse files
Files changed (2) hide show
  1. app.py +60 -16
  2. libs/llm.py +7 -7
app.py CHANGED
@@ -6,7 +6,7 @@ from langchain_core.output_parsers import StrOutputParser
6
  from langchain_core.runnables import RunnablePassthrough, RunnableLambda
7
  from langchain_community.document_transformers import LongContextReorder
8
 
9
- from libs.config import LLM_MODEL, STREAMING
10
  from libs.embeddings import get_embeddings
11
  from libs.retrievers import load_retrievers
12
  from libs.llm import get_llm
@@ -22,8 +22,17 @@ embeddings = get_embeddings()
22
  retriever = load_retrievers(embeddings)
23
 
24
 
25
- def create_rag_chain(chat_history):
26
- llm = get_llm(streaming=STREAMING)
 
 
 
 
 
 
 
 
 
27
  prompt = get_prompt(chat_history)
28
 
29
  return (
@@ -33,29 +42,64 @@ def create_rag_chain(chat_history):
33
  "question": RunnablePassthrough(),
34
  }
35
  | prompt
36
- | llm.with_config(configurable={"llm": LLM_MODEL})
37
  | StrOutputParser()
38
  )
39
 
40
 
41
- def respond_stream(message, history):
42
- rag_chain = create_rag_chain(history)
43
- response = ""
44
  for chunk in rag_chain.stream(message):
45
- response += chunk
46
- yield response
47
 
48
 
49
- def respond(message, history):
50
- rag_chain = create_rag_chain(history)
51
  return rag_chain.invoke(message)
52
 
53
 
54
- demo = gr.ChatInterface(
55
- respond_stream if STREAMING else respond,
56
- title="λŒ€λ²•μ› νŒλ‘€ 상담 λ„μš°λ―Έ",
57
- description="μ•ˆλ…•ν•˜μ„Έμš”! λŒ€λ²•μ› νŒλ‘€μ— κ΄€ν•œ μ§ˆλ¬Έμ— λ‹΅λ³€ν•΄λ“œλ¦¬λŠ” AI 상담 λ„μš°λ―Έμž…λ‹ˆλ‹€. νŒλ‘€ 검색, 해석, 적용 등에 λŒ€ν•΄ κΆκΈˆν•˜μ‹  점이 있으면 μ–Έμ œλ“  λ¬Όμ–΄λ³΄μ„Έμš”.",
58
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  if __name__ == "__main__":
61
  demo.launch()
 
6
  from langchain_core.runnables import RunnablePassthrough, RunnableLambda
7
  from langchain_community.document_transformers import LongContextReorder
8
 
9
+ from libs.config import STREAMING
10
  from libs.embeddings import get_embeddings
11
  from libs.retrievers import load_retrievers
12
  from libs.llm import get_llm
 
22
  retriever = load_retrievers(embeddings)
23
 
24
 
25
+ def kiwi_tokenize(text):
26
+ kiwi = Kiwi()
27
+ return [token.form for token in kiwi.tokenize(text)]
28
+
29
+
30
+ embeddings = get_embeddings()
31
+ retriever = load_retrievers(embeddings)
32
+
33
+
34
+ def create_rag_chain(chat_history, model):
35
+ llm = get_llm(streaming=STREAMING).with_config(configurable={"llm": model})
36
  prompt = get_prompt(chat_history)
37
 
38
  return (
 
42
  "question": RunnablePassthrough(),
43
  }
44
  | prompt
45
+ | llm
46
  | StrOutputParser()
47
  )
48
 
49
 
50
+ def respond_stream(message, history, model):
51
+ rag_chain = create_rag_chain(history, model)
 
52
  for chunk in rag_chain.stream(message):
53
+ yield chunk
 
54
 
55
 
56
+ def respond(message, history, model):
57
+ rag_chain = create_rag_chain(history, model)
58
  return rag_chain.invoke(message)
59
 
60
 
61
+ # μ‚¬μš© κ°€λŠ₯ν•œ λͺ¨λΈ λͺ©λ‘ (key: λͺ¨λΈ μ‹λ³„μž, value: μ‚¬μš©μžμ—κ²Œ ν‘œμ‹œν•  λ ˆμ΄λΈ”)
62
+ AVAILABLE_MODELS = {
63
+ "gpt_3_5_turbo": "GPT-3.5 Turbo",
64
+ "gpt_4o": "GPT-4o",
65
+ "claude_3_5_sonnet": "Claude 3.5 Sonnet",
66
+ "gemini_1_5_flash": "Gemini 1.5 Flash",
67
+ "llama3_70b": "Llama3 70b",
68
+ }
69
+
70
+
71
+ def get_model_key(label):
72
+ return next(key for key, value in AVAILABLE_MODELS.items() if value == label)
73
+
74
+
75
+ def chat_function(message, history, model_label):
76
+ model_key = get_model_key(model_label)
77
+ if STREAMING:
78
+ response = ""
79
+ for chunk in respond_stream(message, history, model_key):
80
+ response += chunk
81
+ yield response
82
+ else:
83
+ response = respond(message, history, model_key)
84
+ yield response
85
+
86
+
87
+ with gr.Blocks() as demo:
88
+ gr.Markdown("# λŒ€λ²•μ› νŒλ‘€ 상담 λ„μš°λ―Έ")
89
+ gr.Markdown(
90
+ "μ•ˆλ…•ν•˜μ„Έμš”! λŒ€λ²•μ› νŒλ‘€μ— κ΄€ν•œ μ§ˆλ¬Έμ— λ‹΅λ³€ν•΄λ“œλ¦¬λŠ” AI 상담 λ„μš°λ―Έμž…λ‹ˆλ‹€. νŒλ‘€ 검색, 해석, 적용 등에 λŒ€ν•΄ κΆκΈˆν•˜μ‹  점이 있으면 μ–Έμ œλ“  λ¬Όμ–΄λ³΄μ„Έμš”."
91
+ )
92
+
93
+ model_dropdown = gr.Dropdown(
94
+ choices=list(AVAILABLE_MODELS.values()),
95
+ label="λͺ¨λΈ 선택",
96
+ value=list(AVAILABLE_MODELS.values())[1],
97
+ )
98
+
99
+ chatbot = gr.ChatInterface(
100
+ fn=chat_function,
101
+ additional_inputs=[model_dropdown],
102
+ )
103
 
104
  if __name__ == "__main__":
105
  demo.launch()
libs/llm.py CHANGED
@@ -16,32 +16,32 @@ class StreamCallback(BaseCallbackHandler):
16
 
17
  def get_llm(streaming=True):
18
  return ChatOpenAI(
19
- model="gpt-4",
20
  temperature=0,
21
  streaming=streaming,
22
  callbacks=[StreamCallback()],
23
  ).configurable_alternatives(
24
  ConfigurableField(id="llm"),
25
- default_key="gpt4",
26
- claude=ChatAnthropic(
27
- model="claude-3-opus-20240229",
28
  temperature=0,
29
  streaming=streaming,
30
  callbacks=[StreamCallback()],
31
  ),
32
- gpt3=ChatOpenAI(
33
  model="gpt-3.5-turbo",
34
  temperature=0,
35
  streaming=streaming,
36
  callbacks=[StreamCallback()],
37
  ),
38
- gemini=GoogleGenerativeAI(
39
  model="gemini-1.5-flash",
40
  temperature=0,
41
  streaming=streaming,
42
  callbacks=[StreamCallback()],
43
  ),
44
- llama3=ChatGroq(
45
  model_name="llama3-70b-8192",
46
  temperature=0,
47
  streaming=streaming,
 
16
 
17
  def get_llm(streaming=True):
18
  return ChatOpenAI(
19
+ model="gpt-4o",
20
  temperature=0,
21
  streaming=streaming,
22
  callbacks=[StreamCallback()],
23
  ).configurable_alternatives(
24
  ConfigurableField(id="llm"),
25
+ default_key="gpt_4o",
26
+ claude_3_5_sonnet=ChatAnthropic(
27
+ model="claude-3-5-sonnet-20240620",
28
  temperature=0,
29
  streaming=streaming,
30
  callbacks=[StreamCallback()],
31
  ),
32
+ gpt_3_5_turbo=ChatOpenAI(
33
  model="gpt-3.5-turbo",
34
  temperature=0,
35
  streaming=streaming,
36
  callbacks=[StreamCallback()],
37
  ),
38
+ gemini_1_5_flash=GoogleGenerativeAI(
39
  model="gemini-1.5-flash",
40
  temperature=0,
41
  streaming=streaming,
42
  callbacks=[StreamCallback()],
43
  ),
44
+ llama3_70b=ChatGroq(
45
  model_name="llama3-70b-8192",
46
  temperature=0,
47
  streaming=streaming,