Spaces:
Sleeping
Sleeping
from typing import Optional, Union, TypeVar, List, Dict, Any, Tuple, cast | |
from numpy.typing import NDArray | |
import numpy as np | |
from typing_extensions import Literal, TypedDict, Protocol | |
import chromadb.errors as errors | |
from chromadb.types import ( | |
Metadata, | |
UpdateMetadata, | |
Vector, | |
LiteralValue, | |
LogicalOperator, | |
WhereOperator, | |
OperatorExpression, | |
Where, | |
WhereDocumentOperator, | |
WhereDocument, | |
) | |
from inspect import signature | |
from tenacity import retry | |
# Re-export types from chromadb.types | |
__all__ = ["Metadata", "Where", "WhereDocument", "UpdateCollectionMetadata"] | |
T = TypeVar("T") | |
OneOrMany = Union[T, List[T]] | |
# URIs | |
URI = str | |
URIs = List[URI] | |
def maybe_cast_one_to_many_uri(target: OneOrMany[URI]) -> URIs: | |
if isinstance(target, str): | |
# One URI | |
return cast(URIs, [target]) | |
# Already a sequence | |
return cast(URIs, target) | |
# IDs | |
ID = str | |
IDs = List[ID] | |
def maybe_cast_one_to_many_ids(target: OneOrMany[ID]) -> IDs: | |
if isinstance(target, str): | |
# One ID | |
return cast(IDs, [target]) | |
# Already a sequence | |
return cast(IDs, target) | |
# Embeddings | |
Embedding = Vector | |
Embeddings = List[Embedding] | |
def maybe_cast_one_to_many_embedding(target: OneOrMany[Embedding]) -> Embeddings: | |
if isinstance(target, List): | |
# One Embedding | |
if isinstance(target[0], (int, float)): | |
return cast(Embeddings, [target]) | |
# Already a sequence | |
return cast(Embeddings, target) | |
# Metadatas | |
Metadatas = List[Metadata] | |
def maybe_cast_one_to_many_metadata(target: OneOrMany[Metadata]) -> Metadatas: | |
# One Metadata dict | |
if isinstance(target, dict): | |
return cast(Metadatas, [target]) | |
# Already a sequence | |
return cast(Metadatas, target) | |
CollectionMetadata = Dict[str, Any] | |
UpdateCollectionMetadata = UpdateMetadata | |
# Documents | |
Document = str | |
Documents = List[Document] | |
def is_document(target: Any) -> bool: | |
if not isinstance(target, str): | |
return False | |
return True | |
def maybe_cast_one_to_many_document(target: OneOrMany[Document]) -> Documents: | |
# One Document | |
if is_document(target): | |
return cast(Documents, [target]) | |
# Already a sequence | |
return cast(Documents, target) | |
# Images | |
ImageDType = Union[np.uint, np.int_, np.float_] | |
Image = NDArray[ImageDType] | |
Images = List[Image] | |
def is_image(target: Any) -> bool: | |
if not isinstance(target, np.ndarray): | |
return False | |
if len(target.shape) < 2: | |
return False | |
return True | |
def maybe_cast_one_to_many_image(target: OneOrMany[Image]) -> Images: | |
if is_image(target): | |
return cast(Images, [target]) | |
# Already a sequence | |
return cast(Images, target) | |
Parameter = TypeVar("Parameter", Document, Image, Embedding, Metadata, ID) | |
# This should ust be List[Literal["documents", "embeddings", "metadatas", "distances"]] | |
# However, this provokes an incompatibility with the Overrides library and Python 3.7 | |
Include = List[ | |
Union[ | |
Literal["documents"], | |
Literal["embeddings"], | |
Literal["metadatas"], | |
Literal["distances"], | |
Literal["uris"], | |
Literal["data"], | |
] | |
] | |
# Re-export types from chromadb.types | |
LiteralValue = LiteralValue | |
LogicalOperator = LogicalOperator | |
WhereOperator = WhereOperator | |
OperatorExpression = OperatorExpression | |
Where = Where | |
WhereDocumentOperator = WhereDocumentOperator | |
Embeddable = Union[Documents, Images] | |
D = TypeVar("D", bound=Embeddable, contravariant=True) | |
Loadable = List[Optional[Image]] | |
L = TypeVar("L", covariant=True, bound=Loadable) | |
class GetResult(TypedDict): | |
ids: List[ID] | |
embeddings: Optional[List[Embedding]] | |
documents: Optional[List[Document]] | |
uris: Optional[URIs] | |
data: Optional[Loadable] | |
metadatas: Optional[List[Metadata]] | |
class QueryResult(TypedDict): | |
ids: List[IDs] | |
embeddings: Optional[List[List[Embedding]]] | |
documents: Optional[List[List[Document]]] | |
uris: Optional[List[List[URI]]] | |
data: Optional[List[Loadable]] | |
metadatas: Optional[List[List[Metadata]]] | |
distances: Optional[List[List[float]]] | |
class IndexMetadata(TypedDict): | |
dimensionality: int | |
# The current number of elements in the index (total = additions - deletes) | |
curr_elements: int | |
# The auto-incrementing ID of the last inserted element, never decreases so | |
# can be used as a count of total historical size. Should increase by 1 every add. | |
# Assume cannot overflow | |
total_elements_added: int | |
time_created: float | |
class EmbeddingFunction(Protocol[D]): | |
def __call__(self, input: D) -> Embeddings: | |
... | |
def __init_subclass__(cls) -> None: | |
super().__init_subclass__() | |
# Raise an exception if __call__ is not defined since it is expected to be defined | |
call = getattr(cls, "__call__") | |
def __call__(self: EmbeddingFunction[D], input: D) -> Embeddings: | |
result = call(self, input) | |
return validate_embeddings(maybe_cast_one_to_many_embedding(result)) | |
setattr(cls, "__call__", __call__) | |
def embed_with_retries(self, input: D, **retry_kwargs: Dict) -> Embeddings: | |
return retry(**retry_kwargs)(self.__call__)(input) | |
def validate_embedding_function( | |
embedding_function: EmbeddingFunction[Embeddable], | |
) -> None: | |
function_signature = signature( | |
embedding_function.__class__.__call__ | |
).parameters.keys() | |
protocol_signature = signature(EmbeddingFunction.__call__).parameters.keys() | |
if not function_signature == protocol_signature: | |
raise ValueError( | |
f"Expected EmbeddingFunction.__call__ to have the following signature: {protocol_signature}, got {function_signature}\n" | |
"Please see https://docs.trychroma.com/embeddings for details of the EmbeddingFunction interface.\n" | |
"Please note the recent change to the EmbeddingFunction interface: https://docs.trychroma.com/migration#migration-to-0416---november-7-2023 \n" | |
) | |
class DataLoader(Protocol[L]): | |
def __call__(self, uris: URIs) -> L: | |
... | |
def validate_ids(ids: IDs) -> IDs: | |
"""Validates ids to ensure it is a list of strings""" | |
if not isinstance(ids, list): | |
raise ValueError(f"Expected IDs to be a list, got {ids}") | |
if len(ids) == 0: | |
raise ValueError(f"Expected IDs to be a non-empty list, got {ids}") | |
seen = set() | |
dups = set() | |
for id_ in ids: | |
if not isinstance(id_, str): | |
raise ValueError(f"Expected ID to be a str, got {id_}") | |
if id_ in seen: | |
dups.add(id_) | |
else: | |
seen.add(id_) | |
if dups: | |
n_dups = len(dups) | |
if n_dups < 10: | |
example_string = ", ".join(dups) | |
message = ( | |
f"Expected IDs to be unique, found duplicates of: {example_string}" | |
) | |
else: | |
examples = [] | |
for idx, dup in enumerate(dups): | |
examples.append(dup) | |
if idx == 10: | |
break | |
example_string = ( | |
f"{', '.join(examples[:5])}, ..., {', '.join(examples[-5:])}" | |
) | |
message = f"Expected IDs to be unique, found {n_dups} duplicated IDs: {example_string}" | |
raise errors.DuplicateIDError(message) | |
return ids | |
def validate_metadata(metadata: Metadata) -> Metadata: | |
"""Validates metadata to ensure it is a dictionary of strings to strings, ints, floats or bools""" | |
if not isinstance(metadata, dict) and metadata is not None: | |
raise ValueError(f"Expected metadata to be a dict or None, got {metadata}") | |
if metadata is None: | |
return metadata | |
if len(metadata) == 0: | |
raise ValueError(f"Expected metadata to be a non-empty dict, got {metadata}") | |
for key, value in metadata.items(): | |
if not isinstance(key, str): | |
raise TypeError( | |
f"Expected metadata key to be a str, got {key} which is a {type(key)}" | |
) | |
# isinstance(True, int) evaluates to True, so we need to check for bools separately | |
if not isinstance(value, bool) and not isinstance(value, (str, int, float)): | |
raise ValueError( | |
f"Expected metadata value to be a str, int, float or bool, got {value} which is a {type(value)}" | |
) | |
return metadata | |
def validate_update_metadata(metadata: UpdateMetadata) -> UpdateMetadata: | |
"""Validates metadata to ensure it is a dictionary of strings to strings, ints, floats or bools""" | |
if not isinstance(metadata, dict) and metadata is not None: | |
raise ValueError(f"Expected metadata to be a dict or None, got {metadata}") | |
if metadata is None: | |
return metadata | |
if len(metadata) == 0: | |
raise ValueError(f"Expected metadata to be a non-empty dict, got {metadata}") | |
for key, value in metadata.items(): | |
if not isinstance(key, str): | |
raise ValueError(f"Expected metadata key to be a str, got {key}") | |
# isinstance(True, int) evaluates to True, so we need to check for bools separately | |
if not isinstance(value, bool) and not isinstance( | |
value, (str, int, float, type(None)) | |
): | |
raise ValueError( | |
f"Expected metadata value to be a str, int, or float, got {value}" | |
) | |
return metadata | |
def validate_metadatas(metadatas: Metadatas) -> Metadatas: | |
"""Validates metadatas to ensure it is a list of dictionaries of strings to strings, ints, floats or bools""" | |
if not isinstance(metadatas, list): | |
raise ValueError(f"Expected metadatas to be a list, got {metadatas}") | |
for metadata in metadatas: | |
validate_metadata(metadata) | |
return metadatas | |
def validate_where(where: Where) -> Where: | |
""" | |
Validates where to ensure it is a dictionary of strings to strings, ints, floats or operator expressions, | |
or in the case of $and and $or, a list of where expressions | |
""" | |
if not isinstance(where, dict): | |
raise ValueError(f"Expected where to be a dict, got {where}") | |
if len(where) != 1: | |
raise ValueError(f"Expected where to have exactly one operator, got {where}") | |
for key, value in where.items(): | |
if not isinstance(key, str): | |
raise ValueError(f"Expected where key to be a str, got {key}") | |
if ( | |
key != "$and" | |
and key != "$or" | |
and key != "$in" | |
and key != "$nin" | |
and not isinstance(value, (str, int, float, dict)) | |
): | |
raise ValueError( | |
f"Expected where value to be a str, int, float, or operator expression, got {value}" | |
) | |
if key == "$and" or key == "$or": | |
if not isinstance(value, list): | |
raise ValueError( | |
f"Expected where value for $and or $or to be a list of where expressions, got {value}" | |
) | |
if len(value) <= 1: | |
raise ValueError( | |
f"Expected where value for $and or $or to be a list with at least two where expressions, got {value}" | |
) | |
for where_expression in value: | |
validate_where(where_expression) | |
# Value is a operator expression | |
if isinstance(value, dict): | |
# Ensure there is only one operator | |
if len(value) != 1: | |
raise ValueError( | |
f"Expected operator expression to have exactly one operator, got {value}" | |
) | |
for operator, operand in value.items(): | |
# Only numbers can be compared with gt, gte, lt, lte | |
if operator in ["$gt", "$gte", "$lt", "$lte"]: | |
if not isinstance(operand, (int, float)): | |
raise ValueError( | |
f"Expected operand value to be an int or a float for operator {operator}, got {operand}" | |
) | |
if operator in ["$in", "$nin"]: | |
if not isinstance(operand, list): | |
raise ValueError( | |
f"Expected operand value to be an list for operator {operator}, got {operand}" | |
) | |
if operator not in [ | |
"$gt", | |
"$gte", | |
"$lt", | |
"$lte", | |
"$ne", | |
"$eq", | |
"$in", | |
"$nin", | |
]: | |
raise ValueError( | |
f"Expected where operator to be one of $gt, $gte, $lt, $lte, $ne, $eq, $in, $nin, " | |
f"got {operator}" | |
) | |
if not isinstance(operand, (str, int, float, list)): | |
raise ValueError( | |
f"Expected where operand value to be a str, int, float, or list of those type, got {operand}" | |
) | |
if isinstance(operand, list) and ( | |
len(operand) == 0 | |
or not all(isinstance(x, type(operand[0])) for x in operand) | |
): | |
raise ValueError( | |
f"Expected where operand value to be a non-empty list, and all values to obe of the same type " | |
f"got {operand}" | |
) | |
return where | |
def validate_where_document(where_document: WhereDocument) -> WhereDocument: | |
""" | |
Validates where_document to ensure it is a dictionary of WhereDocumentOperator to strings, or in the case of $and and $or, | |
a list of where_document expressions | |
""" | |
if not isinstance(where_document, dict): | |
raise ValueError( | |
f"Expected where document to be a dictionary, got {where_document}" | |
) | |
if len(where_document) != 1: | |
raise ValueError( | |
f"Expected where document to have exactly one operator, got {where_document}" | |
) | |
for operator, operand in where_document.items(): | |
if operator not in ["$contains", "$not_contains", "$and", "$or"]: | |
raise ValueError( | |
f"Expected where document operator to be one of $contains, $and, $or, got {operator}" | |
) | |
if operator == "$and" or operator == "$or": | |
if not isinstance(operand, list): | |
raise ValueError( | |
f"Expected document value for $and or $or to be a list of where document expressions, got {operand}" | |
) | |
if len(operand) <= 1: | |
raise ValueError( | |
f"Expected document value for $and or $or to be a list with at least two where document expressions, got {operand}" | |
) | |
for where_document_expression in operand: | |
validate_where_document(where_document_expression) | |
# Value is a $contains operator | |
elif not isinstance(operand, str): | |
raise ValueError( | |
f"Expected where document operand value for operator $contains to be a str, got {operand}" | |
) | |
elif len(operand) == 0: | |
raise ValueError( | |
"Expected where document operand value for operator $contains to be a non-empty str" | |
) | |
return where_document | |
def validate_include(include: Include, allow_distances: bool) -> Include: | |
"""Validates include to ensure it is a list of strings. Since get does not allow distances, allow_distances is used | |
to control if distances is allowed""" | |
if not isinstance(include, list): | |
raise ValueError(f"Expected include to be a list, got {include}") | |
for item in include: | |
if not isinstance(item, str): | |
raise ValueError(f"Expected include item to be a str, got {item}") | |
allowed_values = ["embeddings", "documents", "metadatas", "uris", "data"] | |
if allow_distances: | |
allowed_values.append("distances") | |
if item not in allowed_values: | |
raise ValueError( | |
f"Expected include item to be one of {', '.join(allowed_values)}, got {item}" | |
) | |
return include | |
def validate_n_results(n_results: int) -> int: | |
"""Validates n_results to ensure it is a positive Integer. Since hnswlib does not allow n_results to be negative.""" | |
# Check Number of requested results | |
if not isinstance(n_results, int): | |
raise ValueError( | |
f"Expected requested number of results to be a int, got {n_results}" | |
) | |
if n_results <= 0: | |
raise TypeError( | |
f"Number of requested results {n_results}, cannot be negative, or zero." | |
) | |
return n_results | |
def validate_embeddings(embeddings: Embeddings) -> Embeddings: | |
"""Validates embeddings to ensure it is a list of list of ints, or floats""" | |
if not isinstance(embeddings, list): | |
raise ValueError(f"Expected embeddings to be a list, got {embeddings}") | |
if len(embeddings) == 0: | |
raise ValueError( | |
f"Expected embeddings to be a list with at least one item, got {embeddings}" | |
) | |
if not all([isinstance(e, list) for e in embeddings]): | |
raise ValueError( | |
f"Expected each embedding in the embeddings to be a list, got {embeddings}" | |
) | |
for i,embedding in enumerate(embeddings): | |
if len(embedding) == 0: | |
raise ValueError( | |
f"Expected each embedding in the embeddings to be a non-empty list, got empty embedding at pos {i}" | |
) | |
if not all( | |
[ | |
isinstance(value, (int, float)) and not isinstance(value, bool) | |
for value in embedding | |
] | |
): | |
raise ValueError( | |
f"Expected each value in the embedding to be a int or float, got {embeddings}" | |
) | |
return embeddings | |
def validate_batch( | |
batch: Tuple[ | |
IDs, | |
Optional[Embeddings], | |
Optional[Metadatas], | |
Optional[Documents], | |
Optional[URIs], | |
], | |
limits: Dict[str, Any], | |
) -> None: | |
if len(batch[0]) > limits["max_batch_size"]: | |
raise ValueError( | |
f"Batch size {len(batch[0])} exceeds maximum batch size {limits['max_batch_size']}" | |
) | |