KevinHuSh commited on
Commit
a49657b
·
1 Parent(s): 13080d4

add self-rag (#1070)

Browse files

### What problem does this PR solve?

#1069

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

api/apps/api_app.py CHANGED
@@ -198,15 +198,18 @@ def completion():
198
  else: conv.reference[-1] = ans["reference"]
199
  conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
200
 
 
 
 
 
 
201
  def stream():
202
  nonlocal dia, msg, req, conv
203
  try:
204
  for ans in chat(dia, msg, True, **req):
205
  fillin_conv(ans)
206
- for chunk_i in ans['reference'].get('chunks', []):
207
- chunk_i['doc_name'] = chunk_i['docnm_kwd']
208
- chunk_i.pop('docnm_kwd')
209
- yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
210
  API4ConversationService.append_message(conv.id, conv.to_dict())
211
  except Exception as e:
212
  yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
@@ -554,23 +557,24 @@ def completion_faq():
554
  "content": ""
555
  }
556
  ]
557
- for ans in chat(dia, msg, stream=False, **req):
558
- # answer = ans
559
- data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
560
- fillin_conv(ans)
561
- API4ConversationService.append_message(conv.id, conv.to_dict())
562
-
563
- chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
564
- for chunk_idx in chunk_idxs[:1]:
565
- if ans["reference"]["chunks"][chunk_idx]["img_id"]:
566
- try:
567
- bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
568
- response = MINIO.get(bkt, nm)
569
- data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
570
- data.append(data_type_picture)
571
- except Exception as e:
572
- return server_error_response(e)
573
  break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
 
575
  response = {"code": 200, "msg": "success", "data": data}
576
  return response
 
198
  else: conv.reference[-1] = ans["reference"]
199
  conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
200
 
201
+ def rename_field(ans):
202
+ for chunk_i in ans['reference'].get('chunks', []):
203
+ chunk_i['doc_name'] = chunk_i['docnm_kwd']
204
+ chunk_i.pop('docnm_kwd')
205
+
206
  def stream():
207
  nonlocal dia, msg, req, conv
208
  try:
209
  for ans in chat(dia, msg, True, **req):
210
  fillin_conv(ans)
211
+ rename_field(rename_field)
212
+ yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
 
 
213
  API4ConversationService.append_message(conv.id, conv.to_dict())
214
  except Exception as e:
215
  yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
 
557
  "content": ""
558
  }
559
  ]
560
+ ans = ""
561
+ for a in chat(dia, msg, stream=False, **req):
562
+ ans = a
 
 
 
 
 
 
 
 
 
 
 
 
 
563
  break
564
+ data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
565
+ fillin_conv(ans)
566
+ API4ConversationService.append_message(conv.id, conv.to_dict())
567
+
568
+ chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
569
+ for chunk_idx in chunk_idxs[:1]:
570
+ if ans["reference"]["chunks"][chunk_idx]["img_id"]:
571
+ try:
572
+ bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
573
+ response = MINIO.get(bkt, nm)
574
+ data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
575
+ data.append(data_type_picture)
576
+ except Exception as e:
577
+ return server_error_response(e)
578
 
579
  response = {"code": 200, "msg": "success", "data": data}
580
  return response
api/apps/canvas_app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import json
17
+
18
+ from flask import request
19
+ from flask_login import login_required, current_user
20
+
21
+ from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
22
+ from api.utils import get_uuid
23
+ from api.utils.api_utils import get_json_result, server_error_response, validate_request
24
+ from graph.canvas import Canvas
25
+
26
+
27
+ @manager.route('/templates', methods=['GET'])
28
+ @login_required
29
+ def templates():
30
+ return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.get_all()])
31
+
32
+
33
+ @manager.route('/list', methods=['GET'])
34
+ @login_required
35
+ def canvas_list():
36
+
37
+ return get_json_result(data=[c.to_dict() for c in UserCanvasService.query(user_id=current_user.id)])
38
+
39
+
40
+ @manager.route('/rm', methods=['POST'])
41
+ @validate_request("canvas_ids")
42
+ @login_required
43
+ def rm():
44
+ for i in request.json["canvas_ids"]:
45
+ UserCanvasService.delete_by_id(i)
46
+ return get_json_result(data=True)
47
+
48
+
49
+ @manager.route('/set', methods=['POST'])
50
+ @validate_request("dsl", "title")
51
+ @login_required
52
+ def save():
53
+ req = request.json
54
+ req["user_id"] = current_user.id
55
+ if not isinstance(req["dsl"], str):req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
56
+ try:
57
+ Canvas(req["dsl"])
58
+ except Exception as e:
59
+ return server_error_response(e)
60
+
61
+ req["dsl"] = json.loads(req["dsl"])
62
+ if "id" not in req:
63
+ req["id"] = get_uuid()
64
+ if not UserCanvasService.save(**req):
65
+ return server_error_response("Fail to save canvas.")
66
+ else:
67
+ UserCanvasService.update_by_id(req["id"], req)
68
+
69
+ return get_json_result(data=req)
70
+
71
+
72
+ @manager.route('/get/<canvas_id>', methods=['GET'])
73
+ @login_required
74
+ def get(canvas_id):
75
+ e, c = UserCanvasService.get_by_id(canvas_id)
76
+ if not e:
77
+ return server_error_response("canvas not found.")
78
+ return get_json_result(data=c.to_dict())
79
+
80
+
81
+ @manager.route('/run', methods=['POST'])
82
+ @validate_request("id", "dsl")
83
+ @login_required
84
+ def run():
85
+ req = request.json
86
+ if not isinstance(req["dsl"], str): req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
87
+ try:
88
+ canvas = Canvas(req["dsl"], current_user.id)
89
+ ans = canvas.run()
90
+ req["dsl"] = json.loads(str(canvas))
91
+ UserCanvasService.update_by_id(req["id"], dsl=req["dsl"])
92
+ return get_json_result(data=req["dsl"])
93
+ except Exception as e:
94
+ return server_error_response(e)
95
+
96
+
97
+ @manager.route('/reset', methods=['POST'])
98
+ @validate_request("canvas_id")
99
+ @login_required
100
+ def reset():
101
+ req = request.json
102
+ try:
103
+ user_canvas = UserCanvasService.get_by_id(req["canvas_id"])
104
+ canvas = Canvas(req["dsl"], current_user.id)
105
+ canvas.reset()
106
+ req["dsl"] = json.loads(str(canvas))
107
+ UserCanvasService.update_by_id(req["canvas_id"], dsl=req["dsl"])
108
+ return get_json_result(data=req["dsl"])
109
+ except Exception as e:
110
+ return server_error_response(e)
111
+
112
+
api/apps/conversation_app.py CHANGED
@@ -13,7 +13,8 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
- from flask import request, Response, jsonify
 
17
  from flask_login import login_required
18
  from api.db.services.dialog_service import DialogService, ConversationService, chat
19
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
@@ -121,7 +122,7 @@ def completion():
121
  e, conv = ConversationService.get_by_id(req["conversation_id"])
122
  if not e:
123
  return get_data_error_result(retmsg="Conversation not found!")
124
- conv.message.append(msg[-1])
125
  e, dia = DialogService.get_by_id(conv.dialog_id)
126
  if not e:
127
  return get_data_error_result(retmsg="Dialog not found!")
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
+ from copy import deepcopy
17
+ from flask import request, Response
18
  from flask_login import login_required
19
  from api.db.services.dialog_service import DialogService, ConversationService, chat
20
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
 
122
  e, conv = ConversationService.get_by_id(req["conversation_id"])
123
  if not e:
124
  return get_data_error_result(retmsg="Conversation not found!")
125
+ conv.message.append(deepcopy(msg[-1]))
126
  e, dia = DialogService.get_by_id(conv.dialog_id)
127
  if not e:
128
  return get_data_error_result(retmsg="Dialog not found!")
api/apps/dialog_app.py CHANGED
@@ -31,8 +31,8 @@ def set_dialog():
31
  req = request.json
32
  dialog_id = req.get("dialog_id")
33
  name = req.get("name", "New Dialog")
34
- icon = req.get("icon", "")
35
  description = req.get("description", "A helpful Dialog")
 
36
  top_n = req.get("top_n", 6)
37
  top_k = req.get("top_k", 1024)
38
  rerank_id = req.get("rerank_id", "")
@@ -92,7 +92,7 @@ def set_dialog():
92
  "rerank_id": rerank_id,
93
  "similarity_threshold": similarity_threshold,
94
  "vector_similarity_weight": vector_similarity_weight,
95
- "icon": icon,
96
  }
97
  if not DialogService.save(**dia):
98
  return get_data_error_result(retmsg="Fail to new a dialog!")
 
31
  req = request.json
32
  dialog_id = req.get("dialog_id")
33
  name = req.get("name", "New Dialog")
 
34
  description = req.get("description", "A helpful Dialog")
35
+ icon = req.get("icon", "")
36
  top_n = req.get("top_n", 6)
37
  top_k = req.get("top_k", 1024)
38
  rerank_id = req.get("rerank_id", "")
 
92
  "rerank_id": rerank_id,
93
  "similarity_threshold": similarity_threshold,
94
  "vector_similarity_weight": vector_similarity_weight,
95
+ "icon": icon
96
  }
97
  if not DialogService.save(**dia):
98
  return get_data_error_result(retmsg="Fail to new a dialog!")
api/db/services/canvas_service.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ from datetime import datetime
17
+ import peewee
18
+ from api.db.db_models import DB, API4Conversation, APIToken, Dialog, CanvasTemplate, UserCanvas
19
+ from api.db.services.common_service import CommonService
20
+
21
+
22
+ class CanvasTemplateService(CommonService):
23
+ model = CanvasTemplate
24
+
25
+ class UserCanvasService(CommonService):
26
+ model = UserCanvas
api/db/services/dialog_service.py CHANGED
@@ -23,6 +23,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
23
  from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
24
  from api.settings import chat_logger, retrievaler
25
  from rag.app.resume import forbidden_select_fields4resume
 
26
  from rag.nlp.search import index_name
27
  from rag.utils import rmSpace, num_tokens_from_string, encoder
28
 
@@ -80,7 +81,8 @@ def chat(dialog, messages, stream=True, **kwargs):
80
  if not llm:
81
  raise LookupError("LLM(%s) not found" % dialog.llm_id)
82
  max_tokens = 1024
83
- else: max_tokens = llm[0].max_tokens
 
84
  kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
85
  embd_nms = list(set([kb.embd_id for kb in kbs]))
86
  if len(embd_nms) != 1:
@@ -124,6 +126,16 @@ def chat(dialog, messages, stream=True, **kwargs):
124
  doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
125
  top=1024, aggs=False, rerank_mdl=rerank_mdl)
126
  knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
 
 
 
 
 
 
 
 
 
 
127
  chat_logger.info(
128
  "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
129
 
@@ -136,7 +148,7 @@ def chat(dialog, messages, stream=True, **kwargs):
136
 
137
  msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
138
  msg.extend([{"role": m["role"], "content": m["content"]}
139
- for m in messages if m["role"] != "system"])
140
  used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
141
  assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
142
 
@@ -150,9 +162,9 @@ def chat(dialog, messages, stream=True, **kwargs):
150
  if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
151
  answer, idx = retrievaler.insert_citations(answer,
152
  [ck["content_ltks"]
153
- for ck in kbinfos["chunks"]],
154
  [ck["vector"]
155
- for ck in kbinfos["chunks"]],
156
  embd_mdl,
157
  tkweight=1 - dialog.vector_similarity_weight,
158
  vtweight=dialog.vector_similarity_weight)
@@ -166,7 +178,7 @@ def chat(dialog, messages, stream=True, **kwargs):
166
  for c in refs["chunks"]:
167
  if c.get("vector"):
168
  del c["vector"]
169
- if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api")>=0:
170
  answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
171
  return {"answer": answer, "reference": refs}
172
 
@@ -204,7 +216,7 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
204
  def get_table():
205
  nonlocal sys_prompt, user_promt, question, tried_times
206
  sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
207
- "temperature": 0.06})
208
  print(user_promt, sql)
209
  chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}")
210
  sql = re.sub(r"[\r\n]+", " ", sql.lower())
@@ -273,17 +285,19 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
273
 
274
  # compose markdown table
275
  clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
276
- tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
 
277
 
278
  line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
279
- ("|------|" if docid_idx and docid_idx else "")
280
 
281
  rows = ["|" +
282
  "|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
283
  "|" for r in tbl["rows"]]
284
  if quota:
285
  rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
286
- else: rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
 
287
  rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
288
 
289
  if not docid_idx or not docnm_idx:
@@ -303,5 +317,40 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
303
  return {
304
  "answer": "\n".join([clmns, line, rows]),
305
  "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
306
- "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]}
 
307
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
24
  from api.settings import chat_logger, retrievaler
25
  from rag.app.resume import forbidden_select_fields4resume
26
+ from rag.nlp.rag_tokenizer import is_chinese
27
  from rag.nlp.search import index_name
28
  from rag.utils import rmSpace, num_tokens_from_string, encoder
29
 
 
81
  if not llm:
82
  raise LookupError("LLM(%s) not found" % dialog.llm_id)
83
  max_tokens = 1024
84
+ else:
85
+ max_tokens = llm[0].max_tokens
86
  kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
87
  embd_nms = list(set([kb.embd_id for kb in kbs]))
88
  if len(embd_nms) != 1:
 
126
  doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
127
  top=1024, aggs=False, rerank_mdl=rerank_mdl)
128
  knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
129
+ #self-rag
130
+ if dialog.prompt_config.get("self_rag") and not relevant(dialog.tenant_id, dialog.llm_id, questions[-1], knowledges):
131
+ questions[-1] = rewrite(dialog.tenant_id, dialog.llm_id, questions[-1])
132
+ kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
133
+ dialog.similarity_threshold,
134
+ dialog.vector_similarity_weight,
135
+ doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
136
+ top=1024, aggs=False, rerank_mdl=rerank_mdl)
137
+ knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
138
+
139
  chat_logger.info(
140
  "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
141
 
 
148
 
149
  msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
150
  msg.extend([{"role": m["role"], "content": m["content"]}
151
+ for m in messages if m["role"] != "system"])
152
  used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
153
  assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
154
 
 
162
  if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
163
  answer, idx = retrievaler.insert_citations(answer,
164
  [ck["content_ltks"]
165
+ for ck in kbinfos["chunks"]],
166
  [ck["vector"]
167
+ for ck in kbinfos["chunks"]],
168
  embd_mdl,
169
  tkweight=1 - dialog.vector_similarity_weight,
170
  vtweight=dialog.vector_similarity_weight)
 
178
  for c in refs["chunks"]:
179
  if c.get("vector"):
180
  del c["vector"]
181
+ if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
182
  answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
183
  return {"answer": answer, "reference": refs}
184
 
 
216
  def get_table():
217
  nonlocal sys_prompt, user_promt, question, tried_times
218
  sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
219
+ "temperature": 0.06})
220
  print(user_promt, sql)
221
  chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}")
222
  sql = re.sub(r"[\r\n]+", " ", sql.lower())
 
285
 
286
  # compose markdown table
287
  clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
288
+ tbl["columns"][i]["name"])) for i in
289
+ clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
290
 
291
  line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
292
+ ("|------|" if docid_idx and docid_idx else "")
293
 
294
  rows = ["|" +
295
  "|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
296
  "|" for r in tbl["rows"]]
297
  if quota:
298
  rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
299
+ else:
300
+ rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
301
  rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
302
 
303
  if not docid_idx or not docnm_idx:
 
317
  return {
318
  "answer": "\n".join([clmns, line, rows]),
319
  "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
320
+ "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
321
+ doc_aggs.items()]}
322
  }
323
+
324
+
325
+ def relevant(tenant_id, llm_id, question, contents: list):
326
+ chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
327
+ prompt = """
328
+ You are a grader assessing relevance of a retrieved document to a user question.
329
+ It does not need to be a stringent test. The goal is to filter out erroneous retrievals.
330
+ If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant.
331
+ Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
332
+ No other words needed except 'yes' or 'no'.
333
+ """
334
+ if not contents:return False
335
+ contents = "Documents: \n" + " - ".join(contents)
336
+ contents = f"Question: {question}\n" + contents
337
+ if num_tokens_from_string(contents) >= chat_mdl.max_length - 4:
338
+ contents = encoder.decode(encoder.encode(contents)[:chat_mdl.max_length - 4])
339
+ ans = chat_mdl.chat(prompt, [{"role": "user", "content": contents}], {"temperature": 0.01})
340
+ if ans.lower().find("yes") >= 0: return True
341
+ return False
342
+
343
+
344
+ def rewrite(tenant_id, llm_id, question):
345
+ chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
346
+ prompt = """
347
+ You are an expert at query expansion to generate a paraphrasing of a question.
348
+ I can't retrieval relevant information from the knowledge base by using user's question directly.
349
+ You need to expand or paraphrase user's question by multiple ways such as using synonyms words/phrase,
350
+ writing the abbreviation in its entirety, adding some extra descriptions or explanations,
351
+ changing the way of expression, translating the original question into another language (English/Chinese), etc.
352
+ And return 5 versions of question and one is from translation.
353
+ Just list the question. No other words are needed.
354
+ """
355
+ ans = chat_mdl.chat(prompt, [{"role": "user", "content": question}], {"temperature": 0.8})
356
+ return ans
deepdoc/parser/pdf_parser.py CHANGED
@@ -1021,6 +1021,8 @@ class RAGFlowPdfParser:
1021
 
1022
  self.page_cum_height = np.cumsum(self.page_cum_height)
1023
  assert len(self.page_cum_height) == len(self.page_images) + 1
 
 
1024
 
1025
  def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
1026
  self.__images__(fnm, zoomin)
 
1021
 
1022
  self.page_cum_height = np.cumsum(self.page_cum_height)
1023
  assert len(self.page_cum_height) == len(self.page_images) + 1
1024
+ if len(self.boxes) == 0 and zoomin < 9: self.__images__(fnm, zoomin * 3, page_from,
1025
+ page_to, callback)
1026
 
1027
  def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
1028
  self.__images__(fnm, zoomin)
rag/llm/rerank_model.py CHANGED
@@ -129,4 +129,3 @@ class YoudaoRerank(DefaultRerank):
129
  return np.array(res), token_count
130
 
131
 
132
-
 
129
  return np.array(res), token_count
130
 
131
 
 
rag/nlp/query.py CHANGED
@@ -48,7 +48,7 @@ class EsQueryer:
48
  @staticmethod
49
  def rmWWW(txt):
50
  patts = [
51
- (r"是*(什么样的|哪家|一下|那家|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""),
52
  (r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
53
  (r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down) ", " ")
54
  ]
@@ -68,7 +68,9 @@ class EsQueryer:
68
  if not self.isChinese(txt):
69
  tks = rag_tokenizer.tokenize(txt).split(" ")
70
  tks_w = self.tw.weights(tks)
71
- tks_w = [(re.sub(r"[ \\\"']+", "", tk), w) for tk, w in tks_w]
 
 
72
  q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk]
73
  for i in range(1, len(tks_w)):
74
  q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2))
@@ -118,7 +120,8 @@ class EsQueryer:
118
  if sm:
119
  tk = f"{tk} OR \"%s\" OR (\"%s\"~2)^0.5" % (
120
  " ".join(sm), " ".join(sm))
121
- tms.append((tk, w))
 
122
 
123
  tms = " ".join([f"({t})^{w}" for t, w in tms])
124
 
 
48
  @staticmethod
49
  def rmWWW(txt):
50
  patts = [
51
+ (r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""),
52
  (r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
53
  (r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down) ", " ")
54
  ]
 
68
  if not self.isChinese(txt):
69
  tks = rag_tokenizer.tokenize(txt).split(" ")
70
  tks_w = self.tw.weights(tks)
71
+ tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
72
+ tks_w = [(re.sub(r"^[a-z0-9]$", "", tk), w) for tk, w in tks_w if tk]
73
+ tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
74
  q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk]
75
  for i in range(1, len(tks_w)):
76
  q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2))
 
120
  if sm:
121
  tk = f"{tk} OR \"%s\" OR (\"%s\"~2)^0.5" % (
122
  " ".join(sm), " ".join(sm))
123
+ if tk.strip():
124
+ tms.append((tk, w))
125
 
126
  tms = " ".join([f"({t})^{w}" for t, w in tms])
127