Commit
·
6101699
1
Parent(s):
fc803e8
Move settings initialization after module init phase (#3438)
Browse files### What problem does this PR solve?
1. Module init won't connect database any more.
2. Config in settings need to be used with settings.CONFIG_NAME
### Type of change
- [x] Refactoring
Signed-off-by: jinhai <[email protected]>
- agent/component/generate.py +13 -9
- agent/component/retrieval.py +2 -2
- api/apps/__init__.py +4 -6
- api/apps/api_app.py +33 -31
- api/apps/canvas_app.py +12 -10
- api/apps/chunk_app.py +17 -17
- api/apps/conversation_app.py +7 -6
- api/apps/dialog_app.py +2 -2
- api/apps/document_app.py +31 -30
- api/apps/file2document_app.py +2 -2
- api/apps/file_app.py +5 -5
- api/apps/kb_app.py +7 -8
- api/apps/llm_app.py +2 -2
- api/apps/sdk/chat.py +3 -3
- api/apps/sdk/dataset.py +3 -3
- api/apps/sdk/dify_retrieval.py +6 -6
- api/apps/sdk/doc.py +25 -25
- api/apps/system_app.py +4 -5
- api/apps/user_app.py +29 -42
- api/db/db_models.py +7 -7
- api/db/init_data.py +12 -11
- api/db/services/dialog_service.py +4 -4
- api/db/services/document_service.py +5 -5
- api/ragflow_server.py +5 -6
- api/settings.py +144 -101
- api/utils/api_utils.py +24 -26
- deepdoc/parser/pdf_parser.py +2 -2
- graphrag/claim_extractor.py +2 -2
- graphrag/smoke.py +2 -2
- rag/benchmark.py +11 -11
- rag/llm/embedding_model.py +4 -4
- rag/llm/rerank_model.py +3 -3
- rag/svr/task_executor.py +22 -15
agent/component/generate.py
CHANGED
@@ -19,7 +19,7 @@ import pandas as pd
|
|
19 |
from api.db import LLMType
|
20 |
from api.db.services.dialog_service import message_fit_in
|
21 |
from api.db.services.llm_service import LLMBundle
|
22 |
-
from api
|
23 |
from agent.component.base import ComponentBase, ComponentParamBase
|
24 |
|
25 |
|
@@ -63,18 +63,20 @@ class Generate(ComponentBase):
|
|
63 |
component_name = "Generate"
|
64 |
|
65 |
def get_dependent_components(self):
|
66 |
-
cpnts = [para["component_id"] for para in self._param.parameters if
|
|
|
67 |
return cpnts
|
68 |
|
69 |
def set_cite(self, retrieval_res, answer):
|
70 |
retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True)
|
71 |
if "empty_response" in retrieval_res.columns:
|
72 |
retrieval_res["empty_response"].fillna("", inplace=True)
|
73 |
-
answer, idx = retrievaler.insert_citations(answer,
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
78 |
doc_ids = set([])
|
79 |
recall_docs = []
|
80 |
for i in idx:
|
@@ -127,12 +129,14 @@ class Generate(ComponentBase):
|
|
127 |
else:
|
128 |
if cpn.component_name.lower() == "retrieval":
|
129 |
retrieval_res.append(out)
|
130 |
-
kwargs[para["key"]] = " - "+"\n - ".join(
|
|
|
131 |
self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]})
|
132 |
|
133 |
if retrieval_res:
|
134 |
retrieval_res = pd.concat(retrieval_res, ignore_index=True)
|
135 |
-
else:
|
|
|
136 |
|
137 |
for n, v in kwargs.items():
|
138 |
prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)
|
|
|
19 |
from api.db import LLMType
|
20 |
from api.db.services.dialog_service import message_fit_in
|
21 |
from api.db.services.llm_service import LLMBundle
|
22 |
+
from api import settings
|
23 |
from agent.component.base import ComponentBase, ComponentParamBase
|
24 |
|
25 |
|
|
|
63 |
component_name = "Generate"
|
64 |
|
65 |
def get_dependent_components(self):
|
66 |
+
cpnts = [para["component_id"] for para in self._param.parameters if
|
67 |
+
para.get("component_id") and para["component_id"].lower().find("answer") < 0]
|
68 |
return cpnts
|
69 |
|
70 |
def set_cite(self, retrieval_res, answer):
|
71 |
retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True)
|
72 |
if "empty_response" in retrieval_res.columns:
|
73 |
retrieval_res["empty_response"].fillna("", inplace=True)
|
74 |
+
answer, idx = settings.retrievaler.insert_citations(answer,
|
75 |
+
[ck["content_ltks"] for _, ck in retrieval_res.iterrows()],
|
76 |
+
[ck["vector"] for _, ck in retrieval_res.iterrows()],
|
77 |
+
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
|
78 |
+
self._canvas.get_embedding_model()), tkweight=0.7,
|
79 |
+
vtweight=0.3)
|
80 |
doc_ids = set([])
|
81 |
recall_docs = []
|
82 |
for i in idx:
|
|
|
129 |
else:
|
130 |
if cpn.component_name.lower() == "retrieval":
|
131 |
retrieval_res.append(out)
|
132 |
+
kwargs[para["key"]] = " - " + "\n - ".join(
|
133 |
+
[o if isinstance(o, str) else str(o) for o in out["content"]])
|
134 |
self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]})
|
135 |
|
136 |
if retrieval_res:
|
137 |
retrieval_res = pd.concat(retrieval_res, ignore_index=True)
|
138 |
+
else:
|
139 |
+
retrieval_res = pd.DataFrame([])
|
140 |
|
141 |
for n, v in kwargs.items():
|
142 |
prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)
|
agent/component/retrieval.py
CHANGED
@@ -21,7 +21,7 @@ import pandas as pd
|
|
21 |
from api.db import LLMType
|
22 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
23 |
from api.db.services.llm_service import LLMBundle
|
24 |
-
from api
|
25 |
from agent.component.base import ComponentBase, ComponentParamBase
|
26 |
|
27 |
|
@@ -67,7 +67,7 @@ class Retrieval(ComponentBase, ABC):
|
|
67 |
if self._param.rerank_id:
|
68 |
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
|
69 |
|
70 |
-
kbinfos = retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
|
71 |
1, self._param.top_n,
|
72 |
self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
|
73 |
aggs=False, rerank_mdl=rerank_mdl)
|
|
|
21 |
from api.db import LLMType
|
22 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
23 |
from api.db.services.llm_service import LLMBundle
|
24 |
+
from api import settings
|
25 |
from agent.component.base import ComponentBase, ComponentParamBase
|
26 |
|
27 |
|
|
|
67 |
if self._param.rerank_id:
|
68 |
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
|
69 |
|
70 |
+
kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
|
71 |
1, self._param.top_n,
|
72 |
self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
|
73 |
aggs=False, rerank_mdl=rerank_mdl)
|
api/apps/__init__.py
CHANGED
@@ -30,8 +30,7 @@ from api.utils import CustomJSONEncoder, commands
|
|
30 |
|
31 |
from flask_session import Session
|
32 |
from flask_login import LoginManager
|
33 |
-
from api
|
34 |
-
from api.settings import API_VERSION
|
35 |
from api.utils.api_utils import server_error_response
|
36 |
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
37 |
|
@@ -78,7 +77,6 @@ app.url_map.strict_slashes = False
|
|
78 |
app.json_encoder = CustomJSONEncoder
|
79 |
app.errorhandler(Exception)(server_error_response)
|
80 |
|
81 |
-
|
82 |
## convince for dev and debug
|
83 |
# app.config["LOGIN_DISABLED"] = True
|
84 |
app.config["SESSION_PERMANENT"] = False
|
@@ -110,7 +108,7 @@ def register_page(page_path):
|
|
110 |
|
111 |
page_name = page_path.stem.rstrip("_app")
|
112 |
module_name = ".".join(
|
113 |
-
page_path.parts[page_path.parts.index("api")
|
114 |
)
|
115 |
|
116 |
spec = spec_from_file_location(module_name, page_path)
|
@@ -121,7 +119,7 @@ def register_page(page_path):
|
|
121 |
spec.loader.exec_module(page)
|
122 |
page_name = getattr(page, "page_name", page_name)
|
123 |
url_prefix = (
|
124 |
-
f"/api/{API_VERSION}" if "/sdk/" in path else f"/{API_VERSION}/{page_name}"
|
125 |
)
|
126 |
|
127 |
app.register_blueprint(page.manager, url_prefix=url_prefix)
|
@@ -141,7 +139,7 @@ client_urls_prefix = [
|
|
141 |
|
142 |
@login_manager.request_loader
|
143 |
def load_user(web_request):
|
144 |
-
jwt = Serializer(secret_key=SECRET_KEY)
|
145 |
authorization = web_request.headers.get("Authorization")
|
146 |
if authorization:
|
147 |
try:
|
|
|
30 |
|
31 |
from flask_session import Session
|
32 |
from flask_login import LoginManager
|
33 |
+
from api import settings
|
|
|
34 |
from api.utils.api_utils import server_error_response
|
35 |
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
36 |
|
|
|
77 |
app.json_encoder = CustomJSONEncoder
|
78 |
app.errorhandler(Exception)(server_error_response)
|
79 |
|
|
|
80 |
## convince for dev and debug
|
81 |
# app.config["LOGIN_DISABLED"] = True
|
82 |
app.config["SESSION_PERMANENT"] = False
|
|
|
108 |
|
109 |
page_name = page_path.stem.rstrip("_app")
|
110 |
module_name = ".".join(
|
111 |
+
page_path.parts[page_path.parts.index("api"): -1] + (page_name,)
|
112 |
)
|
113 |
|
114 |
spec = spec_from_file_location(module_name, page_path)
|
|
|
119 |
spec.loader.exec_module(page)
|
120 |
page_name = getattr(page, "page_name", page_name)
|
121 |
url_prefix = (
|
122 |
+
f"/api/{settings.API_VERSION}" if "/sdk/" in path else f"/{settings.API_VERSION}/{page_name}"
|
123 |
)
|
124 |
|
125 |
app.register_blueprint(page.manager, url_prefix=url_prefix)
|
|
|
139 |
|
140 |
@login_manager.request_loader
|
141 |
def load_user(web_request):
|
142 |
+
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
143 |
authorization = web_request.headers.get("Authorization")
|
144 |
if authorization:
|
145 |
try:
|
api/apps/api_app.py
CHANGED
@@ -32,7 +32,7 @@ from api.db.services.file_service import FileService
|
|
32 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
33 |
from api.db.services.task_service import queue_tasks, TaskService
|
34 |
from api.db.services.user_service import UserTenantService
|
35 |
-
from api
|
36 |
from api.utils import get_uuid, current_timestamp, datetime_format
|
37 |
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
|
38 |
generate_confirmation_token
|
@@ -141,7 +141,7 @@ def set_conversation():
|
|
141 |
objs = APIToken.query(token=token)
|
142 |
if not objs:
|
143 |
return get_json_result(
|
144 |
-
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
145 |
req = request.json
|
146 |
try:
|
147 |
if objs[0].source == "agent":
|
@@ -183,7 +183,7 @@ def completion():
|
|
183 |
objs = APIToken.query(token=token)
|
184 |
if not objs:
|
185 |
return get_json_result(
|
186 |
-
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
187 |
req = request.json
|
188 |
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
189 |
if not e:
|
@@ -290,8 +290,8 @@ def completion():
|
|
290 |
API4ConversationService.append_message(conv.id, conv.to_dict())
|
291 |
rename_field(result)
|
292 |
return get_json_result(data=result)
|
293 |
-
|
294 |
-
|
295 |
conv.message.append(msg[-1])
|
296 |
e, dia = DialogService.get_by_id(conv.dialog_id)
|
297 |
if not e:
|
@@ -326,7 +326,7 @@ def completion():
|
|
326 |
resp.headers.add_header("X-Accel-Buffering", "no")
|
327 |
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
328 |
return resp
|
329 |
-
|
330 |
answer = None
|
331 |
for ans in chat(dia, msg, **req):
|
332 |
answer = ans
|
@@ -347,8 +347,8 @@ def get(conversation_id):
|
|
347 |
objs = APIToken.query(token=token)
|
348 |
if not objs:
|
349 |
return get_json_result(
|
350 |
-
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
351 |
-
|
352 |
try:
|
353 |
e, conv = API4ConversationService.get_by_id(conversation_id)
|
354 |
if not e:
|
@@ -357,8 +357,8 @@ def get(conversation_id):
|
|
357 |
conv = conv.to_dict()
|
358 |
if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token:
|
359 |
return get_json_result(data=False, message='Token is not valid for this conversation_id!"',
|
360 |
-
code=RetCode.AUTHENTICATION_ERROR)
|
361 |
-
|
362 |
for referenct_i in conv['reference']:
|
363 |
if referenct_i is None or len(referenct_i) == 0:
|
364 |
continue
|
@@ -378,7 +378,7 @@ def upload():
|
|
378 |
objs = APIToken.query(token=token)
|
379 |
if not objs:
|
380 |
return get_json_result(
|
381 |
-
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
382 |
|
383 |
kb_name = request.form.get("kb_name").strip()
|
384 |
tenant_id = objs[0].tenant_id
|
@@ -394,12 +394,12 @@ def upload():
|
|
394 |
|
395 |
if 'file' not in request.files:
|
396 |
return get_json_result(
|
397 |
-
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
398 |
|
399 |
file = request.files['file']
|
400 |
if file.filename == '':
|
401 |
return get_json_result(
|
402 |
-
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
403 |
|
404 |
root_folder = FileService.get_root_folder(tenant_id)
|
405 |
pf_id = root_folder["id"]
|
@@ -490,17 +490,17 @@ def upload_parse():
|
|
490 |
objs = APIToken.query(token=token)
|
491 |
if not objs:
|
492 |
return get_json_result(
|
493 |
-
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
494 |
|
495 |
if 'file' not in request.files:
|
496 |
return get_json_result(
|
497 |
-
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
498 |
|
499 |
file_objs = request.files.getlist('file')
|
500 |
for file_obj in file_objs:
|
501 |
if file_obj.filename == '':
|
502 |
return get_json_result(
|
503 |
-
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
504 |
|
505 |
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
|
506 |
return get_json_result(data=doc_ids)
|
@@ -513,7 +513,7 @@ def list_chunks():
|
|
513 |
objs = APIToken.query(token=token)
|
514 |
if not objs:
|
515 |
return get_json_result(
|
516 |
-
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
517 |
|
518 |
req = request.json
|
519 |
|
@@ -531,7 +531,7 @@ def list_chunks():
|
|
531 |
)
|
532 |
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
533 |
|
534 |
-
res = retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
|
535 |
res = [
|
536 |
{
|
537 |
"content": res_item["content_with_weight"],
|
@@ -553,7 +553,7 @@ def list_kb_docs():
|
|
553 |
objs = APIToken.query(token=token)
|
554 |
if not objs:
|
555 |
return get_json_result(
|
556 |
-
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
557 |
|
558 |
req = request.json
|
559 |
tenant_id = objs[0].tenant_id
|
@@ -585,6 +585,7 @@ def list_kb_docs():
|
|
585 |
except Exception as e:
|
586 |
return server_error_response(e)
|
587 |
|
|
|
588 |
@manager.route('/document/infos', methods=['POST'])
|
589 |
@validate_request("doc_ids")
|
590 |
def docinfos():
|
@@ -592,7 +593,7 @@ def docinfos():
|
|
592 |
objs = APIToken.query(token=token)
|
593 |
if not objs:
|
594 |
return get_json_result(
|
595 |
-
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
596 |
req = request.json
|
597 |
doc_ids = req["doc_ids"]
|
598 |
docs = DocumentService.get_by_ids(doc_ids)
|
@@ -606,7 +607,7 @@ def document_rm():
|
|
606 |
objs = APIToken.query(token=token)
|
607 |
if not objs:
|
608 |
return get_json_result(
|
609 |
-
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
610 |
|
611 |
tenant_id = objs[0].tenant_id
|
612 |
req = request.json
|
@@ -653,7 +654,7 @@ def document_rm():
|
|
653 |
errors += str(e)
|
654 |
|
655 |
if errors:
|
656 |
-
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
657 |
|
658 |
return get_json_result(data=True)
|
659 |
|
@@ -668,7 +669,7 @@ def completion_faq():
|
|
668 |
objs = APIToken.query(token=token)
|
669 |
if not objs:
|
670 |
return get_json_result(
|
671 |
-
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
672 |
|
673 |
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
674 |
if not e:
|
@@ -805,10 +806,10 @@ def retrieval():
|
|
805 |
objs = APIToken.query(token=token)
|
806 |
if not objs:
|
807 |
return get_json_result(
|
808 |
-
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
809 |
|
810 |
req = request.json
|
811 |
-
kb_ids = req.get("kb_id",[])
|
812 |
doc_ids = req.get("doc_ids", [])
|
813 |
question = req.get("question")
|
814 |
page = int(req.get("page", 1))
|
@@ -822,20 +823,21 @@ def retrieval():
|
|
822 |
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
823 |
if len(embd_nms) != 1:
|
824 |
return get_json_result(
|
825 |
-
data=False, message='Knowledge bases use different embedding models or does not exist."',
|
|
|
826 |
|
827 |
embd_mdl = TenantLLMService.model_instance(
|
828 |
kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id)
|
829 |
rerank_mdl = None
|
830 |
if req.get("rerank_id"):
|
831 |
rerank_mdl = TenantLLMService.model_instance(
|
832 |
-
|
833 |
if req.get("keyword", False):
|
834 |
chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT)
|
835 |
question += keyword_extraction(chat_mdl, question)
|
836 |
-
ranks = retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
837 |
-
|
838 |
-
|
839 |
for c in ranks["chunks"]:
|
840 |
if "vector" in c:
|
841 |
del c["vector"]
|
@@ -843,5 +845,5 @@ def retrieval():
|
|
843 |
except Exception as e:
|
844 |
if str(e).find("not_found") > 0:
|
845 |
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
846 |
-
code=RetCode.DATA_ERROR)
|
847 |
return server_error_response(e)
|
|
|
32 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
33 |
from api.db.services.task_service import queue_tasks, TaskService
|
34 |
from api.db.services.user_service import UserTenantService
|
35 |
+
from api import settings
|
36 |
from api.utils import get_uuid, current_timestamp, datetime_format
|
37 |
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
|
38 |
generate_confirmation_token
|
|
|
141 |
objs = APIToken.query(token=token)
|
142 |
if not objs:
|
143 |
return get_json_result(
|
144 |
+
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
145 |
req = request.json
|
146 |
try:
|
147 |
if objs[0].source == "agent":
|
|
|
183 |
objs = APIToken.query(token=token)
|
184 |
if not objs:
|
185 |
return get_json_result(
|
186 |
+
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
187 |
req = request.json
|
188 |
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
189 |
if not e:
|
|
|
290 |
API4ConversationService.append_message(conv.id, conv.to_dict())
|
291 |
rename_field(result)
|
292 |
return get_json_result(data=result)
|
293 |
+
|
294 |
+
# ******************For dialog******************
|
295 |
conv.message.append(msg[-1])
|
296 |
e, dia = DialogService.get_by_id(conv.dialog_id)
|
297 |
if not e:
|
|
|
326 |
resp.headers.add_header("X-Accel-Buffering", "no")
|
327 |
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
328 |
return resp
|
329 |
+
|
330 |
answer = None
|
331 |
for ans in chat(dia, msg, **req):
|
332 |
answer = ans
|
|
|
347 |
objs = APIToken.query(token=token)
|
348 |
if not objs:
|
349 |
return get_json_result(
|
350 |
+
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
351 |
+
|
352 |
try:
|
353 |
e, conv = API4ConversationService.get_by_id(conversation_id)
|
354 |
if not e:
|
|
|
357 |
conv = conv.to_dict()
|
358 |
if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token:
|
359 |
return get_json_result(data=False, message='Token is not valid for this conversation_id!"',
|
360 |
+
code=settings.RetCode.AUTHENTICATION_ERROR)
|
361 |
+
|
362 |
for referenct_i in conv['reference']:
|
363 |
if referenct_i is None or len(referenct_i) == 0:
|
364 |
continue
|
|
|
378 |
objs = APIToken.query(token=token)
|
379 |
if not objs:
|
380 |
return get_json_result(
|
381 |
+
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
382 |
|
383 |
kb_name = request.form.get("kb_name").strip()
|
384 |
tenant_id = objs[0].tenant_id
|
|
|
394 |
|
395 |
if 'file' not in request.files:
|
396 |
return get_json_result(
|
397 |
+
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
398 |
|
399 |
file = request.files['file']
|
400 |
if file.filename == '':
|
401 |
return get_json_result(
|
402 |
+
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
403 |
|
404 |
root_folder = FileService.get_root_folder(tenant_id)
|
405 |
pf_id = root_folder["id"]
|
|
|
490 |
objs = APIToken.query(token=token)
|
491 |
if not objs:
|
492 |
return get_json_result(
|
493 |
+
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
494 |
|
495 |
if 'file' not in request.files:
|
496 |
return get_json_result(
|
497 |
+
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
498 |
|
499 |
file_objs = request.files.getlist('file')
|
500 |
for file_obj in file_objs:
|
501 |
if file_obj.filename == '':
|
502 |
return get_json_result(
|
503 |
+
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
504 |
|
505 |
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
|
506 |
return get_json_result(data=doc_ids)
|
|
|
513 |
objs = APIToken.query(token=token)
|
514 |
if not objs:
|
515 |
return get_json_result(
|
516 |
+
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
517 |
|
518 |
req = request.json
|
519 |
|
|
|
531 |
)
|
532 |
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
533 |
|
534 |
+
res = settings.retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
|
535 |
res = [
|
536 |
{
|
537 |
"content": res_item["content_with_weight"],
|
|
|
553 |
objs = APIToken.query(token=token)
|
554 |
if not objs:
|
555 |
return get_json_result(
|
556 |
+
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
557 |
|
558 |
req = request.json
|
559 |
tenant_id = objs[0].tenant_id
|
|
|
585 |
except Exception as e:
|
586 |
return server_error_response(e)
|
587 |
|
588 |
+
|
589 |
@manager.route('/document/infos', methods=['POST'])
|
590 |
@validate_request("doc_ids")
|
591 |
def docinfos():
|
|
|
593 |
objs = APIToken.query(token=token)
|
594 |
if not objs:
|
595 |
return get_json_result(
|
596 |
+
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
597 |
req = request.json
|
598 |
doc_ids = req["doc_ids"]
|
599 |
docs = DocumentService.get_by_ids(doc_ids)
|
|
|
607 |
objs = APIToken.query(token=token)
|
608 |
if not objs:
|
609 |
return get_json_result(
|
610 |
+
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
611 |
|
612 |
tenant_id = objs[0].tenant_id
|
613 |
req = request.json
|
|
|
654 |
errors += str(e)
|
655 |
|
656 |
if errors:
|
657 |
+
return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)
|
658 |
|
659 |
return get_json_result(data=True)
|
660 |
|
|
|
669 |
objs = APIToken.query(token=token)
|
670 |
if not objs:
|
671 |
return get_json_result(
|
672 |
+
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
673 |
|
674 |
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
675 |
if not e:
|
|
|
806 |
objs = APIToken.query(token=token)
|
807 |
if not objs:
|
808 |
return get_json_result(
|
809 |
+
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
810 |
|
811 |
req = request.json
|
812 |
+
kb_ids = req.get("kb_id", [])
|
813 |
doc_ids = req.get("doc_ids", [])
|
814 |
question = req.get("question")
|
815 |
page = int(req.get("page", 1))
|
|
|
823 |
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
824 |
if len(embd_nms) != 1:
|
825 |
return get_json_result(
|
826 |
+
data=False, message='Knowledge bases use different embedding models or does not exist."',
|
827 |
+
code=settings.RetCode.AUTHENTICATION_ERROR)
|
828 |
|
829 |
embd_mdl = TenantLLMService.model_instance(
|
830 |
kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id)
|
831 |
rerank_mdl = None
|
832 |
if req.get("rerank_id"):
|
833 |
rerank_mdl = TenantLLMService.model_instance(
|
834 |
+
kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
835 |
if req.get("keyword", False):
|
836 |
chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT)
|
837 |
question += keyword_extraction(chat_mdl, question)
|
838 |
+
ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
839 |
+
similarity_threshold, vector_similarity_weight, top,
|
840 |
+
doc_ids, rerank_mdl=rerank_mdl)
|
841 |
for c in ranks["chunks"]:
|
842 |
if "vector" in c:
|
843 |
del c["vector"]
|
|
|
845 |
except Exception as e:
|
846 |
if str(e).find("not_found") > 0:
|
847 |
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
848 |
+
code=settings.RetCode.DATA_ERROR)
|
849 |
return server_error_response(e)
|
api/apps/canvas_app.py
CHANGED
@@ -19,7 +19,7 @@ from functools import partial
|
|
19 |
from flask import request, Response
|
20 |
from flask_login import login_required, current_user
|
21 |
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
22 |
-
from api
|
23 |
from api.utils import get_uuid
|
24 |
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
|
25 |
from agent.canvas import Canvas
|
@@ -36,7 +36,8 @@ def templates():
|
|
36 |
@login_required
|
37 |
def canvas_list():
|
38 |
return get_json_result(data=sorted([c.to_dict() for c in \
|
39 |
-
|
|
|
40 |
)
|
41 |
|
42 |
|
@@ -45,10 +46,10 @@ def canvas_list():
|
|
45 |
@login_required
|
46 |
def rm():
|
47 |
for i in request.json["canvas_ids"]:
|
48 |
-
if not UserCanvasService.query(user_id=current_user.id,id=i):
|
49 |
return get_json_result(
|
50 |
data=False, message='Only owner of canvas authorized for this operation.',
|
51 |
-
code=RetCode.OPERATING_ERROR)
|
52 |
UserCanvasService.delete_by_id(i)
|
53 |
return get_json_result(data=True)
|
54 |
|
@@ -72,7 +73,7 @@ def save():
|
|
72 |
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
|
73 |
return get_json_result(
|
74 |
data=False, message='Only owner of canvas authorized for this operation.',
|
75 |
-
code=RetCode.OPERATING_ERROR)
|
76 |
UserCanvasService.update_by_id(req["id"], req)
|
77 |
return get_json_result(data=req)
|
78 |
|
@@ -98,7 +99,7 @@ def run():
|
|
98 |
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
|
99 |
return get_json_result(
|
100 |
data=False, message='Only owner of canvas authorized for this operation.',
|
101 |
-
code=RetCode.OPERATING_ERROR)
|
102 |
|
103 |
if not isinstance(cvs.dsl, str):
|
104 |
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
@@ -110,8 +111,8 @@ def run():
|
|
110 |
if "message" in req:
|
111 |
canvas.messages.append({"role": "user", "content": req["message"], "id": message_id})
|
112 |
if len([m for m in canvas.messages if m["role"] == "user"]) > 1:
|
113 |
-
#ten = TenantService.get_info_by(current_user.id)[0]
|
114 |
-
#req["message"] = full_question(ten["tenant_id"], ten["llm_id"], canvas.messages)
|
115 |
pass
|
116 |
canvas.add_user_input(req["message"])
|
117 |
answer = canvas.run(stream=stream)
|
@@ -122,7 +123,8 @@ def run():
|
|
122 |
assert answer is not None, "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
|
123 |
|
124 |
if stream:
|
125 |
-
assert isinstance(answer,
|
|
|
126 |
|
127 |
def sse():
|
128 |
nonlocal answer, cvs
|
@@ -173,7 +175,7 @@ def reset():
|
|
173 |
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
|
174 |
return get_json_result(
|
175 |
data=False, message='Only owner of canvas authorized for this operation.',
|
176 |
-
code=RetCode.OPERATING_ERROR)
|
177 |
|
178 |
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
|
179 |
canvas.reset()
|
|
|
19 |
from flask import request, Response
|
20 |
from flask_login import login_required, current_user
|
21 |
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
22 |
+
from api import settings
|
23 |
from api.utils import get_uuid
|
24 |
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
|
25 |
from agent.canvas import Canvas
|
|
|
36 |
@login_required
|
37 |
def canvas_list():
|
38 |
return get_json_result(data=sorted([c.to_dict() for c in \
|
39 |
+
UserCanvasService.query(user_id=current_user.id)],
|
40 |
+
key=lambda x: x["update_time"] * -1)
|
41 |
)
|
42 |
|
43 |
|
|
|
46 |
@login_required
|
47 |
def rm():
|
48 |
for i in request.json["canvas_ids"]:
|
49 |
+
if not UserCanvasService.query(user_id=current_user.id, id=i):
|
50 |
return get_json_result(
|
51 |
data=False, message='Only owner of canvas authorized for this operation.',
|
52 |
+
code=settings.RetCode.OPERATING_ERROR)
|
53 |
UserCanvasService.delete_by_id(i)
|
54 |
return get_json_result(data=True)
|
55 |
|
|
|
73 |
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
|
74 |
return get_json_result(
|
75 |
data=False, message='Only owner of canvas authorized for this operation.',
|
76 |
+
code=settings.RetCode.OPERATING_ERROR)
|
77 |
UserCanvasService.update_by_id(req["id"], req)
|
78 |
return get_json_result(data=req)
|
79 |
|
|
|
99 |
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
|
100 |
return get_json_result(
|
101 |
data=False, message='Only owner of canvas authorized for this operation.',
|
102 |
+
code=settings.RetCode.OPERATING_ERROR)
|
103 |
|
104 |
if not isinstance(cvs.dsl, str):
|
105 |
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
|
|
111 |
if "message" in req:
|
112 |
canvas.messages.append({"role": "user", "content": req["message"], "id": message_id})
|
113 |
if len([m for m in canvas.messages if m["role"] == "user"]) > 1:
|
114 |
+
# ten = TenantService.get_info_by(current_user.id)[0]
|
115 |
+
# req["message"] = full_question(ten["tenant_id"], ten["llm_id"], canvas.messages)
|
116 |
pass
|
117 |
canvas.add_user_input(req["message"])
|
118 |
answer = canvas.run(stream=stream)
|
|
|
123 |
assert answer is not None, "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
|
124 |
|
125 |
if stream:
|
126 |
+
assert isinstance(answer,
|
127 |
+
partial), "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
|
128 |
|
129 |
def sse():
|
130 |
nonlocal answer, cvs
|
|
|
175 |
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
|
176 |
return get_json_result(
|
177 |
data=False, message='Only owner of canvas authorized for this operation.',
|
178 |
+
code=settings.RetCode.OPERATING_ERROR)
|
179 |
|
180 |
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
|
181 |
canvas.reset()
|
api/apps/chunk_app.py
CHANGED
@@ -29,11 +29,12 @@ from api.db.services.llm_service import LLMBundle
|
|
29 |
from api.db.services.user_service import UserTenantService
|
30 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
31 |
from api.db.services.document_service import DocumentService
|
32 |
-
from api
|
33 |
from api.utils.api_utils import get_json_result
|
34 |
import hashlib
|
35 |
import re
|
36 |
|
|
|
37 |
@manager.route('/list', methods=['POST'])
|
38 |
@login_required
|
39 |
@validate_request("doc_id")
|
@@ -56,7 +57,7 @@ def list_chunk():
|
|
56 |
}
|
57 |
if "available_int" in req:
|
58 |
query["available_int"] = int(req["available_int"])
|
59 |
-
sres = retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
|
60 |
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
61 |
for id in sres.ids:
|
62 |
d = {
|
@@ -72,13 +73,13 @@ def list_chunk():
|
|
72 |
"positions": json.loads(sres.field[id].get("position_list", "[]")),
|
73 |
}
|
74 |
assert isinstance(d["positions"], list)
|
75 |
-
assert len(d["positions"])==0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5)
|
76 |
res["chunks"].append(d)
|
77 |
return get_json_result(data=res)
|
78 |
except Exception as e:
|
79 |
if str(e).find("not_found") > 0:
|
80 |
return get_json_result(data=False, message='No chunk found!',
|
81 |
-
code=RetCode.DATA_ERROR)
|
82 |
return server_error_response(e)
|
83 |
|
84 |
|
@@ -93,7 +94,7 @@ def get():
|
|
93 |
tenant_id = tenants[0].tenant_id
|
94 |
|
95 |
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
96 |
-
chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
|
97 |
if chunk is None:
|
98 |
return server_error_response("Chunk not found")
|
99 |
k = []
|
@@ -107,7 +108,7 @@ def get():
|
|
107 |
except Exception as e:
|
108 |
if str(e).find("NotFoundError") >= 0:
|
109 |
return get_json_result(data=False, message='Chunk not found!',
|
110 |
-
code=RetCode.DATA_ERROR)
|
111 |
return server_error_response(e)
|
112 |
|
113 |
|
@@ -154,7 +155,7 @@ def set():
|
|
154 |
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
155 |
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
156 |
d["q_%d_vec" % len(v)] = v.tolist()
|
157 |
-
docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
158 |
return get_json_result(data=True)
|
159 |
except Exception as e:
|
160 |
return server_error_response(e)
|
@@ -169,8 +170,8 @@ def switch():
|
|
169 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
170 |
if not e:
|
171 |
return get_data_error_result(message="Document not found!")
|
172 |
-
if not docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])},
|
173 |
-
|
174 |
return get_data_error_result(message="Index updating failure")
|
175 |
return get_json_result(data=True)
|
176 |
except Exception as e:
|
@@ -186,7 +187,7 @@ def rm():
|
|
186 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
187 |
if not e:
|
188 |
return get_data_error_result(message="Document not found!")
|
189 |
-
if not docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id):
|
190 |
return get_data_error_result(message="Index updating failure")
|
191 |
deleted_chunk_ids = req["chunk_ids"]
|
192 |
chunk_number = len(deleted_chunk_ids)
|
@@ -230,7 +231,7 @@ def create():
|
|
230 |
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
231 |
v = 0.1 * v[0] + 0.9 * v[1]
|
232 |
d["q_%d_vec" % len(v)] = v.tolist()
|
233 |
-
docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
234 |
|
235 |
DocumentService.increment_chunk_num(
|
236 |
doc.id, doc.kb_id, c, 1, 0)
|
@@ -265,7 +266,7 @@ def retrieval_test():
|
|
265 |
else:
|
266 |
return get_json_result(
|
267 |
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
268 |
-
code=RetCode.OPERATING_ERROR)
|
269 |
|
270 |
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
271 |
if not e:
|
@@ -281,7 +282,7 @@ def retrieval_test():
|
|
281 |
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
282 |
question += keyword_extraction(chat_mdl, question)
|
283 |
|
284 |
-
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
285 |
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, page, size,
|
286 |
similarity_threshold, vector_similarity_weight, top,
|
287 |
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
|
@@ -293,7 +294,7 @@ def retrieval_test():
|
|
293 |
except Exception as e:
|
294 |
if str(e).find("not_found") > 0:
|
295 |
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
296 |
-
code=RetCode.DATA_ERROR)
|
297 |
return server_error_response(e)
|
298 |
|
299 |
|
@@ -304,10 +305,10 @@ def knowledge_graph():
|
|
304 |
tenant_id = DocumentService.get_tenant_id(doc_id)
|
305 |
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
306 |
req = {
|
307 |
-
"doc_ids":[doc_id],
|
308 |
"knowledge_graph_kwd": ["graph", "mind_map"]
|
309 |
}
|
310 |
-
sres = retrievaler.search(req, search.index_name(tenant_id), kb_ids)
|
311 |
obj = {"graph": {}, "mind_map": {}}
|
312 |
for id in sres.ids[:2]:
|
313 |
ty = sres.field[id]["knowledge_graph_kwd"]
|
@@ -336,4 +337,3 @@ def knowledge_graph():
|
|
336 |
obj[ty] = content_json
|
337 |
|
338 |
return get_json_result(data=obj)
|
339 |
-
|
|
|
29 |
from api.db.services.user_service import UserTenantService
|
30 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
31 |
from api.db.services.document_service import DocumentService
|
32 |
+
from api import settings
|
33 |
from api.utils.api_utils import get_json_result
|
34 |
import hashlib
|
35 |
import re
|
36 |
|
37 |
+
|
38 |
@manager.route('/list', methods=['POST'])
|
39 |
@login_required
|
40 |
@validate_request("doc_id")
|
|
|
57 |
}
|
58 |
if "available_int" in req:
|
59 |
query["available_int"] = int(req["available_int"])
|
60 |
+
sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
|
61 |
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
62 |
for id in sres.ids:
|
63 |
d = {
|
|
|
73 |
"positions": json.loads(sres.field[id].get("position_list", "[]")),
|
74 |
}
|
75 |
assert isinstance(d["positions"], list)
|
76 |
+
assert len(d["positions"]) == 0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5)
|
77 |
res["chunks"].append(d)
|
78 |
return get_json_result(data=res)
|
79 |
except Exception as e:
|
80 |
if str(e).find("not_found") > 0:
|
81 |
return get_json_result(data=False, message='No chunk found!',
|
82 |
+
code=settings.RetCode.DATA_ERROR)
|
83 |
return server_error_response(e)
|
84 |
|
85 |
|
|
|
94 |
tenant_id = tenants[0].tenant_id
|
95 |
|
96 |
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
97 |
+
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
|
98 |
if chunk is None:
|
99 |
return server_error_response("Chunk not found")
|
100 |
k = []
|
|
|
108 |
except Exception as e:
|
109 |
if str(e).find("NotFoundError") >= 0:
|
110 |
return get_json_result(data=False, message='Chunk not found!',
|
111 |
+
code=settings.RetCode.DATA_ERROR)
|
112 |
return server_error_response(e)
|
113 |
|
114 |
|
|
|
155 |
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
156 |
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
157 |
d["q_%d_vec" % len(v)] = v.tolist()
|
158 |
+
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
159 |
return get_json_result(data=True)
|
160 |
except Exception as e:
|
161 |
return server_error_response(e)
|
|
|
170 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
171 |
if not e:
|
172 |
return get_data_error_result(message="Document not found!")
|
173 |
+
if not settings.docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])},
|
174 |
+
search.index_name(doc.tenant_id), doc.kb_id):
|
175 |
return get_data_error_result(message="Index updating failure")
|
176 |
return get_json_result(data=True)
|
177 |
except Exception as e:
|
|
|
187 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
188 |
if not e:
|
189 |
return get_data_error_result(message="Document not found!")
|
190 |
+
if not settings.docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id):
|
191 |
return get_data_error_result(message="Index updating failure")
|
192 |
deleted_chunk_ids = req["chunk_ids"]
|
193 |
chunk_number = len(deleted_chunk_ids)
|
|
|
231 |
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
232 |
v = 0.1 * v[0] + 0.9 * v[1]
|
233 |
d["q_%d_vec" % len(v)] = v.tolist()
|
234 |
+
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
235 |
|
236 |
DocumentService.increment_chunk_num(
|
237 |
doc.id, doc.kb_id, c, 1, 0)
|
|
|
266 |
else:
|
267 |
return get_json_result(
|
268 |
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
269 |
+
code=settings.RetCode.OPERATING_ERROR)
|
270 |
|
271 |
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
272 |
if not e:
|
|
|
282 |
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
283 |
question += keyword_extraction(chat_mdl, question)
|
284 |
|
285 |
+
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
|
286 |
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, page, size,
|
287 |
similarity_threshold, vector_similarity_weight, top,
|
288 |
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
|
|
|
294 |
except Exception as e:
|
295 |
if str(e).find("not_found") > 0:
|
296 |
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
297 |
+
code=settings.RetCode.DATA_ERROR)
|
298 |
return server_error_response(e)
|
299 |
|
300 |
|
|
|
305 |
tenant_id = DocumentService.get_tenant_id(doc_id)
|
306 |
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
307 |
req = {
|
308 |
+
"doc_ids": [doc_id],
|
309 |
"knowledge_graph_kwd": ["graph", "mind_map"]
|
310 |
}
|
311 |
+
sres = settings.retrievaler.search(req, search.index_name(tenant_id), kb_ids)
|
312 |
obj = {"graph": {}, "mind_map": {}}
|
313 |
for id in sres.ids[:2]:
|
314 |
ty = sres.field[id]["knowledge_graph_kwd"]
|
|
|
337 |
obj[ty] = content_json
|
338 |
|
339 |
return get_json_result(data=obj)
|
|
api/apps/conversation_app.py
CHANGED
@@ -25,7 +25,7 @@ from api.db import LLMType
|
|
25 |
from api.db.services.dialog_service import DialogService, ConversationService, chat, ask
|
26 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
27 |
from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
|
28 |
-
from api
|
29 |
from api.utils.api_utils import get_json_result
|
30 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
31 |
from graphrag.mind_map_extractor import MindMapExtractor
|
@@ -87,7 +87,7 @@ def get():
|
|
87 |
else:
|
88 |
return get_json_result(
|
89 |
data=False, message='Only owner of conversation authorized for this operation.',
|
90 |
-
code=RetCode.OPERATING_ERROR)
|
91 |
conv = conv.to_dict()
|
92 |
return get_json_result(data=conv)
|
93 |
except Exception as e:
|
@@ -110,7 +110,7 @@ def rm():
|
|
110 |
else:
|
111 |
return get_json_result(
|
112 |
data=False, message='Only owner of conversation authorized for this operation.',
|
113 |
-
code=RetCode.OPERATING_ERROR)
|
114 |
ConversationService.delete_by_id(cid)
|
115 |
return get_json_result(data=True)
|
116 |
except Exception as e:
|
@@ -125,7 +125,7 @@ def list_convsersation():
|
|
125 |
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
|
126 |
return get_json_result(
|
127 |
data=False, message='Only owner of dialog authorized for this operation.',
|
128 |
-
code=RetCode.OPERATING_ERROR)
|
129 |
convs = ConversationService.query(
|
130 |
dialog_id=dialog_id,
|
131 |
order_by=ConversationService.model.create_time,
|
@@ -297,6 +297,7 @@ def thumbup():
|
|
297 |
def ask_about():
|
298 |
req = request.json
|
299 |
uid = current_user.id
|
|
|
300 |
def stream():
|
301 |
nonlocal req, uid
|
302 |
try:
|
@@ -329,8 +330,8 @@ def mindmap():
|
|
329 |
embd_mdl = TenantLLMService.model_instance(
|
330 |
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
331 |
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
|
332 |
-
ranks = retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12,
|
333 |
-
|
334 |
mindmap = MindMapExtractor(chat_mdl)
|
335 |
mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
|
336 |
if "error" in mind_map:
|
|
|
25 |
from api.db.services.dialog_service import DialogService, ConversationService, chat, ask
|
26 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
27 |
from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
|
28 |
+
from api import settings
|
29 |
from api.utils.api_utils import get_json_result
|
30 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
31 |
from graphrag.mind_map_extractor import MindMapExtractor
|
|
|
87 |
else:
|
88 |
return get_json_result(
|
89 |
data=False, message='Only owner of conversation authorized for this operation.',
|
90 |
+
code=settings.RetCode.OPERATING_ERROR)
|
91 |
conv = conv.to_dict()
|
92 |
return get_json_result(data=conv)
|
93 |
except Exception as e:
|
|
|
110 |
else:
|
111 |
return get_json_result(
|
112 |
data=False, message='Only owner of conversation authorized for this operation.',
|
113 |
+
code=settings.RetCode.OPERATING_ERROR)
|
114 |
ConversationService.delete_by_id(cid)
|
115 |
return get_json_result(data=True)
|
116 |
except Exception as e:
|
|
|
125 |
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
|
126 |
return get_json_result(
|
127 |
data=False, message='Only owner of dialog authorized for this operation.',
|
128 |
+
code=settings.RetCode.OPERATING_ERROR)
|
129 |
convs = ConversationService.query(
|
130 |
dialog_id=dialog_id,
|
131 |
order_by=ConversationService.model.create_time,
|
|
|
297 |
def ask_about():
|
298 |
req = request.json
|
299 |
uid = current_user.id
|
300 |
+
|
301 |
def stream():
|
302 |
nonlocal req, uid
|
303 |
try:
|
|
|
330 |
embd_mdl = TenantLLMService.model_instance(
|
331 |
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
332 |
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
|
333 |
+
ranks = settings.retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12,
|
334 |
+
0.3, 0.3, aggs=False)
|
335 |
mindmap = MindMapExtractor(chat_mdl)
|
336 |
mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
|
337 |
if "error" in mind_map:
|
api/apps/dialog_app.py
CHANGED
@@ -20,7 +20,7 @@ from api.db.services.dialog_service import DialogService
|
|
20 |
from api.db import StatusEnum
|
21 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
22 |
from api.db.services.user_service import TenantService, UserTenantService
|
23 |
-
from api
|
24 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
25 |
from api.utils import get_uuid
|
26 |
from api.utils.api_utils import get_json_result
|
@@ -175,7 +175,7 @@ def rm():
|
|
175 |
else:
|
176 |
return get_json_result(
|
177 |
data=False, message='Only owner of dialog authorized for this operation.',
|
178 |
-
code=RetCode.OPERATING_ERROR)
|
179 |
dialog_list.append({"id": id,"status":StatusEnum.INVALID.value})
|
180 |
DialogService.update_many_by_id(dialog_list)
|
181 |
return get_json_result(data=True)
|
|
|
20 |
from api.db import StatusEnum
|
21 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
22 |
from api.db.services.user_service import TenantService, UserTenantService
|
23 |
+
from api import settings
|
24 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
25 |
from api.utils import get_uuid
|
26 |
from api.utils.api_utils import get_json_result
|
|
|
175 |
else:
|
176 |
return get_json_result(
|
177 |
data=False, message='Only owner of dialog authorized for this operation.',
|
178 |
+
code=settings.RetCode.OPERATING_ERROR)
|
179 |
dialog_list.append({"id": id,"status":StatusEnum.INVALID.value})
|
180 |
DialogService.update_many_by_id(dialog_list)
|
181 |
return get_json_result(data=True)
|
api/apps/document_app.py
CHANGED
@@ -34,7 +34,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
|
|
34 |
from api.utils import get_uuid
|
35 |
from api.db import FileType, TaskStatus, ParserType, FileSource
|
36 |
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
37 |
-
from api
|
38 |
from api.utils.api_utils import get_json_result
|
39 |
from rag.utils.storage_factory import STORAGE_IMPL
|
40 |
from api.utils.file_utils import filename_type, thumbnail, get_project_base_directory
|
@@ -49,16 +49,16 @@ def upload():
|
|
49 |
kb_id = request.form.get("kb_id")
|
50 |
if not kb_id:
|
51 |
return get_json_result(
|
52 |
-
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
53 |
if 'file' not in request.files:
|
54 |
return get_json_result(
|
55 |
-
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
56 |
|
57 |
file_objs = request.files.getlist('file')
|
58 |
for file_obj in file_objs:
|
59 |
if file_obj.filename == '':
|
60 |
return get_json_result(
|
61 |
-
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
62 |
|
63 |
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
64 |
if not e:
|
@@ -67,7 +67,7 @@ def upload():
|
|
67 |
err, _ = FileService.upload_document(kb, file_objs, current_user.id)
|
68 |
if err:
|
69 |
return get_json_result(
|
70 |
-
data=False, message="\n".join(err), code=RetCode.SERVER_ERROR)
|
71 |
return get_json_result(data=True)
|
72 |
|
73 |
|
@@ -78,12 +78,12 @@ def web_crawl():
|
|
78 |
kb_id = request.form.get("kb_id")
|
79 |
if not kb_id:
|
80 |
return get_json_result(
|
81 |
-
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
82 |
name = request.form.get("name")
|
83 |
url = request.form.get("url")
|
84 |
if not is_valid_url(url):
|
85 |
return get_json_result(
|
86 |
-
data=False, message='The URL format is invalid', code=RetCode.ARGUMENT_ERROR)
|
87 |
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
88 |
if not e:
|
89 |
raise LookupError("Can't find this knowledgebase!")
|
@@ -145,7 +145,7 @@ def create():
|
|
145 |
kb_id = req["kb_id"]
|
146 |
if not kb_id:
|
147 |
return get_json_result(
|
148 |
-
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
149 |
|
150 |
try:
|
151 |
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
@@ -179,7 +179,7 @@ def list_docs():
|
|
179 |
kb_id = request.args.get("kb_id")
|
180 |
if not kb_id:
|
181 |
return get_json_result(
|
182 |
-
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
183 |
tenants = UserTenantService.query(user_id=current_user.id)
|
184 |
for tenant in tenants:
|
185 |
if KnowledgebaseService.query(
|
@@ -188,7 +188,7 @@ def list_docs():
|
|
188 |
else:
|
189 |
return get_json_result(
|
190 |
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
191 |
-
code=RetCode.OPERATING_ERROR)
|
192 |
keywords = request.args.get("keywords", "")
|
193 |
|
194 |
page_number = int(request.args.get("page", 1))
|
@@ -218,19 +218,19 @@ def docinfos():
|
|
218 |
return get_json_result(
|
219 |
data=False,
|
220 |
message='No authorization.',
|
221 |
-
code=RetCode.AUTHENTICATION_ERROR
|
222 |
)
|
223 |
docs = DocumentService.get_by_ids(doc_ids)
|
224 |
return get_json_result(data=list(docs.dicts()))
|
225 |
|
226 |
|
227 |
@manager.route('/thumbnails', methods=['GET'])
|
228 |
-
|
229 |
def thumbnails():
|
230 |
doc_ids = request.args.get("doc_ids").split(",")
|
231 |
if not doc_ids:
|
232 |
return get_json_result(
|
233 |
-
data=False, message='Lack of "Document ID"', code=RetCode.ARGUMENT_ERROR)
|
234 |
|
235 |
try:
|
236 |
docs = DocumentService.get_thumbnails(doc_ids)
|
@@ -253,13 +253,13 @@ def change_status():
|
|
253 |
return get_json_result(
|
254 |
data=False,
|
255 |
message='"Status" must be either 0 or 1!',
|
256 |
-
code=RetCode.ARGUMENT_ERROR)
|
257 |
|
258 |
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
259 |
return get_json_result(
|
260 |
data=False,
|
261 |
message='No authorization.',
|
262 |
-
code=RetCode.AUTHENTICATION_ERROR)
|
263 |
|
264 |
try:
|
265 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
@@ -276,7 +276,8 @@ def change_status():
|
|
276 |
message="Database error (Document update)!")
|
277 |
|
278 |
status = int(req["status"])
|
279 |
-
docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status},
|
|
|
280 |
return get_json_result(data=True)
|
281 |
except Exception as e:
|
282 |
return server_error_response(e)
|
@@ -295,7 +296,7 @@ def rm():
|
|
295 |
return get_json_result(
|
296 |
data=False,
|
297 |
message='No authorization.',
|
298 |
-
code=RetCode.AUTHENTICATION_ERROR
|
299 |
)
|
300 |
|
301 |
root_folder = FileService.get_root_folder(current_user.id)
|
@@ -326,7 +327,7 @@ def rm():
|
|
326 |
errors += str(e)
|
327 |
|
328 |
if errors:
|
329 |
-
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
330 |
|
331 |
return get_json_result(data=True)
|
332 |
|
@@ -341,7 +342,7 @@ def run():
|
|
341 |
return get_json_result(
|
342 |
data=False,
|
343 |
message='No authorization.',
|
344 |
-
code=RetCode.AUTHENTICATION_ERROR
|
345 |
)
|
346 |
try:
|
347 |
for id in req["doc_ids"]:
|
@@ -358,8 +359,8 @@ def run():
|
|
358 |
e, doc = DocumentService.get_by_id(id)
|
359 |
if not e:
|
360 |
return get_data_error_result(message="Document not found!")
|
361 |
-
if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
362 |
-
docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
363 |
|
364 |
if str(req["run"]) == TaskStatus.RUNNING.value:
|
365 |
TaskService.filter_delete([Task.doc_id == id])
|
@@ -383,7 +384,7 @@ def rename():
|
|
383 |
return get_json_result(
|
384 |
data=False,
|
385 |
message='No authorization.',
|
386 |
-
code=RetCode.AUTHENTICATION_ERROR
|
387 |
)
|
388 |
try:
|
389 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
@@ -394,7 +395,7 @@ def rename():
|
|
394 |
return get_json_result(
|
395 |
data=False,
|
396 |
message="The extension of file can't be changed",
|
397 |
-
code=RetCode.ARGUMENT_ERROR)
|
398 |
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
399 |
if d.name == req["name"]:
|
400 |
return get_data_error_result(
|
@@ -450,7 +451,7 @@ def change_parser():
|
|
450 |
return get_json_result(
|
451 |
data=False,
|
452 |
message='No authorization.',
|
453 |
-
code=RetCode.AUTHENTICATION_ERROR
|
454 |
)
|
455 |
try:
|
456 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
@@ -483,8 +484,8 @@ def change_parser():
|
|
483 |
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
484 |
if not tenant_id:
|
485 |
return get_data_error_result(message="Tenant not found!")
|
486 |
-
if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
487 |
-
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
488 |
|
489 |
return get_json_result(data=True)
|
490 |
except Exception as e:
|
@@ -509,13 +510,13 @@ def get_image(image_id):
|
|
509 |
def upload_and_parse():
|
510 |
if 'file' not in request.files:
|
511 |
return get_json_result(
|
512 |
-
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
513 |
|
514 |
file_objs = request.files.getlist('file')
|
515 |
for file_obj in file_objs:
|
516 |
if file_obj.filename == '':
|
517 |
return get_json_result(
|
518 |
-
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
519 |
|
520 |
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id)
|
521 |
|
@@ -529,7 +530,7 @@ def parse():
|
|
529 |
if url:
|
530 |
if not is_valid_url(url):
|
531 |
return get_json_result(
|
532 |
-
data=False, message='The URL format is invalid', code=RetCode.ARGUMENT_ERROR)
|
533 |
download_path = os.path.join(get_project_base_directory(), "logs/downloads")
|
534 |
os.makedirs(download_path, exist_ok=True)
|
535 |
from selenium.webdriver import Chrome, ChromeOptions
|
@@ -553,7 +554,7 @@ def parse():
|
|
553 |
|
554 |
if 'file' not in request.files:
|
555 |
return get_json_result(
|
556 |
-
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
557 |
|
558 |
file_objs = request.files.getlist('file')
|
559 |
txt = FileService.parse_docs(file_objs, current_user.id)
|
|
|
34 |
from api.utils import get_uuid
|
35 |
from api.db import FileType, TaskStatus, ParserType, FileSource
|
36 |
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
37 |
+
from api import settings
|
38 |
from api.utils.api_utils import get_json_result
|
39 |
from rag.utils.storage_factory import STORAGE_IMPL
|
40 |
from api.utils.file_utils import filename_type, thumbnail, get_project_base_directory
|
|
|
49 |
kb_id = request.form.get("kb_id")
|
50 |
if not kb_id:
|
51 |
return get_json_result(
|
52 |
+
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
53 |
if 'file' not in request.files:
|
54 |
return get_json_result(
|
55 |
+
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
56 |
|
57 |
file_objs = request.files.getlist('file')
|
58 |
for file_obj in file_objs:
|
59 |
if file_obj.filename == '':
|
60 |
return get_json_result(
|
61 |
+
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
62 |
|
63 |
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
64 |
if not e:
|
|
|
67 |
err, _ = FileService.upload_document(kb, file_objs, current_user.id)
|
68 |
if err:
|
69 |
return get_json_result(
|
70 |
+
data=False, message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
|
71 |
return get_json_result(data=True)
|
72 |
|
73 |
|
|
|
78 |
kb_id = request.form.get("kb_id")
|
79 |
if not kb_id:
|
80 |
return get_json_result(
|
81 |
+
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
82 |
name = request.form.get("name")
|
83 |
url = request.form.get("url")
|
84 |
if not is_valid_url(url):
|
85 |
return get_json_result(
|
86 |
+
data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR)
|
87 |
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
88 |
if not e:
|
89 |
raise LookupError("Can't find this knowledgebase!")
|
|
|
145 |
kb_id = req["kb_id"]
|
146 |
if not kb_id:
|
147 |
return get_json_result(
|
148 |
+
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
149 |
|
150 |
try:
|
151 |
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
|
|
179 |
kb_id = request.args.get("kb_id")
|
180 |
if not kb_id:
|
181 |
return get_json_result(
|
182 |
+
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
183 |
tenants = UserTenantService.query(user_id=current_user.id)
|
184 |
for tenant in tenants:
|
185 |
if KnowledgebaseService.query(
|
|
|
188 |
else:
|
189 |
return get_json_result(
|
190 |
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
191 |
+
code=settings.RetCode.OPERATING_ERROR)
|
192 |
keywords = request.args.get("keywords", "")
|
193 |
|
194 |
page_number = int(request.args.get("page", 1))
|
|
|
218 |
return get_json_result(
|
219 |
data=False,
|
220 |
message='No authorization.',
|
221 |
+
code=settings.RetCode.AUTHENTICATION_ERROR
|
222 |
)
|
223 |
docs = DocumentService.get_by_ids(doc_ids)
|
224 |
return get_json_result(data=list(docs.dicts()))
|
225 |
|
226 |
|
227 |
@manager.route('/thumbnails', methods=['GET'])
|
228 |
+
# @login_required
|
229 |
def thumbnails():
|
230 |
doc_ids = request.args.get("doc_ids").split(",")
|
231 |
if not doc_ids:
|
232 |
return get_json_result(
|
233 |
+
data=False, message='Lack of "Document ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
234 |
|
235 |
try:
|
236 |
docs = DocumentService.get_thumbnails(doc_ids)
|
|
|
253 |
return get_json_result(
|
254 |
data=False,
|
255 |
message='"Status" must be either 0 or 1!',
|
256 |
+
code=settings.RetCode.ARGUMENT_ERROR)
|
257 |
|
258 |
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
259 |
return get_json_result(
|
260 |
data=False,
|
261 |
message='No authorization.',
|
262 |
+
code=settings.RetCode.AUTHENTICATION_ERROR)
|
263 |
|
264 |
try:
|
265 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
|
|
276 |
message="Database error (Document update)!")
|
277 |
|
278 |
status = int(req["status"])
|
279 |
+
settings.docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status},
|
280 |
+
search.index_name(kb.tenant_id), doc.kb_id)
|
281 |
return get_json_result(data=True)
|
282 |
except Exception as e:
|
283 |
return server_error_response(e)
|
|
|
296 |
return get_json_result(
|
297 |
data=False,
|
298 |
message='No authorization.',
|
299 |
+
code=settings.RetCode.AUTHENTICATION_ERROR
|
300 |
)
|
301 |
|
302 |
root_folder = FileService.get_root_folder(current_user.id)
|
|
|
327 |
errors += str(e)
|
328 |
|
329 |
if errors:
|
330 |
+
return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)
|
331 |
|
332 |
return get_json_result(data=True)
|
333 |
|
|
|
342 |
return get_json_result(
|
343 |
data=False,
|
344 |
message='No authorization.',
|
345 |
+
code=settings.RetCode.AUTHENTICATION_ERROR
|
346 |
)
|
347 |
try:
|
348 |
for id in req["doc_ids"]:
|
|
|
359 |
e, doc = DocumentService.get_by_id(id)
|
360 |
if not e:
|
361 |
return get_data_error_result(message="Document not found!")
|
362 |
+
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
363 |
+
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
364 |
|
365 |
if str(req["run"]) == TaskStatus.RUNNING.value:
|
366 |
TaskService.filter_delete([Task.doc_id == id])
|
|
|
384 |
return get_json_result(
|
385 |
data=False,
|
386 |
message='No authorization.',
|
387 |
+
code=settings.RetCode.AUTHENTICATION_ERROR
|
388 |
)
|
389 |
try:
|
390 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
|
|
395 |
return get_json_result(
|
396 |
data=False,
|
397 |
message="The extension of file can't be changed",
|
398 |
+
code=settings.RetCode.ARGUMENT_ERROR)
|
399 |
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
400 |
if d.name == req["name"]:
|
401 |
return get_data_error_result(
|
|
|
451 |
return get_json_result(
|
452 |
data=False,
|
453 |
message='No authorization.',
|
454 |
+
code=settings.RetCode.AUTHENTICATION_ERROR
|
455 |
)
|
456 |
try:
|
457 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
|
|
484 |
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
485 |
if not tenant_id:
|
486 |
return get_data_error_result(message="Tenant not found!")
|
487 |
+
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
488 |
+
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
489 |
|
490 |
return get_json_result(data=True)
|
491 |
except Exception as e:
|
|
|
510 |
def upload_and_parse():
|
511 |
if 'file' not in request.files:
|
512 |
return get_json_result(
|
513 |
+
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
514 |
|
515 |
file_objs = request.files.getlist('file')
|
516 |
for file_obj in file_objs:
|
517 |
if file_obj.filename == '':
|
518 |
return get_json_result(
|
519 |
+
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
520 |
|
521 |
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id)
|
522 |
|
|
|
530 |
if url:
|
531 |
if not is_valid_url(url):
|
532 |
return get_json_result(
|
533 |
+
data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR)
|
534 |
download_path = os.path.join(get_project_base_directory(), "logs/downloads")
|
535 |
os.makedirs(download_path, exist_ok=True)
|
536 |
from selenium.webdriver import Chrome, ChromeOptions
|
|
|
554 |
|
555 |
if 'file' not in request.files:
|
556 |
return get_json_result(
|
557 |
+
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
558 |
|
559 |
file_objs = request.files.getlist('file')
|
560 |
txt = FileService.parse_docs(file_objs, current_user.id)
|
api/apps/file2document_app.py
CHANGED
@@ -24,7 +24,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
|
|
24 |
from api.utils import get_uuid
|
25 |
from api.db import FileType
|
26 |
from api.db.services.document_service import DocumentService
|
27 |
-
from api
|
28 |
from api.utils.api_utils import get_json_result
|
29 |
|
30 |
|
@@ -100,7 +100,7 @@ def rm():
|
|
100 |
file_ids = req["file_ids"]
|
101 |
if not file_ids:
|
102 |
return get_json_result(
|
103 |
-
data=False, message='Lack of "Files ID"', code=RetCode.ARGUMENT_ERROR)
|
104 |
try:
|
105 |
for file_id in file_ids:
|
106 |
informs = File2DocumentService.get_by_file_id(file_id)
|
|
|
24 |
from api.utils import get_uuid
|
25 |
from api.db import FileType
|
26 |
from api.db.services.document_service import DocumentService
|
27 |
+
from api import settings
|
28 |
from api.utils.api_utils import get_json_result
|
29 |
|
30 |
|
|
|
100 |
file_ids = req["file_ids"]
|
101 |
if not file_ids:
|
102 |
return get_json_result(
|
103 |
+
data=False, message='Lack of "Files ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
104 |
try:
|
105 |
for file_id in file_ids:
|
106 |
informs = File2DocumentService.get_by_file_id(file_id)
|
api/apps/file_app.py
CHANGED
@@ -28,7 +28,7 @@ from api.utils import get_uuid
|
|
28 |
from api.db import FileType, FileSource
|
29 |
from api.db.services import duplicate_name
|
30 |
from api.db.services.file_service import FileService
|
31 |
-
from api
|
32 |
from api.utils.api_utils import get_json_result
|
33 |
from api.utils.file_utils import filename_type
|
34 |
from rag.utils.storage_factory import STORAGE_IMPL
|
@@ -46,13 +46,13 @@ def upload():
|
|
46 |
|
47 |
if 'file' not in request.files:
|
48 |
return get_json_result(
|
49 |
-
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
50 |
file_objs = request.files.getlist('file')
|
51 |
|
52 |
for file_obj in file_objs:
|
53 |
if file_obj.filename == '':
|
54 |
return get_json_result(
|
55 |
-
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
56 |
file_res = []
|
57 |
try:
|
58 |
for file_obj in file_objs:
|
@@ -134,7 +134,7 @@ def create():
|
|
134 |
try:
|
135 |
if not FileService.is_parent_folder_exist(pf_id):
|
136 |
return get_json_result(
|
137 |
-
data=False, message="Parent Folder Doesn't Exist!", code=RetCode.OPERATING_ERROR)
|
138 |
if FileService.query(name=req["name"], parent_id=pf_id):
|
139 |
return get_data_error_result(
|
140 |
message="Duplicated folder name in the same folder.")
|
@@ -299,7 +299,7 @@ def rename():
|
|
299 |
return get_json_result(
|
300 |
data=False,
|
301 |
message="The extension of file can't be changed",
|
302 |
-
code=RetCode.ARGUMENT_ERROR)
|
303 |
for file in FileService.query(name=req["name"], pf_id=file.parent_id):
|
304 |
if file.name == req["name"]:
|
305 |
return get_data_error_result(
|
|
|
28 |
from api.db import FileType, FileSource
|
29 |
from api.db.services import duplicate_name
|
30 |
from api.db.services.file_service import FileService
|
31 |
+
from api import settings
|
32 |
from api.utils.api_utils import get_json_result
|
33 |
from api.utils.file_utils import filename_type
|
34 |
from rag.utils.storage_factory import STORAGE_IMPL
|
|
|
46 |
|
47 |
if 'file' not in request.files:
|
48 |
return get_json_result(
|
49 |
+
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
50 |
file_objs = request.files.getlist('file')
|
51 |
|
52 |
for file_obj in file_objs:
|
53 |
if file_obj.filename == '':
|
54 |
return get_json_result(
|
55 |
+
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
56 |
file_res = []
|
57 |
try:
|
58 |
for file_obj in file_objs:
|
|
|
134 |
try:
|
135 |
if not FileService.is_parent_folder_exist(pf_id):
|
136 |
return get_json_result(
|
137 |
+
data=False, message="Parent Folder Doesn't Exist!", code=settings.RetCode.OPERATING_ERROR)
|
138 |
if FileService.query(name=req["name"], parent_id=pf_id):
|
139 |
return get_data_error_result(
|
140 |
message="Duplicated folder name in the same folder.")
|
|
|
299 |
return get_json_result(
|
300 |
data=False,
|
301 |
message="The extension of file can't be changed",
|
302 |
+
code=settings.RetCode.ARGUMENT_ERROR)
|
303 |
for file in FileService.query(name=req["name"], pf_id=file.parent_id):
|
304 |
if file.name == req["name"]:
|
305 |
return get_data_error_result(
|
api/apps/kb_app.py
CHANGED
@@ -26,9 +26,8 @@ from api.utils import get_uuid
|
|
26 |
from api.db import StatusEnum, FileSource
|
27 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
28 |
from api.db.db_models import File
|
29 |
-
from api.settings import RetCode
|
30 |
from api.utils.api_utils import get_json_result
|
31 |
-
from api
|
32 |
from rag.nlp import search
|
33 |
|
34 |
|
@@ -68,13 +67,13 @@ def update():
|
|
68 |
return get_json_result(
|
69 |
data=False,
|
70 |
message='No authorization.',
|
71 |
-
code=RetCode.AUTHENTICATION_ERROR
|
72 |
)
|
73 |
try:
|
74 |
if not KnowledgebaseService.query(
|
75 |
created_by=current_user.id, id=req["kb_id"]):
|
76 |
return get_json_result(
|
77 |
-
data=False, message='Only owner of knowledgebase authorized for this operation.', code=RetCode.OPERATING_ERROR)
|
78 |
|
79 |
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
|
80 |
if not e:
|
@@ -113,7 +112,7 @@ def detail():
|
|
113 |
else:
|
114 |
return get_json_result(
|
115 |
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
116 |
-
code=RetCode.OPERATING_ERROR)
|
117 |
kb = KnowledgebaseService.get_detail(kb_id)
|
118 |
if not kb:
|
119 |
return get_data_error_result(
|
@@ -148,14 +147,14 @@ def rm():
|
|
148 |
return get_json_result(
|
149 |
data=False,
|
150 |
message='No authorization.',
|
151 |
-
code=RetCode.AUTHENTICATION_ERROR
|
152 |
)
|
153 |
try:
|
154 |
kbs = KnowledgebaseService.query(
|
155 |
created_by=current_user.id, id=req["kb_id"])
|
156 |
if not kbs:
|
157 |
return get_json_result(
|
158 |
-
data=False, message='Only owner of knowledgebase authorized for this operation.', code=RetCode.OPERATING_ERROR)
|
159 |
|
160 |
for doc in DocumentService.query(kb_id=req["kb_id"]):
|
161 |
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
|
@@ -170,7 +169,7 @@ def rm():
|
|
170 |
message="Database error (Knowledgebase removal)!")
|
171 |
tenants = UserTenantService.query(user_id=current_user.id)
|
172 |
for tenant in tenants:
|
173 |
-
docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"])
|
174 |
return get_json_result(data=True)
|
175 |
except Exception as e:
|
176 |
return server_error_response(e)
|
|
|
26 |
from api.db import StatusEnum, FileSource
|
27 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
28 |
from api.db.db_models import File
|
|
|
29 |
from api.utils.api_utils import get_json_result
|
30 |
+
from api import settings
|
31 |
from rag.nlp import search
|
32 |
|
33 |
|
|
|
67 |
return get_json_result(
|
68 |
data=False,
|
69 |
message='No authorization.',
|
70 |
+
code=settings.RetCode.AUTHENTICATION_ERROR
|
71 |
)
|
72 |
try:
|
73 |
if not KnowledgebaseService.query(
|
74 |
created_by=current_user.id, id=req["kb_id"]):
|
75 |
return get_json_result(
|
76 |
+
data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR)
|
77 |
|
78 |
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
|
79 |
if not e:
|
|
|
112 |
else:
|
113 |
return get_json_result(
|
114 |
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
115 |
+
code=settings.RetCode.OPERATING_ERROR)
|
116 |
kb = KnowledgebaseService.get_detail(kb_id)
|
117 |
if not kb:
|
118 |
return get_data_error_result(
|
|
|
147 |
return get_json_result(
|
148 |
data=False,
|
149 |
message='No authorization.',
|
150 |
+
code=settings.RetCode.AUTHENTICATION_ERROR
|
151 |
)
|
152 |
try:
|
153 |
kbs = KnowledgebaseService.query(
|
154 |
created_by=current_user.id, id=req["kb_id"])
|
155 |
if not kbs:
|
156 |
return get_json_result(
|
157 |
+
data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR)
|
158 |
|
159 |
for doc in DocumentService.query(kb_id=req["kb_id"]):
|
160 |
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
|
|
|
169 |
message="Database error (Knowledgebase removal)!")
|
170 |
tenants = UserTenantService.query(user_id=current_user.id)
|
171 |
for tenant in tenants:
|
172 |
+
settings.docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"])
|
173 |
return get_json_result(data=True)
|
174 |
except Exception as e:
|
175 |
return server_error_response(e)
|
api/apps/llm_app.py
CHANGED
@@ -19,7 +19,7 @@ import json
|
|
19 |
from flask import request
|
20 |
from flask_login import login_required, current_user
|
21 |
from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
|
22 |
-
from api
|
23 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
24 |
from api.db import StatusEnum, LLMType
|
25 |
from api.db.db_models import TenantLLM
|
@@ -333,7 +333,7 @@ def my_llms():
|
|
333 |
@login_required
|
334 |
def list_app():
|
335 |
self_deploied = ["Youdao","FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"]
|
336 |
-
weighted = ["Youdao","FastEmbed", "BAAI"] if LIGHTEN != 0 else []
|
337 |
model_type = request.args.get("model_type")
|
338 |
try:
|
339 |
objs = TenantLLMService.query(tenant_id=current_user.id)
|
|
|
19 |
from flask import request
|
20 |
from flask_login import login_required, current_user
|
21 |
from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
|
22 |
+
from api import settings
|
23 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
24 |
from api.db import StatusEnum, LLMType
|
25 |
from api.db.db_models import TenantLLM
|
|
|
333 |
@login_required
|
334 |
def list_app():
|
335 |
self_deploied = ["Youdao","FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"]
|
336 |
+
weighted = ["Youdao","FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else []
|
337 |
model_type = request.args.get("model_type")
|
338 |
try:
|
339 |
objs = TenantLLMService.query(tenant_id=current_user.id)
|
api/apps/sdk/chat.py
CHANGED
@@ -14,7 +14,7 @@
|
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
from flask import request
|
17 |
-
from api
|
18 |
from api.db import StatusEnum
|
19 |
from api.db.services.dialog_service import DialogService
|
20 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
@@ -44,7 +44,7 @@ def create(tenant_id):
|
|
44 |
kbs = KnowledgebaseService.get_by_ids(ids)
|
45 |
embd_count = list(set([kb.embd_id for kb in kbs]))
|
46 |
if len(embd_count) != 1:
|
47 |
-
return get_result(message='Datasets use different embedding models."',code=RetCode.AUTHENTICATION_ERROR)
|
48 |
req["kb_ids"] = ids
|
49 |
# llm
|
50 |
llm = req.get("llm")
|
@@ -173,7 +173,7 @@ def update(tenant_id,chat_id):
|
|
173 |
if len(embd_count) != 1 :
|
174 |
return get_result(
|
175 |
message='Datasets use different embedding models."',
|
176 |
-
code=RetCode.AUTHENTICATION_ERROR)
|
177 |
req["kb_ids"] = ids
|
178 |
llm = req.get("llm")
|
179 |
if llm:
|
|
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
from flask import request
|
17 |
+
from api import settings
|
18 |
from api.db import StatusEnum
|
19 |
from api.db.services.dialog_service import DialogService
|
20 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
|
44 |
kbs = KnowledgebaseService.get_by_ids(ids)
|
45 |
embd_count = list(set([kb.embd_id for kb in kbs]))
|
46 |
if len(embd_count) != 1:
|
47 |
+
return get_result(message='Datasets use different embedding models."',code=settings.RetCode.AUTHENTICATION_ERROR)
|
48 |
req["kb_ids"] = ids
|
49 |
# llm
|
50 |
llm = req.get("llm")
|
|
|
173 |
if len(embd_count) != 1 :
|
174 |
return get_result(
|
175 |
message='Datasets use different embedding models."',
|
176 |
+
code=settings.RetCode.AUTHENTICATION_ERROR)
|
177 |
req["kb_ids"] = ids
|
178 |
llm = req.get("llm")
|
179 |
if llm:
|
api/apps/sdk/dataset.py
CHANGED
@@ -23,7 +23,7 @@ from api.db.services.file_service import FileService
|
|
23 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
24 |
from api.db.services.llm_service import TenantLLMService, LLMService
|
25 |
from api.db.services.user_service import TenantService
|
26 |
-
from api
|
27 |
from api.utils import get_uuid
|
28 |
from api.utils.api_utils import (
|
29 |
get_result,
|
@@ -255,7 +255,7 @@ def delete(tenant_id):
|
|
255 |
File2DocumentService.delete_by_document_id(doc.id)
|
256 |
if not KnowledgebaseService.delete_by_id(id):
|
257 |
return get_error_data_result(message="Delete dataset error.(Database error)")
|
258 |
-
return get_result(code=RetCode.SUCCESS)
|
259 |
|
260 |
|
261 |
@manager.route("/datasets/<dataset_id>", methods=["PUT"])
|
@@ -424,7 +424,7 @@ def update(tenant_id, dataset_id):
|
|
424 |
)
|
425 |
if not KnowledgebaseService.update_by_id(kb.id, req):
|
426 |
return get_error_data_result(message="Update dataset error.(Database error)")
|
427 |
-
return get_result(code=RetCode.SUCCESS)
|
428 |
|
429 |
|
430 |
@manager.route("/datasets", methods=["GET"])
|
|
|
23 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
24 |
from api.db.services.llm_service import TenantLLMService, LLMService
|
25 |
from api.db.services.user_service import TenantService
|
26 |
+
from api import settings
|
27 |
from api.utils import get_uuid
|
28 |
from api.utils.api_utils import (
|
29 |
get_result,
|
|
|
255 |
File2DocumentService.delete_by_document_id(doc.id)
|
256 |
if not KnowledgebaseService.delete_by_id(id):
|
257 |
return get_error_data_result(message="Delete dataset error.(Database error)")
|
258 |
+
return get_result(code=settings.RetCode.SUCCESS)
|
259 |
|
260 |
|
261 |
@manager.route("/datasets/<dataset_id>", methods=["PUT"])
|
|
|
424 |
)
|
425 |
if not KnowledgebaseService.update_by_id(kb.id, req):
|
426 |
return get_error_data_result(message="Update dataset error.(Database error)")
|
427 |
+
return get_result(code=settings.RetCode.SUCCESS)
|
428 |
|
429 |
|
430 |
@manager.route("/datasets", methods=["GET"])
|
api/apps/sdk/dify_retrieval.py
CHANGED
@@ -18,7 +18,7 @@ from flask import request, jsonify
|
|
18 |
from api.db import LLMType, ParserType
|
19 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
20 |
from api.db.services.llm_service import LLMBundle
|
21 |
-
from api
|
22 |
from api.utils.api_utils import validate_request, build_error_result, apikey_required
|
23 |
|
24 |
|
@@ -37,14 +37,14 @@ def retrieval(tenant_id):
|
|
37 |
|
38 |
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
39 |
if not e:
|
40 |
-
return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND)
|
41 |
|
42 |
if kb.tenant_id != tenant_id:
|
43 |
-
return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND)
|
44 |
|
45 |
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
46 |
|
47 |
-
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
48 |
ranks = retr.retrieval(
|
49 |
question,
|
50 |
embd_mdl,
|
@@ -72,6 +72,6 @@ def retrieval(tenant_id):
|
|
72 |
if str(e).find("not_found") > 0:
|
73 |
return build_error_result(
|
74 |
message='No chunk found! Check the chunk status please!',
|
75 |
-
code=RetCode.NOT_FOUND
|
76 |
)
|
77 |
-
return build_error_result(message=str(e), code=RetCode.SERVER_ERROR)
|
|
|
18 |
from api.db import LLMType, ParserType
|
19 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
20 |
from api.db.services.llm_service import LLMBundle
|
21 |
+
from api import settings
|
22 |
from api.utils.api_utils import validate_request, build_error_result, apikey_required
|
23 |
|
24 |
|
|
|
37 |
|
38 |
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
39 |
if not e:
|
40 |
+
return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND)
|
41 |
|
42 |
if kb.tenant_id != tenant_id:
|
43 |
+
return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND)
|
44 |
|
45 |
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
46 |
|
47 |
+
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
|
48 |
ranks = retr.retrieval(
|
49 |
question,
|
50 |
embd_mdl,
|
|
|
72 |
if str(e).find("not_found") > 0:
|
73 |
return build_error_result(
|
74 |
message='No chunk found! Check the chunk status please!',
|
75 |
+
code=settings.RetCode.NOT_FOUND
|
76 |
)
|
77 |
+
return build_error_result(message=str(e), code=settings.RetCode.SERVER_ERROR)
|
api/apps/sdk/doc.py
CHANGED
@@ -21,7 +21,7 @@ from rag.app.qa import rmPrefix, beAdoc
|
|
21 |
from rag.nlp import rag_tokenizer
|
22 |
from api.db import LLMType, ParserType
|
23 |
from api.db.services.llm_service import TenantLLMService
|
24 |
-
from api
|
25 |
import hashlib
|
26 |
import re
|
27 |
from api.utils.api_utils import token_required
|
@@ -37,11 +37,10 @@ from api.db.services.document_service import DocumentService
|
|
37 |
from api.db.services.file2document_service import File2DocumentService
|
38 |
from api.db.services.file_service import FileService
|
39 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
40 |
-
from api
|
41 |
from api.utils.api_utils import construct_json_result, get_parser_config
|
42 |
from rag.nlp import search
|
43 |
from rag.utils import rmSpace
|
44 |
-
from api.settings import docStoreConn
|
45 |
from rag.utils.storage_factory import STORAGE_IMPL
|
46 |
import os
|
47 |
|
@@ -109,13 +108,13 @@ def upload(dataset_id, tenant_id):
|
|
109 |
"""
|
110 |
if "file" not in request.files:
|
111 |
return get_error_data_result(
|
112 |
-
message="No file part!", code=RetCode.ARGUMENT_ERROR
|
113 |
)
|
114 |
file_objs = request.files.getlist("file")
|
115 |
for file_obj in file_objs:
|
116 |
if file_obj.filename == "":
|
117 |
return get_result(
|
118 |
-
message="No file selected!", code=RetCode.ARGUMENT_ERROR
|
119 |
)
|
120 |
# total size
|
121 |
total_size = 0
|
@@ -127,14 +126,14 @@ def upload(dataset_id, tenant_id):
|
|
127 |
if total_size > MAX_TOTAL_FILE_SIZE:
|
128 |
return get_result(
|
129 |
message=f"Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)",
|
130 |
-
code=RetCode.ARGUMENT_ERROR,
|
131 |
)
|
132 |
e, kb = KnowledgebaseService.get_by_id(dataset_id)
|
133 |
if not e:
|
134 |
raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
|
135 |
err, files = FileService.upload_document(kb, file_objs, tenant_id)
|
136 |
if err:
|
137 |
-
return get_result(message="\n".join(err), code=RetCode.SERVER_ERROR)
|
138 |
# rename key's name
|
139 |
renamed_doc_list = []
|
140 |
for file in files:
|
@@ -221,12 +220,12 @@ def update_doc(tenant_id, dataset_id, document_id):
|
|
221 |
|
222 |
if "name" in req and req["name"] != doc.name:
|
223 |
if (
|
224 |
-
|
225 |
-
|
226 |
):
|
227 |
return get_result(
|
228 |
message="The extension of file can't be changed",
|
229 |
-
code=RetCode.ARGUMENT_ERROR,
|
230 |
)
|
231 |
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
232 |
if d.name == req["name"]:
|
@@ -292,7 +291,7 @@ def update_doc(tenant_id, dataset_id, document_id):
|
|
292 |
)
|
293 |
if not e:
|
294 |
return get_error_data_result(message="Document not found!")
|
295 |
-
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
296 |
|
297 |
return get_result()
|
298 |
|
@@ -349,7 +348,7 @@ def download(tenant_id, dataset_id, document_id):
|
|
349 |
file_stream = STORAGE_IMPL.get(doc_id, doc_location)
|
350 |
if not file_stream:
|
351 |
return construct_json_result(
|
352 |
-
message="This file is empty.", code=RetCode.DATA_ERROR
|
353 |
)
|
354 |
file = BytesIO(file_stream)
|
355 |
# Use send_file with a proper filename and MIME type
|
@@ -582,7 +581,7 @@ def delete(tenant_id, dataset_id):
|
|
582 |
errors += str(e)
|
583 |
|
584 |
if errors:
|
585 |
-
return get_result(message=errors, code=RetCode.SERVER_ERROR)
|
586 |
|
587 |
return get_result()
|
588 |
|
@@ -644,7 +643,7 @@ def parse(tenant_id, dataset_id):
|
|
644 |
info["chunk_num"] = 0
|
645 |
info["token_num"] = 0
|
646 |
DocumentService.update_by_id(id, info)
|
647 |
-
docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
|
648 |
TaskService.filter_delete([Task.doc_id == id])
|
649 |
e, doc = DocumentService.get_by_id(id)
|
650 |
doc = doc.to_dict()
|
@@ -708,7 +707,7 @@ def stop_parsing(tenant_id, dataset_id):
|
|
708 |
)
|
709 |
info = {"run": "2", "progress": 0, "chunk_num": 0}
|
710 |
DocumentService.update_by_id(id, info)
|
711 |
-
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
712 |
return get_result()
|
713 |
|
714 |
|
@@ -828,8 +827,9 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
|
828 |
|
829 |
res = {"total": 0, "chunks": [], "doc": renamed_doc}
|
830 |
origin_chunks = []
|
831 |
-
if docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
|
832 |
-
sres = retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None,
|
|
|
833 |
res["total"] = sres.total
|
834 |
sign = 0
|
835 |
for id in sres.ids:
|
@@ -1003,7 +1003,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
|
|
1003 |
v, c = embd_mdl.encode([doc.name, req["content"]])
|
1004 |
v = 0.1 * v[0] + 0.9 * v[1]
|
1005 |
d["q_%d_vec" % len(v)] = v.tolist()
|
1006 |
-
docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
|
1007 |
|
1008 |
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
|
1009 |
# rename keys
|
@@ -1078,7 +1078,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
|
|
1078 |
condition = {"doc_id": document_id}
|
1079 |
if "chunk_ids" in req:
|
1080 |
condition["id"] = req["chunk_ids"]
|
1081 |
-
chunk_number = docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
|
1082 |
if chunk_number != 0:
|
1083 |
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
|
1084 |
if "chunk_ids" in req and chunk_number != len(req["chunk_ids"]):
|
@@ -1143,7 +1143,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
|
1143 |
schema:
|
1144 |
type: object
|
1145 |
"""
|
1146 |
-
chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
|
1147 |
if chunk is None:
|
1148 |
return get_error_data_result(f"Can't find this chunk {chunk_id}")
|
1149 |
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
@@ -1187,7 +1187,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
|
1187 |
v, c = embd_mdl.encode([doc.name, d["content_with_weight"]])
|
1188 |
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
1189 |
d["q_%d_vec" % len(v)] = v.tolist()
|
1190 |
-
docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
|
1191 |
return get_result()
|
1192 |
|
1193 |
|
@@ -1285,7 +1285,7 @@ def retrieval_test(tenant_id):
|
|
1285 |
if len(embd_nms) != 1:
|
1286 |
return get_result(
|
1287 |
message='Datasets use different embedding models."',
|
1288 |
-
code=RetCode.AUTHENTICATION_ERROR,
|
1289 |
)
|
1290 |
if "question" not in req:
|
1291 |
return get_error_data_result("`question` is required.")
|
@@ -1326,7 +1326,7 @@ def retrieval_test(tenant_id):
|
|
1326 |
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
|
1327 |
question += keyword_extraction(chat_mdl, question)
|
1328 |
|
1329 |
-
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
1330 |
ranks = retr.retrieval(
|
1331 |
question,
|
1332 |
embd_mdl,
|
@@ -1366,6 +1366,6 @@ def retrieval_test(tenant_id):
|
|
1366 |
if str(e).find("not_found") > 0:
|
1367 |
return get_result(
|
1368 |
message="No chunk found! Check the chunk status please!",
|
1369 |
-
code=RetCode.DATA_ERROR,
|
1370 |
)
|
1371 |
-
return server_error_response(e)
|
|
|
21 |
from rag.nlp import rag_tokenizer
|
22 |
from api.db import LLMType, ParserType
|
23 |
from api.db.services.llm_service import TenantLLMService
|
24 |
+
from api import settings
|
25 |
import hashlib
|
26 |
import re
|
27 |
from api.utils.api_utils import token_required
|
|
|
37 |
from api.db.services.file2document_service import File2DocumentService
|
38 |
from api.db.services.file_service import FileService
|
39 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
40 |
+
from api import settings
|
41 |
from api.utils.api_utils import construct_json_result, get_parser_config
|
42 |
from rag.nlp import search
|
43 |
from rag.utils import rmSpace
|
|
|
44 |
from rag.utils.storage_factory import STORAGE_IMPL
|
45 |
import os
|
46 |
|
|
|
108 |
"""
|
109 |
if "file" not in request.files:
|
110 |
return get_error_data_result(
|
111 |
+
message="No file part!", code=settings.RetCode.ARGUMENT_ERROR
|
112 |
)
|
113 |
file_objs = request.files.getlist("file")
|
114 |
for file_obj in file_objs:
|
115 |
if file_obj.filename == "":
|
116 |
return get_result(
|
117 |
+
message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR
|
118 |
)
|
119 |
# total size
|
120 |
total_size = 0
|
|
|
126 |
if total_size > MAX_TOTAL_FILE_SIZE:
|
127 |
return get_result(
|
128 |
message=f"Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)",
|
129 |
+
code=settings.RetCode.ARGUMENT_ERROR,
|
130 |
)
|
131 |
e, kb = KnowledgebaseService.get_by_id(dataset_id)
|
132 |
if not e:
|
133 |
raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
|
134 |
err, files = FileService.upload_document(kb, file_objs, tenant_id)
|
135 |
if err:
|
136 |
+
return get_result(message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
|
137 |
# rename key's name
|
138 |
renamed_doc_list = []
|
139 |
for file in files:
|
|
|
220 |
|
221 |
if "name" in req and req["name"] != doc.name:
|
222 |
if (
|
223 |
+
pathlib.Path(req["name"].lower()).suffix
|
224 |
+
!= pathlib.Path(doc.name.lower()).suffix
|
225 |
):
|
226 |
return get_result(
|
227 |
message="The extension of file can't be changed",
|
228 |
+
code=settings.RetCode.ARGUMENT_ERROR,
|
229 |
)
|
230 |
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
231 |
if d.name == req["name"]:
|
|
|
291 |
)
|
292 |
if not e:
|
293 |
return get_error_data_result(message="Document not found!")
|
294 |
+
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
295 |
|
296 |
return get_result()
|
297 |
|
|
|
348 |
file_stream = STORAGE_IMPL.get(doc_id, doc_location)
|
349 |
if not file_stream:
|
350 |
return construct_json_result(
|
351 |
+
message="This file is empty.", code=settings.RetCode.DATA_ERROR
|
352 |
)
|
353 |
file = BytesIO(file_stream)
|
354 |
# Use send_file with a proper filename and MIME type
|
|
|
581 |
errors += str(e)
|
582 |
|
583 |
if errors:
|
584 |
+
return get_result(message=errors, code=settings.RetCode.SERVER_ERROR)
|
585 |
|
586 |
return get_result()
|
587 |
|
|
|
643 |
info["chunk_num"] = 0
|
644 |
info["token_num"] = 0
|
645 |
DocumentService.update_by_id(id, info)
|
646 |
+
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
|
647 |
TaskService.filter_delete([Task.doc_id == id])
|
648 |
e, doc = DocumentService.get_by_id(id)
|
649 |
doc = doc.to_dict()
|
|
|
707 |
)
|
708 |
info = {"run": "2", "progress": 0, "chunk_num": 0}
|
709 |
DocumentService.update_by_id(id, info)
|
710 |
+
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
711 |
return get_result()
|
712 |
|
713 |
|
|
|
827 |
|
828 |
res = {"total": 0, "chunks": [], "doc": renamed_doc}
|
829 |
origin_chunks = []
|
830 |
+
if settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
|
831 |
+
sres = settings.retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None,
|
832 |
+
highlight=True)
|
833 |
res["total"] = sres.total
|
834 |
sign = 0
|
835 |
for id in sres.ids:
|
|
|
1003 |
v, c = embd_mdl.encode([doc.name, req["content"]])
|
1004 |
v = 0.1 * v[0] + 0.9 * v[1]
|
1005 |
d["q_%d_vec" % len(v)] = v.tolist()
|
1006 |
+
settings.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
|
1007 |
|
1008 |
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
|
1009 |
# rename keys
|
|
|
1078 |
condition = {"doc_id": document_id}
|
1079 |
if "chunk_ids" in req:
|
1080 |
condition["id"] = req["chunk_ids"]
|
1081 |
+
chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
|
1082 |
if chunk_number != 0:
|
1083 |
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
|
1084 |
if "chunk_ids" in req and chunk_number != len(req["chunk_ids"]):
|
|
|
1143 |
schema:
|
1144 |
type: object
|
1145 |
"""
|
1146 |
+
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
|
1147 |
if chunk is None:
|
1148 |
return get_error_data_result(f"Can't find this chunk {chunk_id}")
|
1149 |
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
|
|
1187 |
v, c = embd_mdl.encode([doc.name, d["content_with_weight"]])
|
1188 |
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
1189 |
d["q_%d_vec" % len(v)] = v.tolist()
|
1190 |
+
settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
|
1191 |
return get_result()
|
1192 |
|
1193 |
|
|
|
1285 |
if len(embd_nms) != 1:
|
1286 |
return get_result(
|
1287 |
message='Datasets use different embedding models."',
|
1288 |
+
code=settings.RetCode.AUTHENTICATION_ERROR,
|
1289 |
)
|
1290 |
if "question" not in req:
|
1291 |
return get_error_data_result("`question` is required.")
|
|
|
1326 |
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
|
1327 |
question += keyword_extraction(chat_mdl, question)
|
1328 |
|
1329 |
+
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
|
1330 |
ranks = retr.retrieval(
|
1331 |
question,
|
1332 |
embd_mdl,
|
|
|
1366 |
if str(e).find("not_found") > 0:
|
1367 |
return get_result(
|
1368 |
message="No chunk found! Check the chunk status please!",
|
1369 |
+
code=settings.RetCode.DATA_ERROR,
|
1370 |
)
|
1371 |
+
return server_error_response(e)
|
api/apps/system_app.py
CHANGED
@@ -22,7 +22,7 @@ from api.db.db_models import APIToken
|
|
22 |
from api.db.services.api_service import APITokenService
|
23 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
24 |
from api.db.services.user_service import UserTenantService
|
25 |
-
from api
|
26 |
from api.utils import current_timestamp, datetime_format
|
27 |
from api.utils.api_utils import (
|
28 |
get_json_result,
|
@@ -31,7 +31,6 @@ from api.utils.api_utils import (
|
|
31 |
generate_confirmation_token,
|
32 |
)
|
33 |
from api.versions import get_ragflow_version
|
34 |
-
from api.settings import docStoreConn
|
35 |
from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE
|
36 |
from timeit import default_timer as timer
|
37 |
|
@@ -98,7 +97,7 @@ def status():
|
|
98 |
res = {}
|
99 |
st = timer()
|
100 |
try:
|
101 |
-
res["doc_store"] = docStoreConn.health()
|
102 |
res["doc_store"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
|
103 |
except Exception as e:
|
104 |
res["doc_store"] = {
|
@@ -128,13 +127,13 @@ def status():
|
|
128 |
try:
|
129 |
KnowledgebaseService.get_by_id("x")
|
130 |
res["database"] = {
|
131 |
-
"database": DATABASE_TYPE.lower(),
|
132 |
"status": "green",
|
133 |
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
134 |
}
|
135 |
except Exception as e:
|
136 |
res["database"] = {
|
137 |
-
"database": DATABASE_TYPE.lower(),
|
138 |
"status": "red",
|
139 |
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
140 |
"error": str(e),
|
|
|
22 |
from api.db.services.api_service import APITokenService
|
23 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
24 |
from api.db.services.user_service import UserTenantService
|
25 |
+
from api import settings
|
26 |
from api.utils import current_timestamp, datetime_format
|
27 |
from api.utils.api_utils import (
|
28 |
get_json_result,
|
|
|
31 |
generate_confirmation_token,
|
32 |
)
|
33 |
from api.versions import get_ragflow_version
|
|
|
34 |
from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE
|
35 |
from timeit import default_timer as timer
|
36 |
|
|
|
97 |
res = {}
|
98 |
st = timer()
|
99 |
try:
|
100 |
+
res["doc_store"] = settings.docStoreConn.health()
|
101 |
res["doc_store"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
|
102 |
except Exception as e:
|
103 |
res["doc_store"] = {
|
|
|
127 |
try:
|
128 |
KnowledgebaseService.get_by_id("x")
|
129 |
res["database"] = {
|
130 |
+
"database": settings.DATABASE_TYPE.lower(),
|
131 |
"status": "green",
|
132 |
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
133 |
}
|
134 |
except Exception as e:
|
135 |
res["database"] = {
|
136 |
+
"database": settings.DATABASE_TYPE.lower(),
|
137 |
"status": "red",
|
138 |
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
139 |
"error": str(e),
|
api/apps/user_app.py
CHANGED
@@ -38,20 +38,7 @@ from api.utils import (
|
|
38 |
datetime_format,
|
39 |
)
|
40 |
from api.db import UserTenantRole, FileType
|
41 |
-
from api
|
42 |
-
RetCode,
|
43 |
-
GITHUB_OAUTH,
|
44 |
-
FEISHU_OAUTH,
|
45 |
-
CHAT_MDL,
|
46 |
-
EMBEDDING_MDL,
|
47 |
-
ASR_MDL,
|
48 |
-
IMAGE2TEXT_MDL,
|
49 |
-
PARSERS,
|
50 |
-
API_KEY,
|
51 |
-
LLM_FACTORY,
|
52 |
-
LLM_BASE_URL,
|
53 |
-
RERANK_MDL,
|
54 |
-
)
|
55 |
from api.db.services.user_service import UserService, TenantService, UserTenantService
|
56 |
from api.db.services.file_service import FileService
|
57 |
from api.utils.api_utils import get_json_result, construct_response
|
@@ -90,7 +77,7 @@ def login():
|
|
90 |
"""
|
91 |
if not request.json:
|
92 |
return get_json_result(
|
93 |
-
data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!"
|
94 |
)
|
95 |
|
96 |
email = request.json.get("email", "")
|
@@ -98,7 +85,7 @@ def login():
|
|
98 |
if not users:
|
99 |
return get_json_result(
|
100 |
data=False,
|
101 |
-
code=RetCode.AUTHENTICATION_ERROR,
|
102 |
message=f"Email: {email} is not registered!",
|
103 |
)
|
104 |
|
@@ -107,7 +94,7 @@ def login():
|
|
107 |
password = decrypt(password)
|
108 |
except BaseException:
|
109 |
return get_json_result(
|
110 |
-
data=False, code=RetCode.SERVER_ERROR, message="Fail to crypt password"
|
111 |
)
|
112 |
|
113 |
user = UserService.query_user(email, password)
|
@@ -123,7 +110,7 @@ def login():
|
|
123 |
else:
|
124 |
return get_json_result(
|
125 |
data=False,
|
126 |
-
code=RetCode.AUTHENTICATION_ERROR,
|
127 |
message="Email and password do not match!",
|
128 |
)
|
129 |
|
@@ -150,10 +137,10 @@ def github_callback():
|
|
150 |
import requests
|
151 |
|
152 |
res = requests.post(
|
153 |
-
GITHUB_OAUTH.get("url"),
|
154 |
data={
|
155 |
-
"client_id": GITHUB_OAUTH.get("client_id"),
|
156 |
-
"client_secret": GITHUB_OAUTH.get("secret_key"),
|
157 |
"code": request.args.get("code"),
|
158 |
},
|
159 |
headers={"Accept": "application/json"},
|
@@ -235,11 +222,11 @@ def feishu_callback():
|
|
235 |
import requests
|
236 |
|
237 |
app_access_token_res = requests.post(
|
238 |
-
FEISHU_OAUTH.get("app_access_token_url"),
|
239 |
data=json.dumps(
|
240 |
{
|
241 |
-
"app_id": FEISHU_OAUTH.get("app_id"),
|
242 |
-
"app_secret": FEISHU_OAUTH.get("app_secret"),
|
243 |
}
|
244 |
),
|
245 |
headers={"Content-Type": "application/json; charset=utf-8"},
|
@@ -249,10 +236,10 @@ def feishu_callback():
|
|
249 |
return redirect("/?error=%s" % app_access_token_res)
|
250 |
|
251 |
res = requests.post(
|
252 |
-
FEISHU_OAUTH.get("user_access_token_url"),
|
253 |
data=json.dumps(
|
254 |
{
|
255 |
-
"grant_type": FEISHU_OAUTH.get("grant_type"),
|
256 |
"code": request.args.get("code"),
|
257 |
}
|
258 |
),
|
@@ -405,11 +392,11 @@ def setting_user():
|
|
405 |
if request_data.get("password"):
|
406 |
new_password = request_data.get("new_password")
|
407 |
if not check_password_hash(
|
408 |
-
|
409 |
):
|
410 |
return get_json_result(
|
411 |
data=False,
|
412 |
-
code=RetCode.AUTHENTICATION_ERROR,
|
413 |
message="Password error!",
|
414 |
)
|
415 |
|
@@ -438,7 +425,7 @@ def setting_user():
|
|
438 |
except Exception as e:
|
439 |
logging.exception(e)
|
440 |
return get_json_result(
|
441 |
-
data=False, message="Update failure!", code=RetCode.EXCEPTION_ERROR
|
442 |
)
|
443 |
|
444 |
|
@@ -497,12 +484,12 @@ def user_register(user_id, user):
|
|
497 |
tenant = {
|
498 |
"id": user_id,
|
499 |
"name": user["nickname"] + "‘s Kingdom",
|
500 |
-
"llm_id": CHAT_MDL,
|
501 |
-
"embd_id": EMBEDDING_MDL,
|
502 |
-
"asr_id": ASR_MDL,
|
503 |
-
"parser_ids": PARSERS,
|
504 |
-
"img2txt_id": IMAGE2TEXT_MDL,
|
505 |
-
"rerank_id": RERANK_MDL,
|
506 |
}
|
507 |
usr_tenant = {
|
508 |
"tenant_id": user_id,
|
@@ -522,15 +509,15 @@ def user_register(user_id, user):
|
|
522 |
"location": "",
|
523 |
}
|
524 |
tenant_llm = []
|
525 |
-
for llm in LLMService.query(fid=LLM_FACTORY):
|
526 |
tenant_llm.append(
|
527 |
{
|
528 |
"tenant_id": user_id,
|
529 |
-
"llm_factory": LLM_FACTORY,
|
530 |
"llm_name": llm.llm_name,
|
531 |
"model_type": llm.model_type,
|
532 |
-
"api_key": API_KEY,
|
533 |
-
"api_base": LLM_BASE_URL,
|
534 |
}
|
535 |
)
|
536 |
|
@@ -582,7 +569,7 @@ def user_add():
|
|
582 |
return get_json_result(
|
583 |
data=False,
|
584 |
message=f"Invalid email address: {email_address}!",
|
585 |
-
code=RetCode.OPERATING_ERROR,
|
586 |
)
|
587 |
|
588 |
# Check if the email address is already used
|
@@ -590,7 +577,7 @@ def user_add():
|
|
590 |
return get_json_result(
|
591 |
data=False,
|
592 |
message=f"Email: {email_address} has already registered!",
|
593 |
-
code=RetCode.OPERATING_ERROR,
|
594 |
)
|
595 |
|
596 |
# Construct user info data
|
@@ -625,7 +612,7 @@ def user_add():
|
|
625 |
return get_json_result(
|
626 |
data=False,
|
627 |
message=f"User registration failure, error: {str(e)}",
|
628 |
-
code=RetCode.EXCEPTION_ERROR,
|
629 |
)
|
630 |
|
631 |
|
|
|
38 |
datetime_format,
|
39 |
)
|
40 |
from api.db import UserTenantRole, FileType
|
41 |
+
from api import settings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
from api.db.services.user_service import UserService, TenantService, UserTenantService
|
43 |
from api.db.services.file_service import FileService
|
44 |
from api.utils.api_utils import get_json_result, construct_response
|
|
|
77 |
"""
|
78 |
if not request.json:
|
79 |
return get_json_result(
|
80 |
+
data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="Unauthorized!"
|
81 |
)
|
82 |
|
83 |
email = request.json.get("email", "")
|
|
|
85 |
if not users:
|
86 |
return get_json_result(
|
87 |
data=False,
|
88 |
+
code=settings.RetCode.AUTHENTICATION_ERROR,
|
89 |
message=f"Email: {email} is not registered!",
|
90 |
)
|
91 |
|
|
|
94 |
password = decrypt(password)
|
95 |
except BaseException:
|
96 |
return get_json_result(
|
97 |
+
data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password"
|
98 |
)
|
99 |
|
100 |
user = UserService.query_user(email, password)
|
|
|
110 |
else:
|
111 |
return get_json_result(
|
112 |
data=False,
|
113 |
+
code=settings.RetCode.AUTHENTICATION_ERROR,
|
114 |
message="Email and password do not match!",
|
115 |
)
|
116 |
|
|
|
137 |
import requests
|
138 |
|
139 |
res = requests.post(
|
140 |
+
settings.GITHUB_OAUTH.get("url"),
|
141 |
data={
|
142 |
+
"client_id": settings.GITHUB_OAUTH.get("client_id"),
|
143 |
+
"client_secret": settings.GITHUB_OAUTH.get("secret_key"),
|
144 |
"code": request.args.get("code"),
|
145 |
},
|
146 |
headers={"Accept": "application/json"},
|
|
|
222 |
import requests
|
223 |
|
224 |
app_access_token_res = requests.post(
|
225 |
+
settings.FEISHU_OAUTH.get("app_access_token_url"),
|
226 |
data=json.dumps(
|
227 |
{
|
228 |
+
"app_id": settings.FEISHU_OAUTH.get("app_id"),
|
229 |
+
"app_secret": settings.FEISHU_OAUTH.get("app_secret"),
|
230 |
}
|
231 |
),
|
232 |
headers={"Content-Type": "application/json; charset=utf-8"},
|
|
|
236 |
return redirect("/?error=%s" % app_access_token_res)
|
237 |
|
238 |
res = requests.post(
|
239 |
+
settings.FEISHU_OAUTH.get("user_access_token_url"),
|
240 |
data=json.dumps(
|
241 |
{
|
242 |
+
"grant_type": settings.FEISHU_OAUTH.get("grant_type"),
|
243 |
"code": request.args.get("code"),
|
244 |
}
|
245 |
),
|
|
|
392 |
if request_data.get("password"):
|
393 |
new_password = request_data.get("new_password")
|
394 |
if not check_password_hash(
|
395 |
+
current_user.password, decrypt(request_data["password"])
|
396 |
):
|
397 |
return get_json_result(
|
398 |
data=False,
|
399 |
+
code=settings.RetCode.AUTHENTICATION_ERROR,
|
400 |
message="Password error!",
|
401 |
)
|
402 |
|
|
|
425 |
except Exception as e:
|
426 |
logging.exception(e)
|
427 |
return get_json_result(
|
428 |
+
data=False, message="Update failure!", code=settings.RetCode.EXCEPTION_ERROR
|
429 |
)
|
430 |
|
431 |
|
|
|
484 |
tenant = {
|
485 |
"id": user_id,
|
486 |
"name": user["nickname"] + "‘s Kingdom",
|
487 |
+
"llm_id": settings.CHAT_MDL,
|
488 |
+
"embd_id": settings.EMBEDDING_MDL,
|
489 |
+
"asr_id": settings.ASR_MDL,
|
490 |
+
"parser_ids": settings.PARSERS,
|
491 |
+
"img2txt_id": settings.IMAGE2TEXT_MDL,
|
492 |
+
"rerank_id": settings.RERANK_MDL,
|
493 |
}
|
494 |
usr_tenant = {
|
495 |
"tenant_id": user_id,
|
|
|
509 |
"location": "",
|
510 |
}
|
511 |
tenant_llm = []
|
512 |
+
for llm in LLMService.query(fid=settings.LLM_FACTORY):
|
513 |
tenant_llm.append(
|
514 |
{
|
515 |
"tenant_id": user_id,
|
516 |
+
"llm_factory": settings.LLM_FACTORY,
|
517 |
"llm_name": llm.llm_name,
|
518 |
"model_type": llm.model_type,
|
519 |
+
"api_key": settings.API_KEY,
|
520 |
+
"api_base": settings.LLM_BASE_URL,
|
521 |
}
|
522 |
)
|
523 |
|
|
|
569 |
return get_json_result(
|
570 |
data=False,
|
571 |
message=f"Invalid email address: {email_address}!",
|
572 |
+
code=settings.RetCode.OPERATING_ERROR,
|
573 |
)
|
574 |
|
575 |
# Check if the email address is already used
|
|
|
577 |
return get_json_result(
|
578 |
data=False,
|
579 |
message=f"Email: {email_address} has already registered!",
|
580 |
+
code=settings.RetCode.OPERATING_ERROR,
|
581 |
)
|
582 |
|
583 |
# Construct user info data
|
|
|
612 |
return get_json_result(
|
613 |
data=False,
|
614 |
message=f"User registration failure, error: {str(e)}",
|
615 |
+
code=settings.RetCode.EXCEPTION_ERROR,
|
616 |
)
|
617 |
|
618 |
|
api/db/db_models.py
CHANGED
@@ -31,7 +31,7 @@ from peewee import (
|
|
31 |
)
|
32 |
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
33 |
from api.db import SerializedType, ParserType
|
34 |
-
from api
|
35 |
from api import utils
|
36 |
|
37 |
def singleton(cls, *args, **kw):
|
@@ -62,7 +62,7 @@ class TextFieldType(Enum):
|
|
62 |
|
63 |
|
64 |
class LongTextField(TextField):
|
65 |
-
field_type = TextFieldType[DATABASE_TYPE.upper()].value
|
66 |
|
67 |
|
68 |
class JSONField(LongTextField):
|
@@ -282,9 +282,9 @@ class DatabaseMigrator(Enum):
|
|
282 |
@singleton
|
283 |
class BaseDataBase:
|
284 |
def __init__(self):
|
285 |
-
database_config = DATABASE.copy()
|
286 |
db_name = database_config.pop("name")
|
287 |
-
self.database_connection = PooledDatabase[DATABASE_TYPE.upper()].value(db_name, **database_config)
|
288 |
logging.info('init database on cluster mode successfully')
|
289 |
|
290 |
class PostgresDatabaseLock:
|
@@ -385,7 +385,7 @@ class DatabaseLock(Enum):
|
|
385 |
|
386 |
|
387 |
DB = BaseDataBase().database_connection
|
388 |
-
DB.lock = DatabaseLock[DATABASE_TYPE.upper()].value
|
389 |
|
390 |
|
391 |
def close_connection():
|
@@ -476,7 +476,7 @@ class User(DataBaseModel, UserMixin):
|
|
476 |
return self.email
|
477 |
|
478 |
def get_id(self):
|
479 |
-
jwt = Serializer(secret_key=SECRET_KEY)
|
480 |
return jwt.dumps(str(self.access_token))
|
481 |
|
482 |
class Meta:
|
@@ -977,7 +977,7 @@ class CanvasTemplate(DataBaseModel):
|
|
977 |
|
978 |
def migrate_db():
|
979 |
with DB.transaction():
|
980 |
-
migrator = DatabaseMigrator[DATABASE_TYPE.upper()].value(DB)
|
981 |
try:
|
982 |
migrate(
|
983 |
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",
|
|
|
31 |
)
|
32 |
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
33 |
from api.db import SerializedType, ParserType
|
34 |
+
from api import settings
|
35 |
from api import utils
|
36 |
|
37 |
def singleton(cls, *args, **kw):
|
|
|
62 |
|
63 |
|
64 |
class LongTextField(TextField):
|
65 |
+
field_type = TextFieldType[settings.DATABASE_TYPE.upper()].value
|
66 |
|
67 |
|
68 |
class JSONField(LongTextField):
|
|
|
282 |
@singleton
|
283 |
class BaseDataBase:
|
284 |
def __init__(self):
|
285 |
+
database_config = settings.DATABASE.copy()
|
286 |
db_name = database_config.pop("name")
|
287 |
+
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
|
288 |
logging.info('init database on cluster mode successfully')
|
289 |
|
290 |
class PostgresDatabaseLock:
|
|
|
385 |
|
386 |
|
387 |
DB = BaseDataBase().database_connection
|
388 |
+
DB.lock = DatabaseLock[settings.DATABASE_TYPE.upper()].value
|
389 |
|
390 |
|
391 |
def close_connection():
|
|
|
476 |
return self.email
|
477 |
|
478 |
def get_id(self):
|
479 |
+
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
480 |
return jwt.dumps(str(self.access_token))
|
481 |
|
482 |
class Meta:
|
|
|
977 |
|
978 |
def migrate_db():
|
979 |
with DB.transaction():
|
980 |
+
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
981 |
try:
|
982 |
migrate(
|
983 |
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",
|
api/db/init_data.py
CHANGED
@@ -29,7 +29,7 @@ from api.db.services.document_service import DocumentService
|
|
29 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
30 |
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
|
31 |
from api.db.services.user_service import TenantService, UserTenantService
|
32 |
-
from api
|
33 |
from api.utils.file_utils import get_project_base_directory
|
34 |
|
35 |
|
@@ -51,11 +51,11 @@ def init_superuser():
|
|
51 |
tenant = {
|
52 |
"id": user_info["id"],
|
53 |
"name": user_info["nickname"] + "‘s Kingdom",
|
54 |
-
"llm_id": CHAT_MDL,
|
55 |
-
"embd_id": EMBEDDING_MDL,
|
56 |
-
"asr_id": ASR_MDL,
|
57 |
-
"parser_ids": PARSERS,
|
58 |
-
"img2txt_id": IMAGE2TEXT_MDL
|
59 |
}
|
60 |
usr_tenant = {
|
61 |
"tenant_id": user_info["id"],
|
@@ -64,10 +64,11 @@ def init_superuser():
|
|
64 |
"role": UserTenantRole.OWNER
|
65 |
}
|
66 |
tenant_llm = []
|
67 |
-
for llm in LLMService.query(fid=LLM_FACTORY):
|
68 |
tenant_llm.append(
|
69 |
-
{"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name,
|
70 |
-
"
|
|
|
71 |
|
72 |
if not UserService.save(**user_info):
|
73 |
logging.error("can't init admin.")
|
@@ -80,7 +81,7 @@ def init_superuser():
|
|
80 |
|
81 |
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
82 |
msg = chat_mdl.chat(system="", history=[
|
83 |
-
|
84 |
if msg.find("ERROR: ") == 0:
|
85 |
logging.error(
|
86 |
"'{}' dosen't work. {}".format(
|
@@ -179,7 +180,7 @@ def init_web_data():
|
|
179 |
start_time = time.time()
|
180 |
|
181 |
init_llm_factory()
|
182 |
-
#if not UserService.get_all().count():
|
183 |
# init_superuser()
|
184 |
|
185 |
add_graph_templates()
|
|
|
29 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
30 |
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
|
31 |
from api.db.services.user_service import TenantService, UserTenantService
|
32 |
+
from api import settings
|
33 |
from api.utils.file_utils import get_project_base_directory
|
34 |
|
35 |
|
|
|
51 |
tenant = {
|
52 |
"id": user_info["id"],
|
53 |
"name": user_info["nickname"] + "‘s Kingdom",
|
54 |
+
"llm_id": settings.CHAT_MDL,
|
55 |
+
"embd_id": settings.EMBEDDING_MDL,
|
56 |
+
"asr_id": settings.ASR_MDL,
|
57 |
+
"parser_ids": settings.PARSERS,
|
58 |
+
"img2txt_id": settings.IMAGE2TEXT_MDL
|
59 |
}
|
60 |
usr_tenant = {
|
61 |
"tenant_id": user_info["id"],
|
|
|
64 |
"role": UserTenantRole.OWNER
|
65 |
}
|
66 |
tenant_llm = []
|
67 |
+
for llm in LLMService.query(fid=settings.LLM_FACTORY):
|
68 |
tenant_llm.append(
|
69 |
+
{"tenant_id": user_info["id"], "llm_factory": settings.LLM_FACTORY, "llm_name": llm.llm_name,
|
70 |
+
"model_type": llm.model_type,
|
71 |
+
"api_key": settings.API_KEY, "api_base": settings.LLM_BASE_URL})
|
72 |
|
73 |
if not UserService.save(**user_info):
|
74 |
logging.error("can't init admin.")
|
|
|
81 |
|
82 |
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
83 |
msg = chat_mdl.chat(system="", history=[
|
84 |
+
{"role": "user", "content": "Hello!"}], gen_conf={})
|
85 |
if msg.find("ERROR: ") == 0:
|
86 |
logging.error(
|
87 |
"'{}' dosen't work. {}".format(
|
|
|
180 |
start_time = time.time()
|
181 |
|
182 |
init_llm_factory()
|
183 |
+
# if not UserService.get_all().count():
|
184 |
# init_superuser()
|
185 |
|
186 |
add_graph_templates()
|
api/db/services/dialog_service.py
CHANGED
@@ -27,7 +27,7 @@ from api.db.db_models import Dialog, Conversation,DB
|
|
27 |
from api.db.services.common_service import CommonService
|
28 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
29 |
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
|
30 |
-
from api
|
31 |
from rag.app.resume import forbidden_select_fields4resume
|
32 |
from rag.nlp.search import index_name
|
33 |
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
@@ -152,7 +152,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
152 |
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
153 |
|
154 |
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
155 |
-
retr = retrievaler if not is_kg else kg_retrievaler
|
156 |
|
157 |
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
158 |
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
|
@@ -342,7 +342,7 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|
342 |
|
343 |
logging.debug(f"{question} get SQL(refined): {sql}")
|
344 |
tried_times += 1
|
345 |
-
return retrievaler.sql_retrieval(sql, format="json"), sql
|
346 |
|
347 |
tbl, sql = get_table()
|
348 |
if tbl is None:
|
@@ -596,7 +596,7 @@ def ask(question, kb_ids, tenant_id):
|
|
596 |
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
597 |
|
598 |
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
599 |
-
retr = retrievaler if not is_kg else kg_retrievaler
|
600 |
|
601 |
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
602 |
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
|
|
27 |
from api.db.services.common_service import CommonService
|
28 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
29 |
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
|
30 |
+
from api import settings
|
31 |
from rag.app.resume import forbidden_select_fields4resume
|
32 |
from rag.nlp.search import index_name
|
33 |
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
|
|
152 |
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
153 |
|
154 |
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
155 |
+
retr = settings.retrievaler if not is_kg else settings.kg_retrievaler
|
156 |
|
157 |
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
158 |
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
|
|
|
342 |
|
343 |
logging.debug(f"{question} get SQL(refined): {sql}")
|
344 |
tried_times += 1
|
345 |
+
return settings.retrievaler.sql_retrieval(sql, format="json"), sql
|
346 |
|
347 |
tbl, sql = get_table()
|
348 |
if tbl is None:
|
|
|
596 |
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
597 |
|
598 |
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
599 |
+
retr = settings.retrievaler if not is_kg else settings.kg_retrievaler
|
600 |
|
601 |
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
602 |
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
api/db/services/document_service.py
CHANGED
@@ -26,7 +26,7 @@ from io import BytesIO
|
|
26 |
from peewee import fn
|
27 |
|
28 |
from api.db.db_utils import bulk_insert_into_db
|
29 |
-
from api
|
30 |
from api.utils import current_timestamp, get_format_time, get_uuid
|
31 |
from graphrag.mind_map_extractor import MindMapExtractor
|
32 |
from rag.settings import SVR_QUEUE_NAME
|
@@ -108,7 +108,7 @@ class DocumentService(CommonService):
|
|
108 |
@classmethod
|
109 |
@DB.connection_context()
|
110 |
def remove_document(cls, doc, tenant_id):
|
111 |
-
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
112 |
cls.clear_chunk_num(doc.id)
|
113 |
return cls.delete_by_id(doc.id)
|
114 |
|
@@ -553,10 +553,10 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|
553 |
d["q_%d_vec" % len(v)] = v
|
554 |
for b in range(0, len(cks), es_bulk_size):
|
555 |
if try_create_idx:
|
556 |
-
if not docStoreConn.indexExist(idxnm, kb_id):
|
557 |
-
docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
|
558 |
try_create_idx = False
|
559 |
-
docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
560 |
|
561 |
DocumentService.increment_chunk_num(
|
562 |
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
|
|
26 |
from peewee import fn
|
27 |
|
28 |
from api.db.db_utils import bulk_insert_into_db
|
29 |
+
from api import settings
|
30 |
from api.utils import current_timestamp, get_format_time, get_uuid
|
31 |
from graphrag.mind_map_extractor import MindMapExtractor
|
32 |
from rag.settings import SVR_QUEUE_NAME
|
|
|
108 |
@classmethod
|
109 |
@DB.connection_context()
|
110 |
def remove_document(cls, doc, tenant_id):
|
111 |
+
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
112 |
cls.clear_chunk_num(doc.id)
|
113 |
return cls.delete_by_id(doc.id)
|
114 |
|
|
|
553 |
d["q_%d_vec" % len(v)] = v
|
554 |
for b in range(0, len(cks), es_bulk_size):
|
555 |
if try_create_idx:
|
556 |
+
if not settings.docStoreConn.indexExist(idxnm, kb_id):
|
557 |
+
settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
|
558 |
try_create_idx = False
|
559 |
+
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
560 |
|
561 |
DocumentService.increment_chunk_num(
|
562 |
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
api/ragflow_server.py
CHANGED
@@ -33,12 +33,10 @@ import traceback
|
|
33 |
from concurrent.futures import ThreadPoolExecutor
|
34 |
|
35 |
from werkzeug.serving import run_simple
|
|
|
36 |
from api.apps import app
|
37 |
from api.db.runtime_config import RuntimeConfig
|
38 |
from api.db.services.document_service import DocumentService
|
39 |
-
from api.settings import (
|
40 |
-
HOST, HTTP_PORT
|
41 |
-
)
|
42 |
from api import utils
|
43 |
|
44 |
from api.db.db_models import init_database_tables as init_web_db
|
@@ -72,6 +70,7 @@ if __name__ == '__main__':
|
|
72 |
f'project base: {utils.file_utils.get_project_base_directory()}'
|
73 |
)
|
74 |
show_configs()
|
|
|
75 |
|
76 |
# init db
|
77 |
init_web_db()
|
@@ -96,7 +95,7 @@ if __name__ == '__main__':
|
|
96 |
logging.info("run on debug mode")
|
97 |
|
98 |
RuntimeConfig.init_env()
|
99 |
-
RuntimeConfig.init_config(JOB_SERVER_HOST=
|
100 |
|
101 |
thread = ThreadPoolExecutor(max_workers=1)
|
102 |
thread.submit(update_progress)
|
@@ -105,8 +104,8 @@ if __name__ == '__main__':
|
|
105 |
try:
|
106 |
logging.info("RAGFlow HTTP server start...")
|
107 |
run_simple(
|
108 |
-
hostname=
|
109 |
-
port=
|
110 |
application=app,
|
111 |
threaded=True,
|
112 |
use_reloader=RuntimeConfig.DEBUG,
|
|
|
33 |
from concurrent.futures import ThreadPoolExecutor
|
34 |
|
35 |
from werkzeug.serving import run_simple
|
36 |
+
from api import settings
|
37 |
from api.apps import app
|
38 |
from api.db.runtime_config import RuntimeConfig
|
39 |
from api.db.services.document_service import DocumentService
|
|
|
|
|
|
|
40 |
from api import utils
|
41 |
|
42 |
from api.db.db_models import init_database_tables as init_web_db
|
|
|
70 |
f'project base: {utils.file_utils.get_project_base_directory()}'
|
71 |
)
|
72 |
show_configs()
|
73 |
+
settings.init_settings()
|
74 |
|
75 |
# init db
|
76 |
init_web_db()
|
|
|
95 |
logging.info("run on debug mode")
|
96 |
|
97 |
RuntimeConfig.init_env()
|
98 |
+
RuntimeConfig.init_config(JOB_SERVER_HOST=settings.HOST_IP, HTTP_PORT=settings.HOST_PORT)
|
99 |
|
100 |
thread = ThreadPoolExecutor(max_workers=1)
|
101 |
thread.submit(update_progress)
|
|
|
104 |
try:
|
105 |
logging.info("RAGFlow HTTP server start...")
|
106 |
run_simple(
|
107 |
+
hostname=settings.HOST_IP,
|
108 |
+
port=settings.HOST_PORT,
|
109 |
application=app,
|
110 |
threaded=True,
|
111 |
use_reloader=RuntimeConfig.DEBUG,
|
api/settings.py
CHANGED
@@ -30,114 +30,157 @@ LIGHTEN = int(os.environ.get('LIGHTEN', "0"))
|
|
30 |
|
31 |
REQUEST_WAIT_SEC = 2
|
32 |
REQUEST_MAX_WAIT_SEC = 300
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
},
|
47 |
-
"OpenAI": {
|
48 |
-
"chat_model": "gpt-3.5-turbo",
|
49 |
-
"embedding_model": "text-embedding-ada-002",
|
50 |
-
"image2text_model": "gpt-4-vision-preview",
|
51 |
-
"asr_model": "whisper-1",
|
52 |
-
},
|
53 |
-
"Azure-OpenAI": {
|
54 |
-
"chat_model": "gpt-35-turbo",
|
55 |
-
"embedding_model": "text-embedding-ada-002",
|
56 |
-
"image2text_model": "gpt-4-vision-preview",
|
57 |
-
"asr_model": "whisper-1",
|
58 |
-
},
|
59 |
-
"ZHIPU-AI": {
|
60 |
-
"chat_model": "glm-3-turbo",
|
61 |
-
"embedding_model": "embedding-2",
|
62 |
-
"image2text_model": "glm-4v",
|
63 |
-
"asr_model": "",
|
64 |
-
},
|
65 |
-
"Ollama": {
|
66 |
-
"chat_model": "qwen-14B-chat",
|
67 |
-
"embedding_model": "flag-embedding",
|
68 |
-
"image2text_model": "",
|
69 |
-
"asr_model": "",
|
70 |
-
},
|
71 |
-
"Moonshot": {
|
72 |
-
"chat_model": "moonshot-v1-8k",
|
73 |
-
"embedding_model": "",
|
74 |
-
"image2text_model": "",
|
75 |
-
"asr_model": "",
|
76 |
-
},
|
77 |
-
"DeepSeek": {
|
78 |
-
"chat_model": "deepseek-chat",
|
79 |
-
"embedding_model": "",
|
80 |
-
"image2text_model": "",
|
81 |
-
"asr_model": "",
|
82 |
-
},
|
83 |
-
"VolcEngine": {
|
84 |
-
"chat_model": "",
|
85 |
-
"embedding_model": "",
|
86 |
-
"image2text_model": "",
|
87 |
-
"asr_model": "",
|
88 |
-
},
|
89 |
-
"BAAI": {
|
90 |
-
"chat_model": "",
|
91 |
-
"embedding_model": "BAAI/bge-large-zh-v1.5",
|
92 |
-
"image2text_model": "",
|
93 |
-
"asr_model": "",
|
94 |
-
"rerank_model": "BAAI/bge-reranker-v2-m3",
|
95 |
-
}
|
96 |
-
}
|
97 |
-
|
98 |
-
if LLM_FACTORY:
|
99 |
-
CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"] + f"@{LLM_FACTORY}"
|
100 |
-
ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] + f"@{LLM_FACTORY}"
|
101 |
-
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] + f"@{LLM_FACTORY}"
|
102 |
-
EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"] + "@BAAI"
|
103 |
-
RERANK_MDL = default_llm["BAAI"]["rerank_model"] + "@BAAI"
|
104 |
-
|
105 |
-
API_KEY = LLM.get("api_key", "")
|
106 |
-
PARSERS = LLM.get(
|
107 |
-
"parsers",
|
108 |
-
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email")
|
109 |
-
|
110 |
-
HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
|
111 |
-
HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
|
112 |
-
|
113 |
-
SECRET_KEY = get_base_config(
|
114 |
-
RAG_FLOW_SERVICE_NAME,
|
115 |
-
{}).get("secret_key", str(date.today()))
|
116 |
|
117 |
DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
|
118 |
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
|
119 |
|
120 |
# authentication
|
121 |
-
AUTHENTICATION_CONF =
|
122 |
|
123 |
# client
|
124 |
-
CLIENT_AUTHENTICATION =
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
|
143 |
class CustomEnum(Enum):
|
|
|
30 |
|
31 |
REQUEST_WAIT_SEC = 2
|
32 |
REQUEST_MAX_WAIT_SEC = 300
|
33 |
+
LLM = None
|
34 |
+
LLM_FACTORY = None
|
35 |
+
LLM_BASE_URL = None
|
36 |
+
CHAT_MDL = ""
|
37 |
+
EMBEDDING_MDL = ""
|
38 |
+
RERANK_MDL = ""
|
39 |
+
ASR_MDL = ""
|
40 |
+
IMAGE2TEXT_MDL = ""
|
41 |
+
API_KEY = None
|
42 |
+
PARSERS = None
|
43 |
+
HOST_IP = None
|
44 |
+
HOST_PORT = None
|
45 |
+
SECRET_KEY = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
|
48 |
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
|
49 |
|
50 |
# authentication
|
51 |
+
AUTHENTICATION_CONF = None
|
52 |
|
53 |
# client
|
54 |
+
CLIENT_AUTHENTICATION = None
|
55 |
+
HTTP_APP_KEY = None
|
56 |
+
GITHUB_OAUTH = None
|
57 |
+
FEISHU_OAUTH = None
|
58 |
+
|
59 |
+
DOC_ENGINE = None
|
60 |
+
docStoreConn = None
|
61 |
+
|
62 |
+
retrievaler = None
|
63 |
+
kg_retrievaler = None
|
64 |
+
|
65 |
+
|
66 |
+
def init_settings():
|
67 |
+
global LLM, LLM_FACTORY, LLM_BASE_URL
|
68 |
+
LLM = get_base_config("user_default_llm", {})
|
69 |
+
LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
|
70 |
+
LLM_BASE_URL = LLM.get("base_url")
|
71 |
+
|
72 |
+
global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
|
73 |
+
if not LIGHTEN:
|
74 |
+
default_llm = {
|
75 |
+
"Tongyi-Qianwen": {
|
76 |
+
"chat_model": "qwen-plus",
|
77 |
+
"embedding_model": "text-embedding-v2",
|
78 |
+
"image2text_model": "qwen-vl-max",
|
79 |
+
"asr_model": "paraformer-realtime-8k-v1",
|
80 |
+
},
|
81 |
+
"OpenAI": {
|
82 |
+
"chat_model": "gpt-3.5-turbo",
|
83 |
+
"embedding_model": "text-embedding-ada-002",
|
84 |
+
"image2text_model": "gpt-4-vision-preview",
|
85 |
+
"asr_model": "whisper-1",
|
86 |
+
},
|
87 |
+
"Azure-OpenAI": {
|
88 |
+
"chat_model": "gpt-35-turbo",
|
89 |
+
"embedding_model": "text-embedding-ada-002",
|
90 |
+
"image2text_model": "gpt-4-vision-preview",
|
91 |
+
"asr_model": "whisper-1",
|
92 |
+
},
|
93 |
+
"ZHIPU-AI": {
|
94 |
+
"chat_model": "glm-3-turbo",
|
95 |
+
"embedding_model": "embedding-2",
|
96 |
+
"image2text_model": "glm-4v",
|
97 |
+
"asr_model": "",
|
98 |
+
},
|
99 |
+
"Ollama": {
|
100 |
+
"chat_model": "qwen-14B-chat",
|
101 |
+
"embedding_model": "flag-embedding",
|
102 |
+
"image2text_model": "",
|
103 |
+
"asr_model": "",
|
104 |
+
},
|
105 |
+
"Moonshot": {
|
106 |
+
"chat_model": "moonshot-v1-8k",
|
107 |
+
"embedding_model": "",
|
108 |
+
"image2text_model": "",
|
109 |
+
"asr_model": "",
|
110 |
+
},
|
111 |
+
"DeepSeek": {
|
112 |
+
"chat_model": "deepseek-chat",
|
113 |
+
"embedding_model": "",
|
114 |
+
"image2text_model": "",
|
115 |
+
"asr_model": "",
|
116 |
+
},
|
117 |
+
"VolcEngine": {
|
118 |
+
"chat_model": "",
|
119 |
+
"embedding_model": "",
|
120 |
+
"image2text_model": "",
|
121 |
+
"asr_model": "",
|
122 |
+
},
|
123 |
+
"BAAI": {
|
124 |
+
"chat_model": "",
|
125 |
+
"embedding_model": "BAAI/bge-large-zh-v1.5",
|
126 |
+
"image2text_model": "",
|
127 |
+
"asr_model": "",
|
128 |
+
"rerank_model": "BAAI/bge-reranker-v2-m3",
|
129 |
+
}
|
130 |
+
}
|
131 |
+
|
132 |
+
if LLM_FACTORY:
|
133 |
+
CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"] + f"@{LLM_FACTORY}"
|
134 |
+
ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] + f"@{LLM_FACTORY}"
|
135 |
+
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] + f"@{LLM_FACTORY}"
|
136 |
+
EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"] + "@BAAI"
|
137 |
+
RERANK_MDL = default_llm["BAAI"]["rerank_model"] + "@BAAI"
|
138 |
+
|
139 |
+
global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY
|
140 |
+
API_KEY = LLM.get("api_key", "")
|
141 |
+
PARSERS = LLM.get(
|
142 |
+
"parsers",
|
143 |
+
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email")
|
144 |
+
|
145 |
+
HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
|
146 |
+
HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
|
147 |
+
|
148 |
+
SECRET_KEY = get_base_config(
|
149 |
+
RAG_FLOW_SERVICE_NAME,
|
150 |
+
{}).get("secret_key", str(date.today()))
|
151 |
+
|
152 |
+
global AUTHENTICATION_CONF, CLIENT_AUTHENTICATION, HTTP_APP_KEY, GITHUB_OAUTH, FEISHU_OAUTH
|
153 |
+
# authentication
|
154 |
+
AUTHENTICATION_CONF = get_base_config("authentication", {})
|
155 |
+
|
156 |
+
# client
|
157 |
+
CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get(
|
158 |
+
"client", {}).get(
|
159 |
+
"switch", False)
|
160 |
+
HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
|
161 |
+
GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
|
162 |
+
FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")
|
163 |
+
|
164 |
+
global DOC_ENGINE, docStoreConn, retrievaler, kg_retrievaler
|
165 |
+
DOC_ENGINE = os.environ.get('DOC_ENGINE', "elasticsearch")
|
166 |
+
if DOC_ENGINE == "elasticsearch":
|
167 |
+
docStoreConn = rag.utils.es_conn.ESConnection()
|
168 |
+
elif DOC_ENGINE == "infinity":
|
169 |
+
docStoreConn = rag.utils.infinity_conn.InfinityConnection()
|
170 |
+
else:
|
171 |
+
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
|
172 |
+
|
173 |
+
retrievaler = search.Dealer(docStoreConn)
|
174 |
+
kg_retrievaler = kg_search.KGSearch(docStoreConn)
|
175 |
+
|
176 |
+
def get_host_ip():
|
177 |
+
global HOST_IP
|
178 |
+
return HOST_IP
|
179 |
+
|
180 |
+
|
181 |
+
def get_host_port():
|
182 |
+
global HOST_PORT
|
183 |
+
return HOST_PORT
|
184 |
|
185 |
|
186 |
class CustomEnum(Enum):
|
api/utils/api_utils.py
CHANGED
@@ -34,11 +34,9 @@ from itsdangerous import URLSafeTimedSerializer
|
|
34 |
from werkzeug.http import HTTP_STATUS_CODES
|
35 |
|
36 |
from api.db.db_models import APIToken
|
37 |
-
from api
|
38 |
-
|
39 |
-
|
40 |
-
)
|
41 |
-
from api.settings import RetCode
|
42 |
from api.utils import CustomJSONEncoder, get_uuid
|
43 |
from api.utils import json_dumps
|
44 |
|
@@ -59,13 +57,13 @@ def request(**kwargs):
|
|
59 |
{}).items()}
|
60 |
prepped = requests.Request(**kwargs).prepare()
|
61 |
|
62 |
-
if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY:
|
63 |
timestamp = str(round(time() * 1000))
|
64 |
nonce = str(uuid1())
|
65 |
-
signature = b64encode(HMAC(SECRET_KEY.encode('ascii'), b'\n'.join([
|
66 |
timestamp.encode('ascii'),
|
67 |
nonce.encode('ascii'),
|
68 |
-
HTTP_APP_KEY.encode('ascii'),
|
69 |
prepped.path_url.encode('ascii'),
|
70 |
prepped.body if kwargs.get('json') else b'',
|
71 |
urlencode(
|
@@ -79,7 +77,7 @@ def request(**kwargs):
|
|
79 |
prepped.headers.update({
|
80 |
'TIMESTAMP': timestamp,
|
81 |
'NONCE': nonce,
|
82 |
-
'APP-KEY': HTTP_APP_KEY,
|
83 |
'SIGNATURE': signature,
|
84 |
})
|
85 |
|
@@ -89,7 +87,7 @@ def request(**kwargs):
|
|
89 |
def get_exponential_backoff_interval(retries, full_jitter=False):
|
90 |
"""Calculate the exponential backoff wait time."""
|
91 |
# Will be zero if factor equals 0
|
92 |
-
countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries))
|
93 |
# Full jitter according to
|
94 |
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
|
95 |
if full_jitter:
|
@@ -98,7 +96,7 @@ def get_exponential_backoff_interval(retries, full_jitter=False):
|
|
98 |
return max(0, countdown)
|
99 |
|
100 |
|
101 |
-
def get_data_error_result(code=RetCode.DATA_ERROR,
|
102 |
message='Sorry! Data missing!'):
|
103 |
import re
|
104 |
result_dict = {
|
@@ -126,8 +124,8 @@ def server_error_response(e):
|
|
126 |
pass
|
127 |
if len(e.args) > 1:
|
128 |
return get_json_result(
|
129 |
-
code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
130 |
-
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
|
131 |
|
132 |
|
133 |
def error_response(response_code, message=None):
|
@@ -168,7 +166,7 @@ def validate_request(*args, **kwargs):
|
|
168 |
error_string += "required argument values: {}".format(
|
169 |
",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
|
170 |
return get_json_result(
|
171 |
-
code=RetCode.ARGUMENT_ERROR, message=error_string)
|
172 |
return func(*_args, **_kwargs)
|
173 |
|
174 |
return decorated_function
|
@@ -193,7 +191,7 @@ def send_file_in_mem(data, filename):
|
|
193 |
return send_file(f, as_attachment=True, attachment_filename=filename)
|
194 |
|
195 |
|
196 |
-
def get_json_result(code=RetCode.SUCCESS, message='success', data=None):
|
197 |
response = {"code": code, "message": message, "data": data}
|
198 |
return jsonify(response)
|
199 |
|
@@ -204,7 +202,7 @@ def apikey_required(func):
|
|
204 |
objs = APIToken.query(token=token)
|
205 |
if not objs:
|
206 |
return build_error_result(
|
207 |
-
message='API-KEY is invalid!', code=RetCode.FORBIDDEN
|
208 |
)
|
209 |
kwargs['tenant_id'] = objs[0].tenant_id
|
210 |
return func(*args, **kwargs)
|
@@ -212,14 +210,14 @@ def apikey_required(func):
|
|
212 |
return decorated_function
|
213 |
|
214 |
|
215 |
-
def build_error_result(code=RetCode.FORBIDDEN, message='success'):
|
216 |
response = {"code": code, "message": message}
|
217 |
response = jsonify(response)
|
218 |
response.status_code = code
|
219 |
return response
|
220 |
|
221 |
|
222 |
-
def construct_response(code=RetCode.SUCCESS,
|
223 |
message='success', data=None, auth=None):
|
224 |
result_dict = {"code": code, "message": message, "data": data}
|
225 |
response_dict = {}
|
@@ -239,7 +237,7 @@ def construct_response(code=RetCode.SUCCESS,
|
|
239 |
return response
|
240 |
|
241 |
|
242 |
-
def construct_result(code=RetCode.DATA_ERROR, message='data is missing'):
|
243 |
import re
|
244 |
result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)}
|
245 |
response = {}
|
@@ -251,7 +249,7 @@ def construct_result(code=RetCode.DATA_ERROR, message='data is missing'):
|
|
251 |
return jsonify(response)
|
252 |
|
253 |
|
254 |
-
def construct_json_result(code=RetCode.SUCCESS, message='success', data=None):
|
255 |
if data is None:
|
256 |
return jsonify({"code": code, "message": message})
|
257 |
else:
|
@@ -262,12 +260,12 @@ def construct_error_response(e):
|
|
262 |
logging.exception(e)
|
263 |
try:
|
264 |
if e.code == 401:
|
265 |
-
return construct_json_result(code=RetCode.UNAUTHORIZED, message=repr(e))
|
266 |
except BaseException:
|
267 |
pass
|
268 |
if len(e.args) > 1:
|
269 |
-
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
270 |
-
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
|
271 |
|
272 |
|
273 |
def token_required(func):
|
@@ -280,7 +278,7 @@ def token_required(func):
|
|
280 |
objs = APIToken.query(token=token)
|
281 |
if not objs:
|
282 |
return get_json_result(
|
283 |
-
data=False, message='Token is not valid!', code=RetCode.AUTHENTICATION_ERROR
|
284 |
)
|
285 |
kwargs['tenant_id'] = objs[0].tenant_id
|
286 |
return func(*args, **kwargs)
|
@@ -288,7 +286,7 @@ def token_required(func):
|
|
288 |
return decorated_function
|
289 |
|
290 |
|
291 |
-
def get_result(code=RetCode.SUCCESS, message="", data=None):
|
292 |
if code == 0:
|
293 |
if data is not None:
|
294 |
response = {"code": code, "data": data}
|
@@ -299,7 +297,7 @@ def get_result(code=RetCode.SUCCESS, message="", data=None):
|
|
299 |
return jsonify(response)
|
300 |
|
301 |
|
302 |
-
def get_error_data_result(message='Sorry! Data missing!', code=RetCode.DATA_ERROR,
|
303 |
):
|
304 |
import re
|
305 |
result_dict = {
|
|
|
34 |
from werkzeug.http import HTTP_STATUS_CODES
|
35 |
|
36 |
from api.db.db_models import APIToken
|
37 |
+
from api import settings
|
38 |
+
|
39 |
+
from api import settings
|
|
|
|
|
40 |
from api.utils import CustomJSONEncoder, get_uuid
|
41 |
from api.utils import json_dumps
|
42 |
|
|
|
57 |
{}).items()}
|
58 |
prepped = requests.Request(**kwargs).prepare()
|
59 |
|
60 |
+
if settings.CLIENT_AUTHENTICATION and settings.HTTP_APP_KEY and settings.SECRET_KEY:
|
61 |
timestamp = str(round(time() * 1000))
|
62 |
nonce = str(uuid1())
|
63 |
+
signature = b64encode(HMAC(settings.SECRET_KEY.encode('ascii'), b'\n'.join([
|
64 |
timestamp.encode('ascii'),
|
65 |
nonce.encode('ascii'),
|
66 |
+
settings.HTTP_APP_KEY.encode('ascii'),
|
67 |
prepped.path_url.encode('ascii'),
|
68 |
prepped.body if kwargs.get('json') else b'',
|
69 |
urlencode(
|
|
|
77 |
prepped.headers.update({
|
78 |
'TIMESTAMP': timestamp,
|
79 |
'NONCE': nonce,
|
80 |
+
'APP-KEY': settings.HTTP_APP_KEY,
|
81 |
'SIGNATURE': signature,
|
82 |
})
|
83 |
|
|
|
87 |
def get_exponential_backoff_interval(retries, full_jitter=False):
|
88 |
"""Calculate the exponential backoff wait time."""
|
89 |
# Will be zero if factor equals 0
|
90 |
+
countdown = min(settings.REQUEST_MAX_WAIT_SEC, settings.REQUEST_WAIT_SEC * (2 ** retries))
|
91 |
# Full jitter according to
|
92 |
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
|
93 |
if full_jitter:
|
|
|
96 |
return max(0, countdown)
|
97 |
|
98 |
|
99 |
+
def get_data_error_result(code=settings.RetCode.DATA_ERROR,
|
100 |
message='Sorry! Data missing!'):
|
101 |
import re
|
102 |
result_dict = {
|
|
|
124 |
pass
|
125 |
if len(e.args) > 1:
|
126 |
return get_json_result(
|
127 |
+
code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
128 |
+
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
|
129 |
|
130 |
|
131 |
def error_response(response_code, message=None):
|
|
|
166 |
error_string += "required argument values: {}".format(
|
167 |
",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
|
168 |
return get_json_result(
|
169 |
+
code=settings.RetCode.ARGUMENT_ERROR, message=error_string)
|
170 |
return func(*_args, **_kwargs)
|
171 |
|
172 |
return decorated_function
|
|
|
191 |
return send_file(f, as_attachment=True, attachment_filename=filename)
|
192 |
|
193 |
|
194 |
+
def get_json_result(code=settings.RetCode.SUCCESS, message='success', data=None):
|
195 |
response = {"code": code, "message": message, "data": data}
|
196 |
return jsonify(response)
|
197 |
|
|
|
202 |
objs = APIToken.query(token=token)
|
203 |
if not objs:
|
204 |
return build_error_result(
|
205 |
+
message='API-KEY is invalid!', code=settings.RetCode.FORBIDDEN
|
206 |
)
|
207 |
kwargs['tenant_id'] = objs[0].tenant_id
|
208 |
return func(*args, **kwargs)
|
|
|
210 |
return decorated_function
|
211 |
|
212 |
|
213 |
+
def build_error_result(code=settings.RetCode.FORBIDDEN, message='success'):
|
214 |
response = {"code": code, "message": message}
|
215 |
response = jsonify(response)
|
216 |
response.status_code = code
|
217 |
return response
|
218 |
|
219 |
|
220 |
+
def construct_response(code=settings.RetCode.SUCCESS,
|
221 |
message='success', data=None, auth=None):
|
222 |
result_dict = {"code": code, "message": message, "data": data}
|
223 |
response_dict = {}
|
|
|
237 |
return response
|
238 |
|
239 |
|
240 |
+
def construct_result(code=settings.RetCode.DATA_ERROR, message='data is missing'):
|
241 |
import re
|
242 |
result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)}
|
243 |
response = {}
|
|
|
249 |
return jsonify(response)
|
250 |
|
251 |
|
252 |
+
def construct_json_result(code=settings.RetCode.SUCCESS, message='success', data=None):
|
253 |
if data is None:
|
254 |
return jsonify({"code": code, "message": message})
|
255 |
else:
|
|
|
260 |
logging.exception(e)
|
261 |
try:
|
262 |
if e.code == 401:
|
263 |
+
return construct_json_result(code=settings.RetCode.UNAUTHORIZED, message=repr(e))
|
264 |
except BaseException:
|
265 |
pass
|
266 |
if len(e.args) > 1:
|
267 |
+
return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
268 |
+
return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
|
269 |
|
270 |
|
271 |
def token_required(func):
|
|
|
278 |
objs = APIToken.query(token=token)
|
279 |
if not objs:
|
280 |
return get_json_result(
|
281 |
+
data=False, message='Token is not valid!', code=settings.RetCode.AUTHENTICATION_ERROR
|
282 |
)
|
283 |
kwargs['tenant_id'] = objs[0].tenant_id
|
284 |
return func(*args, **kwargs)
|
|
|
286 |
return decorated_function
|
287 |
|
288 |
|
289 |
+
def get_result(code=settings.RetCode.SUCCESS, message="", data=None):
|
290 |
if code == 0:
|
291 |
if data is not None:
|
292 |
response = {"code": code, "data": data}
|
|
|
297 |
return jsonify(response)
|
298 |
|
299 |
|
300 |
+
def get_error_data_result(message='Sorry! Data missing!', code=settings.RetCode.DATA_ERROR,
|
301 |
):
|
302 |
import re
|
303 |
result_dict = {
|
deepdoc/parser/pdf_parser.py
CHANGED
@@ -24,7 +24,7 @@ import numpy as np
|
|
24 |
from timeit import default_timer as timer
|
25 |
from pypdf import PdfReader as pdf2_read
|
26 |
|
27 |
-
from api
|
28 |
from api.utils.file_utils import get_project_base_directory
|
29 |
from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer
|
30 |
from rag.nlp import rag_tokenizer
|
@@ -41,7 +41,7 @@ class RAGFlowPdfParser:
|
|
41 |
self.tbl_det = TableStructureRecognizer()
|
42 |
|
43 |
self.updown_cnt_mdl = xgb.Booster()
|
44 |
-
if not LIGHTEN:
|
45 |
try:
|
46 |
import torch
|
47 |
if torch.cuda.is_available():
|
|
|
24 |
from timeit import default_timer as timer
|
25 |
from pypdf import PdfReader as pdf2_read
|
26 |
|
27 |
+
from api import settings
|
28 |
from api.utils.file_utils import get_project_base_directory
|
29 |
from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer
|
30 |
from rag.nlp import rag_tokenizer
|
|
|
41 |
self.tbl_det = TableStructureRecognizer()
|
42 |
|
43 |
self.updown_cnt_mdl = xgb.Booster()
|
44 |
+
if not settings.LIGHTEN:
|
45 |
try:
|
46 |
import torch
|
47 |
if torch.cuda.is_available():
|
graphrag/claim_extractor.py
CHANGED
@@ -252,13 +252,13 @@ if __name__ == "__main__":
|
|
252 |
|
253 |
from api.db import LLMType
|
254 |
from api.db.services.llm_service import LLMBundle
|
255 |
-
from api
|
256 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
257 |
|
258 |
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
|
259 |
|
260 |
ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
261 |
-
docs = [d["content_with_weight"] for d in retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
|
262 |
info = {
|
263 |
"input_text": docs,
|
264 |
"entity_specs": "organization, person",
|
|
|
252 |
|
253 |
from api.db import LLMType
|
254 |
from api.db.services.llm_service import LLMBundle
|
255 |
+
from api import settings
|
256 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
257 |
|
258 |
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
|
259 |
|
260 |
ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
261 |
+
docs = [d["content_with_weight"] for d in settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
|
262 |
info = {
|
263 |
"input_text": docs,
|
264 |
"entity_specs": "organization, person",
|
graphrag/smoke.py
CHANGED
@@ -30,14 +30,14 @@ if __name__ == "__main__":
|
|
30 |
|
31 |
from api.db import LLMType
|
32 |
from api.db.services.llm_service import LLMBundle
|
33 |
-
from api
|
34 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
35 |
|
36 |
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
|
37 |
|
38 |
ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
39 |
docs = [d["content_with_weight"] for d in
|
40 |
-
retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])]
|
41 |
graph = ex(docs)
|
42 |
|
43 |
er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT))
|
|
|
30 |
|
31 |
from api.db import LLMType
|
32 |
from api.db.services.llm_service import LLMBundle
|
33 |
+
from api import settings
|
34 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
35 |
|
36 |
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
|
37 |
|
38 |
ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
39 |
docs = [d["content_with_weight"] for d in
|
40 |
+
settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])]
|
41 |
graph = ex(docs)
|
42 |
|
43 |
er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT))
|
rag/benchmark.py
CHANGED
@@ -23,7 +23,7 @@ from collections import defaultdict
|
|
23 |
from api.db import LLMType
|
24 |
from api.db.services.llm_service import LLMBundle
|
25 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
26 |
-
from api
|
27 |
from api.utils import get_uuid
|
28 |
from rag.nlp import tokenize, search
|
29 |
from ranx import evaluate
|
@@ -52,7 +52,7 @@ class Benchmark:
|
|
52 |
run = defaultdict(dict)
|
53 |
query_list = list(qrels.keys())
|
54 |
for query in query_list:
|
55 |
-
ranks = retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
56 |
0.0, self.vector_similarity_weight)
|
57 |
if len(ranks["chunks"]) == 0:
|
58 |
print(f"deleted query: {query}")
|
@@ -81,9 +81,9 @@ class Benchmark:
|
|
81 |
def init_index(self, vector_size: int):
|
82 |
if self.initialized_index:
|
83 |
return
|
84 |
-
if docStoreConn.indexExist(self.index_name, self.kb_id):
|
85 |
-
docStoreConn.deleteIdx(self.index_name, self.kb_id)
|
86 |
-
docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
|
87 |
self.initialized_index = True
|
88 |
|
89 |
def ms_marco_index(self, file_path, index_name):
|
@@ -118,13 +118,13 @@ class Benchmark:
|
|
118 |
docs_count += len(docs)
|
119 |
docs, vector_size = self.embedding(docs)
|
120 |
self.init_index(vector_size)
|
121 |
-
docStoreConn.insert(docs, self.index_name, self.kb_id)
|
122 |
docs = []
|
123 |
|
124 |
if docs:
|
125 |
docs, vector_size = self.embedding(docs)
|
126 |
self.init_index(vector_size)
|
127 |
-
docStoreConn.insert(docs, self.index_name, self.kb_id)
|
128 |
return qrels, texts
|
129 |
|
130 |
def trivia_qa_index(self, file_path, index_name):
|
@@ -159,12 +159,12 @@ class Benchmark:
|
|
159 |
docs_count += len(docs)
|
160 |
docs, vector_size = self.embedding(docs)
|
161 |
self.init_index(vector_size)
|
162 |
-
docStoreConn.insert(docs,self.index_name)
|
163 |
docs = []
|
164 |
|
165 |
docs, vector_size = self.embedding(docs)
|
166 |
self.init_index(vector_size)
|
167 |
-
docStoreConn.insert(docs, self.index_name)
|
168 |
return qrels, texts
|
169 |
|
170 |
def miracl_index(self, file_path, corpus_path, index_name):
|
@@ -214,12 +214,12 @@ class Benchmark:
|
|
214 |
docs_count += len(docs)
|
215 |
docs, vector_size = self.embedding(docs)
|
216 |
self.init_index(vector_size)
|
217 |
-
docStoreConn.insert(docs, self.index_name)
|
218 |
docs = []
|
219 |
|
220 |
docs, vector_size = self.embedding(docs)
|
221 |
self.init_index(vector_size)
|
222 |
-
docStoreConn.insert(docs, self.index_name)
|
223 |
return qrels, texts
|
224 |
|
225 |
def save_results(self, qrels, run, texts, dataset, file_path):
|
|
|
23 |
from api.db import LLMType
|
24 |
from api.db.services.llm_service import LLMBundle
|
25 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
26 |
+
from api import settings
|
27 |
from api.utils import get_uuid
|
28 |
from rag.nlp import tokenize, search
|
29 |
from ranx import evaluate
|
|
|
52 |
run = defaultdict(dict)
|
53 |
query_list = list(qrels.keys())
|
54 |
for query in query_list:
|
55 |
+
ranks = settings.retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
56 |
0.0, self.vector_similarity_weight)
|
57 |
if len(ranks["chunks"]) == 0:
|
58 |
print(f"deleted query: {query}")
|
|
|
81 |
def init_index(self, vector_size: int):
|
82 |
if self.initialized_index:
|
83 |
return
|
84 |
+
if settings.docStoreConn.indexExist(self.index_name, self.kb_id):
|
85 |
+
settings.docStoreConn.deleteIdx(self.index_name, self.kb_id)
|
86 |
+
settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
|
87 |
self.initialized_index = True
|
88 |
|
89 |
def ms_marco_index(self, file_path, index_name):
|
|
|
118 |
docs_count += len(docs)
|
119 |
docs, vector_size = self.embedding(docs)
|
120 |
self.init_index(vector_size)
|
121 |
+
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
122 |
docs = []
|
123 |
|
124 |
if docs:
|
125 |
docs, vector_size = self.embedding(docs)
|
126 |
self.init_index(vector_size)
|
127 |
+
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
128 |
return qrels, texts
|
129 |
|
130 |
def trivia_qa_index(self, file_path, index_name):
|
|
|
159 |
docs_count += len(docs)
|
160 |
docs, vector_size = self.embedding(docs)
|
161 |
self.init_index(vector_size)
|
162 |
+
settings.docStoreConn.insert(docs,self.index_name)
|
163 |
docs = []
|
164 |
|
165 |
docs, vector_size = self.embedding(docs)
|
166 |
self.init_index(vector_size)
|
167 |
+
settings.docStoreConn.insert(docs, self.index_name)
|
168 |
return qrels, texts
|
169 |
|
170 |
def miracl_index(self, file_path, corpus_path, index_name):
|
|
|
214 |
docs_count += len(docs)
|
215 |
docs, vector_size = self.embedding(docs)
|
216 |
self.init_index(vector_size)
|
217 |
+
settings.docStoreConn.insert(docs, self.index_name)
|
218 |
docs = []
|
219 |
|
220 |
docs, vector_size = self.embedding(docs)
|
221 |
self.init_index(vector_size)
|
222 |
+
settings.docStoreConn.insert(docs, self.index_name)
|
223 |
return qrels, texts
|
224 |
|
225 |
def save_results(self, qrels, run, texts, dataset, file_path):
|
rag/llm/embedding_model.py
CHANGED
@@ -28,7 +28,7 @@ from openai import OpenAI
|
|
28 |
import numpy as np
|
29 |
import asyncio
|
30 |
|
31 |
-
from api
|
32 |
from api.utils.file_utils import get_home_cache_dir
|
33 |
from rag.utils import num_tokens_from_string, truncate
|
34 |
import google.generativeai as genai
|
@@ -60,7 +60,7 @@ class DefaultEmbedding(Base):
|
|
60 |
^_-
|
61 |
|
62 |
"""
|
63 |
-
if not LIGHTEN and not DefaultEmbedding._model:
|
64 |
with DefaultEmbedding._model_lock:
|
65 |
from FlagEmbedding import FlagModel
|
66 |
import torch
|
@@ -248,7 +248,7 @@ class FastEmbed(Base):
|
|
248 |
threads: Optional[int] = None,
|
249 |
**kwargs,
|
250 |
):
|
251 |
-
if not LIGHTEN and not FastEmbed._model:
|
252 |
from fastembed import TextEmbedding
|
253 |
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
254 |
|
@@ -294,7 +294,7 @@ class YoudaoEmbed(Base):
|
|
294 |
_client = None
|
295 |
|
296 |
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
|
297 |
-
if not LIGHTEN and not YoudaoEmbed._client:
|
298 |
from BCEmbedding import EmbeddingModel as qanthing
|
299 |
try:
|
300 |
logging.info("LOADING BCE...")
|
|
|
28 |
import numpy as np
|
29 |
import asyncio
|
30 |
|
31 |
+
from api import settings
|
32 |
from api.utils.file_utils import get_home_cache_dir
|
33 |
from rag.utils import num_tokens_from_string, truncate
|
34 |
import google.generativeai as genai
|
|
|
60 |
^_-
|
61 |
|
62 |
"""
|
63 |
+
if not settings.LIGHTEN and not DefaultEmbedding._model:
|
64 |
with DefaultEmbedding._model_lock:
|
65 |
from FlagEmbedding import FlagModel
|
66 |
import torch
|
|
|
248 |
threads: Optional[int] = None,
|
249 |
**kwargs,
|
250 |
):
|
251 |
+
if not settings.LIGHTEN and not FastEmbed._model:
|
252 |
from fastembed import TextEmbedding
|
253 |
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
254 |
|
|
|
294 |
_client = None
|
295 |
|
296 |
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
|
297 |
+
if not settings.LIGHTEN and not YoudaoEmbed._client:
|
298 |
from BCEmbedding import EmbeddingModel as qanthing
|
299 |
try:
|
300 |
logging.info("LOADING BCE...")
|
rag/llm/rerank_model.py
CHANGED
@@ -23,7 +23,7 @@ import os
|
|
23 |
from abc import ABC
|
24 |
import numpy as np
|
25 |
|
26 |
-
from api
|
27 |
from api.utils.file_utils import get_home_cache_dir
|
28 |
from rag.utils import num_tokens_from_string, truncate
|
29 |
import json
|
@@ -57,7 +57,7 @@ class DefaultRerank(Base):
|
|
57 |
^_-
|
58 |
|
59 |
"""
|
60 |
-
if not LIGHTEN and not DefaultRerank._model:
|
61 |
import torch
|
62 |
from FlagEmbedding import FlagReranker
|
63 |
with DefaultRerank._model_lock:
|
@@ -121,7 +121,7 @@ class YoudaoRerank(DefaultRerank):
|
|
121 |
_model_lock = threading.Lock()
|
122 |
|
123 |
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
|
124 |
-
if not LIGHTEN and not YoudaoRerank._model:
|
125 |
from BCEmbedding import RerankerModel
|
126 |
with YoudaoRerank._model_lock:
|
127 |
if not YoudaoRerank._model:
|
|
|
23 |
from abc import ABC
|
24 |
import numpy as np
|
25 |
|
26 |
+
from api import settings
|
27 |
from api.utils.file_utils import get_home_cache_dir
|
28 |
from rag.utils import num_tokens_from_string, truncate
|
29 |
import json
|
|
|
57 |
^_-
|
58 |
|
59 |
"""
|
60 |
+
if not settings.LIGHTEN and not DefaultRerank._model:
|
61 |
import torch
|
62 |
from FlagEmbedding import FlagReranker
|
63 |
with DefaultRerank._model_lock:
|
|
|
121 |
_model_lock = threading.Lock()
|
122 |
|
123 |
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
|
124 |
+
if not settings.LIGHTEN and not YoudaoRerank._model:
|
125 |
from BCEmbedding import RerankerModel
|
126 |
with YoudaoRerank._model_lock:
|
127 |
if not YoudaoRerank._model:
|
rag/svr/task_executor.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16 |
import logging
|
17 |
import sys
|
18 |
from api.utils.log_utils import initRootLogger
|
|
|
19 |
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
20 |
initRootLogger(f"task_executor_{CONSUMER_NO}")
|
21 |
for module in ["pdfminer"]:
|
@@ -49,9 +50,10 @@ from api.db.services.document_service import DocumentService
|
|
49 |
from api.db.services.llm_service import LLMBundle
|
50 |
from api.db.services.task_service import TaskService
|
51 |
from api.db.services.file2document_service import File2DocumentService
|
52 |
-
from api
|
53 |
from api.db.db_models import close_connection
|
54 |
-
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio,
|
|
|
55 |
from rag.nlp import search, rag_tokenizer
|
56 |
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
57 |
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME
|
@@ -88,6 +90,7 @@ PENDING_TASKS = 0
|
|
88 |
HEAD_CREATED_AT = ""
|
89 |
HEAD_DETAIL = ""
|
90 |
|
|
|
91 |
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
|
92 |
global PAYLOAD
|
93 |
if prog is not None and prog < 0:
|
@@ -171,7 +174,8 @@ def build(row):
|
|
171 |
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
172 |
except TimeoutError:
|
173 |
callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
|
174 |
-
logging.exception(
|
|
|
175 |
return
|
176 |
except Exception as e:
|
177 |
if re.search("(No such file|not found)", str(e)):
|
@@ -188,7 +192,7 @@ def build(row):
|
|
188 |
logging.info("Chunking({}) {}/{} done".format(timer() - st, row["location"], row["name"]))
|
189 |
except Exception as e:
|
190 |
callback(-1, "Internal server error while chunking: %s" %
|
191 |
-
|
192 |
logging.exception("Chunking {}/{} got exception".format(row["location"], row["name"]))
|
193 |
return
|
194 |
|
@@ -226,7 +230,8 @@ def build(row):
|
|
226 |
STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue())
|
227 |
el += timer() - st
|
228 |
except Exception:
|
229 |
-
logging.exception(
|
|
|
230 |
|
231 |
d["img_id"] = "{}-{}".format(row["kb_id"], d["id"])
|
232 |
del d["image"]
|
@@ -241,7 +246,7 @@ def build(row):
|
|
241 |
d["important_kwd"] = keyword_extraction(chat_mdl, d["content_with_weight"],
|
242 |
row["parser_config"]["auto_keywords"]).split(",")
|
243 |
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
|
244 |
-
callback(msg="Keywords generation completed in {:.2f}s".format(timer()-st))
|
245 |
|
246 |
if row["parser_config"].get("auto_questions", 0):
|
247 |
st = timer()
|
@@ -255,14 +260,14 @@ def build(row):
|
|
255 |
d["content_ltks"] += " " + qst
|
256 |
if "content_sm_ltks" in d:
|
257 |
d["content_sm_ltks"] += " " + rag_tokenizer.fine_grained_tokenize(qst)
|
258 |
-
callback(msg="Question generation completed in {:.2f}s".format(timer()-st))
|
259 |
|
260 |
return docs
|
261 |
|
262 |
|
263 |
def init_kb(row, vector_size: int):
|
264 |
idxnm = search.index_name(row["tenant_id"])
|
265 |
-
return docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
|
266 |
|
267 |
|
268 |
def embedding(docs, mdl, parser_config=None, callback=None):
|
@@ -313,7 +318,8 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
|
313 |
vector_size = len(vts[0])
|
314 |
vctr_nm = "q_%d_vec" % vector_size
|
315 |
chunks = []
|
316 |
-
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
|
|
|
317 |
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
318 |
|
319 |
raptor = Raptor(
|
@@ -384,7 +390,8 @@ def main():
|
|
384 |
# TODO: exception handler
|
385 |
## set_progress(r["did"], -1, "ERROR: ")
|
386 |
callback(
|
387 |
-
msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks),
|
|
|
388 |
)
|
389 |
st = timer()
|
390 |
try:
|
@@ -403,18 +410,18 @@ def main():
|
|
403 |
es_r = ""
|
404 |
es_bulk_size = 4
|
405 |
for b in range(0, len(cks), es_bulk_size):
|
406 |
-
es_r = docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
|
407 |
if b % 128 == 0:
|
408 |
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
|
409 |
|
410 |
logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
411 |
if es_r:
|
412 |
callback(-1, "Insert chunk error, detail info please check log file. Please also check ES status!")
|
413 |
-
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
|
414 |
logging.error('Insert chunk error: ' + str(es_r))
|
415 |
else:
|
416 |
if TaskService.do_cancel(r["id"]):
|
417 |
-
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
|
418 |
continue
|
419 |
callback(msg="Indexing elapsed in {:.2f}s.".format(timer() - st))
|
420 |
callback(1., "Done!")
|
@@ -435,7 +442,7 @@ def report_status():
|
|
435 |
if PENDING_TASKS > 0:
|
436 |
head_info = REDIS_CONN.queue_head(SVR_QUEUE_NAME)
|
437 |
if head_info is not None:
|
438 |
-
seconds = int(head_info[0].split("-")[0])/1000
|
439 |
HEAD_CREATED_AT = datetime.fromtimestamp(seconds).isoformat()
|
440 |
HEAD_DETAIL = head_info[1]
|
441 |
|
@@ -452,7 +459,7 @@ def report_status():
|
|
452 |
REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
|
453 |
logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
|
454 |
|
455 |
-
expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60*30)
|
456 |
if expired > 0:
|
457 |
REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
|
458 |
except Exception:
|
|
|
16 |
import logging
|
17 |
import sys
|
18 |
from api.utils.log_utils import initRootLogger
|
19 |
+
|
20 |
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
21 |
initRootLogger(f"task_executor_{CONSUMER_NO}")
|
22 |
for module in ["pdfminer"]:
|
|
|
50 |
from api.db.services.llm_service import LLMBundle
|
51 |
from api.db.services.task_service import TaskService
|
52 |
from api.db.services.file2document_service import File2DocumentService
|
53 |
+
from api import settings
|
54 |
from api.db.db_models import close_connection
|
55 |
+
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
|
56 |
+
knowledge_graph, email
|
57 |
from rag.nlp import search, rag_tokenizer
|
58 |
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
59 |
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME
|
|
|
90 |
HEAD_CREATED_AT = ""
|
91 |
HEAD_DETAIL = ""
|
92 |
|
93 |
+
|
94 |
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
|
95 |
global PAYLOAD
|
96 |
if prog is not None and prog < 0:
|
|
|
174 |
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
175 |
except TimeoutError:
|
176 |
callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
|
177 |
+
logging.exception(
|
178 |
+
"Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"]))
|
179 |
return
|
180 |
except Exception as e:
|
181 |
if re.search("(No such file|not found)", str(e)):
|
|
|
192 |
logging.info("Chunking({}) {}/{} done".format(timer() - st, row["location"], row["name"]))
|
193 |
except Exception as e:
|
194 |
callback(-1, "Internal server error while chunking: %s" %
|
195 |
+
str(e).replace("'", ""))
|
196 |
logging.exception("Chunking {}/{} got exception".format(row["location"], row["name"]))
|
197 |
return
|
198 |
|
|
|
230 |
STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue())
|
231 |
el += timer() - st
|
232 |
except Exception:
|
233 |
+
logging.exception(
|
234 |
+
"Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"]))
|
235 |
|
236 |
d["img_id"] = "{}-{}".format(row["kb_id"], d["id"])
|
237 |
del d["image"]
|
|
|
246 |
d["important_kwd"] = keyword_extraction(chat_mdl, d["content_with_weight"],
|
247 |
row["parser_config"]["auto_keywords"]).split(",")
|
248 |
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
|
249 |
+
callback(msg="Keywords generation completed in {:.2f}s".format(timer() - st))
|
250 |
|
251 |
if row["parser_config"].get("auto_questions", 0):
|
252 |
st = timer()
|
|
|
260 |
d["content_ltks"] += " " + qst
|
261 |
if "content_sm_ltks" in d:
|
262 |
d["content_sm_ltks"] += " " + rag_tokenizer.fine_grained_tokenize(qst)
|
263 |
+
callback(msg="Question generation completed in {:.2f}s".format(timer() - st))
|
264 |
|
265 |
return docs
|
266 |
|
267 |
|
268 |
def init_kb(row, vector_size: int):
|
269 |
idxnm = search.index_name(row["tenant_id"])
|
270 |
+
return settings.docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
|
271 |
|
272 |
|
273 |
def embedding(docs, mdl, parser_config=None, callback=None):
|
|
|
318 |
vector_size = len(vts[0])
|
319 |
vctr_nm = "q_%d_vec" % vector_size
|
320 |
chunks = []
|
321 |
+
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
|
322 |
+
fields=["content_with_weight", vctr_nm]):
|
323 |
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
324 |
|
325 |
raptor = Raptor(
|
|
|
390 |
# TODO: exception handler
|
391 |
## set_progress(r["did"], -1, "ERROR: ")
|
392 |
callback(
|
393 |
+
msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks),
|
394 |
+
timer() - st)
|
395 |
)
|
396 |
st = timer()
|
397 |
try:
|
|
|
410 |
es_r = ""
|
411 |
es_bulk_size = 4
|
412 |
for b in range(0, len(cks), es_bulk_size):
|
413 |
+
es_r = settings.docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
|
414 |
if b % 128 == 0:
|
415 |
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
|
416 |
|
417 |
logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
418 |
if es_r:
|
419 |
callback(-1, "Insert chunk error, detail info please check log file. Please also check ES status!")
|
420 |
+
settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
|
421 |
logging.error('Insert chunk error: ' + str(es_r))
|
422 |
else:
|
423 |
if TaskService.do_cancel(r["id"]):
|
424 |
+
settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
|
425 |
continue
|
426 |
callback(msg="Indexing elapsed in {:.2f}s.".format(timer() - st))
|
427 |
callback(1., "Done!")
|
|
|
442 |
if PENDING_TASKS > 0:
|
443 |
head_info = REDIS_CONN.queue_head(SVR_QUEUE_NAME)
|
444 |
if head_info is not None:
|
445 |
+
seconds = int(head_info[0].split("-")[0]) / 1000
|
446 |
HEAD_CREATED_AT = datetime.fromtimestamp(seconds).isoformat()
|
447 |
HEAD_DETAIL = head_info[1]
|
448 |
|
|
|
459 |
REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
|
460 |
logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
|
461 |
|
462 |
+
expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60 * 30)
|
463 |
if expired > 0:
|
464 |
REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
|
465 |
except Exception:
|