from typing import TYPE_CHECKING, Optional, Tuple, Any, Union import numpy as np from pydantic import BaseModel, PrivateAttr from uuid import UUID import chromadb.utils.embedding_functions as ef from chromadb.api.types import ( URI, CollectionMetadata, DataLoader, Embedding, Embeddings, Embeddable, Include, Loadable, Metadata, Metadatas, Document, Documents, Image, Images, URIs, Where, IDs, EmbeddingFunction, GetResult, QueryResult, ID, OneOrMany, WhereDocument, maybe_cast_one_to_many_ids, maybe_cast_one_to_many_embedding, maybe_cast_one_to_many_metadata, maybe_cast_one_to_many_document, maybe_cast_one_to_many_image, maybe_cast_one_to_many_uri, validate_ids, validate_include, validate_metadata, validate_metadatas, validate_where, validate_where_document, validate_n_results, validate_embeddings, validate_embedding_function, ) import logging logger = logging.getLogger(__name__) if TYPE_CHECKING: from chromadb.api import ServerAPI class Collection(BaseModel): name: str id: UUID metadata: Optional[CollectionMetadata] = None tenant: Optional[str] = None database: Optional[str] = None _client: "ServerAPI" = PrivateAttr() _embedding_function: Optional[EmbeddingFunction[Embeddable]] = PrivateAttr() _data_loader: Optional[DataLoader[Loadable]] = PrivateAttr() def __init__( self, client: "ServerAPI", name: str, id: UUID, embedding_function: Optional[ EmbeddingFunction[Embeddable] ] = ef.DefaultEmbeddingFunction(), # type: ignore data_loader: Optional[DataLoader[Loadable]] = None, tenant: Optional[str] = None, database: Optional[str] = None, metadata: Optional[CollectionMetadata] = None, ): super().__init__( name=name, metadata=metadata, id=id, tenant=tenant, database=database ) self._client = client # Check to make sure the embedding function has the right signature, as defined by the EmbeddingFunction protocol if embedding_function is not None: validate_embedding_function(embedding_function) self._embedding_function = embedding_function self._data_loader = data_loader def __repr__(self) -> str: return f"Collection(name={self.name})" def count(self) -> int: """The total number of embeddings added to the database Returns: int: The total number of embeddings added to the database """ return self._client._count(collection_id=self.id) def add( self, ids: OneOrMany[ID], embeddings: Optional[ Union[ OneOrMany[Embedding], OneOrMany[np.ndarray], ] ] = None, metadatas: Optional[OneOrMany[Metadata]] = None, documents: Optional[OneOrMany[Document]] = None, images: Optional[OneOrMany[Image]] = None, uris: Optional[OneOrMany[URI]] = None, ) -> None: """Add embeddings to the data store. Args: ids: The ids of the embeddings you wish to add embeddings: The embeddings to add. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional. metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional. documents: The documents to associate with the embeddings. Optional. images: The images to associate with the embeddings. Optional. uris: The uris of the images to associate with the embeddings. Optional. Returns: None Raises: ValueError: If you don't provide either embeddings or documents ValueError: If the length of ids, embeddings, metadatas, or documents don't match ValueError: If you don't provide an embedding function and don't provide embeddings ValueError: If you provide both embeddings and documents ValueError: If you provide an id that already exists """ ( ids, embeddings, metadatas, documents, images, uris, ) = self._validate_embedding_set( ids, embeddings, metadatas, documents, images, uris ) # We need to compute the embeddings if they're not provided if embeddings is None: # At this point, we know that one of documents or images are provided from the validation above if documents is not None: embeddings = self._embed(input=documents) elif images is not None: embeddings = self._embed(input=images) else: if uris is None: raise ValueError( "You must provide either embeddings, documents, images, or uris." ) if self._data_loader is None: raise ValueError( "You must set a data loader on the collection if loading from URIs." ) embeddings = self._embed(self._data_loader(uris)) self._client._add(ids, self.id, embeddings, metadatas, documents, uris) def get( self, ids: Optional[OneOrMany[ID]] = None, where: Optional[Where] = None, limit: Optional[int] = None, offset: Optional[int] = None, where_document: Optional[WhereDocument] = None, include: Include = ["metadatas", "documents"], ) -> GetResult: """Get embeddings and their associate data from the data store. If no ids or where filter is provided returns all embeddings up to limit starting at offset. Args: ids: The ids of the embeddings to get. Optional. where: A Where type dict used to filter results by. E.g. `{"$and": ["color" : "red", "price": {"$gte": 4.20}]}`. Optional. limit: The number of documents to return. Optional. offset: The offset to start returning results from. Useful for paging results with limit. Optional. where_document: A WhereDocument type dict used to filter by the documents. E.g. `{$contains: {"text": "hello"}}`. Optional. include: A list of what to include in the results. Can contain `"embeddings"`, `"metadatas"`, `"documents"`. Ids are always included. Defaults to `["metadatas", "documents"]`. Optional. Returns: GetResult: A GetResult object containing the results. """ valid_where = validate_where(where) if where else None valid_where_document = ( validate_where_document(where_document) if where_document else None ) valid_ids = validate_ids(maybe_cast_one_to_many_ids(ids)) if ids else None valid_include = validate_include(include, allow_distances=False) if "data" in include and self._data_loader is None: raise ValueError( "You must set a data loader on the collection if loading from URIs." ) # We need to include uris in the result from the API to load datas if "data" in include and "uris" not in include: valid_include.append("uris") get_results = self._client._get( self.id, valid_ids, valid_where, None, limit, offset, where_document=valid_where_document, include=valid_include, ) if ( "data" in include and self._data_loader is not None and get_results["uris"] is not None ): get_results["data"] = self._data_loader(get_results["uris"]) # Remove URIs from the result if they weren't requested if "uris" not in include: get_results["uris"] = None return get_results def peek(self, limit: int = 10) -> GetResult: """Get the first few results in the database up to limit Args: limit: The number of results to return. Returns: GetResult: A GetResult object containing the results. """ return self._client._peek(self.id, limit) def query( self, query_embeddings: Optional[ Union[ OneOrMany[Embedding], OneOrMany[np.ndarray], ] ] = None, query_texts: Optional[OneOrMany[Document]] = None, query_images: Optional[OneOrMany[Image]] = None, query_uris: Optional[OneOrMany[URI]] = None, n_results: int = 10, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, include: Include = ["metadatas", "documents", "distances"], ) -> QueryResult: """Get the n_results nearest neighbor embeddings for provided query_embeddings or query_texts. Args: query_embeddings: The embeddings to get the closes neighbors of. Optional. query_texts: The document texts to get the closes neighbors of. Optional. query_images: The images to get the closes neighbors of. Optional. n_results: The number of neighbors to return for each query_embedding or query_texts. Optional. where: A Where type dict used to filter results by. E.g. `{"$and": ["color" : "red", "price": {"$gte": 4.20}]}`. Optional. where_document: A WhereDocument type dict used to filter by the documents. E.g. `{$contains: {"text": "hello"}}`. Optional. include: A list of what to include in the results. Can contain `"embeddings"`, `"metadatas"`, `"documents"`, `"distances"`. Ids are always included. Defaults to `["metadatas", "documents", "distances"]`. Optional. Returns: QueryResult: A QueryResult object containing the results. Raises: ValueError: If you don't provide either query_embeddings, query_texts, or query_images ValueError: If you provide both query_embeddings and query_texts ValueError: If you provide both query_embeddings and query_images ValueError: If you provide both query_texts and query_images """ # Users must provide only one of query_embeddings, query_texts, query_images, or query_uris if not ( (query_embeddings is not None) ^ (query_texts is not None) ^ (query_images is not None) ^ (query_uris is not None) ): raise ValueError( "You must provide one of query_embeddings, query_texts, query_images, or query_uris." ) valid_where = validate_where(where) if where else {} valid_where_document = ( validate_where_document(where_document) if where_document else {} ) valid_query_embeddings = ( validate_embeddings( self._normalize_embeddings( maybe_cast_one_to_many_embedding(query_embeddings) ) ) if query_embeddings is not None else None ) valid_query_texts = ( maybe_cast_one_to_many_document(query_texts) if query_texts is not None else None ) valid_query_images = ( maybe_cast_one_to_many_image(query_images) if query_images is not None else None ) valid_query_uris = ( maybe_cast_one_to_many_uri(query_uris) if query_uris is not None else None ) valid_include = validate_include(include, allow_distances=True) valid_n_results = validate_n_results(n_results) # If query_embeddings are not provided, we need to compute them from the inputs if valid_query_embeddings is None: if query_texts is not None: valid_query_embeddings = self._embed(input=valid_query_texts) elif query_images is not None: valid_query_embeddings = self._embed(input=valid_query_images) else: if valid_query_uris is None: raise ValueError( "You must provide either query_embeddings, query_texts, query_images, or query_uris." ) if self._data_loader is None: raise ValueError( "You must set a data loader on the collection if loading from URIs." ) valid_query_embeddings = self._embed( self._data_loader(valid_query_uris) ) if "data" in include and "uris" not in include: valid_include.append("uris") query_results = self._client._query( collection_id=self.id, query_embeddings=valid_query_embeddings, n_results=valid_n_results, where=valid_where, where_document=valid_where_document, include=include, ) if ( "data" in include and self._data_loader is not None and query_results["uris"] is not None ): query_results["data"] = [ self._data_loader(uris) for uris in query_results["uris"] ] # Remove URIs from the result if they weren't requested if "uris" not in include: query_results["uris"] = None return query_results def modify( self, name: Optional[str] = None, metadata: Optional[CollectionMetadata] = None ) -> None: """Modify the collection name or metadata Args: name: The updated name for the collection. Optional. metadata: The updated metadata for the collection. Optional. Returns: None """ if metadata is not None: validate_metadata(metadata) if "hnsw:space" in metadata: raise ValueError( "Changing the distance function of a collection once it is created is not supported currently.") self._client._modify(id=self.id, new_name=name, new_metadata=metadata) if name: self.name = name if metadata: self.metadata = metadata def update( self, ids: OneOrMany[ID], embeddings: Optional[ Union[ OneOrMany[Embedding], OneOrMany[np.ndarray], ] ] = None, metadatas: Optional[OneOrMany[Metadata]] = None, documents: Optional[OneOrMany[Document]] = None, images: Optional[OneOrMany[Image]] = None, uris: Optional[OneOrMany[URI]] = None, ) -> None: """Update the embeddings, metadatas or documents for provided ids. Args: ids: The ids of the embeddings to update embeddings: The embeddings to update. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional. metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional. documents: The documents to associate with the embeddings. Optional. images: The images to associate with the embeddings. Optional. Returns: None """ ( ids, embeddings, metadatas, documents, images, uris, ) = self._validate_embedding_set( ids, embeddings, metadatas, documents, images, uris, require_embeddings_or_data=False, ) if embeddings is None: if documents is not None: embeddings = self._embed(input=documents) elif images is not None: embeddings = self._embed(input=images) self._client._update(self.id, ids, embeddings, metadatas, documents, uris) def upsert( self, ids: OneOrMany[ID], embeddings: Optional[ Union[ OneOrMany[Embedding], OneOrMany[np.ndarray], ] ] = None, metadatas: Optional[OneOrMany[Metadata]] = None, documents: Optional[OneOrMany[Document]] = None, images: Optional[OneOrMany[Image]] = None, uris: Optional[OneOrMany[URI]] = None, ) -> None: """Update the embeddings, metadatas or documents for provided ids, or create them if they don't exist. Args: ids: The ids of the embeddings to update embeddings: The embeddings to add. If None, embeddings will be computed based on the documents using the embedding_function set for the Collection. Optional. metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional. documents: The documents to associate with the embeddings. Optional. Returns: None """ ( ids, embeddings, metadatas, documents, images, uris, ) = self._validate_embedding_set( ids, embeddings, metadatas, documents, images, uris ) if embeddings is None: if documents is not None: embeddings = self._embed(input=documents) else: embeddings = self._embed(input=images) self._client._upsert( collection_id=self.id, ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents, uris=uris, ) def delete( self, ids: Optional[IDs] = None, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, ) -> None: """Delete the embeddings based on ids and/or a where filter Args: ids: The ids of the embeddings to delete where: A Where type dict used to filter the delection by. E.g. `{"$and": ["color" : "red", "price": {"$gte": 4.20}]}`. Optional. where_document: A WhereDocument type dict used to filter the deletion by the document content. E.g. `{$contains: {"text": "hello"}}`. Optional. Returns: None Raises: ValueError: If you don't provide either ids, where, or where_document """ ids = validate_ids(maybe_cast_one_to_many_ids(ids)) if ids else None where = validate_where(where) if where else None where_document = ( validate_where_document(where_document) if where_document else None ) self._client._delete(self.id, ids, where, where_document) def _validate_embedding_set( self, ids: OneOrMany[ID], embeddings: Optional[ Union[ OneOrMany[Embedding], OneOrMany[np.ndarray], ] ], metadatas: Optional[OneOrMany[Metadata]], documents: Optional[OneOrMany[Document]], images: Optional[OneOrMany[Image]] = None, uris: Optional[OneOrMany[URI]] = None, require_embeddings_or_data: bool = True, ) -> Tuple[ IDs, Optional[Embeddings], Optional[Metadatas], Optional[Documents], Optional[Images], Optional[URIs], ]: valid_ids = validate_ids(maybe_cast_one_to_many_ids(ids)) valid_embeddings = ( validate_embeddings( self._normalize_embeddings(maybe_cast_one_to_many_embedding(embeddings)) ) if embeddings is not None else None ) valid_metadatas = ( validate_metadatas(maybe_cast_one_to_many_metadata(metadatas)) if metadatas is not None else None ) valid_documents = ( maybe_cast_one_to_many_document(documents) if documents is not None else None ) valid_images = ( maybe_cast_one_to_many_image(images) if images is not None else None ) valid_uris = maybe_cast_one_to_many_uri(uris) if uris is not None else None # Check that one of embeddings or ducuments or images is provided if require_embeddings_or_data: if ( valid_embeddings is None and valid_documents is None and valid_images is None and valid_uris is None ): raise ValueError( "You must provide embeddings, documents, images, or uris." ) # Only one of documents or images can be provided if valid_documents is not None and valid_images is not None: raise ValueError("You can only provide documents or images, not both.") # Check that, if they're provided, the lengths of the arrays match the length of ids if valid_embeddings is not None and len(valid_embeddings) != len(valid_ids): raise ValueError( f"Number of embeddings {len(valid_embeddings)} must match number of ids {len(valid_ids)}" ) if valid_metadatas is not None and len(valid_metadatas) != len(valid_ids): raise ValueError( f"Number of metadatas {len(valid_metadatas)} must match number of ids {len(valid_ids)}" ) if valid_documents is not None and len(valid_documents) != len(valid_ids): raise ValueError( f"Number of documents {len(valid_documents)} must match number of ids {len(valid_ids)}" ) if valid_images is not None and len(valid_images) != len(valid_ids): raise ValueError( f"Number of images {len(valid_images)} must match number of ids {len(valid_ids)}" ) if valid_uris is not None and len(valid_uris) != len(valid_ids): raise ValueError( f"Number of uris {len(valid_uris)} must match number of ids {len(valid_ids)}" ) return ( valid_ids, valid_embeddings, valid_metadatas, valid_documents, valid_images, valid_uris, ) @staticmethod def _normalize_embeddings( embeddings: Union[ OneOrMany[Embedding], OneOrMany[np.ndarray], ] ) -> Embeddings: if isinstance(embeddings, np.ndarray): return embeddings.tolist() return embeddings def _embed(self, input: Any) -> Embeddings: if self._embedding_function is None: raise ValueError( "You must provide an embedding function to compute embeddings." "https://docs.trychroma.com/embeddings" ) return self._embedding_function(input=input)