Kevin Hu commited on
Commit
43b4969
·
1 Parent(s): 40a1db3

refine components retrieval and rewrite (#2818)

Browse files

### What problem does this PR solve?


### Type of change

- [x] Performance Improvement

agent/component/retrieval.py CHANGED
@@ -50,12 +50,15 @@ class Retrieval(ComponentBase, ABC):
50
  component_name = "Retrieval"
51
 
52
  def _run(self, history, **kwargs):
53
- query = []
54
- for role, cnt in history[::-1][:self._param.message_history_window_size]:
55
- if role != "user":continue
56
- query.append(cnt)
57
- # query = "\n".join(query)
58
- query = query[0]
 
 
 
59
  kbs = KnowledgebaseService.get_by_ids(self._param.kb_ids)
60
  if not kbs:
61
  raise ValueError("Can't find knowledgebases by {}".format(self._param.kb_ids))
 
50
  component_name = "Retrieval"
51
 
52
  def _run(self, history, **kwargs):
53
+ # query = []
54
+ # for role, cnt in history[::-1][:self._param.message_history_window_size]:
55
+ # if role != "user":continue
56
+ # query.append(cnt)
57
+ # # query = "\n".join(query)
58
+ # query = query[0]
59
+ query = self.get_input()
60
+ query = str(query["content"][0]) if "content" in query else ""
61
+
62
  kbs = KnowledgebaseService.get_by_ids(self._param.kb_ids)
63
  if not kbs:
64
  raise ValueError("Can't find knowledgebases by {}".format(self._param.kb_ids))
agent/component/rewrite.py CHANGED
@@ -91,7 +91,11 @@ class RewriteQuestion(Generate, ABC):
91
  raise Exception("Sorry! Nothing relevant found.")
92
  self._loop += 1
93
 
94
- conv = self._canvas.get_history(4)
 
 
 
 
95
  conv = "\n".join(conv)
96
 
97
  chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
 
91
  raise Exception("Sorry! Nothing relevant found.")
92
  self._loop += 1
93
 
94
+ hist = self._canvas.get_history(4)
95
+ conv = []
96
+ for m in hist:
97
+ if m["role"] not in ["user", "assistant"]: continue
98
+ conv.append("{}: {}".format(m["role"].upper(), m["content"]))
99
  conv = "\n".join(conv)
100
 
101
  chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)