Spaces:
Running
Running
from typing import Dict, Any, Optional, List | |
import chromadb | |
from chromadb.config import Settings | |
from chromadb.api.types import ( | |
Where, | |
GetResult, | |
QueryResult, | |
) | |
from ..embedding_provider import EmbeddingProvider | |
from .vector_store import VectorStore | |
class ChromaDB(VectorStore): | |
""" | |
ChromaDB is an example of a vector-store class implementation. | |
See more: | |
https://github.com/chroma-core/chroma | |
""" | |
def __init__( | |
self, | |
configs: Dict[str, Any] = {}, | |
db_path: str = ".chromadb", | |
embedding_function: Optional[EmbeddingProvider] = None, | |
collection_name: Optional[str] = None, | |
) -> None: | |
self.client = chromadb.PersistentClient( | |
path=db_path | |
) | |
self.configs = configs | |
self.embedding_function = embedding_function | |
self._collection_name = collection_name | |
self.collection = self.client.get_or_create_collection( | |
name = self.collection_name or "default_collection" | |
) | |
# self.logger = get_logger(self.__class__.__name__) | |
def db_path(self) -> str: | |
return self.client.get_settings().persist_directory | |
def db_path(self, value: str) -> None: | |
self.client = chromadb.PersistentClient(path=value) | |
self.collection = self.client.get_or_create_collection( | |
name = self.collection_name or "default_collection" | |
) | |
def collection_name(self): | |
return self._collection_name | |
def collection_name(self, value): | |
self._collection_name = value | |
self.collection.modify(name=value) | |
def add_data( | |
self, | |
documents: List[str], | |
ids: List[str], | |
metadatas: Optional[List[Dict[str, Any]]] = None, | |
**optional_kwargs | |
) -> None: | |
""" | |
Add data to the collection by creating embeddings for them. | |
Args: | |
documents (List[str]): List of documents to add. | |
ids (List[str]): List of ids for the documents. | |
metadatas (Optional[List[Dict[str, Any]]]): List of metadata for the documents. | |
**optional_kwargs: Additional keyword arguments (see collection.add for more). | |
""" | |
try: | |
params = { | |
"documents": documents, | |
"ids": ids, | |
**optional_kwargs | |
} | |
params["metadatas"] = metadatas or None | |
# If an embedding function is provided, create embeddings for the documents | |
if self.embedding_function: | |
embeddings = self.embedding_function.embed_documents(documents) | |
params["embeddings"] = embeddings | |
self.collection.add(**params) | |
except Exception as e: | |
# self.logger.error(f"Error adding data to collection: {e}") | |
print(f"Error adding data to collection: {e}") | |
raise e | |
def search( | |
self, | |
query_text: Optional[List[str]] = None, | |
query_embedding: Optional[List[List[float]]] = None, | |
n_results: int = 10, | |
**optional_kwargs | |
) -> QueryResult: | |
""" | |
Query the collection for similar documents. | |
Args: | |
query_text (Optional[List[str]]): List of query texts. | |
query_embedding (Optional[List[List[float]]]): List of query embeddings. | |
n_results (int): Number of results to return. | |
**optional_kwargs: Additional keyword arguments (see collection.query for more). | |
Returns: | |
QueryResult: The result of the query. | |
""" | |
try: | |
if query_text is None and query_embedding is None: | |
raise ValueError("Either query_text or query_embedding must be provided.") | |
params = { | |
"n_results": n_results, | |
**optional_kwargs | |
} | |
if query_text and query_embedding is None: | |
if self.embedding_function: | |
query_embedding = self.embedding_function.embed_query(query_text) | |
params["query_embeddings"] = query_embedding | |
else: | |
params["query_text"] = query_text | |
elif query_embedding and query_text is None: | |
params["query_embeddings"] = query_embedding | |
elif query_embedding and query_text: | |
params["query_embeddings"] = query_embedding | |
if self.embedding_function: | |
embeddings = self.embedding_function.embed_query(query_text) | |
params["query_embeddings"] = query_embedding.extend(embeddings) | |
else: | |
params["query_text"] = query_text | |
return self.collection.query(**params) | |
except Exception as e: | |
# self.logger.error(f"Error querying data from collection: {e}") | |
print(f"Error querying data from collection: {e}") | |
raise e | |
def query_by_id_or_metadata( | |
self, | |
ids: Optional[List[str]] = None, | |
where: Optional[Where] = None, | |
n_results: int = 10, | |
**optional_kwargs | |
) -> GetResult: | |
""" | |
Query the collection for similar documents. | |
Args: | |
ids (Optional[List[str]]): List of ids to query. | |
where (Optional[Where]): Where clause to query. | |
n_results (int): Number of results to return. | |
**optional_kwargs: Additional keyword arguments (see collection.get for more). | |
Returns: | |
GetResult: The result of the query. | |
""" | |
try: | |
if ids is None and where is None: | |
raise ValueError("Either ids or where must be provided.") | |
params = { | |
"n_results": n_results, | |
**optional_kwargs | |
} | |
if ids: | |
params["ids"] = ids | |
if where: | |
params["where"] = where | |
return self.collection.get(**params) | |
except Exception as e: | |
# self.logger.error(f"Error querying data from collection: {e}") | |
print(f"Error querying data from collection: {e}") | |
raise e | |
def delete_collection(self, collection_name: Optional[str] = None) -> None: | |
""" | |
Delete a specific collection from the ChromaDB. | |
Args: | |
collection_name (Optional[str]): Name of collection to delete. | |
Uses class's collection_name if not provided. | |
""" | |
try: | |
target_collection = collection_name or self.collection_name | |
if not target_collection: | |
raise ValueError("No collection name provided") | |
self.client.delete_collection(name=target_collection) | |
print(f"Collection '{target_collection}' deleted successfully.") | |
except Exception as e: | |
print(f"Error deleting collection: {e}") | |