Kevin Hu
commited on
Commit
·
3e144e5
1
Parent(s):
847f564
refine agent (#2787)
Browse files### What problem does this PR solve?
### Type of change
- [ ] Bug Fix (non-breaking change which fixes an issue)
- [ ] New Feature (non-breaking change which adds functionality)
- [ ] Documentation Update
- [ ] Refactoring
- [x] Performance Improvement
- [ ] Other (please describe):
- agent/component/categorize.py +1 -1
- agent/component/generate.py +1 -1
- agent/component/rewrite.py +39 -7
- api/apps/canvas_app.py +1 -1
agent/component/categorize.py
CHANGED
@@ -73,7 +73,7 @@ class Categorize(Generate, ABC):
|
|
73 |
|
74 |
def _run(self, history, **kwargs):
|
75 |
input = self.get_input()
|
76 |
-
input = "Question: " + (
|
77 |
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
78 |
ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content": input}],
|
79 |
self._param.gen_conf())
|
|
|
73 |
|
74 |
def _run(self, history, **kwargs):
|
75 |
input = self.get_input()
|
76 |
+
input = "Question: " + (list(input["content"])[-1] if "content" in input else "") + "\tCategory: "
|
77 |
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
78 |
ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content": input}],
|
79 |
self._param.gen_conf())
|
agent/component/generate.py
CHANGED
@@ -101,7 +101,7 @@ class Generate(ComponentBase):
|
|
101 |
prompt = self._param.prompt
|
102 |
|
103 |
retrieval_res = self.get_input()
|
104 |
-
input = (" - "
|
105 |
for para in self._param.parameters:
|
106 |
cpn = self._canvas.get_component(para["component_id"])["obj"]
|
107 |
_, out = cpn.output(allow_partial=False)
|
|
|
101 |
prompt = self._param.prompt
|
102 |
|
103 |
retrieval_res = self.get_input()
|
104 |
+
input = (" - "+"\n - ".join([c for c in retrieval_res["content"] if isinstance(c, str)])) if "content" in retrieval_res else ""
|
105 |
for para in self._param.parameters:
|
106 |
cpn = self._canvas.get_component(para["component_id"])["obj"]
|
107 |
_, out = cpn.output(allow_partial=False)
|
agent/component/rewrite.py
CHANGED
@@ -33,7 +33,7 @@ class RewriteQuestionParam(GenerateParam):
|
|
33 |
def check(self):
|
34 |
super().check()
|
35 |
|
36 |
-
def get_prompt(self):
|
37 |
self.prompt = """
|
38 |
You are an expert at query expansion to generate a paraphrasing of a question.
|
39 |
I can't retrieval relevant information from the knowledge base by using user's question directly.
|
@@ -43,6 +43,40 @@ class RewriteQuestionParam(GenerateParam):
|
|
43 |
And return 5 versions of question and one is from translation.
|
44 |
Just list the question. No other words are needed.
|
45 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
return self.prompt
|
47 |
|
48 |
|
@@ -56,14 +90,12 @@ class RewriteQuestion(Generate, ABC):
|
|
56 |
self._loop = 0
|
57 |
raise Exception("Sorry! Nothing relevant found.")
|
58 |
self._loop += 1
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
q += c
|
63 |
-
break
|
64 |
|
65 |
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
66 |
-
ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content":
|
67 |
self._param.gen_conf())
|
68 |
self._canvas.history.pop()
|
69 |
self._canvas.history.append(("user", ans))
|
|
|
33 |
def check(self):
|
34 |
super().check()
|
35 |
|
36 |
+
def get_prompt(self, conv):
|
37 |
self.prompt = """
|
38 |
You are an expert at query expansion to generate a paraphrasing of a question.
|
39 |
I can't retrieval relevant information from the knowledge base by using user's question directly.
|
|
|
43 |
And return 5 versions of question and one is from translation.
|
44 |
Just list the question. No other words are needed.
|
45 |
"""
|
46 |
+
return f"""
|
47 |
+
Role: A helpful assistant
|
48 |
+
Task: Generate a full user question that would follow the conversation.
|
49 |
+
Requirements & Restrictions:
|
50 |
+
- Text generated MUST be in the same language of the original user's question.
|
51 |
+
- If the user's latest question is completely, don't do anything, just return the original question.
|
52 |
+
- DON'T generate anything except a refined question.
|
53 |
+
|
54 |
+
######################
|
55 |
+
-Examples-
|
56 |
+
######################
|
57 |
+
# Example 1
|
58 |
+
## Conversation
|
59 |
+
USER: What is the name of Donald Trump's father?
|
60 |
+
ASSISTANT: Fred Trump.
|
61 |
+
USER: And his mother?
|
62 |
+
###############
|
63 |
+
Output: What's the name of Donald Trump's mother?
|
64 |
+
------------
|
65 |
+
# Example 2
|
66 |
+
## Conversation
|
67 |
+
USER: What is the name of Donald Trump's father?
|
68 |
+
ASSISTANT: Fred Trump.
|
69 |
+
USER: And his mother?
|
70 |
+
ASSISTANT: Mary Trump.
|
71 |
+
User: What's her full name?
|
72 |
+
###############
|
73 |
+
Output: What's the full name of Donald Trump's mother Mary Trump?
|
74 |
+
######################
|
75 |
+
# Real Data
|
76 |
+
## Conversation
|
77 |
+
{conv}
|
78 |
+
###############
|
79 |
+
"""
|
80 |
return self.prompt
|
81 |
|
82 |
|
|
|
90 |
self._loop = 0
|
91 |
raise Exception("Sorry! Nothing relevant found.")
|
92 |
self._loop += 1
|
93 |
+
|
94 |
+
conv = self._canvas.get_history(4)
|
95 |
+
conv = "\n".join(conv)
|
|
|
|
|
96 |
|
97 |
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
98 |
+
ans = chat_mdl.chat(self._param.get_prompt(conv), [{"role": "user", "content": "Output: "}],
|
99 |
self._param.gen_conf())
|
100 |
self._canvas.history.pop()
|
101 |
self._canvas.history.append(("user", ans))
|
api/apps/canvas_app.py
CHANGED
@@ -112,7 +112,7 @@ def run():
|
|
112 |
canvas.messages.append({"role": "user", "content": req["message"], "id": message_id})
|
113 |
if len([m for m in canvas.messages if m["role"] == "user"]) > 1:
|
114 |
ten = TenantService.get_by_user_id(current_user.id)[0]
|
115 |
-
req["message"] = full_question(ten["tenant_id"], ten["llm_id"], canvas.messages)
|
116 |
canvas.add_user_input(req["message"])
|
117 |
answer = canvas.run(stream=stream)
|
118 |
print(canvas)
|
|
|
112 |
canvas.messages.append({"role": "user", "content": req["message"], "id": message_id})
|
113 |
if len([m for m in canvas.messages if m["role"] == "user"]) > 1:
|
114 |
ten = TenantService.get_by_user_id(current_user.id)[0]
|
115 |
+
#req["message"] = full_question(ten["tenant_id"], ten["llm_id"], canvas.messages)
|
116 |
canvas.add_user_input(req["message"])
|
117 |
answer = canvas.run(stream=stream)
|
118 |
print(canvas)
|