KevinHuSh
commited on
Commit
·
a49657b
1
Parent(s):
13080d4
add self-rag (#1070)
Browse files### What problem does this PR solve?
#1069
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- api/apps/api_app.py +24 -20
- api/apps/canvas_app.py +112 -0
- api/apps/conversation_app.py +3 -2
- api/apps/dialog_app.py +2 -2
- api/db/services/canvas_service.py +26 -0
- api/db/services/dialog_service.py +59 -10
- deepdoc/parser/pdf_parser.py +2 -0
- rag/llm/rerank_model.py +0 -1
- rag/nlp/query.py +6 -3
api/apps/api_app.py
CHANGED
@@ -198,15 +198,18 @@ def completion():
|
|
198 |
else: conv.reference[-1] = ans["reference"]
|
199 |
conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
|
200 |
|
|
|
|
|
|
|
|
|
|
|
201 |
def stream():
|
202 |
nonlocal dia, msg, req, conv
|
203 |
try:
|
204 |
for ans in chat(dia, msg, True, **req):
|
205 |
fillin_conv(ans)
|
206 |
-
|
207 |
-
|
208 |
-
chunk_i.pop('docnm_kwd')
|
209 |
-
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
210 |
API4ConversationService.append_message(conv.id, conv.to_dict())
|
211 |
except Exception as e:
|
212 |
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
|
@@ -554,23 +557,24 @@ def completion_faq():
|
|
554 |
"content": ""
|
555 |
}
|
556 |
]
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
fillin_conv(ans)
|
561 |
-
API4ConversationService.append_message(conv.id, conv.to_dict())
|
562 |
-
|
563 |
-
chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
|
564 |
-
for chunk_idx in chunk_idxs[:1]:
|
565 |
-
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
|
566 |
-
try:
|
567 |
-
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
|
568 |
-
response = MINIO.get(bkt, nm)
|
569 |
-
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
|
570 |
-
data.append(data_type_picture)
|
571 |
-
except Exception as e:
|
572 |
-
return server_error_response(e)
|
573 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
574 |
|
575 |
response = {"code": 200, "msg": "success", "data": data}
|
576 |
return response
|
|
|
198 |
else: conv.reference[-1] = ans["reference"]
|
199 |
conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
|
200 |
|
201 |
+
def rename_field(ans):
|
202 |
+
for chunk_i in ans['reference'].get('chunks', []):
|
203 |
+
chunk_i['doc_name'] = chunk_i['docnm_kwd']
|
204 |
+
chunk_i.pop('docnm_kwd')
|
205 |
+
|
206 |
def stream():
|
207 |
nonlocal dia, msg, req, conv
|
208 |
try:
|
209 |
for ans in chat(dia, msg, True, **req):
|
210 |
fillin_conv(ans)
|
211 |
+
rename_field(rename_field)
|
212 |
+
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
|
|
|
|
213 |
API4ConversationService.append_message(conv.id, conv.to_dict())
|
214 |
except Exception as e:
|
215 |
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
|
|
|
557 |
"content": ""
|
558 |
}
|
559 |
]
|
560 |
+
ans = ""
|
561 |
+
for a in chat(dia, msg, stream=False, **req):
|
562 |
+
ans = a
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
563 |
break
|
564 |
+
data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
|
565 |
+
fillin_conv(ans)
|
566 |
+
API4ConversationService.append_message(conv.id, conv.to_dict())
|
567 |
+
|
568 |
+
chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
|
569 |
+
for chunk_idx in chunk_idxs[:1]:
|
570 |
+
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
|
571 |
+
try:
|
572 |
+
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
|
573 |
+
response = MINIO.get(bkt, nm)
|
574 |
+
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
|
575 |
+
data.append(data_type_picture)
|
576 |
+
except Exception as e:
|
577 |
+
return server_error_response(e)
|
578 |
|
579 |
response = {"code": 200, "msg": "success", "data": data}
|
580 |
return response
|
api/apps/canvas_app.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
#
|
16 |
+
import json
|
17 |
+
|
18 |
+
from flask import request
|
19 |
+
from flask_login import login_required, current_user
|
20 |
+
|
21 |
+
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
22 |
+
from api.utils import get_uuid
|
23 |
+
from api.utils.api_utils import get_json_result, server_error_response, validate_request
|
24 |
+
from graph.canvas import Canvas
|
25 |
+
|
26 |
+
|
27 |
+
@manager.route('/templates', methods=['GET'])
|
28 |
+
@login_required
|
29 |
+
def templates():
|
30 |
+
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.get_all()])
|
31 |
+
|
32 |
+
|
33 |
+
@manager.route('/list', methods=['GET'])
|
34 |
+
@login_required
|
35 |
+
def canvas_list():
|
36 |
+
|
37 |
+
return get_json_result(data=[c.to_dict() for c in UserCanvasService.query(user_id=current_user.id)])
|
38 |
+
|
39 |
+
|
40 |
+
@manager.route('/rm', methods=['POST'])
|
41 |
+
@validate_request("canvas_ids")
|
42 |
+
@login_required
|
43 |
+
def rm():
|
44 |
+
for i in request.json["canvas_ids"]:
|
45 |
+
UserCanvasService.delete_by_id(i)
|
46 |
+
return get_json_result(data=True)
|
47 |
+
|
48 |
+
|
49 |
+
@manager.route('/set', methods=['POST'])
|
50 |
+
@validate_request("dsl", "title")
|
51 |
+
@login_required
|
52 |
+
def save():
|
53 |
+
req = request.json
|
54 |
+
req["user_id"] = current_user.id
|
55 |
+
if not isinstance(req["dsl"], str):req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
56 |
+
try:
|
57 |
+
Canvas(req["dsl"])
|
58 |
+
except Exception as e:
|
59 |
+
return server_error_response(e)
|
60 |
+
|
61 |
+
req["dsl"] = json.loads(req["dsl"])
|
62 |
+
if "id" not in req:
|
63 |
+
req["id"] = get_uuid()
|
64 |
+
if not UserCanvasService.save(**req):
|
65 |
+
return server_error_response("Fail to save canvas.")
|
66 |
+
else:
|
67 |
+
UserCanvasService.update_by_id(req["id"], req)
|
68 |
+
|
69 |
+
return get_json_result(data=req)
|
70 |
+
|
71 |
+
|
72 |
+
@manager.route('/get/<canvas_id>', methods=['GET'])
|
73 |
+
@login_required
|
74 |
+
def get(canvas_id):
|
75 |
+
e, c = UserCanvasService.get_by_id(canvas_id)
|
76 |
+
if not e:
|
77 |
+
return server_error_response("canvas not found.")
|
78 |
+
return get_json_result(data=c.to_dict())
|
79 |
+
|
80 |
+
|
81 |
+
@manager.route('/run', methods=['POST'])
|
82 |
+
@validate_request("id", "dsl")
|
83 |
+
@login_required
|
84 |
+
def run():
|
85 |
+
req = request.json
|
86 |
+
if not isinstance(req["dsl"], str): req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
87 |
+
try:
|
88 |
+
canvas = Canvas(req["dsl"], current_user.id)
|
89 |
+
ans = canvas.run()
|
90 |
+
req["dsl"] = json.loads(str(canvas))
|
91 |
+
UserCanvasService.update_by_id(req["id"], dsl=req["dsl"])
|
92 |
+
return get_json_result(data=req["dsl"])
|
93 |
+
except Exception as e:
|
94 |
+
return server_error_response(e)
|
95 |
+
|
96 |
+
|
97 |
+
@manager.route('/reset', methods=['POST'])
|
98 |
+
@validate_request("canvas_id")
|
99 |
+
@login_required
|
100 |
+
def reset():
|
101 |
+
req = request.json
|
102 |
+
try:
|
103 |
+
user_canvas = UserCanvasService.get_by_id(req["canvas_id"])
|
104 |
+
canvas = Canvas(req["dsl"], current_user.id)
|
105 |
+
canvas.reset()
|
106 |
+
req["dsl"] = json.loads(str(canvas))
|
107 |
+
UserCanvasService.update_by_id(req["canvas_id"], dsl=req["dsl"])
|
108 |
+
return get_json_result(data=req["dsl"])
|
109 |
+
except Exception as e:
|
110 |
+
return server_error_response(e)
|
111 |
+
|
112 |
+
|
api/apps/conversation_app.py
CHANGED
@@ -13,7 +13,8 @@
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
-
from
|
|
|
17 |
from flask_login import login_required
|
18 |
from api.db.services.dialog_service import DialogService, ConversationService, chat
|
19 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
@@ -121,7 +122,7 @@ def completion():
|
|
121 |
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
122 |
if not e:
|
123 |
return get_data_error_result(retmsg="Conversation not found!")
|
124 |
-
conv.message.append(msg[-1])
|
125 |
e, dia = DialogService.get_by_id(conv.dialog_id)
|
126 |
if not e:
|
127 |
return get_data_error_result(retmsg="Dialog not found!")
|
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
+
from copy import deepcopy
|
17 |
+
from flask import request, Response
|
18 |
from flask_login import login_required
|
19 |
from api.db.services.dialog_service import DialogService, ConversationService, chat
|
20 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
|
|
122 |
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
123 |
if not e:
|
124 |
return get_data_error_result(retmsg="Conversation not found!")
|
125 |
+
conv.message.append(deepcopy(msg[-1]))
|
126 |
e, dia = DialogService.get_by_id(conv.dialog_id)
|
127 |
if not e:
|
128 |
return get_data_error_result(retmsg="Dialog not found!")
|
api/apps/dialog_app.py
CHANGED
@@ -31,8 +31,8 @@ def set_dialog():
|
|
31 |
req = request.json
|
32 |
dialog_id = req.get("dialog_id")
|
33 |
name = req.get("name", "New Dialog")
|
34 |
-
icon = req.get("icon", "")
|
35 |
description = req.get("description", "A helpful Dialog")
|
|
|
36 |
top_n = req.get("top_n", 6)
|
37 |
top_k = req.get("top_k", 1024)
|
38 |
rerank_id = req.get("rerank_id", "")
|
@@ -92,7 +92,7 @@ def set_dialog():
|
|
92 |
"rerank_id": rerank_id,
|
93 |
"similarity_threshold": similarity_threshold,
|
94 |
"vector_similarity_weight": vector_similarity_weight,
|
95 |
-
"icon": icon
|
96 |
}
|
97 |
if not DialogService.save(**dia):
|
98 |
return get_data_error_result(retmsg="Fail to new a dialog!")
|
|
|
31 |
req = request.json
|
32 |
dialog_id = req.get("dialog_id")
|
33 |
name = req.get("name", "New Dialog")
|
|
|
34 |
description = req.get("description", "A helpful Dialog")
|
35 |
+
icon = req.get("icon", "")
|
36 |
top_n = req.get("top_n", 6)
|
37 |
top_k = req.get("top_k", 1024)
|
38 |
rerank_id = req.get("rerank_id", "")
|
|
|
92 |
"rerank_id": rerank_id,
|
93 |
"similarity_threshold": similarity_threshold,
|
94 |
"vector_similarity_weight": vector_similarity_weight,
|
95 |
+
"icon": icon
|
96 |
}
|
97 |
if not DialogService.save(**dia):
|
98 |
return get_data_error_result(retmsg="Fail to new a dialog!")
|
api/db/services/canvas_service.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
#
|
16 |
+
from datetime import datetime
|
17 |
+
import peewee
|
18 |
+
from api.db.db_models import DB, API4Conversation, APIToken, Dialog, CanvasTemplate, UserCanvas
|
19 |
+
from api.db.services.common_service import CommonService
|
20 |
+
|
21 |
+
|
22 |
+
class CanvasTemplateService(CommonService):
|
23 |
+
model = CanvasTemplate
|
24 |
+
|
25 |
+
class UserCanvasService(CommonService):
|
26 |
+
model = UserCanvas
|
api/db/services/dialog_service.py
CHANGED
@@ -23,6 +23,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
23 |
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
|
24 |
from api.settings import chat_logger, retrievaler
|
25 |
from rag.app.resume import forbidden_select_fields4resume
|
|
|
26 |
from rag.nlp.search import index_name
|
27 |
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
28 |
|
@@ -80,7 +81,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
80 |
if not llm:
|
81 |
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
82 |
max_tokens = 1024
|
83 |
-
else:
|
|
|
84 |
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
85 |
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
86 |
if len(embd_nms) != 1:
|
@@ -124,6 +126,16 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
124 |
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
|
125 |
top=1024, aggs=False, rerank_mdl=rerank_mdl)
|
126 |
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
chat_logger.info(
|
128 |
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
129 |
|
@@ -136,7 +148,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
136 |
|
137 |
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
|
138 |
msg.extend([{"role": m["role"], "content": m["content"]}
|
139 |
-
|
140 |
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
|
141 |
assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
|
142 |
|
@@ -150,9 +162,9 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
150 |
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
151 |
answer, idx = retrievaler.insert_citations(answer,
|
152 |
[ck["content_ltks"]
|
153 |
-
|
154 |
[ck["vector"]
|
155 |
-
|
156 |
embd_mdl,
|
157 |
tkweight=1 - dialog.vector_similarity_weight,
|
158 |
vtweight=dialog.vector_similarity_weight)
|
@@ -166,7 +178,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
166 |
for c in refs["chunks"]:
|
167 |
if c.get("vector"):
|
168 |
del c["vector"]
|
169 |
-
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api")>=0:
|
170 |
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
171 |
return {"answer": answer, "reference": refs}
|
172 |
|
@@ -204,7 +216,7 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|
204 |
def get_table():
|
205 |
nonlocal sys_prompt, user_promt, question, tried_times
|
206 |
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
|
207 |
-
|
208 |
print(user_promt, sql)
|
209 |
chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}")
|
210 |
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
@@ -273,17 +285,19 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|
273 |
|
274 |
# compose markdown table
|
275 |
clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
|
276 |
-
|
|
|
277 |
|
278 |
line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
|
279 |
-
|
280 |
|
281 |
rows = ["|" +
|
282 |
"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
|
283 |
"|" for r in tbl["rows"]]
|
284 |
if quota:
|
285 |
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
286 |
-
else:
|
|
|
287 |
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
|
288 |
|
289 |
if not docid_idx or not docnm_idx:
|
@@ -303,5 +317,40 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|
303 |
return {
|
304 |
"answer": "\n".join([clmns, line, rows]),
|
305 |
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
|
306 |
-
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
|
|
|
307 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
|
24 |
from api.settings import chat_logger, retrievaler
|
25 |
from rag.app.resume import forbidden_select_fields4resume
|
26 |
+
from rag.nlp.rag_tokenizer import is_chinese
|
27 |
from rag.nlp.search import index_name
|
28 |
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
29 |
|
|
|
81 |
if not llm:
|
82 |
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
83 |
max_tokens = 1024
|
84 |
+
else:
|
85 |
+
max_tokens = llm[0].max_tokens
|
86 |
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
87 |
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
88 |
if len(embd_nms) != 1:
|
|
|
126 |
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
|
127 |
top=1024, aggs=False, rerank_mdl=rerank_mdl)
|
128 |
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
129 |
+
#self-rag
|
130 |
+
if dialog.prompt_config.get("self_rag") and not relevant(dialog.tenant_id, dialog.llm_id, questions[-1], knowledges):
|
131 |
+
questions[-1] = rewrite(dialog.tenant_id, dialog.llm_id, questions[-1])
|
132 |
+
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
133 |
+
dialog.similarity_threshold,
|
134 |
+
dialog.vector_similarity_weight,
|
135 |
+
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
|
136 |
+
top=1024, aggs=False, rerank_mdl=rerank_mdl)
|
137 |
+
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
138 |
+
|
139 |
chat_logger.info(
|
140 |
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
141 |
|
|
|
148 |
|
149 |
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
|
150 |
msg.extend([{"role": m["role"], "content": m["content"]}
|
151 |
+
for m in messages if m["role"] != "system"])
|
152 |
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
|
153 |
assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
|
154 |
|
|
|
162 |
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
163 |
answer, idx = retrievaler.insert_citations(answer,
|
164 |
[ck["content_ltks"]
|
165 |
+
for ck in kbinfos["chunks"]],
|
166 |
[ck["vector"]
|
167 |
+
for ck in kbinfos["chunks"]],
|
168 |
embd_mdl,
|
169 |
tkweight=1 - dialog.vector_similarity_weight,
|
170 |
vtweight=dialog.vector_similarity_weight)
|
|
|
178 |
for c in refs["chunks"]:
|
179 |
if c.get("vector"):
|
180 |
del c["vector"]
|
181 |
+
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
182 |
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
183 |
return {"answer": answer, "reference": refs}
|
184 |
|
|
|
216 |
def get_table():
|
217 |
nonlocal sys_prompt, user_promt, question, tried_times
|
218 |
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
|
219 |
+
"temperature": 0.06})
|
220 |
print(user_promt, sql)
|
221 |
chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}")
|
222 |
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
|
|
285 |
|
286 |
# compose markdown table
|
287 |
clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
|
288 |
+
tbl["columns"][i]["name"])) for i in
|
289 |
+
clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
290 |
|
291 |
line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
|
292 |
+
("|------|" if docid_idx and docid_idx else "")
|
293 |
|
294 |
rows = ["|" +
|
295 |
"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
|
296 |
"|" for r in tbl["rows"]]
|
297 |
if quota:
|
298 |
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
299 |
+
else:
|
300 |
+
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
301 |
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
|
302 |
|
303 |
if not docid_idx or not docnm_idx:
|
|
|
317 |
return {
|
318 |
"answer": "\n".join([clmns, line, rows]),
|
319 |
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
|
320 |
+
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
|
321 |
+
doc_aggs.items()]}
|
322 |
}
|
323 |
+
|
324 |
+
|
325 |
+
def relevant(tenant_id, llm_id, question, contents: list):
|
326 |
+
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
327 |
+
prompt = """
|
328 |
+
You are a grader assessing relevance of a retrieved document to a user question.
|
329 |
+
It does not need to be a stringent test. The goal is to filter out erroneous retrievals.
|
330 |
+
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant.
|
331 |
+
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
|
332 |
+
No other words needed except 'yes' or 'no'.
|
333 |
+
"""
|
334 |
+
if not contents:return False
|
335 |
+
contents = "Documents: \n" + " - ".join(contents)
|
336 |
+
contents = f"Question: {question}\n" + contents
|
337 |
+
if num_tokens_from_string(contents) >= chat_mdl.max_length - 4:
|
338 |
+
contents = encoder.decode(encoder.encode(contents)[:chat_mdl.max_length - 4])
|
339 |
+
ans = chat_mdl.chat(prompt, [{"role": "user", "content": contents}], {"temperature": 0.01})
|
340 |
+
if ans.lower().find("yes") >= 0: return True
|
341 |
+
return False
|
342 |
+
|
343 |
+
|
344 |
+
def rewrite(tenant_id, llm_id, question):
|
345 |
+
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
346 |
+
prompt = """
|
347 |
+
You are an expert at query expansion to generate a paraphrasing of a question.
|
348 |
+
I can't retrieval relevant information from the knowledge base by using user's question directly.
|
349 |
+
You need to expand or paraphrase user's question by multiple ways such as using synonyms words/phrase,
|
350 |
+
writing the abbreviation in its entirety, adding some extra descriptions or explanations,
|
351 |
+
changing the way of expression, translating the original question into another language (English/Chinese), etc.
|
352 |
+
And return 5 versions of question and one is from translation.
|
353 |
+
Just list the question. No other words are needed.
|
354 |
+
"""
|
355 |
+
ans = chat_mdl.chat(prompt, [{"role": "user", "content": question}], {"temperature": 0.8})
|
356 |
+
return ans
|
deepdoc/parser/pdf_parser.py
CHANGED
@@ -1021,6 +1021,8 @@ class RAGFlowPdfParser:
|
|
1021 |
|
1022 |
self.page_cum_height = np.cumsum(self.page_cum_height)
|
1023 |
assert len(self.page_cum_height) == len(self.page_images) + 1
|
|
|
|
|
1024 |
|
1025 |
def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
|
1026 |
self.__images__(fnm, zoomin)
|
|
|
1021 |
|
1022 |
self.page_cum_height = np.cumsum(self.page_cum_height)
|
1023 |
assert len(self.page_cum_height) == len(self.page_images) + 1
|
1024 |
+
if len(self.boxes) == 0 and zoomin < 9: self.__images__(fnm, zoomin * 3, page_from,
|
1025 |
+
page_to, callback)
|
1026 |
|
1027 |
def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
|
1028 |
self.__images__(fnm, zoomin)
|
rag/llm/rerank_model.py
CHANGED
@@ -129,4 +129,3 @@ class YoudaoRerank(DefaultRerank):
|
|
129 |
return np.array(res), token_count
|
130 |
|
131 |
|
132 |
-
|
|
|
129 |
return np.array(res), token_count
|
130 |
|
131 |
|
|
rag/nlp/query.py
CHANGED
@@ -48,7 +48,7 @@ class EsQueryer:
|
|
48 |
@staticmethod
|
49 |
def rmWWW(txt):
|
50 |
patts = [
|
51 |
-
(r"是*(
|
52 |
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
53 |
(r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down) ", " ")
|
54 |
]
|
@@ -68,7 +68,9 @@ class EsQueryer:
|
|
68 |
if not self.isChinese(txt):
|
69 |
tks = rag_tokenizer.tokenize(txt).split(" ")
|
70 |
tks_w = self.tw.weights(tks)
|
71 |
-
tks_w = [(re.sub(r"[ \\\"']
|
|
|
|
|
72 |
q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk]
|
73 |
for i in range(1, len(tks_w)):
|
74 |
q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2))
|
@@ -118,7 +120,8 @@ class EsQueryer:
|
|
118 |
if sm:
|
119 |
tk = f"{tk} OR \"%s\" OR (\"%s\"~2)^0.5" % (
|
120 |
" ".join(sm), " ".join(sm))
|
121 |
-
|
|
|
122 |
|
123 |
tms = " ".join([f"({t})^{w}" for t, w in tms])
|
124 |
|
|
|
48 |
@staticmethod
|
49 |
def rmWWW(txt):
|
50 |
patts = [
|
51 |
+
(r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""),
|
52 |
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
53 |
(r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down) ", " ")
|
54 |
]
|
|
|
68 |
if not self.isChinese(txt):
|
69 |
tks = rag_tokenizer.tokenize(txt).split(" ")
|
70 |
tks_w = self.tw.weights(tks)
|
71 |
+
tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
|
72 |
+
tks_w = [(re.sub(r"^[a-z0-9]$", "", tk), w) for tk, w in tks_w if tk]
|
73 |
+
tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
|
74 |
q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk]
|
75 |
for i in range(1, len(tks_w)):
|
76 |
q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2))
|
|
|
120 |
if sm:
|
121 |
tk = f"{tk} OR \"%s\" OR (\"%s\"~2)^0.5" % (
|
122 |
" ".join(sm), " ".join(sm))
|
123 |
+
if tk.strip():
|
124 |
+
tms.append((tk, w))
|
125 |
|
126 |
tms = " ".join([f"({t})^{w}" for t, w in tms])
|
127 |
|