ragflow / rag /svr /task_executor.py
Kevin Hu
Cache the result from llm for graphrag and raptor (#4051)
758538f
raw
history blame
24.5 kB
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# from beartype import BeartypeConf
# from beartype.claw import beartype_all # <-- you didn't sign up for this
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
import sys
from api.utils.log_utils import initRootLogger
from graphrag.utils import get_llm_cache, set_llm_cache
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
initRootLogger(CONSUMER_NAME)
import logging
import os
from datetime import datetime
import json
import xxhash
import copy
import re
import time
import threading
from functools import partial
from io import BytesIO
from multiprocessing.context import TimeoutError
from timeit import default_timer as timer
import tracemalloc
import numpy as np
from peewee import DoesNotExist
from api.db import LLMType, ParserType, TaskStatus
from api.db.services.dialog_service import keyword_extraction, question_proposal
from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService
from api.db.services.file2document_service import File2DocumentService
from api import settings
from api.versions import get_ragflow_version
from api.db.db_models import close_connection
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
knowledge_graph, email
from rag.nlp import search, rag_tokenizer
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings
from rag.utils import num_tokens_from_string
from rag.utils.redis_conn import REDIS_CONN, Payload
from rag.utils.storage_factory import STORAGE_IMPL
BATCH_SIZE = 64
FACTORY = {
"general": naive,
ParserType.NAIVE.value: naive,
ParserType.PAPER.value: paper,
ParserType.BOOK.value: book,
ParserType.PRESENTATION.value: presentation,
ParserType.MANUAL.value: manual,
ParserType.LAWS.value: laws,
ParserType.QA.value: qa,
ParserType.TABLE.value: table,
ParserType.RESUME.value: resume,
ParserType.PICTURE.value: picture,
ParserType.ONE.value: one,
ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email,
ParserType.KG.value: knowledge_graph
}
CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
PAYLOAD: Payload | None = None
BOOT_AT = datetime.now().isoformat()
PENDING_TASKS = 0
LAG_TASKS = 0
mt_lock = threading.Lock()
DONE_TASKS = 0
FAILED_TASKS = 0
CURRENT_TASK = None
class TaskCanceledException(Exception):
def __init__(self, msg):
self.msg = msg
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
global PAYLOAD
if prog is not None and prog < 0:
msg = "[ERROR]" + msg
try:
cancel = TaskService.do_cancel(task_id)
except DoesNotExist:
logging.warning(f"set_progress task {task_id} is unknown")
if PAYLOAD:
PAYLOAD.ack()
PAYLOAD = None
return
if cancel:
msg += " [Canceled]"
prog = -1
if to_page > 0:
if msg:
msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
d = {"progress_msg": msg}
if prog is not None:
d["progress"] = prog
logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}")
try:
TaskService.update_progress(task_id, d)
except DoesNotExist:
logging.warning(f"set_progress task {task_id} is unknown")
if PAYLOAD:
PAYLOAD.ack()
PAYLOAD = None
return
close_connection()
if cancel and PAYLOAD:
PAYLOAD.ack()
PAYLOAD = None
raise TaskCanceledException(msg)
def collect():
global CONSUMER_NAME, PAYLOAD, DONE_TASKS, FAILED_TASKS
try:
PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
if not PAYLOAD:
PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
if not PAYLOAD:
time.sleep(1)
return None
except Exception:
logging.exception("Get task event from queue exception")
return None
msg = PAYLOAD.get_message()
if not msg:
return None
task = None
canceled = False
try:
task = TaskService.get_task(msg["id"])
if task:
_, doc = DocumentService.get_by_id(task["doc_id"])
canceled = doc.run == TaskStatus.CANCEL.value or doc.progress < 0
except DoesNotExist:
pass
except Exception:
logging.exception("collect get_task exception")
if not task or canceled:
state = "is unknown" if not task else "has been cancelled"
with mt_lock:
DONE_TASKS += 1
logging.info(f"collect task {msg['id']} {state}")
return None
if msg.get("type", "") == "raptor":
task["task_type"] = "raptor"
return task
def get_storage_binary(bucket, name):
return STORAGE_IMPL.get(bucket, name)
def build_chunks(task, progress_callback):
if task["size"] > DOC_MAXIMUM_SIZE:
set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
return []
chunker = FACTORY[task["parser_id"].lower()]
try:
st = timer()
bucket, name = File2DocumentService.get_storage_address(doc_id=task["doc_id"])
binary = get_storage_binary(bucket, name)
logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
except TimeoutError:
progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
logging.exception("Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"]))
raise
except Exception as e:
if re.search("(No such file|not found)", str(e)):
progress_callback(-1, "Can not find file <%s> from minio. Could you try it again?" % task["name"])
else:
progress_callback(-1, "Get file from minio: %s" % str(e).replace("'", ""))
logging.exception("Chunking {}/{} got exception".format(task["location"], task["name"]))
raise
try:
cks = chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
to_page=task["to_page"], lang=task["language"], callback=progress_callback,
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"])
logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
except TaskCanceledException:
raise
except Exception as e:
progress_callback(-1, "Internal server error while chunking: %s" % str(e).replace("'", ""))
logging.exception("Chunking {}/{} got exception".format(task["location"], task["name"]))
raise
docs = []
doc = {
"doc_id": task["doc_id"],
"kb_id": str(task["kb_id"])
}
if task["pagerank"]:
doc["pagerank_fea"] = int(task["pagerank"])
el = 0
for ck in cks:
d = copy.deepcopy(doc)
d.update(ck)
d["id"] = xxhash.xxh64((ck["content_with_weight"] + str(d["doc_id"])).encode("utf-8")).hexdigest()
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.now().timestamp()
if not d.get("image"):
_ = d.pop("image", None)
d["img_id"] = ""
docs.append(d)
continue
try:
output_buffer = BytesIO()
if isinstance(d["image"], bytes):
output_buffer = BytesIO(d["image"])
else:
d["image"].save(output_buffer, format='JPEG')
st = timer()
STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue())
el += timer() - st
except Exception:
logging.exception("Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["_id"]))
raise
d["img_id"] = "{}-{}".format(task["kb_id"], d["id"])
del d["image"]
docs.append(d)
logging.info("MINIO PUT({}):{}".format(task["name"], el))
if task["parser_config"].get("auto_keywords", 0):
st = timer()
progress_callback(msg="Start to generate keywords for every chunk ...")
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
for d in docs:
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords",
{"topn": task["parser_config"]["auto_keywords"]})
if not cached:
cached = keyword_extraction(chat_mdl, d["content_with_weight"],
task["parser_config"]["auto_keywords"])
if cached:
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords",
{"topn": task["parser_config"]["auto_keywords"]})
d["important_kwd"] = cached.split(",")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
progress_callback(msg="Keywords generation completed in {:.2f}s".format(timer() - st))
if task["parser_config"].get("auto_questions", 0):
st = timer()
progress_callback(msg="Start to generate questions for every chunk ...")
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
for d in docs:
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question",
{"topn": task["parser_config"]["auto_questions"]})
if not cached:
cached = question_proposal(chat_mdl, d["content_with_weight"], task["parser_config"]["auto_questions"])
if cached:
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question",
{"topn": task["parser_config"]["auto_questions"]})
d["question_kwd"] = cached.split("\n")
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
progress_callback(msg="Question generation completed in {:.2f}s".format(timer() - st))
return docs
def init_kb(row, vector_size: int):
idxnm = search.index_name(row["tenant_id"])
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id",""), vector_size)
def embedding(docs, mdl, parser_config=None, callback=None):
if parser_config is None:
parser_config = {}
batch_size = 16
tts, cnts = [], []
for d in docs:
tts.append(d.get("docnm_kwd", "Title"))
c = "\n".join(d.get("question_kwd", []))
if not c:
c = d["content_with_weight"]
c = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", c)
cnts.append(c)
tk_count = 0
if len(tts) == len(cnts):
tts_ = np.array([])
for i in range(0, len(tts), batch_size):
vts, c = mdl.encode(tts[i: i + batch_size])
if len(tts_) == 0:
tts_ = vts
else:
tts_ = np.concatenate((tts_, vts), axis=0)
tk_count += c
callback(prog=0.6 + 0.1 * (i + 1) / len(tts), msg="")
tts = tts_
cnts_ = np.array([])
for i in range(0, len(cnts), batch_size):
vts, c = mdl.encode(cnts[i: i + batch_size])
if len(cnts_) == 0:
cnts_ = vts
else:
cnts_ = np.concatenate((cnts_, vts), axis=0)
tk_count += c
callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="")
cnts = cnts_
title_w = float(parser_config.get("filename_embd_weight", 0.1))
vects = (title_w * tts + (1 - title_w) *
cnts) if len(tts) == len(cnts) else cnts
assert len(vects) == len(docs)
vector_size = 0
for i, d in enumerate(docs):
v = vects[i].tolist()
vector_size = len(v)
d["q_%d_vec" % len(v)] = v
return tk_count, vector_size
def run_raptor(row, chat_mdl, embd_mdl, callback=None):
vts, _ = embd_mdl.encode(["ok"])
vector_size = len(vts[0])
vctr_nm = "q_%d_vec" % vector_size
chunks = []
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", vctr_nm]):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
raptor = Raptor(
row["parser_config"]["raptor"].get("max_cluster", 64),
chat_mdl,
embd_mdl,
row["parser_config"]["raptor"]["prompt"],
row["parser_config"]["raptor"]["max_token"],
row["parser_config"]["raptor"]["threshold"]
)
original_length = len(chunks)
chunks = raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
doc = {
"doc_id": row["doc_id"],
"kb_id": [str(row["kb_id"])],
"docnm_kwd": row["name"],
"title_tks": rag_tokenizer.tokenize(row["name"])
}
if row["pagerank"]:
doc["pagerank_fea"] = int(row["pagerank"])
res = []
tk_count = 0
for content, vctr in chunks[original_length:]:
d = copy.deepcopy(doc)
d["id"] = xxhash.xxh64((content + str(d["doc_id"])).encode("utf-8")).hexdigest()
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.now().timestamp()
d[vctr_nm] = vctr.tolist()
d["content_with_weight"] = content
d["content_ltks"] = rag_tokenizer.tokenize(content)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
res.append(d)
tk_count += num_tokens_from_string(content)
return res, tk_count, vector_size
def do_handle_task(task):
task_id = task["id"]
task_from_page = task["from_page"]
task_to_page = task["to_page"]
task_tenant_id = task["tenant_id"]
task_embedding_id = task["embd_id"]
task_language = task["language"]
task_llm_id = task["llm_id"]
task_dataset_id = task["kb_id"]
task_doc_id = task["doc_id"]
task_document_name = task["name"]
task_parser_config = task["parser_config"]
# prepare the progress callback function
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
try:
task_canceled = TaskService.do_cancel(task_id)
except DoesNotExist:
logging.warning(f"task {task_id} is unknown")
return
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return
try:
# bind embedding model
embedding_model = LLMBundle(task_tenant_id, LLMType.EMBEDDING, llm_name=task_embedding_id, lang=task_language)
except Exception as e:
error_message = f'Fail to bind embedding model: {str(e)}'
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
# Either using RAPTOR or Standard chunking methods
if task.get("task_type", "") == "raptor":
try:
# bind LLM for raptor
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
# run RAPTOR
chunks, token_count, vector_size = run_raptor(task, chat_model, embedding_model, progress_callback)
except TaskCanceledException:
raise
except Exception as e:
error_message = f'Fail to bind LLM used by RAPTOR: {str(e)}'
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
else:
# Standard chunking methods
start_ts = timer()
chunks = build_chunks(task, progress_callback)
logging.info("Build document {}: {:.2f}s".format(task_document_name, timer() - start_ts))
if chunks is None:
return
if not chunks:
progress_callback(1., msg=f"No chunk built from {task_document_name}")
return
# TODO: exception handler
## set_progress(task["did"], -1, "ERROR: ")
progress_callback(msg="Generate {} chunks".format(len(chunks)))
start_ts = timer()
try:
token_count, vector_size = embedding(chunks, embedding_model, task_parser_config, progress_callback)
except Exception as e:
error_message = "Generate embedding error:{}".format(str(e))
progress_callback(-1, error_message)
logging.exception(error_message)
token_count = 0
raise
progress_message = "Embedding chunks ({:.2f}s)".format(timer() - start_ts)
logging.info(progress_message)
progress_callback(msg=progress_message)
# logging.info(f"task_executor init_kb index {search.index_name(task_tenant_id)} embedding_model {embedding_model.llm_name} vector length {vector_size}")
init_kb(task, vector_size)
chunk_count = len(set([chunk["id"] for chunk in chunks]))
start_ts = timer()
doc_store_result = ""
es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size):
doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), task_dataset_id)
if b % 128 == 0:
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
progress_callback(-1, msg=error_message)
raise Exception(error_message)
chunk_ids = [chunk["id"] for chunk in chunks[:b + es_bulk_size]]
chunk_ids_str = " ".join(chunk_ids)
try:
TaskService.update_chunk_ids(task["id"], chunk_ids_str)
except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id)
return
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), timer() - start_ts))
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
time_cost = timer() - start_ts
progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(time_cost))
logging.info("Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), token_count, time_cost))
def handle_task():
global PAYLOAD, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
task = collect()
if task:
try:
logging.info(f"handle_task begin for task {json.dumps(task)}")
with mt_lock:
CURRENT_TASK = copy.deepcopy(task)
do_handle_task(task)
with mt_lock:
DONE_TASKS += 1
CURRENT_TASK = None
logging.info(f"handle_task done for task {json.dumps(task)}")
except TaskCanceledException:
with mt_lock:
DONE_TASKS += 1
CURRENT_TASK = None
try:
set_progress(task["id"], prog=-1, msg="handle_task got TaskCanceledException")
except Exception:
pass
logging.debug("handle_task got TaskCanceledException", exc_info=True)
except Exception:
with mt_lock:
FAILED_TASKS += 1
CURRENT_TASK = None
try:
set_progress(task["id"], prog=-1, msg="handle_task got exception, please check log")
except Exception:
pass
logging.exception(f"handle_task got exception for task {json.dumps(task)}")
if PAYLOAD:
PAYLOAD.ack()
PAYLOAD = None
def report_status():
global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME)
while True:
try:
now = datetime.now()
group_info = REDIS_CONN.queue_info(SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
if group_info is not None:
PENDING_TASKS = int(group_info.get("pending", 0))
LAG_TASKS = int(group_info.get("lag", 0))
with mt_lock:
heartbeat = json.dumps({
"name": CONSUMER_NAME,
"now": now.isoformat(),
"boot_at": BOOT_AT,
"pending": PENDING_TASKS,
"lag": LAG_TASKS,
"done": DONE_TASKS,
"failed": FAILED_TASKS,
"current": CURRENT_TASK,
})
REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60 * 30)
if expired > 0:
REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
except Exception:
logging.exception("report_status got exception")
time.sleep(30)
def analyze_heap(snapshot1: tracemalloc.Snapshot, snapshot2: tracemalloc.Snapshot, snapshot_id: int, dump_full: bool):
msg = ""
if dump_full:
stats2 = snapshot2.statistics('lineno')
msg += f"{CONSUMER_NAME} memory usage of snapshot {snapshot_id}:\n"
for stat in stats2[:10]:
msg += f"{stat}\n"
stats1_vs_2 = snapshot2.compare_to(snapshot1, 'lineno')
msg += f"{CONSUMER_NAME} memory usage increase from snapshot {snapshot_id - 1} to snapshot {snapshot_id}:\n"
for stat in stats1_vs_2[:10]:
msg += f"{stat}\n"
msg += f"{CONSUMER_NAME} detailed traceback for the top memory consumers:\n"
for stat in stats1_vs_2[:3]:
msg += '\n'.join(stat.traceback.format())
logging.info(msg)
def main():
logging.info(r"""
______ __ ______ __
/_ __/___ ______/ /__ / ____/ _____ _______ __/ /_____ _____
/ / / __ `/ ___/ //_/ / __/ | |/_/ _ \/ ___/ / / / __/ __ \/ ___/
/ / / /_/ (__ ) ,< / /____> </ __/ /__/ /_/ / /_/ /_/ / /
/_/ \__,_/____/_/|_| /_____/_/|_|\___/\___/\__,_/\__/\____/_/
""")
logging.info(f'TaskExecutor: RAGFlow version: {get_ragflow_version()}')
settings.init_settings()
print_rag_settings()
background_thread = threading.Thread(target=report_status)
background_thread.daemon = True
background_thread.start()
TRACE_MALLOC_DELTA = int(os.environ.get('TRACE_MALLOC_DELTA', "0"))
TRACE_MALLOC_FULL = int(os.environ.get('TRACE_MALLOC_FULL', "0"))
if TRACE_MALLOC_DELTA > 0:
if TRACE_MALLOC_FULL < TRACE_MALLOC_DELTA:
TRACE_MALLOC_FULL = TRACE_MALLOC_DELTA
tracemalloc.start()
snapshot1 = tracemalloc.take_snapshot()
while True:
handle_task()
num_tasks = DONE_TASKS + FAILED_TASKS
if TRACE_MALLOC_DELTA > 0 and num_tasks > 0 and num_tasks % TRACE_MALLOC_DELTA == 0:
snapshot2 = tracemalloc.take_snapshot()
analyze_heap(snapshot1, snapshot2, int(num_tasks / TRACE_MALLOC_DELTA), num_tasks % TRACE_MALLOC_FULL == 0)
snapshot1 = snapshot2
snapshot2 = None
if __name__ == "__main__":
main()