jinhai-2012 commited on
Commit
6101699
·
1 Parent(s): fc803e8

Move settings initialization after module init phase (#3438)

Browse files

### What problem does this PR solve?

1. Module init won't connect database any more.
2. Config in settings need to be used with settings.CONFIG_NAME

### Type of change

- [x] Refactoring

Signed-off-by: jinhai <[email protected]>

agent/component/generate.py CHANGED
@@ -19,7 +19,7 @@ import pandas as pd
19
  from api.db import LLMType
20
  from api.db.services.dialog_service import message_fit_in
21
  from api.db.services.llm_service import LLMBundle
22
- from api.settings import retrievaler
23
  from agent.component.base import ComponentBase, ComponentParamBase
24
 
25
 
@@ -63,18 +63,20 @@ class Generate(ComponentBase):
63
  component_name = "Generate"
64
 
65
  def get_dependent_components(self):
66
- cpnts = [para["component_id"] for para in self._param.parameters if para.get("component_id") and para["component_id"].lower().find("answer") < 0]
 
67
  return cpnts
68
 
69
  def set_cite(self, retrieval_res, answer):
70
  retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True)
71
  if "empty_response" in retrieval_res.columns:
72
  retrieval_res["empty_response"].fillna("", inplace=True)
73
- answer, idx = retrievaler.insert_citations(answer, [ck["content_ltks"] for _, ck in retrieval_res.iterrows()],
74
- [ck["vector"] for _, ck in retrieval_res.iterrows()],
75
- LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
76
- self._canvas.get_embedding_model()), tkweight=0.7,
77
- vtweight=0.3)
 
78
  doc_ids = set([])
79
  recall_docs = []
80
  for i in idx:
@@ -127,12 +129,14 @@ class Generate(ComponentBase):
127
  else:
128
  if cpn.component_name.lower() == "retrieval":
129
  retrieval_res.append(out)
130
- kwargs[para["key"]] = " - "+"\n - ".join([o if isinstance(o, str) else str(o) for o in out["content"]])
 
131
  self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]})
132
 
133
  if retrieval_res:
134
  retrieval_res = pd.concat(retrieval_res, ignore_index=True)
135
- else: retrieval_res = pd.DataFrame([])
 
136
 
137
  for n, v in kwargs.items():
138
  prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)
 
19
  from api.db import LLMType
20
  from api.db.services.dialog_service import message_fit_in
21
  from api.db.services.llm_service import LLMBundle
22
+ from api import settings
23
  from agent.component.base import ComponentBase, ComponentParamBase
24
 
25
 
 
63
  component_name = "Generate"
64
 
65
  def get_dependent_components(self):
66
+ cpnts = [para["component_id"] for para in self._param.parameters if
67
+ para.get("component_id") and para["component_id"].lower().find("answer") < 0]
68
  return cpnts
69
 
70
  def set_cite(self, retrieval_res, answer):
71
  retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True)
72
  if "empty_response" in retrieval_res.columns:
73
  retrieval_res["empty_response"].fillna("", inplace=True)
74
+ answer, idx = settings.retrievaler.insert_citations(answer,
75
+ [ck["content_ltks"] for _, ck in retrieval_res.iterrows()],
76
+ [ck["vector"] for _, ck in retrieval_res.iterrows()],
77
+ LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
78
+ self._canvas.get_embedding_model()), tkweight=0.7,
79
+ vtweight=0.3)
80
  doc_ids = set([])
81
  recall_docs = []
82
  for i in idx:
 
129
  else:
130
  if cpn.component_name.lower() == "retrieval":
131
  retrieval_res.append(out)
132
+ kwargs[para["key"]] = " - " + "\n - ".join(
133
+ [o if isinstance(o, str) else str(o) for o in out["content"]])
134
  self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]})
135
 
136
  if retrieval_res:
137
  retrieval_res = pd.concat(retrieval_res, ignore_index=True)
138
+ else:
139
+ retrieval_res = pd.DataFrame([])
140
 
141
  for n, v in kwargs.items():
142
  prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)
agent/component/retrieval.py CHANGED
@@ -21,7 +21,7 @@ import pandas as pd
21
  from api.db import LLMType
22
  from api.db.services.knowledgebase_service import KnowledgebaseService
23
  from api.db.services.llm_service import LLMBundle
24
- from api.settings import retrievaler
25
  from agent.component.base import ComponentBase, ComponentParamBase
26
 
27
 
@@ -67,7 +67,7 @@ class Retrieval(ComponentBase, ABC):
67
  if self._param.rerank_id:
68
  rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
69
 
70
- kbinfos = retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
71
  1, self._param.top_n,
72
  self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
73
  aggs=False, rerank_mdl=rerank_mdl)
 
21
  from api.db import LLMType
22
  from api.db.services.knowledgebase_service import KnowledgebaseService
23
  from api.db.services.llm_service import LLMBundle
24
+ from api import settings
25
  from agent.component.base import ComponentBase, ComponentParamBase
26
 
27
 
 
67
  if self._param.rerank_id:
68
  rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
69
 
70
+ kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
71
  1, self._param.top_n,
72
  self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
73
  aggs=False, rerank_mdl=rerank_mdl)
api/apps/__init__.py CHANGED
@@ -30,8 +30,7 @@ from api.utils import CustomJSONEncoder, commands
30
 
31
  from flask_session import Session
32
  from flask_login import LoginManager
33
- from api.settings import SECRET_KEY
34
- from api.settings import API_VERSION
35
  from api.utils.api_utils import server_error_response
36
  from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
37
 
@@ -78,7 +77,6 @@ app.url_map.strict_slashes = False
78
  app.json_encoder = CustomJSONEncoder
79
  app.errorhandler(Exception)(server_error_response)
80
 
81
-
82
  ## convince for dev and debug
83
  # app.config["LOGIN_DISABLED"] = True
84
  app.config["SESSION_PERMANENT"] = False
@@ -110,7 +108,7 @@ def register_page(page_path):
110
 
111
  page_name = page_path.stem.rstrip("_app")
112
  module_name = ".".join(
113
- page_path.parts[page_path.parts.index("api") : -1] + (page_name,)
114
  )
115
 
116
  spec = spec_from_file_location(module_name, page_path)
@@ -121,7 +119,7 @@ def register_page(page_path):
121
  spec.loader.exec_module(page)
122
  page_name = getattr(page, "page_name", page_name)
123
  url_prefix = (
124
- f"/api/{API_VERSION}" if "/sdk/" in path else f"/{API_VERSION}/{page_name}"
125
  )
126
 
127
  app.register_blueprint(page.manager, url_prefix=url_prefix)
@@ -141,7 +139,7 @@ client_urls_prefix = [
141
 
142
  @login_manager.request_loader
143
  def load_user(web_request):
144
- jwt = Serializer(secret_key=SECRET_KEY)
145
  authorization = web_request.headers.get("Authorization")
146
  if authorization:
147
  try:
 
30
 
31
  from flask_session import Session
32
  from flask_login import LoginManager
33
+ from api import settings
 
34
  from api.utils.api_utils import server_error_response
35
  from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
36
 
 
77
  app.json_encoder = CustomJSONEncoder
78
  app.errorhandler(Exception)(server_error_response)
79
 
 
80
  ## convince for dev and debug
81
  # app.config["LOGIN_DISABLED"] = True
82
  app.config["SESSION_PERMANENT"] = False
 
108
 
109
  page_name = page_path.stem.rstrip("_app")
110
  module_name = ".".join(
111
+ page_path.parts[page_path.parts.index("api"): -1] + (page_name,)
112
  )
113
 
114
  spec = spec_from_file_location(module_name, page_path)
 
119
  spec.loader.exec_module(page)
120
  page_name = getattr(page, "page_name", page_name)
121
  url_prefix = (
122
+ f"/api/{settings.API_VERSION}" if "/sdk/" in path else f"/{settings.API_VERSION}/{page_name}"
123
  )
124
 
125
  app.register_blueprint(page.manager, url_prefix=url_prefix)
 
139
 
140
  @login_manager.request_loader
141
  def load_user(web_request):
142
+ jwt = Serializer(secret_key=settings.SECRET_KEY)
143
  authorization = web_request.headers.get("Authorization")
144
  if authorization:
145
  try:
api/apps/api_app.py CHANGED
@@ -32,7 +32,7 @@ from api.db.services.file_service import FileService
32
  from api.db.services.knowledgebase_service import KnowledgebaseService
33
  from api.db.services.task_service import queue_tasks, TaskService
34
  from api.db.services.user_service import UserTenantService
35
- from api.settings import RetCode, retrievaler
36
  from api.utils import get_uuid, current_timestamp, datetime_format
37
  from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
38
  generate_confirmation_token
@@ -141,7 +141,7 @@ def set_conversation():
141
  objs = APIToken.query(token=token)
142
  if not objs:
143
  return get_json_result(
144
- data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
145
  req = request.json
146
  try:
147
  if objs[0].source == "agent":
@@ -183,7 +183,7 @@ def completion():
183
  objs = APIToken.query(token=token)
184
  if not objs:
185
  return get_json_result(
186
- data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
187
  req = request.json
188
  e, conv = API4ConversationService.get_by_id(req["conversation_id"])
189
  if not e:
@@ -290,8 +290,8 @@ def completion():
290
  API4ConversationService.append_message(conv.id, conv.to_dict())
291
  rename_field(result)
292
  return get_json_result(data=result)
293
-
294
- #******************For dialog******************
295
  conv.message.append(msg[-1])
296
  e, dia = DialogService.get_by_id(conv.dialog_id)
297
  if not e:
@@ -326,7 +326,7 @@ def completion():
326
  resp.headers.add_header("X-Accel-Buffering", "no")
327
  resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
328
  return resp
329
-
330
  answer = None
331
  for ans in chat(dia, msg, **req):
332
  answer = ans
@@ -347,8 +347,8 @@ def get(conversation_id):
347
  objs = APIToken.query(token=token)
348
  if not objs:
349
  return get_json_result(
350
- data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
351
-
352
  try:
353
  e, conv = API4ConversationService.get_by_id(conversation_id)
354
  if not e:
@@ -357,8 +357,8 @@ def get(conversation_id):
357
  conv = conv.to_dict()
358
  if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token:
359
  return get_json_result(data=False, message='Token is not valid for this conversation_id!"',
360
- code=RetCode.AUTHENTICATION_ERROR)
361
-
362
  for referenct_i in conv['reference']:
363
  if referenct_i is None or len(referenct_i) == 0:
364
  continue
@@ -378,7 +378,7 @@ def upload():
378
  objs = APIToken.query(token=token)
379
  if not objs:
380
  return get_json_result(
381
- data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
382
 
383
  kb_name = request.form.get("kb_name").strip()
384
  tenant_id = objs[0].tenant_id
@@ -394,12 +394,12 @@ def upload():
394
 
395
  if 'file' not in request.files:
396
  return get_json_result(
397
- data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
398
 
399
  file = request.files['file']
400
  if file.filename == '':
401
  return get_json_result(
402
- data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
403
 
404
  root_folder = FileService.get_root_folder(tenant_id)
405
  pf_id = root_folder["id"]
@@ -490,17 +490,17 @@ def upload_parse():
490
  objs = APIToken.query(token=token)
491
  if not objs:
492
  return get_json_result(
493
- data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
494
 
495
  if 'file' not in request.files:
496
  return get_json_result(
497
- data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
498
 
499
  file_objs = request.files.getlist('file')
500
  for file_obj in file_objs:
501
  if file_obj.filename == '':
502
  return get_json_result(
503
- data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
504
 
505
  doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
506
  return get_json_result(data=doc_ids)
@@ -513,7 +513,7 @@ def list_chunks():
513
  objs = APIToken.query(token=token)
514
  if not objs:
515
  return get_json_result(
516
- data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
517
 
518
  req = request.json
519
 
@@ -531,7 +531,7 @@ def list_chunks():
531
  )
532
  kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
533
 
534
- res = retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
535
  res = [
536
  {
537
  "content": res_item["content_with_weight"],
@@ -553,7 +553,7 @@ def list_kb_docs():
553
  objs = APIToken.query(token=token)
554
  if not objs:
555
  return get_json_result(
556
- data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
557
 
558
  req = request.json
559
  tenant_id = objs[0].tenant_id
@@ -585,6 +585,7 @@ def list_kb_docs():
585
  except Exception as e:
586
  return server_error_response(e)
587
 
 
588
  @manager.route('/document/infos', methods=['POST'])
589
  @validate_request("doc_ids")
590
  def docinfos():
@@ -592,7 +593,7 @@ def docinfos():
592
  objs = APIToken.query(token=token)
593
  if not objs:
594
  return get_json_result(
595
- data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
596
  req = request.json
597
  doc_ids = req["doc_ids"]
598
  docs = DocumentService.get_by_ids(doc_ids)
@@ -606,7 +607,7 @@ def document_rm():
606
  objs = APIToken.query(token=token)
607
  if not objs:
608
  return get_json_result(
609
- data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
610
 
611
  tenant_id = objs[0].tenant_id
612
  req = request.json
@@ -653,7 +654,7 @@ def document_rm():
653
  errors += str(e)
654
 
655
  if errors:
656
- return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
657
 
658
  return get_json_result(data=True)
659
 
@@ -668,7 +669,7 @@ def completion_faq():
668
  objs = APIToken.query(token=token)
669
  if not objs:
670
  return get_json_result(
671
- data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
672
 
673
  e, conv = API4ConversationService.get_by_id(req["conversation_id"])
674
  if not e:
@@ -805,10 +806,10 @@ def retrieval():
805
  objs = APIToken.query(token=token)
806
  if not objs:
807
  return get_json_result(
808
- data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
809
 
810
  req = request.json
811
- kb_ids = req.get("kb_id",[])
812
  doc_ids = req.get("doc_ids", [])
813
  question = req.get("question")
814
  page = int(req.get("page", 1))
@@ -822,20 +823,21 @@ def retrieval():
822
  embd_nms = list(set([kb.embd_id for kb in kbs]))
823
  if len(embd_nms) != 1:
824
  return get_json_result(
825
- data=False, message='Knowledge bases use different embedding models or does not exist."', code=RetCode.AUTHENTICATION_ERROR)
 
826
 
827
  embd_mdl = TenantLLMService.model_instance(
828
  kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id)
829
  rerank_mdl = None
830
  if req.get("rerank_id"):
831
  rerank_mdl = TenantLLMService.model_instance(
832
- kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
833
  if req.get("keyword", False):
834
  chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT)
835
  question += keyword_extraction(chat_mdl, question)
836
- ranks = retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
837
- similarity_threshold, vector_similarity_weight, top,
838
- doc_ids, rerank_mdl=rerank_mdl)
839
  for c in ranks["chunks"]:
840
  if "vector" in c:
841
  del c["vector"]
@@ -843,5 +845,5 @@ def retrieval():
843
  except Exception as e:
844
  if str(e).find("not_found") > 0:
845
  return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
846
- code=RetCode.DATA_ERROR)
847
  return server_error_response(e)
 
32
  from api.db.services.knowledgebase_service import KnowledgebaseService
33
  from api.db.services.task_service import queue_tasks, TaskService
34
  from api.db.services.user_service import UserTenantService
35
+ from api import settings
36
  from api.utils import get_uuid, current_timestamp, datetime_format
37
  from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
38
  generate_confirmation_token
 
141
  objs = APIToken.query(token=token)
142
  if not objs:
143
  return get_json_result(
144
+ data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
145
  req = request.json
146
  try:
147
  if objs[0].source == "agent":
 
183
  objs = APIToken.query(token=token)
184
  if not objs:
185
  return get_json_result(
186
+ data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
187
  req = request.json
188
  e, conv = API4ConversationService.get_by_id(req["conversation_id"])
189
  if not e:
 
290
  API4ConversationService.append_message(conv.id, conv.to_dict())
291
  rename_field(result)
292
  return get_json_result(data=result)
293
+
294
+ # ******************For dialog******************
295
  conv.message.append(msg[-1])
296
  e, dia = DialogService.get_by_id(conv.dialog_id)
297
  if not e:
 
326
  resp.headers.add_header("X-Accel-Buffering", "no")
327
  resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
328
  return resp
329
+
330
  answer = None
331
  for ans in chat(dia, msg, **req):
332
  answer = ans
 
347
  objs = APIToken.query(token=token)
348
  if not objs:
349
  return get_json_result(
350
+ data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
351
+
352
  try:
353
  e, conv = API4ConversationService.get_by_id(conversation_id)
354
  if not e:
 
357
  conv = conv.to_dict()
358
  if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token:
359
  return get_json_result(data=False, message='Token is not valid for this conversation_id!"',
360
+ code=settings.RetCode.AUTHENTICATION_ERROR)
361
+
362
  for referenct_i in conv['reference']:
363
  if referenct_i is None or len(referenct_i) == 0:
364
  continue
 
378
  objs = APIToken.query(token=token)
379
  if not objs:
380
  return get_json_result(
381
+ data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
382
 
383
  kb_name = request.form.get("kb_name").strip()
384
  tenant_id = objs[0].tenant_id
 
394
 
395
  if 'file' not in request.files:
396
  return get_json_result(
397
+ data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
398
 
399
  file = request.files['file']
400
  if file.filename == '':
401
  return get_json_result(
402
+ data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
403
 
404
  root_folder = FileService.get_root_folder(tenant_id)
405
  pf_id = root_folder["id"]
 
490
  objs = APIToken.query(token=token)
491
  if not objs:
492
  return get_json_result(
493
+ data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
494
 
495
  if 'file' not in request.files:
496
  return get_json_result(
497
+ data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
498
 
499
  file_objs = request.files.getlist('file')
500
  for file_obj in file_objs:
501
  if file_obj.filename == '':
502
  return get_json_result(
503
+ data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
504
 
505
  doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
506
  return get_json_result(data=doc_ids)
 
513
  objs = APIToken.query(token=token)
514
  if not objs:
515
  return get_json_result(
516
+ data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
517
 
518
  req = request.json
519
 
 
531
  )
532
  kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
533
 
534
+ res = settings.retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
535
  res = [
536
  {
537
  "content": res_item["content_with_weight"],
 
553
  objs = APIToken.query(token=token)
554
  if not objs:
555
  return get_json_result(
556
+ data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
557
 
558
  req = request.json
559
  tenant_id = objs[0].tenant_id
 
585
  except Exception as e:
586
  return server_error_response(e)
587
 
588
+
589
  @manager.route('/document/infos', methods=['POST'])
590
  @validate_request("doc_ids")
591
  def docinfos():
 
593
  objs = APIToken.query(token=token)
594
  if not objs:
595
  return get_json_result(
596
+ data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
597
  req = request.json
598
  doc_ids = req["doc_ids"]
599
  docs = DocumentService.get_by_ids(doc_ids)
 
607
  objs = APIToken.query(token=token)
608
  if not objs:
609
  return get_json_result(
610
+ data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
611
 
612
  tenant_id = objs[0].tenant_id
613
  req = request.json
 
654
  errors += str(e)
655
 
656
  if errors:
657
+ return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)
658
 
659
  return get_json_result(data=True)
660
 
 
669
  objs = APIToken.query(token=token)
670
  if not objs:
671
  return get_json_result(
672
+ data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
673
 
674
  e, conv = API4ConversationService.get_by_id(req["conversation_id"])
675
  if not e:
 
806
  objs = APIToken.query(token=token)
807
  if not objs:
808
  return get_json_result(
809
+ data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
810
 
811
  req = request.json
812
+ kb_ids = req.get("kb_id", [])
813
  doc_ids = req.get("doc_ids", [])
814
  question = req.get("question")
815
  page = int(req.get("page", 1))
 
823
  embd_nms = list(set([kb.embd_id for kb in kbs]))
824
  if len(embd_nms) != 1:
825
  return get_json_result(
826
+ data=False, message='Knowledge bases use different embedding models or does not exist."',
827
+ code=settings.RetCode.AUTHENTICATION_ERROR)
828
 
829
  embd_mdl = TenantLLMService.model_instance(
830
  kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id)
831
  rerank_mdl = None
832
  if req.get("rerank_id"):
833
  rerank_mdl = TenantLLMService.model_instance(
834
+ kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
835
  if req.get("keyword", False):
836
  chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT)
837
  question += keyword_extraction(chat_mdl, question)
838
+ ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
839
+ similarity_threshold, vector_similarity_weight, top,
840
+ doc_ids, rerank_mdl=rerank_mdl)
841
  for c in ranks["chunks"]:
842
  if "vector" in c:
843
  del c["vector"]
 
845
  except Exception as e:
846
  if str(e).find("not_found") > 0:
847
  return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
848
+ code=settings.RetCode.DATA_ERROR)
849
  return server_error_response(e)
api/apps/canvas_app.py CHANGED
@@ -19,7 +19,7 @@ from functools import partial
19
  from flask import request, Response
20
  from flask_login import login_required, current_user
21
  from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
22
- from api.settings import RetCode
23
  from api.utils import get_uuid
24
  from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
25
  from agent.canvas import Canvas
@@ -36,7 +36,8 @@ def templates():
36
  @login_required
37
  def canvas_list():
38
  return get_json_result(data=sorted([c.to_dict() for c in \
39
- UserCanvasService.query(user_id=current_user.id)], key=lambda x: x["update_time"]*-1)
 
40
  )
41
 
42
 
@@ -45,10 +46,10 @@ def canvas_list():
45
  @login_required
46
  def rm():
47
  for i in request.json["canvas_ids"]:
48
- if not UserCanvasService.query(user_id=current_user.id,id=i):
49
  return get_json_result(
50
  data=False, message='Only owner of canvas authorized for this operation.',
51
- code=RetCode.OPERATING_ERROR)
52
  UserCanvasService.delete_by_id(i)
53
  return get_json_result(data=True)
54
 
@@ -72,7 +73,7 @@ def save():
72
  if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
73
  return get_json_result(
74
  data=False, message='Only owner of canvas authorized for this operation.',
75
- code=RetCode.OPERATING_ERROR)
76
  UserCanvasService.update_by_id(req["id"], req)
77
  return get_json_result(data=req)
78
 
@@ -98,7 +99,7 @@ def run():
98
  if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
99
  return get_json_result(
100
  data=False, message='Only owner of canvas authorized for this operation.',
101
- code=RetCode.OPERATING_ERROR)
102
 
103
  if not isinstance(cvs.dsl, str):
104
  cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
@@ -110,8 +111,8 @@ def run():
110
  if "message" in req:
111
  canvas.messages.append({"role": "user", "content": req["message"], "id": message_id})
112
  if len([m for m in canvas.messages if m["role"] == "user"]) > 1:
113
- #ten = TenantService.get_info_by(current_user.id)[0]
114
- #req["message"] = full_question(ten["tenant_id"], ten["llm_id"], canvas.messages)
115
  pass
116
  canvas.add_user_input(req["message"])
117
  answer = canvas.run(stream=stream)
@@ -122,7 +123,8 @@ def run():
122
  assert answer is not None, "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
123
 
124
  if stream:
125
- assert isinstance(answer, partial), "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
 
126
 
127
  def sse():
128
  nonlocal answer, cvs
@@ -173,7 +175,7 @@ def reset():
173
  if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
174
  return get_json_result(
175
  data=False, message='Only owner of canvas authorized for this operation.',
176
- code=RetCode.OPERATING_ERROR)
177
 
178
  canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
179
  canvas.reset()
 
19
  from flask import request, Response
20
  from flask_login import login_required, current_user
21
  from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
22
+ from api import settings
23
  from api.utils import get_uuid
24
  from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
25
  from agent.canvas import Canvas
 
36
  @login_required
37
  def canvas_list():
38
  return get_json_result(data=sorted([c.to_dict() for c in \
39
+ UserCanvasService.query(user_id=current_user.id)],
40
+ key=lambda x: x["update_time"] * -1)
41
  )
42
 
43
 
 
46
  @login_required
47
  def rm():
48
  for i in request.json["canvas_ids"]:
49
+ if not UserCanvasService.query(user_id=current_user.id, id=i):
50
  return get_json_result(
51
  data=False, message='Only owner of canvas authorized for this operation.',
52
+ code=settings.RetCode.OPERATING_ERROR)
53
  UserCanvasService.delete_by_id(i)
54
  return get_json_result(data=True)
55
 
 
73
  if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
74
  return get_json_result(
75
  data=False, message='Only owner of canvas authorized for this operation.',
76
+ code=settings.RetCode.OPERATING_ERROR)
77
  UserCanvasService.update_by_id(req["id"], req)
78
  return get_json_result(data=req)
79
 
 
99
  if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
100
  return get_json_result(
101
  data=False, message='Only owner of canvas authorized for this operation.',
102
+ code=settings.RetCode.OPERATING_ERROR)
103
 
104
  if not isinstance(cvs.dsl, str):
105
  cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
 
111
  if "message" in req:
112
  canvas.messages.append({"role": "user", "content": req["message"], "id": message_id})
113
  if len([m for m in canvas.messages if m["role"] == "user"]) > 1:
114
+ # ten = TenantService.get_info_by(current_user.id)[0]
115
+ # req["message"] = full_question(ten["tenant_id"], ten["llm_id"], canvas.messages)
116
  pass
117
  canvas.add_user_input(req["message"])
118
  answer = canvas.run(stream=stream)
 
123
  assert answer is not None, "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
124
 
125
  if stream:
126
+ assert isinstance(answer,
127
+ partial), "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
128
 
129
  def sse():
130
  nonlocal answer, cvs
 
175
  if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
176
  return get_json_result(
177
  data=False, message='Only owner of canvas authorized for this operation.',
178
+ code=settings.RetCode.OPERATING_ERROR)
179
 
180
  canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
181
  canvas.reset()
api/apps/chunk_app.py CHANGED
@@ -29,11 +29,12 @@ from api.db.services.llm_service import LLMBundle
29
  from api.db.services.user_service import UserTenantService
30
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
31
  from api.db.services.document_service import DocumentService
32
- from api.settings import RetCode, retrievaler, kg_retrievaler, docStoreConn
33
  from api.utils.api_utils import get_json_result
34
  import hashlib
35
  import re
36
 
 
37
  @manager.route('/list', methods=['POST'])
38
  @login_required
39
  @validate_request("doc_id")
@@ -56,7 +57,7 @@ def list_chunk():
56
  }
57
  if "available_int" in req:
58
  query["available_int"] = int(req["available_int"])
59
- sres = retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
60
  res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
61
  for id in sres.ids:
62
  d = {
@@ -72,13 +73,13 @@ def list_chunk():
72
  "positions": json.loads(sres.field[id].get("position_list", "[]")),
73
  }
74
  assert isinstance(d["positions"], list)
75
- assert len(d["positions"])==0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5)
76
  res["chunks"].append(d)
77
  return get_json_result(data=res)
78
  except Exception as e:
79
  if str(e).find("not_found") > 0:
80
  return get_json_result(data=False, message='No chunk found!',
81
- code=RetCode.DATA_ERROR)
82
  return server_error_response(e)
83
 
84
 
@@ -93,7 +94,7 @@ def get():
93
  tenant_id = tenants[0].tenant_id
94
 
95
  kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
96
- chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
97
  if chunk is None:
98
  return server_error_response("Chunk not found")
99
  k = []
@@ -107,7 +108,7 @@ def get():
107
  except Exception as e:
108
  if str(e).find("NotFoundError") >= 0:
109
  return get_json_result(data=False, message='Chunk not found!',
110
- code=RetCode.DATA_ERROR)
111
  return server_error_response(e)
112
 
113
 
@@ -154,7 +155,7 @@ def set():
154
  v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
155
  v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
156
  d["q_%d_vec" % len(v)] = v.tolist()
157
- docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
158
  return get_json_result(data=True)
159
  except Exception as e:
160
  return server_error_response(e)
@@ -169,8 +170,8 @@ def switch():
169
  e, doc = DocumentService.get_by_id(req["doc_id"])
170
  if not e:
171
  return get_data_error_result(message="Document not found!")
172
- if not docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])},
173
- search.index_name(doc.tenant_id), doc.kb_id):
174
  return get_data_error_result(message="Index updating failure")
175
  return get_json_result(data=True)
176
  except Exception as e:
@@ -186,7 +187,7 @@ def rm():
186
  e, doc = DocumentService.get_by_id(req["doc_id"])
187
  if not e:
188
  return get_data_error_result(message="Document not found!")
189
- if not docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id):
190
  return get_data_error_result(message="Index updating failure")
191
  deleted_chunk_ids = req["chunk_ids"]
192
  chunk_number = len(deleted_chunk_ids)
@@ -230,7 +231,7 @@ def create():
230
  v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
231
  v = 0.1 * v[0] + 0.9 * v[1]
232
  d["q_%d_vec" % len(v)] = v.tolist()
233
- docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
234
 
235
  DocumentService.increment_chunk_num(
236
  doc.id, doc.kb_id, c, 1, 0)
@@ -265,7 +266,7 @@ def retrieval_test():
265
  else:
266
  return get_json_result(
267
  data=False, message='Only owner of knowledgebase authorized for this operation.',
268
- code=RetCode.OPERATING_ERROR)
269
 
270
  e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
271
  if not e:
@@ -281,7 +282,7 @@ def retrieval_test():
281
  chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
282
  question += keyword_extraction(chat_mdl, question)
283
 
284
- retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
285
  ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, page, size,
286
  similarity_threshold, vector_similarity_weight, top,
287
  doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
@@ -293,7 +294,7 @@ def retrieval_test():
293
  except Exception as e:
294
  if str(e).find("not_found") > 0:
295
  return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
296
- code=RetCode.DATA_ERROR)
297
  return server_error_response(e)
298
 
299
 
@@ -304,10 +305,10 @@ def knowledge_graph():
304
  tenant_id = DocumentService.get_tenant_id(doc_id)
305
  kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
306
  req = {
307
- "doc_ids":[doc_id],
308
  "knowledge_graph_kwd": ["graph", "mind_map"]
309
  }
310
- sres = retrievaler.search(req, search.index_name(tenant_id), kb_ids)
311
  obj = {"graph": {}, "mind_map": {}}
312
  for id in sres.ids[:2]:
313
  ty = sres.field[id]["knowledge_graph_kwd"]
@@ -336,4 +337,3 @@ def knowledge_graph():
336
  obj[ty] = content_json
337
 
338
  return get_json_result(data=obj)
339
-
 
29
  from api.db.services.user_service import UserTenantService
30
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
31
  from api.db.services.document_service import DocumentService
32
+ from api import settings
33
  from api.utils.api_utils import get_json_result
34
  import hashlib
35
  import re
36
 
37
+
38
  @manager.route('/list', methods=['POST'])
39
  @login_required
40
  @validate_request("doc_id")
 
57
  }
58
  if "available_int" in req:
59
  query["available_int"] = int(req["available_int"])
60
+ sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
61
  res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
62
  for id in sres.ids:
63
  d = {
 
73
  "positions": json.loads(sres.field[id].get("position_list", "[]")),
74
  }
75
  assert isinstance(d["positions"], list)
76
+ assert len(d["positions"]) == 0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5)
77
  res["chunks"].append(d)
78
  return get_json_result(data=res)
79
  except Exception as e:
80
  if str(e).find("not_found") > 0:
81
  return get_json_result(data=False, message='No chunk found!',
82
+ code=settings.RetCode.DATA_ERROR)
83
  return server_error_response(e)
84
 
85
 
 
94
  tenant_id = tenants[0].tenant_id
95
 
96
  kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
97
+ chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
98
  if chunk is None:
99
  return server_error_response("Chunk not found")
100
  k = []
 
108
  except Exception as e:
109
  if str(e).find("NotFoundError") >= 0:
110
  return get_json_result(data=False, message='Chunk not found!',
111
+ code=settings.RetCode.DATA_ERROR)
112
  return server_error_response(e)
113
 
114
 
 
155
  v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
156
  v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
157
  d["q_%d_vec" % len(v)] = v.tolist()
158
+ settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
159
  return get_json_result(data=True)
160
  except Exception as e:
161
  return server_error_response(e)
 
170
  e, doc = DocumentService.get_by_id(req["doc_id"])
171
  if not e:
172
  return get_data_error_result(message="Document not found!")
173
+ if not settings.docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])},
174
+ search.index_name(doc.tenant_id), doc.kb_id):
175
  return get_data_error_result(message="Index updating failure")
176
  return get_json_result(data=True)
177
  except Exception as e:
 
187
  e, doc = DocumentService.get_by_id(req["doc_id"])
188
  if not e:
189
  return get_data_error_result(message="Document not found!")
190
+ if not settings.docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id):
191
  return get_data_error_result(message="Index updating failure")
192
  deleted_chunk_ids = req["chunk_ids"]
193
  chunk_number = len(deleted_chunk_ids)
 
231
  v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
232
  v = 0.1 * v[0] + 0.9 * v[1]
233
  d["q_%d_vec" % len(v)] = v.tolist()
234
+ settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
235
 
236
  DocumentService.increment_chunk_num(
237
  doc.id, doc.kb_id, c, 1, 0)
 
266
  else:
267
  return get_json_result(
268
  data=False, message='Only owner of knowledgebase authorized for this operation.',
269
+ code=settings.RetCode.OPERATING_ERROR)
270
 
271
  e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
272
  if not e:
 
282
  chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
283
  question += keyword_extraction(chat_mdl, question)
284
 
285
+ retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
286
  ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, page, size,
287
  similarity_threshold, vector_similarity_weight, top,
288
  doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
 
294
  except Exception as e:
295
  if str(e).find("not_found") > 0:
296
  return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
297
+ code=settings.RetCode.DATA_ERROR)
298
  return server_error_response(e)
299
 
300
 
 
305
  tenant_id = DocumentService.get_tenant_id(doc_id)
306
  kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
307
  req = {
308
+ "doc_ids": [doc_id],
309
  "knowledge_graph_kwd": ["graph", "mind_map"]
310
  }
311
+ sres = settings.retrievaler.search(req, search.index_name(tenant_id), kb_ids)
312
  obj = {"graph": {}, "mind_map": {}}
313
  for id in sres.ids[:2]:
314
  ty = sres.field[id]["knowledge_graph_kwd"]
 
337
  obj[ty] = content_json
338
 
339
  return get_json_result(data=obj)
 
api/apps/conversation_app.py CHANGED
@@ -25,7 +25,7 @@ from api.db import LLMType
25
  from api.db.services.dialog_service import DialogService, ConversationService, chat, ask
26
  from api.db.services.knowledgebase_service import KnowledgebaseService
27
  from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
28
- from api.settings import RetCode, retrievaler
29
  from api.utils.api_utils import get_json_result
30
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
31
  from graphrag.mind_map_extractor import MindMapExtractor
@@ -87,7 +87,7 @@ def get():
87
  else:
88
  return get_json_result(
89
  data=False, message='Only owner of conversation authorized for this operation.',
90
- code=RetCode.OPERATING_ERROR)
91
  conv = conv.to_dict()
92
  return get_json_result(data=conv)
93
  except Exception as e:
@@ -110,7 +110,7 @@ def rm():
110
  else:
111
  return get_json_result(
112
  data=False, message='Only owner of conversation authorized for this operation.',
113
- code=RetCode.OPERATING_ERROR)
114
  ConversationService.delete_by_id(cid)
115
  return get_json_result(data=True)
116
  except Exception as e:
@@ -125,7 +125,7 @@ def list_convsersation():
125
  if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
126
  return get_json_result(
127
  data=False, message='Only owner of dialog authorized for this operation.',
128
- code=RetCode.OPERATING_ERROR)
129
  convs = ConversationService.query(
130
  dialog_id=dialog_id,
131
  order_by=ConversationService.model.create_time,
@@ -297,6 +297,7 @@ def thumbup():
297
  def ask_about():
298
  req = request.json
299
  uid = current_user.id
 
300
  def stream():
301
  nonlocal req, uid
302
  try:
@@ -329,8 +330,8 @@ def mindmap():
329
  embd_mdl = TenantLLMService.model_instance(
330
  kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
331
  chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
332
- ranks = retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12,
333
- 0.3, 0.3, aggs=False)
334
  mindmap = MindMapExtractor(chat_mdl)
335
  mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
336
  if "error" in mind_map:
 
25
  from api.db.services.dialog_service import DialogService, ConversationService, chat, ask
26
  from api.db.services.knowledgebase_service import KnowledgebaseService
27
  from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
28
+ from api import settings
29
  from api.utils.api_utils import get_json_result
30
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
31
  from graphrag.mind_map_extractor import MindMapExtractor
 
87
  else:
88
  return get_json_result(
89
  data=False, message='Only owner of conversation authorized for this operation.',
90
+ code=settings.RetCode.OPERATING_ERROR)
91
  conv = conv.to_dict()
92
  return get_json_result(data=conv)
93
  except Exception as e:
 
110
  else:
111
  return get_json_result(
112
  data=False, message='Only owner of conversation authorized for this operation.',
113
+ code=settings.RetCode.OPERATING_ERROR)
114
  ConversationService.delete_by_id(cid)
115
  return get_json_result(data=True)
116
  except Exception as e:
 
125
  if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
126
  return get_json_result(
127
  data=False, message='Only owner of dialog authorized for this operation.',
128
+ code=settings.RetCode.OPERATING_ERROR)
129
  convs = ConversationService.query(
130
  dialog_id=dialog_id,
131
  order_by=ConversationService.model.create_time,
 
297
  def ask_about():
298
  req = request.json
299
  uid = current_user.id
300
+
301
  def stream():
302
  nonlocal req, uid
303
  try:
 
330
  embd_mdl = TenantLLMService.model_instance(
331
  kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
332
  chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
333
+ ranks = settings.retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12,
334
+ 0.3, 0.3, aggs=False)
335
  mindmap = MindMapExtractor(chat_mdl)
336
  mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
337
  if "error" in mind_map:
api/apps/dialog_app.py CHANGED
@@ -20,7 +20,7 @@ from api.db.services.dialog_service import DialogService
20
  from api.db import StatusEnum
21
  from api.db.services.knowledgebase_service import KnowledgebaseService
22
  from api.db.services.user_service import TenantService, UserTenantService
23
- from api.settings import RetCode
24
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
25
  from api.utils import get_uuid
26
  from api.utils.api_utils import get_json_result
@@ -175,7 +175,7 @@ def rm():
175
  else:
176
  return get_json_result(
177
  data=False, message='Only owner of dialog authorized for this operation.',
178
- code=RetCode.OPERATING_ERROR)
179
  dialog_list.append({"id": id,"status":StatusEnum.INVALID.value})
180
  DialogService.update_many_by_id(dialog_list)
181
  return get_json_result(data=True)
 
20
  from api.db import StatusEnum
21
  from api.db.services.knowledgebase_service import KnowledgebaseService
22
  from api.db.services.user_service import TenantService, UserTenantService
23
+ from api import settings
24
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
25
  from api.utils import get_uuid
26
  from api.utils.api_utils import get_json_result
 
175
  else:
176
  return get_json_result(
177
  data=False, message='Only owner of dialog authorized for this operation.',
178
+ code=settings.RetCode.OPERATING_ERROR)
179
  dialog_list.append({"id": id,"status":StatusEnum.INVALID.value})
180
  DialogService.update_many_by_id(dialog_list)
181
  return get_json_result(data=True)
api/apps/document_app.py CHANGED
@@ -34,7 +34,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
34
  from api.utils import get_uuid
35
  from api.db import FileType, TaskStatus, ParserType, FileSource
36
  from api.db.services.document_service import DocumentService, doc_upload_and_parse
37
- from api.settings import RetCode, docStoreConn
38
  from api.utils.api_utils import get_json_result
39
  from rag.utils.storage_factory import STORAGE_IMPL
40
  from api.utils.file_utils import filename_type, thumbnail, get_project_base_directory
@@ -49,16 +49,16 @@ def upload():
49
  kb_id = request.form.get("kb_id")
50
  if not kb_id:
51
  return get_json_result(
52
- data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
53
  if 'file' not in request.files:
54
  return get_json_result(
55
- data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
56
 
57
  file_objs = request.files.getlist('file')
58
  for file_obj in file_objs:
59
  if file_obj.filename == '':
60
  return get_json_result(
61
- data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
62
 
63
  e, kb = KnowledgebaseService.get_by_id(kb_id)
64
  if not e:
@@ -67,7 +67,7 @@ def upload():
67
  err, _ = FileService.upload_document(kb, file_objs, current_user.id)
68
  if err:
69
  return get_json_result(
70
- data=False, message="\n".join(err), code=RetCode.SERVER_ERROR)
71
  return get_json_result(data=True)
72
 
73
 
@@ -78,12 +78,12 @@ def web_crawl():
78
  kb_id = request.form.get("kb_id")
79
  if not kb_id:
80
  return get_json_result(
81
- data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
82
  name = request.form.get("name")
83
  url = request.form.get("url")
84
  if not is_valid_url(url):
85
  return get_json_result(
86
- data=False, message='The URL format is invalid', code=RetCode.ARGUMENT_ERROR)
87
  e, kb = KnowledgebaseService.get_by_id(kb_id)
88
  if not e:
89
  raise LookupError("Can't find this knowledgebase!")
@@ -145,7 +145,7 @@ def create():
145
  kb_id = req["kb_id"]
146
  if not kb_id:
147
  return get_json_result(
148
- data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
149
 
150
  try:
151
  e, kb = KnowledgebaseService.get_by_id(kb_id)
@@ -179,7 +179,7 @@ def list_docs():
179
  kb_id = request.args.get("kb_id")
180
  if not kb_id:
181
  return get_json_result(
182
- data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
183
  tenants = UserTenantService.query(user_id=current_user.id)
184
  for tenant in tenants:
185
  if KnowledgebaseService.query(
@@ -188,7 +188,7 @@ def list_docs():
188
  else:
189
  return get_json_result(
190
  data=False, message='Only owner of knowledgebase authorized for this operation.',
191
- code=RetCode.OPERATING_ERROR)
192
  keywords = request.args.get("keywords", "")
193
 
194
  page_number = int(request.args.get("page", 1))
@@ -218,19 +218,19 @@ def docinfos():
218
  return get_json_result(
219
  data=False,
220
  message='No authorization.',
221
- code=RetCode.AUTHENTICATION_ERROR
222
  )
223
  docs = DocumentService.get_by_ids(doc_ids)
224
  return get_json_result(data=list(docs.dicts()))
225
 
226
 
227
  @manager.route('/thumbnails', methods=['GET'])
228
- #@login_required
229
  def thumbnails():
230
  doc_ids = request.args.get("doc_ids").split(",")
231
  if not doc_ids:
232
  return get_json_result(
233
- data=False, message='Lack of "Document ID"', code=RetCode.ARGUMENT_ERROR)
234
 
235
  try:
236
  docs = DocumentService.get_thumbnails(doc_ids)
@@ -253,13 +253,13 @@ def change_status():
253
  return get_json_result(
254
  data=False,
255
  message='"Status" must be either 0 or 1!',
256
- code=RetCode.ARGUMENT_ERROR)
257
 
258
  if not DocumentService.accessible(req["doc_id"], current_user.id):
259
  return get_json_result(
260
  data=False,
261
  message='No authorization.',
262
- code=RetCode.AUTHENTICATION_ERROR)
263
 
264
  try:
265
  e, doc = DocumentService.get_by_id(req["doc_id"])
@@ -276,7 +276,8 @@ def change_status():
276
  message="Database error (Document update)!")
277
 
278
  status = int(req["status"])
279
- docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
 
280
  return get_json_result(data=True)
281
  except Exception as e:
282
  return server_error_response(e)
@@ -295,7 +296,7 @@ def rm():
295
  return get_json_result(
296
  data=False,
297
  message='No authorization.',
298
- code=RetCode.AUTHENTICATION_ERROR
299
  )
300
 
301
  root_folder = FileService.get_root_folder(current_user.id)
@@ -326,7 +327,7 @@ def rm():
326
  errors += str(e)
327
 
328
  if errors:
329
- return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
330
 
331
  return get_json_result(data=True)
332
 
@@ -341,7 +342,7 @@ def run():
341
  return get_json_result(
342
  data=False,
343
  message='No authorization.',
344
- code=RetCode.AUTHENTICATION_ERROR
345
  )
346
  try:
347
  for id in req["doc_ids"]:
@@ -358,8 +359,8 @@ def run():
358
  e, doc = DocumentService.get_by_id(id)
359
  if not e:
360
  return get_data_error_result(message="Document not found!")
361
- if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
362
- docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
363
 
364
  if str(req["run"]) == TaskStatus.RUNNING.value:
365
  TaskService.filter_delete([Task.doc_id == id])
@@ -383,7 +384,7 @@ def rename():
383
  return get_json_result(
384
  data=False,
385
  message='No authorization.',
386
- code=RetCode.AUTHENTICATION_ERROR
387
  )
388
  try:
389
  e, doc = DocumentService.get_by_id(req["doc_id"])
@@ -394,7 +395,7 @@ def rename():
394
  return get_json_result(
395
  data=False,
396
  message="The extension of file can't be changed",
397
- code=RetCode.ARGUMENT_ERROR)
398
  for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
399
  if d.name == req["name"]:
400
  return get_data_error_result(
@@ -450,7 +451,7 @@ def change_parser():
450
  return get_json_result(
451
  data=False,
452
  message='No authorization.',
453
- code=RetCode.AUTHENTICATION_ERROR
454
  )
455
  try:
456
  e, doc = DocumentService.get_by_id(req["doc_id"])
@@ -483,8 +484,8 @@ def change_parser():
483
  tenant_id = DocumentService.get_tenant_id(req["doc_id"])
484
  if not tenant_id:
485
  return get_data_error_result(message="Tenant not found!")
486
- if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
487
- docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
488
 
489
  return get_json_result(data=True)
490
  except Exception as e:
@@ -509,13 +510,13 @@ def get_image(image_id):
509
  def upload_and_parse():
510
  if 'file' not in request.files:
511
  return get_json_result(
512
- data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
513
 
514
  file_objs = request.files.getlist('file')
515
  for file_obj in file_objs:
516
  if file_obj.filename == '':
517
  return get_json_result(
518
- data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
519
 
520
  doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id)
521
 
@@ -529,7 +530,7 @@ def parse():
529
  if url:
530
  if not is_valid_url(url):
531
  return get_json_result(
532
- data=False, message='The URL format is invalid', code=RetCode.ARGUMENT_ERROR)
533
  download_path = os.path.join(get_project_base_directory(), "logs/downloads")
534
  os.makedirs(download_path, exist_ok=True)
535
  from selenium.webdriver import Chrome, ChromeOptions
@@ -553,7 +554,7 @@ def parse():
553
 
554
  if 'file' not in request.files:
555
  return get_json_result(
556
- data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
557
 
558
  file_objs = request.files.getlist('file')
559
  txt = FileService.parse_docs(file_objs, current_user.id)
 
34
  from api.utils import get_uuid
35
  from api.db import FileType, TaskStatus, ParserType, FileSource
36
  from api.db.services.document_service import DocumentService, doc_upload_and_parse
37
+ from api import settings
38
  from api.utils.api_utils import get_json_result
39
  from rag.utils.storage_factory import STORAGE_IMPL
40
  from api.utils.file_utils import filename_type, thumbnail, get_project_base_directory
 
49
  kb_id = request.form.get("kb_id")
50
  if not kb_id:
51
  return get_json_result(
52
+ data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
53
  if 'file' not in request.files:
54
  return get_json_result(
55
+ data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
56
 
57
  file_objs = request.files.getlist('file')
58
  for file_obj in file_objs:
59
  if file_obj.filename == '':
60
  return get_json_result(
61
+ data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
62
 
63
  e, kb = KnowledgebaseService.get_by_id(kb_id)
64
  if not e:
 
67
  err, _ = FileService.upload_document(kb, file_objs, current_user.id)
68
  if err:
69
  return get_json_result(
70
+ data=False, message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
71
  return get_json_result(data=True)
72
 
73
 
 
78
  kb_id = request.form.get("kb_id")
79
  if not kb_id:
80
  return get_json_result(
81
+ data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
82
  name = request.form.get("name")
83
  url = request.form.get("url")
84
  if not is_valid_url(url):
85
  return get_json_result(
86
+ data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR)
87
  e, kb = KnowledgebaseService.get_by_id(kb_id)
88
  if not e:
89
  raise LookupError("Can't find this knowledgebase!")
 
145
  kb_id = req["kb_id"]
146
  if not kb_id:
147
  return get_json_result(
148
+ data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
149
 
150
  try:
151
  e, kb = KnowledgebaseService.get_by_id(kb_id)
 
179
  kb_id = request.args.get("kb_id")
180
  if not kb_id:
181
  return get_json_result(
182
+ data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
183
  tenants = UserTenantService.query(user_id=current_user.id)
184
  for tenant in tenants:
185
  if KnowledgebaseService.query(
 
188
  else:
189
  return get_json_result(
190
  data=False, message='Only owner of knowledgebase authorized for this operation.',
191
+ code=settings.RetCode.OPERATING_ERROR)
192
  keywords = request.args.get("keywords", "")
193
 
194
  page_number = int(request.args.get("page", 1))
 
218
  return get_json_result(
219
  data=False,
220
  message='No authorization.',
221
+ code=settings.RetCode.AUTHENTICATION_ERROR
222
  )
223
  docs = DocumentService.get_by_ids(doc_ids)
224
  return get_json_result(data=list(docs.dicts()))
225
 
226
 
227
  @manager.route('/thumbnails', methods=['GET'])
228
+ # @login_required
229
  def thumbnails():
230
  doc_ids = request.args.get("doc_ids").split(",")
231
  if not doc_ids:
232
  return get_json_result(
233
+ data=False, message='Lack of "Document ID"', code=settings.RetCode.ARGUMENT_ERROR)
234
 
235
  try:
236
  docs = DocumentService.get_thumbnails(doc_ids)
 
253
  return get_json_result(
254
  data=False,
255
  message='"Status" must be either 0 or 1!',
256
+ code=settings.RetCode.ARGUMENT_ERROR)
257
 
258
  if not DocumentService.accessible(req["doc_id"], current_user.id):
259
  return get_json_result(
260
  data=False,
261
  message='No authorization.',
262
+ code=settings.RetCode.AUTHENTICATION_ERROR)
263
 
264
  try:
265
  e, doc = DocumentService.get_by_id(req["doc_id"])
 
276
  message="Database error (Document update)!")
277
 
278
  status = int(req["status"])
279
+ settings.docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status},
280
+ search.index_name(kb.tenant_id), doc.kb_id)
281
  return get_json_result(data=True)
282
  except Exception as e:
283
  return server_error_response(e)
 
296
  return get_json_result(
297
  data=False,
298
  message='No authorization.',
299
+ code=settings.RetCode.AUTHENTICATION_ERROR
300
  )
301
 
302
  root_folder = FileService.get_root_folder(current_user.id)
 
327
  errors += str(e)
328
 
329
  if errors:
330
+ return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)
331
 
332
  return get_json_result(data=True)
333
 
 
342
  return get_json_result(
343
  data=False,
344
  message='No authorization.',
345
+ code=settings.RetCode.AUTHENTICATION_ERROR
346
  )
347
  try:
348
  for id in req["doc_ids"]:
 
359
  e, doc = DocumentService.get_by_id(id)
360
  if not e:
361
  return get_data_error_result(message="Document not found!")
362
+ if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
363
+ settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
364
 
365
  if str(req["run"]) == TaskStatus.RUNNING.value:
366
  TaskService.filter_delete([Task.doc_id == id])
 
384
  return get_json_result(
385
  data=False,
386
  message='No authorization.',
387
+ code=settings.RetCode.AUTHENTICATION_ERROR
388
  )
389
  try:
390
  e, doc = DocumentService.get_by_id(req["doc_id"])
 
395
  return get_json_result(
396
  data=False,
397
  message="The extension of file can't be changed",
398
+ code=settings.RetCode.ARGUMENT_ERROR)
399
  for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
400
  if d.name == req["name"]:
401
  return get_data_error_result(
 
451
  return get_json_result(
452
  data=False,
453
  message='No authorization.',
454
+ code=settings.RetCode.AUTHENTICATION_ERROR
455
  )
456
  try:
457
  e, doc = DocumentService.get_by_id(req["doc_id"])
 
484
  tenant_id = DocumentService.get_tenant_id(req["doc_id"])
485
  if not tenant_id:
486
  return get_data_error_result(message="Tenant not found!")
487
+ if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
488
+ settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
489
 
490
  return get_json_result(data=True)
491
  except Exception as e:
 
510
  def upload_and_parse():
511
  if 'file' not in request.files:
512
  return get_json_result(
513
+ data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
514
 
515
  file_objs = request.files.getlist('file')
516
  for file_obj in file_objs:
517
  if file_obj.filename == '':
518
  return get_json_result(
519
+ data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
520
 
521
  doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id)
522
 
 
530
  if url:
531
  if not is_valid_url(url):
532
  return get_json_result(
533
+ data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR)
534
  download_path = os.path.join(get_project_base_directory(), "logs/downloads")
535
  os.makedirs(download_path, exist_ok=True)
536
  from selenium.webdriver import Chrome, ChromeOptions
 
554
 
555
  if 'file' not in request.files:
556
  return get_json_result(
557
+ data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
558
 
559
  file_objs = request.files.getlist('file')
560
  txt = FileService.parse_docs(file_objs, current_user.id)
api/apps/file2document_app.py CHANGED
@@ -24,7 +24,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
24
  from api.utils import get_uuid
25
  from api.db import FileType
26
  from api.db.services.document_service import DocumentService
27
- from api.settings import RetCode
28
  from api.utils.api_utils import get_json_result
29
 
30
 
@@ -100,7 +100,7 @@ def rm():
100
  file_ids = req["file_ids"]
101
  if not file_ids:
102
  return get_json_result(
103
- data=False, message='Lack of "Files ID"', code=RetCode.ARGUMENT_ERROR)
104
  try:
105
  for file_id in file_ids:
106
  informs = File2DocumentService.get_by_file_id(file_id)
 
24
  from api.utils import get_uuid
25
  from api.db import FileType
26
  from api.db.services.document_service import DocumentService
27
+ from api import settings
28
  from api.utils.api_utils import get_json_result
29
 
30
 
 
100
  file_ids = req["file_ids"]
101
  if not file_ids:
102
  return get_json_result(
103
+ data=False, message='Lack of "Files ID"', code=settings.RetCode.ARGUMENT_ERROR)
104
  try:
105
  for file_id in file_ids:
106
  informs = File2DocumentService.get_by_file_id(file_id)
api/apps/file_app.py CHANGED
@@ -28,7 +28,7 @@ from api.utils import get_uuid
28
  from api.db import FileType, FileSource
29
  from api.db.services import duplicate_name
30
  from api.db.services.file_service import FileService
31
- from api.settings import RetCode
32
  from api.utils.api_utils import get_json_result
33
  from api.utils.file_utils import filename_type
34
  from rag.utils.storage_factory import STORAGE_IMPL
@@ -46,13 +46,13 @@ def upload():
46
 
47
  if 'file' not in request.files:
48
  return get_json_result(
49
- data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
50
  file_objs = request.files.getlist('file')
51
 
52
  for file_obj in file_objs:
53
  if file_obj.filename == '':
54
  return get_json_result(
55
- data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
56
  file_res = []
57
  try:
58
  for file_obj in file_objs:
@@ -134,7 +134,7 @@ def create():
134
  try:
135
  if not FileService.is_parent_folder_exist(pf_id):
136
  return get_json_result(
137
- data=False, message="Parent Folder Doesn't Exist!", code=RetCode.OPERATING_ERROR)
138
  if FileService.query(name=req["name"], parent_id=pf_id):
139
  return get_data_error_result(
140
  message="Duplicated folder name in the same folder.")
@@ -299,7 +299,7 @@ def rename():
299
  return get_json_result(
300
  data=False,
301
  message="The extension of file can't be changed",
302
- code=RetCode.ARGUMENT_ERROR)
303
  for file in FileService.query(name=req["name"], pf_id=file.parent_id):
304
  if file.name == req["name"]:
305
  return get_data_error_result(
 
28
  from api.db import FileType, FileSource
29
  from api.db.services import duplicate_name
30
  from api.db.services.file_service import FileService
31
+ from api import settings
32
  from api.utils.api_utils import get_json_result
33
  from api.utils.file_utils import filename_type
34
  from rag.utils.storage_factory import STORAGE_IMPL
 
46
 
47
  if 'file' not in request.files:
48
  return get_json_result(
49
+ data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
50
  file_objs = request.files.getlist('file')
51
 
52
  for file_obj in file_objs:
53
  if file_obj.filename == '':
54
  return get_json_result(
55
+ data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
56
  file_res = []
57
  try:
58
  for file_obj in file_objs:
 
134
  try:
135
  if not FileService.is_parent_folder_exist(pf_id):
136
  return get_json_result(
137
+ data=False, message="Parent Folder Doesn't Exist!", code=settings.RetCode.OPERATING_ERROR)
138
  if FileService.query(name=req["name"], parent_id=pf_id):
139
  return get_data_error_result(
140
  message="Duplicated folder name in the same folder.")
 
299
  return get_json_result(
300
  data=False,
301
  message="The extension of file can't be changed",
302
+ code=settings.RetCode.ARGUMENT_ERROR)
303
  for file in FileService.query(name=req["name"], pf_id=file.parent_id):
304
  if file.name == req["name"]:
305
  return get_data_error_result(
api/apps/kb_app.py CHANGED
@@ -26,9 +26,8 @@ from api.utils import get_uuid
26
  from api.db import StatusEnum, FileSource
27
  from api.db.services.knowledgebase_service import KnowledgebaseService
28
  from api.db.db_models import File
29
- from api.settings import RetCode
30
  from api.utils.api_utils import get_json_result
31
- from api.settings import docStoreConn
32
  from rag.nlp import search
33
 
34
 
@@ -68,13 +67,13 @@ def update():
68
  return get_json_result(
69
  data=False,
70
  message='No authorization.',
71
- code=RetCode.AUTHENTICATION_ERROR
72
  )
73
  try:
74
  if not KnowledgebaseService.query(
75
  created_by=current_user.id, id=req["kb_id"]):
76
  return get_json_result(
77
- data=False, message='Only owner of knowledgebase authorized for this operation.', code=RetCode.OPERATING_ERROR)
78
 
79
  e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
80
  if not e:
@@ -113,7 +112,7 @@ def detail():
113
  else:
114
  return get_json_result(
115
  data=False, message='Only owner of knowledgebase authorized for this operation.',
116
- code=RetCode.OPERATING_ERROR)
117
  kb = KnowledgebaseService.get_detail(kb_id)
118
  if not kb:
119
  return get_data_error_result(
@@ -148,14 +147,14 @@ def rm():
148
  return get_json_result(
149
  data=False,
150
  message='No authorization.',
151
- code=RetCode.AUTHENTICATION_ERROR
152
  )
153
  try:
154
  kbs = KnowledgebaseService.query(
155
  created_by=current_user.id, id=req["kb_id"])
156
  if not kbs:
157
  return get_json_result(
158
- data=False, message='Only owner of knowledgebase authorized for this operation.', code=RetCode.OPERATING_ERROR)
159
 
160
  for doc in DocumentService.query(kb_id=req["kb_id"]):
161
  if not DocumentService.remove_document(doc, kbs[0].tenant_id):
@@ -170,7 +169,7 @@ def rm():
170
  message="Database error (Knowledgebase removal)!")
171
  tenants = UserTenantService.query(user_id=current_user.id)
172
  for tenant in tenants:
173
- docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"])
174
  return get_json_result(data=True)
175
  except Exception as e:
176
  return server_error_response(e)
 
26
  from api.db import StatusEnum, FileSource
27
  from api.db.services.knowledgebase_service import KnowledgebaseService
28
  from api.db.db_models import File
 
29
  from api.utils.api_utils import get_json_result
30
+ from api import settings
31
  from rag.nlp import search
32
 
33
 
 
67
  return get_json_result(
68
  data=False,
69
  message='No authorization.',
70
+ code=settings.RetCode.AUTHENTICATION_ERROR
71
  )
72
  try:
73
  if not KnowledgebaseService.query(
74
  created_by=current_user.id, id=req["kb_id"]):
75
  return get_json_result(
76
+ data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR)
77
 
78
  e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
79
  if not e:
 
112
  else:
113
  return get_json_result(
114
  data=False, message='Only owner of knowledgebase authorized for this operation.',
115
+ code=settings.RetCode.OPERATING_ERROR)
116
  kb = KnowledgebaseService.get_detail(kb_id)
117
  if not kb:
118
  return get_data_error_result(
 
147
  return get_json_result(
148
  data=False,
149
  message='No authorization.',
150
+ code=settings.RetCode.AUTHENTICATION_ERROR
151
  )
152
  try:
153
  kbs = KnowledgebaseService.query(
154
  created_by=current_user.id, id=req["kb_id"])
155
  if not kbs:
156
  return get_json_result(
157
+ data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR)
158
 
159
  for doc in DocumentService.query(kb_id=req["kb_id"]):
160
  if not DocumentService.remove_document(doc, kbs[0].tenant_id):
 
169
  message="Database error (Knowledgebase removal)!")
170
  tenants = UserTenantService.query(user_id=current_user.id)
171
  for tenant in tenants:
172
+ settings.docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"])
173
  return get_json_result(data=True)
174
  except Exception as e:
175
  return server_error_response(e)
api/apps/llm_app.py CHANGED
@@ -19,7 +19,7 @@ import json
19
  from flask import request
20
  from flask_login import login_required, current_user
21
  from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
22
- from api.settings import LIGHTEN
23
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
24
  from api.db import StatusEnum, LLMType
25
  from api.db.db_models import TenantLLM
@@ -333,7 +333,7 @@ def my_llms():
333
  @login_required
334
  def list_app():
335
  self_deploied = ["Youdao","FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"]
336
- weighted = ["Youdao","FastEmbed", "BAAI"] if LIGHTEN != 0 else []
337
  model_type = request.args.get("model_type")
338
  try:
339
  objs = TenantLLMService.query(tenant_id=current_user.id)
 
19
  from flask import request
20
  from flask_login import login_required, current_user
21
  from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
22
+ from api import settings
23
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
24
  from api.db import StatusEnum, LLMType
25
  from api.db.db_models import TenantLLM
 
333
  @login_required
334
  def list_app():
335
  self_deploied = ["Youdao","FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"]
336
+ weighted = ["Youdao","FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else []
337
  model_type = request.args.get("model_type")
338
  try:
339
  objs = TenantLLMService.query(tenant_id=current_user.id)
api/apps/sdk/chat.py CHANGED
@@ -14,7 +14,7 @@
14
  # limitations under the License.
15
  #
16
  from flask import request
17
- from api.settings import RetCode
18
  from api.db import StatusEnum
19
  from api.db.services.dialog_service import DialogService
20
  from api.db.services.knowledgebase_service import KnowledgebaseService
@@ -44,7 +44,7 @@ def create(tenant_id):
44
  kbs = KnowledgebaseService.get_by_ids(ids)
45
  embd_count = list(set([kb.embd_id for kb in kbs]))
46
  if len(embd_count) != 1:
47
- return get_result(message='Datasets use different embedding models."',code=RetCode.AUTHENTICATION_ERROR)
48
  req["kb_ids"] = ids
49
  # llm
50
  llm = req.get("llm")
@@ -173,7 +173,7 @@ def update(tenant_id,chat_id):
173
  if len(embd_count) != 1 :
174
  return get_result(
175
  message='Datasets use different embedding models."',
176
- code=RetCode.AUTHENTICATION_ERROR)
177
  req["kb_ids"] = ids
178
  llm = req.get("llm")
179
  if llm:
 
14
  # limitations under the License.
15
  #
16
  from flask import request
17
+ from api import settings
18
  from api.db import StatusEnum
19
  from api.db.services.dialog_service import DialogService
20
  from api.db.services.knowledgebase_service import KnowledgebaseService
 
44
  kbs = KnowledgebaseService.get_by_ids(ids)
45
  embd_count = list(set([kb.embd_id for kb in kbs]))
46
  if len(embd_count) != 1:
47
+ return get_result(message='Datasets use different embedding models."',code=settings.RetCode.AUTHENTICATION_ERROR)
48
  req["kb_ids"] = ids
49
  # llm
50
  llm = req.get("llm")
 
173
  if len(embd_count) != 1 :
174
  return get_result(
175
  message='Datasets use different embedding models."',
176
+ code=settings.RetCode.AUTHENTICATION_ERROR)
177
  req["kb_ids"] = ids
178
  llm = req.get("llm")
179
  if llm:
api/apps/sdk/dataset.py CHANGED
@@ -23,7 +23,7 @@ from api.db.services.file_service import FileService
23
  from api.db.services.knowledgebase_service import KnowledgebaseService
24
  from api.db.services.llm_service import TenantLLMService, LLMService
25
  from api.db.services.user_service import TenantService
26
- from api.settings import RetCode
27
  from api.utils import get_uuid
28
  from api.utils.api_utils import (
29
  get_result,
@@ -255,7 +255,7 @@ def delete(tenant_id):
255
  File2DocumentService.delete_by_document_id(doc.id)
256
  if not KnowledgebaseService.delete_by_id(id):
257
  return get_error_data_result(message="Delete dataset error.(Database error)")
258
- return get_result(code=RetCode.SUCCESS)
259
 
260
 
261
  @manager.route("/datasets/<dataset_id>", methods=["PUT"])
@@ -424,7 +424,7 @@ def update(tenant_id, dataset_id):
424
  )
425
  if not KnowledgebaseService.update_by_id(kb.id, req):
426
  return get_error_data_result(message="Update dataset error.(Database error)")
427
- return get_result(code=RetCode.SUCCESS)
428
 
429
 
430
  @manager.route("/datasets", methods=["GET"])
 
23
  from api.db.services.knowledgebase_service import KnowledgebaseService
24
  from api.db.services.llm_service import TenantLLMService, LLMService
25
  from api.db.services.user_service import TenantService
26
+ from api import settings
27
  from api.utils import get_uuid
28
  from api.utils.api_utils import (
29
  get_result,
 
255
  File2DocumentService.delete_by_document_id(doc.id)
256
  if not KnowledgebaseService.delete_by_id(id):
257
  return get_error_data_result(message="Delete dataset error.(Database error)")
258
+ return get_result(code=settings.RetCode.SUCCESS)
259
 
260
 
261
  @manager.route("/datasets/<dataset_id>", methods=["PUT"])
 
424
  )
425
  if not KnowledgebaseService.update_by_id(kb.id, req):
426
  return get_error_data_result(message="Update dataset error.(Database error)")
427
+ return get_result(code=settings.RetCode.SUCCESS)
428
 
429
 
430
  @manager.route("/datasets", methods=["GET"])
api/apps/sdk/dify_retrieval.py CHANGED
@@ -18,7 +18,7 @@ from flask import request, jsonify
18
  from api.db import LLMType, ParserType
19
  from api.db.services.knowledgebase_service import KnowledgebaseService
20
  from api.db.services.llm_service import LLMBundle
21
- from api.settings import retrievaler, kg_retrievaler, RetCode
22
  from api.utils.api_utils import validate_request, build_error_result, apikey_required
23
 
24
 
@@ -37,14 +37,14 @@ def retrieval(tenant_id):
37
 
38
  e, kb = KnowledgebaseService.get_by_id(kb_id)
39
  if not e:
40
- return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND)
41
 
42
  if kb.tenant_id != tenant_id:
43
- return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND)
44
 
45
  embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
46
 
47
- retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
48
  ranks = retr.retrieval(
49
  question,
50
  embd_mdl,
@@ -72,6 +72,6 @@ def retrieval(tenant_id):
72
  if str(e).find("not_found") > 0:
73
  return build_error_result(
74
  message='No chunk found! Check the chunk status please!',
75
- code=RetCode.NOT_FOUND
76
  )
77
- return build_error_result(message=str(e), code=RetCode.SERVER_ERROR)
 
18
  from api.db import LLMType, ParserType
19
  from api.db.services.knowledgebase_service import KnowledgebaseService
20
  from api.db.services.llm_service import LLMBundle
21
+ from api import settings
22
  from api.utils.api_utils import validate_request, build_error_result, apikey_required
23
 
24
 
 
37
 
38
  e, kb = KnowledgebaseService.get_by_id(kb_id)
39
  if not e:
40
+ return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND)
41
 
42
  if kb.tenant_id != tenant_id:
43
+ return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND)
44
 
45
  embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
46
 
47
+ retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
48
  ranks = retr.retrieval(
49
  question,
50
  embd_mdl,
 
72
  if str(e).find("not_found") > 0:
73
  return build_error_result(
74
  message='No chunk found! Check the chunk status please!',
75
+ code=settings.RetCode.NOT_FOUND
76
  )
77
+ return build_error_result(message=str(e), code=settings.RetCode.SERVER_ERROR)
api/apps/sdk/doc.py CHANGED
@@ -21,7 +21,7 @@ from rag.app.qa import rmPrefix, beAdoc
21
  from rag.nlp import rag_tokenizer
22
  from api.db import LLMType, ParserType
23
  from api.db.services.llm_service import TenantLLMService
24
- from api.settings import kg_retrievaler
25
  import hashlib
26
  import re
27
  from api.utils.api_utils import token_required
@@ -37,11 +37,10 @@ from api.db.services.document_service import DocumentService
37
  from api.db.services.file2document_service import File2DocumentService
38
  from api.db.services.file_service import FileService
39
  from api.db.services.knowledgebase_service import KnowledgebaseService
40
- from api.settings import RetCode, retrievaler
41
  from api.utils.api_utils import construct_json_result, get_parser_config
42
  from rag.nlp import search
43
  from rag.utils import rmSpace
44
- from api.settings import docStoreConn
45
  from rag.utils.storage_factory import STORAGE_IMPL
46
  import os
47
 
@@ -109,13 +108,13 @@ def upload(dataset_id, tenant_id):
109
  """
110
  if "file" not in request.files:
111
  return get_error_data_result(
112
- message="No file part!", code=RetCode.ARGUMENT_ERROR
113
  )
114
  file_objs = request.files.getlist("file")
115
  for file_obj in file_objs:
116
  if file_obj.filename == "":
117
  return get_result(
118
- message="No file selected!", code=RetCode.ARGUMENT_ERROR
119
  )
120
  # total size
121
  total_size = 0
@@ -127,14 +126,14 @@ def upload(dataset_id, tenant_id):
127
  if total_size > MAX_TOTAL_FILE_SIZE:
128
  return get_result(
129
  message=f"Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)",
130
- code=RetCode.ARGUMENT_ERROR,
131
  )
132
  e, kb = KnowledgebaseService.get_by_id(dataset_id)
133
  if not e:
134
  raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
135
  err, files = FileService.upload_document(kb, file_objs, tenant_id)
136
  if err:
137
- return get_result(message="\n".join(err), code=RetCode.SERVER_ERROR)
138
  # rename key's name
139
  renamed_doc_list = []
140
  for file in files:
@@ -221,12 +220,12 @@ def update_doc(tenant_id, dataset_id, document_id):
221
 
222
  if "name" in req and req["name"] != doc.name:
223
  if (
224
- pathlib.Path(req["name"].lower()).suffix
225
- != pathlib.Path(doc.name.lower()).suffix
226
  ):
227
  return get_result(
228
  message="The extension of file can't be changed",
229
- code=RetCode.ARGUMENT_ERROR,
230
  )
231
  for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
232
  if d.name == req["name"]:
@@ -292,7 +291,7 @@ def update_doc(tenant_id, dataset_id, document_id):
292
  )
293
  if not e:
294
  return get_error_data_result(message="Document not found!")
295
- docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
296
 
297
  return get_result()
298
 
@@ -349,7 +348,7 @@ def download(tenant_id, dataset_id, document_id):
349
  file_stream = STORAGE_IMPL.get(doc_id, doc_location)
350
  if not file_stream:
351
  return construct_json_result(
352
- message="This file is empty.", code=RetCode.DATA_ERROR
353
  )
354
  file = BytesIO(file_stream)
355
  # Use send_file with a proper filename and MIME type
@@ -582,7 +581,7 @@ def delete(tenant_id, dataset_id):
582
  errors += str(e)
583
 
584
  if errors:
585
- return get_result(message=errors, code=RetCode.SERVER_ERROR)
586
 
587
  return get_result()
588
 
@@ -644,7 +643,7 @@ def parse(tenant_id, dataset_id):
644
  info["chunk_num"] = 0
645
  info["token_num"] = 0
646
  DocumentService.update_by_id(id, info)
647
- docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
648
  TaskService.filter_delete([Task.doc_id == id])
649
  e, doc = DocumentService.get_by_id(id)
650
  doc = doc.to_dict()
@@ -708,7 +707,7 @@ def stop_parsing(tenant_id, dataset_id):
708
  )
709
  info = {"run": "2", "progress": 0, "chunk_num": 0}
710
  DocumentService.update_by_id(id, info)
711
- docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
712
  return get_result()
713
 
714
 
@@ -828,8 +827,9 @@ def list_chunks(tenant_id, dataset_id, document_id):
828
 
829
  res = {"total": 0, "chunks": [], "doc": renamed_doc}
830
  origin_chunks = []
831
- if docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
832
- sres = retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
 
833
  res["total"] = sres.total
834
  sign = 0
835
  for id in sres.ids:
@@ -1003,7 +1003,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
1003
  v, c = embd_mdl.encode([doc.name, req["content"]])
1004
  v = 0.1 * v[0] + 0.9 * v[1]
1005
  d["q_%d_vec" % len(v)] = v.tolist()
1006
- docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
1007
 
1008
  DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
1009
  # rename keys
@@ -1078,7 +1078,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
1078
  condition = {"doc_id": document_id}
1079
  if "chunk_ids" in req:
1080
  condition["id"] = req["chunk_ids"]
1081
- chunk_number = docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
1082
  if chunk_number != 0:
1083
  DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
1084
  if "chunk_ids" in req and chunk_number != len(req["chunk_ids"]):
@@ -1143,7 +1143,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
1143
  schema:
1144
  type: object
1145
  """
1146
- chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
1147
  if chunk is None:
1148
  return get_error_data_result(f"Can't find this chunk {chunk_id}")
1149
  if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
@@ -1187,7 +1187,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
1187
  v, c = embd_mdl.encode([doc.name, d["content_with_weight"]])
1188
  v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
1189
  d["q_%d_vec" % len(v)] = v.tolist()
1190
- docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
1191
  return get_result()
1192
 
1193
 
@@ -1285,7 +1285,7 @@ def retrieval_test(tenant_id):
1285
  if len(embd_nms) != 1:
1286
  return get_result(
1287
  message='Datasets use different embedding models."',
1288
- code=RetCode.AUTHENTICATION_ERROR,
1289
  )
1290
  if "question" not in req:
1291
  return get_error_data_result("`question` is required.")
@@ -1326,7 +1326,7 @@ def retrieval_test(tenant_id):
1326
  chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
1327
  question += keyword_extraction(chat_mdl, question)
1328
 
1329
- retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
1330
  ranks = retr.retrieval(
1331
  question,
1332
  embd_mdl,
@@ -1366,6 +1366,6 @@ def retrieval_test(tenant_id):
1366
  if str(e).find("not_found") > 0:
1367
  return get_result(
1368
  message="No chunk found! Check the chunk status please!",
1369
- code=RetCode.DATA_ERROR,
1370
  )
1371
- return server_error_response(e)
 
21
  from rag.nlp import rag_tokenizer
22
  from api.db import LLMType, ParserType
23
  from api.db.services.llm_service import TenantLLMService
24
+ from api import settings
25
  import hashlib
26
  import re
27
  from api.utils.api_utils import token_required
 
37
  from api.db.services.file2document_service import File2DocumentService
38
  from api.db.services.file_service import FileService
39
  from api.db.services.knowledgebase_service import KnowledgebaseService
40
+ from api import settings
41
  from api.utils.api_utils import construct_json_result, get_parser_config
42
  from rag.nlp import search
43
  from rag.utils import rmSpace
 
44
  from rag.utils.storage_factory import STORAGE_IMPL
45
  import os
46
 
 
108
  """
109
  if "file" not in request.files:
110
  return get_error_data_result(
111
+ message="No file part!", code=settings.RetCode.ARGUMENT_ERROR
112
  )
113
  file_objs = request.files.getlist("file")
114
  for file_obj in file_objs:
115
  if file_obj.filename == "":
116
  return get_result(
117
+ message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR
118
  )
119
  # total size
120
  total_size = 0
 
126
  if total_size > MAX_TOTAL_FILE_SIZE:
127
  return get_result(
128
  message=f"Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)",
129
+ code=settings.RetCode.ARGUMENT_ERROR,
130
  )
131
  e, kb = KnowledgebaseService.get_by_id(dataset_id)
132
  if not e:
133
  raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
134
  err, files = FileService.upload_document(kb, file_objs, tenant_id)
135
  if err:
136
+ return get_result(message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
137
  # rename key's name
138
  renamed_doc_list = []
139
  for file in files:
 
220
 
221
  if "name" in req and req["name"] != doc.name:
222
  if (
223
+ pathlib.Path(req["name"].lower()).suffix
224
+ != pathlib.Path(doc.name.lower()).suffix
225
  ):
226
  return get_result(
227
  message="The extension of file can't be changed",
228
+ code=settings.RetCode.ARGUMENT_ERROR,
229
  )
230
  for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
231
  if d.name == req["name"]:
 
291
  )
292
  if not e:
293
  return get_error_data_result(message="Document not found!")
294
+ settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
295
 
296
  return get_result()
297
 
 
348
  file_stream = STORAGE_IMPL.get(doc_id, doc_location)
349
  if not file_stream:
350
  return construct_json_result(
351
+ message="This file is empty.", code=settings.RetCode.DATA_ERROR
352
  )
353
  file = BytesIO(file_stream)
354
  # Use send_file with a proper filename and MIME type
 
581
  errors += str(e)
582
 
583
  if errors:
584
+ return get_result(message=errors, code=settings.RetCode.SERVER_ERROR)
585
 
586
  return get_result()
587
 
 
643
  info["chunk_num"] = 0
644
  info["token_num"] = 0
645
  DocumentService.update_by_id(id, info)
646
+ settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
647
  TaskService.filter_delete([Task.doc_id == id])
648
  e, doc = DocumentService.get_by_id(id)
649
  doc = doc.to_dict()
 
707
  )
708
  info = {"run": "2", "progress": 0, "chunk_num": 0}
709
  DocumentService.update_by_id(id, info)
710
+ settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
711
  return get_result()
712
 
713
 
 
827
 
828
  res = {"total": 0, "chunks": [], "doc": renamed_doc}
829
  origin_chunks = []
830
+ if settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
831
+ sres = settings.retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None,
832
+ highlight=True)
833
  res["total"] = sres.total
834
  sign = 0
835
  for id in sres.ids:
 
1003
  v, c = embd_mdl.encode([doc.name, req["content"]])
1004
  v = 0.1 * v[0] + 0.9 * v[1]
1005
  d["q_%d_vec" % len(v)] = v.tolist()
1006
+ settings.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
1007
 
1008
  DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
1009
  # rename keys
 
1078
  condition = {"doc_id": document_id}
1079
  if "chunk_ids" in req:
1080
  condition["id"] = req["chunk_ids"]
1081
+ chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
1082
  if chunk_number != 0:
1083
  DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
1084
  if "chunk_ids" in req and chunk_number != len(req["chunk_ids"]):
 
1143
  schema:
1144
  type: object
1145
  """
1146
+ chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
1147
  if chunk is None:
1148
  return get_error_data_result(f"Can't find this chunk {chunk_id}")
1149
  if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
 
1187
  v, c = embd_mdl.encode([doc.name, d["content_with_weight"]])
1188
  v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
1189
  d["q_%d_vec" % len(v)] = v.tolist()
1190
+ settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
1191
  return get_result()
1192
 
1193
 
 
1285
  if len(embd_nms) != 1:
1286
  return get_result(
1287
  message='Datasets use different embedding models."',
1288
+ code=settings.RetCode.AUTHENTICATION_ERROR,
1289
  )
1290
  if "question" not in req:
1291
  return get_error_data_result("`question` is required.")
 
1326
  chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
1327
  question += keyword_extraction(chat_mdl, question)
1328
 
1329
+ retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
1330
  ranks = retr.retrieval(
1331
  question,
1332
  embd_mdl,
 
1366
  if str(e).find("not_found") > 0:
1367
  return get_result(
1368
  message="No chunk found! Check the chunk status please!",
1369
+ code=settings.RetCode.DATA_ERROR,
1370
  )
1371
+ return server_error_response(e)
api/apps/system_app.py CHANGED
@@ -22,7 +22,7 @@ from api.db.db_models import APIToken
22
  from api.db.services.api_service import APITokenService
23
  from api.db.services.knowledgebase_service import KnowledgebaseService
24
  from api.db.services.user_service import UserTenantService
25
- from api.settings import DATABASE_TYPE
26
  from api.utils import current_timestamp, datetime_format
27
  from api.utils.api_utils import (
28
  get_json_result,
@@ -31,7 +31,6 @@ from api.utils.api_utils import (
31
  generate_confirmation_token,
32
  )
33
  from api.versions import get_ragflow_version
34
- from api.settings import docStoreConn
35
  from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE
36
  from timeit import default_timer as timer
37
 
@@ -98,7 +97,7 @@ def status():
98
  res = {}
99
  st = timer()
100
  try:
101
- res["doc_store"] = docStoreConn.health()
102
  res["doc_store"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
103
  except Exception as e:
104
  res["doc_store"] = {
@@ -128,13 +127,13 @@ def status():
128
  try:
129
  KnowledgebaseService.get_by_id("x")
130
  res["database"] = {
131
- "database": DATABASE_TYPE.lower(),
132
  "status": "green",
133
  "elapsed": "{:.1f}".format((timer() - st) * 1000.0),
134
  }
135
  except Exception as e:
136
  res["database"] = {
137
- "database": DATABASE_TYPE.lower(),
138
  "status": "red",
139
  "elapsed": "{:.1f}".format((timer() - st) * 1000.0),
140
  "error": str(e),
 
22
  from api.db.services.api_service import APITokenService
23
  from api.db.services.knowledgebase_service import KnowledgebaseService
24
  from api.db.services.user_service import UserTenantService
25
+ from api import settings
26
  from api.utils import current_timestamp, datetime_format
27
  from api.utils.api_utils import (
28
  get_json_result,
 
31
  generate_confirmation_token,
32
  )
33
  from api.versions import get_ragflow_version
 
34
  from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE
35
  from timeit import default_timer as timer
36
 
 
97
  res = {}
98
  st = timer()
99
  try:
100
+ res["doc_store"] = settings.docStoreConn.health()
101
  res["doc_store"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
102
  except Exception as e:
103
  res["doc_store"] = {
 
127
  try:
128
  KnowledgebaseService.get_by_id("x")
129
  res["database"] = {
130
+ "database": settings.DATABASE_TYPE.lower(),
131
  "status": "green",
132
  "elapsed": "{:.1f}".format((timer() - st) * 1000.0),
133
  }
134
  except Exception as e:
135
  res["database"] = {
136
+ "database": settings.DATABASE_TYPE.lower(),
137
  "status": "red",
138
  "elapsed": "{:.1f}".format((timer() - st) * 1000.0),
139
  "error": str(e),
api/apps/user_app.py CHANGED
@@ -38,20 +38,7 @@ from api.utils import (
38
  datetime_format,
39
  )
40
  from api.db import UserTenantRole, FileType
41
- from api.settings import (
42
- RetCode,
43
- GITHUB_OAUTH,
44
- FEISHU_OAUTH,
45
- CHAT_MDL,
46
- EMBEDDING_MDL,
47
- ASR_MDL,
48
- IMAGE2TEXT_MDL,
49
- PARSERS,
50
- API_KEY,
51
- LLM_FACTORY,
52
- LLM_BASE_URL,
53
- RERANK_MDL,
54
- )
55
  from api.db.services.user_service import UserService, TenantService, UserTenantService
56
  from api.db.services.file_service import FileService
57
  from api.utils.api_utils import get_json_result, construct_response
@@ -90,7 +77,7 @@ def login():
90
  """
91
  if not request.json:
92
  return get_json_result(
93
- data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!"
94
  )
95
 
96
  email = request.json.get("email", "")
@@ -98,7 +85,7 @@ def login():
98
  if not users:
99
  return get_json_result(
100
  data=False,
101
- code=RetCode.AUTHENTICATION_ERROR,
102
  message=f"Email: {email} is not registered!",
103
  )
104
 
@@ -107,7 +94,7 @@ def login():
107
  password = decrypt(password)
108
  except BaseException:
109
  return get_json_result(
110
- data=False, code=RetCode.SERVER_ERROR, message="Fail to crypt password"
111
  )
112
 
113
  user = UserService.query_user(email, password)
@@ -123,7 +110,7 @@ def login():
123
  else:
124
  return get_json_result(
125
  data=False,
126
- code=RetCode.AUTHENTICATION_ERROR,
127
  message="Email and password do not match!",
128
  )
129
 
@@ -150,10 +137,10 @@ def github_callback():
150
  import requests
151
 
152
  res = requests.post(
153
- GITHUB_OAUTH.get("url"),
154
  data={
155
- "client_id": GITHUB_OAUTH.get("client_id"),
156
- "client_secret": GITHUB_OAUTH.get("secret_key"),
157
  "code": request.args.get("code"),
158
  },
159
  headers={"Accept": "application/json"},
@@ -235,11 +222,11 @@ def feishu_callback():
235
  import requests
236
 
237
  app_access_token_res = requests.post(
238
- FEISHU_OAUTH.get("app_access_token_url"),
239
  data=json.dumps(
240
  {
241
- "app_id": FEISHU_OAUTH.get("app_id"),
242
- "app_secret": FEISHU_OAUTH.get("app_secret"),
243
  }
244
  ),
245
  headers={"Content-Type": "application/json; charset=utf-8"},
@@ -249,10 +236,10 @@ def feishu_callback():
249
  return redirect("/?error=%s" % app_access_token_res)
250
 
251
  res = requests.post(
252
- FEISHU_OAUTH.get("user_access_token_url"),
253
  data=json.dumps(
254
  {
255
- "grant_type": FEISHU_OAUTH.get("grant_type"),
256
  "code": request.args.get("code"),
257
  }
258
  ),
@@ -405,11 +392,11 @@ def setting_user():
405
  if request_data.get("password"):
406
  new_password = request_data.get("new_password")
407
  if not check_password_hash(
408
- current_user.password, decrypt(request_data["password"])
409
  ):
410
  return get_json_result(
411
  data=False,
412
- code=RetCode.AUTHENTICATION_ERROR,
413
  message="Password error!",
414
  )
415
 
@@ -438,7 +425,7 @@ def setting_user():
438
  except Exception as e:
439
  logging.exception(e)
440
  return get_json_result(
441
- data=False, message="Update failure!", code=RetCode.EXCEPTION_ERROR
442
  )
443
 
444
 
@@ -497,12 +484,12 @@ def user_register(user_id, user):
497
  tenant = {
498
  "id": user_id,
499
  "name": user["nickname"] + "‘s Kingdom",
500
- "llm_id": CHAT_MDL,
501
- "embd_id": EMBEDDING_MDL,
502
- "asr_id": ASR_MDL,
503
- "parser_ids": PARSERS,
504
- "img2txt_id": IMAGE2TEXT_MDL,
505
- "rerank_id": RERANK_MDL,
506
  }
507
  usr_tenant = {
508
  "tenant_id": user_id,
@@ -522,15 +509,15 @@ def user_register(user_id, user):
522
  "location": "",
523
  }
524
  tenant_llm = []
525
- for llm in LLMService.query(fid=LLM_FACTORY):
526
  tenant_llm.append(
527
  {
528
  "tenant_id": user_id,
529
- "llm_factory": LLM_FACTORY,
530
  "llm_name": llm.llm_name,
531
  "model_type": llm.model_type,
532
- "api_key": API_KEY,
533
- "api_base": LLM_BASE_URL,
534
  }
535
  )
536
 
@@ -582,7 +569,7 @@ def user_add():
582
  return get_json_result(
583
  data=False,
584
  message=f"Invalid email address: {email_address}!",
585
- code=RetCode.OPERATING_ERROR,
586
  )
587
 
588
  # Check if the email address is already used
@@ -590,7 +577,7 @@ def user_add():
590
  return get_json_result(
591
  data=False,
592
  message=f"Email: {email_address} has already registered!",
593
- code=RetCode.OPERATING_ERROR,
594
  )
595
 
596
  # Construct user info data
@@ -625,7 +612,7 @@ def user_add():
625
  return get_json_result(
626
  data=False,
627
  message=f"User registration failure, error: {str(e)}",
628
- code=RetCode.EXCEPTION_ERROR,
629
  )
630
 
631
 
 
38
  datetime_format,
39
  )
40
  from api.db import UserTenantRole, FileType
41
+ from api import settings
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  from api.db.services.user_service import UserService, TenantService, UserTenantService
43
  from api.db.services.file_service import FileService
44
  from api.utils.api_utils import get_json_result, construct_response
 
77
  """
78
  if not request.json:
79
  return get_json_result(
80
+ data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="Unauthorized!"
81
  )
82
 
83
  email = request.json.get("email", "")
 
85
  if not users:
86
  return get_json_result(
87
  data=False,
88
+ code=settings.RetCode.AUTHENTICATION_ERROR,
89
  message=f"Email: {email} is not registered!",
90
  )
91
 
 
94
  password = decrypt(password)
95
  except BaseException:
96
  return get_json_result(
97
+ data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password"
98
  )
99
 
100
  user = UserService.query_user(email, password)
 
110
  else:
111
  return get_json_result(
112
  data=False,
113
+ code=settings.RetCode.AUTHENTICATION_ERROR,
114
  message="Email and password do not match!",
115
  )
116
 
 
137
  import requests
138
 
139
  res = requests.post(
140
+ settings.GITHUB_OAUTH.get("url"),
141
  data={
142
+ "client_id": settings.GITHUB_OAUTH.get("client_id"),
143
+ "client_secret": settings.GITHUB_OAUTH.get("secret_key"),
144
  "code": request.args.get("code"),
145
  },
146
  headers={"Accept": "application/json"},
 
222
  import requests
223
 
224
  app_access_token_res = requests.post(
225
+ settings.FEISHU_OAUTH.get("app_access_token_url"),
226
  data=json.dumps(
227
  {
228
+ "app_id": settings.FEISHU_OAUTH.get("app_id"),
229
+ "app_secret": settings.FEISHU_OAUTH.get("app_secret"),
230
  }
231
  ),
232
  headers={"Content-Type": "application/json; charset=utf-8"},
 
236
  return redirect("/?error=%s" % app_access_token_res)
237
 
238
  res = requests.post(
239
+ settings.FEISHU_OAUTH.get("user_access_token_url"),
240
  data=json.dumps(
241
  {
242
+ "grant_type": settings.FEISHU_OAUTH.get("grant_type"),
243
  "code": request.args.get("code"),
244
  }
245
  ),
 
392
  if request_data.get("password"):
393
  new_password = request_data.get("new_password")
394
  if not check_password_hash(
395
+ current_user.password, decrypt(request_data["password"])
396
  ):
397
  return get_json_result(
398
  data=False,
399
+ code=settings.RetCode.AUTHENTICATION_ERROR,
400
  message="Password error!",
401
  )
402
 
 
425
  except Exception as e:
426
  logging.exception(e)
427
  return get_json_result(
428
+ data=False, message="Update failure!", code=settings.RetCode.EXCEPTION_ERROR
429
  )
430
 
431
 
 
484
  tenant = {
485
  "id": user_id,
486
  "name": user["nickname"] + "‘s Kingdom",
487
+ "llm_id": settings.CHAT_MDL,
488
+ "embd_id": settings.EMBEDDING_MDL,
489
+ "asr_id": settings.ASR_MDL,
490
+ "parser_ids": settings.PARSERS,
491
+ "img2txt_id": settings.IMAGE2TEXT_MDL,
492
+ "rerank_id": settings.RERANK_MDL,
493
  }
494
  usr_tenant = {
495
  "tenant_id": user_id,
 
509
  "location": "",
510
  }
511
  tenant_llm = []
512
+ for llm in LLMService.query(fid=settings.LLM_FACTORY):
513
  tenant_llm.append(
514
  {
515
  "tenant_id": user_id,
516
+ "llm_factory": settings.LLM_FACTORY,
517
  "llm_name": llm.llm_name,
518
  "model_type": llm.model_type,
519
+ "api_key": settings.API_KEY,
520
+ "api_base": settings.LLM_BASE_URL,
521
  }
522
  )
523
 
 
569
  return get_json_result(
570
  data=False,
571
  message=f"Invalid email address: {email_address}!",
572
+ code=settings.RetCode.OPERATING_ERROR,
573
  )
574
 
575
  # Check if the email address is already used
 
577
  return get_json_result(
578
  data=False,
579
  message=f"Email: {email_address} has already registered!",
580
+ code=settings.RetCode.OPERATING_ERROR,
581
  )
582
 
583
  # Construct user info data
 
612
  return get_json_result(
613
  data=False,
614
  message=f"User registration failure, error: {str(e)}",
615
+ code=settings.RetCode.EXCEPTION_ERROR,
616
  )
617
 
618
 
api/db/db_models.py CHANGED
@@ -31,7 +31,7 @@ from peewee import (
31
  )
32
  from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
33
  from api.db import SerializedType, ParserType
34
- from api.settings import DATABASE, SECRET_KEY, DATABASE_TYPE
35
  from api import utils
36
 
37
  def singleton(cls, *args, **kw):
@@ -62,7 +62,7 @@ class TextFieldType(Enum):
62
 
63
 
64
  class LongTextField(TextField):
65
- field_type = TextFieldType[DATABASE_TYPE.upper()].value
66
 
67
 
68
  class JSONField(LongTextField):
@@ -282,9 +282,9 @@ class DatabaseMigrator(Enum):
282
  @singleton
283
  class BaseDataBase:
284
  def __init__(self):
285
- database_config = DATABASE.copy()
286
  db_name = database_config.pop("name")
287
- self.database_connection = PooledDatabase[DATABASE_TYPE.upper()].value(db_name, **database_config)
288
  logging.info('init database on cluster mode successfully')
289
 
290
  class PostgresDatabaseLock:
@@ -385,7 +385,7 @@ class DatabaseLock(Enum):
385
 
386
 
387
  DB = BaseDataBase().database_connection
388
- DB.lock = DatabaseLock[DATABASE_TYPE.upper()].value
389
 
390
 
391
  def close_connection():
@@ -476,7 +476,7 @@ class User(DataBaseModel, UserMixin):
476
  return self.email
477
 
478
  def get_id(self):
479
- jwt = Serializer(secret_key=SECRET_KEY)
480
  return jwt.dumps(str(self.access_token))
481
 
482
  class Meta:
@@ -977,7 +977,7 @@ class CanvasTemplate(DataBaseModel):
977
 
978
  def migrate_db():
979
  with DB.transaction():
980
- migrator = DatabaseMigrator[DATABASE_TYPE.upper()].value(DB)
981
  try:
982
  migrate(
983
  migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",
 
31
  )
32
  from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
33
  from api.db import SerializedType, ParserType
34
+ from api import settings
35
  from api import utils
36
 
37
  def singleton(cls, *args, **kw):
 
62
 
63
 
64
  class LongTextField(TextField):
65
+ field_type = TextFieldType[settings.DATABASE_TYPE.upper()].value
66
 
67
 
68
  class JSONField(LongTextField):
 
282
  @singleton
283
  class BaseDataBase:
284
  def __init__(self):
285
+ database_config = settings.DATABASE.copy()
286
  db_name = database_config.pop("name")
287
+ self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
288
  logging.info('init database on cluster mode successfully')
289
 
290
  class PostgresDatabaseLock:
 
385
 
386
 
387
  DB = BaseDataBase().database_connection
388
+ DB.lock = DatabaseLock[settings.DATABASE_TYPE.upper()].value
389
 
390
 
391
  def close_connection():
 
476
  return self.email
477
 
478
  def get_id(self):
479
+ jwt = Serializer(secret_key=settings.SECRET_KEY)
480
  return jwt.dumps(str(self.access_token))
481
 
482
  class Meta:
 
977
 
978
  def migrate_db():
979
  with DB.transaction():
980
+ migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
981
  try:
982
  migrate(
983
  migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",
api/db/init_data.py CHANGED
@@ -29,7 +29,7 @@ from api.db.services.document_service import DocumentService
29
  from api.db.services.knowledgebase_service import KnowledgebaseService
30
  from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
31
  from api.db.services.user_service import TenantService, UserTenantService
32
- from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY, LLM_BASE_URL
33
  from api.utils.file_utils import get_project_base_directory
34
 
35
 
@@ -51,11 +51,11 @@ def init_superuser():
51
  tenant = {
52
  "id": user_info["id"],
53
  "name": user_info["nickname"] + "‘s Kingdom",
54
- "llm_id": CHAT_MDL,
55
- "embd_id": EMBEDDING_MDL,
56
- "asr_id": ASR_MDL,
57
- "parser_ids": PARSERS,
58
- "img2txt_id": IMAGE2TEXT_MDL
59
  }
60
  usr_tenant = {
61
  "tenant_id": user_info["id"],
@@ -64,10 +64,11 @@ def init_superuser():
64
  "role": UserTenantRole.OWNER
65
  }
66
  tenant_llm = []
67
- for llm in LLMService.query(fid=LLM_FACTORY):
68
  tenant_llm.append(
69
- {"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type,
70
- "api_key": API_KEY, "api_base": LLM_BASE_URL})
 
71
 
72
  if not UserService.save(**user_info):
73
  logging.error("can't init admin.")
@@ -80,7 +81,7 @@ def init_superuser():
80
 
81
  chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
82
  msg = chat_mdl.chat(system="", history=[
83
- {"role": "user", "content": "Hello!"}], gen_conf={})
84
  if msg.find("ERROR: ") == 0:
85
  logging.error(
86
  "'{}' dosen't work. {}".format(
@@ -179,7 +180,7 @@ def init_web_data():
179
  start_time = time.time()
180
 
181
  init_llm_factory()
182
- #if not UserService.get_all().count():
183
  # init_superuser()
184
 
185
  add_graph_templates()
 
29
  from api.db.services.knowledgebase_service import KnowledgebaseService
30
  from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
31
  from api.db.services.user_service import TenantService, UserTenantService
32
+ from api import settings
33
  from api.utils.file_utils import get_project_base_directory
34
 
35
 
 
51
  tenant = {
52
  "id": user_info["id"],
53
  "name": user_info["nickname"] + "‘s Kingdom",
54
+ "llm_id": settings.CHAT_MDL,
55
+ "embd_id": settings.EMBEDDING_MDL,
56
+ "asr_id": settings.ASR_MDL,
57
+ "parser_ids": settings.PARSERS,
58
+ "img2txt_id": settings.IMAGE2TEXT_MDL
59
  }
60
  usr_tenant = {
61
  "tenant_id": user_info["id"],
 
64
  "role": UserTenantRole.OWNER
65
  }
66
  tenant_llm = []
67
+ for llm in LLMService.query(fid=settings.LLM_FACTORY):
68
  tenant_llm.append(
69
+ {"tenant_id": user_info["id"], "llm_factory": settings.LLM_FACTORY, "llm_name": llm.llm_name,
70
+ "model_type": llm.model_type,
71
+ "api_key": settings.API_KEY, "api_base": settings.LLM_BASE_URL})
72
 
73
  if not UserService.save(**user_info):
74
  logging.error("can't init admin.")
 
81
 
82
  chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
83
  msg = chat_mdl.chat(system="", history=[
84
+ {"role": "user", "content": "Hello!"}], gen_conf={})
85
  if msg.find("ERROR: ") == 0:
86
  logging.error(
87
  "'{}' dosen't work. {}".format(
 
180
  start_time = time.time()
181
 
182
  init_llm_factory()
183
+ # if not UserService.get_all().count():
184
  # init_superuser()
185
 
186
  add_graph_templates()
api/db/services/dialog_service.py CHANGED
@@ -27,7 +27,7 @@ from api.db.db_models import Dialog, Conversation,DB
27
  from api.db.services.common_service import CommonService
28
  from api.db.services.knowledgebase_service import KnowledgebaseService
29
  from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
30
- from api.settings import retrievaler, kg_retrievaler
31
  from rag.app.resume import forbidden_select_fields4resume
32
  from rag.nlp.search import index_name
33
  from rag.utils import rmSpace, num_tokens_from_string, encoder
@@ -152,7 +152,7 @@ def chat(dialog, messages, stream=True, **kwargs):
152
  return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
153
 
154
  is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
155
- retr = retrievaler if not is_kg else kg_retrievaler
156
 
157
  questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
158
  attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
@@ -342,7 +342,7 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
342
 
343
  logging.debug(f"{question} get SQL(refined): {sql}")
344
  tried_times += 1
345
- return retrievaler.sql_retrieval(sql, format="json"), sql
346
 
347
  tbl, sql = get_table()
348
  if tbl is None:
@@ -596,7 +596,7 @@ def ask(question, kb_ids, tenant_id):
596
  embd_nms = list(set([kb.embd_id for kb in kbs]))
597
 
598
  is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
599
- retr = retrievaler if not is_kg else kg_retrievaler
600
 
601
  embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
602
  chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
 
27
  from api.db.services.common_service import CommonService
28
  from api.db.services.knowledgebase_service import KnowledgebaseService
29
  from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
30
+ from api import settings
31
  from rag.app.resume import forbidden_select_fields4resume
32
  from rag.nlp.search import index_name
33
  from rag.utils import rmSpace, num_tokens_from_string, encoder
 
152
  return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
153
 
154
  is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
155
+ retr = settings.retrievaler if not is_kg else settings.kg_retrievaler
156
 
157
  questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
158
  attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
 
342
 
343
  logging.debug(f"{question} get SQL(refined): {sql}")
344
  tried_times += 1
345
+ return settings.retrievaler.sql_retrieval(sql, format="json"), sql
346
 
347
  tbl, sql = get_table()
348
  if tbl is None:
 
596
  embd_nms = list(set([kb.embd_id for kb in kbs]))
597
 
598
  is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
599
+ retr = settings.retrievaler if not is_kg else settings.kg_retrievaler
600
 
601
  embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
602
  chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
api/db/services/document_service.py CHANGED
@@ -26,7 +26,7 @@ from io import BytesIO
26
  from peewee import fn
27
 
28
  from api.db.db_utils import bulk_insert_into_db
29
- from api.settings import docStoreConn
30
  from api.utils import current_timestamp, get_format_time, get_uuid
31
  from graphrag.mind_map_extractor import MindMapExtractor
32
  from rag.settings import SVR_QUEUE_NAME
@@ -108,7 +108,7 @@ class DocumentService(CommonService):
108
  @classmethod
109
  @DB.connection_context()
110
  def remove_document(cls, doc, tenant_id):
111
- docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
112
  cls.clear_chunk_num(doc.id)
113
  return cls.delete_by_id(doc.id)
114
 
@@ -553,10 +553,10 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
553
  d["q_%d_vec" % len(v)] = v
554
  for b in range(0, len(cks), es_bulk_size):
555
  if try_create_idx:
556
- if not docStoreConn.indexExist(idxnm, kb_id):
557
- docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
558
  try_create_idx = False
559
- docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
560
 
561
  DocumentService.increment_chunk_num(
562
  doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
 
26
  from peewee import fn
27
 
28
  from api.db.db_utils import bulk_insert_into_db
29
+ from api import settings
30
  from api.utils import current_timestamp, get_format_time, get_uuid
31
  from graphrag.mind_map_extractor import MindMapExtractor
32
  from rag.settings import SVR_QUEUE_NAME
 
108
  @classmethod
109
  @DB.connection_context()
110
  def remove_document(cls, doc, tenant_id):
111
+ settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
112
  cls.clear_chunk_num(doc.id)
113
  return cls.delete_by_id(doc.id)
114
 
 
553
  d["q_%d_vec" % len(v)] = v
554
  for b in range(0, len(cks), es_bulk_size):
555
  if try_create_idx:
556
+ if not settings.docStoreConn.indexExist(idxnm, kb_id):
557
+ settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
558
  try_create_idx = False
559
+ settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
560
 
561
  DocumentService.increment_chunk_num(
562
  doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
api/ragflow_server.py CHANGED
@@ -33,12 +33,10 @@ import traceback
33
  from concurrent.futures import ThreadPoolExecutor
34
 
35
  from werkzeug.serving import run_simple
 
36
  from api.apps import app
37
  from api.db.runtime_config import RuntimeConfig
38
  from api.db.services.document_service import DocumentService
39
- from api.settings import (
40
- HOST, HTTP_PORT
41
- )
42
  from api import utils
43
 
44
  from api.db.db_models import init_database_tables as init_web_db
@@ -72,6 +70,7 @@ if __name__ == '__main__':
72
  f'project base: {utils.file_utils.get_project_base_directory()}'
73
  )
74
  show_configs()
 
75
 
76
  # init db
77
  init_web_db()
@@ -96,7 +95,7 @@ if __name__ == '__main__':
96
  logging.info("run on debug mode")
97
 
98
  RuntimeConfig.init_env()
99
- RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT)
100
 
101
  thread = ThreadPoolExecutor(max_workers=1)
102
  thread.submit(update_progress)
@@ -105,8 +104,8 @@ if __name__ == '__main__':
105
  try:
106
  logging.info("RAGFlow HTTP server start...")
107
  run_simple(
108
- hostname=HOST,
109
- port=HTTP_PORT,
110
  application=app,
111
  threaded=True,
112
  use_reloader=RuntimeConfig.DEBUG,
 
33
  from concurrent.futures import ThreadPoolExecutor
34
 
35
  from werkzeug.serving import run_simple
36
+ from api import settings
37
  from api.apps import app
38
  from api.db.runtime_config import RuntimeConfig
39
  from api.db.services.document_service import DocumentService
 
 
 
40
  from api import utils
41
 
42
  from api.db.db_models import init_database_tables as init_web_db
 
70
  f'project base: {utils.file_utils.get_project_base_directory()}'
71
  )
72
  show_configs()
73
+ settings.init_settings()
74
 
75
  # init db
76
  init_web_db()
 
95
  logging.info("run on debug mode")
96
 
97
  RuntimeConfig.init_env()
98
+ RuntimeConfig.init_config(JOB_SERVER_HOST=settings.HOST_IP, HTTP_PORT=settings.HOST_PORT)
99
 
100
  thread = ThreadPoolExecutor(max_workers=1)
101
  thread.submit(update_progress)
 
104
  try:
105
  logging.info("RAGFlow HTTP server start...")
106
  run_simple(
107
+ hostname=settings.HOST_IP,
108
+ port=settings.HOST_PORT,
109
  application=app,
110
  threaded=True,
111
  use_reloader=RuntimeConfig.DEBUG,
api/settings.py CHANGED
@@ -30,114 +30,157 @@ LIGHTEN = int(os.environ.get('LIGHTEN', "0"))
30
 
31
  REQUEST_WAIT_SEC = 2
32
  REQUEST_MAX_WAIT_SEC = 300
33
-
34
- LLM = get_base_config("user_default_llm", {})
35
- LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
36
- LLM_BASE_URL = LLM.get("base_url")
37
-
38
- CHAT_MDL = EMBEDDING_MDL = RERANK_MDL = ASR_MDL = IMAGE2TEXT_MDL = ""
39
- if not LIGHTEN:
40
- default_llm = {
41
- "Tongyi-Qianwen": {
42
- "chat_model": "qwen-plus",
43
- "embedding_model": "text-embedding-v2",
44
- "image2text_model": "qwen-vl-max",
45
- "asr_model": "paraformer-realtime-8k-v1",
46
- },
47
- "OpenAI": {
48
- "chat_model": "gpt-3.5-turbo",
49
- "embedding_model": "text-embedding-ada-002",
50
- "image2text_model": "gpt-4-vision-preview",
51
- "asr_model": "whisper-1",
52
- },
53
- "Azure-OpenAI": {
54
- "chat_model": "gpt-35-turbo",
55
- "embedding_model": "text-embedding-ada-002",
56
- "image2text_model": "gpt-4-vision-preview",
57
- "asr_model": "whisper-1",
58
- },
59
- "ZHIPU-AI": {
60
- "chat_model": "glm-3-turbo",
61
- "embedding_model": "embedding-2",
62
- "image2text_model": "glm-4v",
63
- "asr_model": "",
64
- },
65
- "Ollama": {
66
- "chat_model": "qwen-14B-chat",
67
- "embedding_model": "flag-embedding",
68
- "image2text_model": "",
69
- "asr_model": "",
70
- },
71
- "Moonshot": {
72
- "chat_model": "moonshot-v1-8k",
73
- "embedding_model": "",
74
- "image2text_model": "",
75
- "asr_model": "",
76
- },
77
- "DeepSeek": {
78
- "chat_model": "deepseek-chat",
79
- "embedding_model": "",
80
- "image2text_model": "",
81
- "asr_model": "",
82
- },
83
- "VolcEngine": {
84
- "chat_model": "",
85
- "embedding_model": "",
86
- "image2text_model": "",
87
- "asr_model": "",
88
- },
89
- "BAAI": {
90
- "chat_model": "",
91
- "embedding_model": "BAAI/bge-large-zh-v1.5",
92
- "image2text_model": "",
93
- "asr_model": "",
94
- "rerank_model": "BAAI/bge-reranker-v2-m3",
95
- }
96
- }
97
-
98
- if LLM_FACTORY:
99
- CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"] + f"@{LLM_FACTORY}"
100
- ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] + f"@{LLM_FACTORY}"
101
- IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] + f"@{LLM_FACTORY}"
102
- EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"] + "@BAAI"
103
- RERANK_MDL = default_llm["BAAI"]["rerank_model"] + "@BAAI"
104
-
105
- API_KEY = LLM.get("api_key", "")
106
- PARSERS = LLM.get(
107
- "parsers",
108
- "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email")
109
-
110
- HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
111
- HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
112
-
113
- SECRET_KEY = get_base_config(
114
- RAG_FLOW_SERVICE_NAME,
115
- {}).get("secret_key", str(date.today()))
116
 
117
  DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
118
  DATABASE = decrypt_database_config(name=DATABASE_TYPE)
119
 
120
  # authentication
121
- AUTHENTICATION_CONF = get_base_config("authentication", {})
122
 
123
  # client
124
- CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get(
125
- "client", {}).get(
126
- "switch", False)
127
- HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
128
- GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
129
- FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")
130
-
131
- DOC_ENGINE = os.environ.get('DOC_ENGINE', "elasticsearch")
132
- if DOC_ENGINE == "elasticsearch":
133
- docStoreConn = rag.utils.es_conn.ESConnection()
134
- elif DOC_ENGINE == "infinity":
135
- docStoreConn = rag.utils.infinity_conn.InfinityConnection()
136
- else:
137
- raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
138
-
139
- retrievaler = search.Dealer(docStoreConn)
140
- kg_retrievaler = kg_search.KGSearch(docStoreConn)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
 
143
  class CustomEnum(Enum):
 
30
 
31
  REQUEST_WAIT_SEC = 2
32
  REQUEST_MAX_WAIT_SEC = 300
33
+ LLM = None
34
+ LLM_FACTORY = None
35
+ LLM_BASE_URL = None
36
+ CHAT_MDL = ""
37
+ EMBEDDING_MDL = ""
38
+ RERANK_MDL = ""
39
+ ASR_MDL = ""
40
+ IMAGE2TEXT_MDL = ""
41
+ API_KEY = None
42
+ PARSERS = None
43
+ HOST_IP = None
44
+ HOST_PORT = None
45
+ SECRET_KEY = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
48
  DATABASE = decrypt_database_config(name=DATABASE_TYPE)
49
 
50
  # authentication
51
+ AUTHENTICATION_CONF = None
52
 
53
  # client
54
+ CLIENT_AUTHENTICATION = None
55
+ HTTP_APP_KEY = None
56
+ GITHUB_OAUTH = None
57
+ FEISHU_OAUTH = None
58
+
59
+ DOC_ENGINE = None
60
+ docStoreConn = None
61
+
62
+ retrievaler = None
63
+ kg_retrievaler = None
64
+
65
+
66
+ def init_settings():
67
+ global LLM, LLM_FACTORY, LLM_BASE_URL
68
+ LLM = get_base_config("user_default_llm", {})
69
+ LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
70
+ LLM_BASE_URL = LLM.get("base_url")
71
+
72
+ global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
73
+ if not LIGHTEN:
74
+ default_llm = {
75
+ "Tongyi-Qianwen": {
76
+ "chat_model": "qwen-plus",
77
+ "embedding_model": "text-embedding-v2",
78
+ "image2text_model": "qwen-vl-max",
79
+ "asr_model": "paraformer-realtime-8k-v1",
80
+ },
81
+ "OpenAI": {
82
+ "chat_model": "gpt-3.5-turbo",
83
+ "embedding_model": "text-embedding-ada-002",
84
+ "image2text_model": "gpt-4-vision-preview",
85
+ "asr_model": "whisper-1",
86
+ },
87
+ "Azure-OpenAI": {
88
+ "chat_model": "gpt-35-turbo",
89
+ "embedding_model": "text-embedding-ada-002",
90
+ "image2text_model": "gpt-4-vision-preview",
91
+ "asr_model": "whisper-1",
92
+ },
93
+ "ZHIPU-AI": {
94
+ "chat_model": "glm-3-turbo",
95
+ "embedding_model": "embedding-2",
96
+ "image2text_model": "glm-4v",
97
+ "asr_model": "",
98
+ },
99
+ "Ollama": {
100
+ "chat_model": "qwen-14B-chat",
101
+ "embedding_model": "flag-embedding",
102
+ "image2text_model": "",
103
+ "asr_model": "",
104
+ },
105
+ "Moonshot": {
106
+ "chat_model": "moonshot-v1-8k",
107
+ "embedding_model": "",
108
+ "image2text_model": "",
109
+ "asr_model": "",
110
+ },
111
+ "DeepSeek": {
112
+ "chat_model": "deepseek-chat",
113
+ "embedding_model": "",
114
+ "image2text_model": "",
115
+ "asr_model": "",
116
+ },
117
+ "VolcEngine": {
118
+ "chat_model": "",
119
+ "embedding_model": "",
120
+ "image2text_model": "",
121
+ "asr_model": "",
122
+ },
123
+ "BAAI": {
124
+ "chat_model": "",
125
+ "embedding_model": "BAAI/bge-large-zh-v1.5",
126
+ "image2text_model": "",
127
+ "asr_model": "",
128
+ "rerank_model": "BAAI/bge-reranker-v2-m3",
129
+ }
130
+ }
131
+
132
+ if LLM_FACTORY:
133
+ CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"] + f"@{LLM_FACTORY}"
134
+ ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] + f"@{LLM_FACTORY}"
135
+ IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] + f"@{LLM_FACTORY}"
136
+ EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"] + "@BAAI"
137
+ RERANK_MDL = default_llm["BAAI"]["rerank_model"] + "@BAAI"
138
+
139
+ global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY
140
+ API_KEY = LLM.get("api_key", "")
141
+ PARSERS = LLM.get(
142
+ "parsers",
143
+ "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email")
144
+
145
+ HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
146
+ HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
147
+
148
+ SECRET_KEY = get_base_config(
149
+ RAG_FLOW_SERVICE_NAME,
150
+ {}).get("secret_key", str(date.today()))
151
+
152
+ global AUTHENTICATION_CONF, CLIENT_AUTHENTICATION, HTTP_APP_KEY, GITHUB_OAUTH, FEISHU_OAUTH
153
+ # authentication
154
+ AUTHENTICATION_CONF = get_base_config("authentication", {})
155
+
156
+ # client
157
+ CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get(
158
+ "client", {}).get(
159
+ "switch", False)
160
+ HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
161
+ GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
162
+ FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")
163
+
164
+ global DOC_ENGINE, docStoreConn, retrievaler, kg_retrievaler
165
+ DOC_ENGINE = os.environ.get('DOC_ENGINE', "elasticsearch")
166
+ if DOC_ENGINE == "elasticsearch":
167
+ docStoreConn = rag.utils.es_conn.ESConnection()
168
+ elif DOC_ENGINE == "infinity":
169
+ docStoreConn = rag.utils.infinity_conn.InfinityConnection()
170
+ else:
171
+ raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
172
+
173
+ retrievaler = search.Dealer(docStoreConn)
174
+ kg_retrievaler = kg_search.KGSearch(docStoreConn)
175
+
176
+ def get_host_ip():
177
+ global HOST_IP
178
+ return HOST_IP
179
+
180
+
181
+ def get_host_port():
182
+ global HOST_PORT
183
+ return HOST_PORT
184
 
185
 
186
  class CustomEnum(Enum):
api/utils/api_utils.py CHANGED
@@ -34,11 +34,9 @@ from itsdangerous import URLSafeTimedSerializer
34
  from werkzeug.http import HTTP_STATUS_CODES
35
 
36
  from api.db.db_models import APIToken
37
- from api.settings import (
38
- REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
39
- CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
40
- )
41
- from api.settings import RetCode
42
  from api.utils import CustomJSONEncoder, get_uuid
43
  from api.utils import json_dumps
44
 
@@ -59,13 +57,13 @@ def request(**kwargs):
59
  {}).items()}
60
  prepped = requests.Request(**kwargs).prepare()
61
 
62
- if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY:
63
  timestamp = str(round(time() * 1000))
64
  nonce = str(uuid1())
65
- signature = b64encode(HMAC(SECRET_KEY.encode('ascii'), b'\n'.join([
66
  timestamp.encode('ascii'),
67
  nonce.encode('ascii'),
68
- HTTP_APP_KEY.encode('ascii'),
69
  prepped.path_url.encode('ascii'),
70
  prepped.body if kwargs.get('json') else b'',
71
  urlencode(
@@ -79,7 +77,7 @@ def request(**kwargs):
79
  prepped.headers.update({
80
  'TIMESTAMP': timestamp,
81
  'NONCE': nonce,
82
- 'APP-KEY': HTTP_APP_KEY,
83
  'SIGNATURE': signature,
84
  })
85
 
@@ -89,7 +87,7 @@ def request(**kwargs):
89
  def get_exponential_backoff_interval(retries, full_jitter=False):
90
  """Calculate the exponential backoff wait time."""
91
  # Will be zero if factor equals 0
92
- countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries))
93
  # Full jitter according to
94
  # https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
95
  if full_jitter:
@@ -98,7 +96,7 @@ def get_exponential_backoff_interval(retries, full_jitter=False):
98
  return max(0, countdown)
99
 
100
 
101
- def get_data_error_result(code=RetCode.DATA_ERROR,
102
  message='Sorry! Data missing!'):
103
  import re
104
  result_dict = {
@@ -126,8 +124,8 @@ def server_error_response(e):
126
  pass
127
  if len(e.args) > 1:
128
  return get_json_result(
129
- code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
130
- return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
131
 
132
 
133
  def error_response(response_code, message=None):
@@ -168,7 +166,7 @@ def validate_request(*args, **kwargs):
168
  error_string += "required argument values: {}".format(
169
  ",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
170
  return get_json_result(
171
- code=RetCode.ARGUMENT_ERROR, message=error_string)
172
  return func(*_args, **_kwargs)
173
 
174
  return decorated_function
@@ -193,7 +191,7 @@ def send_file_in_mem(data, filename):
193
  return send_file(f, as_attachment=True, attachment_filename=filename)
194
 
195
 
196
- def get_json_result(code=RetCode.SUCCESS, message='success', data=None):
197
  response = {"code": code, "message": message, "data": data}
198
  return jsonify(response)
199
 
@@ -204,7 +202,7 @@ def apikey_required(func):
204
  objs = APIToken.query(token=token)
205
  if not objs:
206
  return build_error_result(
207
- message='API-KEY is invalid!', code=RetCode.FORBIDDEN
208
  )
209
  kwargs['tenant_id'] = objs[0].tenant_id
210
  return func(*args, **kwargs)
@@ -212,14 +210,14 @@ def apikey_required(func):
212
  return decorated_function
213
 
214
 
215
- def build_error_result(code=RetCode.FORBIDDEN, message='success'):
216
  response = {"code": code, "message": message}
217
  response = jsonify(response)
218
  response.status_code = code
219
  return response
220
 
221
 
222
- def construct_response(code=RetCode.SUCCESS,
223
  message='success', data=None, auth=None):
224
  result_dict = {"code": code, "message": message, "data": data}
225
  response_dict = {}
@@ -239,7 +237,7 @@ def construct_response(code=RetCode.SUCCESS,
239
  return response
240
 
241
 
242
- def construct_result(code=RetCode.DATA_ERROR, message='data is missing'):
243
  import re
244
  result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)}
245
  response = {}
@@ -251,7 +249,7 @@ def construct_result(code=RetCode.DATA_ERROR, message='data is missing'):
251
  return jsonify(response)
252
 
253
 
254
- def construct_json_result(code=RetCode.SUCCESS, message='success', data=None):
255
  if data is None:
256
  return jsonify({"code": code, "message": message})
257
  else:
@@ -262,12 +260,12 @@ def construct_error_response(e):
262
  logging.exception(e)
263
  try:
264
  if e.code == 401:
265
- return construct_json_result(code=RetCode.UNAUTHORIZED, message=repr(e))
266
  except BaseException:
267
  pass
268
  if len(e.args) > 1:
269
- return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
270
- return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
271
 
272
 
273
  def token_required(func):
@@ -280,7 +278,7 @@ def token_required(func):
280
  objs = APIToken.query(token=token)
281
  if not objs:
282
  return get_json_result(
283
- data=False, message='Token is not valid!', code=RetCode.AUTHENTICATION_ERROR
284
  )
285
  kwargs['tenant_id'] = objs[0].tenant_id
286
  return func(*args, **kwargs)
@@ -288,7 +286,7 @@ def token_required(func):
288
  return decorated_function
289
 
290
 
291
- def get_result(code=RetCode.SUCCESS, message="", data=None):
292
  if code == 0:
293
  if data is not None:
294
  response = {"code": code, "data": data}
@@ -299,7 +297,7 @@ def get_result(code=RetCode.SUCCESS, message="", data=None):
299
  return jsonify(response)
300
 
301
 
302
- def get_error_data_result(message='Sorry! Data missing!', code=RetCode.DATA_ERROR,
303
  ):
304
  import re
305
  result_dict = {
 
34
  from werkzeug.http import HTTP_STATUS_CODES
35
 
36
  from api.db.db_models import APIToken
37
+ from api import settings
38
+
39
+ from api import settings
 
 
40
  from api.utils import CustomJSONEncoder, get_uuid
41
  from api.utils import json_dumps
42
 
 
57
  {}).items()}
58
  prepped = requests.Request(**kwargs).prepare()
59
 
60
+ if settings.CLIENT_AUTHENTICATION and settings.HTTP_APP_KEY and settings.SECRET_KEY:
61
  timestamp = str(round(time() * 1000))
62
  nonce = str(uuid1())
63
+ signature = b64encode(HMAC(settings.SECRET_KEY.encode('ascii'), b'\n'.join([
64
  timestamp.encode('ascii'),
65
  nonce.encode('ascii'),
66
+ settings.HTTP_APP_KEY.encode('ascii'),
67
  prepped.path_url.encode('ascii'),
68
  prepped.body if kwargs.get('json') else b'',
69
  urlencode(
 
77
  prepped.headers.update({
78
  'TIMESTAMP': timestamp,
79
  'NONCE': nonce,
80
+ 'APP-KEY': settings.HTTP_APP_KEY,
81
  'SIGNATURE': signature,
82
  })
83
 
 
87
  def get_exponential_backoff_interval(retries, full_jitter=False):
88
  """Calculate the exponential backoff wait time."""
89
  # Will be zero if factor equals 0
90
+ countdown = min(settings.REQUEST_MAX_WAIT_SEC, settings.REQUEST_WAIT_SEC * (2 ** retries))
91
  # Full jitter according to
92
  # https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
93
  if full_jitter:
 
96
  return max(0, countdown)
97
 
98
 
99
+ def get_data_error_result(code=settings.RetCode.DATA_ERROR,
100
  message='Sorry! Data missing!'):
101
  import re
102
  result_dict = {
 
124
  pass
125
  if len(e.args) > 1:
126
  return get_json_result(
127
+ code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
128
+ return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
129
 
130
 
131
  def error_response(response_code, message=None):
 
166
  error_string += "required argument values: {}".format(
167
  ",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
168
  return get_json_result(
169
+ code=settings.RetCode.ARGUMENT_ERROR, message=error_string)
170
  return func(*_args, **_kwargs)
171
 
172
  return decorated_function
 
191
  return send_file(f, as_attachment=True, attachment_filename=filename)
192
 
193
 
194
+ def get_json_result(code=settings.RetCode.SUCCESS, message='success', data=None):
195
  response = {"code": code, "message": message, "data": data}
196
  return jsonify(response)
197
 
 
202
  objs = APIToken.query(token=token)
203
  if not objs:
204
  return build_error_result(
205
+ message='API-KEY is invalid!', code=settings.RetCode.FORBIDDEN
206
  )
207
  kwargs['tenant_id'] = objs[0].tenant_id
208
  return func(*args, **kwargs)
 
210
  return decorated_function
211
 
212
 
213
+ def build_error_result(code=settings.RetCode.FORBIDDEN, message='success'):
214
  response = {"code": code, "message": message}
215
  response = jsonify(response)
216
  response.status_code = code
217
  return response
218
 
219
 
220
+ def construct_response(code=settings.RetCode.SUCCESS,
221
  message='success', data=None, auth=None):
222
  result_dict = {"code": code, "message": message, "data": data}
223
  response_dict = {}
 
237
  return response
238
 
239
 
240
+ def construct_result(code=settings.RetCode.DATA_ERROR, message='data is missing'):
241
  import re
242
  result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)}
243
  response = {}
 
249
  return jsonify(response)
250
 
251
 
252
+ def construct_json_result(code=settings.RetCode.SUCCESS, message='success', data=None):
253
  if data is None:
254
  return jsonify({"code": code, "message": message})
255
  else:
 
260
  logging.exception(e)
261
  try:
262
  if e.code == 401:
263
+ return construct_json_result(code=settings.RetCode.UNAUTHORIZED, message=repr(e))
264
  except BaseException:
265
  pass
266
  if len(e.args) > 1:
267
+ return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
268
+ return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
269
 
270
 
271
  def token_required(func):
 
278
  objs = APIToken.query(token=token)
279
  if not objs:
280
  return get_json_result(
281
+ data=False, message='Token is not valid!', code=settings.RetCode.AUTHENTICATION_ERROR
282
  )
283
  kwargs['tenant_id'] = objs[0].tenant_id
284
  return func(*args, **kwargs)
 
286
  return decorated_function
287
 
288
 
289
+ def get_result(code=settings.RetCode.SUCCESS, message="", data=None):
290
  if code == 0:
291
  if data is not None:
292
  response = {"code": code, "data": data}
 
297
  return jsonify(response)
298
 
299
 
300
+ def get_error_data_result(message='Sorry! Data missing!', code=settings.RetCode.DATA_ERROR,
301
  ):
302
  import re
303
  result_dict = {
deepdoc/parser/pdf_parser.py CHANGED
@@ -24,7 +24,7 @@ import numpy as np
24
  from timeit import default_timer as timer
25
  from pypdf import PdfReader as pdf2_read
26
 
27
- from api.settings import LIGHTEN
28
  from api.utils.file_utils import get_project_base_directory
29
  from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer
30
  from rag.nlp import rag_tokenizer
@@ -41,7 +41,7 @@ class RAGFlowPdfParser:
41
  self.tbl_det = TableStructureRecognizer()
42
 
43
  self.updown_cnt_mdl = xgb.Booster()
44
- if not LIGHTEN:
45
  try:
46
  import torch
47
  if torch.cuda.is_available():
 
24
  from timeit import default_timer as timer
25
  from pypdf import PdfReader as pdf2_read
26
 
27
+ from api import settings
28
  from api.utils.file_utils import get_project_base_directory
29
  from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer
30
  from rag.nlp import rag_tokenizer
 
41
  self.tbl_det = TableStructureRecognizer()
42
 
43
  self.updown_cnt_mdl = xgb.Booster()
44
+ if not settings.LIGHTEN:
45
  try:
46
  import torch
47
  if torch.cuda.is_available():
graphrag/claim_extractor.py CHANGED
@@ -252,13 +252,13 @@ if __name__ == "__main__":
252
 
253
  from api.db import LLMType
254
  from api.db.services.llm_service import LLMBundle
255
- from api.settings import retrievaler
256
  from api.db.services.knowledgebase_service import KnowledgebaseService
257
 
258
  kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
259
 
260
  ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
261
- docs = [d["content_with_weight"] for d in retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
262
  info = {
263
  "input_text": docs,
264
  "entity_specs": "organization, person",
 
252
 
253
  from api.db import LLMType
254
  from api.db.services.llm_service import LLMBundle
255
+ from api import settings
256
  from api.db.services.knowledgebase_service import KnowledgebaseService
257
 
258
  kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
259
 
260
  ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
261
+ docs = [d["content_with_weight"] for d in settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
262
  info = {
263
  "input_text": docs,
264
  "entity_specs": "organization, person",
graphrag/smoke.py CHANGED
@@ -30,14 +30,14 @@ if __name__ == "__main__":
30
 
31
  from api.db import LLMType
32
  from api.db.services.llm_service import LLMBundle
33
- from api.settings import retrievaler
34
  from api.db.services.knowledgebase_service import KnowledgebaseService
35
 
36
  kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
37
 
38
  ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
39
  docs = [d["content_with_weight"] for d in
40
- retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])]
41
  graph = ex(docs)
42
 
43
  er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT))
 
30
 
31
  from api.db import LLMType
32
  from api.db.services.llm_service import LLMBundle
33
+ from api import settings
34
  from api.db.services.knowledgebase_service import KnowledgebaseService
35
 
36
  kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
37
 
38
  ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
39
  docs = [d["content_with_weight"] for d in
40
+ settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])]
41
  graph = ex(docs)
42
 
43
  er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT))
rag/benchmark.py CHANGED
@@ -23,7 +23,7 @@ from collections import defaultdict
23
  from api.db import LLMType
24
  from api.db.services.llm_service import LLMBundle
25
  from api.db.services.knowledgebase_service import KnowledgebaseService
26
- from api.settings import retrievaler, docStoreConn
27
  from api.utils import get_uuid
28
  from rag.nlp import tokenize, search
29
  from ranx import evaluate
@@ -52,7 +52,7 @@ class Benchmark:
52
  run = defaultdict(dict)
53
  query_list = list(qrels.keys())
54
  for query in query_list:
55
- ranks = retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
56
  0.0, self.vector_similarity_weight)
57
  if len(ranks["chunks"]) == 0:
58
  print(f"deleted query: {query}")
@@ -81,9 +81,9 @@ class Benchmark:
81
  def init_index(self, vector_size: int):
82
  if self.initialized_index:
83
  return
84
- if docStoreConn.indexExist(self.index_name, self.kb_id):
85
- docStoreConn.deleteIdx(self.index_name, self.kb_id)
86
- docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
87
  self.initialized_index = True
88
 
89
  def ms_marco_index(self, file_path, index_name):
@@ -118,13 +118,13 @@ class Benchmark:
118
  docs_count += len(docs)
119
  docs, vector_size = self.embedding(docs)
120
  self.init_index(vector_size)
121
- docStoreConn.insert(docs, self.index_name, self.kb_id)
122
  docs = []
123
 
124
  if docs:
125
  docs, vector_size = self.embedding(docs)
126
  self.init_index(vector_size)
127
- docStoreConn.insert(docs, self.index_name, self.kb_id)
128
  return qrels, texts
129
 
130
  def trivia_qa_index(self, file_path, index_name):
@@ -159,12 +159,12 @@ class Benchmark:
159
  docs_count += len(docs)
160
  docs, vector_size = self.embedding(docs)
161
  self.init_index(vector_size)
162
- docStoreConn.insert(docs,self.index_name)
163
  docs = []
164
 
165
  docs, vector_size = self.embedding(docs)
166
  self.init_index(vector_size)
167
- docStoreConn.insert(docs, self.index_name)
168
  return qrels, texts
169
 
170
  def miracl_index(self, file_path, corpus_path, index_name):
@@ -214,12 +214,12 @@ class Benchmark:
214
  docs_count += len(docs)
215
  docs, vector_size = self.embedding(docs)
216
  self.init_index(vector_size)
217
- docStoreConn.insert(docs, self.index_name)
218
  docs = []
219
 
220
  docs, vector_size = self.embedding(docs)
221
  self.init_index(vector_size)
222
- docStoreConn.insert(docs, self.index_name)
223
  return qrels, texts
224
 
225
  def save_results(self, qrels, run, texts, dataset, file_path):
 
23
  from api.db import LLMType
24
  from api.db.services.llm_service import LLMBundle
25
  from api.db.services.knowledgebase_service import KnowledgebaseService
26
+ from api import settings
27
  from api.utils import get_uuid
28
  from rag.nlp import tokenize, search
29
  from ranx import evaluate
 
52
  run = defaultdict(dict)
53
  query_list = list(qrels.keys())
54
  for query in query_list:
55
+ ranks = settings.retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
56
  0.0, self.vector_similarity_weight)
57
  if len(ranks["chunks"]) == 0:
58
  print(f"deleted query: {query}")
 
81
  def init_index(self, vector_size: int):
82
  if self.initialized_index:
83
  return
84
+ if settings.docStoreConn.indexExist(self.index_name, self.kb_id):
85
+ settings.docStoreConn.deleteIdx(self.index_name, self.kb_id)
86
+ settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
87
  self.initialized_index = True
88
 
89
  def ms_marco_index(self, file_path, index_name):
 
118
  docs_count += len(docs)
119
  docs, vector_size = self.embedding(docs)
120
  self.init_index(vector_size)
121
+ settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
122
  docs = []
123
 
124
  if docs:
125
  docs, vector_size = self.embedding(docs)
126
  self.init_index(vector_size)
127
+ settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
128
  return qrels, texts
129
 
130
  def trivia_qa_index(self, file_path, index_name):
 
159
  docs_count += len(docs)
160
  docs, vector_size = self.embedding(docs)
161
  self.init_index(vector_size)
162
+ settings.docStoreConn.insert(docs,self.index_name)
163
  docs = []
164
 
165
  docs, vector_size = self.embedding(docs)
166
  self.init_index(vector_size)
167
+ settings.docStoreConn.insert(docs, self.index_name)
168
  return qrels, texts
169
 
170
  def miracl_index(self, file_path, corpus_path, index_name):
 
214
  docs_count += len(docs)
215
  docs, vector_size = self.embedding(docs)
216
  self.init_index(vector_size)
217
+ settings.docStoreConn.insert(docs, self.index_name)
218
  docs = []
219
 
220
  docs, vector_size = self.embedding(docs)
221
  self.init_index(vector_size)
222
+ settings.docStoreConn.insert(docs, self.index_name)
223
  return qrels, texts
224
 
225
  def save_results(self, qrels, run, texts, dataset, file_path):
rag/llm/embedding_model.py CHANGED
@@ -28,7 +28,7 @@ from openai import OpenAI
28
  import numpy as np
29
  import asyncio
30
 
31
- from api.settings import LIGHTEN
32
  from api.utils.file_utils import get_home_cache_dir
33
  from rag.utils import num_tokens_from_string, truncate
34
  import google.generativeai as genai
@@ -60,7 +60,7 @@ class DefaultEmbedding(Base):
60
  ^_-
61
 
62
  """
63
- if not LIGHTEN and not DefaultEmbedding._model:
64
  with DefaultEmbedding._model_lock:
65
  from FlagEmbedding import FlagModel
66
  import torch
@@ -248,7 +248,7 @@ class FastEmbed(Base):
248
  threads: Optional[int] = None,
249
  **kwargs,
250
  ):
251
- if not LIGHTEN and not FastEmbed._model:
252
  from fastembed import TextEmbedding
253
  self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
254
 
@@ -294,7 +294,7 @@ class YoudaoEmbed(Base):
294
  _client = None
295
 
296
  def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
297
- if not LIGHTEN and not YoudaoEmbed._client:
298
  from BCEmbedding import EmbeddingModel as qanthing
299
  try:
300
  logging.info("LOADING BCE...")
 
28
  import numpy as np
29
  import asyncio
30
 
31
+ from api import settings
32
  from api.utils.file_utils import get_home_cache_dir
33
  from rag.utils import num_tokens_from_string, truncate
34
  import google.generativeai as genai
 
60
  ^_-
61
 
62
  """
63
+ if not settings.LIGHTEN and not DefaultEmbedding._model:
64
  with DefaultEmbedding._model_lock:
65
  from FlagEmbedding import FlagModel
66
  import torch
 
248
  threads: Optional[int] = None,
249
  **kwargs,
250
  ):
251
+ if not settings.LIGHTEN and not FastEmbed._model:
252
  from fastembed import TextEmbedding
253
  self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
254
 
 
294
  _client = None
295
 
296
  def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
297
+ if not settings.LIGHTEN and not YoudaoEmbed._client:
298
  from BCEmbedding import EmbeddingModel as qanthing
299
  try:
300
  logging.info("LOADING BCE...")
rag/llm/rerank_model.py CHANGED
@@ -23,7 +23,7 @@ import os
23
  from abc import ABC
24
  import numpy as np
25
 
26
- from api.settings import LIGHTEN
27
  from api.utils.file_utils import get_home_cache_dir
28
  from rag.utils import num_tokens_from_string, truncate
29
  import json
@@ -57,7 +57,7 @@ class DefaultRerank(Base):
57
  ^_-
58
 
59
  """
60
- if not LIGHTEN and not DefaultRerank._model:
61
  import torch
62
  from FlagEmbedding import FlagReranker
63
  with DefaultRerank._model_lock:
@@ -121,7 +121,7 @@ class YoudaoRerank(DefaultRerank):
121
  _model_lock = threading.Lock()
122
 
123
  def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
124
- if not LIGHTEN and not YoudaoRerank._model:
125
  from BCEmbedding import RerankerModel
126
  with YoudaoRerank._model_lock:
127
  if not YoudaoRerank._model:
 
23
  from abc import ABC
24
  import numpy as np
25
 
26
+ from api import settings
27
  from api.utils.file_utils import get_home_cache_dir
28
  from rag.utils import num_tokens_from_string, truncate
29
  import json
 
57
  ^_-
58
 
59
  """
60
+ if not settings.LIGHTEN and not DefaultRerank._model:
61
  import torch
62
  from FlagEmbedding import FlagReranker
63
  with DefaultRerank._model_lock:
 
121
  _model_lock = threading.Lock()
122
 
123
  def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
124
+ if not settings.LIGHTEN and not YoudaoRerank._model:
125
  from BCEmbedding import RerankerModel
126
  with YoudaoRerank._model_lock:
127
  if not YoudaoRerank._model:
rag/svr/task_executor.py CHANGED
@@ -16,6 +16,7 @@
16
  import logging
17
  import sys
18
  from api.utils.log_utils import initRootLogger
 
19
  CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
20
  initRootLogger(f"task_executor_{CONSUMER_NO}")
21
  for module in ["pdfminer"]:
@@ -49,9 +50,10 @@ from api.db.services.document_service import DocumentService
49
  from api.db.services.llm_service import LLMBundle
50
  from api.db.services.task_service import TaskService
51
  from api.db.services.file2document_service import File2DocumentService
52
- from api.settings import retrievaler, docStoreConn
53
  from api.db.db_models import close_connection
54
- from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
 
55
  from rag.nlp import search, rag_tokenizer
56
  from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
57
  from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME
@@ -88,6 +90,7 @@ PENDING_TASKS = 0
88
  HEAD_CREATED_AT = ""
89
  HEAD_DETAIL = ""
90
 
 
91
  def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
92
  global PAYLOAD
93
  if prog is not None and prog < 0:
@@ -171,7 +174,8 @@ def build(row):
171
  "From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
172
  except TimeoutError:
173
  callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
174
- logging.exception("Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"]))
 
175
  return
176
  except Exception as e:
177
  if re.search("(No such file|not found)", str(e)):
@@ -188,7 +192,7 @@ def build(row):
188
  logging.info("Chunking({}) {}/{} done".format(timer() - st, row["location"], row["name"]))
189
  except Exception as e:
190
  callback(-1, "Internal server error while chunking: %s" %
191
- str(e).replace("'", ""))
192
  logging.exception("Chunking {}/{} got exception".format(row["location"], row["name"]))
193
  return
194
 
@@ -226,7 +230,8 @@ def build(row):
226
  STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue())
227
  el += timer() - st
228
  except Exception:
229
- logging.exception("Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"]))
 
230
 
231
  d["img_id"] = "{}-{}".format(row["kb_id"], d["id"])
232
  del d["image"]
@@ -241,7 +246,7 @@ def build(row):
241
  d["important_kwd"] = keyword_extraction(chat_mdl, d["content_with_weight"],
242
  row["parser_config"]["auto_keywords"]).split(",")
243
  d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
244
- callback(msg="Keywords generation completed in {:.2f}s".format(timer()-st))
245
 
246
  if row["parser_config"].get("auto_questions", 0):
247
  st = timer()
@@ -255,14 +260,14 @@ def build(row):
255
  d["content_ltks"] += " " + qst
256
  if "content_sm_ltks" in d:
257
  d["content_sm_ltks"] += " " + rag_tokenizer.fine_grained_tokenize(qst)
258
- callback(msg="Question generation completed in {:.2f}s".format(timer()-st))
259
 
260
  return docs
261
 
262
 
263
  def init_kb(row, vector_size: int):
264
  idxnm = search.index_name(row["tenant_id"])
265
- return docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
266
 
267
 
268
  def embedding(docs, mdl, parser_config=None, callback=None):
@@ -313,7 +318,8 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
313
  vector_size = len(vts[0])
314
  vctr_nm = "q_%d_vec" % vector_size
315
  chunks = []
316
- for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm]):
 
317
  chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
318
 
319
  raptor = Raptor(
@@ -384,7 +390,8 @@ def main():
384
  # TODO: exception handler
385
  ## set_progress(r["did"], -1, "ERROR: ")
386
  callback(
387
- msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks), timer() - st)
 
388
  )
389
  st = timer()
390
  try:
@@ -403,18 +410,18 @@ def main():
403
  es_r = ""
404
  es_bulk_size = 4
405
  for b in range(0, len(cks), es_bulk_size):
406
- es_r = docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
407
  if b % 128 == 0:
408
  callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
409
 
410
  logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
411
  if es_r:
412
  callback(-1, "Insert chunk error, detail info please check log file. Please also check ES status!")
413
- docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
414
  logging.error('Insert chunk error: ' + str(es_r))
415
  else:
416
  if TaskService.do_cancel(r["id"]):
417
- docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
418
  continue
419
  callback(msg="Indexing elapsed in {:.2f}s.".format(timer() - st))
420
  callback(1., "Done!")
@@ -435,7 +442,7 @@ def report_status():
435
  if PENDING_TASKS > 0:
436
  head_info = REDIS_CONN.queue_head(SVR_QUEUE_NAME)
437
  if head_info is not None:
438
- seconds = int(head_info[0].split("-")[0])/1000
439
  HEAD_CREATED_AT = datetime.fromtimestamp(seconds).isoformat()
440
  HEAD_DETAIL = head_info[1]
441
 
@@ -452,7 +459,7 @@ def report_status():
452
  REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
453
  logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
454
 
455
- expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60*30)
456
  if expired > 0:
457
  REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
458
  except Exception:
 
16
  import logging
17
  import sys
18
  from api.utils.log_utils import initRootLogger
19
+
20
  CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
21
  initRootLogger(f"task_executor_{CONSUMER_NO}")
22
  for module in ["pdfminer"]:
 
50
  from api.db.services.llm_service import LLMBundle
51
  from api.db.services.task_service import TaskService
52
  from api.db.services.file2document_service import File2DocumentService
53
+ from api import settings
54
  from api.db.db_models import close_connection
55
+ from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
56
+ knowledge_graph, email
57
  from rag.nlp import search, rag_tokenizer
58
  from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
59
  from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME
 
90
  HEAD_CREATED_AT = ""
91
  HEAD_DETAIL = ""
92
 
93
+
94
  def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
95
  global PAYLOAD
96
  if prog is not None and prog < 0:
 
174
  "From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
175
  except TimeoutError:
176
  callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
177
+ logging.exception(
178
+ "Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"]))
179
  return
180
  except Exception as e:
181
  if re.search("(No such file|not found)", str(e)):
 
192
  logging.info("Chunking({}) {}/{} done".format(timer() - st, row["location"], row["name"]))
193
  except Exception as e:
194
  callback(-1, "Internal server error while chunking: %s" %
195
+ str(e).replace("'", ""))
196
  logging.exception("Chunking {}/{} got exception".format(row["location"], row["name"]))
197
  return
198
 
 
230
  STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue())
231
  el += timer() - st
232
  except Exception:
233
+ logging.exception(
234
+ "Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"]))
235
 
236
  d["img_id"] = "{}-{}".format(row["kb_id"], d["id"])
237
  del d["image"]
 
246
  d["important_kwd"] = keyword_extraction(chat_mdl, d["content_with_weight"],
247
  row["parser_config"]["auto_keywords"]).split(",")
248
  d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
249
+ callback(msg="Keywords generation completed in {:.2f}s".format(timer() - st))
250
 
251
  if row["parser_config"].get("auto_questions", 0):
252
  st = timer()
 
260
  d["content_ltks"] += " " + qst
261
  if "content_sm_ltks" in d:
262
  d["content_sm_ltks"] += " " + rag_tokenizer.fine_grained_tokenize(qst)
263
+ callback(msg="Question generation completed in {:.2f}s".format(timer() - st))
264
 
265
  return docs
266
 
267
 
268
  def init_kb(row, vector_size: int):
269
  idxnm = search.index_name(row["tenant_id"])
270
+ return settings.docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
271
 
272
 
273
  def embedding(docs, mdl, parser_config=None, callback=None):
 
318
  vector_size = len(vts[0])
319
  vctr_nm = "q_%d_vec" % vector_size
320
  chunks = []
321
+ for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
322
+ fields=["content_with_weight", vctr_nm]):
323
  chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
324
 
325
  raptor = Raptor(
 
390
  # TODO: exception handler
391
  ## set_progress(r["did"], -1, "ERROR: ")
392
  callback(
393
+ msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks),
394
+ timer() - st)
395
  )
396
  st = timer()
397
  try:
 
410
  es_r = ""
411
  es_bulk_size = 4
412
  for b in range(0, len(cks), es_bulk_size):
413
+ es_r = settings.docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
414
  if b % 128 == 0:
415
  callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
416
 
417
  logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
418
  if es_r:
419
  callback(-1, "Insert chunk error, detail info please check log file. Please also check ES status!")
420
+ settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
421
  logging.error('Insert chunk error: ' + str(es_r))
422
  else:
423
  if TaskService.do_cancel(r["id"]):
424
+ settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
425
  continue
426
  callback(msg="Indexing elapsed in {:.2f}s.".format(timer() - st))
427
  callback(1., "Done!")
 
442
  if PENDING_TASKS > 0:
443
  head_info = REDIS_CONN.queue_head(SVR_QUEUE_NAME)
444
  if head_info is not None:
445
+ seconds = int(head_info[0].split("-")[0]) / 1000
446
  HEAD_CREATED_AT = datetime.fromtimestamp(seconds).isoformat()
447
  HEAD_DETAIL = head_info[1]
448
 
 
459
  REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
460
  logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
461
 
462
+ expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60 * 30)
463
  if expired > 0:
464
  REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
465
  except Exception: