Kevin Hu commited on
Commit
a3da325
·
1 Parent(s): 3fb798a

Fix raptor resuable issue. (#4063)

Browse files

### What problem does this PR solve?

#4045

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

api/db/services/document_service.py CHANGED
@@ -344,6 +344,8 @@ class DocumentService(CommonService):
344
  old[k] = v
345
 
346
  dfs_update(d.parser_config, config)
 
 
347
  cls.update_by_id(id, {"parser_config": d.parser_config})
348
 
349
  @classmethod
@@ -432,6 +434,11 @@ class DocumentService(CommonService):
432
 
433
 
434
  def queue_raptor_tasks(doc):
 
 
 
 
 
435
  def new_task():
436
  nonlocal doc
437
  return {
@@ -443,6 +450,9 @@ def queue_raptor_tasks(doc):
443
  }
444
 
445
  task = new_task()
 
 
 
446
  bulk_insert_into_db(Task, [task], True)
447
  task["type"] = "raptor"
448
  assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
 
344
  old[k] = v
345
 
346
  dfs_update(d.parser_config, config)
347
+ if not config.get("raptor") and d.parser_config.get("raptor"):
348
+ del d.parser_config["raptor"]
349
  cls.update_by_id(id, {"parser_config": d.parser_config})
350
 
351
  @classmethod
 
434
 
435
 
436
  def queue_raptor_tasks(doc):
437
+ chunking_config = DocumentService.get_chunking_config(doc["id"])
438
+ hasher = xxhash.xxh64()
439
+ for field in sorted(chunking_config.keys()):
440
+ hasher.update(str(chunking_config[field]).encode("utf-8"))
441
+
442
  def new_task():
443
  nonlocal doc
444
  return {
 
450
  }
451
 
452
  task = new_task()
453
+ for field in ["doc_id", "from_page", "to_page"]:
454
+ hasher.update(str(task.get(field, "")).encode("utf-8"))
455
+ task["digest"] = hasher.hexdigest()
456
  bulk_insert_into_db(Task, [task], True)
457
  task["type"] = "raptor"
458
  assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
api/db/services/task_service.py CHANGED
@@ -34,15 +34,17 @@ from rag.utils.redis_conn import REDIS_CONN
34
  from api import settings
35
  from rag.nlp import search
36
 
 
37
  def trim_header_by_lines(text: str, max_length) -> str:
38
  len_text = len(text)
39
  if len_text <= max_length:
40
  return text
41
  for i in range(len_text):
42
  if text[i] == '\n' and len_text - i <= max_length:
43
- return text[i+1:]
44
  return text
45
 
 
46
  class TaskService(CommonService):
47
  model = Task
48
 
@@ -73,10 +75,10 @@ class TaskService(CommonService):
73
  ]
74
  docs = (
75
  cls.model.select(*fields)
76
- .join(Document, on=(cls.model.doc_id == Document.id))
77
- .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
78
- .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
79
- .where(cls.model.id == task_id)
80
  )
81
  docs = list(docs.dicts())
82
  if not docs:
@@ -111,7 +113,7 @@ class TaskService(CommonService):
111
  ]
112
  tasks = (
113
  cls.model.select(*fields).order_by(cls.model.from_page.asc(), cls.model.create_time.desc())
114
- .where(cls.model.doc_id == doc_id)
115
  )
116
  tasks = list(tasks.dicts())
117
  if not tasks:
@@ -131,18 +133,18 @@ class TaskService(CommonService):
131
  cls.model.select(
132
  *[Document.id, Document.kb_id, Document.location, File.parent_id]
133
  )
134
- .join(Document, on=(cls.model.doc_id == Document.id))
135
- .join(
136
  File2Document,
137
  on=(File2Document.document_id == Document.id),
138
  join_type=JOIN.LEFT_OUTER,
139
  )
140
- .join(
141
  File,
142
  on=(File2Document.file_id == File.id),
143
  join_type=JOIN.LEFT_OUTER,
144
  )
145
- .where(
146
  Document.status == StatusEnum.VALID.value,
147
  Document.run == TaskStatus.RUNNING.value,
148
  ~(Document.type == FileType.VIRTUAL.value),
@@ -212,8 +214,8 @@ def queue_tasks(doc: dict, bucket: str, name: str):
212
  if doc["parser_id"] == "paper":
213
  page_size = doc["parser_config"].get("task_page_size", 22)
214
  if doc["parser_id"] in ["one", "knowledge_graph"] or not do_layout:
215
- page_size = 10**9
216
- page_ranges = doc["parser_config"].get("pages") or [(1, 10**5)]
217
  for s, e in page_ranges:
218
  s -= 1
219
  s = max(0, s)
@@ -257,7 +259,8 @@ def queue_tasks(doc: dict, bucket: str, name: str):
257
  if task["chunk_ids"]:
258
  chunk_ids.extend(task["chunk_ids"].split())
259
  if chunk_ids:
260
- settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(chunking_config["tenant_id"]), chunking_config["kb_id"])
 
261
  DocumentService.update_by_id(doc["id"], {"chunk_num": ck_num})
262
 
263
  bulk_insert_into_db(Task, tsks, True)
@@ -271,7 +274,8 @@ def queue_tasks(doc: dict, bucket: str, name: str):
271
 
272
 
273
  def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
274
- idx = bisect.bisect_left(prev_tasks, task.get("from_page", 0), key=lambda x: x.get("from_page",0))
 
275
  if idx >= len(prev_tasks):
276
  return 0
277
  prev_task = prev_tasks[idx]
@@ -286,4 +290,4 @@ def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config:
286
  task["progress_msg"] += "reused previous task's chunks."
287
  prev_task["chunk_ids"] = ""
288
 
289
- return len(task["chunk_ids"].split())
 
34
  from api import settings
35
  from rag.nlp import search
36
 
37
+
38
  def trim_header_by_lines(text: str, max_length) -> str:
39
  len_text = len(text)
40
  if len_text <= max_length:
41
  return text
42
  for i in range(len_text):
43
  if text[i] == '\n' and len_text - i <= max_length:
44
+ return text[i + 1:]
45
  return text
46
 
47
+
48
  class TaskService(CommonService):
49
  model = Task
50
 
 
75
  ]
76
  docs = (
77
  cls.model.select(*fields)
78
+ .join(Document, on=(cls.model.doc_id == Document.id))
79
+ .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
80
+ .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
81
+ .where(cls.model.id == task_id)
82
  )
83
  docs = list(docs.dicts())
84
  if not docs:
 
113
  ]
114
  tasks = (
115
  cls.model.select(*fields).order_by(cls.model.from_page.asc(), cls.model.create_time.desc())
116
+ .where(cls.model.doc_id == doc_id)
117
  )
118
  tasks = list(tasks.dicts())
119
  if not tasks:
 
133
  cls.model.select(
134
  *[Document.id, Document.kb_id, Document.location, File.parent_id]
135
  )
136
+ .join(Document, on=(cls.model.doc_id == Document.id))
137
+ .join(
138
  File2Document,
139
  on=(File2Document.document_id == Document.id),
140
  join_type=JOIN.LEFT_OUTER,
141
  )
142
+ .join(
143
  File,
144
  on=(File2Document.file_id == File.id),
145
  join_type=JOIN.LEFT_OUTER,
146
  )
147
+ .where(
148
  Document.status == StatusEnum.VALID.value,
149
  Document.run == TaskStatus.RUNNING.value,
150
  ~(Document.type == FileType.VIRTUAL.value),
 
214
  if doc["parser_id"] == "paper":
215
  page_size = doc["parser_config"].get("task_page_size", 22)
216
  if doc["parser_id"] in ["one", "knowledge_graph"] or not do_layout:
217
+ page_size = 10 ** 9
218
+ page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)]
219
  for s, e in page_ranges:
220
  s -= 1
221
  s = max(0, s)
 
259
  if task["chunk_ids"]:
260
  chunk_ids.extend(task["chunk_ids"].split())
261
  if chunk_ids:
262
+ settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(chunking_config["tenant_id"]),
263
+ chunking_config["kb_id"])
264
  DocumentService.update_by_id(doc["id"], {"chunk_num": ck_num})
265
 
266
  bulk_insert_into_db(Task, tsks, True)
 
274
 
275
 
276
  def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
277
+ idx = bisect.bisect_left(prev_tasks, (task.get("from_page", 0), task.get("digest", "")),
278
+ key=lambda x: (x.get("from_page", 0), x.get("digest", "")))
279
  if idx >= len(prev_tasks):
280
  return 0
281
  prev_task = prev_tasks[idx]
 
290
  task["progress_msg"] += "reused previous task's chunks."
291
  prev_task["chunk_ids"] = ""
292
 
293
+ return len(task["chunk_ids"].split())
graphrag/utils.py CHANGED
@@ -78,7 +78,7 @@ def get_llm_cache(llmnm, txt, history, genconf):
78
  bin = REDIS_CONN.get(k)
79
  if not bin:
80
  return
81
- return bin.decode("utf-8")
82
 
83
 
84
  def set_llm_cache(llmnm, txt, v: str, history, genconf):
 
78
  bin = REDIS_CONN.get(k)
79
  if not bin:
80
  return
81
+ return bin
82
 
83
 
84
  def set_llm_cache(llmnm, txt, v: str, history, genconf):