KevinHuSh
commited on
Commit
·
a8294f2
1
Parent(s):
452020d
Refine resume parts and fix bugs in retrival using sql (#66)
Browse files- api/apps/conversation_app.py +53 -39
- api/apps/dialog_app.py +1 -1
- api/apps/document_app.py +7 -6
- api/apps/kb_app.py +3 -0
- api/apps/llm_app.py +27 -3
- api/apps/user_app.py +4 -3
- api/db/db_models.py +2 -1
- api/db/init_data.py +5 -36
- api/db/services/knowledgebase_service.py +5 -2
- api/settings.py +7 -5
- api/utils/file_utils.py +3 -3
- conf/mapping.json +36 -2
- conf/service_conf.yaml +14 -1
- rag/app/book.py +5 -0
- rag/app/laws.py +3 -1
- rag/app/manual.py +3 -1
- rag/app/naive.py +21 -8
- rag/app/paper.py +4 -0
- rag/app/presentation.py +5 -0
- rag/app/qa.py +10 -0
- rag/app/resume.py +34 -14
- rag/app/table.py +15 -2
- rag/llm/__init__.py +1 -1
- rag/llm/chat_model.py +7 -5
- rag/llm/embedding_model.py +3 -3
- rag/nlp/search.py +14 -13
- rag/parser/pdf_parser.py +3 -2
- rag/svr/task_executor.py +5 -4
- rag/utils/es_conn.py +1 -1
api/apps/conversation_app.py
CHANGED
@@ -21,20 +21,21 @@ from api.db.services.dialog_service import DialogService, ConversationService
|
|
21 |
from api.db import LLMType
|
22 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
23 |
from api.db.services.llm_service import LLMService, LLMBundle
|
24 |
-
from api.settings import access_logger
|
25 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
26 |
from api.utils import get_uuid
|
27 |
from api.utils.api_utils import get_json_result
|
|
|
28 |
from rag.llm import ChatModel
|
29 |
from rag.nlp import retrievaler
|
30 |
from rag.nlp.search import index_name
|
31 |
-
from rag.utils import num_tokens_from_string, encoder
|
32 |
|
33 |
|
34 |
@manager.route('/set', methods=['POST'])
|
35 |
@login_required
|
36 |
@validate_request("dialog_id")
|
37 |
-
def
|
38 |
req = request.json
|
39 |
conv_id = req.get("conversation_id")
|
40 |
if conv_id:
|
@@ -96,9 +97,10 @@ def rm():
|
|
96 |
except Exception as e:
|
97 |
return server_error_response(e)
|
98 |
|
|
|
99 |
@manager.route('/list', methods=['GET'])
|
100 |
@login_required
|
101 |
-
def
|
102 |
dialog_id = request.args["dialog_id"]
|
103 |
try:
|
104 |
convs = ConversationService.query(dialog_id=dialog_id)
|
@@ -112,7 +114,7 @@ def message_fit_in(msg, max_length=4000):
|
|
112 |
def count():
|
113 |
nonlocal msg
|
114 |
tks_cnts = []
|
115 |
-
for m in msg:tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])})
|
116 |
total = 0
|
117 |
for m in tks_cnts: total += m["count"]
|
118 |
return total
|
@@ -121,22 +123,22 @@ def message_fit_in(msg, max_length=4000):
|
|
121 |
if c < max_length: return c, msg
|
122 |
msg = [m for m in msg if m.role in ["system", "user"]]
|
123 |
c = count()
|
124 |
-
if c < max_length:return c, msg
|
125 |
msg_ = [m for m in msg[:-1] if m.role == "system"]
|
126 |
msg_.append(msg[-1])
|
127 |
msg = msg_
|
128 |
c = count()
|
129 |
-
if c < max_length:return c, msg
|
130 |
ll = num_tokens_from_string(msg_[0].content)
|
131 |
l = num_tokens_from_string(msg_[-1].content)
|
132 |
-
if ll/(ll + l) > 0.8:
|
133 |
m = msg_[0].content
|
134 |
-
m = encoder.decode(encoder.encode(m)[:max_length-l])
|
135 |
msg[0].content = m
|
136 |
return max_length, msg
|
137 |
|
138 |
m = msg_[1].content
|
139 |
-
m = encoder.decode(encoder.encode(m)[:max_length-l])
|
140 |
msg[1].content = m
|
141 |
return max_length, msg
|
142 |
|
@@ -148,8 +150,8 @@ def completion():
|
|
148 |
req = request.json
|
149 |
msg = []
|
150 |
for m in req["messages"]:
|
151 |
-
if m["role"] == "system":continue
|
152 |
-
if m["role"] == "assistant" and not msg:continue
|
153 |
msg.append({"role": m["role"], "content": m["content"]})
|
154 |
try:
|
155 |
e, dia = DialogService.get_by_id(req["dialog_id"])
|
@@ -166,7 +168,7 @@ def chat(dialog, messages, **kwargs):
|
|
166 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
167 |
llm = LLMService.query(llm_name=dialog.llm_id)
|
168 |
if not llm:
|
169 |
-
raise LookupError("LLM(%s) not found"%dialog.llm_id)
|
170 |
llm = llm[0]
|
171 |
question = messages[-1]["content"]
|
172 |
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
|
@@ -175,19 +177,21 @@ def chat(dialog, messages, **kwargs):
|
|
175 |
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
176 |
## try to use sql if field mapping is good to go
|
177 |
if field_map:
|
178 |
-
|
|
|
179 |
if markdown_tbl:
|
180 |
return {"answer": markdown_tbl, "retrieval": {"chunks": chunks}}
|
181 |
|
182 |
prompt_config = dialog.prompt_config
|
183 |
for p in prompt_config["parameters"]:
|
184 |
-
if p["key"] == "knowledge":continue
|
185 |
-
if p["key"] not in kwargs and not p["optional"]:raise KeyError("Miss parameter: " + p["key"])
|
186 |
if p["key"] not in kwargs:
|
187 |
-
prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ")
|
188 |
|
189 |
-
kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
190 |
-
|
|
|
191 |
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
192 |
|
193 |
if not knowledges and prompt_config["empty_response"]:
|
@@ -202,17 +206,17 @@ def chat(dialog, messages, **kwargs):
|
|
202 |
answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
|
203 |
|
204 |
answer = retrievaler.insert_citations(answer,
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
for c in kbinfos["chunks"]:
|
211 |
-
if c.get("vector"):del c["vector"]
|
212 |
return {"answer": answer, "retrieval": kbinfos}
|
213 |
|
214 |
|
215 |
-
def use_sql(question,field_map, tenant_id, chat_mdl):
|
216 |
sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据我的问题写出sql。"
|
217 |
user_promt = """
|
218 |
表名:{};
|
@@ -220,37 +224,47 @@ def use_sql(question,field_map, tenant_id, chat_mdl):
|
|
220 |
{}
|
221 |
|
222 |
问题:{}
|
223 |
-
请写出SQL
|
224 |
""".format(
|
225 |
index_name(tenant_id),
|
226 |
-
"\n".join([f"{k}: {v}" for k,v in field_map.items()]),
|
227 |
question
|
228 |
)
|
229 |
-
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.
|
230 |
-
|
|
|
|
|
231 |
sql = re.sub(r" +", " ", sql)
|
232 |
-
sql = re.sub(r"[;;].*", "", sql)
|
233 |
-
if sql[:len("select ")]
|
234 |
return None, None
|
235 |
-
if sql[:len("select *")]
|
236 |
sql = "select doc_id,docnm_kwd," + sql[6:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
|
238 |
-
|
239 |
-
|
|
|
240 |
|
241 |
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
|
242 |
docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
|
243 |
-
clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx|docnm_idx)]
|
244 |
|
245 |
# compose markdown table
|
246 |
-
clmns = "|".join([re.sub(r"
|
247 |
line = "|".join(["------" for _ in range(len(clmn_idx))]) + "|------"
|
248 |
-
rows = ["|".join([str(r[i]) for i in clmn_idx])+"|" for r in tbl["rows"]]
|
249 |
if not docid_idx or not docnm_idx:
|
250 |
access_logger.error("SQL missing field: " + sql)
|
251 |
return "\n".join([clmns, line, "\n".join(rows)]), []
|
252 |
|
253 |
-
rows = "\n".join([r+f"##{ii}$$" for ii,r in enumerate(rows)])
|
254 |
docid_idx = list(docid_idx)[0]
|
255 |
docnm_idx = list(docnm_idx)[0]
|
256 |
return "\n".join([clmns, line, rows]), [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]]
|
|
|
21 |
from api.db import LLMType
|
22 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
23 |
from api.db.services.llm_service import LLMService, LLMBundle
|
24 |
+
from api.settings import access_logger, stat_logger
|
25 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
26 |
from api.utils import get_uuid
|
27 |
from api.utils.api_utils import get_json_result
|
28 |
+
from rag.app.resume import forbidden_select_fields4resume
|
29 |
from rag.llm import ChatModel
|
30 |
from rag.nlp import retrievaler
|
31 |
from rag.nlp.search import index_name
|
32 |
+
from rag.utils import num_tokens_from_string, encoder, rmSpace
|
33 |
|
34 |
|
35 |
@manager.route('/set', methods=['POST'])
|
36 |
@login_required
|
37 |
@validate_request("dialog_id")
|
38 |
+
def set_conversation():
|
39 |
req = request.json
|
40 |
conv_id = req.get("conversation_id")
|
41 |
if conv_id:
|
|
|
97 |
except Exception as e:
|
98 |
return server_error_response(e)
|
99 |
|
100 |
+
|
101 |
@manager.route('/list', methods=['GET'])
|
102 |
@login_required
|
103 |
+
def list_convsersation():
|
104 |
dialog_id = request.args["dialog_id"]
|
105 |
try:
|
106 |
convs = ConversationService.query(dialog_id=dialog_id)
|
|
|
114 |
def count():
|
115 |
nonlocal msg
|
116 |
tks_cnts = []
|
117 |
+
for m in msg: tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])})
|
118 |
total = 0
|
119 |
for m in tks_cnts: total += m["count"]
|
120 |
return total
|
|
|
123 |
if c < max_length: return c, msg
|
124 |
msg = [m for m in msg if m.role in ["system", "user"]]
|
125 |
c = count()
|
126 |
+
if c < max_length: return c, msg
|
127 |
msg_ = [m for m in msg[:-1] if m.role == "system"]
|
128 |
msg_.append(msg[-1])
|
129 |
msg = msg_
|
130 |
c = count()
|
131 |
+
if c < max_length: return c, msg
|
132 |
ll = num_tokens_from_string(msg_[0].content)
|
133 |
l = num_tokens_from_string(msg_[-1].content)
|
134 |
+
if ll / (ll + l) > 0.8:
|
135 |
m = msg_[0].content
|
136 |
+
m = encoder.decode(encoder.encode(m)[:max_length - l])
|
137 |
msg[0].content = m
|
138 |
return max_length, msg
|
139 |
|
140 |
m = msg_[1].content
|
141 |
+
m = encoder.decode(encoder.encode(m)[:max_length - l])
|
142 |
msg[1].content = m
|
143 |
return max_length, msg
|
144 |
|
|
|
150 |
req = request.json
|
151 |
msg = []
|
152 |
for m in req["messages"]:
|
153 |
+
if m["role"] == "system": continue
|
154 |
+
if m["role"] == "assistant" and not msg: continue
|
155 |
msg.append({"role": m["role"], "content": m["content"]})
|
156 |
try:
|
157 |
e, dia = DialogService.get_by_id(req["dialog_id"])
|
|
|
168 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
169 |
llm = LLMService.query(llm_name=dialog.llm_id)
|
170 |
if not llm:
|
171 |
+
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
172 |
llm = llm[0]
|
173 |
question = messages[-1]["content"]
|
174 |
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
|
|
|
177 |
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
178 |
## try to use sql if field mapping is good to go
|
179 |
if field_map:
|
180 |
+
stat_logger.info("Use SQL to retrieval.")
|
181 |
+
markdown_tbl, chunks = use_sql(question, field_map, dialog.tenant_id, chat_mdl)
|
182 |
if markdown_tbl:
|
183 |
return {"answer": markdown_tbl, "retrieval": {"chunks": chunks}}
|
184 |
|
185 |
prompt_config = dialog.prompt_config
|
186 |
for p in prompt_config["parameters"]:
|
187 |
+
if p["key"] == "knowledge": continue
|
188 |
+
if p["key"] not in kwargs and not p["optional"]: raise KeyError("Miss parameter: " + p["key"])
|
189 |
if p["key"] not in kwargs:
|
190 |
+
prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
|
191 |
|
192 |
+
kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
193 |
+
dialog.similarity_threshold,
|
194 |
+
dialog.vector_similarity_weight, top=1024, aggs=False)
|
195 |
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
196 |
|
197 |
if not knowledges and prompt_config["empty_response"]:
|
|
|
206 |
answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
|
207 |
|
208 |
answer = retrievaler.insert_citations(answer,
|
209 |
+
[ck["content_ltks"] for ck in kbinfos["chunks"]],
|
210 |
+
[ck["vector"] for ck in kbinfos["chunks"]],
|
211 |
+
embd_mdl,
|
212 |
+
tkweight=1 - dialog.vector_similarity_weight,
|
213 |
+
vtweight=dialog.vector_similarity_weight)
|
214 |
for c in kbinfos["chunks"]:
|
215 |
+
if c.get("vector"): del c["vector"]
|
216 |
return {"answer": answer, "retrieval": kbinfos}
|
217 |
|
218 |
|
219 |
+
def use_sql(question, field_map, tenant_id, chat_mdl):
|
220 |
sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据我的问题写出sql。"
|
221 |
user_promt = """
|
222 |
表名:{};
|
|
|
224 |
{}
|
225 |
|
226 |
问题:{}
|
227 |
+
请写出SQL,且只要SQL,不要有其他说明及文字。
|
228 |
""".format(
|
229 |
index_name(tenant_id),
|
230 |
+
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
|
231 |
question
|
232 |
)
|
233 |
+
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.06})
|
234 |
+
stat_logger.info(f"“{question}” get SQL: {sql}")
|
235 |
+
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
236 |
+
sql = re.sub(r".*?select ", "select ", sql.lower())
|
237 |
sql = re.sub(r" +", " ", sql)
|
238 |
+
sql = re.sub(r"([;;]|```).*", "", sql)
|
239 |
+
if sql[:len("select ")] != "select ":
|
240 |
return None, None
|
241 |
+
if sql[:len("select *")] != "select *":
|
242 |
sql = "select doc_id,docnm_kwd," + sql[6:]
|
243 |
+
else:
|
244 |
+
flds = []
|
245 |
+
for k in field_map.keys():
|
246 |
+
if k in forbidden_select_fields4resume:continue
|
247 |
+
if len(flds) > 11:break
|
248 |
+
flds.append(k)
|
249 |
+
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
|
250 |
|
251 |
+
stat_logger.info(f"“{question}” get SQL(refined): {sql}")
|
252 |
+
tbl = retrievaler.sql_retrieval(sql, format="json")
|
253 |
+
if not tbl or len(tbl["rows"]) == 0: return None, None
|
254 |
|
255 |
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
|
256 |
docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
|
257 |
+
clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]
|
258 |
|
259 |
# compose markdown table
|
260 |
+
clmns = "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], f"C{i}")) for i in clmn_idx]) + "|原文"
|
261 |
line = "|".join(["------" for _ in range(len(clmn_idx))]) + "|------"
|
262 |
+
rows = ["|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
|
263 |
if not docid_idx or not docnm_idx:
|
264 |
access_logger.error("SQL missing field: " + sql)
|
265 |
return "\n".join([clmns, line, "\n".join(rows)]), []
|
266 |
|
267 |
+
rows = "\n".join([r + f"##{ii}$$" for ii, r in enumerate(rows)])
|
268 |
docid_idx = list(docid_idx)[0]
|
269 |
docnm_idx = list(docnm_idx)[0]
|
270 |
return "\n".join([clmns, line, rows]), [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]]
|
api/apps/dialog_app.py
CHANGED
@@ -27,7 +27,7 @@ from api.utils.api_utils import get_json_result
|
|
27 |
|
28 |
@manager.route('/set', methods=['POST'])
|
29 |
@login_required
|
30 |
-
def
|
31 |
req = request.json
|
32 |
dialog_id = req.get("dialog_id")
|
33 |
name = req.get("name", "New Dialog")
|
|
|
27 |
|
28 |
@manager.route('/set', methods=['POST'])
|
29 |
@login_required
|
30 |
+
def set_dialog():
|
31 |
req = request.json
|
32 |
dialog_id = req.get("dialog_id")
|
33 |
name = req.get("name", "New Dialog")
|
api/apps/document_app.py
CHANGED
@@ -262,17 +262,18 @@ def rename():
|
|
262 |
return server_error_response(e)
|
263 |
|
264 |
|
265 |
-
@manager.route('/get', methods=['GET'])
|
266 |
-
|
267 |
-
def get():
|
268 |
-
doc_id = request.args["doc_id"]
|
269 |
try:
|
270 |
e, doc = DocumentService.get_by_id(doc_id)
|
271 |
if not e:
|
272 |
return get_data_error_result(retmsg="Document not found!")
|
273 |
|
274 |
-
|
275 |
-
|
|
|
|
|
|
|
276 |
except Exception as e:
|
277 |
return server_error_response(e)
|
278 |
|
|
|
262 |
return server_error_response(e)
|
263 |
|
264 |
|
265 |
+
@manager.route('/get/<doc_id>', methods=['GET'])
|
266 |
+
def get(doc_id):
|
|
|
|
|
267 |
try:
|
268 |
e, doc = DocumentService.get_by_id(doc_id)
|
269 |
if not e:
|
270 |
return get_data_error_result(retmsg="Document not found!")
|
271 |
|
272 |
+
response = flask.make_response(MINIO.get(doc.kb_id, doc.location))
|
273 |
+
ext = re.search(r"\.([^.]+)$", doc.name)
|
274 |
+
if ext:
|
275 |
+
response.headers.set('Content-Type', 'application/%s'%ext.group(1))
|
276 |
+
return response
|
277 |
except Exception as e:
|
278 |
return server_error_response(e)
|
279 |
|
api/apps/kb_app.py
CHANGED
@@ -38,6 +38,9 @@ def create():
|
|
38 |
req["id"] = get_uuid()
|
39 |
req["tenant_id"] = current_user.id
|
40 |
req["created_by"] = current_user.id
|
|
|
|
|
|
|
41 |
if not KnowledgebaseService.save(**req): return get_data_error_result()
|
42 |
return get_json_result(data={"kb_id": req["id"]})
|
43 |
except Exception as e:
|
|
|
38 |
req["id"] = get_uuid()
|
39 |
req["tenant_id"] = current_user.id
|
40 |
req["created_by"] = current_user.id
|
41 |
+
e, t = TenantService.get_by_id(current_user.id)
|
42 |
+
if not e: return get_data_error_result(retmsg="Tenant not found.")
|
43 |
+
req["embd_id"] = t.embd_id
|
44 |
if not KnowledgebaseService.save(**req): return get_data_error_result()
|
45 |
return get_json_result(data={"kb_id": req["id"]})
|
46 |
except Exception as e:
|
api/apps/llm_app.py
CHANGED
@@ -21,11 +21,12 @@ from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, L
|
|
21 |
from api.db.services.user_service import TenantService, UserTenantService
|
22 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
23 |
from api.utils import get_uuid, get_format_time
|
24 |
-
from api.db import StatusEnum, UserTenantRole
|
25 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
26 |
from api.db.db_models import Knowledgebase, TenantLLM
|
27 |
from api.settings import stat_logger, RetCode
|
28 |
from api.utils.api_utils import get_json_result
|
|
|
29 |
|
30 |
|
31 |
@manager.route('/factories', methods=['GET'])
|
@@ -43,16 +44,37 @@ def factories():
|
|
43 |
@validate_request("llm_factory", "api_key")
|
44 |
def set_api_key():
|
45 |
req = request.json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
llm = {
|
47 |
"tenant_id": current_user.id,
|
48 |
"llm_factory": req["llm_factory"],
|
49 |
"api_key": req["api_key"]
|
50 |
}
|
51 |
-
# TODO: Test api_key
|
52 |
for n in ["model_type", "llm_name"]:
|
53 |
if n in req: llm[n] = req[n]
|
54 |
|
55 |
-
TenantLLM.
|
56 |
return get_json_result(data=True)
|
57 |
|
58 |
|
@@ -69,6 +91,7 @@ def my_llms():
|
|
69 |
@manager.route('/list', methods=['GET'])
|
70 |
@login_required
|
71 |
def list():
|
|
|
72 |
try:
|
73 |
objs = TenantLLMService.query(tenant_id=current_user.id)
|
74 |
mdlnms = set([o.to_dict()["llm_name"] for o in objs if o.api_key])
|
@@ -79,6 +102,7 @@ def list():
|
|
79 |
|
80 |
res = {}
|
81 |
for m in llms:
|
|
|
82 |
if m["fid"] not in res: res[m["fid"]] = []
|
83 |
res[m["fid"]].append(m)
|
84 |
|
|
|
21 |
from api.db.services.user_service import TenantService, UserTenantService
|
22 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
23 |
from api.utils import get_uuid, get_format_time
|
24 |
+
from api.db import StatusEnum, UserTenantRole, LLMType
|
25 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
26 |
from api.db.db_models import Knowledgebase, TenantLLM
|
27 |
from api.settings import stat_logger, RetCode
|
28 |
from api.utils.api_utils import get_json_result
|
29 |
+
from rag.llm import EmbeddingModel, CvModel, ChatModel
|
30 |
|
31 |
|
32 |
@manager.route('/factories', methods=['GET'])
|
|
|
44 |
@validate_request("llm_factory", "api_key")
|
45 |
def set_api_key():
|
46 |
req = request.json
|
47 |
+
# test if api key works
|
48 |
+
msg = ""
|
49 |
+
for llm in LLMService.query(fid=req["llm_factory"]):
|
50 |
+
if llm.model_type == LLMType.EMBEDDING.value:
|
51 |
+
mdl = EmbeddingModel[req["llm_factory"]](
|
52 |
+
req["api_key"], llm.llm_name)
|
53 |
+
try:
|
54 |
+
arr, tc = mdl.encode(["Test if the api key is available"])
|
55 |
+
if len(arr[0]) == 0 or tc ==0: raise Exception("Fail")
|
56 |
+
except Exception as e:
|
57 |
+
msg += f"\nFail to access embedding model({llm.llm_name}) using this api key."
|
58 |
+
elif llm.model_type == LLMType.CHAT.value:
|
59 |
+
mdl = ChatModel[req["llm_factory"]](
|
60 |
+
req["api_key"], llm.llm_name)
|
61 |
+
try:
|
62 |
+
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
|
63 |
+
if not tc: raise Exception(m)
|
64 |
+
except Exception as e:
|
65 |
+
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(e)
|
66 |
+
|
67 |
+
if msg: return get_data_error_result(retmsg=msg)
|
68 |
+
|
69 |
llm = {
|
70 |
"tenant_id": current_user.id,
|
71 |
"llm_factory": req["llm_factory"],
|
72 |
"api_key": req["api_key"]
|
73 |
}
|
|
|
74 |
for n in ["model_type", "llm_name"]:
|
75 |
if n in req: llm[n] = req[n]
|
76 |
|
77 |
+
TenantLLMService.filter_update([TenantLLM.tenant_id==llm["tenant_id"], TenantLLM.llm_factory==llm["llm_factory"]], llm)
|
78 |
return get_json_result(data=True)
|
79 |
|
80 |
|
|
|
91 |
@manager.route('/list', methods=['GET'])
|
92 |
@login_required
|
93 |
def list():
|
94 |
+
model_type = request.args.get("model_type")
|
95 |
try:
|
96 |
objs = TenantLLMService.query(tenant_id=current_user.id)
|
97 |
mdlnms = set([o.to_dict()["llm_name"] for o in objs if o.api_key])
|
|
|
102 |
|
103 |
res = {}
|
104 |
for m in llms:
|
105 |
+
if model_type and m["model_type"] != model_type: continue
|
106 |
if m["fid"] not in res: res[m["fid"]] = []
|
107 |
res[m["fid"]].append(m)
|
108 |
|
api/apps/user_app.py
CHANGED
@@ -24,7 +24,8 @@ from api.db.services.llm_service import TenantLLMService, LLMService
|
|
24 |
from api.utils.api_utils import server_error_response, validate_request
|
25 |
from api.utils import get_uuid, get_format_time, decrypt, download_img
|
26 |
from api.db import UserTenantRole, LLMType
|
27 |
-
from api.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS
|
|
|
28 |
from api.db.services.user_service import UserService, TenantService, UserTenantService
|
29 |
from api.settings import stat_logger
|
30 |
from api.utils.api_utils import get_json_result, cors_reponse
|
@@ -204,8 +205,8 @@ def user_register(user_id, user):
|
|
204 |
"role": UserTenantRole.OWNER
|
205 |
}
|
206 |
tenant_llm = []
|
207 |
-
for llm in LLMService.query(fid=
|
208 |
-
tenant_llm.append({"tenant_id": user_id, "llm_factory":
|
209 |
|
210 |
if not UserService.save(**user):return
|
211 |
TenantService.save(**tenant)
|
|
|
24 |
from api.utils.api_utils import server_error_response, validate_request
|
25 |
from api.utils import get_uuid, get_format_time, decrypt, download_img
|
26 |
from api.db import UserTenantRole, LLMType
|
27 |
+
from api.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, API_KEY, \
|
28 |
+
LLM_FACTORY
|
29 |
from api.db.services.user_service import UserService, TenantService, UserTenantService
|
30 |
from api.settings import stat_logger
|
31 |
from api.utils.api_utils import get_json_result, cors_reponse
|
|
|
205 |
"role": UserTenantRole.OWNER
|
206 |
}
|
207 |
tenant_llm = []
|
208 |
+
for llm in LLMService.query(fid=LLM_FACTORY):
|
209 |
+
tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY})
|
210 |
|
211 |
if not UserService.save(**user):return
|
212 |
TenantService.save(**tenant)
|
api/db/db_models.py
CHANGED
@@ -465,7 +465,8 @@ class Knowledgebase(DataBaseModel):
|
|
465 |
tenant_id = CharField(max_length=32, null=False)
|
466 |
name = CharField(max_length=128, null=False, help_text="KB name", index=True)
|
467 |
description = TextField(null=True, help_text="KB description")
|
468 |
-
|
|
|
469 |
created_by = CharField(max_length=32, null=False)
|
470 |
doc_num = IntegerField(default=0)
|
471 |
token_num = IntegerField(default=0)
|
|
|
465 |
tenant_id = CharField(max_length=32, null=False)
|
466 |
name = CharField(max_length=128, null=False, help_text="KB name", index=True)
|
467 |
description = TextField(null=True, help_text="KB description")
|
468 |
+
embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID")
|
469 |
+
permission = CharField(max_length=16, null=False, help_text="me|team", default="me")
|
470 |
created_by = CharField(max_length=32, null=False)
|
471 |
doc_num = IntegerField(default=0)
|
472 |
token_num = IntegerField(default=0)
|
api/db/init_data.py
CHANGED
@@ -46,11 +46,6 @@ def init_llm_factory():
|
|
46 |
"logo": "",
|
47 |
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
48 |
"status": "1",
|
49 |
-
},{
|
50 |
-
"name": "Infiniflow",
|
51 |
-
"logo": "",
|
52 |
-
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
53 |
-
"status": "1",
|
54 |
},{
|
55 |
"name": "智普AI",
|
56 |
"logo": "",
|
@@ -135,59 +130,33 @@ def init_llm_factory():
|
|
135 |
"model_type": LLMType.SPEECH2TEXT.value
|
136 |
},{
|
137 |
"fid": factory_infos[1]["name"],
|
138 |
-
"llm_name": "
|
139 |
-
"tags": "LLM,CHAT,IMAGE2TEXT",
|
140 |
-
"max_tokens": 765,
|
141 |
-
"model_type": LLMType.IMAGE2TEXT.value
|
142 |
-
},
|
143 |
-
# ----------------------- Infiniflow -----------------------
|
144 |
-
{
|
145 |
-
"fid": factory_infos[2]["name"],
|
146 |
-
"llm_name": "gpt-3.5-turbo",
|
147 |
-
"tags": "LLM,CHAT,4K",
|
148 |
-
"max_tokens": 4096,
|
149 |
-
"model_type": LLMType.CHAT.value
|
150 |
-
},{
|
151 |
-
"fid": factory_infos[2]["name"],
|
152 |
-
"llm_name": "text-embedding-ada-002",
|
153 |
-
"tags": "TEXT EMBEDDING,8K",
|
154 |
-
"max_tokens": 8191,
|
155 |
-
"model_type": LLMType.EMBEDDING.value
|
156 |
-
},{
|
157 |
-
"fid": factory_infos[2]["name"],
|
158 |
-
"llm_name": "whisper-1",
|
159 |
-
"tags": "SPEECH2TEXT",
|
160 |
-
"max_tokens": 25*1024*1024,
|
161 |
-
"model_type": LLMType.SPEECH2TEXT.value
|
162 |
-
},{
|
163 |
-
"fid": factory_infos[2]["name"],
|
164 |
-
"llm_name": "gpt-4-vision-preview",
|
165 |
"tags": "LLM,CHAT,IMAGE2TEXT",
|
166 |
"max_tokens": 765,
|
167 |
"model_type": LLMType.IMAGE2TEXT.value
|
168 |
},
|
169 |
# ---------------------- ZhipuAI ----------------------
|
170 |
{
|
171 |
-
"fid": factory_infos[
|
172 |
"llm_name": "glm-3-turbo",
|
173 |
"tags": "LLM,CHAT,",
|
174 |
"max_tokens": 128 * 1000,
|
175 |
"model_type": LLMType.CHAT.value
|
176 |
}, {
|
177 |
-
"fid": factory_infos[
|
178 |
"llm_name": "glm-4",
|
179 |
"tags": "LLM,CHAT,",
|
180 |
"max_tokens": 128 * 1000,
|
181 |
"model_type": LLMType.CHAT.value
|
182 |
}, {
|
183 |
-
"fid": factory_infos[
|
184 |
"llm_name": "glm-4v",
|
185 |
"tags": "LLM,CHAT,IMAGE2TEXT",
|
186 |
"max_tokens": 2000,
|
187 |
"model_type": LLMType.IMAGE2TEXT.value
|
188 |
},
|
189 |
{
|
190 |
-
"fid": factory_infos[
|
191 |
"llm_name": "embedding-2",
|
192 |
"tags": "TEXT EMBEDDING",
|
193 |
"max_tokens": 512,
|
|
|
46 |
"logo": "",
|
47 |
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
48 |
"status": "1",
|
|
|
|
|
|
|
|
|
|
|
49 |
},{
|
50 |
"name": "智普AI",
|
51 |
"logo": "",
|
|
|
130 |
"model_type": LLMType.SPEECH2TEXT.value
|
131 |
},{
|
132 |
"fid": factory_infos[1]["name"],
|
133 |
+
"llm_name": "qwen-vl-max",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
"tags": "LLM,CHAT,IMAGE2TEXT",
|
135 |
"max_tokens": 765,
|
136 |
"model_type": LLMType.IMAGE2TEXT.value
|
137 |
},
|
138 |
# ---------------------- ZhipuAI ----------------------
|
139 |
{
|
140 |
+
"fid": factory_infos[2]["name"],
|
141 |
"llm_name": "glm-3-turbo",
|
142 |
"tags": "LLM,CHAT,",
|
143 |
"max_tokens": 128 * 1000,
|
144 |
"model_type": LLMType.CHAT.value
|
145 |
}, {
|
146 |
+
"fid": factory_infos[2]["name"],
|
147 |
"llm_name": "glm-4",
|
148 |
"tags": "LLM,CHAT,",
|
149 |
"max_tokens": 128 * 1000,
|
150 |
"model_type": LLMType.CHAT.value
|
151 |
}, {
|
152 |
+
"fid": factory_infos[2]["name"],
|
153 |
"llm_name": "glm-4v",
|
154 |
"tags": "LLM,CHAT,IMAGE2TEXT",
|
155 |
"max_tokens": 2000,
|
156 |
"model_type": LLMType.IMAGE2TEXT.value
|
157 |
},
|
158 |
{
|
159 |
+
"fid": factory_infos[2]["name"],
|
160 |
"llm_name": "embedding-2",
|
161 |
"tags": "TEXT EMBEDDING",
|
162 |
"max_tokens": 512,
|
api/db/services/knowledgebase_service.py
CHANGED
@@ -77,9 +77,12 @@ class KnowledgebaseService(CommonService):
|
|
77 |
if isinstance(v, dict):
|
78 |
assert isinstance(old[k], dict)
|
79 |
dfs_update(old[k], v)
|
|
|
|
|
|
|
80 |
else: old[k] = v
|
81 |
dfs_update(m.parser_config, config)
|
82 |
-
cls.update_by_id(id, m.parser_config)
|
83 |
|
84 |
|
85 |
@classmethod
|
@@ -88,6 +91,6 @@ class KnowledgebaseService(CommonService):
|
|
88 |
conf = {}
|
89 |
for k in cls.get_by_ids(ids):
|
90 |
if k.parser_config and "field_map" in k.parser_config:
|
91 |
-
conf.update(k.parser_config)
|
92 |
return conf
|
93 |
|
|
|
77 |
if isinstance(v, dict):
|
78 |
assert isinstance(old[k], dict)
|
79 |
dfs_update(old[k], v)
|
80 |
+
if isinstance(v, list):
|
81 |
+
assert isinstance(old[k], list)
|
82 |
+
old[k] = list(set(old[k]+v))
|
83 |
else: old[k] = v
|
84 |
dfs_update(m.parser_config, config)
|
85 |
+
cls.update_by_id(id, {"parser_config": m.parser_config})
|
86 |
|
87 |
|
88 |
@classmethod
|
|
|
91 |
conf = {}
|
92 |
for k in cls.get_by_ids(ids):
|
93 |
if k.parser_config and "field_map" in k.parser_config:
|
94 |
+
conf.update(k.parser_config["field_map"])
|
95 |
return conf
|
96 |
|
api/settings.py
CHANGED
@@ -43,12 +43,14 @@ REQUEST_MAX_WAIT_SEC = 300
|
|
43 |
|
44 |
USE_REGISTRY = get_base_config("use_registry")
|
45 |
|
46 |
-
LLM = get_base_config("
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
50 |
PARSERS = LLM.get("parsers", "general:General,qa:Q&A,resume:Resume,naive:Naive,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
|
51 |
-
IMAGE2TEXT_MDL = LLM.get("image2text_model", "gpt-4-vision-preview")
|
52 |
|
53 |
# distribution
|
54 |
DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
|
|
|
43 |
|
44 |
USE_REGISTRY = get_base_config("use_registry")
|
45 |
|
46 |
+
LLM = get_base_config("user_default_llm", {})
|
47 |
+
LLM_FACTORY=LLM.get("factory", "通义千问")
|
48 |
+
CHAT_MDL = LLM.get("chat_model", "qwen-plus")
|
49 |
+
EMBEDDING_MDL = LLM.get("embedding_model", "text-embedding-v2")
|
50 |
+
ASR_MDL = LLM.get("asr_model", "paraformer-realtime-8k-v1")
|
51 |
+
IMAGE2TEXT_MDL = LLM.get("image2text_model", "qwen-vl-max")
|
52 |
+
API_KEY = LLM.get("api_key", "infiniflow API Key")
|
53 |
PARSERS = LLM.get("parsers", "general:General,qa:Q&A,resume:Resume,naive:Naive,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
|
|
|
54 |
|
55 |
# distribution
|
56 |
DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
|
api/utils/file_utils.py
CHANGED
@@ -164,10 +164,10 @@ def thumbnail(filename, blob):
|
|
164 |
buffered = BytesIO()
|
165 |
Image.frombytes("RGB", [pix.width, pix.height],
|
166 |
pix.samples).save(buffered, format="png")
|
167 |
-
return "data:image/png;base64," + base64.b64encode(buffered.getvalue())
|
168 |
|
169 |
if re.match(r".*\.(jpg|jpeg|png|tif|gif|icon|ico|webp)$", filename):
|
170 |
-
return ("data:image/%s;base64,"%filename.split(".")[-1]) + base64.b64encode(Image.open(BytesIO(blob)).thumbnail((30, 30)).tobytes())
|
171 |
|
172 |
if re.match(r".*\.(ppt|pptx)$", filename):
|
173 |
import aspose.slides as slides
|
@@ -176,7 +176,7 @@ def thumbnail(filename, blob):
|
|
176 |
with slides.Presentation(BytesIO(blob)) as presentation:
|
177 |
buffered = BytesIO()
|
178 |
presentation.slides[0].get_thumbnail(0.03, 0.03).save(buffered, drawing.imaging.ImageFormat.png)
|
179 |
-
return "data:image/png;base64," + base64.b64encode(buffered.getvalue())
|
180 |
except Exception as e:
|
181 |
pass
|
182 |
|
|
|
164 |
buffered = BytesIO()
|
165 |
Image.frombytes("RGB", [pix.width, pix.height],
|
166 |
pix.samples).save(buffered, format="png")
|
167 |
+
return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8")
|
168 |
|
169 |
if re.match(r".*\.(jpg|jpeg|png|tif|gif|icon|ico|webp)$", filename):
|
170 |
+
return ("data:image/%s;base64,"%filename.split(".")[-1]) + base64.b64encode(Image.open(BytesIO(blob)).thumbnail((30, 30)).tobytes()).decode("utf-8")
|
171 |
|
172 |
if re.match(r".*\.(ppt|pptx)$", filename):
|
173 |
import aspose.slides as slides
|
|
|
176 |
with slides.Presentation(BytesIO(blob)) as presentation:
|
177 |
buffered = BytesIO()
|
178 |
presentation.slides[0].get_thumbnail(0.03, 0.03).save(buffered, drawing.imaging.ImageFormat.png)
|
179 |
+
return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8")
|
180 |
except Exception as e:
|
181 |
pass
|
182 |
|
conf/mapping.json
CHANGED
@@ -118,11 +118,45 @@
|
|
118 |
},
|
119 |
{
|
120 |
"dense_vector": {
|
121 |
-
"match": "*
|
122 |
"mapping": {
|
123 |
"type": "dense_vector",
|
124 |
"index": true,
|
125 |
-
"similarity": "cosine"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
}
|
127 |
}
|
128 |
},
|
|
|
118 |
},
|
119 |
{
|
120 |
"dense_vector": {
|
121 |
+
"match": "*_512_vec",
|
122 |
"mapping": {
|
123 |
"type": "dense_vector",
|
124 |
"index": true,
|
125 |
+
"similarity": "cosine",
|
126 |
+
"dims": 512
|
127 |
+
}
|
128 |
+
}
|
129 |
+
},
|
130 |
+
{
|
131 |
+
"dense_vector": {
|
132 |
+
"match": "*_768_vec",
|
133 |
+
"mapping": {
|
134 |
+
"type": "dense_vector",
|
135 |
+
"index": true,
|
136 |
+
"similarity": "cosine",
|
137 |
+
"dims": 768
|
138 |
+
}
|
139 |
+
}
|
140 |
+
},
|
141 |
+
{
|
142 |
+
"dense_vector": {
|
143 |
+
"match": "*_1024_vec",
|
144 |
+
"mapping": {
|
145 |
+
"type": "dense_vector",
|
146 |
+
"index": true,
|
147 |
+
"similarity": "cosine",
|
148 |
+
"dims": 1024
|
149 |
+
}
|
150 |
+
}
|
151 |
+
},
|
152 |
+
{
|
153 |
+
"dense_vector": {
|
154 |
+
"match": "*_1536_vec",
|
155 |
+
"mapping": {
|
156 |
+
"type": "dense_vector",
|
157 |
+
"index": true,
|
158 |
+
"similarity": "cosine",
|
159 |
+
"dims": 1536
|
160 |
}
|
161 |
}
|
162 |
},
|
conf/service_conf.yaml
CHANGED
@@ -11,7 +11,7 @@ permission:
|
|
11 |
dataset: false
|
12 |
ragflow:
|
13 |
# you must set real ip address, 127.0.0.1 and 0.0.0.0 is not supported
|
14 |
-
host:
|
15 |
http_port: 9380
|
16 |
database:
|
17 |
name: 'rag_flow'
|
@@ -21,6 +21,19 @@ database:
|
|
21 |
port: 5455
|
22 |
max_connections: 100
|
23 |
stale_timeout: 30
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
oauth:
|
25 |
github:
|
26 |
client_id: 302129228f0d96055bee
|
|
|
11 |
dataset: false
|
12 |
ragflow:
|
13 |
# you must set real ip address, 127.0.0.1 and 0.0.0.0 is not supported
|
14 |
+
host: 0.0.0.0
|
15 |
http_port: 9380
|
16 |
database:
|
17 |
name: 'rag_flow'
|
|
|
21 |
port: 5455
|
22 |
max_connections: 100
|
23 |
stale_timeout: 30
|
24 |
+
minio:
|
25 |
+
user: 'rag_flow'
|
26 |
+
passwd: 'infini_rag_flow'
|
27 |
+
host: '123.60.95.134:9000'
|
28 |
+
es:
|
29 |
+
hosts: 'http://123.60.95.134:9200'
|
30 |
+
user_default_llm:
|
31 |
+
factory: '通义千问'
|
32 |
+
chat_model: 'qwen-plus'
|
33 |
+
embedding_model: 'text-embedding-v2'
|
34 |
+
asr_model: 'paraformer-realtime-8k-v1'
|
35 |
+
image2text_model: 'qwen-vl-max'
|
36 |
+
api_key: 'sk-xxxxxxxxxxxxx'
|
37 |
oauth:
|
38 |
github:
|
39 |
client_id: 302129228f0d96055bee
|
rag/app/book.py
CHANGED
@@ -39,6 +39,11 @@ class Pdf(HuParser):
|
|
39 |
|
40 |
|
41 |
def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
42 |
doc = {
|
43 |
"docnm_kwd": filename,
|
44 |
"title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
|
|
39 |
|
40 |
|
41 |
def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
|
42 |
+
"""
|
43 |
+
Supported file formats are docx, pdf, txt.
|
44 |
+
Since a book is long and not all the parts are useful, if it's a PDF,
|
45 |
+
please setup the page ranges for every book in order eliminate negative effects and save elapsed computing time.
|
46 |
+
"""
|
47 |
doc = {
|
48 |
"docnm_kwd": filename,
|
49 |
"title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
rag/app/laws.py
CHANGED
@@ -2,7 +2,6 @@ import copy
|
|
2 |
import re
|
3 |
from io import BytesIO
|
4 |
from docx import Document
|
5 |
-
import numpy as np
|
6 |
from rag.parser import bullets_category, is_english, tokenize, remove_contents_table, hierarchical_merge, \
|
7 |
make_colon_as_title
|
8 |
from rag.nlp import huqie
|
@@ -59,6 +58,9 @@ class Pdf(HuParser):
|
|
59 |
|
60 |
|
61 |
def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
|
|
|
|
|
|
|
62 |
doc = {
|
63 |
"docnm_kwd": filename,
|
64 |
"title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
|
|
2 |
import re
|
3 |
from io import BytesIO
|
4 |
from docx import Document
|
|
|
5 |
from rag.parser import bullets_category, is_english, tokenize, remove_contents_table, hierarchical_merge, \
|
6 |
make_colon_as_title
|
7 |
from rag.nlp import huqie
|
|
|
58 |
|
59 |
|
60 |
def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
|
61 |
+
"""
|
62 |
+
Supported file formats are docx, pdf, txt.
|
63 |
+
"""
|
64 |
doc = {
|
65 |
"docnm_kwd": filename,
|
66 |
"title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
rag/app/manual.py
CHANGED
@@ -58,8 +58,10 @@ class Pdf(HuParser):
|
|
58 |
|
59 |
|
60 |
def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
|
|
|
|
|
|
|
61 |
pdf_parser = None
|
62 |
-
paper = {}
|
63 |
|
64 |
if re.search(r"\.pdf$", filename, re.IGNORECASE):
|
65 |
pdf_parser = Pdf()
|
|
|
58 |
|
59 |
|
60 |
def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
|
61 |
+
"""
|
62 |
+
Only pdf is supported.
|
63 |
+
"""
|
64 |
pdf_parser = None
|
|
|
65 |
|
66 |
if re.search(r"\.pdf$", filename, re.IGNORECASE):
|
67 |
pdf_parser = Pdf()
|
rag/app/naive.py
CHANGED
@@ -6,6 +6,7 @@ from rag.nlp import huqie
|
|
6 |
from rag.parser.pdf_parser import HuParser
|
7 |
from rag.settings import cron_logger
|
8 |
|
|
|
9 |
class Pdf(HuParser):
|
10 |
def __call__(self, filename, binary=None, from_page=0,
|
11 |
to_page=100000, zoomin=3, callback=None):
|
@@ -20,12 +21,18 @@ class Pdf(HuParser):
|
|
20 |
start = timer()
|
21 |
self._layouts_paddle(zoomin)
|
22 |
callback(0.77, "Layout analysis finished")
|
23 |
-
cron_logger.info("paddle layouts:".format((timer()-start)/(self.total_page+0.1)))
|
24 |
self._naive_vertical_merge()
|
25 |
return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes]
|
26 |
|
27 |
|
28 |
def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
doc = {
|
30 |
"docnm_kwd": filename,
|
31 |
"title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
@@ -41,24 +48,26 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
|
|
41 |
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
42 |
pdf_parser = Pdf()
|
43 |
sections = pdf_parser(filename if not binary else binary,
|
44 |
-
|
45 |
elif re.search(r"\.txt$", filename, re.IGNORECASE):
|
46 |
callback(0.1, "Start to parse.")
|
47 |
txt = ""
|
48 |
-
if binary:
|
|
|
49 |
else:
|
50 |
with open(filename, "r") as f:
|
51 |
while True:
|
52 |
l = f.readline()
|
53 |
-
if not l:break
|
54 |
txt += l
|
55 |
sections = txt.split("\n")
|
56 |
-
sections = [(l,"") for l in sections if l]
|
57 |
callback(0.8, "Finish parsing.")
|
58 |
-
else:
|
|
|
59 |
|
60 |
-
parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "
|
61 |
-
cks = naive_merge(sections, parser_config["chunk_token_num"], parser_config["
|
62 |
eng = is_english(cks)
|
63 |
res = []
|
64 |
# wrap up to es documents
|
@@ -75,6 +84,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
|
|
75 |
|
76 |
if __name__ == "__main__":
|
77 |
import sys
|
|
|
|
|
78 |
def dummy(a, b):
|
79 |
pass
|
|
|
|
|
80 |
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
|
|
|
6 |
from rag.parser.pdf_parser import HuParser
|
7 |
from rag.settings import cron_logger
|
8 |
|
9 |
+
|
10 |
class Pdf(HuParser):
|
11 |
def __call__(self, filename, binary=None, from_page=0,
|
12 |
to_page=100000, zoomin=3, callback=None):
|
|
|
21 |
start = timer()
|
22 |
self._layouts_paddle(zoomin)
|
23 |
callback(0.77, "Layout analysis finished")
|
24 |
+
cron_logger.info("paddle layouts:".format((timer() - start) / (self.total_page + 0.1)))
|
25 |
self._naive_vertical_merge()
|
26 |
return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes]
|
27 |
|
28 |
|
29 |
def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
|
30 |
+
"""
|
31 |
+
Supported file formats are docx, pdf, txt.
|
32 |
+
This method apply the naive ways to chunk files.
|
33 |
+
Successive text will be sliced into pieces using 'delimiter'.
|
34 |
+
Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'.
|
35 |
+
"""
|
36 |
doc = {
|
37 |
"docnm_kwd": filename,
|
38 |
"title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
|
|
48 |
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
49 |
pdf_parser = Pdf()
|
50 |
sections = pdf_parser(filename if not binary else binary,
|
51 |
+
from_page=from_page, to_page=to_page, callback=callback)
|
52 |
elif re.search(r"\.txt$", filename, re.IGNORECASE):
|
53 |
callback(0.1, "Start to parse.")
|
54 |
txt = ""
|
55 |
+
if binary:
|
56 |
+
txt = binary.decode("utf-8")
|
57 |
else:
|
58 |
with open(filename, "r") as f:
|
59 |
while True:
|
60 |
l = f.readline()
|
61 |
+
if not l: break
|
62 |
txt += l
|
63 |
sections = txt.split("\n")
|
64 |
+
sections = [(l, "") for l in sections if l]
|
65 |
callback(0.8, "Finish parsing.")
|
66 |
+
else:
|
67 |
+
raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)")
|
68 |
|
69 |
+
parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimiter": "\n!?。;!?"})
|
70 |
+
cks = naive_merge(sections, parser_config["chunk_token_num"], parser_config["delimiter"])
|
71 |
eng = is_english(cks)
|
72 |
res = []
|
73 |
# wrap up to es documents
|
|
|
84 |
|
85 |
if __name__ == "__main__":
|
86 |
import sys
|
87 |
+
|
88 |
+
|
89 |
def dummy(a, b):
|
90 |
pass
|
91 |
+
|
92 |
+
|
93 |
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
|
rag/app/paper.py
CHANGED
@@ -129,6 +129,10 @@ class Pdf(HuParser):
|
|
129 |
|
130 |
|
131 |
def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
|
|
|
|
|
|
|
|
|
132 |
pdf_parser = None
|
133 |
if re.search(r"\.pdf$", filename, re.IGNORECASE):
|
134 |
pdf_parser = Pdf()
|
|
|
129 |
|
130 |
|
131 |
def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
|
132 |
+
"""
|
133 |
+
Only pdf is supported.
|
134 |
+
The abstract of the paper will be sliced as an entire chunk, and will not be sliced partly.
|
135 |
+
"""
|
136 |
pdf_parser = None
|
137 |
if re.search(r"\.pdf$", filename, re.IGNORECASE):
|
138 |
pdf_parser = Pdf()
|
rag/app/presentation.py
CHANGED
@@ -94,6 +94,11 @@ class Pdf(HuParser):
|
|
94 |
|
95 |
|
96 |
def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
97 |
doc = {
|
98 |
"docnm_kwd": filename,
|
99 |
"title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
|
|
94 |
|
95 |
|
96 |
def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
|
97 |
+
"""
|
98 |
+
The supported file formats are pdf, pptx.
|
99 |
+
Every page will be treated as a chunk. And the thumbnail of every page will be stored.
|
100 |
+
PPT file will be parsed by using this method automatically, setting-up for every PPT file is not necessary.
|
101 |
+
"""
|
102 |
doc = {
|
103 |
"docnm_kwd": filename,
|
104 |
"title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
rag/app/qa.py
CHANGED
@@ -70,7 +70,17 @@ def beAdoc(d, q, a, eng):
|
|
70 |
|
71 |
|
72 |
def chunk(filename, binary=None, callback=None, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
73 |
|
|
|
|
|
|
|
|
|
|
|
74 |
res = []
|
75 |
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
76 |
callback(0.1, "Start to parse.")
|
|
|
70 |
|
71 |
|
72 |
def chunk(filename, binary=None, callback=None, **kwargs):
|
73 |
+
"""
|
74 |
+
Excel and csv(txt) format files are supported.
|
75 |
+
If the file is in excel format, there should be 2 column question and answer without header.
|
76 |
+
And question column is ahead of answer column.
|
77 |
+
And it's O.K if it has multiple sheets as long as the columns are rightly composed.
|
78 |
|
79 |
+
If it's in csv format, it should be UTF-8 encoded. Use TAB as delimiter to separate question and answer.
|
80 |
+
|
81 |
+
All the deformed lines will be ignored.
|
82 |
+
Every pair of Q&A will be treated as a chunk.
|
83 |
+
"""
|
84 |
res = []
|
85 |
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
86 |
callback(0.1, "Start to parse.")
|
rag/app/resume.py
CHANGED
@@ -4,24 +4,34 @@ import os
|
|
4 |
import re
|
5 |
import requests
|
6 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
|
7 |
from rag.nlp import huqie
|
8 |
|
9 |
from rag.settings import cron_logger
|
10 |
from rag.utils import rmSpace
|
11 |
|
|
|
|
|
|
|
12 |
|
13 |
def chunk(filename, binary=None, callback=None, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
if not re.search(r"\.(pdf|doc|docx|txt)$", filename, flags=re.IGNORECASE):
|
15 |
raise NotImplementedError("file type not supported yet(pdf supported)")
|
16 |
|
17 |
url = os.environ.get("INFINIFLOW_SERVER")
|
18 |
-
if not url:
|
19 |
-
raise EnvironmentError(
|
20 |
-
"Please set environment variable: 'INFINIFLOW_SERVER'")
|
21 |
token = os.environ.get("INFINIFLOW_TOKEN")
|
22 |
-
if not token:
|
23 |
-
|
24 |
-
"
|
|
|
25 |
|
26 |
if not binary:
|
27 |
with open(filename, "rb") as f:
|
@@ -44,22 +54,28 @@ def chunk(filename, binary=None, callback=None, **kwargs):
|
|
44 |
|
45 |
callback(0.2, "Resume parsing is going on...")
|
46 |
resume = remote_call()
|
|
|
|
|
|
|
47 |
callback(0.6, "Done parsing. Chunking...")
|
48 |
print(json.dumps(resume, ensure_ascii=False, indent=2))
|
49 |
|
50 |
field_map = {
|
51 |
"name_kwd": "姓名/名字",
|
|
|
52 |
"gender_kwd": "性别(男,女)",
|
53 |
"age_int": "年龄/岁/年纪",
|
54 |
"phone_kwd": "电话/手机/微信",
|
55 |
"email_tks": "email/e-mail/邮箱",
|
56 |
"position_name_tks": "职位/职能/岗位/职责",
|
57 |
-
"
|
|
|
|
|
58 |
|
59 |
-
"
|
60 |
"first_degree_kwd": "第一学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA)",
|
|
|
61 |
"first_major_tks": "第一学历专业",
|
62 |
-
"first_school_name_tks": "第一学历毕业学校",
|
63 |
"edu_first_fea_kwd": "第一学历标签(211,留学,双一流,985,海外知名,重点大学,中专,专升本,专科,本科,大专)",
|
64 |
|
65 |
"degree_kwd": "过往学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA)",
|
@@ -68,14 +84,14 @@ def chunk(filename, binary=None, callback=None, **kwargs):
|
|
68 |
"sch_rank_kwd": "学校标签(顶尖学校,精英学校,优质学校,一般学校)",
|
69 |
"edu_fea_kwd": "教育标签(211,留学,双一流,985,海外知名,重点大学,中专,专升本,专科,本科,大专)",
|
70 |
|
71 |
-
"work_exp_flt": "工作年限/工作年份/N年经验/毕业了多少年",
|
72 |
-
"birth_dt": "生日/出生年份",
|
73 |
"corp_nm_tks": "就职过的公司/之前的公司/上过班的公司",
|
74 |
-
"corporation_name_tks": "最近就职(上班)的公司/上一家公司",
|
75 |
"edu_end_int": "毕业年份",
|
76 |
-
"
|
77 |
-
|
|
|
|
|
78 |
}
|
|
|
79 |
titles = []
|
80 |
for n in ["name_kwd", "gender_kwd", "position_name_tks", "age_int"]:
|
81 |
v = resume.get(n, "")
|
@@ -105,6 +121,10 @@ def chunk(filename, binary=None, callback=None, **kwargs):
|
|
105 |
doc["content_ltks"] = huqie.qie(doc["content_with_weight"])
|
106 |
doc["content_sm_ltks"] = huqie.qieqie(doc["content_ltks"])
|
107 |
for n, _ in field_map.items():
|
|
|
|
|
|
|
|
|
108 |
doc[n] = resume[n]
|
109 |
|
110 |
print(doc)
|
|
|
4 |
import re
|
5 |
import requests
|
6 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
7 |
+
from api.settings import stat_logger
|
8 |
from rag.nlp import huqie
|
9 |
|
10 |
from rag.settings import cron_logger
|
11 |
from rag.utils import rmSpace
|
12 |
|
13 |
+
forbidden_select_fields4resume = [
|
14 |
+
"name_pinyin_kwd", "edu_first_fea_kwd", "degree_kwd", "sch_rank_kwd", "edu_fea_kwd"
|
15 |
+
]
|
16 |
|
17 |
def chunk(filename, binary=None, callback=None, **kwargs):
|
18 |
+
"""
|
19 |
+
The supported file formats are pdf, docx and txt.
|
20 |
+
To maximize the effectiveness, parse the resume correctly,
|
21 |
+
please visit https://github.com/infiniflow/ragflow, and sign in the our demo web-site
|
22 |
+
to get token. It's FREE!
|
23 |
+
Set INFINIFLOW_SERVER and INFINIFLOW_TOKEN in '.env' file or
|
24 |
+
using 'export' to set both environment variables: INFINIFLOW_SERVER and INFINIFLOW_TOKEN in docker container.
|
25 |
+
"""
|
26 |
if not re.search(r"\.(pdf|doc|docx|txt)$", filename, flags=re.IGNORECASE):
|
27 |
raise NotImplementedError("file type not supported yet(pdf supported)")
|
28 |
|
29 |
url = os.environ.get("INFINIFLOW_SERVER")
|
|
|
|
|
|
|
30 |
token = os.environ.get("INFINIFLOW_TOKEN")
|
31 |
+
if not url or not token:
|
32 |
+
stat_logger.warning(
|
33 |
+
"INFINIFLOW_SERVER is not specified. To maximize the effectiveness, please visit https://github.com/infiniflow/ragflow, and sign in the our demo web site to get token. It's FREE! Using 'export' to set both environment variables: INFINIFLOW_SERVER and INFINIFLOW_TOKEN.")
|
34 |
+
return []
|
35 |
|
36 |
if not binary:
|
37 |
with open(filename, "rb") as f:
|
|
|
54 |
|
55 |
callback(0.2, "Resume parsing is going on...")
|
56 |
resume = remote_call()
|
57 |
+
if len(resume.keys()) < 7:
|
58 |
+
callback(-1, "Resume is not successfully parsed.")
|
59 |
+
return []
|
60 |
callback(0.6, "Done parsing. Chunking...")
|
61 |
print(json.dumps(resume, ensure_ascii=False, indent=2))
|
62 |
|
63 |
field_map = {
|
64 |
"name_kwd": "姓名/名字",
|
65 |
+
"name_pinyin_kwd": "姓名拼音/名字拼音",
|
66 |
"gender_kwd": "性别(男,女)",
|
67 |
"age_int": "年龄/岁/年纪",
|
68 |
"phone_kwd": "电话/手机/微信",
|
69 |
"email_tks": "email/e-mail/邮箱",
|
70 |
"position_name_tks": "职位/职能/岗位/职责",
|
71 |
+
"expect_city_names_tks": "期望城市",
|
72 |
+
"work_exp_flt": "工作年限/工作年份/N年经验/毕业了多少年",
|
73 |
+
"corporation_name_tks": "最近就职(上班)的公司/上一家公司",
|
74 |
|
75 |
+
"first_school_name_tks": "第一学历毕业学校",
|
76 |
"first_degree_kwd": "第一学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA)",
|
77 |
+
"highest_degree_kwd": "最高学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA)",
|
78 |
"first_major_tks": "第一学历专业",
|
|
|
79 |
"edu_first_fea_kwd": "第一学历标签(211,留学,双一流,985,海外知名,重点大学,中专,专升本,专科,本科,大专)",
|
80 |
|
81 |
"degree_kwd": "过往学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA)",
|
|
|
84 |
"sch_rank_kwd": "学校标签(顶尖学校,精英学校,优质学校,一般学校)",
|
85 |
"edu_fea_kwd": "教育标签(211,留学,双一流,985,海外知名,重点大学,中专,专升本,专科,本科,大专)",
|
86 |
|
|
|
|
|
87 |
"corp_nm_tks": "就职过的公司/之前的公司/上过班的公司",
|
|
|
88 |
"edu_end_int": "毕业年份",
|
89 |
+
"industry_name_tks": "所在行业",
|
90 |
+
|
91 |
+
"birth_dt": "生日/出生年份",
|
92 |
+
"expect_position_name_tks": "期望职位/期望职能/期望岗位",
|
93 |
}
|
94 |
+
|
95 |
titles = []
|
96 |
for n in ["name_kwd", "gender_kwd", "position_name_tks", "age_int"]:
|
97 |
v = resume.get(n, "")
|
|
|
121 |
doc["content_ltks"] = huqie.qie(doc["content_with_weight"])
|
122 |
doc["content_sm_ltks"] = huqie.qieqie(doc["content_ltks"])
|
123 |
for n, _ in field_map.items():
|
124 |
+
if n not in resume:continue
|
125 |
+
if isinstance(resume[n], list) and (len(resume[n]) == 1 or n not in forbidden_select_fields4resume):
|
126 |
+
resume[n] = resume[n][0]
|
127 |
+
if n.find("_tks")>0: resume[n] = huqie.qieqie(resume[n])
|
128 |
doc[n] = resume[n]
|
129 |
|
130 |
print(doc)
|
rag/app/table.py
CHANGED
@@ -100,7 +100,20 @@ def column_data_type(arr):
|
|
100 |
|
101 |
|
102 |
def chunk(filename, binary=None, callback=None, **kwargs):
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
105 |
callback(0.1, "Start to parse.")
|
106 |
excel_parser = Excel()
|
@@ -155,7 +168,7 @@ def chunk(filename, binary=None, callback=None, **kwargs):
|
|
155 |
del df[n]
|
156 |
clmns = df.columns.values
|
157 |
txts = list(copy.deepcopy(clmns))
|
158 |
-
py_clmns = [PY.get_pinyins(
|
159 |
clmn_tys = []
|
160 |
for j in range(len(clmns)):
|
161 |
cln, ty = column_data_type(df[clmns[j]])
|
|
|
100 |
|
101 |
|
102 |
def chunk(filename, binary=None, callback=None, **kwargs):
|
103 |
+
"""
|
104 |
+
Excel and csv(txt) format files are supported.
|
105 |
+
For csv or txt file, the delimiter between columns is TAB.
|
106 |
+
The first line must be column headers.
|
107 |
+
Column headers must be meaningful terms inorder to make our NLP model understanding.
|
108 |
+
It's good to enumerate some synonyms using slash '/' to separate, and even better to
|
109 |
+
enumerate values using brackets like 'gender/sex(male, female)'.
|
110 |
+
Here are some examples for headers:
|
111 |
+
1. supplier/vendor\tcolor(yellow, red, brown)\tgender/sex(male, female)\tsize(M,L,XL,XXL)
|
112 |
+
2. 姓名/名字\t电话/手机/微信\t最高学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA)
|
113 |
+
|
114 |
+
Every row in table will be treated as a chunk.
|
115 |
+
"""
|
116 |
+
|
117 |
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
118 |
callback(0.1, "Start to parse.")
|
119 |
excel_parser = Excel()
|
|
|
168 |
del df[n]
|
169 |
clmns = df.columns.values
|
170 |
txts = list(copy.deepcopy(clmns))
|
171 |
+
py_clmns = [PY.get_pinyins(re.sub(r"(/.*|([^()]+?)|\([^()]+?\))", "", n), '_')[0] for n in clmns]
|
172 |
clmn_tys = []
|
173 |
for j in range(len(clmns)):
|
174 |
cln, ty = column_data_type(df[clmns[j]])
|
rag/llm/__init__.py
CHANGED
@@ -21,7 +21,7 @@ from .cv_model import *
|
|
21 |
EmbeddingModel = {
|
22 |
"Infiniflow": HuEmbedding,
|
23 |
"OpenAI": OpenAIEmbed,
|
24 |
-
"通义千问": QWenEmbed,
|
25 |
}
|
26 |
|
27 |
|
|
|
21 |
EmbeddingModel = {
|
22 |
"Infiniflow": HuEmbedding,
|
23 |
"OpenAI": OpenAIEmbed,
|
24 |
+
"通义千问": HuEmbedding, #QWenEmbed,
|
25 |
}
|
26 |
|
27 |
|
rag/llm/chat_model.py
CHANGED
@@ -32,7 +32,7 @@ class GptTurbo(Base):
|
|
32 |
self.model_name = model_name
|
33 |
|
34 |
def chat(self, system, history, gen_conf):
|
35 |
-
history.insert(0, {"role": "system", "content": system})
|
36 |
res = self.client.chat.completions.create(
|
37 |
model=self.model_name,
|
38 |
messages=history,
|
@@ -49,11 +49,12 @@ class QWenChat(Base):
|
|
49 |
|
50 |
def chat(self, system, history, gen_conf):
|
51 |
from http import HTTPStatus
|
52 |
-
history.insert(0, {"role": "system", "content": system})
|
53 |
response = Generation.call(
|
54 |
self.model_name,
|
55 |
messages=history,
|
56 |
-
result_format='message'
|
|
|
57 |
)
|
58 |
if response.status_code == HTTPStatus.OK:
|
59 |
return response.output.choices[0]['message']['content'], response.usage.output_tokens
|
@@ -68,10 +69,11 @@ class ZhipuChat(Base):
|
|
68 |
|
69 |
def chat(self, system, history, gen_conf):
|
70 |
from http import HTTPStatus
|
71 |
-
history.insert(0, {"role": "system", "content": system})
|
72 |
response = self.client.chat.completions.create(
|
73 |
self.model_name,
|
74 |
-
messages=history
|
|
|
75 |
)
|
76 |
if response.status_code == HTTPStatus.OK:
|
77 |
return response.output.choices[0]['message']['content'], response.usage.completion_tokens
|
|
|
32 |
self.model_name = model_name
|
33 |
|
34 |
def chat(self, system, history, gen_conf):
|
35 |
+
if system: history.insert(0, {"role": "system", "content": system})
|
36 |
res = self.client.chat.completions.create(
|
37 |
model=self.model_name,
|
38 |
messages=history,
|
|
|
49 |
|
50 |
def chat(self, system, history, gen_conf):
|
51 |
from http import HTTPStatus
|
52 |
+
if system: history.insert(0, {"role": "system", "content": system})
|
53 |
response = Generation.call(
|
54 |
self.model_name,
|
55 |
messages=history,
|
56 |
+
result_format='message',
|
57 |
+
**gen_conf
|
58 |
)
|
59 |
if response.status_code == HTTPStatus.OK:
|
60 |
return response.output.choices[0]['message']['content'], response.usage.output_tokens
|
|
|
69 |
|
70 |
def chat(self, system, history, gen_conf):
|
71 |
from http import HTTPStatus
|
72 |
+
if system: history.insert(0, {"role": "system", "content": system})
|
73 |
response = self.client.chat.completions.create(
|
74 |
self.model_name,
|
75 |
+
messages=history,
|
76 |
+
**gen_conf
|
77 |
)
|
78 |
if response.status_code == HTTPStatus.OK:
|
79 |
return response.output.choices[0]['message']['content'], response.usage.completion_tokens
|
rag/llm/embedding_model.py
CHANGED
@@ -100,11 +100,11 @@ class QWenEmbed(Base):
|
|
100 |
input=texts[i:i+batch_size],
|
101 |
text_type="document"
|
102 |
)
|
103 |
-
embds = [[]
|
104 |
for e in resp["output"]["embeddings"]:
|
105 |
embds[e["text_index"]] = e["embedding"]
|
106 |
res.extend(embds)
|
107 |
-
token_count += resp["usage"]["
|
108 |
return np.array(res), token_count
|
109 |
|
110 |
def encode_queries(self, text):
|
@@ -113,7 +113,7 @@ class QWenEmbed(Base):
|
|
113 |
input=text[:2048],
|
114 |
text_type="query"
|
115 |
)
|
116 |
-
return np.array(resp["output"]["embeddings"][0]["embedding"]), resp["usage"]["
|
117 |
|
118 |
|
119 |
from zhipuai import ZhipuAI
|
|
|
100 |
input=texts[i:i+batch_size],
|
101 |
text_type="document"
|
102 |
)
|
103 |
+
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
|
104 |
for e in resp["output"]["embeddings"]:
|
105 |
embds[e["text_index"]] = e["embedding"]
|
106 |
res.extend(embds)
|
107 |
+
token_count += resp["usage"]["total_tokens"]
|
108 |
return np.array(res), token_count
|
109 |
|
110 |
def encode_queries(self, text):
|
|
|
113 |
input=text[:2048],
|
114 |
text_type="query"
|
115 |
)
|
116 |
+
return np.array(resp["output"]["embeddings"][0]["embedding"]), resp["usage"]["total_tokens"]
|
117 |
|
118 |
|
119 |
from zhipuai import ZhipuAI
|
rag/nlp/search.py
CHANGED
@@ -92,7 +92,7 @@ class Dealer:
|
|
92 |
assert emb_mdl, "No embedding model selected"
|
93 |
s["knn"] = self._vector(
|
94 |
qst, emb_mdl, req.get(
|
95 |
-
"similarity", 0.
|
96 |
s["knn"]["filter"] = bqry.to_dict()
|
97 |
if "highlight" in s:
|
98 |
del s["highlight"]
|
@@ -106,7 +106,7 @@ class Dealer:
|
|
106 |
bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
|
107 |
s["query"] = bqry.to_dict()
|
108 |
s["knn"]["filter"] = bqry.to_dict()
|
109 |
-
s["knn"]["similarity"] = 0.
|
110 |
res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
|
111 |
|
112 |
kwds = set([])
|
@@ -171,7 +171,7 @@ class Dealer:
|
|
171 |
continue
|
172 |
if not isinstance(v, type("")):
|
173 |
m[n] = str(m[n])
|
174 |
-
m[n] = rmSpace(m[n])
|
175 |
|
176 |
if m:
|
177 |
res[d["id"]] = m
|
@@ -303,21 +303,22 @@ class Dealer:
|
|
303 |
|
304 |
return ranks
|
305 |
|
306 |
-
def sql_retrieval(self, sql, fetch_size=128):
|
307 |
sql = re.sub(r"[ ]+", " ", sql)
|
|
|
|
|
308 |
replaces = []
|
309 |
-
for r in re.finditer(r" ([a-z_]+_l?tks like |
|
310 |
-
fld, v = r.group(1), r.group(
|
311 |
-
|
312 |
-
|
313 |
-
match = " MATCH({}, '{}', 'operator=OR;fuzziness=AUTO:1,3;minimum_should_match=30%') ".format(fld, huqie.qie(v))
|
314 |
-
replaces.append((r.group(1)+r.group(2), match))
|
315 |
|
316 |
-
for p, r in replaces: sql.replace(p, r)
|
|
|
317 |
|
318 |
try:
|
319 |
-
tbl = self.es.sql(sql, fetch_size)
|
320 |
return tbl
|
321 |
except Exception as e:
|
322 |
-
es_logger(f"SQL failure: {sql} =>" + str(e))
|
323 |
|
|
|
92 |
assert emb_mdl, "No embedding model selected"
|
93 |
s["knn"] = self._vector(
|
94 |
qst, emb_mdl, req.get(
|
95 |
+
"similarity", 0.1), ps)
|
96 |
s["knn"]["filter"] = bqry.to_dict()
|
97 |
if "highlight" in s:
|
98 |
del s["highlight"]
|
|
|
106 |
bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
|
107 |
s["query"] = bqry.to_dict()
|
108 |
s["knn"]["filter"] = bqry.to_dict()
|
109 |
+
s["knn"]["similarity"] = 0.17
|
110 |
res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
|
111 |
|
112 |
kwds = set([])
|
|
|
171 |
continue
|
172 |
if not isinstance(v, type("")):
|
173 |
m[n] = str(m[n])
|
174 |
+
if n.find("tks")>0: m[n] = rmSpace(m[n])
|
175 |
|
176 |
if m:
|
177 |
res[d["id"]] = m
|
|
|
303 |
|
304 |
return ranks
|
305 |
|
306 |
+
def sql_retrieval(self, sql, fetch_size=128, format="json"):
|
307 |
sql = re.sub(r"[ ]+", " ", sql)
|
308 |
+
sql = sql.replace("%", "")
|
309 |
+
es_logger.info(f"Get es sql: {sql}")
|
310 |
replaces = []
|
311 |
+
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
|
312 |
+
fld, v = r.group(1), r.group(3)
|
313 |
+
match = " MATCH({}, '{}', 'operator=OR;fuzziness=AUTO:1,3;minimum_should_match=30%') ".format(fld, huqie.qieqie(huqie.qie(v)))
|
314 |
+
replaces.append(("{}{}'{}'".format(r.group(1), r.group(2), r.group(3)), match))
|
|
|
|
|
315 |
|
316 |
+
for p, r in replaces: sql = sql.replace(p, r, 1)
|
317 |
+
es_logger.info(f"To es: {sql}")
|
318 |
|
319 |
try:
|
320 |
+
tbl = self.es.sql(sql, fetch_size, format)
|
321 |
return tbl
|
322 |
except Exception as e:
|
323 |
+
es_logger.error(f"SQL failure: {sql} =>" + str(e))
|
324 |
|
rag/parser/pdf_parser.py
CHANGED
@@ -53,9 +53,10 @@ class HuParser:
|
|
53 |
|
54 |
def __remote_call(self, species, images, thr=0.7):
|
55 |
url = os.environ.get("INFINIFLOW_SERVER")
|
56 |
-
if not url:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_SERVER'")
|
57 |
token = os.environ.get("INFINIFLOW_TOKEN")
|
58 |
-
if not
|
|
|
|
|
59 |
|
60 |
def convert_image_to_bytes(PILimage):
|
61 |
image = BytesIO()
|
|
|
53 |
|
54 |
def __remote_call(self, species, images, thr=0.7):
|
55 |
url = os.environ.get("INFINIFLOW_SERVER")
|
|
|
56 |
token = os.environ.get("INFINIFLOW_TOKEN")
|
57 |
+
if not url or not token:
|
58 |
+
logging.warning("INFINIFLOW_SERVER is not specified. To maximize the effectiveness, please visit https://github.com/infiniflow/ragflow, and sign in the our demo web site to get token. It's FREE! Using 'export' to set both environment variables: INFINIFLOW_SERVER and INFINIFLOW_TOKEN.")
|
59 |
+
return []
|
60 |
|
61 |
def convert_image_to_bytes(PILimage):
|
62 |
image = BytesIO()
|
rag/svr/task_executor.py
CHANGED
@@ -47,7 +47,7 @@ from api.utils.file_utils import get_project_base_directory
|
|
47 |
BATCH_SIZE = 64
|
48 |
|
49 |
FACTORY = {
|
50 |
-
ParserType.GENERAL.value:
|
51 |
ParserType.PAPER.value: paper,
|
52 |
ParserType.BOOK.value: book,
|
53 |
ParserType.PRESENTATION.value: presentation,
|
@@ -119,8 +119,8 @@ def build(row, cvmdl):
|
|
119 |
chunker = FACTORY[row["parser_id"].lower()]
|
120 |
try:
|
121 |
cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
|
122 |
-
cks = chunker.chunk(row["name"], MINIO.get(row["kb_id"], row["location"]), row["from_page"], row["to_page"],
|
123 |
-
callback, kb_id=row["kb_id"], parser_config=row["parser_config"])
|
124 |
except Exception as e:
|
125 |
if re.search("(No such file|not found)", str(e)):
|
126 |
callback(-1, "Can not find file <%s>" % row["doc_name"])
|
@@ -129,7 +129,7 @@ def build(row, cvmdl):
|
|
129 |
|
130 |
cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e)))
|
131 |
|
132 |
-
return
|
133 |
|
134 |
callback(msg="Finished slicing files. Start to embedding the content.")
|
135 |
|
@@ -211,6 +211,7 @@ def main(comm, mod):
|
|
211 |
|
212 |
st_tm = timer()
|
213 |
cks = build(r, cv_mdl)
|
|
|
214 |
if not cks:
|
215 |
tmf.write(str(r["update_time"]) + "\n")
|
216 |
callback(1., "No chunk! Done!")
|
|
|
47 |
BATCH_SIZE = 64
|
48 |
|
49 |
FACTORY = {
|
50 |
+
ParserType.GENERAL.value: manual,
|
51 |
ParserType.PAPER.value: paper,
|
52 |
ParserType.BOOK.value: book,
|
53 |
ParserType.PRESENTATION.value: presentation,
|
|
|
119 |
chunker = FACTORY[row["parser_id"].lower()]
|
120 |
try:
|
121 |
cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
|
122 |
+
cks = chunker.chunk(row["name"], binary = MINIO.get(row["kb_id"], row["location"]), from_page=row["from_page"], to_page=row["to_page"],
|
123 |
+
callback = callback, kb_id=row["kb_id"], parser_config=row["parser_config"])
|
124 |
except Exception as e:
|
125 |
if re.search("(No such file|not found)", str(e)):
|
126 |
callback(-1, "Can not find file <%s>" % row["doc_name"])
|
|
|
129 |
|
130 |
cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e)))
|
131 |
|
132 |
+
return
|
133 |
|
134 |
callback(msg="Finished slicing files. Start to embedding the content.")
|
135 |
|
|
|
211 |
|
212 |
st_tm = timer()
|
213 |
cks = build(r, cv_mdl)
|
214 |
+
if cks is None:continue
|
215 |
if not cks:
|
216 |
tmf.write(str(r["update_time"]) + "\n")
|
217 |
callback(1., "No chunk! Done!")
|
rag/utils/es_conn.py
CHANGED
@@ -241,7 +241,7 @@ class HuEs:
|
|
241 |
es_logger.error("ES search timeout for 3 times!")
|
242 |
raise Exception("ES search timeout.")
|
243 |
|
244 |
-
def sql(self, sql, fetch_size=128, format="json", timeout=
|
245 |
for i in range(3):
|
246 |
try:
|
247 |
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout=timeout)
|
|
|
241 |
es_logger.error("ES search timeout for 3 times!")
|
242 |
raise Exception("ES search timeout.")
|
243 |
|
244 |
+
def sql(self, sql, fetch_size=128, format="json", timeout="2s"):
|
245 |
for i in range(3):
|
246 |
try:
|
247 |
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout=timeout)
|