AbeerTrial's picture
Duplicate from AbeerTrial/SOAPAssist
35b22df
raw
history blame
3.89 kB
"""Chroma Reader."""
from typing import Any, List, Optional, Union
from gpt_index.readers.base import BaseReader
from gpt_index.readers.schema.base import Document
class ChromaReader(BaseReader):
"""Chroma reader.
Retrieve documents from existing persisted Chroma collections.
Args:
collection_name: Name of the peristed collection.
persist_directory: Directory where the collection is persisted.
"""
def __init__(
self,
collection_name: str,
persist_directory: Optional[str] = None,
host: str = "localhost",
port: int = 8000,
) -> None:
"""Initialize with parameters."""
import_err_msg = (
"`chromadb` package not found, please run `pip install chromadb`"
)
try:
import chromadb # noqa: F401
except ImportError:
raise ImportError(import_err_msg)
if collection_name is None:
raise ValueError("Please provide a collection name.")
from chromadb.config import Settings
if persist_directory:
self._client = chromadb.Client(
Settings(
chroma_db_impl="duckdb+parquet", persist_directory=persist_directory
)
)
else:
self._client = chromadb.Client(
Settings(
chroma_api_impl="rest",
chroma_server_host=host,
chroma_server_http_port=port,
)
)
self._collection = self._client.get_collection(collection_name)
def create_documents(self, results: Any) -> List[Document]:
"""Create documents from the results.
Args:
results: Results from the query.
Returns:
List of documents.
"""
documents = []
for result in zip(
results["ids"],
results["documents"],
results["embeddings"],
results["metadatas"],
):
document = Document(
doc_id=result[0][0],
text=result[1][0],
embedding=result[2][0],
extra_info=result[3][0],
)
documents.append(document)
return documents
def load_data(
self,
query_embedding: Optional[List[float]] = None,
limit: int = 10,
where: Optional[dict] = None,
where_document: Optional[dict] = None,
query: Optional[Union[str, List[str]]] = None,
) -> Any:
"""Load data from the collection.
Args:
limit: Number of results to return.
where: Filter results by metadata. {"metadata_field": "is_equal_to_this"}
where_document: Filter results by document. {"$contains":"search_string"}
Returns:
List of documents.
"""
where = where or {}
where_document = where_document or {}
if query_embedding is not None:
results = self._collection.search(
query_embedding=query_embedding,
n_results=limit,
where=where,
where_document=where_document,
include=["metadatas", "documents", "distances", "embeddings"],
)
return self.create_documents(results)
elif query is not None:
query = query if isinstance(query, list) else [query]
results = self._collection.query(
query_texts=query,
n_results=limit,
where=where,
where_document=where_document,
include=["metadatas", "documents", "distances", "embeddings"],
)
return self.create_documents(results)
else:
raise ValueError("Please provide either query embedding or query.")