KevinHuSh
commited on
Commit
·
028fe40
1
Parent(s):
df17cda
add stream chat (#811)
Browse files### What problem does this PR solve?
#709
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- api/apps/api_app.py +45 -44
- api/apps/conversation_app.py +43 -7
- api/apps/system_app.py +67 -0
- api/db/services/dialog_service.py +51 -32
- api/db/services/document_service.py +2 -2
- api/db/services/llm_service.py +11 -1
- api/utils/api_utils.py +0 -4
- rag/llm/chat_model.py +107 -1
- rag/svr/task_executor.py +12 -13
- rag/utils/es_conn.py +3 -0
- rag/utils/minio_conn.py +10 -0
- rag/utils/redis_conn.py +4 -0
api/apps/api_app.py
CHANGED
@@ -13,10 +13,11 @@
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
|
|
16 |
import os
|
17 |
import re
|
18 |
from datetime import datetime, timedelta
|
19 |
-
from flask import request
|
20 |
from flask_login import login_required, current_user
|
21 |
|
22 |
from api.db import FileType, ParserType
|
@@ -31,11 +32,11 @@ from api.settings import RetCode
|
|
31 |
from api.utils import get_uuid, current_timestamp, datetime_format
|
32 |
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request
|
33 |
from itsdangerous import URLSafeTimedSerializer
|
34 |
-
|
35 |
from api.utils.file_utils import filename_type, thumbnail
|
36 |
from rag.utils.minio_conn import MINIO
|
37 |
-
|
38 |
-
|
39 |
def generate_confirmation_token(tenent_id):
|
40 |
serializer = URLSafeTimedSerializer(tenent_id)
|
41 |
return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34]
|
@@ -164,6 +165,7 @@ def completion():
|
|
164 |
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
165 |
if not e:
|
166 |
return get_data_error_result(retmsg="Conversation not found!")
|
|
|
167 |
|
168 |
msg = []
|
169 |
for m in req["messages"]:
|
@@ -180,13 +182,45 @@ def completion():
|
|
180 |
return get_data_error_result(retmsg="Dialog not found!")
|
181 |
del req["conversation_id"]
|
182 |
del req["messages"]
|
183 |
-
|
184 |
if not conv.reference:
|
185 |
conv.reference = []
|
186 |
-
conv.
|
187 |
-
conv.
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
except Exception as e:
|
191 |
return server_error_response(e)
|
192 |
|
@@ -229,7 +263,6 @@ def upload():
|
|
229 |
return get_json_result(
|
230 |
data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
|
231 |
|
232 |
-
|
233 |
file = request.files['file']
|
234 |
if file.filename == '':
|
235 |
return get_json_result(
|
@@ -253,7 +286,6 @@ def upload():
|
|
253 |
location += "_"
|
254 |
blob = request.files['file'].read()
|
255 |
MINIO.put(kb_id, location, blob)
|
256 |
-
|
257 |
doc = {
|
258 |
"id": get_uuid(),
|
259 |
"kb_id": kb.id,
|
@@ -266,42 +298,11 @@ def upload():
|
|
266 |
"size": len(blob),
|
267 |
"thumbnail": thumbnail(filename, blob)
|
268 |
}
|
269 |
-
|
270 |
-
form_data=request.form
|
271 |
-
if "parser_id" in form_data.keys():
|
272 |
-
if request.form.get("parser_id").strip() in list(vars(ParserType).values())[1:-3]:
|
273 |
-
doc["parser_id"] = request.form.get("parser_id").strip()
|
274 |
if doc["type"] == FileType.VISUAL:
|
275 |
doc["parser_id"] = ParserType.PICTURE.value
|
276 |
if re.search(r"\.(ppt|pptx|pages)$", filename):
|
277 |
doc["parser_id"] = ParserType.PRESENTATION.value
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
except Exception as e:
|
282 |
return server_error_response(e)
|
283 |
-
|
284 |
-
if "run" in form_data.keys():
|
285 |
-
if request.form.get("run").strip() == "1":
|
286 |
-
try:
|
287 |
-
info = {"run": 1, "progress": 0}
|
288 |
-
info["progress_msg"] = ""
|
289 |
-
info["chunk_num"] = 0
|
290 |
-
info["token_num"] = 0
|
291 |
-
DocumentService.update_by_id(doc["id"], info)
|
292 |
-
# if str(req["run"]) == TaskStatus.CANCEL.value:
|
293 |
-
tenant_id = DocumentService.get_tenant_id(doc["id"])
|
294 |
-
if not tenant_id:
|
295 |
-
return get_data_error_result(retmsg="Tenant not found!")
|
296 |
-
|
297 |
-
#e, doc = DocumentService.get_by_id(doc["id"])
|
298 |
-
TaskService.filter_delete([Task.doc_id == doc["id"]])
|
299 |
-
e, doc = DocumentService.get_by_id(doc["id"])
|
300 |
-
doc = doc.to_dict()
|
301 |
-
doc["tenant_id"] = tenant_id
|
302 |
-
bucket, name = File2DocumentService.get_minio_address(doc_id=doc["id"])
|
303 |
-
queue_tasks(doc, bucket, name)
|
304 |
-
except Exception as e:
|
305 |
-
return server_error_response(e)
|
306 |
-
|
307 |
-
return get_json_result(data=doc_result.to_json())
|
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
+
import json
|
17 |
import os
|
18 |
import re
|
19 |
from datetime import datetime, timedelta
|
20 |
+
from flask import request, Response
|
21 |
from flask_login import login_required, current_user
|
22 |
|
23 |
from api.db import FileType, ParserType
|
|
|
32 |
from api.utils import get_uuid, current_timestamp, datetime_format
|
33 |
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request
|
34 |
from itsdangerous import URLSafeTimedSerializer
|
35 |
+
|
36 |
from api.utils.file_utils import filename_type, thumbnail
|
37 |
from rag.utils.minio_conn import MINIO
|
38 |
+
|
39 |
+
|
40 |
def generate_confirmation_token(tenent_id):
|
41 |
serializer = URLSafeTimedSerializer(tenent_id)
|
42 |
return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34]
|
|
|
165 |
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
166 |
if not e:
|
167 |
return get_data_error_result(retmsg="Conversation not found!")
|
168 |
+
if "quote" not in req: req["quote"] = False
|
169 |
|
170 |
msg = []
|
171 |
for m in req["messages"]:
|
|
|
182 |
return get_data_error_result(retmsg="Dialog not found!")
|
183 |
del req["conversation_id"]
|
184 |
del req["messages"]
|
185 |
+
|
186 |
if not conv.reference:
|
187 |
conv.reference = []
|
188 |
+
conv.message.append({"role": "assistant", "content": ""})
|
189 |
+
conv.reference.append({"chunks": [], "doc_aggs": []})
|
190 |
+
|
191 |
+
def fillin_conv(ans):
|
192 |
+
nonlocal conv
|
193 |
+
if not conv.reference:
|
194 |
+
conv.reference.append(ans["reference"])
|
195 |
+
else: conv.reference[-1] = ans["reference"]
|
196 |
+
conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
|
197 |
+
|
198 |
+
def stream():
|
199 |
+
nonlocal dia, msg, req, conv
|
200 |
+
try:
|
201 |
+
for ans in chat(dia, msg, True, **req):
|
202 |
+
fillin_conv(ans)
|
203 |
+
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
204 |
+
API4ConversationService.append_message(conv.id, conv.to_dict())
|
205 |
+
except Exception as e:
|
206 |
+
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
|
207 |
+
"data": {"answer": "**ERROR**: "+str(e), "reference": []}},
|
208 |
+
ensure_ascii=False) + "\n\n"
|
209 |
+
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
|
210 |
+
|
211 |
+
if req.get("stream", True):
|
212 |
+
resp = Response(stream(), mimetype="text/event-stream")
|
213 |
+
resp.headers.add_header("Cache-control", "no-cache")
|
214 |
+
resp.headers.add_header("Connection", "keep-alive")
|
215 |
+
resp.headers.add_header("X-Accel-Buffering", "no")
|
216 |
+
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
217 |
+
return resp
|
218 |
+
else:
|
219 |
+
ans = chat(dia, msg, False, **req)
|
220 |
+
fillin_conv(ans)
|
221 |
+
API4ConversationService.append_message(conv.id, conv.to_dict())
|
222 |
+
return get_json_result(data=ans)
|
223 |
+
|
224 |
except Exception as e:
|
225 |
return server_error_response(e)
|
226 |
|
|
|
263 |
return get_json_result(
|
264 |
data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
|
265 |
|
|
|
266 |
file = request.files['file']
|
267 |
if file.filename == '':
|
268 |
return get_json_result(
|
|
|
286 |
location += "_"
|
287 |
blob = request.files['file'].read()
|
288 |
MINIO.put(kb_id, location, blob)
|
|
|
289 |
doc = {
|
290 |
"id": get_uuid(),
|
291 |
"kb_id": kb.id,
|
|
|
298 |
"size": len(blob),
|
299 |
"thumbnail": thumbnail(filename, blob)
|
300 |
}
|
|
|
|
|
|
|
|
|
|
|
301 |
if doc["type"] == FileType.VISUAL:
|
302 |
doc["parser_id"] = ParserType.PICTURE.value
|
303 |
if re.search(r"\.(ppt|pptx|pages)$", filename):
|
304 |
doc["parser_id"] = ParserType.PRESENTATION.value
|
305 |
+
doc = DocumentService.insert(doc)
|
306 |
+
return get_json_result(data=doc.to_json())
|
|
|
307 |
except Exception as e:
|
308 |
return server_error_response(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api/apps/conversation_app.py
CHANGED
@@ -13,12 +13,13 @@
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
-
from flask import request
|
17 |
from flask_login import login_required
|
18 |
from api.db.services.dialog_service import DialogService, ConversationService, chat
|
19 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
20 |
from api.utils import get_uuid
|
21 |
from api.utils.api_utils import get_json_result
|
|
|
22 |
|
23 |
|
24 |
@manager.route('/set', methods=['POST'])
|
@@ -103,9 +104,12 @@ def list_convsersation():
|
|
103 |
|
104 |
@manager.route('/completion', methods=['POST'])
|
105 |
@login_required
|
106 |
-
|
107 |
def completion():
|
108 |
req = request.json
|
|
|
|
|
|
|
109 |
msg = []
|
110 |
for m in req["messages"]:
|
111 |
if m["role"] == "system":
|
@@ -123,13 +127,45 @@ def completion():
|
|
123 |
return get_data_error_result(retmsg="Dialog not found!")
|
124 |
del req["conversation_id"]
|
125 |
del req["messages"]
|
126 |
-
|
127 |
if not conv.reference:
|
128 |
conv.reference = []
|
129 |
-
conv.
|
130 |
-
conv.
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
except Exception as e:
|
134 |
return server_error_response(e)
|
135 |
|
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
+
from flask import request, Response, jsonify
|
17 |
from flask_login import login_required
|
18 |
from api.db.services.dialog_service import DialogService, ConversationService, chat
|
19 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
20 |
from api.utils import get_uuid
|
21 |
from api.utils.api_utils import get_json_result
|
22 |
+
import json
|
23 |
|
24 |
|
25 |
@manager.route('/set', methods=['POST'])
|
|
|
104 |
|
105 |
@manager.route('/completion', methods=['POST'])
|
106 |
@login_required
|
107 |
+
#@validate_request("conversation_id", "messages")
|
108 |
def completion():
|
109 |
req = request.json
|
110 |
+
#req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
|
111 |
+
# {"role": "user", "content": "上海有吗?"}
|
112 |
+
#]}
|
113 |
msg = []
|
114 |
for m in req["messages"]:
|
115 |
if m["role"] == "system":
|
|
|
127 |
return get_data_error_result(retmsg="Dialog not found!")
|
128 |
del req["conversation_id"]
|
129 |
del req["messages"]
|
130 |
+
|
131 |
if not conv.reference:
|
132 |
conv.reference = []
|
133 |
+
conv.message.append({"role": "assistant", "content": ""})
|
134 |
+
conv.reference.append({"chunks": [], "doc_aggs": []})
|
135 |
+
|
136 |
+
def fillin_conv(ans):
|
137 |
+
nonlocal conv
|
138 |
+
if not conv.reference:
|
139 |
+
conv.reference.append(ans["reference"])
|
140 |
+
else: conv.reference[-1] = ans["reference"]
|
141 |
+
conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
|
142 |
+
|
143 |
+
def stream():
|
144 |
+
nonlocal dia, msg, req, conv
|
145 |
+
try:
|
146 |
+
for ans in chat(dia, msg, True, **req):
|
147 |
+
fillin_conv(ans)
|
148 |
+
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
149 |
+
ConversationService.update_by_id(conv.id, conv.to_dict())
|
150 |
+
except Exception as e:
|
151 |
+
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
|
152 |
+
"data": {"answer": "**ERROR**: "+str(e), "reference": []}},
|
153 |
+
ensure_ascii=False) + "\n\n"
|
154 |
+
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
|
155 |
+
|
156 |
+
if req.get("stream", True):
|
157 |
+
resp = Response(stream(), mimetype="text/event-stream")
|
158 |
+
resp.headers.add_header("Cache-control", "no-cache")
|
159 |
+
resp.headers.add_header("Connection", "keep-alive")
|
160 |
+
resp.headers.add_header("X-Accel-Buffering", "no")
|
161 |
+
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
162 |
+
return resp
|
163 |
+
|
164 |
+
else:
|
165 |
+
ans = chat(dia, msg, False, **req)
|
166 |
+
fillin_conv(ans)
|
167 |
+
ConversationService.update_by_id(conv.id, conv.to_dict())
|
168 |
+
return get_json_result(data=ans)
|
169 |
except Exception as e:
|
170 |
return server_error_response(e)
|
171 |
|
api/apps/system_app.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License
|
15 |
+
#
|
16 |
+
from flask_login import login_required
|
17 |
+
|
18 |
+
from api.db.services.knowledgebase_service import KnowledgebaseService
|
19 |
+
from api.utils.api_utils import get_json_result
|
20 |
+
from api.versions import get_rag_version
|
21 |
+
from rag.settings import SVR_QUEUE_NAME
|
22 |
+
from rag.utils.es_conn import ELASTICSEARCH
|
23 |
+
from rag.utils.minio_conn import MINIO
|
24 |
+
from timeit import default_timer as timer
|
25 |
+
|
26 |
+
from rag.utils.redis_conn import REDIS_CONN
|
27 |
+
|
28 |
+
|
29 |
+
@manager.route('/version', methods=['GET'])
|
30 |
+
@login_required
|
31 |
+
def version():
|
32 |
+
return get_json_result(data=get_rag_version())
|
33 |
+
|
34 |
+
|
35 |
+
@manager.route('/status', methods=['GET'])
|
36 |
+
@login_required
|
37 |
+
def status():
|
38 |
+
res = {}
|
39 |
+
st = timer()
|
40 |
+
try:
|
41 |
+
res["es"] = ELASTICSEARCH.health()
|
42 |
+
res["es"]["elapsed"] = "{:.1f}".format((timer() - st)*1000.)
|
43 |
+
except Exception as e:
|
44 |
+
res["es"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)}
|
45 |
+
|
46 |
+
st = timer()
|
47 |
+
try:
|
48 |
+
MINIO.health()
|
49 |
+
res["minio"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.)}
|
50 |
+
except Exception as e:
|
51 |
+
res["minio"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)}
|
52 |
+
|
53 |
+
st = timer()
|
54 |
+
try:
|
55 |
+
KnowledgebaseService.get_by_id("x")
|
56 |
+
res["mysql"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.)}
|
57 |
+
except Exception as e:
|
58 |
+
res["mysql"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)}
|
59 |
+
|
60 |
+
st = timer()
|
61 |
+
try:
|
62 |
+
qinfo = REDIS_CONN.health(SVR_QUEUE_NAME)
|
63 |
+
res["redis"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.), "pending": qinfo["pending"]}
|
64 |
+
except Exception as e:
|
65 |
+
res["redis"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)}
|
66 |
+
|
67 |
+
return get_json_result(data=res)
|
api/db/services/dialog_service.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
import re
|
|
|
17 |
|
18 |
from api.db import LLMType
|
19 |
from api.db.db_models import Dialog, Conversation
|
@@ -71,7 +72,7 @@ def message_fit_in(msg, max_length=4000):
|
|
71 |
return max_length, msg
|
72 |
|
73 |
|
74 |
-
def chat(dialog, messages, **kwargs):
|
75 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
76 |
llm = LLMService.query(llm_name=dialog.llm_id)
|
77 |
if not llm:
|
@@ -82,7 +83,10 @@ def chat(dialog, messages, **kwargs):
|
|
82 |
else: max_tokens = llm[0].max_tokens
|
83 |
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
84 |
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
85 |
-
|
|
|
|
|
|
|
86 |
|
87 |
questions = [m["content"] for m in messages if m["role"] == "user"]
|
88 |
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
@@ -94,7 +98,9 @@ def chat(dialog, messages, **kwargs):
|
|
94 |
if field_map:
|
95 |
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
|
96 |
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
|
97 |
-
if ans:
|
|
|
|
|
98 |
|
99 |
for p in prompt_config["parameters"]:
|
100 |
if p["key"] == "knowledge":
|
@@ -118,8 +124,9 @@ def chat(dialog, messages, **kwargs):
|
|
118 |
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
119 |
|
120 |
if not knowledges and prompt_config.get("empty_response"):
|
121 |
-
|
122 |
-
"answer": prompt_config["empty_response"], "reference": kbinfos}
|
|
|
123 |
|
124 |
kwargs["knowledge"] = "\n".join(knowledges)
|
125 |
gen_conf = dialog.llm_setting
|
@@ -130,33 +137,45 @@ def chat(dialog, messages, **kwargs):
|
|
130 |
gen_conf["max_tokens"] = min(
|
131 |
gen_conf["max_tokens"],
|
132 |
max_tokens - used_token_count)
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
if
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
|
162 |
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
import re
|
17 |
+
from copy import deepcopy
|
18 |
|
19 |
from api.db import LLMType
|
20 |
from api.db.db_models import Dialog, Conversation
|
|
|
72 |
return max_length, msg
|
73 |
|
74 |
|
75 |
+
def chat(dialog, messages, stream=True, **kwargs):
|
76 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
77 |
llm = LLMService.query(llm_name=dialog.llm_id)
|
78 |
if not llm:
|
|
|
83 |
else: max_tokens = llm[0].max_tokens
|
84 |
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
85 |
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
86 |
+
if len(embd_nms) != 1:
|
87 |
+
if stream:
|
88 |
+
yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
89 |
+
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
90 |
|
91 |
questions = [m["content"] for m in messages if m["role"] == "user"]
|
92 |
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
|
|
98 |
if field_map:
|
99 |
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
|
100 |
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
|
101 |
+
if ans:
|
102 |
+
yield ans
|
103 |
+
return
|
104 |
|
105 |
for p in prompt_config["parameters"]:
|
106 |
if p["key"] == "knowledge":
|
|
|
124 |
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
125 |
|
126 |
if not knowledges and prompt_config.get("empty_response"):
|
127 |
+
if stream:
|
128 |
+
yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
129 |
+
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
130 |
|
131 |
kwargs["knowledge"] = "\n".join(knowledges)
|
132 |
gen_conf = dialog.llm_setting
|
|
|
137 |
gen_conf["max_tokens"] = min(
|
138 |
gen_conf["max_tokens"],
|
139 |
max_tokens - used_token_count)
|
140 |
+
|
141 |
+
def decorate_answer(answer):
|
142 |
+
nonlocal prompt_config, knowledges, kwargs, kbinfos
|
143 |
+
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
144 |
+
answer, idx = retrievaler.insert_citations(answer,
|
145 |
+
[ck["content_ltks"]
|
146 |
+
for ck in kbinfos["chunks"]],
|
147 |
+
[ck["vector"]
|
148 |
+
for ck in kbinfos["chunks"]],
|
149 |
+
embd_mdl,
|
150 |
+
tkweight=1 - dialog.vector_similarity_weight,
|
151 |
+
vtweight=dialog.vector_similarity_weight)
|
152 |
+
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
153 |
+
recall_docs = [
|
154 |
+
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
155 |
+
if not recall_docs: recall_docs = kbinfos["doc_aggs"]
|
156 |
+
kbinfos["doc_aggs"] = recall_docs
|
157 |
+
|
158 |
+
refs = deepcopy(kbinfos)
|
159 |
+
for c in refs["chunks"]:
|
160 |
+
if c.get("vector"):
|
161 |
+
del c["vector"]
|
162 |
+
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api")>=0:
|
163 |
+
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
164 |
+
return {"answer": answer, "reference": refs}
|
165 |
+
|
166 |
+
if stream:
|
167 |
+
answer = ""
|
168 |
+
for ans in chat_mdl.chat_streamly(prompt_config["system"].format(**kwargs), msg, gen_conf):
|
169 |
+
answer = ans
|
170 |
+
yield {"answer": answer, "reference": {}}
|
171 |
+
yield decorate_answer(answer)
|
172 |
+
else:
|
173 |
+
answer = chat_mdl.chat(
|
174 |
+
prompt_config["system"].format(
|
175 |
+
**kwargs), msg, gen_conf)
|
176 |
+
chat_logger.info("User: {}|Assistant: {}".format(
|
177 |
+
msg[-1]["content"], answer))
|
178 |
+
return decorate_answer(answer)
|
179 |
|
180 |
|
181 |
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
api/db/services/document_service.py
CHANGED
@@ -43,7 +43,7 @@ class DocumentService(CommonService):
|
|
43 |
docs = cls.model.select().where(
|
44 |
(cls.model.kb_id == kb_id),
|
45 |
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
46 |
-
|
47 |
else:
|
48 |
docs = cls.model.select().where(cls.model.kb_id == kb_id)
|
49 |
count = docs.count()
|
@@ -75,7 +75,7 @@ class DocumentService(CommonService):
|
|
75 |
def delete(cls, doc):
|
76 |
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
77 |
if not KnowledgebaseService.update_by_id(
|
78 |
-
kb.id, {"doc_num": kb.doc_num - 1}):
|
79 |
raise RuntimeError("Database error (Knowledgebase)!")
|
80 |
return cls.delete_by_id(doc.id)
|
81 |
|
|
|
43 |
docs = cls.model.select().where(
|
44 |
(cls.model.kb_id == kb_id),
|
45 |
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
46 |
+
)
|
47 |
else:
|
48 |
docs = cls.model.select().where(cls.model.kb_id == kb_id)
|
49 |
count = docs.count()
|
|
|
75 |
def delete(cls, doc):
|
76 |
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
77 |
if not KnowledgebaseService.update_by_id(
|
78 |
+
kb.id, {"doc_num": max(0, kb.doc_num - 1)}):
|
79 |
raise RuntimeError("Database error (Knowledgebase)!")
|
80 |
return cls.delete_by_id(doc.id)
|
81 |
|
api/db/services/llm_service.py
CHANGED
@@ -172,8 +172,18 @@ class LLMBundle(object):
|
|
172 |
|
173 |
def chat(self, system, history, gen_conf):
|
174 |
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
|
175 |
-
if TenantLLMService.increase_usage(
|
176 |
self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
177 |
database_logger.error(
|
178 |
"Can't update token usage for {}/CHAT".format(self.tenant_id))
|
179 |
return txt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
def chat(self, system, history, gen_conf):
|
174 |
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
|
175 |
+
if not TenantLLMService.increase_usage(
|
176 |
self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
177 |
database_logger.error(
|
178 |
"Can't update token usage for {}/CHAT".format(self.tenant_id))
|
179 |
return txt
|
180 |
+
|
181 |
+
def chat_streamly(self, system, history, gen_conf):
|
182 |
+
for txt in self.mdl.chat_streamly(system, history, gen_conf):
|
183 |
+
if isinstance(txt, int):
|
184 |
+
if not TenantLLMService.increase_usage(
|
185 |
+
self.tenant_id, self.llm_type, txt, self.llm_name):
|
186 |
+
database_logger.error(
|
187 |
+
"Can't update token usage for {}/CHAT".format(self.tenant_id))
|
188 |
+
return
|
189 |
+
yield txt
|
api/utils/api_utils.py
CHANGED
@@ -25,7 +25,6 @@ from flask import (
|
|
25 |
from werkzeug.http import HTTP_STATUS_CODES
|
26 |
|
27 |
from api.utils import json_dumps
|
28 |
-
from api.versions import get_rag_version
|
29 |
from api.settings import RetCode
|
30 |
from api.settings import (
|
31 |
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
|
@@ -84,9 +83,6 @@ def request(**kwargs):
|
|
84 |
return sess.send(prepped, stream=stream, timeout=timeout)
|
85 |
|
86 |
|
87 |
-
rag_version = get_rag_version() or ''
|
88 |
-
|
89 |
-
|
90 |
def get_exponential_backoff_interval(retries, full_jitter=False):
|
91 |
"""Calculate the exponential backoff wait time."""
|
92 |
# Will be zero if factor equals 0
|
|
|
25 |
from werkzeug.http import HTTP_STATUS_CODES
|
26 |
|
27 |
from api.utils import json_dumps
|
|
|
28 |
from api.settings import RetCode
|
29 |
from api.settings import (
|
30 |
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
|
|
|
83 |
return sess.send(prepped, stream=stream, timeout=timeout)
|
84 |
|
85 |
|
|
|
|
|
|
|
86 |
def get_exponential_backoff_interval(retries, full_jitter=False):
|
87 |
"""Calculate the exponential backoff wait time."""
|
88 |
# Will be zero if factor equals 0
|
rag/llm/chat_model.py
CHANGED
@@ -20,7 +20,6 @@ from openai import OpenAI
|
|
20 |
import openai
|
21 |
from ollama import Client
|
22 |
from rag.nlp import is_english
|
23 |
-
from rag.utils import num_tokens_from_string
|
24 |
|
25 |
|
26 |
class Base(ABC):
|
@@ -44,6 +43,31 @@ class Base(ABC):
|
|
44 |
except openai.APIError as e:
|
45 |
return "**ERROR**: " + str(e), 0
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
class GptTurbo(Base):
|
49 |
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
|
@@ -97,6 +121,35 @@ class QWenChat(Base):
|
|
97 |
|
98 |
return "**ERROR**: " + response.message, tk_count
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
class ZhipuChat(Base):
|
102 |
def __init__(self, key, model_name="glm-3-turbo", **kwargs):
|
@@ -122,6 +175,34 @@ class ZhipuChat(Base):
|
|
122 |
except Exception as e:
|
123 |
return "**ERROR**: " + str(e), 0
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
class OllamaChat(Base):
|
127 |
def __init__(self, key, model_name, **kwargs):
|
@@ -148,3 +229,28 @@ class OllamaChat(Base):
|
|
148 |
except Exception as e:
|
149 |
return "**ERROR**: " + str(e), 0
|
150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
import openai
|
21 |
from ollama import Client
|
22 |
from rag.nlp import is_english
|
|
|
23 |
|
24 |
|
25 |
class Base(ABC):
|
|
|
43 |
except openai.APIError as e:
|
44 |
return "**ERROR**: " + str(e), 0
|
45 |
|
46 |
+
def chat_streamly(self, system, history, gen_conf):
|
47 |
+
if system:
|
48 |
+
history.insert(0, {"role": "system", "content": system})
|
49 |
+
ans = ""
|
50 |
+
total_tokens = 0
|
51 |
+
try:
|
52 |
+
response = self.client.chat.completions.create(
|
53 |
+
model=self.model_name,
|
54 |
+
messages=history,
|
55 |
+
stream=True,
|
56 |
+
**gen_conf)
|
57 |
+
for resp in response:
|
58 |
+
if not resp.choices[0].delta.content:continue
|
59 |
+
ans += resp.choices[0].delta.content
|
60 |
+
total_tokens += 1
|
61 |
+
if resp.choices[0].finish_reason == "length":
|
62 |
+
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
63 |
+
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
64 |
+
yield ans
|
65 |
+
|
66 |
+
except openai.APIError as e:
|
67 |
+
yield ans + "\n**ERROR**: " + str(e)
|
68 |
+
|
69 |
+
yield total_tokens
|
70 |
+
|
71 |
|
72 |
class GptTurbo(Base):
|
73 |
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
|
|
|
121 |
|
122 |
return "**ERROR**: " + response.message, tk_count
|
123 |
|
124 |
+
def chat_streamly(self, system, history, gen_conf):
|
125 |
+
from http import HTTPStatus
|
126 |
+
if system:
|
127 |
+
history.insert(0, {"role": "system", "content": system})
|
128 |
+
ans = ""
|
129 |
+
try:
|
130 |
+
response = Generation.call(
|
131 |
+
self.model_name,
|
132 |
+
messages=history,
|
133 |
+
result_format='message',
|
134 |
+
stream=True,
|
135 |
+
**gen_conf
|
136 |
+
)
|
137 |
+
tk_count = 0
|
138 |
+
for resp in response:
|
139 |
+
if resp.status_code == HTTPStatus.OK:
|
140 |
+
ans = resp.output.choices[0]['message']['content']
|
141 |
+
tk_count = resp.usage.total_tokens
|
142 |
+
if resp.output.choices[0].get("finish_reason", "") == "length":
|
143 |
+
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
144 |
+
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
145 |
+
yield ans
|
146 |
+
else:
|
147 |
+
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access")<0 else "Out of credit. Please set the API key in **settings > Model providers.**"
|
148 |
+
except Exception as e:
|
149 |
+
yield ans + "\n**ERROR**: " + str(e)
|
150 |
+
|
151 |
+
yield tk_count
|
152 |
+
|
153 |
|
154 |
class ZhipuChat(Base):
|
155 |
def __init__(self, key, model_name="glm-3-turbo", **kwargs):
|
|
|
175 |
except Exception as e:
|
176 |
return "**ERROR**: " + str(e), 0
|
177 |
|
178 |
+
def chat_streamly(self, system, history, gen_conf):
|
179 |
+
if system:
|
180 |
+
history.insert(0, {"role": "system", "content": system})
|
181 |
+
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
|
182 |
+
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
|
183 |
+
ans = ""
|
184 |
+
try:
|
185 |
+
response = self.client.chat.completions.create(
|
186 |
+
model=self.model_name,
|
187 |
+
messages=history,
|
188 |
+
stream=True,
|
189 |
+
**gen_conf
|
190 |
+
)
|
191 |
+
tk_count = 0
|
192 |
+
for resp in response:
|
193 |
+
if not resp.choices[0].delta.content:continue
|
194 |
+
delta = resp.choices[0].delta.content
|
195 |
+
ans += delta
|
196 |
+
tk_count = resp.usage.total_tokens if response.usage else 0
|
197 |
+
if resp.output.choices[0].finish_reason == "length":
|
198 |
+
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
199 |
+
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
200 |
+
yield ans
|
201 |
+
except Exception as e:
|
202 |
+
yield ans + "\n**ERROR**: " + str(e)
|
203 |
+
|
204 |
+
yield tk_count
|
205 |
+
|
206 |
|
207 |
class OllamaChat(Base):
|
208 |
def __init__(self, key, model_name, **kwargs):
|
|
|
229 |
except Exception as e:
|
230 |
return "**ERROR**: " + str(e), 0
|
231 |
|
232 |
+
def chat_streamly(self, system, history, gen_conf):
|
233 |
+
if system:
|
234 |
+
history.insert(0, {"role": "system", "content": system})
|
235 |
+
options = {}
|
236 |
+
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
|
237 |
+
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
|
238 |
+
if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
|
239 |
+
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
|
240 |
+
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
241 |
+
ans = ""
|
242 |
+
try:
|
243 |
+
response = self.client.chat(
|
244 |
+
model=self.model_name,
|
245 |
+
messages=history,
|
246 |
+
stream=True,
|
247 |
+
options=options
|
248 |
+
)
|
249 |
+
for resp in response:
|
250 |
+
if resp["done"]:
|
251 |
+
return resp["prompt_eval_count"] + resp["eval_count"]
|
252 |
+
ans = resp["message"]["content"]
|
253 |
+
yield ans
|
254 |
+
except Exception as e:
|
255 |
+
yield ans + "\n**ERROR**: " + str(e)
|
256 |
+
yield 0
|
rag/svr/task_executor.py
CHANGED
@@ -80,7 +80,7 @@ def set_progress(task_id, from_page=0, to_page=-1,
|
|
80 |
|
81 |
if to_page > 0:
|
82 |
if msg:
|
83 |
-
msg = f"Page({from_page+1}~{to_page+1}): " + msg
|
84 |
d = {"progress_msg": msg}
|
85 |
if prog is not None:
|
86 |
d["progress"] = prog
|
@@ -124,7 +124,7 @@ def get_minio_binary(bucket, name):
|
|
124 |
def build(row):
|
125 |
if row["size"] > DOC_MAXIMUM_SIZE:
|
126 |
set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
|
127 |
-
|
128 |
return []
|
129 |
|
130 |
callback = partial(
|
@@ -138,12 +138,12 @@ def build(row):
|
|
138 |
bucket, name = File2DocumentService.get_minio_address(doc_id=row["doc_id"])
|
139 |
binary = get_minio_binary(bucket, name)
|
140 |
cron_logger.info(
|
141 |
-
"From minio({}) {}/{}".format(timer()-st, row["location"], row["name"]))
|
142 |
cks = chunker.chunk(row["name"], binary=binary, from_page=row["from_page"],
|
143 |
to_page=row["to_page"], lang=row["language"], callback=callback,
|
144 |
kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"])
|
145 |
cron_logger.info(
|
146 |
-
"Chunkking({}) {}/{}".format(timer()-st, row["location"], row["name"]))
|
147 |
except TimeoutError as e:
|
148 |
callback(-1, f"Internal server error: Fetch file timeout. Could you try it again.")
|
149 |
cron_logger.error(
|
@@ -173,7 +173,7 @@ def build(row):
|
|
173 |
d.update(ck)
|
174 |
md5 = hashlib.md5()
|
175 |
md5.update((ck["content_with_weight"] +
|
176 |
-
|
177 |
d["_id"] = md5.hexdigest()
|
178 |
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
179 |
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
@@ -261,7 +261,7 @@ def main():
|
|
261 |
|
262 |
st = timer()
|
263 |
cks = build(r)
|
264 |
-
cron_logger.info("Build chunks({}): {
|
265 |
if cks is None:
|
266 |
continue
|
267 |
if not cks:
|
@@ -271,7 +271,7 @@ def main():
|
|
271 |
## set_progress(r["did"], -1, "ERROR: ")
|
272 |
callback(
|
273 |
msg="Finished slicing files(%d). Start to embedding the content." %
|
274 |
-
|
275 |
st = timer()
|
276 |
try:
|
277 |
tk_count = embedding(cks, embd_mdl, r["parser_config"], callback)
|
@@ -279,19 +279,19 @@ def main():
|
|
279 |
callback(-1, "Embedding error:{}".format(str(e)))
|
280 |
cron_logger.error(str(e))
|
281 |
tk_count = 0
|
282 |
-
cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer()-st))
|
283 |
|
284 |
-
callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer()-st))
|
285 |
init_kb(r)
|
286 |
chunk_count = len(set([c["_id"] for c in cks]))
|
287 |
st = timer()
|
288 |
es_r = ""
|
289 |
for b in range(0, len(cks), 32):
|
290 |
-
es_r = ELASTICSEARCH.bulk(cks[b:b+32], search.index_name(r["tenant_id"]))
|
291 |
if b % 128 == 0:
|
292 |
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
|
293 |
|
294 |
-
cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer()-st))
|
295 |
if es_r:
|
296 |
callback(-1, "Index failure!")
|
297 |
ELASTICSEARCH.deleteByQuery(
|
@@ -307,8 +307,7 @@ def main():
|
|
307 |
r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
|
308 |
cron_logger.info(
|
309 |
"Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format(
|
310 |
-
r["id"], tk_count, len(cks), timer()-st))
|
311 |
-
|
312 |
|
313 |
|
314 |
if __name__ == "__main__":
|
|
|
80 |
|
81 |
if to_page > 0:
|
82 |
if msg:
|
83 |
+
msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
|
84 |
d = {"progress_msg": msg}
|
85 |
if prog is not None:
|
86 |
d["progress"] = prog
|
|
|
124 |
def build(row):
|
125 |
if row["size"] > DOC_MAXIMUM_SIZE:
|
126 |
set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
|
127 |
+
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
128 |
return []
|
129 |
|
130 |
callback = partial(
|
|
|
138 |
bucket, name = File2DocumentService.get_minio_address(doc_id=row["doc_id"])
|
139 |
binary = get_minio_binary(bucket, name)
|
140 |
cron_logger.info(
|
141 |
+
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
142 |
cks = chunker.chunk(row["name"], binary=binary, from_page=row["from_page"],
|
143 |
to_page=row["to_page"], lang=row["language"], callback=callback,
|
144 |
kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"])
|
145 |
cron_logger.info(
|
146 |
+
"Chunkking({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
147 |
except TimeoutError as e:
|
148 |
callback(-1, f"Internal server error: Fetch file timeout. Could you try it again.")
|
149 |
cron_logger.error(
|
|
|
173 |
d.update(ck)
|
174 |
md5 = hashlib.md5()
|
175 |
md5.update((ck["content_with_weight"] +
|
176 |
+
str(d["doc_id"])).encode("utf-8"))
|
177 |
d["_id"] = md5.hexdigest()
|
178 |
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
179 |
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
|
|
261 |
|
262 |
st = timer()
|
263 |
cks = build(r)
|
264 |
+
cron_logger.info("Build chunks({}): {}".format(r["name"], timer() - st))
|
265 |
if cks is None:
|
266 |
continue
|
267 |
if not cks:
|
|
|
271 |
## set_progress(r["did"], -1, "ERROR: ")
|
272 |
callback(
|
273 |
msg="Finished slicing files(%d). Start to embedding the content." %
|
274 |
+
len(cks))
|
275 |
st = timer()
|
276 |
try:
|
277 |
tk_count = embedding(cks, embd_mdl, r["parser_config"], callback)
|
|
|
279 |
callback(-1, "Embedding error:{}".format(str(e)))
|
280 |
cron_logger.error(str(e))
|
281 |
tk_count = 0
|
282 |
+
cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
283 |
|
284 |
+
callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st))
|
285 |
init_kb(r)
|
286 |
chunk_count = len(set([c["_id"] for c in cks]))
|
287 |
st = timer()
|
288 |
es_r = ""
|
289 |
for b in range(0, len(cks), 32):
|
290 |
+
es_r = ELASTICSEARCH.bulk(cks[b:b + 32], search.index_name(r["tenant_id"]))
|
291 |
if b % 128 == 0:
|
292 |
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
|
293 |
|
294 |
+
cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
295 |
if es_r:
|
296 |
callback(-1, "Index failure!")
|
297 |
ELASTICSEARCH.deleteByQuery(
|
|
|
307 |
r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
|
308 |
cron_logger.info(
|
309 |
"Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format(
|
310 |
+
r["id"], tk_count, len(cks), timer() - st))
|
|
|
311 |
|
312 |
|
313 |
if __name__ == "__main__":
|
rag/utils/es_conn.py
CHANGED
@@ -43,6 +43,9 @@ class ESConnection:
|
|
43 |
v = v["number"].split(".")[0]
|
44 |
return int(v) >= 7
|
45 |
|
|
|
|
|
|
|
46 |
def upsert(self, df, idxnm=""):
|
47 |
res = []
|
48 |
for d in df:
|
|
|
43 |
v = v["number"].split(".")[0]
|
44 |
return int(v) >= 7
|
45 |
|
46 |
+
def health(self):
|
47 |
+
return dict(self.es.cluster.health())
|
48 |
+
|
49 |
def upsert(self, df, idxnm=""):
|
50 |
res = []
|
51 |
for d in df:
|
rag/utils/minio_conn.py
CHANGED
@@ -34,6 +34,16 @@ class RAGFlowMinio(object):
|
|
34 |
del self.conn
|
35 |
self.conn = None
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
def put(self, bucket, fnm, binary):
|
38 |
for _ in range(3):
|
39 |
try:
|
|
|
34 |
del self.conn
|
35 |
self.conn = None
|
36 |
|
37 |
+
def health(self):
|
38 |
+
bucket, fnm, binary = "_t@@@1", "_t@@@1", b"_t@@@1"
|
39 |
+
if not self.conn.bucket_exists(bucket):
|
40 |
+
self.conn.make_bucket(bucket)
|
41 |
+
r = self.conn.put_object(bucket, fnm,
|
42 |
+
BytesIO(binary),
|
43 |
+
len(binary)
|
44 |
+
)
|
45 |
+
return r
|
46 |
+
|
47 |
def put(self, bucket, fnm, binary):
|
48 |
for _ in range(3):
|
49 |
try:
|
rag/utils/redis_conn.py
CHANGED
@@ -44,6 +44,10 @@ class RedisDB:
|
|
44 |
logging.warning("Redis can't be connected.")
|
45 |
return self.REDIS
|
46 |
|
|
|
|
|
|
|
|
|
47 |
def is_alive(self):
|
48 |
return self.REDIS is not None
|
49 |
|
|
|
44 |
logging.warning("Redis can't be connected.")
|
45 |
return self.REDIS
|
46 |
|
47 |
+
def health(self, queue_name):
|
48 |
+
self.REDIS.ping()
|
49 |
+
return self.REDIS.xinfo_groups(queue_name)[0]
|
50 |
+
|
51 |
def is_alive(self):
|
52 |
return self.REDIS is not None
|
53 |
|