teddyllm's picture
Upload 50 files
b917edb verified
from typing import Dict, Any, Optional, List
import chromadb
from chromadb.config import Settings
from chromadb.api.types import (
Where,
GetResult,
QueryResult,
)
from ..embedding_provider import EmbeddingProvider
from .vector_store import VectorStore
class ChromaDB(VectorStore):
"""
ChromaDB is an example of a vector-store class implementation.
See more:
https://github.com/chroma-core/chroma
"""
def __init__(
self,
configs: Dict[str, Any] = {},
db_path: str = ".chromadb",
embedding_function: Optional[EmbeddingProvider] = None,
collection_name: Optional[str] = None,
) -> None:
self.client = chromadb.PersistentClient(
path=db_path
)
self.configs = configs
self.embedding_function = embedding_function
self._collection_name = collection_name
self.collection = self.client.get_or_create_collection(
name = self.collection_name or "default_collection"
)
# self.logger = get_logger(self.__class__.__name__)
@property
def db_path(self) -> str:
return self.client.get_settings().persist_directory
@db_path.setter
def db_path(self, value: str) -> None:
self.client = chromadb.PersistentClient(path=value)
self.collection = self.client.get_or_create_collection(
name = self.collection_name or "default_collection"
)
@property
def collection_name(self):
return self._collection_name
@collection_name.setter
def collection_name(self, value):
self._collection_name = value
self.collection.modify(name=value)
def add_data(
self,
documents: List[str],
ids: List[str],
metadatas: Optional[List[Dict[str, Any]]] = None,
**optional_kwargs
) -> None:
"""
Add data to the collection by creating embeddings for them.
Args:
documents (List[str]): List of documents to add.
ids (List[str]): List of ids for the documents.
metadatas (Optional[List[Dict[str, Any]]]): List of metadata for the documents.
**optional_kwargs: Additional keyword arguments (see collection.add for more).
"""
try:
params = {
"documents": documents,
"ids": ids,
**optional_kwargs
}
params["metadatas"] = metadatas or None
# If an embedding function is provided, create embeddings for the documents
if self.embedding_function:
embeddings = self.embedding_function.embed_documents(documents)
params["embeddings"] = embeddings
self.collection.add(**params)
except Exception as e:
# self.logger.error(f"Error adding data to collection: {e}")
print(f"Error adding data to collection: {e}")
raise e
def search(
self,
query_text: Optional[List[str]] = None,
query_embedding: Optional[List[List[float]]] = None,
n_results: int = 10,
**optional_kwargs
) -> QueryResult:
"""
Query the collection for similar documents.
Args:
query_text (Optional[List[str]]): List of query texts.
query_embedding (Optional[List[List[float]]]): List of query embeddings.
n_results (int): Number of results to return.
**optional_kwargs: Additional keyword arguments (see collection.query for more).
Returns:
QueryResult: The result of the query.
"""
try:
if query_text is None and query_embedding is None:
raise ValueError("Either query_text or query_embedding must be provided.")
params = {
"n_results": n_results,
**optional_kwargs
}
if query_text and query_embedding is None:
if self.embedding_function:
query_embedding = self.embedding_function.embed_query(query_text)
params["query_embeddings"] = query_embedding
else:
params["query_text"] = query_text
elif query_embedding and query_text is None:
params["query_embeddings"] = query_embedding
elif query_embedding and query_text:
params["query_embeddings"] = query_embedding
if self.embedding_function:
embeddings = self.embedding_function.embed_query(query_text)
params["query_embeddings"] = query_embedding.extend(embeddings)
else:
params["query_text"] = query_text
return self.collection.query(**params)
except Exception as e:
# self.logger.error(f"Error querying data from collection: {e}")
print(f"Error querying data from collection: {e}")
raise e
def query_by_id_or_metadata(
self,
ids: Optional[List[str]] = None,
where: Optional[Where] = None,
n_results: int = 10,
**optional_kwargs
) -> GetResult:
"""
Query the collection for similar documents.
Args:
ids (Optional[List[str]]): List of ids to query.
where (Optional[Where]): Where clause to query.
n_results (int): Number of results to return.
**optional_kwargs: Additional keyword arguments (see collection.get for more).
Returns:
GetResult: The result of the query.
"""
try:
if ids is None and where is None:
raise ValueError("Either ids or where must be provided.")
params = {
"n_results": n_results,
**optional_kwargs
}
if ids:
params["ids"] = ids
if where:
params["where"] = where
return self.collection.get(**params)
except Exception as e:
# self.logger.error(f"Error querying data from collection: {e}")
print(f"Error querying data from collection: {e}")
raise e
def delete_collection(self, collection_name: Optional[str] = None) -> None:
"""
Delete a specific collection from the ChromaDB.
Args:
collection_name (Optional[str]): Name of collection to delete.
Uses class's collection_name if not provided.
"""
try:
target_collection = collection_name or self.collection_name
if not target_collection:
raise ValueError("No collection name provided")
self.client.delete_collection(name=target_collection)
print(f"Collection '{target_collection}' deleted successfully.")
except Exception as e:
print(f"Error deleting collection: {e}")