Kevin Hu commited on
Commit
1b1a5b7
·
1 Parent(s): faaabea

Support iframe chatbot. (#3961)

Browse files

### What problem does this PR solve?

#3909

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

agent/canvas.py CHANGED
@@ -330,4 +330,7 @@ class Canvas(ABC):
330
  q["value"] = v
331
 
332
  def get_preset_param(self):
333
- return self.components["begin"]["obj"]._param.query
 
 
 
 
330
  q["value"] = v
331
 
332
  def get_preset_param(self):
333
+ return self.components["begin"]["obj"]._param.query
334
+
335
+ def get_component_input_elements(self, cpnnm):
336
+ return self.components["begin"]["obj"].get_input_elements()
agent/component/base.py CHANGED
@@ -476,7 +476,7 @@ class ComponentBase(ABC):
476
  self._param.inputs.append({"component_id": q["component_id"],
477
  "content": "\n".join(
478
  [str(d["content"]) for d in outs[-1].to_dict('records')])})
479
- elif q["value"]:
480
  self._param.inputs.append({"component_id": None, "content": q["value"]})
481
  outs.append(pd.DataFrame([{"content": q["value"]}]))
482
  if outs:
@@ -526,6 +526,21 @@ class ComponentBase(ABC):
526
 
527
  return df
528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
  def get_stream_input(self):
530
  reversed_cpnts = []
531
  if len(self._canvas.path) > 1:
 
476
  self._param.inputs.append({"component_id": q["component_id"],
477
  "content": "\n".join(
478
  [str(d["content"]) for d in outs[-1].to_dict('records')])})
479
+ elif q.get("value"):
480
  self._param.inputs.append({"component_id": None, "content": q["value"]})
481
  outs.append(pd.DataFrame([{"content": q["value"]}]))
482
  if outs:
 
526
 
527
  return df
528
 
529
+ def get_input_elements(self):
530
+ assert self._param.query, "Please identify input parameters firstly."
531
+ eles = []
532
+ for q in self._param.query:
533
+ if q.get("component_id"):
534
+ if q["component_id"].split("@")[0].lower().find("begin") >= 0:
535
+ cpn_id, key = q["component_id"].split("@")
536
+ eles.extend(self._canvas.get_component(cpn_id)["obj"]._param.query)
537
+ continue
538
+
539
+ eles.append({"key": q["key"], "component_id": q["component_id"]})
540
+ else:
541
+ eles.append({"key": q["key"]})
542
+ return eles
543
+
544
  def get_stream_input(self):
545
  reversed_cpnts = []
546
  if len(self._canvas.path) > 1:
agent/component/generate.py CHANGED
@@ -17,6 +17,7 @@ import re
17
  from functools import partial
18
  import pandas as pd
19
  from api.db import LLMType
 
20
  from api.db.services.dialog_service import message_fit_in
21
  from api.db.services.llm_service import LLMBundle
22
  from api import settings
@@ -104,9 +105,16 @@ class Generate(ComponentBase):
104
  if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
105
  answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
106
  res = {"content": answer, "reference": reference}
 
107
 
108
  return res
109
 
 
 
 
 
 
 
110
  def _run(self, history, **kwargs):
111
  chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
112
  prompt = self._param.prompt
 
17
  from functools import partial
18
  import pandas as pd
19
  from api.db import LLMType
20
+ from api.db.services.conversation_service import structure_answer
21
  from api.db.services.dialog_service import message_fit_in
22
  from api.db.services.llm_service import LLMBundle
23
  from api import settings
 
105
  if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
106
  answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
107
  res = {"content": answer, "reference": reference}
108
+ res = structure_answer(None, res, "", "")
109
 
110
  return res
111
 
112
+ def get_input_elements(self):
113
+ if self._param.parameters:
114
+ return self._param.parameters
115
+
116
+ return [{"key": "input"}]
117
+
118
  def _run(self, history, **kwargs):
119
  chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
120
  prompt = self._param.prompt
api/apps/canvas_app.py CHANGED
@@ -186,6 +186,26 @@ def reset():
186
  return server_error_response(e)
187
 
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  @manager.route('/test_db_connect', methods=['POST']) # noqa: F821
190
  @validate_request("db_type", "database", "username", "host", "port", "password")
191
  @login_required
 
186
  return server_error_response(e)
187
 
188
 
189
+ @manager.route('/input_elements', methods=['GET']) # noqa: F821
190
+ @validate_request("id", "component_id")
191
+ @login_required
192
+ def input_elements():
193
+ req = request.json
194
+ try:
195
+ e, user_canvas = UserCanvasService.get_by_id(req["id"])
196
+ if not e:
197
+ return get_data_error_result(message="canvas not found.")
198
+ if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
199
+ return get_json_result(
200
+ data=False, message='Only owner of canvas authorized for this operation.',
201
+ code=RetCode.OPERATING_ERROR)
202
+
203
+ canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
204
+ return get_json_result(data=canvas.get_component_input_elements(req["component_id"]))
205
+ except Exception as e:
206
+ return server_error_response(e)
207
+
208
+
209
  @manager.route('/test_db_connect', methods=['POST']) # noqa: F821
210
  @validate_request("db_type", "database", "username", "host", "port", "password")
211
  @login_required
api/apps/conversation_app.py CHANGED
@@ -18,7 +18,7 @@ import re
18
  import traceback
19
  from copy import deepcopy
20
 
21
- from api.db.services.conversation_service import ConversationService
22
  from api.db.services.user_service import UserTenantService
23
  from flask import request, Response
24
  from flask_login import login_required, current_user
@@ -90,6 +90,21 @@ def get():
90
  return get_json_result(
91
  data=False, message='Only owner of conversation authorized for this operation.',
92
  code=settings.RetCode.OPERATING_ERROR)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  conv = conv.to_dict()
94
  return get_json_result(data=conv)
95
  except Exception as e:
@@ -132,6 +147,7 @@ def list_convsersation():
132
  dialog_id=dialog_id,
133
  order_by=ConversationService.model.create_time,
134
  reverse=True)
 
135
  convs = [d.to_dict() for d in convs]
136
  return get_json_result(data=convs)
137
  except Exception as e:
@@ -164,24 +180,29 @@ def completion():
164
 
165
  if not conv.reference:
166
  conv.reference = []
167
- conv.message.append({"role": "assistant", "content": "", "id": message_id})
168
- conv.reference.append({"chunks": [], "doc_aggs": []})
169
-
170
- def fillin_conv(ans):
171
- nonlocal conv, message_id
172
- if not conv.reference:
173
- conv.reference.append(ans["reference"])
174
- else:
175
- conv.reference[-1] = ans["reference"]
176
- conv.message[-1] = {"role": "assistant", "content": ans["answer"],
177
- "id": message_id, "prompt": ans.get("prompt", "")}
178
- ans["id"] = message_id
 
 
179
 
 
 
 
180
  def stream():
181
  nonlocal dia, msg, req, conv
182
  try:
183
  for ans in chat(dia, msg, True, **req):
184
- fillin_conv(ans)
185
  yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
186
  ConversationService.update_by_id(conv.id, conv.to_dict())
187
  except Exception as e:
@@ -202,8 +223,7 @@ def completion():
202
  else:
203
  answer = None
204
  for ans in chat(dia, msg, **req):
205
- answer = ans
206
- fillin_conv(ans)
207
  ConversationService.update_by_id(conv.id, conv.to_dict())
208
  break
209
  return get_json_result(data=answer)
 
18
  import traceback
19
  from copy import deepcopy
20
 
21
+ from api.db.services.conversation_service import ConversationService, structure_answer
22
  from api.db.services.user_service import UserTenantService
23
  from flask import request, Response
24
  from flask_login import login_required, current_user
 
90
  return get_json_result(
91
  data=False, message='Only owner of conversation authorized for this operation.',
92
  code=settings.RetCode.OPERATING_ERROR)
93
+
94
+ def get_value(d, k1, k2):
95
+ return d.get(k1, d.get(k2))
96
+
97
+ for ref in conv.reference:
98
+ ref["chunks"] = [{
99
+ "id": get_value(ck, "chunk_id", "id"),
100
+ "content": get_value(ck, "content", "content_with_weight"),
101
+ "document_id": get_value(ck, "doc_id", "document_id"),
102
+ "document_name": get_value(ck, "docnm_kwd", "document_name"),
103
+ "dataset_id": get_value(ck, "kb_id", "dataset_id"),
104
+ "image_id": get_value(ck, "image_id", "img_id"),
105
+ "positions": get_value(ck, "positions", "position_int"),
106
+ } for ck in ref.get("chunks", [])]
107
+
108
  conv = conv.to_dict()
109
  return get_json_result(data=conv)
110
  except Exception as e:
 
147
  dialog_id=dialog_id,
148
  order_by=ConversationService.model.create_time,
149
  reverse=True)
150
+
151
  convs = [d.to_dict() for d in convs]
152
  return get_json_result(data=convs)
153
  except Exception as e:
 
180
 
181
  if not conv.reference:
182
  conv.reference = []
183
+ else:
184
+ def get_value(d, k1, k2):
185
+ return d.get(k1, d.get(k2))
186
+
187
+ for ref in conv.reference:
188
+ ref["chunks"] = [{
189
+ "id": get_value(ck, "chunk_id", "id"),
190
+ "content": get_value(ck, "content", "content_with_weight"),
191
+ "document_id": get_value(ck, "doc_id", "document_id"),
192
+ "document_name": get_value(ck, "docnm_kwd", "document_name"),
193
+ "dataset_id": get_value(ck, "kb_id", "dataset_id"),
194
+ "image_id": get_value(ck, "image_id", "img_id"),
195
+ "positions": get_value(ck, "positions", "position_int"),
196
+ } for ck in ref.get("chunks", [])]
197
 
198
+ if not conv.reference:
199
+ conv.reference = []
200
+ conv.reference.append({"chunks": [], "doc_aggs": []})
201
  def stream():
202
  nonlocal dia, msg, req, conv
203
  try:
204
  for ans in chat(dia, msg, True, **req):
205
+ ans = structure_answer(conv, ans, message_id, conv.id)
206
  yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
207
  ConversationService.update_by_id(conv.id, conv.to_dict())
208
  except Exception as e:
 
223
  else:
224
  answer = None
225
  for ans in chat(dia, msg, **req):
226
+ answer = structure_answer(conv, ans, message_id, req["conversation_id"])
 
227
  ConversationService.update_by_id(conv.id, conv.to_dict())
228
  break
229
  return get_json_result(data=answer)
api/apps/sdk/session.py CHANGED
@@ -112,6 +112,11 @@ def update(tenant_id, chat_id, session_id):
112
  @token_required
113
  def chat_completion(tenant_id, chat_id):
114
  req = request.json
 
 
 
 
 
115
  if req.get("stream", True):
116
  resp = Response(rag_completion(tenant_id, chat_id, **req), mimetype="text/event-stream")
117
  resp.headers.add_header("Cache-control", "no-cache")
@@ -133,6 +138,11 @@ def chat_completion(tenant_id, chat_id):
133
  @token_required
134
  def agent_completions(tenant_id, agent_id):
135
  req = request.json
 
 
 
 
 
136
  if req.get("stream", True):
137
  resp = Response(agent_completion(tenant_id, agent_id, **req), mimetype="text/event-stream")
138
  resp.headers.add_header("Cache-control", "no-cache")
 
112
  @token_required
113
  def chat_completion(tenant_id, chat_id):
114
  req = request.json
115
+ if not DialogService.query(tenant_id=tenant_id,id=chat_id,status=StatusEnum.VALID.value):
116
+ return get_error_data_result(f"You don't own the chat {chat_id}")
117
+ if req.get("session_id"):
118
+ if not ConversationService.query(id=req["session_id"],dialog_id=chat_id):
119
+ return get_error_data_result(f"You don't own the session {req['session_id']}")
120
  if req.get("stream", True):
121
  resp = Response(rag_completion(tenant_id, chat_id, **req), mimetype="text/event-stream")
122
  resp.headers.add_header("Cache-control", "no-cache")
 
138
  @token_required
139
  def agent_completions(tenant_id, agent_id):
140
  req = request.json
141
+ if not UserCanvasService.query(user_id=tenant_id,id=agent_id):
142
+ return get_error_data_result(f"You don't own the agent {agent_id}")
143
+ if req.get("session_id"):
144
+ if not API4ConversationService.query(id=req["session_id"],dialog_id=agent_id):
145
+ return get_error_data_result(f"You don't own the session {req['session_id']}")
146
  if req.get("stream", True):
147
  resp = Response(agent_completion(tenant_id, agent_id, **req), mimetype="text/event-stream")
148
  resp.headers.add_header("Cache-control", "no-cache")
api/db/services/canvas_service.py CHANGED
@@ -14,6 +14,7 @@
14
  # limitations under the License.
15
  #
16
  import json
 
17
  from uuid import uuid4
18
  from agent.canvas import Canvas
19
  from api.db.db_models import DB, CanvasTemplate, UserCanvas, API4Conversation
@@ -58,6 +59,8 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw
58
  if not isinstance(cvs.dsl, str):
59
  cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
60
  canvas = Canvas(cvs.dsl, tenant_id)
 
 
61
 
62
  if not session_id:
63
  session_id = get_uuid()
@@ -84,40 +87,24 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw
84
  return
85
  conv = API4Conversation(**conv)
86
  else:
87
- session_id = session_id
88
  e, conv = API4ConversationService.get_by_id(session_id)
89
  assert e, "Session not found!"
90
  canvas = Canvas(json.dumps(conv.dsl), tenant_id)
91
-
92
- if not conv.message:
93
- conv.message = []
94
- messages = conv.message
95
- question = {
96
- "role": "user",
97
- "content": question,
98
- "id": str(uuid4())
99
- }
100
- messages.append(question)
101
- msg = []
102
- for m in messages:
103
- if m["role"] == "system":
104
- continue
105
- if m["role"] == "assistant" and not msg:
106
- continue
107
- msg.append(m)
108
- if not msg[-1].get("id"):
109
- msg[-1]["id"] = get_uuid()
110
- message_id = msg[-1]["id"]
111
-
112
- if not conv.reference:
113
- conv.reference = []
114
- conv.message.append({"role": "assistant", "content": "", "id": message_id})
115
- conv.reference.append({"chunks": [], "doc_aggs": []})
116
 
117
  final_ans = {"reference": [], "content": ""}
118
 
119
- canvas.add_user_input(msg[-1]["content"])
120
-
121
  if stream:
122
  try:
123
  for ans in canvas.run(stream=stream):
@@ -141,6 +128,7 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw
141
  conv.dsl = json.loads(str(canvas))
142
  API4ConversationService.append_message(conv.id, conv.to_dict())
143
  except Exception as e:
 
144
  conv.dsl = json.loads(str(canvas))
145
  API4ConversationService.append_message(conv.id, conv.to_dict())
146
  yield "data:" + json.dumps({"code": 500, "message": str(e),
 
14
  # limitations under the License.
15
  #
16
  import json
17
+ import traceback
18
  from uuid import uuid4
19
  from agent.canvas import Canvas
20
  from api.db.db_models import DB, CanvasTemplate, UserCanvas, API4Conversation
 
59
  if not isinstance(cvs.dsl, str):
60
  cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
61
  canvas = Canvas(cvs.dsl, tenant_id)
62
+ canvas.reset()
63
+ message_id = str(uuid4())
64
 
65
  if not session_id:
66
  session_id = get_uuid()
 
87
  return
88
  conv = API4Conversation(**conv)
89
  else:
 
90
  e, conv = API4ConversationService.get_by_id(session_id)
91
  assert e, "Session not found!"
92
  canvas = Canvas(json.dumps(conv.dsl), tenant_id)
93
+ canvas.messages.append({"role": "user", "content": question, "id": message_id})
94
+ canvas.add_user_input(question)
95
+ if not conv.message:
96
+ conv.message = []
97
+ conv.message.append({
98
+ "role": "user",
99
+ "content": question,
100
+ "id": message_id
101
+ })
102
+ if not conv.reference:
103
+ conv.reference = []
104
+ conv.reference.append({"chunks": [], "doc_aggs": []})
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  final_ans = {"reference": [], "content": ""}
107
 
 
 
108
  if stream:
109
  try:
110
  for ans in canvas.run(stream=stream):
 
128
  conv.dsl = json.loads(str(canvas))
129
  API4ConversationService.append_message(conv.id, conv.to_dict())
130
  except Exception as e:
131
+ traceback.print_exc()
132
  conv.dsl = json.loads(str(canvas))
133
  API4ConversationService.append_message(conv.id, conv.to_dict())
134
  yield "data:" + json.dumps({"code": 500, "message": str(e),
api/db/services/conversation_service.py CHANGED
@@ -21,7 +21,6 @@ from api.db.services.common_service import CommonService
21
  from api.db.services.dialog_service import DialogService, chat
22
  from api.utils import get_uuid
23
  import json
24
- from copy import deepcopy
25
 
26
 
27
  class ConversationService(CommonService):
@@ -49,30 +48,35 @@ def structure_answer(conv, ans, message_id, session_id):
49
  reference = ans["reference"]
50
  if not isinstance(reference, dict):
51
  reference = {}
52
- temp_reference = deepcopy(ans["reference"])
53
- if not conv.reference:
54
- conv.reference.append(temp_reference)
55
- else:
56
- conv.reference[-1] = temp_reference
57
- conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
58
 
 
 
59
  chunk_list = [{
60
- "id": chunk["chunk_id"],
61
- "content": chunk.get("content") if chunk.get("content") else chunk.get("content_with_content"),
62
- "document_id": chunk["doc_id"],
63
- "document_name": chunk["docnm_kwd"],
64
- "dataset_id": chunk["kb_id"],
65
- "image_id": chunk["image_id"],
66
- "similarity": chunk["similarity"],
67
- "vector_similarity": chunk["vector_similarity"],
68
- "term_similarity": chunk["term_similarity"],
69
- "positions": chunk["positions"],
70
  } for chunk in reference.get("chunks", [])]
71
 
72
  reference["chunks"] = chunk_list
73
  ans["id"] = message_id
74
  ans["session_id"] = session_id
75
 
 
 
 
 
 
 
 
 
 
 
 
76
  return ans
77
 
78
 
@@ -199,7 +203,6 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg
199
 
200
  if not conv.reference:
201
  conv.reference = []
202
- conv.message.append({"role": "assistant", "content": "", "id": message_id})
203
  conv.reference.append({"chunks": [], "doc_aggs": []})
204
 
205
  if stream:
 
21
  from api.db.services.dialog_service import DialogService, chat
22
  from api.utils import get_uuid
23
  import json
 
24
 
25
 
26
  class ConversationService(CommonService):
 
48
  reference = ans["reference"]
49
  if not isinstance(reference, dict):
50
  reference = {}
51
+ ans["reference"] = {}
 
 
 
 
 
52
 
53
+ def get_value(d, k1, k2):
54
+ return d.get(k1, d.get(k2))
55
  chunk_list = [{
56
+ "id": get_value(chunk, "chunk_id", "id"),
57
+ "content": get_value(chunk, "content", "content_with_weight"),
58
+ "document_id": get_value(chunk, "doc_id", "document_id"),
59
+ "document_name": get_value(chunk, "docnm_kwd", "document_name"),
60
+ "dataset_id": get_value(chunk, "kb_id", "dataset_id"),
61
+ "image_id": get_value(chunk, "image_id", "img_id"),
62
+ "positions": get_value(chunk, "positions", "position_int"),
 
 
 
63
  } for chunk in reference.get("chunks", [])]
64
 
65
  reference["chunks"] = chunk_list
66
  ans["id"] = message_id
67
  ans["session_id"] = session_id
68
 
69
+ if not conv:
70
+ return ans
71
+
72
+ if not conv.message:
73
+ conv.message = []
74
+ if not conv.message or conv.message[-1].get("role", "") != "assistant":
75
+ conv.message.append({"role": "assistant", "content": ans["answer"], "id": message_id})
76
+ else:
77
+ conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
78
+ if conv.reference:
79
+ conv.reference[-1] = reference
80
  return ans
81
 
82
 
 
203
 
204
  if not conv.reference:
205
  conv.reference = []
 
206
  conv.reference.append({"chunks": [], "doc_aggs": []})
207
 
208
  if stream:
api/db/services/dialog_service.py CHANGED
@@ -18,6 +18,7 @@ import binascii
18
  import os
19
  import json
20
  import re
 
21
  from copy import deepcopy
22
  from timeit import default_timer as timer
23
  import datetime
@@ -108,6 +109,32 @@ def llm_id2llm_type(llm_id):
108
  return llm["model_type"].strip(",")[-1]
109
 
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  def chat(dialog, messages, stream=True, **kwargs):
112
  assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
113
  st = timer()
@@ -195,32 +222,7 @@ def chat(dialog, messages, stream=True, **kwargs):
195
  dialog.vector_similarity_weight,
196
  doc_ids=attachments,
197
  top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
198
-
199
- # Group chunks by document ID
200
- doc_chunks = {}
201
- for ck in kbinfos["chunks"]:
202
- doc_id = ck["doc_id"]
203
- if doc_id not in doc_chunks:
204
- doc_chunks[doc_id] = []
205
- doc_chunks[doc_id].append(ck["content_with_weight"])
206
-
207
- # Create knowledges list with grouped chunks
208
- knowledges = []
209
- for doc_id, chunks in doc_chunks.items():
210
- # Find the corresponding document name
211
- doc_name = next((d["doc_name"] for d in kbinfos.get("doc_aggs", []) if d["doc_id"] == doc_id), doc_id)
212
-
213
- # Create a header for the document
214
- doc_knowledge = f"Document: {doc_name} \nContains the following relevant fragments:\n"
215
-
216
- # Add numbered fragments
217
- for i, chunk in enumerate(chunks, 1):
218
- doc_knowledge += f"{i}. {chunk}\n"
219
-
220
- knowledges.append(doc_knowledge)
221
-
222
-
223
-
224
  logging.debug(
225
  "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
226
  retrieval_tm = timer()
@@ -603,7 +605,6 @@ def tts(tts_mdl, text):
603
 
604
  def ask(question, kb_ids, tenant_id):
605
  kbs = KnowledgebaseService.get_by_ids(kb_ids)
606
- tenant_ids = [kb.tenant_id for kb in kbs]
607
  embd_nms = list(set([kb.embd_id for kb in kbs]))
608
 
609
  is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
@@ -612,45 +613,9 @@ def ask(question, kb_ids, tenant_id):
612
  embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
613
  chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
614
  max_tokens = chat_mdl.max_length
615
-
616
  kbinfos = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False)
617
- knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
618
-
619
- used_token_count = 0
620
- chunks_num = 0
621
- for i, c in enumerate(knowledges):
622
- used_token_count += num_tokens_from_string(c)
623
- if max_tokens * 0.97 < used_token_count:
624
- knowledges = knowledges[:i]
625
- chunks_num = chunks_num + 1
626
- break
627
-
628
- # Group chunks by document ID
629
- doc_chunks = {}
630
- counter_chunks = 0
631
- for ck in kbinfos["chunks"]:
632
- if counter_chunks < chunks_num:
633
- counter_chunks = counter_chunks + 1
634
- doc_id = ck["doc_id"]
635
- if doc_id not in doc_chunks:
636
- doc_chunks[doc_id] = []
637
- doc_chunks[doc_id].append(ck["content_with_weight"])
638
-
639
- # Create knowledges list with grouped chunks
640
- knowledges = []
641
- for doc_id, chunks in doc_chunks.items():
642
- # Find the corresponding document name
643
- doc_name = next((d["doc_name"] for d in kbinfos.get("doc_aggs", []) if d["doc_id"] == doc_id), doc_id)
644
-
645
- # Create a header for the document
646
- doc_knowledge = f"Document: {doc_name} \nContains the following relevant fragments:\n"
647
-
648
- # Add numbered fragments
649
- for i, chunk in enumerate(chunks, 1):
650
- doc_knowledge += f"{i}. {chunk}\n"
651
-
652
- knowledges.append(doc_knowledge)
653
-
654
  prompt = """
655
  Role: You're a smart assistant. Your name is Miss R.
656
  Task: Summarize the information from knowledge bases and answer user's question.
@@ -660,25 +625,25 @@ def ask(question, kb_ids, tenant_id):
660
  - Answer with markdown format text.
661
  - Answer in language of user's question.
662
  - DO NOT make things up, especially for numbers.
663
-
664
  ### Information from knowledge bases
665
  %s
666
-
667
  The above is information from knowledge bases.
668
-
669
- """%"\n".join(knowledges)
670
  msg = [{"role": "user", "content": question}]
671
 
672
  def decorate_answer(answer):
673
  nonlocal knowledges, kbinfos, prompt
674
  answer, idx = retr.insert_citations(answer,
675
- [ck["content_ltks"]
676
- for ck in kbinfos["chunks"]],
677
- [ck["vector"]
678
- for ck in kbinfos["chunks"]],
679
- embd_mdl,
680
- tkweight=0.7,
681
- vtweight=0.3)
682
  idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
683
  recall_docs = [
684
  d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
@@ -691,7 +656,7 @@ def ask(question, kb_ids, tenant_id):
691
  del c["vector"]
692
 
693
  if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
694
- answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
695
  return {"answer": answer, "reference": refs}
696
 
697
  answer = ""
 
18
  import os
19
  import json
20
  import re
21
+ from collections import defaultdict
22
  from copy import deepcopy
23
  from timeit import default_timer as timer
24
  import datetime
 
109
  return llm["model_type"].strip(",")[-1]
110
 
111
 
112
+ def kb_prompt(kbinfos, max_tokens):
113
+ knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
114
+ used_token_count = 0
115
+ chunks_num = 0
116
+ for i, c in enumerate(knowledges):
117
+ used_token_count += num_tokens_from_string(c)
118
+ chunks_num += 1
119
+ if max_tokens * 0.97 < used_token_count:
120
+ knowledges = knowledges[:i]
121
+ break
122
+
123
+ doc2chunks = defaultdict(list)
124
+ for i, ck in enumerate(kbinfos["chunks"]):
125
+ if i >= chunks_num:
126
+ break
127
+ doc2chunks["docnm_kwd"].append(ck["content_with_weight"])
128
+
129
+ knowledges = []
130
+ for nm, chunks in doc2chunks.items():
131
+ txt = f"Document: {nm} \nContains the following relevant fragments:\n"
132
+ for i, chunk in enumerate(chunks, 1):
133
+ txt += f"{i}. {chunk}\n"
134
+ knowledges.append(txt)
135
+ return knowledges
136
+
137
+
138
  def chat(dialog, messages, stream=True, **kwargs):
139
  assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
140
  st = timer()
 
222
  dialog.vector_similarity_weight,
223
  doc_ids=attachments,
224
  top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
225
+ knowledges = kb_prompt(kbinfos, max_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  logging.debug(
227
  "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
228
  retrieval_tm = timer()
 
605
 
606
  def ask(question, kb_ids, tenant_id):
607
  kbs = KnowledgebaseService.get_by_ids(kb_ids)
 
608
  embd_nms = list(set([kb.embd_id for kb in kbs]))
609
 
610
  is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
 
613
  embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
614
  chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
615
  max_tokens = chat_mdl.max_length
616
+ tenant_ids = list(set([kb.tenant_id for kb in kbs]))
617
  kbinfos = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False)
618
+ knowledges = kb_prompt(kbinfos, max_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
619
  prompt = """
620
  Role: You're a smart assistant. Your name is Miss R.
621
  Task: Summarize the information from knowledge bases and answer user's question.
 
625
  - Answer with markdown format text.
626
  - Answer in language of user's question.
627
  - DO NOT make things up, especially for numbers.
628
+
629
  ### Information from knowledge bases
630
  %s
631
+
632
  The above is information from knowledge bases.
633
+
634
+ """ % "\n".join(knowledges)
635
  msg = [{"role": "user", "content": question}]
636
 
637
  def decorate_answer(answer):
638
  nonlocal knowledges, kbinfos, prompt
639
  answer, idx = retr.insert_citations(answer,
640
+ [ck["content_ltks"]
641
+ for ck in kbinfos["chunks"]],
642
+ [ck["vector"]
643
+ for ck in kbinfos["chunks"]],
644
+ embd_mdl,
645
+ tkweight=0.7,
646
+ vtweight=0.3)
647
  idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
648
  recall_docs = [
649
  d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
 
656
  del c["vector"]
657
 
658
  if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
659
+ answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
660
  return {"answer": answer, "reference": refs}
661
 
662
  answer = ""