# # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from peewee import Expression from elasticsearch_dsl import Q from rag.utils.es_conn import ELASTICSEARCH from rag.utils.minio_conn import MINIO from rag.nlp import search from api.db import FileType, TaskStatus from api.db.db_models import DB, Knowledgebase, Tenant from api.db.db_models import Document from api.db.services.common_service import CommonService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db import StatusEnum class DocumentService(CommonService): model = Document @classmethod @DB.connection_context() def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords): if keywords: docs = cls.model.select().where( cls.model.kb_id == kb_id, cls.model.name.like(f"%%{keywords}%%")) else: docs = cls.model.select().where(cls.model.kb_id == kb_id) count = docs.count() if desc: docs = docs.order_by(cls.model.getter_by(orderby).desc()) else: docs = docs.order_by(cls.model.getter_by(orderby).asc()) docs = docs.paginate(page_number, items_per_page) return list(docs.dicts()), count @classmethod @DB.connection_context() def insert(cls, doc): if not cls.save(**doc): raise RuntimeError("Database error (Document)!") e, doc = cls.get_by_id(doc["id"]) if not e: raise RuntimeError("Database error (Document retrieval)!") e, kb = KnowledgebaseService.get_by_id(doc.kb_id) if not KnowledgebaseService.update_by_id( kb.id, {"doc_num": kb.doc_num + 1}): raise RuntimeError("Database error (Knowledgebase)!") return doc @classmethod @DB.connection_context() def delete(cls, doc): e, kb = KnowledgebaseService.get_by_id(doc.kb_id) if not KnowledgebaseService.update_by_id( kb.id, {"doc_num": kb.doc_num - 1}): raise RuntimeError("Database error (Knowledgebase)!") return cls.delete_by_id(doc.id) @classmethod @DB.connection_context() def remove_document(cls, doc, tenant_id): ELASTICSEARCH.deleteByQuery( Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)) cls.increment_chunk_num( doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, 0) if not cls.delete(doc): raise RuntimeError("Database error (Document removal)!") MINIO.rm(doc.kb_id, doc.location) return cls.delete_by_id(doc.id) @classmethod @DB.connection_context() def get_newly_uploaded(cls, tm, mod=0, comm=1, items_per_page=64): fields = [ cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.parser_config, cls.model.name, cls.model.type, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time] docs = cls.model.select(*fields) \ .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \ .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\ .where( cls.model.status == StatusEnum.VALID.value, ~(cls.model.type == FileType.VIRTUAL.value), cls.model.progress == 0, cls.model.update_time >= tm, cls.model.run == TaskStatus.RUNNING.value, (Expression(cls.model.create_time, "%%", comm) == mod))\ .order_by(cls.model.update_time.asc())\ .paginate(1, items_per_page) return list(docs.dicts()) @classmethod @DB.connection_context() def get_unfinished_docs(cls): fields = [cls.model.id, cls.model.process_begin_at] docs = cls.model.select(*fields) \ .where( cls.model.status == StatusEnum.VALID.value, ~(cls.model.type == FileType.VIRTUAL.value), cls.model.progress < 1, cls.model.progress > 0) return list(docs.dicts()) @classmethod @DB.connection_context() def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation): num = cls.model.update(token_num=cls.model.token_num + token_num, chunk_num=cls.model.chunk_num + chunk_num, process_duation=cls.model.process_duation + duation).where( cls.model.id == doc_id).execute() if num == 0: raise LookupError( "Document not found which is supposed to be there") num = Knowledgebase.update( token_num=Knowledgebase.token_num + token_num, chunk_num=Knowledgebase.chunk_num + chunk_num).where( Knowledgebase.id == kb_id).execute() return num @classmethod @DB.connection_context() def get_tenant_id(cls, doc_id): docs = cls.model.select( Knowledgebase.tenant_id).join( Knowledgebase, on=( Knowledgebase.id == cls.model.kb_id)).where( cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) docs = docs.dicts() if not docs: return return docs[0]["tenant_id"] @classmethod @DB.connection_context() def get_thumbnails(cls, docids): fields = [cls.model.id, cls.model.thumbnail] return list(cls.model.select( *fields).where(cls.model.id.in_(docids)).dicts()) @classmethod @DB.connection_context() def update_parser_config(cls, id, config): e, d = cls.get_by_id(id) if not e: raise LookupError(f"Document({id}) not found.") def dfs_update(old, new): for k, v in new.items(): if k not in old: old[k] = v continue if isinstance(v, dict): assert isinstance(old[k], dict) dfs_update(old[k], v) else: old[k] = v dfs_update(d.parser_config, config) cls.update_by_id(id, {"parser_config": d.parser_config}) @classmethod @DB.connection_context() def get_doc_count(cls, tenant_id): docs = cls.model.select(cls.model.id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where( Knowledgebase.tenant_id == tenant_id) return len(docs)