ragflow / api /db /services /document_service.py
Kevin Hu
fix thumbnail issue (#2917)
2904f4e
raw
history blame
21.4 kB
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import hashlib
import json
import os
import random
import re
import traceback
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from datetime import datetime
from io import BytesIO
from elasticsearch_dsl import Q
from peewee import fn
from api.db.db_utils import bulk_insert_into_db
from api.settings import stat_logger
from api.utils import current_timestamp, get_format_time, get_uuid
from api.utils.file_utils import get_project_base_directory
from graphrag.mind_map_extractor import MindMapExtractor
from rag.settings import SVR_QUEUE_NAME
from rag.utils.es_conn import ELASTICSEARCH
from rag.utils.storage_factory import STORAGE_IMPL
from rag.nlp import search, rag_tokenizer
from api.db import FileType, TaskStatus, ParserType, LLMType
from api.db.db_models import DB, Knowledgebase, Tenant, Task, UserTenant
from api.db.db_models import Document
from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db import StatusEnum
from rag.utils.redis_conn import REDIS_CONN
class DocumentService(CommonService):
model = Document
@classmethod
@DB.connection_context()
def get_list(cls, kb_id, page_number, items_per_page,
orderby, desc, keywords, id):
docs =cls.model.select().where(cls.model.kb_id==kb_id)
if id:
docs = docs.where(
cls.model.id== id )
if keywords:
docs = docs.where(
fn.LOWER(cls.model.name).contains(keywords.lower())
)
if desc:
docs = docs.order_by(cls.model.getter_by(orderby).desc())
else:
docs = docs.order_by(cls.model.getter_by(orderby).asc())
docs = docs.paginate(page_number, items_per_page)
count = docs.count()
return list(docs.dicts()), count
@classmethod
@DB.connection_context()
def get_by_kb_id(cls, kb_id, page_number, items_per_page,
orderby, desc, keywords):
if keywords:
docs = cls.model.select().where(
(cls.model.kb_id == kb_id),
(fn.LOWER(cls.model.name).contains(keywords.lower()))
)
else:
docs = cls.model.select().where(cls.model.kb_id == kb_id)
count = docs.count()
if desc:
docs = docs.order_by(cls.model.getter_by(orderby).desc())
else:
docs = docs.order_by(cls.model.getter_by(orderby).asc())
docs = docs.paginate(page_number, items_per_page)
return list(docs.dicts()), count
@classmethod
@DB.connection_context()
def list_documents_in_dataset(cls, dataset_id, offset, count, order_by, descend, keywords):
if keywords:
docs = cls.model.select().where(
(cls.model.kb_id == dataset_id),
(fn.LOWER(cls.model.name).contains(keywords.lower()))
)
else:
docs = cls.model.select().where(cls.model.kb_id == dataset_id)
total = docs.count()
if descend == 'True':
docs = docs.order_by(cls.model.getter_by(order_by).desc())
if descend == 'False':
docs = docs.order_by(cls.model.getter_by(order_by).asc())
docs = list(docs.dicts())
docs_length = len(docs)
if offset < 0 or offset > docs_length:
raise IndexError("Offset is out of the valid range.")
if count == -1:
return docs[offset:], total
return docs[offset:offset + count], total
@classmethod
@DB.connection_context()
def insert(cls, doc):
if not cls.save(**doc):
raise RuntimeError("Database error (Document)!")
e, doc = cls.get_by_id(doc["id"])
if not e:
raise RuntimeError("Database error (Document retrieval)!")
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not KnowledgebaseService.update_by_id(
kb.id, {"doc_num": kb.doc_num + 1}):
raise RuntimeError("Database error (Knowledgebase)!")
return doc
@classmethod
@DB.connection_context()
def remove_document(cls, doc, tenant_id):
ELASTICSEARCH.deleteByQuery(
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
cls.clear_chunk_num(doc.id)
return cls.delete_by_id(doc.id)
@classmethod
@DB.connection_context()
def get_newly_uploaded(cls):
fields = [
cls.model.id,
cls.model.kb_id,
cls.model.parser_id,
cls.model.parser_config,
cls.model.name,
cls.model.type,
cls.model.location,
cls.model.size,
Knowledgebase.tenant_id,
Tenant.embd_id,
Tenant.img2txt_id,
Tenant.asr_id,
cls.model.update_time]
docs = cls.model.select(*fields) \
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
.where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress == 0,
cls.model.update_time >= current_timestamp() - 1000 * 600,
cls.model.run == TaskStatus.RUNNING.value)\
.order_by(cls.model.update_time.asc())
return list(docs.dicts())
@classmethod
@DB.connection_context()
def get_unfinished_docs(cls):
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, cls.model.run]
docs = cls.model.select(*fields) \
.where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress < 1,
cls.model.progress > 0)
return list(docs.dicts())
@classmethod
@DB.connection_context()
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
num = cls.model.update(token_num=cls.model.token_num + token_num,
chunk_num=cls.model.chunk_num + chunk_num,
process_duation=cls.model.process_duation + duation).where(
cls.model.id == doc_id).execute()
if num == 0:
raise LookupError(
"Document not found which is supposed to be there")
num = Knowledgebase.update(
token_num=Knowledgebase.token_num +
token_num,
chunk_num=Knowledgebase.chunk_num +
chunk_num).where(
Knowledgebase.id == kb_id).execute()
return num
@classmethod
@DB.connection_context()
def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
num = cls.model.update(token_num=cls.model.token_num - token_num,
chunk_num=cls.model.chunk_num - chunk_num,
process_duation=cls.model.process_duation + duation).where(
cls.model.id == doc_id).execute()
if num == 0:
raise LookupError(
"Document not found which is supposed to be there")
num = Knowledgebase.update(
token_num=Knowledgebase.token_num -
token_num,
chunk_num=Knowledgebase.chunk_num -
chunk_num
).where(
Knowledgebase.id == kb_id).execute()
return num
@classmethod
@DB.connection_context()
def clear_chunk_num(cls, doc_id):
doc = cls.model.get_by_id(doc_id)
assert doc, "Can't fine document in database."
num = Knowledgebase.update(
token_num=Knowledgebase.token_num -
doc.token_num,
chunk_num=Knowledgebase.chunk_num -
doc.chunk_num,
doc_num=Knowledgebase.doc_num-1
).where(
Knowledgebase.id == doc.kb_id).execute()
return num
@classmethod
@DB.connection_context()
def get_tenant_id(cls, doc_id):
docs = cls.model.select(
Knowledgebase.tenant_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return
return docs[0]["tenant_id"]
@classmethod
@DB.connection_context()
def get_tenant_id_by_name(cls, name):
docs = cls.model.select(
Knowledgebase.tenant_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return
return docs[0]["tenant_id"]
@classmethod
@DB.connection_context()
def accessible(cls, doc_id, user_id):
docs = cls.model.select(
cls.model.id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)
).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
return True
@classmethod
@DB.connection_context()
def accessible4deletion(cls, doc_id, user_id):
docs = cls.model.select(
cls.model.id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)
).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
return True
@classmethod
@DB.connection_context()
def get_embd_id(cls, doc_id):
docs = cls.model.select(
Knowledgebase.embd_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return
return docs[0]["embd_id"]
@classmethod
@DB.connection_context()
def get_doc_id_by_doc_name(cls, doc_name):
fields = [cls.model.id]
doc_id = cls.model.select(*fields) \
.where(cls.model.name == doc_name)
doc_id = doc_id.dicts()
if not doc_id:
return
return doc_id[0]["id"]
@classmethod
@DB.connection_context()
def get_thumbnails(cls, docids):
fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail]
return list(cls.model.select(
*fields).where(cls.model.id.in_(docids)).dicts())
@classmethod
@DB.connection_context()
def update_parser_config(cls, id, config):
e, d = cls.get_by_id(id)
if not e:
raise LookupError(f"Document({id}) not found.")
def dfs_update(old, new):
for k, v in new.items():
if k not in old:
old[k] = v
continue
if isinstance(v, dict):
assert isinstance(old[k], dict)
dfs_update(old[k], v)
else:
old[k] = v
dfs_update(d.parser_config, config)
cls.update_by_id(id, {"parser_config": d.parser_config})
@classmethod
@DB.connection_context()
def get_doc_count(cls, tenant_id):
docs = cls.model.select(cls.model.id).join(Knowledgebase,
on=(Knowledgebase.id == cls.model.kb_id)).where(
Knowledgebase.tenant_id == tenant_id)
return len(docs)
@classmethod
@DB.connection_context()
def begin2parse(cls, docid):
cls.update_by_id(
docid, {"progress": random.random() * 1 / 100.,
"progress_msg": "Task dispatched...",
"process_begin_at": get_format_time()
})
@classmethod
@DB.connection_context()
def update_progress(cls):
docs = cls.get_unfinished_docs()
for d in docs:
try:
tsks = Task.query(doc_id=d["id"], order_by=Task.create_time)
if not tsks:
continue
msg = []
prg = 0
finished = True
bad = 0
e, doc = DocumentService.get_by_id(d["id"])
status = doc.run#TaskStatus.RUNNING.value
for t in tsks:
if 0 <= t.progress < 1:
finished = False
prg += t.progress if t.progress >= 0 else 0
if t.progress_msg not in msg:
msg.append(t.progress_msg)
if t.progress == -1:
bad += 1
prg /= len(tsks)
if finished and bad:
prg = -1
status = TaskStatus.FAIL.value
elif finished:
if d["parser_config"].get("raptor", {}).get("use_raptor") and d["progress_msg"].lower().find(" raptor")<0:
queue_raptor_tasks(d)
prg = 0.98 * len(tsks)/(len(tsks)+1)
msg.append("------ RAPTOR -------")
else:
status = TaskStatus.DONE.value
msg = "\n".join(msg)
info = {
"process_duation": datetime.timestamp(
datetime.now()) -
d["process_begin_at"].timestamp(),
"run": status}
if prg != 0:
info["progress"] = prg
if msg:
info["progress_msg"] = msg
cls.update_by_id(d["id"], info)
except Exception as e:
if str(e).find("'0'") < 0:
stat_logger.error("fetch task exception:" + str(e))
@classmethod
@DB.connection_context()
def get_kb_doc_count(cls, kb_id):
return len(cls.model.select(cls.model.id).where(
cls.model.kb_id == kb_id).dicts())
@classmethod
@DB.connection_context()
def do_cancel(cls, doc_id):
try:
_, doc = DocumentService.get_by_id(doc_id)
return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
except Exception as e:
pass
return False
def queue_raptor_tasks(doc):
def new_task():
nonlocal doc
return {
"id": get_uuid(),
"doc_id": doc["id"],
"from_page": 0,
"to_page": -1,
"progress_msg": "Start to do RAPTOR (Recursive Abstractive Processing For Tree-Organized Retrieval)."
}
task = new_task()
bulk_insert_into_db(Task, [task], True)
task["type"] = "raptor"
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
def doc_upload_and_parse(conversation_id, file_objs, user_id):
from rag.app import presentation, picture, naive, audio, email
from api.db.services.dialog_service import ConversationService, DialogService
from api.db.services.file_service import FileService
from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService
from api.db.services.api_service import API4ConversationService
e, conv = ConversationService.get_by_id(conversation_id)
if not e:
e, conv = API4ConversationService.get_by_id(conversation_id)
assert e, "Conversation not found!"
e, dia = DialogService.get_by_id(conv.dialog_id)
kb_id = dia.kb_ids[0]
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
raise LookupError("Can't find this knowledgebase!")
idxnm = search.index_name(kb.tenant_id)
if not ELASTICSEARCH.indexExist(idxnm):
ELASTICSEARCH.createIdx(idxnm, json.load(
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
err, files = FileService.upload_document(kb, file_objs, user_id)
assert not err, "\n".join(err)
def dummy(prog=None, msg=""):
pass
FACTORY = {
ParserType.PRESENTATION.value: presentation,
ParserType.PICTURE.value: picture,
ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email
}
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": False}
exe = ThreadPoolExecutor(max_workers=12)
threads = []
doc_nm = {}
for d, blob in files:
doc_nm[d["id"]] = d["name"]
for d, blob in files:
kwargs = {
"callback": dummy,
"parser_config": parser_config,
"from_page": 0,
"to_page": 100000,
"tenant_id": kb.tenant_id,
"lang": kb.language
}
threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
for (docinfo, _), th in zip(files, threads):
docs = []
doc = {
"doc_id": docinfo["id"],
"kb_id": [kb.id]
}
for ck in th.result():
d = deepcopy(doc)
d.update(ck)
md5 = hashlib.md5()
md5.update((ck["content_with_weight"] +
str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest()
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.now().timestamp()
if not d.get("image"):
docs.append(d)
continue
output_buffer = BytesIO()
if isinstance(d["image"], bytes):
output_buffer = BytesIO(d["image"])
else:
d["image"].save(output_buffer, format='JPEG')
STORAGE_IMPL.put(kb.id, d["_id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(kb.id, d["_id"])
del d["image"]
docs.append(d)
parser_ids = {d["id"]: d["parser_id"] for d, _ in files}
docids = [d["id"] for d, _ in files]
chunk_counts = {id: 0 for id in docids}
token_counts = {id: 0 for id in docids}
es_bulk_size = 64
def embedding(doc_id, cnts, batch_size=16):
nonlocal embd_mdl, chunk_counts, token_counts
vects = []
for i in range(0, len(cnts), batch_size):
vts, c = embd_mdl.encode(cnts[i: i + batch_size])
vects.extend(vts.tolist())
chunk_counts[doc_id] += len(cnts[i:i + batch_size])
token_counts[doc_id] += c
return vects
_, tenant = TenantService.get_by_id(kb.tenant_id)
llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
for doc_id in docids:
cks = [c for c in docs if c["doc_id"] == doc_id]
if parser_ids[doc_id] != ParserType.PICTURE.value:
mindmap = MindMapExtractor(llm_bdl)
try:
mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output,
ensure_ascii=False, indent=2)
if len(mind_map) < 32: raise Exception("Few content: " + mind_map)
cks.append({
"id": get_uuid(),
"doc_id": doc_id,
"kb_id": [kb.id],
"docnm_kwd": doc_nm[doc_id],
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
"content_ltks": "",
"content_with_weight": mind_map,
"knowledge_graph_kwd": "mind_map"
})
except Exception as e:
stat_logger.error("Mind map generation error:", traceback.format_exc())
vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
assert len(cks) == len(vects)
for i, d in enumerate(cks):
v = vects[i]
d["q_%d_vec" % len(v)] = v
for b in range(0, len(cks), es_bulk_size):
ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], idxnm)
DocumentService.increment_chunk_num(
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
return [d["id"] for d,_ in files]