SOAPAssistV00 / gpt_index /docstore.py
AbeerTrial's picture
Duplicate from AbeerTrial/SOAPAssist
35b22df
raw
history blame
4.2 kB
"""Document store."""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Type, Union
from dataclasses_json import DataClassJsonMixin
from gpt_index.data_structs.data_structs import IndexStruct
from gpt_index.readers.schema.base import Document
DOC_TYPE = Union[IndexStruct, Document]
# type key: used to store type of document
TYPE_KEY = "__type__"
@dataclass
class DocumentStore(DataClassJsonMixin):
"""Document store."""
docs: Dict[str, DOC_TYPE] = field(default_factory=dict)
def serialize_to_dict(self) -> Dict[str, Any]:
"""Serialize to dict."""
docs_dict = {}
for doc_id, doc in self.docs.items():
doc_dict = doc.to_dict()
doc_dict[TYPE_KEY] = doc.get_type()
docs_dict[doc_id] = doc_dict
return {"docs": docs_dict}
def contains_index_struct(self, exclude_ids: Optional[List[str]] = None) -> bool:
"""Check if contains index struct."""
exclude_ids = exclude_ids or []
for doc in self.docs.values():
if isinstance(doc, IndexStruct) and doc.get_doc_id() not in exclude_ids:
return True
return False
@classmethod
def load_from_dict(
cls,
docs_dict: Dict[str, Any],
type_to_struct: Optional[Dict[str, Type[IndexStruct]]] = None,
) -> "DocumentStore":
"""Load from dict."""
docs_obj_dict = {}
for doc_id, doc_dict in docs_dict["docs"].items():
doc_type = doc_dict.pop(TYPE_KEY, None)
if doc_type == "Document" or doc_type is None:
doc: DOC_TYPE = Document.from_dict(doc_dict)
else:
if type_to_struct is None:
raise ValueError(
"type_to_struct must be provided if type is index struct."
)
# try using IndexStructType to retrieve documents
if doc_type not in type_to_struct:
raise ValueError(
f"doc_type {doc_type} not found in type_to_struct. "
"Make sure that it was registered in the index registry."
)
doc = type_to_struct[doc_type].from_dict(doc_dict)
# doc = index_struct_cls.from_dict(doc_dict)
docs_obj_dict[doc_id] = doc
return cls(docs=docs_obj_dict)
@classmethod
def from_documents(cls, docs: List[DOC_TYPE]) -> "DocumentStore":
"""Create from documents."""
obj = cls()
obj.add_documents(docs)
return obj
def update_docstore(self, other: "DocumentStore") -> None:
"""Update docstore."""
self.docs.update(other.docs)
def add_documents(self, docs: List[DOC_TYPE], allow_update: bool = False) -> None:
"""Add a document to the store."""
for doc in docs:
if doc.is_doc_id_none:
raise ValueError("doc_id not set")
# NOTE: doc could already exist in the store, but we overwrite it
if not allow_update and self.document_exists(doc.get_doc_id()):
raise ValueError(
f"doc_id {doc.get_doc_id()} already exists. "
"Set allow_update to True to overwrite."
)
self.docs[doc.get_doc_id()] = doc
def get_document(self, doc_id: str, raise_error: bool = True) -> Optional[DOC_TYPE]:
"""Get a document from the store."""
doc = self.docs.get(doc_id, None)
if doc is None and raise_error:
raise ValueError(f"doc_id {doc_id} not found.")
return doc
def document_exists(self, doc_id: str) -> bool:
"""Check if document exists."""
return doc_id in self.docs
def delete_document(
self, doc_id: str, raise_error: bool = True
) -> Optional[DOC_TYPE]:
"""Delete a document from the store."""
doc = self.docs.pop(doc_id, None)
if doc is None and raise_error:
raise ValueError(f"doc_id {doc_id} not found.")
return doc
def __len__(self) -> int:
"""Get length."""
return len(self.docs.keys())