Kevin Hu
commited on
Commit
·
1b1a5b7
1
Parent(s):
faaabea
Support iframe chatbot. (#3961)
Browse files### What problem does this PR solve?
#3909
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- agent/canvas.py +4 -1
- agent/component/base.py +16 -1
- agent/component/generate.py +8 -0
- api/apps/canvas_app.py +20 -0
- api/apps/conversation_app.py +36 -16
- api/apps/sdk/session.py +10 -0
- api/db/services/canvas_service.py +16 -28
- api/db/services/conversation_service.py +21 -18
- api/db/services/dialog_service.py +42 -77
agent/canvas.py
CHANGED
@@ -330,4 +330,7 @@ class Canvas(ABC):
|
|
330 |
q["value"] = v
|
331 |
|
332 |
def get_preset_param(self):
|
333 |
-
return self.components["begin"]["obj"]._param.query
|
|
|
|
|
|
|
|
330 |
q["value"] = v
|
331 |
|
332 |
def get_preset_param(self):
|
333 |
+
return self.components["begin"]["obj"]._param.query
|
334 |
+
|
335 |
+
def get_component_input_elements(self, cpnnm):
|
336 |
+
return self.components["begin"]["obj"].get_input_elements()
|
agent/component/base.py
CHANGED
@@ -476,7 +476,7 @@ class ComponentBase(ABC):
|
|
476 |
self._param.inputs.append({"component_id": q["component_id"],
|
477 |
"content": "\n".join(
|
478 |
[str(d["content"]) for d in outs[-1].to_dict('records')])})
|
479 |
-
elif q
|
480 |
self._param.inputs.append({"component_id": None, "content": q["value"]})
|
481 |
outs.append(pd.DataFrame([{"content": q["value"]}]))
|
482 |
if outs:
|
@@ -526,6 +526,21 @@ class ComponentBase(ABC):
|
|
526 |
|
527 |
return df
|
528 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
529 |
def get_stream_input(self):
|
530 |
reversed_cpnts = []
|
531 |
if len(self._canvas.path) > 1:
|
|
|
476 |
self._param.inputs.append({"component_id": q["component_id"],
|
477 |
"content": "\n".join(
|
478 |
[str(d["content"]) for d in outs[-1].to_dict('records')])})
|
479 |
+
elif q.get("value"):
|
480 |
self._param.inputs.append({"component_id": None, "content": q["value"]})
|
481 |
outs.append(pd.DataFrame([{"content": q["value"]}]))
|
482 |
if outs:
|
|
|
526 |
|
527 |
return df
|
528 |
|
529 |
+
def get_input_elements(self):
|
530 |
+
assert self._param.query, "Please identify input parameters firstly."
|
531 |
+
eles = []
|
532 |
+
for q in self._param.query:
|
533 |
+
if q.get("component_id"):
|
534 |
+
if q["component_id"].split("@")[0].lower().find("begin") >= 0:
|
535 |
+
cpn_id, key = q["component_id"].split("@")
|
536 |
+
eles.extend(self._canvas.get_component(cpn_id)["obj"]._param.query)
|
537 |
+
continue
|
538 |
+
|
539 |
+
eles.append({"key": q["key"], "component_id": q["component_id"]})
|
540 |
+
else:
|
541 |
+
eles.append({"key": q["key"]})
|
542 |
+
return eles
|
543 |
+
|
544 |
def get_stream_input(self):
|
545 |
reversed_cpnts = []
|
546 |
if len(self._canvas.path) > 1:
|
agent/component/generate.py
CHANGED
@@ -17,6 +17,7 @@ import re
|
|
17 |
from functools import partial
|
18 |
import pandas as pd
|
19 |
from api.db import LLMType
|
|
|
20 |
from api.db.services.dialog_service import message_fit_in
|
21 |
from api.db.services.llm_service import LLMBundle
|
22 |
from api import settings
|
@@ -104,9 +105,16 @@ class Generate(ComponentBase):
|
|
104 |
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
105 |
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
|
106 |
res = {"content": answer, "reference": reference}
|
|
|
107 |
|
108 |
return res
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
def _run(self, history, **kwargs):
|
111 |
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
112 |
prompt = self._param.prompt
|
|
|
17 |
from functools import partial
|
18 |
import pandas as pd
|
19 |
from api.db import LLMType
|
20 |
+
from api.db.services.conversation_service import structure_answer
|
21 |
from api.db.services.dialog_service import message_fit_in
|
22 |
from api.db.services.llm_service import LLMBundle
|
23 |
from api import settings
|
|
|
105 |
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
106 |
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
|
107 |
res = {"content": answer, "reference": reference}
|
108 |
+
res = structure_answer(None, res, "", "")
|
109 |
|
110 |
return res
|
111 |
|
112 |
+
def get_input_elements(self):
|
113 |
+
if self._param.parameters:
|
114 |
+
return self._param.parameters
|
115 |
+
|
116 |
+
return [{"key": "input"}]
|
117 |
+
|
118 |
def _run(self, history, **kwargs):
|
119 |
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
120 |
prompt = self._param.prompt
|
api/apps/canvas_app.py
CHANGED
@@ -186,6 +186,26 @@ def reset():
|
|
186 |
return server_error_response(e)
|
187 |
|
188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
@manager.route('/test_db_connect', methods=['POST']) # noqa: F821
|
190 |
@validate_request("db_type", "database", "username", "host", "port", "password")
|
191 |
@login_required
|
|
|
186 |
return server_error_response(e)
|
187 |
|
188 |
|
189 |
+
@manager.route('/input_elements', methods=['GET']) # noqa: F821
|
190 |
+
@validate_request("id", "component_id")
|
191 |
+
@login_required
|
192 |
+
def input_elements():
|
193 |
+
req = request.json
|
194 |
+
try:
|
195 |
+
e, user_canvas = UserCanvasService.get_by_id(req["id"])
|
196 |
+
if not e:
|
197 |
+
return get_data_error_result(message="canvas not found.")
|
198 |
+
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
|
199 |
+
return get_json_result(
|
200 |
+
data=False, message='Only owner of canvas authorized for this operation.',
|
201 |
+
code=RetCode.OPERATING_ERROR)
|
202 |
+
|
203 |
+
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
|
204 |
+
return get_json_result(data=canvas.get_component_input_elements(req["component_id"]))
|
205 |
+
except Exception as e:
|
206 |
+
return server_error_response(e)
|
207 |
+
|
208 |
+
|
209 |
@manager.route('/test_db_connect', methods=['POST']) # noqa: F821
|
210 |
@validate_request("db_type", "database", "username", "host", "port", "password")
|
211 |
@login_required
|
api/apps/conversation_app.py
CHANGED
@@ -18,7 +18,7 @@ import re
|
|
18 |
import traceback
|
19 |
from copy import deepcopy
|
20 |
|
21 |
-
from api.db.services.conversation_service import ConversationService
|
22 |
from api.db.services.user_service import UserTenantService
|
23 |
from flask import request, Response
|
24 |
from flask_login import login_required, current_user
|
@@ -90,6 +90,21 @@ def get():
|
|
90 |
return get_json_result(
|
91 |
data=False, message='Only owner of conversation authorized for this operation.',
|
92 |
code=settings.RetCode.OPERATING_ERROR)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
conv = conv.to_dict()
|
94 |
return get_json_result(data=conv)
|
95 |
except Exception as e:
|
@@ -132,6 +147,7 @@ def list_convsersation():
|
|
132 |
dialog_id=dialog_id,
|
133 |
order_by=ConversationService.model.create_time,
|
134 |
reverse=True)
|
|
|
135 |
convs = [d.to_dict() for d in convs]
|
136 |
return get_json_result(data=convs)
|
137 |
except Exception as e:
|
@@ -164,24 +180,29 @@ def completion():
|
|
164 |
|
165 |
if not conv.reference:
|
166 |
conv.reference = []
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
|
|
|
|
179 |
|
|
|
|
|
|
|
180 |
def stream():
|
181 |
nonlocal dia, msg, req, conv
|
182 |
try:
|
183 |
for ans in chat(dia, msg, True, **req):
|
184 |
-
|
185 |
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
186 |
ConversationService.update_by_id(conv.id, conv.to_dict())
|
187 |
except Exception as e:
|
@@ -202,8 +223,7 @@ def completion():
|
|
202 |
else:
|
203 |
answer = None
|
204 |
for ans in chat(dia, msg, **req):
|
205 |
-
answer = ans
|
206 |
-
fillin_conv(ans)
|
207 |
ConversationService.update_by_id(conv.id, conv.to_dict())
|
208 |
break
|
209 |
return get_json_result(data=answer)
|
|
|
18 |
import traceback
|
19 |
from copy import deepcopy
|
20 |
|
21 |
+
from api.db.services.conversation_service import ConversationService, structure_answer
|
22 |
from api.db.services.user_service import UserTenantService
|
23 |
from flask import request, Response
|
24 |
from flask_login import login_required, current_user
|
|
|
90 |
return get_json_result(
|
91 |
data=False, message='Only owner of conversation authorized for this operation.',
|
92 |
code=settings.RetCode.OPERATING_ERROR)
|
93 |
+
|
94 |
+
def get_value(d, k1, k2):
|
95 |
+
return d.get(k1, d.get(k2))
|
96 |
+
|
97 |
+
for ref in conv.reference:
|
98 |
+
ref["chunks"] = [{
|
99 |
+
"id": get_value(ck, "chunk_id", "id"),
|
100 |
+
"content": get_value(ck, "content", "content_with_weight"),
|
101 |
+
"document_id": get_value(ck, "doc_id", "document_id"),
|
102 |
+
"document_name": get_value(ck, "docnm_kwd", "document_name"),
|
103 |
+
"dataset_id": get_value(ck, "kb_id", "dataset_id"),
|
104 |
+
"image_id": get_value(ck, "image_id", "img_id"),
|
105 |
+
"positions": get_value(ck, "positions", "position_int"),
|
106 |
+
} for ck in ref.get("chunks", [])]
|
107 |
+
|
108 |
conv = conv.to_dict()
|
109 |
return get_json_result(data=conv)
|
110 |
except Exception as e:
|
|
|
147 |
dialog_id=dialog_id,
|
148 |
order_by=ConversationService.model.create_time,
|
149 |
reverse=True)
|
150 |
+
|
151 |
convs = [d.to_dict() for d in convs]
|
152 |
return get_json_result(data=convs)
|
153 |
except Exception as e:
|
|
|
180 |
|
181 |
if not conv.reference:
|
182 |
conv.reference = []
|
183 |
+
else:
|
184 |
+
def get_value(d, k1, k2):
|
185 |
+
return d.get(k1, d.get(k2))
|
186 |
+
|
187 |
+
for ref in conv.reference:
|
188 |
+
ref["chunks"] = [{
|
189 |
+
"id": get_value(ck, "chunk_id", "id"),
|
190 |
+
"content": get_value(ck, "content", "content_with_weight"),
|
191 |
+
"document_id": get_value(ck, "doc_id", "document_id"),
|
192 |
+
"document_name": get_value(ck, "docnm_kwd", "document_name"),
|
193 |
+
"dataset_id": get_value(ck, "kb_id", "dataset_id"),
|
194 |
+
"image_id": get_value(ck, "image_id", "img_id"),
|
195 |
+
"positions": get_value(ck, "positions", "position_int"),
|
196 |
+
} for ck in ref.get("chunks", [])]
|
197 |
|
198 |
+
if not conv.reference:
|
199 |
+
conv.reference = []
|
200 |
+
conv.reference.append({"chunks": [], "doc_aggs": []})
|
201 |
def stream():
|
202 |
nonlocal dia, msg, req, conv
|
203 |
try:
|
204 |
for ans in chat(dia, msg, True, **req):
|
205 |
+
ans = structure_answer(conv, ans, message_id, conv.id)
|
206 |
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
207 |
ConversationService.update_by_id(conv.id, conv.to_dict())
|
208 |
except Exception as e:
|
|
|
223 |
else:
|
224 |
answer = None
|
225 |
for ans in chat(dia, msg, **req):
|
226 |
+
answer = structure_answer(conv, ans, message_id, req["conversation_id"])
|
|
|
227 |
ConversationService.update_by_id(conv.id, conv.to_dict())
|
228 |
break
|
229 |
return get_json_result(data=answer)
|
api/apps/sdk/session.py
CHANGED
@@ -112,6 +112,11 @@ def update(tenant_id, chat_id, session_id):
|
|
112 |
@token_required
|
113 |
def chat_completion(tenant_id, chat_id):
|
114 |
req = request.json
|
|
|
|
|
|
|
|
|
|
|
115 |
if req.get("stream", True):
|
116 |
resp = Response(rag_completion(tenant_id, chat_id, **req), mimetype="text/event-stream")
|
117 |
resp.headers.add_header("Cache-control", "no-cache")
|
@@ -133,6 +138,11 @@ def chat_completion(tenant_id, chat_id):
|
|
133 |
@token_required
|
134 |
def agent_completions(tenant_id, agent_id):
|
135 |
req = request.json
|
|
|
|
|
|
|
|
|
|
|
136 |
if req.get("stream", True):
|
137 |
resp = Response(agent_completion(tenant_id, agent_id, **req), mimetype="text/event-stream")
|
138 |
resp.headers.add_header("Cache-control", "no-cache")
|
|
|
112 |
@token_required
|
113 |
def chat_completion(tenant_id, chat_id):
|
114 |
req = request.json
|
115 |
+
if not DialogService.query(tenant_id=tenant_id,id=chat_id,status=StatusEnum.VALID.value):
|
116 |
+
return get_error_data_result(f"You don't own the chat {chat_id}")
|
117 |
+
if req.get("session_id"):
|
118 |
+
if not ConversationService.query(id=req["session_id"],dialog_id=chat_id):
|
119 |
+
return get_error_data_result(f"You don't own the session {req['session_id']}")
|
120 |
if req.get("stream", True):
|
121 |
resp = Response(rag_completion(tenant_id, chat_id, **req), mimetype="text/event-stream")
|
122 |
resp.headers.add_header("Cache-control", "no-cache")
|
|
|
138 |
@token_required
|
139 |
def agent_completions(tenant_id, agent_id):
|
140 |
req = request.json
|
141 |
+
if not UserCanvasService.query(user_id=tenant_id,id=agent_id):
|
142 |
+
return get_error_data_result(f"You don't own the agent {agent_id}")
|
143 |
+
if req.get("session_id"):
|
144 |
+
if not API4ConversationService.query(id=req["session_id"],dialog_id=agent_id):
|
145 |
+
return get_error_data_result(f"You don't own the session {req['session_id']}")
|
146 |
if req.get("stream", True):
|
147 |
resp = Response(agent_completion(tenant_id, agent_id, **req), mimetype="text/event-stream")
|
148 |
resp.headers.add_header("Cache-control", "no-cache")
|
api/db/services/canvas_service.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
import json
|
|
|
17 |
from uuid import uuid4
|
18 |
from agent.canvas import Canvas
|
19 |
from api.db.db_models import DB, CanvasTemplate, UserCanvas, API4Conversation
|
@@ -58,6 +59,8 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw
|
|
58 |
if not isinstance(cvs.dsl, str):
|
59 |
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
60 |
canvas = Canvas(cvs.dsl, tenant_id)
|
|
|
|
|
61 |
|
62 |
if not session_id:
|
63 |
session_id = get_uuid()
|
@@ -84,40 +87,24 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw
|
|
84 |
return
|
85 |
conv = API4Conversation(**conv)
|
86 |
else:
|
87 |
-
session_id = session_id
|
88 |
e, conv = API4ConversationService.get_by_id(session_id)
|
89 |
assert e, "Session not found!"
|
90 |
canvas = Canvas(json.dumps(conv.dsl), tenant_id)
|
91 |
-
|
92 |
-
|
93 |
-
conv.message
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
if m["role"] == "system":
|
104 |
-
continue
|
105 |
-
if m["role"] == "assistant" and not msg:
|
106 |
-
continue
|
107 |
-
msg.append(m)
|
108 |
-
if not msg[-1].get("id"):
|
109 |
-
msg[-1]["id"] = get_uuid()
|
110 |
-
message_id = msg[-1]["id"]
|
111 |
-
|
112 |
-
if not conv.reference:
|
113 |
-
conv.reference = []
|
114 |
-
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
115 |
-
conv.reference.append({"chunks": [], "doc_aggs": []})
|
116 |
|
117 |
final_ans = {"reference": [], "content": ""}
|
118 |
|
119 |
-
canvas.add_user_input(msg[-1]["content"])
|
120 |
-
|
121 |
if stream:
|
122 |
try:
|
123 |
for ans in canvas.run(stream=stream):
|
@@ -141,6 +128,7 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw
|
|
141 |
conv.dsl = json.loads(str(canvas))
|
142 |
API4ConversationService.append_message(conv.id, conv.to_dict())
|
143 |
except Exception as e:
|
|
|
144 |
conv.dsl = json.loads(str(canvas))
|
145 |
API4ConversationService.append_message(conv.id, conv.to_dict())
|
146 |
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
|
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
import json
|
17 |
+
import traceback
|
18 |
from uuid import uuid4
|
19 |
from agent.canvas import Canvas
|
20 |
from api.db.db_models import DB, CanvasTemplate, UserCanvas, API4Conversation
|
|
|
59 |
if not isinstance(cvs.dsl, str):
|
60 |
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
61 |
canvas = Canvas(cvs.dsl, tenant_id)
|
62 |
+
canvas.reset()
|
63 |
+
message_id = str(uuid4())
|
64 |
|
65 |
if not session_id:
|
66 |
session_id = get_uuid()
|
|
|
87 |
return
|
88 |
conv = API4Conversation(**conv)
|
89 |
else:
|
|
|
90 |
e, conv = API4ConversationService.get_by_id(session_id)
|
91 |
assert e, "Session not found!"
|
92 |
canvas = Canvas(json.dumps(conv.dsl), tenant_id)
|
93 |
+
canvas.messages.append({"role": "user", "content": question, "id": message_id})
|
94 |
+
canvas.add_user_input(question)
|
95 |
+
if not conv.message:
|
96 |
+
conv.message = []
|
97 |
+
conv.message.append({
|
98 |
+
"role": "user",
|
99 |
+
"content": question,
|
100 |
+
"id": message_id
|
101 |
+
})
|
102 |
+
if not conv.reference:
|
103 |
+
conv.reference = []
|
104 |
+
conv.reference.append({"chunks": [], "doc_aggs": []})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
final_ans = {"reference": [], "content": ""}
|
107 |
|
|
|
|
|
108 |
if stream:
|
109 |
try:
|
110 |
for ans in canvas.run(stream=stream):
|
|
|
128 |
conv.dsl = json.loads(str(canvas))
|
129 |
API4ConversationService.append_message(conv.id, conv.to_dict())
|
130 |
except Exception as e:
|
131 |
+
traceback.print_exc()
|
132 |
conv.dsl = json.loads(str(canvas))
|
133 |
API4ConversationService.append_message(conv.id, conv.to_dict())
|
134 |
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
api/db/services/conversation_service.py
CHANGED
@@ -21,7 +21,6 @@ from api.db.services.common_service import CommonService
|
|
21 |
from api.db.services.dialog_service import DialogService, chat
|
22 |
from api.utils import get_uuid
|
23 |
import json
|
24 |
-
from copy import deepcopy
|
25 |
|
26 |
|
27 |
class ConversationService(CommonService):
|
@@ -49,30 +48,35 @@ def structure_answer(conv, ans, message_id, session_id):
|
|
49 |
reference = ans["reference"]
|
50 |
if not isinstance(reference, dict):
|
51 |
reference = {}
|
52 |
-
|
53 |
-
if not conv.reference:
|
54 |
-
conv.reference.append(temp_reference)
|
55 |
-
else:
|
56 |
-
conv.reference[-1] = temp_reference
|
57 |
-
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
|
58 |
|
|
|
|
|
59 |
chunk_list = [{
|
60 |
-
"id": chunk
|
61 |
-
"content": chunk
|
62 |
-
"document_id": chunk
|
63 |
-
"document_name": chunk
|
64 |
-
"dataset_id": chunk
|
65 |
-
"image_id": chunk
|
66 |
-
"
|
67 |
-
"vector_similarity": chunk["vector_similarity"],
|
68 |
-
"term_similarity": chunk["term_similarity"],
|
69 |
-
"positions": chunk["positions"],
|
70 |
} for chunk in reference.get("chunks", [])]
|
71 |
|
72 |
reference["chunks"] = chunk_list
|
73 |
ans["id"] = message_id
|
74 |
ans["session_id"] = session_id
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
return ans
|
77 |
|
78 |
|
@@ -199,7 +203,6 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg
|
|
199 |
|
200 |
if not conv.reference:
|
201 |
conv.reference = []
|
202 |
-
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
203 |
conv.reference.append({"chunks": [], "doc_aggs": []})
|
204 |
|
205 |
if stream:
|
|
|
21 |
from api.db.services.dialog_service import DialogService, chat
|
22 |
from api.utils import get_uuid
|
23 |
import json
|
|
|
24 |
|
25 |
|
26 |
class ConversationService(CommonService):
|
|
|
48 |
reference = ans["reference"]
|
49 |
if not isinstance(reference, dict):
|
50 |
reference = {}
|
51 |
+
ans["reference"] = {}
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
+
def get_value(d, k1, k2):
|
54 |
+
return d.get(k1, d.get(k2))
|
55 |
chunk_list = [{
|
56 |
+
"id": get_value(chunk, "chunk_id", "id"),
|
57 |
+
"content": get_value(chunk, "content", "content_with_weight"),
|
58 |
+
"document_id": get_value(chunk, "doc_id", "document_id"),
|
59 |
+
"document_name": get_value(chunk, "docnm_kwd", "document_name"),
|
60 |
+
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
|
61 |
+
"image_id": get_value(chunk, "image_id", "img_id"),
|
62 |
+
"positions": get_value(chunk, "positions", "position_int"),
|
|
|
|
|
|
|
63 |
} for chunk in reference.get("chunks", [])]
|
64 |
|
65 |
reference["chunks"] = chunk_list
|
66 |
ans["id"] = message_id
|
67 |
ans["session_id"] = session_id
|
68 |
|
69 |
+
if not conv:
|
70 |
+
return ans
|
71 |
+
|
72 |
+
if not conv.message:
|
73 |
+
conv.message = []
|
74 |
+
if not conv.message or conv.message[-1].get("role", "") != "assistant":
|
75 |
+
conv.message.append({"role": "assistant", "content": ans["answer"], "id": message_id})
|
76 |
+
else:
|
77 |
+
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
|
78 |
+
if conv.reference:
|
79 |
+
conv.reference[-1] = reference
|
80 |
return ans
|
81 |
|
82 |
|
|
|
203 |
|
204 |
if not conv.reference:
|
205 |
conv.reference = []
|
|
|
206 |
conv.reference.append({"chunks": [], "doc_aggs": []})
|
207 |
|
208 |
if stream:
|
api/db/services/dialog_service.py
CHANGED
@@ -18,6 +18,7 @@ import binascii
|
|
18 |
import os
|
19 |
import json
|
20 |
import re
|
|
|
21 |
from copy import deepcopy
|
22 |
from timeit import default_timer as timer
|
23 |
import datetime
|
@@ -108,6 +109,32 @@ def llm_id2llm_type(llm_id):
|
|
108 |
return llm["model_type"].strip(",")[-1]
|
109 |
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
def chat(dialog, messages, stream=True, **kwargs):
|
112 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
113 |
st = timer()
|
@@ -195,32 +222,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
195 |
dialog.vector_similarity_weight,
|
196 |
doc_ids=attachments,
|
197 |
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
|
198 |
-
|
199 |
-
# Group chunks by document ID
|
200 |
-
doc_chunks = {}
|
201 |
-
for ck in kbinfos["chunks"]:
|
202 |
-
doc_id = ck["doc_id"]
|
203 |
-
if doc_id not in doc_chunks:
|
204 |
-
doc_chunks[doc_id] = []
|
205 |
-
doc_chunks[doc_id].append(ck["content_with_weight"])
|
206 |
-
|
207 |
-
# Create knowledges list with grouped chunks
|
208 |
-
knowledges = []
|
209 |
-
for doc_id, chunks in doc_chunks.items():
|
210 |
-
# Find the corresponding document name
|
211 |
-
doc_name = next((d["doc_name"] for d in kbinfos.get("doc_aggs", []) if d["doc_id"] == doc_id), doc_id)
|
212 |
-
|
213 |
-
# Create a header for the document
|
214 |
-
doc_knowledge = f"Document: {doc_name} \nContains the following relevant fragments:\n"
|
215 |
-
|
216 |
-
# Add numbered fragments
|
217 |
-
for i, chunk in enumerate(chunks, 1):
|
218 |
-
doc_knowledge += f"{i}. {chunk}\n"
|
219 |
-
|
220 |
-
knowledges.append(doc_knowledge)
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
logging.debug(
|
225 |
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
226 |
retrieval_tm = timer()
|
@@ -603,7 +605,6 @@ def tts(tts_mdl, text):
|
|
603 |
|
604 |
def ask(question, kb_ids, tenant_id):
|
605 |
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
606 |
-
tenant_ids = [kb.tenant_id for kb in kbs]
|
607 |
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
608 |
|
609 |
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
@@ -612,45 +613,9 @@ def ask(question, kb_ids, tenant_id):
|
|
612 |
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
613 |
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
614 |
max_tokens = chat_mdl.max_length
|
615 |
-
|
616 |
kbinfos = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False)
|
617 |
-
knowledges =
|
618 |
-
|
619 |
-
used_token_count = 0
|
620 |
-
chunks_num = 0
|
621 |
-
for i, c in enumerate(knowledges):
|
622 |
-
used_token_count += num_tokens_from_string(c)
|
623 |
-
if max_tokens * 0.97 < used_token_count:
|
624 |
-
knowledges = knowledges[:i]
|
625 |
-
chunks_num = chunks_num + 1
|
626 |
-
break
|
627 |
-
|
628 |
-
# Group chunks by document ID
|
629 |
-
doc_chunks = {}
|
630 |
-
counter_chunks = 0
|
631 |
-
for ck in kbinfos["chunks"]:
|
632 |
-
if counter_chunks < chunks_num:
|
633 |
-
counter_chunks = counter_chunks + 1
|
634 |
-
doc_id = ck["doc_id"]
|
635 |
-
if doc_id not in doc_chunks:
|
636 |
-
doc_chunks[doc_id] = []
|
637 |
-
doc_chunks[doc_id].append(ck["content_with_weight"])
|
638 |
-
|
639 |
-
# Create knowledges list with grouped chunks
|
640 |
-
knowledges = []
|
641 |
-
for doc_id, chunks in doc_chunks.items():
|
642 |
-
# Find the corresponding document name
|
643 |
-
doc_name = next((d["doc_name"] for d in kbinfos.get("doc_aggs", []) if d["doc_id"] == doc_id), doc_id)
|
644 |
-
|
645 |
-
# Create a header for the document
|
646 |
-
doc_knowledge = f"Document: {doc_name} \nContains the following relevant fragments:\n"
|
647 |
-
|
648 |
-
# Add numbered fragments
|
649 |
-
for i, chunk in enumerate(chunks, 1):
|
650 |
-
doc_knowledge += f"{i}. {chunk}\n"
|
651 |
-
|
652 |
-
knowledges.append(doc_knowledge)
|
653 |
-
|
654 |
prompt = """
|
655 |
Role: You're a smart assistant. Your name is Miss R.
|
656 |
Task: Summarize the information from knowledge bases and answer user's question.
|
@@ -660,25 +625,25 @@ def ask(question, kb_ids, tenant_id):
|
|
660 |
- Answer with markdown format text.
|
661 |
- Answer in language of user's question.
|
662 |
- DO NOT make things up, especially for numbers.
|
663 |
-
|
664 |
### Information from knowledge bases
|
665 |
%s
|
666 |
-
|
667 |
The above is information from knowledge bases.
|
668 |
-
|
669 |
-
"""%"\n".join(knowledges)
|
670 |
msg = [{"role": "user", "content": question}]
|
671 |
|
672 |
def decorate_answer(answer):
|
673 |
nonlocal knowledges, kbinfos, prompt
|
674 |
answer, idx = retr.insert_citations(answer,
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
683 |
recall_docs = [
|
684 |
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
@@ -691,7 +656,7 @@ def ask(question, kb_ids, tenant_id):
|
|
691 |
del c["vector"]
|
692 |
|
693 |
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
694 |
-
answer += " Please set LLM API-Key in 'User Setting -> Model
|
695 |
return {"answer": answer, "reference": refs}
|
696 |
|
697 |
answer = ""
|
|
|
18 |
import os
|
19 |
import json
|
20 |
import re
|
21 |
+
from collections import defaultdict
|
22 |
from copy import deepcopy
|
23 |
from timeit import default_timer as timer
|
24 |
import datetime
|
|
|
109 |
return llm["model_type"].strip(",")[-1]
|
110 |
|
111 |
|
112 |
+
def kb_prompt(kbinfos, max_tokens):
|
113 |
+
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
114 |
+
used_token_count = 0
|
115 |
+
chunks_num = 0
|
116 |
+
for i, c in enumerate(knowledges):
|
117 |
+
used_token_count += num_tokens_from_string(c)
|
118 |
+
chunks_num += 1
|
119 |
+
if max_tokens * 0.97 < used_token_count:
|
120 |
+
knowledges = knowledges[:i]
|
121 |
+
break
|
122 |
+
|
123 |
+
doc2chunks = defaultdict(list)
|
124 |
+
for i, ck in enumerate(kbinfos["chunks"]):
|
125 |
+
if i >= chunks_num:
|
126 |
+
break
|
127 |
+
doc2chunks["docnm_kwd"].append(ck["content_with_weight"])
|
128 |
+
|
129 |
+
knowledges = []
|
130 |
+
for nm, chunks in doc2chunks.items():
|
131 |
+
txt = f"Document: {nm} \nContains the following relevant fragments:\n"
|
132 |
+
for i, chunk in enumerate(chunks, 1):
|
133 |
+
txt += f"{i}. {chunk}\n"
|
134 |
+
knowledges.append(txt)
|
135 |
+
return knowledges
|
136 |
+
|
137 |
+
|
138 |
def chat(dialog, messages, stream=True, **kwargs):
|
139 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
140 |
st = timer()
|
|
|
222 |
dialog.vector_similarity_weight,
|
223 |
doc_ids=attachments,
|
224 |
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
|
225 |
+
knowledges = kb_prompt(kbinfos, max_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
logging.debug(
|
227 |
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
228 |
retrieval_tm = timer()
|
|
|
605 |
|
606 |
def ask(question, kb_ids, tenant_id):
|
607 |
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
|
|
608 |
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
609 |
|
610 |
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
|
|
613 |
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
614 |
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
615 |
max_tokens = chat_mdl.max_length
|
616 |
+
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
617 |
kbinfos = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False)
|
618 |
+
knowledges = kb_prompt(kbinfos, max_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
619 |
prompt = """
|
620 |
Role: You're a smart assistant. Your name is Miss R.
|
621 |
Task: Summarize the information from knowledge bases and answer user's question.
|
|
|
625 |
- Answer with markdown format text.
|
626 |
- Answer in language of user's question.
|
627 |
- DO NOT make things up, especially for numbers.
|
628 |
+
|
629 |
### Information from knowledge bases
|
630 |
%s
|
631 |
+
|
632 |
The above is information from knowledge bases.
|
633 |
+
|
634 |
+
""" % "\n".join(knowledges)
|
635 |
msg = [{"role": "user", "content": question}]
|
636 |
|
637 |
def decorate_answer(answer):
|
638 |
nonlocal knowledges, kbinfos, prompt
|
639 |
answer, idx = retr.insert_citations(answer,
|
640 |
+
[ck["content_ltks"]
|
641 |
+
for ck in kbinfos["chunks"]],
|
642 |
+
[ck["vector"]
|
643 |
+
for ck in kbinfos["chunks"]],
|
644 |
+
embd_mdl,
|
645 |
+
tkweight=0.7,
|
646 |
+
vtweight=0.3)
|
647 |
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
648 |
recall_docs = [
|
649 |
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
|
|
656 |
del c["vector"]
|
657 |
|
658 |
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
659 |
+
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
660 |
return {"answer": answer, "reference": refs}
|
661 |
|
662 |
answer = ""
|