Spaces:
Runtime error
Runtime error
File size: 3,934 Bytes
8a58cf3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
"""Chroma vector store."""
import logging
import math
from typing import Any, List, Optional, cast
from gpt_index.data_structs.data_structs import Node
from gpt_index.utils import truncate_text
from gpt_index.vector_stores.types import (
NodeEmbeddingResult,
VectorStore,
VectorStoreQueryResult,
)
class ChromaVectorStore(VectorStore):
"""Chroma vector store.
In this vector store, embeddings are stored within a ChromaDB collection.
During query time, the index uses ChromaDB to query for the top
k most similar nodes.
Args:
chroma_collection (chromadb.api.models.Collection.Collection):
ChromaDB collection instance
"""
stores_text: bool = True
def __init__(self, chroma_collection: Any, **kwargs: Any) -> None:
"""Init params."""
import_err_msg = (
"`chromadb` package not found, please run `pip install chromadb`"
)
try:
import chromadb # noqa: F401
except ImportError:
raise ImportError(import_err_msg)
from chromadb.api.models.Collection import Collection
self._collection = cast(Collection, chroma_collection)
@property
def config_dict(self) -> dict:
"""Return config dict."""
return {}
def add(self, embedding_results: List[NodeEmbeddingResult]) -> List[str]:
"""Add embedding results to index.
Args
embedding_results: List[NodeEmbeddingResult]: list of embedding results
"""
if not self._collection:
raise ValueError("Collection not initialized")
embeddings = []
metadatas = []
ids = []
documents = []
for result in embedding_results:
embeddings.append(result.embedding)
metadatas.append({"document_id": result.doc_id})
ids.append(result.id)
documents.append(result.node.get_text())
self._collection.add(
embeddings=embeddings,
ids=ids,
metadatas=metadatas,
documents=documents,
)
return ids
def delete(self, doc_id: str, **delete_kwargs: Any) -> None:
"""Delete a document.
Args:
doc_id (str): document id
"""
self._collection.delete(where={"document_id": doc_id})
@property
def client(self) -> Any:
"""Return client."""
return self._collection
def query(
self,
query_embedding: List[float],
similarity_top_k: int,
doc_ids: Optional[List[str]] = None,
) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes.
Args:
query_embedding (List[float]): query embedding
similarity_top_k (int): top k most similar nodes
"""
results = self._collection.query(
query_embeddings=query_embedding, n_results=similarity_top_k
)
logging.debug(f"> Top {len(results['documents'])} nodes:")
nodes = []
similarities = []
ids = []
for result in zip(
results["ids"],
results["documents"],
results["metadatas"],
results["distances"],
):
node_id = result[0][0]
node = Node(
doc_id=node_id,
text=result[1][0],
extra_info=result[2][0],
ref_doc_id=result[2][0]["document_id"],
)
nodes.append(node)
similarity_score = 1.0 - math.exp(-result[3][0])
similarities.append(similarity_score)
logging.debug(
f"> [Node {result[0][0]}] [Similarity score: {similarity_score}] "
f"{truncate_text(str(result[1][0]), 100)}"
)
ids.append(node_id)
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
|