File size: 3,891 Bytes
35b22df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""Chroma Reader."""

from typing import Any, List, Optional, Union

from gpt_index.readers.base import BaseReader
from gpt_index.readers.schema.base import Document


class ChromaReader(BaseReader):
    """Chroma reader.

    Retrieve documents from existing persisted Chroma collections.

    Args:
        collection_name: Name of the peristed collection.
        persist_directory: Directory where the collection is persisted.

    """

    def __init__(
        self,
        collection_name: str,
        persist_directory: Optional[str] = None,
        host: str = "localhost",
        port: int = 8000,
    ) -> None:
        """Initialize with parameters."""
        import_err_msg = (
            "`chromadb` package not found, please run `pip install chromadb`"
        )
        try:
            import chromadb  # noqa: F401
        except ImportError:
            raise ImportError(import_err_msg)

        if collection_name is None:
            raise ValueError("Please provide a collection name.")
        from chromadb.config import Settings

        if persist_directory:
            self._client = chromadb.Client(
                Settings(
                    chroma_db_impl="duckdb+parquet", persist_directory=persist_directory
                )
            )
        else:
            self._client = chromadb.Client(
                Settings(
                    chroma_api_impl="rest",
                    chroma_server_host=host,
                    chroma_server_http_port=port,
                )
            )
        self._collection = self._client.get_collection(collection_name)

    def create_documents(self, results: Any) -> List[Document]:
        """Create documents from the results.

        Args:
            results: Results from the query.

        Returns:
            List of documents.
        """
        documents = []
        for result in zip(
            results["ids"],
            results["documents"],
            results["embeddings"],
            results["metadatas"],
        ):
            document = Document(
                doc_id=result[0][0],
                text=result[1][0],
                embedding=result[2][0],
                extra_info=result[3][0],
            )
            documents.append(document)

        return documents

    def load_data(
        self,
        query_embedding: Optional[List[float]] = None,
        limit: int = 10,
        where: Optional[dict] = None,
        where_document: Optional[dict] = None,
        query: Optional[Union[str, List[str]]] = None,
    ) -> Any:
        """Load data from the collection.

        Args:
            limit: Number of results to return.
            where: Filter results by metadata. {"metadata_field": "is_equal_to_this"}
            where_document: Filter results by document. {"$contains":"search_string"}

        Returns:
            List of documents.
        """
        where = where or {}
        where_document = where_document or {}
        if query_embedding is not None:
            results = self._collection.search(
                query_embedding=query_embedding,
                n_results=limit,
                where=where,
                where_document=where_document,
                include=["metadatas", "documents", "distances", "embeddings"],
            )
            return self.create_documents(results)
        elif query is not None:
            query = query if isinstance(query, list) else [query]
            results = self._collection.query(
                query_texts=query,
                n_results=limit,
                where=where,
                where_document=where_document,
                include=["metadatas", "documents", "distances", "embeddings"],
            )
            return self.create_documents(results)
        else:
            raise ValueError("Please provide either query embedding or query.")