KevinHuSh commited on
Commit
028fe40
·
1 Parent(s): df17cda

add stream chat (#811)

Browse files

### What problem does this PR solve?

#709
### Type of change

- [x] New Feature (non-breaking change which adds functionality)

api/apps/api_app.py CHANGED
@@ -13,10 +13,11 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
 
16
  import os
17
  import re
18
  from datetime import datetime, timedelta
19
- from flask import request
20
  from flask_login import login_required, current_user
21
 
22
  from api.db import FileType, ParserType
@@ -31,11 +32,11 @@ from api.settings import RetCode
31
  from api.utils import get_uuid, current_timestamp, datetime_format
32
  from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request
33
  from itsdangerous import URLSafeTimedSerializer
34
- from api.db.services.task_service import TaskService, queue_tasks
35
  from api.utils.file_utils import filename_type, thumbnail
36
  from rag.utils.minio_conn import MINIO
37
- from api.db.db_models import Task
38
- from api.db.services.file2document_service import File2DocumentService
39
  def generate_confirmation_token(tenent_id):
40
  serializer = URLSafeTimedSerializer(tenent_id)
41
  return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34]
@@ -164,6 +165,7 @@ def completion():
164
  e, conv = API4ConversationService.get_by_id(req["conversation_id"])
165
  if not e:
166
  return get_data_error_result(retmsg="Conversation not found!")
 
167
 
168
  msg = []
169
  for m in req["messages"]:
@@ -180,13 +182,45 @@ def completion():
180
  return get_data_error_result(retmsg="Dialog not found!")
181
  del req["conversation_id"]
182
  del req["messages"]
183
- ans = chat(dia, msg, **req)
184
  if not conv.reference:
185
  conv.reference = []
186
- conv.reference.append(ans["reference"])
187
- conv.message.append({"role": "assistant", "content": ans["answer"]})
188
- API4ConversationService.append_message(conv.id, conv.to_dict())
189
- return get_json_result(data=ans)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  except Exception as e:
191
  return server_error_response(e)
192
 
@@ -229,7 +263,6 @@ def upload():
229
  return get_json_result(
230
  data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
231
 
232
-
233
  file = request.files['file']
234
  if file.filename == '':
235
  return get_json_result(
@@ -253,7 +286,6 @@ def upload():
253
  location += "_"
254
  blob = request.files['file'].read()
255
  MINIO.put(kb_id, location, blob)
256
-
257
  doc = {
258
  "id": get_uuid(),
259
  "kb_id": kb.id,
@@ -266,42 +298,11 @@ def upload():
266
  "size": len(blob),
267
  "thumbnail": thumbnail(filename, blob)
268
  }
269
-
270
- form_data=request.form
271
- if "parser_id" in form_data.keys():
272
- if request.form.get("parser_id").strip() in list(vars(ParserType).values())[1:-3]:
273
- doc["parser_id"] = request.form.get("parser_id").strip()
274
  if doc["type"] == FileType.VISUAL:
275
  doc["parser_id"] = ParserType.PICTURE.value
276
  if re.search(r"\.(ppt|pptx|pages)$", filename):
277
  doc["parser_id"] = ParserType.PRESENTATION.value
278
-
279
- doc_result = DocumentService.insert(doc)
280
-
281
  except Exception as e:
282
  return server_error_response(e)
283
-
284
- if "run" in form_data.keys():
285
- if request.form.get("run").strip() == "1":
286
- try:
287
- info = {"run": 1, "progress": 0}
288
- info["progress_msg"] = ""
289
- info["chunk_num"] = 0
290
- info["token_num"] = 0
291
- DocumentService.update_by_id(doc["id"], info)
292
- # if str(req["run"]) == TaskStatus.CANCEL.value:
293
- tenant_id = DocumentService.get_tenant_id(doc["id"])
294
- if not tenant_id:
295
- return get_data_error_result(retmsg="Tenant not found!")
296
-
297
- #e, doc = DocumentService.get_by_id(doc["id"])
298
- TaskService.filter_delete([Task.doc_id == doc["id"]])
299
- e, doc = DocumentService.get_by_id(doc["id"])
300
- doc = doc.to_dict()
301
- doc["tenant_id"] = tenant_id
302
- bucket, name = File2DocumentService.get_minio_address(doc_id=doc["id"])
303
- queue_tasks(doc, bucket, name)
304
- except Exception as e:
305
- return server_error_response(e)
306
-
307
- return get_json_result(data=doc_result.to_json())
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
+ import json
17
  import os
18
  import re
19
  from datetime import datetime, timedelta
20
+ from flask import request, Response
21
  from flask_login import login_required, current_user
22
 
23
  from api.db import FileType, ParserType
 
32
  from api.utils import get_uuid, current_timestamp, datetime_format
33
  from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request
34
  from itsdangerous import URLSafeTimedSerializer
35
+
36
  from api.utils.file_utils import filename_type, thumbnail
37
  from rag.utils.minio_conn import MINIO
38
+
39
+
40
  def generate_confirmation_token(tenent_id):
41
  serializer = URLSafeTimedSerializer(tenent_id)
42
  return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34]
 
165
  e, conv = API4ConversationService.get_by_id(req["conversation_id"])
166
  if not e:
167
  return get_data_error_result(retmsg="Conversation not found!")
168
+ if "quote" not in req: req["quote"] = False
169
 
170
  msg = []
171
  for m in req["messages"]:
 
182
  return get_data_error_result(retmsg="Dialog not found!")
183
  del req["conversation_id"]
184
  del req["messages"]
185
+
186
  if not conv.reference:
187
  conv.reference = []
188
+ conv.message.append({"role": "assistant", "content": ""})
189
+ conv.reference.append({"chunks": [], "doc_aggs": []})
190
+
191
+ def fillin_conv(ans):
192
+ nonlocal conv
193
+ if not conv.reference:
194
+ conv.reference.append(ans["reference"])
195
+ else: conv.reference[-1] = ans["reference"]
196
+ conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
197
+
198
+ def stream():
199
+ nonlocal dia, msg, req, conv
200
+ try:
201
+ for ans in chat(dia, msg, True, **req):
202
+ fillin_conv(ans)
203
+ yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
204
+ API4ConversationService.append_message(conv.id, conv.to_dict())
205
+ except Exception as e:
206
+ yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
207
+ "data": {"answer": "**ERROR**: "+str(e), "reference": []}},
208
+ ensure_ascii=False) + "\n\n"
209
+ yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
210
+
211
+ if req.get("stream", True):
212
+ resp = Response(stream(), mimetype="text/event-stream")
213
+ resp.headers.add_header("Cache-control", "no-cache")
214
+ resp.headers.add_header("Connection", "keep-alive")
215
+ resp.headers.add_header("X-Accel-Buffering", "no")
216
+ resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
217
+ return resp
218
+ else:
219
+ ans = chat(dia, msg, False, **req)
220
+ fillin_conv(ans)
221
+ API4ConversationService.append_message(conv.id, conv.to_dict())
222
+ return get_json_result(data=ans)
223
+
224
  except Exception as e:
225
  return server_error_response(e)
226
 
 
263
  return get_json_result(
264
  data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
265
 
 
266
  file = request.files['file']
267
  if file.filename == '':
268
  return get_json_result(
 
286
  location += "_"
287
  blob = request.files['file'].read()
288
  MINIO.put(kb_id, location, blob)
 
289
  doc = {
290
  "id": get_uuid(),
291
  "kb_id": kb.id,
 
298
  "size": len(blob),
299
  "thumbnail": thumbnail(filename, blob)
300
  }
 
 
 
 
 
301
  if doc["type"] == FileType.VISUAL:
302
  doc["parser_id"] = ParserType.PICTURE.value
303
  if re.search(r"\.(ppt|pptx|pages)$", filename):
304
  doc["parser_id"] = ParserType.PRESENTATION.value
305
+ doc = DocumentService.insert(doc)
306
+ return get_json_result(data=doc.to_json())
 
307
  except Exception as e:
308
  return server_error_response(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/apps/conversation_app.py CHANGED
@@ -13,12 +13,13 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
- from flask import request
17
  from flask_login import login_required
18
  from api.db.services.dialog_service import DialogService, ConversationService, chat
19
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
20
  from api.utils import get_uuid
21
  from api.utils.api_utils import get_json_result
 
22
 
23
 
24
  @manager.route('/set', methods=['POST'])
@@ -103,9 +104,12 @@ def list_convsersation():
103
 
104
  @manager.route('/completion', methods=['POST'])
105
  @login_required
106
- @validate_request("conversation_id", "messages")
107
  def completion():
108
  req = request.json
 
 
 
109
  msg = []
110
  for m in req["messages"]:
111
  if m["role"] == "system":
@@ -123,13 +127,45 @@ def completion():
123
  return get_data_error_result(retmsg="Dialog not found!")
124
  del req["conversation_id"]
125
  del req["messages"]
126
- ans = chat(dia, msg, **req)
127
  if not conv.reference:
128
  conv.reference = []
129
- conv.reference.append(ans["reference"])
130
- conv.message.append({"role": "assistant", "content": ans["answer"]})
131
- ConversationService.update_by_id(conv.id, conv.to_dict())
132
- return get_json_result(data=ans)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  except Exception as e:
134
  return server_error_response(e)
135
 
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
+ from flask import request, Response, jsonify
17
  from flask_login import login_required
18
  from api.db.services.dialog_service import DialogService, ConversationService, chat
19
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
20
  from api.utils import get_uuid
21
  from api.utils.api_utils import get_json_result
22
+ import json
23
 
24
 
25
  @manager.route('/set', methods=['POST'])
 
104
 
105
  @manager.route('/completion', methods=['POST'])
106
  @login_required
107
+ #@validate_request("conversation_id", "messages")
108
  def completion():
109
  req = request.json
110
+ #req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
111
+ # {"role": "user", "content": "上海有吗?"}
112
+ #]}
113
  msg = []
114
  for m in req["messages"]:
115
  if m["role"] == "system":
 
127
  return get_data_error_result(retmsg="Dialog not found!")
128
  del req["conversation_id"]
129
  del req["messages"]
130
+
131
  if not conv.reference:
132
  conv.reference = []
133
+ conv.message.append({"role": "assistant", "content": ""})
134
+ conv.reference.append({"chunks": [], "doc_aggs": []})
135
+
136
+ def fillin_conv(ans):
137
+ nonlocal conv
138
+ if not conv.reference:
139
+ conv.reference.append(ans["reference"])
140
+ else: conv.reference[-1] = ans["reference"]
141
+ conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
142
+
143
+ def stream():
144
+ nonlocal dia, msg, req, conv
145
+ try:
146
+ for ans in chat(dia, msg, True, **req):
147
+ fillin_conv(ans)
148
+ yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
149
+ ConversationService.update_by_id(conv.id, conv.to_dict())
150
+ except Exception as e:
151
+ yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
152
+ "data": {"answer": "**ERROR**: "+str(e), "reference": []}},
153
+ ensure_ascii=False) + "\n\n"
154
+ yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
155
+
156
+ if req.get("stream", True):
157
+ resp = Response(stream(), mimetype="text/event-stream")
158
+ resp.headers.add_header("Cache-control", "no-cache")
159
+ resp.headers.add_header("Connection", "keep-alive")
160
+ resp.headers.add_header("X-Accel-Buffering", "no")
161
+ resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
162
+ return resp
163
+
164
+ else:
165
+ ans = chat(dia, msg, False, **req)
166
+ fillin_conv(ans)
167
+ ConversationService.update_by_id(conv.id, conv.to_dict())
168
+ return get_json_result(data=ans)
169
  except Exception as e:
170
  return server_error_response(e)
171
 
api/apps/system_app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License
15
+ #
16
+ from flask_login import login_required
17
+
18
+ from api.db.services.knowledgebase_service import KnowledgebaseService
19
+ from api.utils.api_utils import get_json_result
20
+ from api.versions import get_rag_version
21
+ from rag.settings import SVR_QUEUE_NAME
22
+ from rag.utils.es_conn import ELASTICSEARCH
23
+ from rag.utils.minio_conn import MINIO
24
+ from timeit import default_timer as timer
25
+
26
+ from rag.utils.redis_conn import REDIS_CONN
27
+
28
+
29
+ @manager.route('/version', methods=['GET'])
30
+ @login_required
31
+ def version():
32
+ return get_json_result(data=get_rag_version())
33
+
34
+
35
+ @manager.route('/status', methods=['GET'])
36
+ @login_required
37
+ def status():
38
+ res = {}
39
+ st = timer()
40
+ try:
41
+ res["es"] = ELASTICSEARCH.health()
42
+ res["es"]["elapsed"] = "{:.1f}".format((timer() - st)*1000.)
43
+ except Exception as e:
44
+ res["es"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)}
45
+
46
+ st = timer()
47
+ try:
48
+ MINIO.health()
49
+ res["minio"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.)}
50
+ except Exception as e:
51
+ res["minio"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)}
52
+
53
+ st = timer()
54
+ try:
55
+ KnowledgebaseService.get_by_id("x")
56
+ res["mysql"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.)}
57
+ except Exception as e:
58
+ res["mysql"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)}
59
+
60
+ st = timer()
61
+ try:
62
+ qinfo = REDIS_CONN.health(SVR_QUEUE_NAME)
63
+ res["redis"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.), "pending": qinfo["pending"]}
64
+ except Exception as e:
65
+ res["redis"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)}
66
+
67
+ return get_json_result(data=res)
api/db/services/dialog_service.py CHANGED
@@ -14,6 +14,7 @@
14
  # limitations under the License.
15
  #
16
  import re
 
17
 
18
  from api.db import LLMType
19
  from api.db.db_models import Dialog, Conversation
@@ -71,7 +72,7 @@ def message_fit_in(msg, max_length=4000):
71
  return max_length, msg
72
 
73
 
74
- def chat(dialog, messages, **kwargs):
75
  assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
76
  llm = LLMService.query(llm_name=dialog.llm_id)
77
  if not llm:
@@ -82,7 +83,10 @@ def chat(dialog, messages, **kwargs):
82
  else: max_tokens = llm[0].max_tokens
83
  kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
84
  embd_nms = list(set([kb.embd_id for kb in kbs]))
85
- assert len(embd_nms) == 1, "Knowledge bases use different embedding models."
 
 
 
86
 
87
  questions = [m["content"] for m in messages if m["role"] == "user"]
88
  embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
@@ -94,7 +98,9 @@ def chat(dialog, messages, **kwargs):
94
  if field_map:
95
  chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
96
  ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
97
- if ans: return ans
 
 
98
 
99
  for p in prompt_config["parameters"]:
100
  if p["key"] == "knowledge":
@@ -118,8 +124,9 @@ def chat(dialog, messages, **kwargs):
118
  "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
119
 
120
  if not knowledges and prompt_config.get("empty_response"):
121
- return {
122
- "answer": prompt_config["empty_response"], "reference": kbinfos}
 
123
 
124
  kwargs["knowledge"] = "\n".join(knowledges)
125
  gen_conf = dialog.llm_setting
@@ -130,33 +137,45 @@ def chat(dialog, messages, **kwargs):
130
  gen_conf["max_tokens"] = min(
131
  gen_conf["max_tokens"],
132
  max_tokens - used_token_count)
133
- answer = chat_mdl.chat(
134
- prompt_config["system"].format(
135
- **kwargs), msg, gen_conf)
136
- chat_logger.info("User: {}|Assistant: {}".format(
137
- msg[-1]["content"], answer))
138
-
139
- if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
140
- answer, idx = retrievaler.insert_citations(answer,
141
- [ck["content_ltks"]
142
- for ck in kbinfos["chunks"]],
143
- [ck["vector"]
144
- for ck in kbinfos["chunks"]],
145
- embd_mdl,
146
- tkweight=1 - dialog.vector_similarity_weight,
147
- vtweight=dialog.vector_similarity_weight)
148
- idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
149
- recall_docs = [
150
- d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
151
- if not recall_docs: recall_docs = kbinfos["doc_aggs"]
152
- kbinfos["doc_aggs"] = recall_docs
153
-
154
- for c in kbinfos["chunks"]:
155
- if c.get("vector"):
156
- del c["vector"]
157
- if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api")>=0:
158
- answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
159
- return {"answer": answer, "reference": kbinfos}
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
 
162
  def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
 
14
  # limitations under the License.
15
  #
16
  import re
17
+ from copy import deepcopy
18
 
19
  from api.db import LLMType
20
  from api.db.db_models import Dialog, Conversation
 
72
  return max_length, msg
73
 
74
 
75
+ def chat(dialog, messages, stream=True, **kwargs):
76
  assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
77
  llm = LLMService.query(llm_name=dialog.llm_id)
78
  if not llm:
 
83
  else: max_tokens = llm[0].max_tokens
84
  kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
85
  embd_nms = list(set([kb.embd_id for kb in kbs]))
86
+ if len(embd_nms) != 1:
87
+ if stream:
88
+ yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
89
+ return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
90
 
91
  questions = [m["content"] for m in messages if m["role"] == "user"]
92
  embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
 
98
  if field_map:
99
  chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
100
  ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
101
+ if ans:
102
+ yield ans
103
+ return
104
 
105
  for p in prompt_config["parameters"]:
106
  if p["key"] == "knowledge":
 
124
  "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
125
 
126
  if not knowledges and prompt_config.get("empty_response"):
127
+ if stream:
128
+ yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
129
+ return {"answer": prompt_config["empty_response"], "reference": kbinfos}
130
 
131
  kwargs["knowledge"] = "\n".join(knowledges)
132
  gen_conf = dialog.llm_setting
 
137
  gen_conf["max_tokens"] = min(
138
  gen_conf["max_tokens"],
139
  max_tokens - used_token_count)
140
+
141
+ def decorate_answer(answer):
142
+ nonlocal prompt_config, knowledges, kwargs, kbinfos
143
+ if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
144
+ answer, idx = retrievaler.insert_citations(answer,
145
+ [ck["content_ltks"]
146
+ for ck in kbinfos["chunks"]],
147
+ [ck["vector"]
148
+ for ck in kbinfos["chunks"]],
149
+ embd_mdl,
150
+ tkweight=1 - dialog.vector_similarity_weight,
151
+ vtweight=dialog.vector_similarity_weight)
152
+ idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
153
+ recall_docs = [
154
+ d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
155
+ if not recall_docs: recall_docs = kbinfos["doc_aggs"]
156
+ kbinfos["doc_aggs"] = recall_docs
157
+
158
+ refs = deepcopy(kbinfos)
159
+ for c in refs["chunks"]:
160
+ if c.get("vector"):
161
+ del c["vector"]
162
+ if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api")>=0:
163
+ answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
164
+ return {"answer": answer, "reference": refs}
165
+
166
+ if stream:
167
+ answer = ""
168
+ for ans in chat_mdl.chat_streamly(prompt_config["system"].format(**kwargs), msg, gen_conf):
169
+ answer = ans
170
+ yield {"answer": answer, "reference": {}}
171
+ yield decorate_answer(answer)
172
+ else:
173
+ answer = chat_mdl.chat(
174
+ prompt_config["system"].format(
175
+ **kwargs), msg, gen_conf)
176
+ chat_logger.info("User: {}|Assistant: {}".format(
177
+ msg[-1]["content"], answer))
178
+ return decorate_answer(answer)
179
 
180
 
181
  def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
api/db/services/document_service.py CHANGED
@@ -43,7 +43,7 @@ class DocumentService(CommonService):
43
  docs = cls.model.select().where(
44
  (cls.model.kb_id == kb_id),
45
  (fn.LOWER(cls.model.name).contains(keywords.lower()))
46
- )
47
  else:
48
  docs = cls.model.select().where(cls.model.kb_id == kb_id)
49
  count = docs.count()
@@ -75,7 +75,7 @@ class DocumentService(CommonService):
75
  def delete(cls, doc):
76
  e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
77
  if not KnowledgebaseService.update_by_id(
78
- kb.id, {"doc_num": kb.doc_num - 1}):
79
  raise RuntimeError("Database error (Knowledgebase)!")
80
  return cls.delete_by_id(doc.id)
81
 
 
43
  docs = cls.model.select().where(
44
  (cls.model.kb_id == kb_id),
45
  (fn.LOWER(cls.model.name).contains(keywords.lower()))
46
+ )
47
  else:
48
  docs = cls.model.select().where(cls.model.kb_id == kb_id)
49
  count = docs.count()
 
75
  def delete(cls, doc):
76
  e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
77
  if not KnowledgebaseService.update_by_id(
78
+ kb.id, {"doc_num": max(0, kb.doc_num - 1)}):
79
  raise RuntimeError("Database error (Knowledgebase)!")
80
  return cls.delete_by_id(doc.id)
81
 
api/db/services/llm_service.py CHANGED
@@ -172,8 +172,18 @@ class LLMBundle(object):
172
 
173
  def chat(self, system, history, gen_conf):
174
  txt, used_tokens = self.mdl.chat(system, history, gen_conf)
175
- if TenantLLMService.increase_usage(
176
  self.tenant_id, self.llm_type, used_tokens, self.llm_name):
177
  database_logger.error(
178
  "Can't update token usage for {}/CHAT".format(self.tenant_id))
179
  return txt
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  def chat(self, system, history, gen_conf):
174
  txt, used_tokens = self.mdl.chat(system, history, gen_conf)
175
+ if not TenantLLMService.increase_usage(
176
  self.tenant_id, self.llm_type, used_tokens, self.llm_name):
177
  database_logger.error(
178
  "Can't update token usage for {}/CHAT".format(self.tenant_id))
179
  return txt
180
+
181
+ def chat_streamly(self, system, history, gen_conf):
182
+ for txt in self.mdl.chat_streamly(system, history, gen_conf):
183
+ if isinstance(txt, int):
184
+ if not TenantLLMService.increase_usage(
185
+ self.tenant_id, self.llm_type, txt, self.llm_name):
186
+ database_logger.error(
187
+ "Can't update token usage for {}/CHAT".format(self.tenant_id))
188
+ return
189
+ yield txt
api/utils/api_utils.py CHANGED
@@ -25,7 +25,6 @@ from flask import (
25
  from werkzeug.http import HTTP_STATUS_CODES
26
 
27
  from api.utils import json_dumps
28
- from api.versions import get_rag_version
29
  from api.settings import RetCode
30
  from api.settings import (
31
  REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
@@ -84,9 +83,6 @@ def request(**kwargs):
84
  return sess.send(prepped, stream=stream, timeout=timeout)
85
 
86
 
87
- rag_version = get_rag_version() or ''
88
-
89
-
90
  def get_exponential_backoff_interval(retries, full_jitter=False):
91
  """Calculate the exponential backoff wait time."""
92
  # Will be zero if factor equals 0
 
25
  from werkzeug.http import HTTP_STATUS_CODES
26
 
27
  from api.utils import json_dumps
 
28
  from api.settings import RetCode
29
  from api.settings import (
30
  REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
 
83
  return sess.send(prepped, stream=stream, timeout=timeout)
84
 
85
 
 
 
 
86
  def get_exponential_backoff_interval(retries, full_jitter=False):
87
  """Calculate the exponential backoff wait time."""
88
  # Will be zero if factor equals 0
rag/llm/chat_model.py CHANGED
@@ -20,7 +20,6 @@ from openai import OpenAI
20
  import openai
21
  from ollama import Client
22
  from rag.nlp import is_english
23
- from rag.utils import num_tokens_from_string
24
 
25
 
26
  class Base(ABC):
@@ -44,6 +43,31 @@ class Base(ABC):
44
  except openai.APIError as e:
45
  return "**ERROR**: " + str(e), 0
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  class GptTurbo(Base):
49
  def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
@@ -97,6 +121,35 @@ class QWenChat(Base):
97
 
98
  return "**ERROR**: " + response.message, tk_count
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  class ZhipuChat(Base):
102
  def __init__(self, key, model_name="glm-3-turbo", **kwargs):
@@ -122,6 +175,34 @@ class ZhipuChat(Base):
122
  except Exception as e:
123
  return "**ERROR**: " + str(e), 0
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  class OllamaChat(Base):
127
  def __init__(self, key, model_name, **kwargs):
@@ -148,3 +229,28 @@ class OllamaChat(Base):
148
  except Exception as e:
149
  return "**ERROR**: " + str(e), 0
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  import openai
21
  from ollama import Client
22
  from rag.nlp import is_english
 
23
 
24
 
25
  class Base(ABC):
 
43
  except openai.APIError as e:
44
  return "**ERROR**: " + str(e), 0
45
 
46
+ def chat_streamly(self, system, history, gen_conf):
47
+ if system:
48
+ history.insert(0, {"role": "system", "content": system})
49
+ ans = ""
50
+ total_tokens = 0
51
+ try:
52
+ response = self.client.chat.completions.create(
53
+ model=self.model_name,
54
+ messages=history,
55
+ stream=True,
56
+ **gen_conf)
57
+ for resp in response:
58
+ if not resp.choices[0].delta.content:continue
59
+ ans += resp.choices[0].delta.content
60
+ total_tokens += 1
61
+ if resp.choices[0].finish_reason == "length":
62
+ ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
63
+ [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
64
+ yield ans
65
+
66
+ except openai.APIError as e:
67
+ yield ans + "\n**ERROR**: " + str(e)
68
+
69
+ yield total_tokens
70
+
71
 
72
  class GptTurbo(Base):
73
  def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
 
121
 
122
  return "**ERROR**: " + response.message, tk_count
123
 
124
+ def chat_streamly(self, system, history, gen_conf):
125
+ from http import HTTPStatus
126
+ if system:
127
+ history.insert(0, {"role": "system", "content": system})
128
+ ans = ""
129
+ try:
130
+ response = Generation.call(
131
+ self.model_name,
132
+ messages=history,
133
+ result_format='message',
134
+ stream=True,
135
+ **gen_conf
136
+ )
137
+ tk_count = 0
138
+ for resp in response:
139
+ if resp.status_code == HTTPStatus.OK:
140
+ ans = resp.output.choices[0]['message']['content']
141
+ tk_count = resp.usage.total_tokens
142
+ if resp.output.choices[0].get("finish_reason", "") == "length":
143
+ ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
144
+ [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
145
+ yield ans
146
+ else:
147
+ yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access")<0 else "Out of credit. Please set the API key in **settings > Model providers.**"
148
+ except Exception as e:
149
+ yield ans + "\n**ERROR**: " + str(e)
150
+
151
+ yield tk_count
152
+
153
 
154
  class ZhipuChat(Base):
155
  def __init__(self, key, model_name="glm-3-turbo", **kwargs):
 
175
  except Exception as e:
176
  return "**ERROR**: " + str(e), 0
177
 
178
+ def chat_streamly(self, system, history, gen_conf):
179
+ if system:
180
+ history.insert(0, {"role": "system", "content": system})
181
+ if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
182
+ if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
183
+ ans = ""
184
+ try:
185
+ response = self.client.chat.completions.create(
186
+ model=self.model_name,
187
+ messages=history,
188
+ stream=True,
189
+ **gen_conf
190
+ )
191
+ tk_count = 0
192
+ for resp in response:
193
+ if not resp.choices[0].delta.content:continue
194
+ delta = resp.choices[0].delta.content
195
+ ans += delta
196
+ tk_count = resp.usage.total_tokens if response.usage else 0
197
+ if resp.output.choices[0].finish_reason == "length":
198
+ ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
199
+ [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
200
+ yield ans
201
+ except Exception as e:
202
+ yield ans + "\n**ERROR**: " + str(e)
203
+
204
+ yield tk_count
205
+
206
 
207
  class OllamaChat(Base):
208
  def __init__(self, key, model_name, **kwargs):
 
229
  except Exception as e:
230
  return "**ERROR**: " + str(e), 0
231
 
232
+ def chat_streamly(self, system, history, gen_conf):
233
+ if system:
234
+ history.insert(0, {"role": "system", "content": system})
235
+ options = {}
236
+ if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
237
+ if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
238
+ if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
239
+ if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
240
+ if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
241
+ ans = ""
242
+ try:
243
+ response = self.client.chat(
244
+ model=self.model_name,
245
+ messages=history,
246
+ stream=True,
247
+ options=options
248
+ )
249
+ for resp in response:
250
+ if resp["done"]:
251
+ return resp["prompt_eval_count"] + resp["eval_count"]
252
+ ans = resp["message"]["content"]
253
+ yield ans
254
+ except Exception as e:
255
+ yield ans + "\n**ERROR**: " + str(e)
256
+ yield 0
rag/svr/task_executor.py CHANGED
@@ -80,7 +80,7 @@ def set_progress(task_id, from_page=0, to_page=-1,
80
 
81
  if to_page > 0:
82
  if msg:
83
- msg = f"Page({from_page+1}~{to_page+1}): " + msg
84
  d = {"progress_msg": msg}
85
  if prog is not None:
86
  d["progress"] = prog
@@ -124,7 +124,7 @@ def get_minio_binary(bucket, name):
124
  def build(row):
125
  if row["size"] > DOC_MAXIMUM_SIZE:
126
  set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
127
- (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
128
  return []
129
 
130
  callback = partial(
@@ -138,12 +138,12 @@ def build(row):
138
  bucket, name = File2DocumentService.get_minio_address(doc_id=row["doc_id"])
139
  binary = get_minio_binary(bucket, name)
140
  cron_logger.info(
141
- "From minio({}) {}/{}".format(timer()-st, row["location"], row["name"]))
142
  cks = chunker.chunk(row["name"], binary=binary, from_page=row["from_page"],
143
  to_page=row["to_page"], lang=row["language"], callback=callback,
144
  kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"])
145
  cron_logger.info(
146
- "Chunkking({}) {}/{}".format(timer()-st, row["location"], row["name"]))
147
  except TimeoutError as e:
148
  callback(-1, f"Internal server error: Fetch file timeout. Could you try it again.")
149
  cron_logger.error(
@@ -173,7 +173,7 @@ def build(row):
173
  d.update(ck)
174
  md5 = hashlib.md5()
175
  md5.update((ck["content_with_weight"] +
176
- str(d["doc_id"])).encode("utf-8"))
177
  d["_id"] = md5.hexdigest()
178
  d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
179
  d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
@@ -261,7 +261,7 @@ def main():
261
 
262
  st = timer()
263
  cks = build(r)
264
- cron_logger.info("Build chunks({}): {:.2f}".format(r["name"], timer()-st))
265
  if cks is None:
266
  continue
267
  if not cks:
@@ -271,7 +271,7 @@ def main():
271
  ## set_progress(r["did"], -1, "ERROR: ")
272
  callback(
273
  msg="Finished slicing files(%d). Start to embedding the content." %
274
- len(cks))
275
  st = timer()
276
  try:
277
  tk_count = embedding(cks, embd_mdl, r["parser_config"], callback)
@@ -279,19 +279,19 @@ def main():
279
  callback(-1, "Embedding error:{}".format(str(e)))
280
  cron_logger.error(str(e))
281
  tk_count = 0
282
- cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer()-st))
283
 
284
- callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer()-st))
285
  init_kb(r)
286
  chunk_count = len(set([c["_id"] for c in cks]))
287
  st = timer()
288
  es_r = ""
289
  for b in range(0, len(cks), 32):
290
- es_r = ELASTICSEARCH.bulk(cks[b:b+32], search.index_name(r["tenant_id"]))
291
  if b % 128 == 0:
292
  callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
293
 
294
- cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer()-st))
295
  if es_r:
296
  callback(-1, "Index failure!")
297
  ELASTICSEARCH.deleteByQuery(
@@ -307,8 +307,7 @@ def main():
307
  r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
308
  cron_logger.info(
309
  "Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format(
310
- r["id"], tk_count, len(cks), timer()-st))
311
-
312
 
313
 
314
  if __name__ == "__main__":
 
80
 
81
  if to_page > 0:
82
  if msg:
83
+ msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
84
  d = {"progress_msg": msg}
85
  if prog is not None:
86
  d["progress"] = prog
 
124
  def build(row):
125
  if row["size"] > DOC_MAXIMUM_SIZE:
126
  set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
127
+ (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
128
  return []
129
 
130
  callback = partial(
 
138
  bucket, name = File2DocumentService.get_minio_address(doc_id=row["doc_id"])
139
  binary = get_minio_binary(bucket, name)
140
  cron_logger.info(
141
+ "From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
142
  cks = chunker.chunk(row["name"], binary=binary, from_page=row["from_page"],
143
  to_page=row["to_page"], lang=row["language"], callback=callback,
144
  kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"])
145
  cron_logger.info(
146
+ "Chunkking({}) {}/{}".format(timer() - st, row["location"], row["name"]))
147
  except TimeoutError as e:
148
  callback(-1, f"Internal server error: Fetch file timeout. Could you try it again.")
149
  cron_logger.error(
 
173
  d.update(ck)
174
  md5 = hashlib.md5()
175
  md5.update((ck["content_with_weight"] +
176
+ str(d["doc_id"])).encode("utf-8"))
177
  d["_id"] = md5.hexdigest()
178
  d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
179
  d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
 
261
 
262
  st = timer()
263
  cks = build(r)
264
+ cron_logger.info("Build chunks({}): {}".format(r["name"], timer() - st))
265
  if cks is None:
266
  continue
267
  if not cks:
 
271
  ## set_progress(r["did"], -1, "ERROR: ")
272
  callback(
273
  msg="Finished slicing files(%d). Start to embedding the content." %
274
+ len(cks))
275
  st = timer()
276
  try:
277
  tk_count = embedding(cks, embd_mdl, r["parser_config"], callback)
 
279
  callback(-1, "Embedding error:{}".format(str(e)))
280
  cron_logger.error(str(e))
281
  tk_count = 0
282
+ cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
283
 
284
+ callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st))
285
  init_kb(r)
286
  chunk_count = len(set([c["_id"] for c in cks]))
287
  st = timer()
288
  es_r = ""
289
  for b in range(0, len(cks), 32):
290
+ es_r = ELASTICSEARCH.bulk(cks[b:b + 32], search.index_name(r["tenant_id"]))
291
  if b % 128 == 0:
292
  callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
293
 
294
+ cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
295
  if es_r:
296
  callback(-1, "Index failure!")
297
  ELASTICSEARCH.deleteByQuery(
 
307
  r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
308
  cron_logger.info(
309
  "Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format(
310
+ r["id"], tk_count, len(cks), timer() - st))
 
311
 
312
 
313
  if __name__ == "__main__":
rag/utils/es_conn.py CHANGED
@@ -43,6 +43,9 @@ class ESConnection:
43
  v = v["number"].split(".")[0]
44
  return int(v) >= 7
45
 
 
 
 
46
  def upsert(self, df, idxnm=""):
47
  res = []
48
  for d in df:
 
43
  v = v["number"].split(".")[0]
44
  return int(v) >= 7
45
 
46
+ def health(self):
47
+ return dict(self.es.cluster.health())
48
+
49
  def upsert(self, df, idxnm=""):
50
  res = []
51
  for d in df:
rag/utils/minio_conn.py CHANGED
@@ -34,6 +34,16 @@ class RAGFlowMinio(object):
34
  del self.conn
35
  self.conn = None
36
 
 
 
 
 
 
 
 
 
 
 
37
  def put(self, bucket, fnm, binary):
38
  for _ in range(3):
39
  try:
 
34
  del self.conn
35
  self.conn = None
36
 
37
+ def health(self):
38
+ bucket, fnm, binary = "_t@@@1", "_t@@@1", b"_t@@@1"
39
+ if not self.conn.bucket_exists(bucket):
40
+ self.conn.make_bucket(bucket)
41
+ r = self.conn.put_object(bucket, fnm,
42
+ BytesIO(binary),
43
+ len(binary)
44
+ )
45
+ return r
46
+
47
  def put(self, bucket, fnm, binary):
48
  for _ in range(3):
49
  try:
rag/utils/redis_conn.py CHANGED
@@ -44,6 +44,10 @@ class RedisDB:
44
  logging.warning("Redis can't be connected.")
45
  return self.REDIS
46
 
 
 
 
 
47
  def is_alive(self):
48
  return self.REDIS is not None
49
 
 
44
  logging.warning("Redis can't be connected.")
45
  return self.REDIS
46
 
47
+ def health(self, queue_name):
48
+ self.REDIS.ping()
49
+ return self.REDIS.xinfo_groups(queue_name)[0]
50
+
51
  def is_alive(self):
52
  return self.REDIS is not None
53