KevinHuSh commited on
Commit
a8294f2
·
1 Parent(s): 452020d

Refine resume parts and fix bugs in retrival using sql (#66)

Browse files
api/apps/conversation_app.py CHANGED
@@ -21,20 +21,21 @@ from api.db.services.dialog_service import DialogService, ConversationService
21
  from api.db import LLMType
22
  from api.db.services.knowledgebase_service import KnowledgebaseService
23
  from api.db.services.llm_service import LLMService, LLMBundle
24
- from api.settings import access_logger
25
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
26
  from api.utils import get_uuid
27
  from api.utils.api_utils import get_json_result
 
28
  from rag.llm import ChatModel
29
  from rag.nlp import retrievaler
30
  from rag.nlp.search import index_name
31
- from rag.utils import num_tokens_from_string, encoder
32
 
33
 
34
  @manager.route('/set', methods=['POST'])
35
  @login_required
36
  @validate_request("dialog_id")
37
- def set():
38
  req = request.json
39
  conv_id = req.get("conversation_id")
40
  if conv_id:
@@ -96,9 +97,10 @@ def rm():
96
  except Exception as e:
97
  return server_error_response(e)
98
 
 
99
  @manager.route('/list', methods=['GET'])
100
  @login_required
101
- def list():
102
  dialog_id = request.args["dialog_id"]
103
  try:
104
  convs = ConversationService.query(dialog_id=dialog_id)
@@ -112,7 +114,7 @@ def message_fit_in(msg, max_length=4000):
112
  def count():
113
  nonlocal msg
114
  tks_cnts = []
115
- for m in msg:tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])})
116
  total = 0
117
  for m in tks_cnts: total += m["count"]
118
  return total
@@ -121,22 +123,22 @@ def message_fit_in(msg, max_length=4000):
121
  if c < max_length: return c, msg
122
  msg = [m for m in msg if m.role in ["system", "user"]]
123
  c = count()
124
- if c < max_length:return c, msg
125
  msg_ = [m for m in msg[:-1] if m.role == "system"]
126
  msg_.append(msg[-1])
127
  msg = msg_
128
  c = count()
129
- if c < max_length:return c, msg
130
  ll = num_tokens_from_string(msg_[0].content)
131
  l = num_tokens_from_string(msg_[-1].content)
132
- if ll/(ll + l) > 0.8:
133
  m = msg_[0].content
134
- m = encoder.decode(encoder.encode(m)[:max_length-l])
135
  msg[0].content = m
136
  return max_length, msg
137
 
138
  m = msg_[1].content
139
- m = encoder.decode(encoder.encode(m)[:max_length-l])
140
  msg[1].content = m
141
  return max_length, msg
142
 
@@ -148,8 +150,8 @@ def completion():
148
  req = request.json
149
  msg = []
150
  for m in req["messages"]:
151
- if m["role"] == "system":continue
152
- if m["role"] == "assistant" and not msg:continue
153
  msg.append({"role": m["role"], "content": m["content"]})
154
  try:
155
  e, dia = DialogService.get_by_id(req["dialog_id"])
@@ -166,7 +168,7 @@ def chat(dialog, messages, **kwargs):
166
  assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
167
  llm = LLMService.query(llm_name=dialog.llm_id)
168
  if not llm:
169
- raise LookupError("LLM(%s) not found"%dialog.llm_id)
170
  llm = llm[0]
171
  question = messages[-1]["content"]
172
  embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
@@ -175,19 +177,21 @@ def chat(dialog, messages, **kwargs):
175
  field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
176
  ## try to use sql if field mapping is good to go
177
  if field_map:
178
- markdown_tbl,chunks = use_sql(question, field_map, dialog.tenant_id, chat_mdl)
 
179
  if markdown_tbl:
180
  return {"answer": markdown_tbl, "retrieval": {"chunks": chunks}}
181
 
182
  prompt_config = dialog.prompt_config
183
  for p in prompt_config["parameters"]:
184
- if p["key"] == "knowledge":continue
185
- if p["key"] not in kwargs and not p["optional"]:raise KeyError("Miss parameter: " + p["key"])
186
  if p["key"] not in kwargs:
187
- prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ")
188
 
189
- kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold,
190
- dialog.vector_similarity_weight, top=1024, aggs=False)
 
191
  knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
192
 
193
  if not knowledges and prompt_config["empty_response"]:
@@ -202,17 +206,17 @@ def chat(dialog, messages, **kwargs):
202
  answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
203
 
204
  answer = retrievaler.insert_citations(answer,
205
- [ck["content_ltks"] for ck in kbinfos["chunks"]],
206
- [ck["vector"] for ck in kbinfos["chunks"]],
207
- embd_mdl,
208
- tkweight=1-dialog.vector_similarity_weight,
209
- vtweight=dialog.vector_similarity_weight)
210
  for c in kbinfos["chunks"]:
211
- if c.get("vector"):del c["vector"]
212
  return {"answer": answer, "retrieval": kbinfos}
213
 
214
 
215
- def use_sql(question,field_map, tenant_id, chat_mdl):
216
  sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据我的问题写出sql。"
217
  user_promt = """
218
  表名:{};
@@ -220,37 +224,47 @@ def use_sql(question,field_map, tenant_id, chat_mdl):
220
  {}
221
 
222
  问题:{}
223
- 请写出SQL
224
  """.format(
225
  index_name(tenant_id),
226
- "\n".join([f"{k}: {v}" for k,v in field_map.items()]),
227
  question
228
  )
229
- sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.1})
230
- sql = re.sub(r".*?select ", "select ", sql, flags=re.IGNORECASE)
 
 
231
  sql = re.sub(r" +", " ", sql)
232
- sql = re.sub(r"[;;].*", "", sql)
233
- if sql[:len("select ")].lower() != "select ":
234
  return None, None
235
- if sql[:len("select *")].lower() != "select *":
236
  sql = "select doc_id,docnm_kwd," + sql[6:]
 
 
 
 
 
 
 
237
 
238
- tbl = retrievaler.sql_retrieval(sql)
239
- if not tbl: return None, None
 
240
 
241
  docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
242
  docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
243
- clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx|docnm_idx)]
244
 
245
  # compose markdown table
246
- clmns = "|".join([re.sub(r"/.*", "", field_map.get(tbl["columns"][i]["name"], f"C{i}")) for i in clmn_idx]) + "|原文"
247
  line = "|".join(["------" for _ in range(len(clmn_idx))]) + "|------"
248
- rows = ["|".join([str(r[i]) for i in clmn_idx])+"|" for r in tbl["rows"]]
249
  if not docid_idx or not docnm_idx:
250
  access_logger.error("SQL missing field: " + sql)
251
  return "\n".join([clmns, line, "\n".join(rows)]), []
252
 
253
- rows = "\n".join([r+f"##{ii}$$" for ii,r in enumerate(rows)])
254
  docid_idx = list(docid_idx)[0]
255
  docnm_idx = list(docnm_idx)[0]
256
  return "\n".join([clmns, line, rows]), [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]]
 
21
  from api.db import LLMType
22
  from api.db.services.knowledgebase_service import KnowledgebaseService
23
  from api.db.services.llm_service import LLMService, LLMBundle
24
+ from api.settings import access_logger, stat_logger
25
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
26
  from api.utils import get_uuid
27
  from api.utils.api_utils import get_json_result
28
+ from rag.app.resume import forbidden_select_fields4resume
29
  from rag.llm import ChatModel
30
  from rag.nlp import retrievaler
31
  from rag.nlp.search import index_name
32
+ from rag.utils import num_tokens_from_string, encoder, rmSpace
33
 
34
 
35
  @manager.route('/set', methods=['POST'])
36
  @login_required
37
  @validate_request("dialog_id")
38
+ def set_conversation():
39
  req = request.json
40
  conv_id = req.get("conversation_id")
41
  if conv_id:
 
97
  except Exception as e:
98
  return server_error_response(e)
99
 
100
+
101
  @manager.route('/list', methods=['GET'])
102
  @login_required
103
+ def list_convsersation():
104
  dialog_id = request.args["dialog_id"]
105
  try:
106
  convs = ConversationService.query(dialog_id=dialog_id)
 
114
  def count():
115
  nonlocal msg
116
  tks_cnts = []
117
+ for m in msg: tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])})
118
  total = 0
119
  for m in tks_cnts: total += m["count"]
120
  return total
 
123
  if c < max_length: return c, msg
124
  msg = [m for m in msg if m.role in ["system", "user"]]
125
  c = count()
126
+ if c < max_length: return c, msg
127
  msg_ = [m for m in msg[:-1] if m.role == "system"]
128
  msg_.append(msg[-1])
129
  msg = msg_
130
  c = count()
131
+ if c < max_length: return c, msg
132
  ll = num_tokens_from_string(msg_[0].content)
133
  l = num_tokens_from_string(msg_[-1].content)
134
+ if ll / (ll + l) > 0.8:
135
  m = msg_[0].content
136
+ m = encoder.decode(encoder.encode(m)[:max_length - l])
137
  msg[0].content = m
138
  return max_length, msg
139
 
140
  m = msg_[1].content
141
+ m = encoder.decode(encoder.encode(m)[:max_length - l])
142
  msg[1].content = m
143
  return max_length, msg
144
 
 
150
  req = request.json
151
  msg = []
152
  for m in req["messages"]:
153
+ if m["role"] == "system": continue
154
+ if m["role"] == "assistant" and not msg: continue
155
  msg.append({"role": m["role"], "content": m["content"]})
156
  try:
157
  e, dia = DialogService.get_by_id(req["dialog_id"])
 
168
  assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
169
  llm = LLMService.query(llm_name=dialog.llm_id)
170
  if not llm:
171
+ raise LookupError("LLM(%s) not found" % dialog.llm_id)
172
  llm = llm[0]
173
  question = messages[-1]["content"]
174
  embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
 
177
  field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
178
  ## try to use sql if field mapping is good to go
179
  if field_map:
180
+ stat_logger.info("Use SQL to retrieval.")
181
+ markdown_tbl, chunks = use_sql(question, field_map, dialog.tenant_id, chat_mdl)
182
  if markdown_tbl:
183
  return {"answer": markdown_tbl, "retrieval": {"chunks": chunks}}
184
 
185
  prompt_config = dialog.prompt_config
186
  for p in prompt_config["parameters"]:
187
+ if p["key"] == "knowledge": continue
188
+ if p["key"] not in kwargs and not p["optional"]: raise KeyError("Miss parameter: " + p["key"])
189
  if p["key"] not in kwargs:
190
+ prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
191
 
192
+ kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
193
+ dialog.similarity_threshold,
194
+ dialog.vector_similarity_weight, top=1024, aggs=False)
195
  knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
196
 
197
  if not knowledges and prompt_config["empty_response"]:
 
206
  answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
207
 
208
  answer = retrievaler.insert_citations(answer,
209
+ [ck["content_ltks"] for ck in kbinfos["chunks"]],
210
+ [ck["vector"] for ck in kbinfos["chunks"]],
211
+ embd_mdl,
212
+ tkweight=1 - dialog.vector_similarity_weight,
213
+ vtweight=dialog.vector_similarity_weight)
214
  for c in kbinfos["chunks"]:
215
+ if c.get("vector"): del c["vector"]
216
  return {"answer": answer, "retrieval": kbinfos}
217
 
218
 
219
+ def use_sql(question, field_map, tenant_id, chat_mdl):
220
  sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据我的问题写出sql。"
221
  user_promt = """
222
  表名:{};
 
224
  {}
225
 
226
  问题:{}
227
+ 请写出SQL,且只要SQL,不要有其他说明及文字。
228
  """.format(
229
  index_name(tenant_id),
230
+ "\n".join([f"{k}: {v}" for k, v in field_map.items()]),
231
  question
232
  )
233
+ sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.06})
234
+ stat_logger.info(f"“{question}” get SQL: {sql}")
235
+ sql = re.sub(r"[\r\n]+", " ", sql.lower())
236
+ sql = re.sub(r".*?select ", "select ", sql.lower())
237
  sql = re.sub(r" +", " ", sql)
238
+ sql = re.sub(r"([;;]|```).*", "", sql)
239
+ if sql[:len("select ")] != "select ":
240
  return None, None
241
+ if sql[:len("select *")] != "select *":
242
  sql = "select doc_id,docnm_kwd," + sql[6:]
243
+ else:
244
+ flds = []
245
+ for k in field_map.keys():
246
+ if k in forbidden_select_fields4resume:continue
247
+ if len(flds) > 11:break
248
+ flds.append(k)
249
+ sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
250
 
251
+ stat_logger.info(f"“{question}” get SQL(refined): {sql}")
252
+ tbl = retrievaler.sql_retrieval(sql, format="json")
253
+ if not tbl or len(tbl["rows"]) == 0: return None, None
254
 
255
  docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
256
  docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
257
+ clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]
258
 
259
  # compose markdown table
260
+ clmns = "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], f"C{i}")) for i in clmn_idx]) + "|原文"
261
  line = "|".join(["------" for _ in range(len(clmn_idx))]) + "|------"
262
+ rows = ["|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
263
  if not docid_idx or not docnm_idx:
264
  access_logger.error("SQL missing field: " + sql)
265
  return "\n".join([clmns, line, "\n".join(rows)]), []
266
 
267
+ rows = "\n".join([r + f"##{ii}$$" for ii, r in enumerate(rows)])
268
  docid_idx = list(docid_idx)[0]
269
  docnm_idx = list(docnm_idx)[0]
270
  return "\n".join([clmns, line, rows]), [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]]
api/apps/dialog_app.py CHANGED
@@ -27,7 +27,7 @@ from api.utils.api_utils import get_json_result
27
 
28
  @manager.route('/set', methods=['POST'])
29
  @login_required
30
- def set():
31
  req = request.json
32
  dialog_id = req.get("dialog_id")
33
  name = req.get("name", "New Dialog")
 
27
 
28
  @manager.route('/set', methods=['POST'])
29
  @login_required
30
+ def set_dialog():
31
  req = request.json
32
  dialog_id = req.get("dialog_id")
33
  name = req.get("name", "New Dialog")
api/apps/document_app.py CHANGED
@@ -262,17 +262,18 @@ def rename():
262
  return server_error_response(e)
263
 
264
 
265
- @manager.route('/get', methods=['GET'])
266
- @login_required
267
- def get():
268
- doc_id = request.args["doc_id"]
269
  try:
270
  e, doc = DocumentService.get_by_id(doc_id)
271
  if not e:
272
  return get_data_error_result(retmsg="Document not found!")
273
 
274
- blob = MINIO.get(doc.kb_id, doc.location)
275
- return get_json_result(data={"base64": base64.b64decode(blob)})
 
 
 
276
  except Exception as e:
277
  return server_error_response(e)
278
 
 
262
  return server_error_response(e)
263
 
264
 
265
+ @manager.route('/get/<doc_id>', methods=['GET'])
266
+ def get(doc_id):
 
 
267
  try:
268
  e, doc = DocumentService.get_by_id(doc_id)
269
  if not e:
270
  return get_data_error_result(retmsg="Document not found!")
271
 
272
+ response = flask.make_response(MINIO.get(doc.kb_id, doc.location))
273
+ ext = re.search(r"\.([^.]+)$", doc.name)
274
+ if ext:
275
+ response.headers.set('Content-Type', 'application/%s'%ext.group(1))
276
+ return response
277
  except Exception as e:
278
  return server_error_response(e)
279
 
api/apps/kb_app.py CHANGED
@@ -38,6 +38,9 @@ def create():
38
  req["id"] = get_uuid()
39
  req["tenant_id"] = current_user.id
40
  req["created_by"] = current_user.id
 
 
 
41
  if not KnowledgebaseService.save(**req): return get_data_error_result()
42
  return get_json_result(data={"kb_id": req["id"]})
43
  except Exception as e:
 
38
  req["id"] = get_uuid()
39
  req["tenant_id"] = current_user.id
40
  req["created_by"] = current_user.id
41
+ e, t = TenantService.get_by_id(current_user.id)
42
+ if not e: return get_data_error_result(retmsg="Tenant not found.")
43
+ req["embd_id"] = t.embd_id
44
  if not KnowledgebaseService.save(**req): return get_data_error_result()
45
  return get_json_result(data={"kb_id": req["id"]})
46
  except Exception as e:
api/apps/llm_app.py CHANGED
@@ -21,11 +21,12 @@ from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, L
21
  from api.db.services.user_service import TenantService, UserTenantService
22
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
23
  from api.utils import get_uuid, get_format_time
24
- from api.db import StatusEnum, UserTenantRole
25
  from api.db.services.knowledgebase_service import KnowledgebaseService
26
  from api.db.db_models import Knowledgebase, TenantLLM
27
  from api.settings import stat_logger, RetCode
28
  from api.utils.api_utils import get_json_result
 
29
 
30
 
31
  @manager.route('/factories', methods=['GET'])
@@ -43,16 +44,37 @@ def factories():
43
  @validate_request("llm_factory", "api_key")
44
  def set_api_key():
45
  req = request.json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  llm = {
47
  "tenant_id": current_user.id,
48
  "llm_factory": req["llm_factory"],
49
  "api_key": req["api_key"]
50
  }
51
- # TODO: Test api_key
52
  for n in ["model_type", "llm_name"]:
53
  if n in req: llm[n] = req[n]
54
 
55
- TenantLLM.insert(**llm).on_conflict("replace").execute()
56
  return get_json_result(data=True)
57
 
58
 
@@ -69,6 +91,7 @@ def my_llms():
69
  @manager.route('/list', methods=['GET'])
70
  @login_required
71
  def list():
 
72
  try:
73
  objs = TenantLLMService.query(tenant_id=current_user.id)
74
  mdlnms = set([o.to_dict()["llm_name"] for o in objs if o.api_key])
@@ -79,6 +102,7 @@ def list():
79
 
80
  res = {}
81
  for m in llms:
 
82
  if m["fid"] not in res: res[m["fid"]] = []
83
  res[m["fid"]].append(m)
84
 
 
21
  from api.db.services.user_service import TenantService, UserTenantService
22
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
23
  from api.utils import get_uuid, get_format_time
24
+ from api.db import StatusEnum, UserTenantRole, LLMType
25
  from api.db.services.knowledgebase_service import KnowledgebaseService
26
  from api.db.db_models import Knowledgebase, TenantLLM
27
  from api.settings import stat_logger, RetCode
28
  from api.utils.api_utils import get_json_result
29
+ from rag.llm import EmbeddingModel, CvModel, ChatModel
30
 
31
 
32
  @manager.route('/factories', methods=['GET'])
 
44
  @validate_request("llm_factory", "api_key")
45
  def set_api_key():
46
  req = request.json
47
+ # test if api key works
48
+ msg = ""
49
+ for llm in LLMService.query(fid=req["llm_factory"]):
50
+ if llm.model_type == LLMType.EMBEDDING.value:
51
+ mdl = EmbeddingModel[req["llm_factory"]](
52
+ req["api_key"], llm.llm_name)
53
+ try:
54
+ arr, tc = mdl.encode(["Test if the api key is available"])
55
+ if len(arr[0]) == 0 or tc ==0: raise Exception("Fail")
56
+ except Exception as e:
57
+ msg += f"\nFail to access embedding model({llm.llm_name}) using this api key."
58
+ elif llm.model_type == LLMType.CHAT.value:
59
+ mdl = ChatModel[req["llm_factory"]](
60
+ req["api_key"], llm.llm_name)
61
+ try:
62
+ m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
63
+ if not tc: raise Exception(m)
64
+ except Exception as e:
65
+ msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(e)
66
+
67
+ if msg: return get_data_error_result(retmsg=msg)
68
+
69
  llm = {
70
  "tenant_id": current_user.id,
71
  "llm_factory": req["llm_factory"],
72
  "api_key": req["api_key"]
73
  }
 
74
  for n in ["model_type", "llm_name"]:
75
  if n in req: llm[n] = req[n]
76
 
77
+ TenantLLMService.filter_update([TenantLLM.tenant_id==llm["tenant_id"], TenantLLM.llm_factory==llm["llm_factory"]], llm)
78
  return get_json_result(data=True)
79
 
80
 
 
91
  @manager.route('/list', methods=['GET'])
92
  @login_required
93
  def list():
94
+ model_type = request.args.get("model_type")
95
  try:
96
  objs = TenantLLMService.query(tenant_id=current_user.id)
97
  mdlnms = set([o.to_dict()["llm_name"] for o in objs if o.api_key])
 
102
 
103
  res = {}
104
  for m in llms:
105
+ if model_type and m["model_type"] != model_type: continue
106
  if m["fid"] not in res: res[m["fid"]] = []
107
  res[m["fid"]].append(m)
108
 
api/apps/user_app.py CHANGED
@@ -24,7 +24,8 @@ from api.db.services.llm_service import TenantLLMService, LLMService
24
  from api.utils.api_utils import server_error_response, validate_request
25
  from api.utils import get_uuid, get_format_time, decrypt, download_img
26
  from api.db import UserTenantRole, LLMType
27
- from api.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS
 
28
  from api.db.services.user_service import UserService, TenantService, UserTenantService
29
  from api.settings import stat_logger
30
  from api.utils.api_utils import get_json_result, cors_reponse
@@ -204,8 +205,8 @@ def user_register(user_id, user):
204
  "role": UserTenantRole.OWNER
205
  }
206
  tenant_llm = []
207
- for llm in LLMService.query(fid="Infiniflow"):
208
- tenant_llm.append({"tenant_id": user_id, "llm_factory": "Infiniflow", "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": "infiniflow API Key"})
209
 
210
  if not UserService.save(**user):return
211
  TenantService.save(**tenant)
 
24
  from api.utils.api_utils import server_error_response, validate_request
25
  from api.utils import get_uuid, get_format_time, decrypt, download_img
26
  from api.db import UserTenantRole, LLMType
27
+ from api.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, API_KEY, \
28
+ LLM_FACTORY
29
  from api.db.services.user_service import UserService, TenantService, UserTenantService
30
  from api.settings import stat_logger
31
  from api.utils.api_utils import get_json_result, cors_reponse
 
205
  "role": UserTenantRole.OWNER
206
  }
207
  tenant_llm = []
208
+ for llm in LLMService.query(fid=LLM_FACTORY):
209
+ tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY})
210
 
211
  if not UserService.save(**user):return
212
  TenantService.save(**tenant)
api/db/db_models.py CHANGED
@@ -465,7 +465,8 @@ class Knowledgebase(DataBaseModel):
465
  tenant_id = CharField(max_length=32, null=False)
466
  name = CharField(max_length=128, null=False, help_text="KB name", index=True)
467
  description = TextField(null=True, help_text="KB description")
468
- permission = CharField(max_length=16, null=False, help_text="me|team")
 
469
  created_by = CharField(max_length=32, null=False)
470
  doc_num = IntegerField(default=0)
471
  token_num = IntegerField(default=0)
 
465
  tenant_id = CharField(max_length=32, null=False)
466
  name = CharField(max_length=128, null=False, help_text="KB name", index=True)
467
  description = TextField(null=True, help_text="KB description")
468
+ embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID")
469
+ permission = CharField(max_length=16, null=False, help_text="me|team", default="me")
470
  created_by = CharField(max_length=32, null=False)
471
  doc_num = IntegerField(default=0)
472
  token_num = IntegerField(default=0)
api/db/init_data.py CHANGED
@@ -46,11 +46,6 @@ def init_llm_factory():
46
  "logo": "",
47
  "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
48
  "status": "1",
49
- },{
50
- "name": "Infiniflow",
51
- "logo": "",
52
- "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
53
- "status": "1",
54
  },{
55
  "name": "智普AI",
56
  "logo": "",
@@ -135,59 +130,33 @@ def init_llm_factory():
135
  "model_type": LLMType.SPEECH2TEXT.value
136
  },{
137
  "fid": factory_infos[1]["name"],
138
- "llm_name": "qwen_vl_chat_v1",
139
- "tags": "LLM,CHAT,IMAGE2TEXT",
140
- "max_tokens": 765,
141
- "model_type": LLMType.IMAGE2TEXT.value
142
- },
143
- # ----------------------- Infiniflow -----------------------
144
- {
145
- "fid": factory_infos[2]["name"],
146
- "llm_name": "gpt-3.5-turbo",
147
- "tags": "LLM,CHAT,4K",
148
- "max_tokens": 4096,
149
- "model_type": LLMType.CHAT.value
150
- },{
151
- "fid": factory_infos[2]["name"],
152
- "llm_name": "text-embedding-ada-002",
153
- "tags": "TEXT EMBEDDING,8K",
154
- "max_tokens": 8191,
155
- "model_type": LLMType.EMBEDDING.value
156
- },{
157
- "fid": factory_infos[2]["name"],
158
- "llm_name": "whisper-1",
159
- "tags": "SPEECH2TEXT",
160
- "max_tokens": 25*1024*1024,
161
- "model_type": LLMType.SPEECH2TEXT.value
162
- },{
163
- "fid": factory_infos[2]["name"],
164
- "llm_name": "gpt-4-vision-preview",
165
  "tags": "LLM,CHAT,IMAGE2TEXT",
166
  "max_tokens": 765,
167
  "model_type": LLMType.IMAGE2TEXT.value
168
  },
169
  # ---------------------- ZhipuAI ----------------------
170
  {
171
- "fid": factory_infos[3]["name"],
172
  "llm_name": "glm-3-turbo",
173
  "tags": "LLM,CHAT,",
174
  "max_tokens": 128 * 1000,
175
  "model_type": LLMType.CHAT.value
176
  }, {
177
- "fid": factory_infos[3]["name"],
178
  "llm_name": "glm-4",
179
  "tags": "LLM,CHAT,",
180
  "max_tokens": 128 * 1000,
181
  "model_type": LLMType.CHAT.value
182
  }, {
183
- "fid": factory_infos[3]["name"],
184
  "llm_name": "glm-4v",
185
  "tags": "LLM,CHAT,IMAGE2TEXT",
186
  "max_tokens": 2000,
187
  "model_type": LLMType.IMAGE2TEXT.value
188
  },
189
  {
190
- "fid": factory_infos[3]["name"],
191
  "llm_name": "embedding-2",
192
  "tags": "TEXT EMBEDDING",
193
  "max_tokens": 512,
 
46
  "logo": "",
47
  "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
48
  "status": "1",
 
 
 
 
 
49
  },{
50
  "name": "智普AI",
51
  "logo": "",
 
130
  "model_type": LLMType.SPEECH2TEXT.value
131
  },{
132
  "fid": factory_infos[1]["name"],
133
+ "llm_name": "qwen-vl-max",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  "tags": "LLM,CHAT,IMAGE2TEXT",
135
  "max_tokens": 765,
136
  "model_type": LLMType.IMAGE2TEXT.value
137
  },
138
  # ---------------------- ZhipuAI ----------------------
139
  {
140
+ "fid": factory_infos[2]["name"],
141
  "llm_name": "glm-3-turbo",
142
  "tags": "LLM,CHAT,",
143
  "max_tokens": 128 * 1000,
144
  "model_type": LLMType.CHAT.value
145
  }, {
146
+ "fid": factory_infos[2]["name"],
147
  "llm_name": "glm-4",
148
  "tags": "LLM,CHAT,",
149
  "max_tokens": 128 * 1000,
150
  "model_type": LLMType.CHAT.value
151
  }, {
152
+ "fid": factory_infos[2]["name"],
153
  "llm_name": "glm-4v",
154
  "tags": "LLM,CHAT,IMAGE2TEXT",
155
  "max_tokens": 2000,
156
  "model_type": LLMType.IMAGE2TEXT.value
157
  },
158
  {
159
+ "fid": factory_infos[2]["name"],
160
  "llm_name": "embedding-2",
161
  "tags": "TEXT EMBEDDING",
162
  "max_tokens": 512,
api/db/services/knowledgebase_service.py CHANGED
@@ -77,9 +77,12 @@ class KnowledgebaseService(CommonService):
77
  if isinstance(v, dict):
78
  assert isinstance(old[k], dict)
79
  dfs_update(old[k], v)
 
 
 
80
  else: old[k] = v
81
  dfs_update(m.parser_config, config)
82
- cls.update_by_id(id, m.parser_config)
83
 
84
 
85
  @classmethod
@@ -88,6 +91,6 @@ class KnowledgebaseService(CommonService):
88
  conf = {}
89
  for k in cls.get_by_ids(ids):
90
  if k.parser_config and "field_map" in k.parser_config:
91
- conf.update(k.parser_config)
92
  return conf
93
 
 
77
  if isinstance(v, dict):
78
  assert isinstance(old[k], dict)
79
  dfs_update(old[k], v)
80
+ if isinstance(v, list):
81
+ assert isinstance(old[k], list)
82
+ old[k] = list(set(old[k]+v))
83
  else: old[k] = v
84
  dfs_update(m.parser_config, config)
85
+ cls.update_by_id(id, {"parser_config": m.parser_config})
86
 
87
 
88
  @classmethod
 
91
  conf = {}
92
  for k in cls.get_by_ids(ids):
93
  if k.parser_config and "field_map" in k.parser_config:
94
+ conf.update(k.parser_config["field_map"])
95
  return conf
96
 
api/settings.py CHANGED
@@ -43,12 +43,14 @@ REQUEST_MAX_WAIT_SEC = 300
43
 
44
  USE_REGISTRY = get_base_config("use_registry")
45
 
46
- LLM = get_base_config("llm", {})
47
- CHAT_MDL = LLM.get("chat_model", "gpt-3.5-turbo")
48
- EMBEDDING_MDL = LLM.get("embedding_model", "text-embedding-ada-002")
49
- ASR_MDL = LLM.get("asr_model", "whisper-1")
 
 
 
50
  PARSERS = LLM.get("parsers", "general:General,qa:Q&A,resume:Resume,naive:Naive,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
51
- IMAGE2TEXT_MDL = LLM.get("image2text_model", "gpt-4-vision-preview")
52
 
53
  # distribution
54
  DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
 
43
 
44
  USE_REGISTRY = get_base_config("use_registry")
45
 
46
+ LLM = get_base_config("user_default_llm", {})
47
+ LLM_FACTORY=LLM.get("factory", "通义千问")
48
+ CHAT_MDL = LLM.get("chat_model", "qwen-plus")
49
+ EMBEDDING_MDL = LLM.get("embedding_model", "text-embedding-v2")
50
+ ASR_MDL = LLM.get("asr_model", "paraformer-realtime-8k-v1")
51
+ IMAGE2TEXT_MDL = LLM.get("image2text_model", "qwen-vl-max")
52
+ API_KEY = LLM.get("api_key", "infiniflow API Key")
53
  PARSERS = LLM.get("parsers", "general:General,qa:Q&A,resume:Resume,naive:Naive,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
 
54
 
55
  # distribution
56
  DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
api/utils/file_utils.py CHANGED
@@ -164,10 +164,10 @@ def thumbnail(filename, blob):
164
  buffered = BytesIO()
165
  Image.frombytes("RGB", [pix.width, pix.height],
166
  pix.samples).save(buffered, format="png")
167
- return "data:image/png;base64," + base64.b64encode(buffered.getvalue())
168
 
169
  if re.match(r".*\.(jpg|jpeg|png|tif|gif|icon|ico|webp)$", filename):
170
- return ("data:image/%s;base64,"%filename.split(".")[-1]) + base64.b64encode(Image.open(BytesIO(blob)).thumbnail((30, 30)).tobytes())
171
 
172
  if re.match(r".*\.(ppt|pptx)$", filename):
173
  import aspose.slides as slides
@@ -176,7 +176,7 @@ def thumbnail(filename, blob):
176
  with slides.Presentation(BytesIO(blob)) as presentation:
177
  buffered = BytesIO()
178
  presentation.slides[0].get_thumbnail(0.03, 0.03).save(buffered, drawing.imaging.ImageFormat.png)
179
- return "data:image/png;base64," + base64.b64encode(buffered.getvalue())
180
  except Exception as e:
181
  pass
182
 
 
164
  buffered = BytesIO()
165
  Image.frombytes("RGB", [pix.width, pix.height],
166
  pix.samples).save(buffered, format="png")
167
+ return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8")
168
 
169
  if re.match(r".*\.(jpg|jpeg|png|tif|gif|icon|ico|webp)$", filename):
170
+ return ("data:image/%s;base64,"%filename.split(".")[-1]) + base64.b64encode(Image.open(BytesIO(blob)).thumbnail((30, 30)).tobytes()).decode("utf-8")
171
 
172
  if re.match(r".*\.(ppt|pptx)$", filename):
173
  import aspose.slides as slides
 
176
  with slides.Presentation(BytesIO(blob)) as presentation:
177
  buffered = BytesIO()
178
  presentation.slides[0].get_thumbnail(0.03, 0.03).save(buffered, drawing.imaging.ImageFormat.png)
179
+ return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8")
180
  except Exception as e:
181
  pass
182
 
conf/mapping.json CHANGED
@@ -118,11 +118,45 @@
118
  },
119
  {
120
  "dense_vector": {
121
- "match": "*_vec",
122
  "mapping": {
123
  "type": "dense_vector",
124
  "index": true,
125
- "similarity": "cosine"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  }
127
  }
128
  },
 
118
  },
119
  {
120
  "dense_vector": {
121
+ "match": "*_512_vec",
122
  "mapping": {
123
  "type": "dense_vector",
124
  "index": true,
125
+ "similarity": "cosine",
126
+ "dims": 512
127
+ }
128
+ }
129
+ },
130
+ {
131
+ "dense_vector": {
132
+ "match": "*_768_vec",
133
+ "mapping": {
134
+ "type": "dense_vector",
135
+ "index": true,
136
+ "similarity": "cosine",
137
+ "dims": 768
138
+ }
139
+ }
140
+ },
141
+ {
142
+ "dense_vector": {
143
+ "match": "*_1024_vec",
144
+ "mapping": {
145
+ "type": "dense_vector",
146
+ "index": true,
147
+ "similarity": "cosine",
148
+ "dims": 1024
149
+ }
150
+ }
151
+ },
152
+ {
153
+ "dense_vector": {
154
+ "match": "*_1536_vec",
155
+ "mapping": {
156
+ "type": "dense_vector",
157
+ "index": true,
158
+ "similarity": "cosine",
159
+ "dims": 1536
160
  }
161
  }
162
  },
conf/service_conf.yaml CHANGED
@@ -11,7 +11,7 @@ permission:
11
  dataset: false
12
  ragflow:
13
  # you must set real ip address, 127.0.0.1 and 0.0.0.0 is not supported
14
- host: 127.0.0.1
15
  http_port: 9380
16
  database:
17
  name: 'rag_flow'
@@ -21,6 +21,19 @@ database:
21
  port: 5455
22
  max_connections: 100
23
  stale_timeout: 30
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  oauth:
25
  github:
26
  client_id: 302129228f0d96055bee
 
11
  dataset: false
12
  ragflow:
13
  # you must set real ip address, 127.0.0.1 and 0.0.0.0 is not supported
14
+ host: 0.0.0.0
15
  http_port: 9380
16
  database:
17
  name: 'rag_flow'
 
21
  port: 5455
22
  max_connections: 100
23
  stale_timeout: 30
24
+ minio:
25
+ user: 'rag_flow'
26
+ passwd: 'infini_rag_flow'
27
+ host: '123.60.95.134:9000'
28
+ es:
29
+ hosts: 'http://123.60.95.134:9200'
30
+ user_default_llm:
31
+ factory: '通义千问'
32
+ chat_model: 'qwen-plus'
33
+ embedding_model: 'text-embedding-v2'
34
+ asr_model: 'paraformer-realtime-8k-v1'
35
+ image2text_model: 'qwen-vl-max'
36
+ api_key: 'sk-xxxxxxxxxxxxx'
37
  oauth:
38
  github:
39
  client_id: 302129228f0d96055bee
rag/app/book.py CHANGED
@@ -39,6 +39,11 @@ class Pdf(HuParser):
39
 
40
 
41
  def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
 
 
 
 
 
42
  doc = {
43
  "docnm_kwd": filename,
44
  "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
 
39
 
40
 
41
  def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
42
+ """
43
+ Supported file formats are docx, pdf, txt.
44
+ Since a book is long and not all the parts are useful, if it's a PDF,
45
+ please setup the page ranges for every book in order eliminate negative effects and save elapsed computing time.
46
+ """
47
  doc = {
48
  "docnm_kwd": filename,
49
  "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
rag/app/laws.py CHANGED
@@ -2,7 +2,6 @@ import copy
2
  import re
3
  from io import BytesIO
4
  from docx import Document
5
- import numpy as np
6
  from rag.parser import bullets_category, is_english, tokenize, remove_contents_table, hierarchical_merge, \
7
  make_colon_as_title
8
  from rag.nlp import huqie
@@ -59,6 +58,9 @@ class Pdf(HuParser):
59
 
60
 
61
  def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
 
 
 
62
  doc = {
63
  "docnm_kwd": filename,
64
  "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
 
2
  import re
3
  from io import BytesIO
4
  from docx import Document
 
5
  from rag.parser import bullets_category, is_english, tokenize, remove_contents_table, hierarchical_merge, \
6
  make_colon_as_title
7
  from rag.nlp import huqie
 
58
 
59
 
60
  def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
61
+ """
62
+ Supported file formats are docx, pdf, txt.
63
+ """
64
  doc = {
65
  "docnm_kwd": filename,
66
  "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
rag/app/manual.py CHANGED
@@ -58,8 +58,10 @@ class Pdf(HuParser):
58
 
59
 
60
  def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
 
 
 
61
  pdf_parser = None
62
- paper = {}
63
 
64
  if re.search(r"\.pdf$", filename, re.IGNORECASE):
65
  pdf_parser = Pdf()
 
58
 
59
 
60
  def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
61
+ """
62
+ Only pdf is supported.
63
+ """
64
  pdf_parser = None
 
65
 
66
  if re.search(r"\.pdf$", filename, re.IGNORECASE):
67
  pdf_parser = Pdf()
rag/app/naive.py CHANGED
@@ -6,6 +6,7 @@ from rag.nlp import huqie
6
  from rag.parser.pdf_parser import HuParser
7
  from rag.settings import cron_logger
8
 
 
9
  class Pdf(HuParser):
10
  def __call__(self, filename, binary=None, from_page=0,
11
  to_page=100000, zoomin=3, callback=None):
@@ -20,12 +21,18 @@ class Pdf(HuParser):
20
  start = timer()
21
  self._layouts_paddle(zoomin)
22
  callback(0.77, "Layout analysis finished")
23
- cron_logger.info("paddle layouts:".format((timer()-start)/(self.total_page+0.1)))
24
  self._naive_vertical_merge()
25
  return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes]
26
 
27
 
28
  def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
 
 
 
 
 
 
29
  doc = {
30
  "docnm_kwd": filename,
31
  "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
@@ -41,24 +48,26 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
41
  elif re.search(r"\.pdf$", filename, re.IGNORECASE):
42
  pdf_parser = Pdf()
43
  sections = pdf_parser(filename if not binary else binary,
44
- from_page=from_page, to_page=to_page, callback=callback)
45
  elif re.search(r"\.txt$", filename, re.IGNORECASE):
46
  callback(0.1, "Start to parse.")
47
  txt = ""
48
- if binary:txt = binary.decode("utf-8")
 
49
  else:
50
  with open(filename, "r") as f:
51
  while True:
52
  l = f.readline()
53
- if not l:break
54
  txt += l
55
  sections = txt.split("\n")
56
- sections = [(l,"") for l in sections if l]
57
  callback(0.8, "Finish parsing.")
58
- else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)")
 
59
 
60
- parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimer": "\n。;!?"})
61
- cks = naive_merge(sections, parser_config["chunk_token_num"], parser_config["delimer"])
62
  eng = is_english(cks)
63
  res = []
64
  # wrap up to es documents
@@ -75,6 +84,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
75
 
76
  if __name__ == "__main__":
77
  import sys
 
 
78
  def dummy(a, b):
79
  pass
 
 
80
  chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
 
6
  from rag.parser.pdf_parser import HuParser
7
  from rag.settings import cron_logger
8
 
9
+
10
  class Pdf(HuParser):
11
  def __call__(self, filename, binary=None, from_page=0,
12
  to_page=100000, zoomin=3, callback=None):
 
21
  start = timer()
22
  self._layouts_paddle(zoomin)
23
  callback(0.77, "Layout analysis finished")
24
+ cron_logger.info("paddle layouts:".format((timer() - start) / (self.total_page + 0.1)))
25
  self._naive_vertical_merge()
26
  return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes]
27
 
28
 
29
  def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
30
+ """
31
+ Supported file formats are docx, pdf, txt.
32
+ This method apply the naive ways to chunk files.
33
+ Successive text will be sliced into pieces using 'delimiter'.
34
+ Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'.
35
+ """
36
  doc = {
37
  "docnm_kwd": filename,
38
  "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
 
48
  elif re.search(r"\.pdf$", filename, re.IGNORECASE):
49
  pdf_parser = Pdf()
50
  sections = pdf_parser(filename if not binary else binary,
51
+ from_page=from_page, to_page=to_page, callback=callback)
52
  elif re.search(r"\.txt$", filename, re.IGNORECASE):
53
  callback(0.1, "Start to parse.")
54
  txt = ""
55
+ if binary:
56
+ txt = binary.decode("utf-8")
57
  else:
58
  with open(filename, "r") as f:
59
  while True:
60
  l = f.readline()
61
+ if not l: break
62
  txt += l
63
  sections = txt.split("\n")
64
+ sections = [(l, "") for l in sections if l]
65
  callback(0.8, "Finish parsing.")
66
+ else:
67
+ raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)")
68
 
69
+ parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimiter": "\n!?。;!?"})
70
+ cks = naive_merge(sections, parser_config["chunk_token_num"], parser_config["delimiter"])
71
  eng = is_english(cks)
72
  res = []
73
  # wrap up to es documents
 
84
 
85
  if __name__ == "__main__":
86
  import sys
87
+
88
+
89
  def dummy(a, b):
90
  pass
91
+
92
+
93
  chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
rag/app/paper.py CHANGED
@@ -129,6 +129,10 @@ class Pdf(HuParser):
129
 
130
 
131
  def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
 
 
 
 
132
  pdf_parser = None
133
  if re.search(r"\.pdf$", filename, re.IGNORECASE):
134
  pdf_parser = Pdf()
 
129
 
130
 
131
  def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
132
+ """
133
+ Only pdf is supported.
134
+ The abstract of the paper will be sliced as an entire chunk, and will not be sliced partly.
135
+ """
136
  pdf_parser = None
137
  if re.search(r"\.pdf$", filename, re.IGNORECASE):
138
  pdf_parser = Pdf()
rag/app/presentation.py CHANGED
@@ -94,6 +94,11 @@ class Pdf(HuParser):
94
 
95
 
96
  def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
 
 
 
 
 
97
  doc = {
98
  "docnm_kwd": filename,
99
  "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
 
94
 
95
 
96
  def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
97
+ """
98
+ The supported file formats are pdf, pptx.
99
+ Every page will be treated as a chunk. And the thumbnail of every page will be stored.
100
+ PPT file will be parsed by using this method automatically, setting-up for every PPT file is not necessary.
101
+ """
102
  doc = {
103
  "docnm_kwd": filename,
104
  "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
rag/app/qa.py CHANGED
@@ -70,7 +70,17 @@ def beAdoc(d, q, a, eng):
70
 
71
 
72
  def chunk(filename, binary=None, callback=None, **kwargs):
 
 
 
 
 
73
 
 
 
 
 
 
74
  res = []
75
  if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
76
  callback(0.1, "Start to parse.")
 
70
 
71
 
72
  def chunk(filename, binary=None, callback=None, **kwargs):
73
+ """
74
+ Excel and csv(txt) format files are supported.
75
+ If the file is in excel format, there should be 2 column question and answer without header.
76
+ And question column is ahead of answer column.
77
+ And it's O.K if it has multiple sheets as long as the columns are rightly composed.
78
 
79
+ If it's in csv format, it should be UTF-8 encoded. Use TAB as delimiter to separate question and answer.
80
+
81
+ All the deformed lines will be ignored.
82
+ Every pair of Q&A will be treated as a chunk.
83
+ """
84
  res = []
85
  if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
86
  callback(0.1, "Start to parse.")
rag/app/resume.py CHANGED
@@ -4,24 +4,34 @@ import os
4
  import re
5
  import requests
6
  from api.db.services.knowledgebase_service import KnowledgebaseService
 
7
  from rag.nlp import huqie
8
 
9
  from rag.settings import cron_logger
10
  from rag.utils import rmSpace
11
 
 
 
 
12
 
13
  def chunk(filename, binary=None, callback=None, **kwargs):
 
 
 
 
 
 
 
 
14
  if not re.search(r"\.(pdf|doc|docx|txt)$", filename, flags=re.IGNORECASE):
15
  raise NotImplementedError("file type not supported yet(pdf supported)")
16
 
17
  url = os.environ.get("INFINIFLOW_SERVER")
18
- if not url:
19
- raise EnvironmentError(
20
- "Please set environment variable: 'INFINIFLOW_SERVER'")
21
  token = os.environ.get("INFINIFLOW_TOKEN")
22
- if not token:
23
- raise EnvironmentError(
24
- "Please set environment variable: 'INFINIFLOW_TOKEN'")
 
25
 
26
  if not binary:
27
  with open(filename, "rb") as f:
@@ -44,22 +54,28 @@ def chunk(filename, binary=None, callback=None, **kwargs):
44
 
45
  callback(0.2, "Resume parsing is going on...")
46
  resume = remote_call()
 
 
 
47
  callback(0.6, "Done parsing. Chunking...")
48
  print(json.dumps(resume, ensure_ascii=False, indent=2))
49
 
50
  field_map = {
51
  "name_kwd": "姓名/名字",
 
52
  "gender_kwd": "性别(男,女)",
53
  "age_int": "年龄/岁/年纪",
54
  "phone_kwd": "电话/手机/微信",
55
  "email_tks": "email/e-mail/邮箱",
56
  "position_name_tks": "职位/职能/岗位/职责",
57
- "expect_position_name_tks": "期望职位/期望职能/期望岗位",
 
 
58
 
59
- "hightest_degree_kwd": "最高学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA)",
60
  "first_degree_kwd": "第一学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA)",
 
61
  "first_major_tks": "第一学历专业",
62
- "first_school_name_tks": "第一学历毕业学校",
63
  "edu_first_fea_kwd": "第一学历标签(211,留学,双一流,985,海外知名,重点大学,中专,专升本,专科,本科,大专)",
64
 
65
  "degree_kwd": "过往学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA)",
@@ -68,14 +84,14 @@ def chunk(filename, binary=None, callback=None, **kwargs):
68
  "sch_rank_kwd": "学校标签(顶尖学校,精英学校,优质学校,一般学校)",
69
  "edu_fea_kwd": "教育标签(211,留学,双一流,985,海外知名,重点大学,中专,专升本,专科,本科,大专)",
70
 
71
- "work_exp_flt": "工作年限/工作年份/N年经验/毕业了多少年",
72
- "birth_dt": "生日/出生年份",
73
  "corp_nm_tks": "就职过的公司/之前的公司/上过班的公司",
74
- "corporation_name_tks": "最近就职(上班)的公司/上一家公司",
75
  "edu_end_int": "毕业年份",
76
- "expect_city_names_tks": "期望城市",
77
- "industry_name_tks": "所在行业"
 
 
78
  }
 
79
  titles = []
80
  for n in ["name_kwd", "gender_kwd", "position_name_tks", "age_int"]:
81
  v = resume.get(n, "")
@@ -105,6 +121,10 @@ def chunk(filename, binary=None, callback=None, **kwargs):
105
  doc["content_ltks"] = huqie.qie(doc["content_with_weight"])
106
  doc["content_sm_ltks"] = huqie.qieqie(doc["content_ltks"])
107
  for n, _ in field_map.items():
 
 
 
 
108
  doc[n] = resume[n]
109
 
110
  print(doc)
 
4
  import re
5
  import requests
6
  from api.db.services.knowledgebase_service import KnowledgebaseService
7
+ from api.settings import stat_logger
8
  from rag.nlp import huqie
9
 
10
  from rag.settings import cron_logger
11
  from rag.utils import rmSpace
12
 
13
+ forbidden_select_fields4resume = [
14
+ "name_pinyin_kwd", "edu_first_fea_kwd", "degree_kwd", "sch_rank_kwd", "edu_fea_kwd"
15
+ ]
16
 
17
  def chunk(filename, binary=None, callback=None, **kwargs):
18
+ """
19
+ The supported file formats are pdf, docx and txt.
20
+ To maximize the effectiveness, parse the resume correctly,
21
+ please visit https://github.com/infiniflow/ragflow, and sign in the our demo web-site
22
+ to get token. It's FREE!
23
+ Set INFINIFLOW_SERVER and INFINIFLOW_TOKEN in '.env' file or
24
+ using 'export' to set both environment variables: INFINIFLOW_SERVER and INFINIFLOW_TOKEN in docker container.
25
+ """
26
  if not re.search(r"\.(pdf|doc|docx|txt)$", filename, flags=re.IGNORECASE):
27
  raise NotImplementedError("file type not supported yet(pdf supported)")
28
 
29
  url = os.environ.get("INFINIFLOW_SERVER")
 
 
 
30
  token = os.environ.get("INFINIFLOW_TOKEN")
31
+ if not url or not token:
32
+ stat_logger.warning(
33
+ "INFINIFLOW_SERVER is not specified. To maximize the effectiveness, please visit https://github.com/infiniflow/ragflow, and sign in the our demo web site to get token. It's FREE! Using 'export' to set both environment variables: INFINIFLOW_SERVER and INFINIFLOW_TOKEN.")
34
+ return []
35
 
36
  if not binary:
37
  with open(filename, "rb") as f:
 
54
 
55
  callback(0.2, "Resume parsing is going on...")
56
  resume = remote_call()
57
+ if len(resume.keys()) < 7:
58
+ callback(-1, "Resume is not successfully parsed.")
59
+ return []
60
  callback(0.6, "Done parsing. Chunking...")
61
  print(json.dumps(resume, ensure_ascii=False, indent=2))
62
 
63
  field_map = {
64
  "name_kwd": "姓名/名字",
65
+ "name_pinyin_kwd": "姓名拼音/名字拼音",
66
  "gender_kwd": "性别(男,女)",
67
  "age_int": "年龄/岁/年纪",
68
  "phone_kwd": "电话/手机/微信",
69
  "email_tks": "email/e-mail/邮箱",
70
  "position_name_tks": "职位/职能/岗位/职责",
71
+ "expect_city_names_tks": "期望城市",
72
+ "work_exp_flt": "工作年限/工作年份/N年经验/毕业了多少年",
73
+ "corporation_name_tks": "最近就职(上班)的公司/上一家公司",
74
 
75
+ "first_school_name_tks": "第一学历毕业学校",
76
  "first_degree_kwd": "第一学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA)",
77
+ "highest_degree_kwd": "最高学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA)",
78
  "first_major_tks": "第一学历专业",
 
79
  "edu_first_fea_kwd": "第一学历标签(211,留学,双一流,985,海外知名,重点大学,中专,专升本,专科,本科,大专)",
80
 
81
  "degree_kwd": "过往学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA)",
 
84
  "sch_rank_kwd": "学校标签(顶尖学校,精英学校,优质学校,一般学校)",
85
  "edu_fea_kwd": "教育标签(211,留学,双一流,985,海外知名,重点大学,中专,专升本,专科,本科,大专)",
86
 
 
 
87
  "corp_nm_tks": "就职过的公司/之前的公司/上过班的公司",
 
88
  "edu_end_int": "毕业年份",
89
+ "industry_name_tks": "所在行业",
90
+
91
+ "birth_dt": "生日/出生年份",
92
+ "expect_position_name_tks": "期望职位/期望职能/期望岗位",
93
  }
94
+
95
  titles = []
96
  for n in ["name_kwd", "gender_kwd", "position_name_tks", "age_int"]:
97
  v = resume.get(n, "")
 
121
  doc["content_ltks"] = huqie.qie(doc["content_with_weight"])
122
  doc["content_sm_ltks"] = huqie.qieqie(doc["content_ltks"])
123
  for n, _ in field_map.items():
124
+ if n not in resume:continue
125
+ if isinstance(resume[n], list) and (len(resume[n]) == 1 or n not in forbidden_select_fields4resume):
126
+ resume[n] = resume[n][0]
127
+ if n.find("_tks")>0: resume[n] = huqie.qieqie(resume[n])
128
  doc[n] = resume[n]
129
 
130
  print(doc)
rag/app/table.py CHANGED
@@ -100,7 +100,20 @@ def column_data_type(arr):
100
 
101
 
102
  def chunk(filename, binary=None, callback=None, **kwargs):
103
- dfs = []
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
105
  callback(0.1, "Start to parse.")
106
  excel_parser = Excel()
@@ -155,7 +168,7 @@ def chunk(filename, binary=None, callback=None, **kwargs):
155
  del df[n]
156
  clmns = df.columns.values
157
  txts = list(copy.deepcopy(clmns))
158
- py_clmns = [PY.get_pinyins(n)[0].replace("-", "_") for n in clmns]
159
  clmn_tys = []
160
  for j in range(len(clmns)):
161
  cln, ty = column_data_type(df[clmns[j]])
 
100
 
101
 
102
  def chunk(filename, binary=None, callback=None, **kwargs):
103
+ """
104
+ Excel and csv(txt) format files are supported.
105
+ For csv or txt file, the delimiter between columns is TAB.
106
+ The first line must be column headers.
107
+ Column headers must be meaningful terms inorder to make our NLP model understanding.
108
+ It's good to enumerate some synonyms using slash '/' to separate, and even better to
109
+ enumerate values using brackets like 'gender/sex(male, female)'.
110
+ Here are some examples for headers:
111
+ 1. supplier/vendor\tcolor(yellow, red, brown)\tgender/sex(male, female)\tsize(M,L,XL,XXL)
112
+ 2. 姓名/名字\t电话/手机/微信\t最高学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA)
113
+
114
+ Every row in table will be treated as a chunk.
115
+ """
116
+
117
  if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
118
  callback(0.1, "Start to parse.")
119
  excel_parser = Excel()
 
168
  del df[n]
169
  clmns = df.columns.values
170
  txts = list(copy.deepcopy(clmns))
171
+ py_clmns = [PY.get_pinyins(re.sub(r"(/.*|([^()]+?)|\([^()]+?\))", "", n), '_')[0] for n in clmns]
172
  clmn_tys = []
173
  for j in range(len(clmns)):
174
  cln, ty = column_data_type(df[clmns[j]])
rag/llm/__init__.py CHANGED
@@ -21,7 +21,7 @@ from .cv_model import *
21
  EmbeddingModel = {
22
  "Infiniflow": HuEmbedding,
23
  "OpenAI": OpenAIEmbed,
24
- "通义千问": QWenEmbed,
25
  }
26
 
27
 
 
21
  EmbeddingModel = {
22
  "Infiniflow": HuEmbedding,
23
  "OpenAI": OpenAIEmbed,
24
+ "通义千问": HuEmbedding, #QWenEmbed,
25
  }
26
 
27
 
rag/llm/chat_model.py CHANGED
@@ -32,7 +32,7 @@ class GptTurbo(Base):
32
  self.model_name = model_name
33
 
34
  def chat(self, system, history, gen_conf):
35
- history.insert(0, {"role": "system", "content": system})
36
  res = self.client.chat.completions.create(
37
  model=self.model_name,
38
  messages=history,
@@ -49,11 +49,12 @@ class QWenChat(Base):
49
 
50
  def chat(self, system, history, gen_conf):
51
  from http import HTTPStatus
52
- history.insert(0, {"role": "system", "content": system})
53
  response = Generation.call(
54
  self.model_name,
55
  messages=history,
56
- result_format='message'
 
57
  )
58
  if response.status_code == HTTPStatus.OK:
59
  return response.output.choices[0]['message']['content'], response.usage.output_tokens
@@ -68,10 +69,11 @@ class ZhipuChat(Base):
68
 
69
  def chat(self, system, history, gen_conf):
70
  from http import HTTPStatus
71
- history.insert(0, {"role": "system", "content": system})
72
  response = self.client.chat.completions.create(
73
  self.model_name,
74
- messages=history
 
75
  )
76
  if response.status_code == HTTPStatus.OK:
77
  return response.output.choices[0]['message']['content'], response.usage.completion_tokens
 
32
  self.model_name = model_name
33
 
34
  def chat(self, system, history, gen_conf):
35
+ if system: history.insert(0, {"role": "system", "content": system})
36
  res = self.client.chat.completions.create(
37
  model=self.model_name,
38
  messages=history,
 
49
 
50
  def chat(self, system, history, gen_conf):
51
  from http import HTTPStatus
52
+ if system: history.insert(0, {"role": "system", "content": system})
53
  response = Generation.call(
54
  self.model_name,
55
  messages=history,
56
+ result_format='message',
57
+ **gen_conf
58
  )
59
  if response.status_code == HTTPStatus.OK:
60
  return response.output.choices[0]['message']['content'], response.usage.output_tokens
 
69
 
70
  def chat(self, system, history, gen_conf):
71
  from http import HTTPStatus
72
+ if system: history.insert(0, {"role": "system", "content": system})
73
  response = self.client.chat.completions.create(
74
  self.model_name,
75
+ messages=history,
76
+ **gen_conf
77
  )
78
  if response.status_code == HTTPStatus.OK:
79
  return response.output.choices[0]['message']['content'], response.usage.completion_tokens
rag/llm/embedding_model.py CHANGED
@@ -100,11 +100,11 @@ class QWenEmbed(Base):
100
  input=texts[i:i+batch_size],
101
  text_type="document"
102
  )
103
- embds = [[]] * len(resp["output"]["embeddings"])
104
  for e in resp["output"]["embeddings"]:
105
  embds[e["text_index"]] = e["embedding"]
106
  res.extend(embds)
107
- token_count += resp["usage"]["input_tokens"]
108
  return np.array(res), token_count
109
 
110
  def encode_queries(self, text):
@@ -113,7 +113,7 @@ class QWenEmbed(Base):
113
  input=text[:2048],
114
  text_type="query"
115
  )
116
- return np.array(resp["output"]["embeddings"][0]["embedding"]), resp["usage"]["input_tokens"]
117
 
118
 
119
  from zhipuai import ZhipuAI
 
100
  input=texts[i:i+batch_size],
101
  text_type="document"
102
  )
103
+ embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
104
  for e in resp["output"]["embeddings"]:
105
  embds[e["text_index"]] = e["embedding"]
106
  res.extend(embds)
107
+ token_count += resp["usage"]["total_tokens"]
108
  return np.array(res), token_count
109
 
110
  def encode_queries(self, text):
 
113
  input=text[:2048],
114
  text_type="query"
115
  )
116
+ return np.array(resp["output"]["embeddings"][0]["embedding"]), resp["usage"]["total_tokens"]
117
 
118
 
119
  from zhipuai import ZhipuAI
rag/nlp/search.py CHANGED
@@ -92,7 +92,7 @@ class Dealer:
92
  assert emb_mdl, "No embedding model selected"
93
  s["knn"] = self._vector(
94
  qst, emb_mdl, req.get(
95
- "similarity", 0.4), ps)
96
  s["knn"]["filter"] = bqry.to_dict()
97
  if "highlight" in s:
98
  del s["highlight"]
@@ -106,7 +106,7 @@ class Dealer:
106
  bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
107
  s["query"] = bqry.to_dict()
108
  s["knn"]["filter"] = bqry.to_dict()
109
- s["knn"]["similarity"] = 0.7
110
  res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
111
 
112
  kwds = set([])
@@ -171,7 +171,7 @@ class Dealer:
171
  continue
172
  if not isinstance(v, type("")):
173
  m[n] = str(m[n])
174
- m[n] = rmSpace(m[n])
175
 
176
  if m:
177
  res[d["id"]] = m
@@ -303,21 +303,22 @@ class Dealer:
303
 
304
  return ranks
305
 
306
- def sql_retrieval(self, sql, fetch_size=128):
307
  sql = re.sub(r"[ ]+", " ", sql)
 
 
308
  replaces = []
309
- for r in re.finditer(r" ([a-z_]+_l?tks like |[a-z_]+_l?tks ?= ?)'([^']+)'", sql):
310
- fld, v = r.group(1), r.group(2)
311
- fld = re.sub(r" ?(like|=)$", "", fld).lower()
312
- if v[0] == "%%": v = v[1:-1]
313
- match = " MATCH({}, '{}', 'operator=OR;fuzziness=AUTO:1,3;minimum_should_match=30%') ".format(fld, huqie.qie(v))
314
- replaces.append((r.group(1)+r.group(2), match))
315
 
316
- for p, r in replaces: sql.replace(p, r)
 
317
 
318
  try:
319
- tbl = self.es.sql(sql, fetch_size)
320
  return tbl
321
  except Exception as e:
322
- es_logger(f"SQL failure: {sql} =>" + str(e))
323
 
 
92
  assert emb_mdl, "No embedding model selected"
93
  s["knn"] = self._vector(
94
  qst, emb_mdl, req.get(
95
+ "similarity", 0.1), ps)
96
  s["knn"]["filter"] = bqry.to_dict()
97
  if "highlight" in s:
98
  del s["highlight"]
 
106
  bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
107
  s["query"] = bqry.to_dict()
108
  s["knn"]["filter"] = bqry.to_dict()
109
+ s["knn"]["similarity"] = 0.17
110
  res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
111
 
112
  kwds = set([])
 
171
  continue
172
  if not isinstance(v, type("")):
173
  m[n] = str(m[n])
174
+ if n.find("tks")>0: m[n] = rmSpace(m[n])
175
 
176
  if m:
177
  res[d["id"]] = m
 
303
 
304
  return ranks
305
 
306
+ def sql_retrieval(self, sql, fetch_size=128, format="json"):
307
  sql = re.sub(r"[ ]+", " ", sql)
308
+ sql = sql.replace("%", "")
309
+ es_logger.info(f"Get es sql: {sql}")
310
  replaces = []
311
+ for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
312
+ fld, v = r.group(1), r.group(3)
313
+ match = " MATCH({}, '{}', 'operator=OR;fuzziness=AUTO:1,3;minimum_should_match=30%') ".format(fld, huqie.qieqie(huqie.qie(v)))
314
+ replaces.append(("{}{}'{}'".format(r.group(1), r.group(2), r.group(3)), match))
 
 
315
 
316
+ for p, r in replaces: sql = sql.replace(p, r, 1)
317
+ es_logger.info(f"To es: {sql}")
318
 
319
  try:
320
+ tbl = self.es.sql(sql, fetch_size, format)
321
  return tbl
322
  except Exception as e:
323
+ es_logger.error(f"SQL failure: {sql} =>" + str(e))
324
 
rag/parser/pdf_parser.py CHANGED
@@ -53,9 +53,10 @@ class HuParser:
53
 
54
  def __remote_call(self, species, images, thr=0.7):
55
  url = os.environ.get("INFINIFLOW_SERVER")
56
- if not url:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_SERVER'")
57
  token = os.environ.get("INFINIFLOW_TOKEN")
58
- if not token:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_TOKEN'")
 
 
59
 
60
  def convert_image_to_bytes(PILimage):
61
  image = BytesIO()
 
53
 
54
  def __remote_call(self, species, images, thr=0.7):
55
  url = os.environ.get("INFINIFLOW_SERVER")
 
56
  token = os.environ.get("INFINIFLOW_TOKEN")
57
+ if not url or not token:
58
+ logging.warning("INFINIFLOW_SERVER is not specified. To maximize the effectiveness, please visit https://github.com/infiniflow/ragflow, and sign in the our demo web site to get token. It's FREE! Using 'export' to set both environment variables: INFINIFLOW_SERVER and INFINIFLOW_TOKEN.")
59
+ return []
60
 
61
  def convert_image_to_bytes(PILimage):
62
  image = BytesIO()
rag/svr/task_executor.py CHANGED
@@ -47,7 +47,7 @@ from api.utils.file_utils import get_project_base_directory
47
  BATCH_SIZE = 64
48
 
49
  FACTORY = {
50
- ParserType.GENERAL.value: laws,
51
  ParserType.PAPER.value: paper,
52
  ParserType.BOOK.value: book,
53
  ParserType.PRESENTATION.value: presentation,
@@ -119,8 +119,8 @@ def build(row, cvmdl):
119
  chunker = FACTORY[row["parser_id"].lower()]
120
  try:
121
  cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
122
- cks = chunker.chunk(row["name"], MINIO.get(row["kb_id"], row["location"]), row["from_page"], row["to_page"],
123
- callback, kb_id=row["kb_id"], parser_config=row["parser_config"])
124
  except Exception as e:
125
  if re.search("(No such file|not found)", str(e)):
126
  callback(-1, "Can not find file <%s>" % row["doc_name"])
@@ -129,7 +129,7 @@ def build(row, cvmdl):
129
 
130
  cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e)))
131
 
132
- return []
133
 
134
  callback(msg="Finished slicing files. Start to embedding the content.")
135
 
@@ -211,6 +211,7 @@ def main(comm, mod):
211
 
212
  st_tm = timer()
213
  cks = build(r, cv_mdl)
 
214
  if not cks:
215
  tmf.write(str(r["update_time"]) + "\n")
216
  callback(1., "No chunk! Done!")
 
47
  BATCH_SIZE = 64
48
 
49
  FACTORY = {
50
+ ParserType.GENERAL.value: manual,
51
  ParserType.PAPER.value: paper,
52
  ParserType.BOOK.value: book,
53
  ParserType.PRESENTATION.value: presentation,
 
119
  chunker = FACTORY[row["parser_id"].lower()]
120
  try:
121
  cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
122
+ cks = chunker.chunk(row["name"], binary = MINIO.get(row["kb_id"], row["location"]), from_page=row["from_page"], to_page=row["to_page"],
123
+ callback = callback, kb_id=row["kb_id"], parser_config=row["parser_config"])
124
  except Exception as e:
125
  if re.search("(No such file|not found)", str(e)):
126
  callback(-1, "Can not find file <%s>" % row["doc_name"])
 
129
 
130
  cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e)))
131
 
132
+ return
133
 
134
  callback(msg="Finished slicing files. Start to embedding the content.")
135
 
 
211
 
212
  st_tm = timer()
213
  cks = build(r, cv_mdl)
214
+ if cks is None:continue
215
  if not cks:
216
  tmf.write(str(r["update_time"]) + "\n")
217
  callback(1., "No chunk! Done!")
rag/utils/es_conn.py CHANGED
@@ -241,7 +241,7 @@ class HuEs:
241
  es_logger.error("ES search timeout for 3 times!")
242
  raise Exception("ES search timeout.")
243
 
244
- def sql(self, sql, fetch_size=128, format="json", timeout=2):
245
  for i in range(3):
246
  try:
247
  res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout=timeout)
 
241
  es_logger.error("ES search timeout for 3 times!")
242
  raise Exception("ES search timeout.")
243
 
244
+ def sql(self, sql, fetch_size=128, format="json", timeout="2s"):
245
  for i in range(3):
246
  try:
247
  res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout=timeout)