AbeerTrial's picture
Duplicate from AbeerTrial/SOAPAssist
35b22df
raw
history blame
12 kB
"""File for core data structures."""
import random
import sys
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set, Tuple
from dataclasses_json import DataClassJsonMixin
from gpt_index.data_structs.struct_type import IndexStructType
from gpt_index.schema import BaseDocument
from gpt_index.utils import get_new_int_id
@dataclass
class IndexStruct(BaseDocument, DataClassJsonMixin):
"""A base data struct for a LlamaIndex."""
# NOTE: the text field, inherited from BaseDocument,
# represents a summary of the content of the index struct.
# primarily used for composing indices with other indices
# NOTE: the doc_id field, inherited from BaseDocument,
# represents a unique identifier for the index struct
# that will be put in the docstore.
# Not all index_structs need to have a doc_id. Only index_structs that
# represent a complete data structure (e.g. IndexGraph, IndexList),
# and are used to compose a higher level index, will have a doc_id.
@dataclass
class Node(IndexStruct):
"""A generic node of data.
Base struct used in most indices.
"""
def __post_init__(self) -> None:
"""Post init."""
super().__post_init__()
# NOTE: for Node objects, the text field is required
if self.text is None:
raise ValueError("text field not set.")
# used for GPTTreeIndex
index: int = 0
child_indices: Set[int] = field(default_factory=set)
# embeddings
embedding: Optional[List[float]] = None
# reference document id
ref_doc_id: Optional[str] = None
# extra node info
node_info: Optional[Dict[str, Any]] = None
def get_text(self) -> str:
"""Get text."""
text = super().get_text()
result_text = (
text if self.extra_info_str is None else f"{self.extra_info_str}\n\n{text}"
)
return result_text
@classmethod
def get_type(cls) -> str:
"""Get type."""
# TODO: consolidate with IndexStructType
return "node"
@dataclass
class IndexGraph(IndexStruct):
"""A graph representing the tree-structured index."""
all_nodes: Dict[int, Node] = field(default_factory=dict)
root_nodes: Dict[int, Node] = field(default_factory=dict)
@property
def size(self) -> int:
"""Get the size of the graph."""
return len(self.all_nodes)
def get_children(self, parent_node: Optional[Node]) -> Dict[int, Node]:
"""Get nodes given indices."""
if parent_node is None:
return self.root_nodes
else:
return {i: self.all_nodes[i] for i in parent_node.child_indices}
def insert_under_parent(self, node: Node, parent_node: Optional[Node]) -> None:
"""Insert under parent node."""
if node.index in self.all_nodes:
raise ValueError(
"Cannot insert a new node with the same index as an existing node."
)
if parent_node is None:
self.root_nodes[node.index] = node
else:
parent_node.child_indices.add(node.index)
self.all_nodes[node.index] = node
@classmethod
def get_type(cls) -> str:
"""Get type."""
return "tree"
@dataclass
class KeywordTable(IndexStruct):
"""A table of keywords mapping keywords to text chunks."""
table: Dict[str, Set[int]] = field(default_factory=dict)
text_chunks: Dict[int, Node] = field(default_factory=dict)
def _get_index(self) -> int:
"""Get the next index for the text chunk."""
# randomly generate until we get a unique index
while True:
idx = random.randint(0, sys.maxsize)
if idx not in self.text_chunks:
break
return idx
def add_node(self, keywords: List[str], node: Node) -> int:
"""Add text to table."""
cur_idx = self._get_index()
for keyword in keywords:
if keyword not in self.table:
self.table[keyword] = set()
self.table[keyword].add(cur_idx)
self.text_chunks[cur_idx] = node
return cur_idx
def get_texts(self, keyword: str) -> List[str]:
"""Get texts given keyword."""
if keyword not in self.table:
raise ValueError("Keyword not found in table.")
return [self.text_chunks[idx].get_text() for idx in self.table[keyword]]
@property
def keywords(self) -> Set[str]:
"""Get all keywords in the table."""
return set(self.table.keys())
@property
def size(self) -> int:
"""Get the size of the table."""
return len(self.table)
@classmethod
def get_type(cls) -> str:
"""Get type."""
return "keyword_table"
@dataclass
class IndexList(IndexStruct):
"""A list of documents."""
nodes: List[Node] = field(default_factory=list)
def add_node(self, node: Node) -> None:
"""Add text to table, return current position in list."""
# don't worry about child indices for now, nodes are all in order
self.nodes.append(node)
@classmethod
def get_type(cls) -> str:
"""Get type."""
return "list"
@dataclass
class IndexDict(IndexStruct):
"""A simple dictionary of documents."""
nodes_dict: Dict[int, Node] = field(default_factory=dict)
id_map: Dict[str, int] = field(default_factory=dict)
# TODO: temporary hack to store embeddings for simple vector index
# this should be empty for all other indices
embeddings_dict: Dict[str, List[float]] = field(default_factory=dict)
def add_node(
self,
node: Node,
text_id: Optional[str] = None,
) -> str:
"""Add text to table, return current position in list."""
int_id = get_new_int_id(set(self.nodes_dict.keys()))
if text_id in self.id_map:
raise ValueError("text_id cannot already exist in index.")
elif text_id is not None and not isinstance(text_id, str):
raise ValueError("text_id must be a string.")
elif text_id is None:
text_id = str(int_id)
self.id_map[text_id] = int_id
# don't worry about child indices for now, nodes are all in order
self.nodes_dict[int_id] = node
return text_id
def get_nodes(self, text_ids: List[str]) -> List[Node]:
"""Get nodes."""
nodes = []
for text_id in text_ids:
if text_id not in self.id_map:
raise ValueError("text_id not found in id_map")
elif not isinstance(text_id, str):
raise ValueError("text_id must be a string.")
int_id = self.id_map[text_id]
if int_id not in self.nodes_dict:
raise ValueError("int_id not found in nodes_dict")
nodes.append(self.nodes_dict[int_id])
return nodes
def get_node(self, text_id: str) -> Node:
"""Get node."""
return self.get_nodes([text_id])[0]
def delete(self, doc_id: str) -> None:
"""Delete a document."""
text_ids_to_delete = set()
int_ids_to_delete = set()
for text_id, int_id in self.id_map.items():
node = self.nodes_dict[int_id]
if node.ref_doc_id != doc_id:
continue
text_ids_to_delete.add(text_id)
int_ids_to_delete.add(int_id)
for int_id, text_id in zip(int_ids_to_delete, text_ids_to_delete):
del self.nodes_dict[int_id]
del self.id_map[text_id]
@classmethod
def get_type(cls) -> str:
"""Get type."""
return IndexStructType.VECTOR_STORE
@dataclass
class KG(IndexStruct):
"""A table of keywords mapping keywords to text chunks."""
# Unidirectional
table: Dict[str, Set[str]] = field(default_factory=dict)
text_chunks: Dict[str, Node] = field(default_factory=dict)
rel_map: Dict[str, List[Tuple[str, str]]] = field(default_factory=dict)
embedding_dict: Dict[str, List[float]] = field(default_factory=dict)
def add_to_embedding_dict(self, triplet_str: str, embedding: List[float]) -> None:
"""Add embedding to dict."""
self.embedding_dict[triplet_str] = embedding
def upsert_triplet(self, triplet: Tuple[str, str, str], node: Node) -> None:
"""Upsert a knowledge triplet to the graph."""
subj, relationship, obj = triplet
self.add_node([subj, obj], node)
if subj not in self.rel_map:
self.rel_map[subj] = []
self.rel_map[subj].append((obj, relationship))
def add_node(self, keywords: List[str], node: Node) -> None:
"""Add text to table."""
node_id = node.get_doc_id()
for keyword in keywords:
if keyword not in self.table:
self.table[keyword] = set()
self.table[keyword].add(node_id)
self.text_chunks[node_id] = node
def get_rel_map_texts(self, keyword: str) -> List[str]:
"""Get the corresponding knowledge for a given keyword."""
# NOTE: return a single node for now
if keyword not in self.rel_map:
return []
texts = []
for obj, rel in self.rel_map[keyword]:
texts.append(str((keyword, rel, obj)))
return texts
def get_rel_map_tuples(self, keyword: str) -> List[Tuple[str, str]]:
"""Get the corresponding knowledge for a given keyword."""
# NOTE: return a single node for now
if keyword not in self.rel_map:
return []
return self.rel_map[keyword]
def get_node_ids(self, keyword: str, depth: int = 1) -> List[str]:
"""Get the corresponding knowledge for a given keyword."""
if depth > 1:
raise ValueError("Depth > 1 not supported yet.")
if keyword not in self.table:
return []
keywords = [keyword]
# some keywords may correspond to a leaf node, may not be in rel_map
if keyword in self.rel_map:
keywords.extend([child for child, _ in self.rel_map[keyword]])
node_ids: List[str] = []
for keyword in keywords:
for node_id in self.table.get(keyword, set()):
node_ids.append(node_id)
# TODO: Traverse (with depth > 1)
return node_ids
@classmethod
def get_type(cls) -> str:
"""Get type."""
return "kg"
# TODO: remove once we centralize UX around vector index
class SimpleIndexDict(IndexDict):
"""Index dict for simple vector index."""
@classmethod
def get_type(cls) -> str:
"""Get type."""
return IndexStructType.SIMPLE_DICT
class FaissIndexDict(IndexDict):
"""Index dict for Faiss vector index."""
@classmethod
def get_type(cls) -> str:
"""Get type."""
return IndexStructType.DICT
class WeaviateIndexDict(IndexDict):
"""Index dict for Weaviate vector index."""
@classmethod
def get_type(cls) -> str:
"""Get type."""
return IndexStructType.WEAVIATE
class PineconeIndexDict(IndexDict):
"""Index dict for Pinecone vector index."""
@classmethod
def get_type(cls) -> str:
"""Get type."""
return IndexStructType.PINECONE
class QdrantIndexDict(IndexDict):
"""Index dict for Qdrant vector index."""
@classmethod
def get_type(cls) -> str:
"""Get type."""
return IndexStructType.QDRANT
class ChromaIndexDict(IndexDict):
"""Index dict for Chroma vector index."""
@classmethod
def get_type(cls) -> str:
"""Get type."""
return IndexStructType.CHROMA
class OpensearchIndexDict(IndexDict):
"""Index dict for Opensearch vector index."""
@classmethod
def get_type(cls) -> str:
"""Get type."""
return IndexStructType.OPENSEARCH
class EmptyIndex(IndexStruct):
"""Empty index."""
@classmethod
def get_type(cls) -> str:
"""Get type."""
return IndexStructType.EMPTY