Spaces:
Runtime error
Runtime error
"""Chroma vector store.""" | |
import logging | |
import math | |
from typing import Any, List, Optional, cast | |
from gpt_index.data_structs.data_structs import Node | |
from gpt_index.utils import truncate_text | |
from gpt_index.vector_stores.types import ( | |
NodeEmbeddingResult, | |
VectorStore, | |
VectorStoreQueryResult, | |
) | |
class ChromaVectorStore(VectorStore): | |
"""Chroma vector store. | |
In this vector store, embeddings are stored within a ChromaDB collection. | |
During query time, the index uses ChromaDB to query for the top | |
k most similar nodes. | |
Args: | |
chroma_collection (chromadb.api.models.Collection.Collection): | |
ChromaDB collection instance | |
""" | |
stores_text: bool = True | |
def __init__(self, chroma_collection: Any, **kwargs: Any) -> None: | |
"""Init params.""" | |
import_err_msg = ( | |
"`chromadb` package not found, please run `pip install chromadb`" | |
) | |
try: | |
import chromadb # noqa: F401 | |
except ImportError: | |
raise ImportError(import_err_msg) | |
from chromadb.api.models.Collection import Collection | |
self._collection = cast(Collection, chroma_collection) | |
def config_dict(self) -> dict: | |
"""Return config dict.""" | |
return {} | |
def add(self, embedding_results: List[NodeEmbeddingResult]) -> List[str]: | |
"""Add embedding results to index. | |
Args | |
embedding_results: List[NodeEmbeddingResult]: list of embedding results | |
""" | |
if not self._collection: | |
raise ValueError("Collection not initialized") | |
embeddings = [] | |
metadatas = [] | |
ids = [] | |
documents = [] | |
for result in embedding_results: | |
embeddings.append(result.embedding) | |
metadatas.append({"document_id": result.doc_id}) | |
ids.append(result.id) | |
documents.append(result.node.get_text()) | |
self._collection.add( | |
embeddings=embeddings, | |
ids=ids, | |
metadatas=metadatas, | |
documents=documents, | |
) | |
return ids | |
def delete(self, doc_id: str, **delete_kwargs: Any) -> None: | |
"""Delete a document. | |
Args: | |
doc_id (str): document id | |
""" | |
self._collection.delete(where={"document_id": doc_id}) | |
def client(self) -> Any: | |
"""Return client.""" | |
return self._collection | |
def query( | |
self, | |
query_embedding: List[float], | |
similarity_top_k: int, | |
doc_ids: Optional[List[str]] = None, | |
) -> VectorStoreQueryResult: | |
"""Query index for top k most similar nodes. | |
Args: | |
query_embedding (List[float]): query embedding | |
similarity_top_k (int): top k most similar nodes | |
""" | |
results = self._collection.query( | |
query_embeddings=query_embedding, n_results=similarity_top_k | |
) | |
logging.debug(f"> Top {len(results['documents'])} nodes:") | |
nodes = [] | |
similarities = [] | |
ids = [] | |
for result in zip( | |
results["ids"], | |
results["documents"], | |
results["metadatas"], | |
results["distances"], | |
): | |
node_id = result[0][0] | |
node = Node( | |
doc_id=node_id, | |
text=result[1][0], | |
extra_info=result[2][0], | |
ref_doc_id=result[2][0]["document_id"], | |
) | |
nodes.append(node) | |
similarity_score = 1.0 - math.exp(-result[3][0]) | |
similarities.append(similarity_score) | |
logging.debug( | |
f"> [Node {result[0][0]}] [Similarity score: {similarity_score}] " | |
f"{truncate_text(str(result[1][0]), 100)}" | |
) | |
ids.append(node_id) | |
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids) | |