chroma / chromadb /api /types.py
badalsahani's picture
feat: chroma initial deploy
287a0bc
from typing import Optional, Union, TypeVar, List, Dict, Any, Tuple, cast
from numpy.typing import NDArray
import numpy as np
from typing_extensions import Literal, TypedDict, Protocol
import chromadb.errors as errors
from chromadb.types import (
Metadata,
UpdateMetadata,
Vector,
LiteralValue,
LogicalOperator,
WhereOperator,
OperatorExpression,
Where,
WhereDocumentOperator,
WhereDocument,
)
from inspect import signature
from tenacity import retry
# Re-export types from chromadb.types
__all__ = ["Metadata", "Where", "WhereDocument", "UpdateCollectionMetadata"]
T = TypeVar("T")
OneOrMany = Union[T, List[T]]
# URIs
URI = str
URIs = List[URI]
def maybe_cast_one_to_many_uri(target: OneOrMany[URI]) -> URIs:
if isinstance(target, str):
# One URI
return cast(URIs, [target])
# Already a sequence
return cast(URIs, target)
# IDs
ID = str
IDs = List[ID]
def maybe_cast_one_to_many_ids(target: OneOrMany[ID]) -> IDs:
if isinstance(target, str):
# One ID
return cast(IDs, [target])
# Already a sequence
return cast(IDs, target)
# Embeddings
Embedding = Vector
Embeddings = List[Embedding]
def maybe_cast_one_to_many_embedding(target: OneOrMany[Embedding]) -> Embeddings:
if isinstance(target, List):
# One Embedding
if isinstance(target[0], (int, float)):
return cast(Embeddings, [target])
# Already a sequence
return cast(Embeddings, target)
# Metadatas
Metadatas = List[Metadata]
def maybe_cast_one_to_many_metadata(target: OneOrMany[Metadata]) -> Metadatas:
# One Metadata dict
if isinstance(target, dict):
return cast(Metadatas, [target])
# Already a sequence
return cast(Metadatas, target)
CollectionMetadata = Dict[str, Any]
UpdateCollectionMetadata = UpdateMetadata
# Documents
Document = str
Documents = List[Document]
def is_document(target: Any) -> bool:
if not isinstance(target, str):
return False
return True
def maybe_cast_one_to_many_document(target: OneOrMany[Document]) -> Documents:
# One Document
if is_document(target):
return cast(Documents, [target])
# Already a sequence
return cast(Documents, target)
# Images
ImageDType = Union[np.uint, np.int_, np.float_]
Image = NDArray[ImageDType]
Images = List[Image]
def is_image(target: Any) -> bool:
if not isinstance(target, np.ndarray):
return False
if len(target.shape) < 2:
return False
return True
def maybe_cast_one_to_many_image(target: OneOrMany[Image]) -> Images:
if is_image(target):
return cast(Images, [target])
# Already a sequence
return cast(Images, target)
Parameter = TypeVar("Parameter", Document, Image, Embedding, Metadata, ID)
# This should ust be List[Literal["documents", "embeddings", "metadatas", "distances"]]
# However, this provokes an incompatibility with the Overrides library and Python 3.7
Include = List[
Union[
Literal["documents"],
Literal["embeddings"],
Literal["metadatas"],
Literal["distances"],
Literal["uris"],
Literal["data"],
]
]
# Re-export types from chromadb.types
LiteralValue = LiteralValue
LogicalOperator = LogicalOperator
WhereOperator = WhereOperator
OperatorExpression = OperatorExpression
Where = Where
WhereDocumentOperator = WhereDocumentOperator
Embeddable = Union[Documents, Images]
D = TypeVar("D", bound=Embeddable, contravariant=True)
Loadable = List[Optional[Image]]
L = TypeVar("L", covariant=True, bound=Loadable)
class GetResult(TypedDict):
ids: List[ID]
embeddings: Optional[List[Embedding]]
documents: Optional[List[Document]]
uris: Optional[URIs]
data: Optional[Loadable]
metadatas: Optional[List[Metadata]]
class QueryResult(TypedDict):
ids: List[IDs]
embeddings: Optional[List[List[Embedding]]]
documents: Optional[List[List[Document]]]
uris: Optional[List[List[URI]]]
data: Optional[List[Loadable]]
metadatas: Optional[List[List[Metadata]]]
distances: Optional[List[List[float]]]
class IndexMetadata(TypedDict):
dimensionality: int
# The current number of elements in the index (total = additions - deletes)
curr_elements: int
# The auto-incrementing ID of the last inserted element, never decreases so
# can be used as a count of total historical size. Should increase by 1 every add.
# Assume cannot overflow
total_elements_added: int
time_created: float
class EmbeddingFunction(Protocol[D]):
def __call__(self, input: D) -> Embeddings:
...
def __init_subclass__(cls) -> None:
super().__init_subclass__()
# Raise an exception if __call__ is not defined since it is expected to be defined
call = getattr(cls, "__call__")
def __call__(self: EmbeddingFunction[D], input: D) -> Embeddings:
result = call(self, input)
return validate_embeddings(maybe_cast_one_to_many_embedding(result))
setattr(cls, "__call__", __call__)
def embed_with_retries(self, input: D, **retry_kwargs: Dict) -> Embeddings:
return retry(**retry_kwargs)(self.__call__)(input)
def validate_embedding_function(
embedding_function: EmbeddingFunction[Embeddable],
) -> None:
function_signature = signature(
embedding_function.__class__.__call__
).parameters.keys()
protocol_signature = signature(EmbeddingFunction.__call__).parameters.keys()
if not function_signature == protocol_signature:
raise ValueError(
f"Expected EmbeddingFunction.__call__ to have the following signature: {protocol_signature}, got {function_signature}\n"
"Please see https://docs.trychroma.com/embeddings for details of the EmbeddingFunction interface.\n"
"Please note the recent change to the EmbeddingFunction interface: https://docs.trychroma.com/migration#migration-to-0416---november-7-2023 \n"
)
class DataLoader(Protocol[L]):
def __call__(self, uris: URIs) -> L:
...
def validate_ids(ids: IDs) -> IDs:
"""Validates ids to ensure it is a list of strings"""
if not isinstance(ids, list):
raise ValueError(f"Expected IDs to be a list, got {ids}")
if len(ids) == 0:
raise ValueError(f"Expected IDs to be a non-empty list, got {ids}")
seen = set()
dups = set()
for id_ in ids:
if not isinstance(id_, str):
raise ValueError(f"Expected ID to be a str, got {id_}")
if id_ in seen:
dups.add(id_)
else:
seen.add(id_)
if dups:
n_dups = len(dups)
if n_dups < 10:
example_string = ", ".join(dups)
message = (
f"Expected IDs to be unique, found duplicates of: {example_string}"
)
else:
examples = []
for idx, dup in enumerate(dups):
examples.append(dup)
if idx == 10:
break
example_string = (
f"{', '.join(examples[:5])}, ..., {', '.join(examples[-5:])}"
)
message = f"Expected IDs to be unique, found {n_dups} duplicated IDs: {example_string}"
raise errors.DuplicateIDError(message)
return ids
def validate_metadata(metadata: Metadata) -> Metadata:
"""Validates metadata to ensure it is a dictionary of strings to strings, ints, floats or bools"""
if not isinstance(metadata, dict) and metadata is not None:
raise ValueError(f"Expected metadata to be a dict or None, got {metadata}")
if metadata is None:
return metadata
if len(metadata) == 0:
raise ValueError(f"Expected metadata to be a non-empty dict, got {metadata}")
for key, value in metadata.items():
if not isinstance(key, str):
raise TypeError(
f"Expected metadata key to be a str, got {key} which is a {type(key)}"
)
# isinstance(True, int) evaluates to True, so we need to check for bools separately
if not isinstance(value, bool) and not isinstance(value, (str, int, float)):
raise ValueError(
f"Expected metadata value to be a str, int, float or bool, got {value} which is a {type(value)}"
)
return metadata
def validate_update_metadata(metadata: UpdateMetadata) -> UpdateMetadata:
"""Validates metadata to ensure it is a dictionary of strings to strings, ints, floats or bools"""
if not isinstance(metadata, dict) and metadata is not None:
raise ValueError(f"Expected metadata to be a dict or None, got {metadata}")
if metadata is None:
return metadata
if len(metadata) == 0:
raise ValueError(f"Expected metadata to be a non-empty dict, got {metadata}")
for key, value in metadata.items():
if not isinstance(key, str):
raise ValueError(f"Expected metadata key to be a str, got {key}")
# isinstance(True, int) evaluates to True, so we need to check for bools separately
if not isinstance(value, bool) and not isinstance(
value, (str, int, float, type(None))
):
raise ValueError(
f"Expected metadata value to be a str, int, or float, got {value}"
)
return metadata
def validate_metadatas(metadatas: Metadatas) -> Metadatas:
"""Validates metadatas to ensure it is a list of dictionaries of strings to strings, ints, floats or bools"""
if not isinstance(metadatas, list):
raise ValueError(f"Expected metadatas to be a list, got {metadatas}")
for metadata in metadatas:
validate_metadata(metadata)
return metadatas
def validate_where(where: Where) -> Where:
"""
Validates where to ensure it is a dictionary of strings to strings, ints, floats or operator expressions,
or in the case of $and and $or, a list of where expressions
"""
if not isinstance(where, dict):
raise ValueError(f"Expected where to be a dict, got {where}")
if len(where) != 1:
raise ValueError(f"Expected where to have exactly one operator, got {where}")
for key, value in where.items():
if not isinstance(key, str):
raise ValueError(f"Expected where key to be a str, got {key}")
if (
key != "$and"
and key != "$or"
and key != "$in"
and key != "$nin"
and not isinstance(value, (str, int, float, dict))
):
raise ValueError(
f"Expected where value to be a str, int, float, or operator expression, got {value}"
)
if key == "$and" or key == "$or":
if not isinstance(value, list):
raise ValueError(
f"Expected where value for $and or $or to be a list of where expressions, got {value}"
)
if len(value) <= 1:
raise ValueError(
f"Expected where value for $and or $or to be a list with at least two where expressions, got {value}"
)
for where_expression in value:
validate_where(where_expression)
# Value is a operator expression
if isinstance(value, dict):
# Ensure there is only one operator
if len(value) != 1:
raise ValueError(
f"Expected operator expression to have exactly one operator, got {value}"
)
for operator, operand in value.items():
# Only numbers can be compared with gt, gte, lt, lte
if operator in ["$gt", "$gte", "$lt", "$lte"]:
if not isinstance(operand, (int, float)):
raise ValueError(
f"Expected operand value to be an int or a float for operator {operator}, got {operand}"
)
if operator in ["$in", "$nin"]:
if not isinstance(operand, list):
raise ValueError(
f"Expected operand value to be an list for operator {operator}, got {operand}"
)
if operator not in [
"$gt",
"$gte",
"$lt",
"$lte",
"$ne",
"$eq",
"$in",
"$nin",
]:
raise ValueError(
f"Expected where operator to be one of $gt, $gte, $lt, $lte, $ne, $eq, $in, $nin, "
f"got {operator}"
)
if not isinstance(operand, (str, int, float, list)):
raise ValueError(
f"Expected where operand value to be a str, int, float, or list of those type, got {operand}"
)
if isinstance(operand, list) and (
len(operand) == 0
or not all(isinstance(x, type(operand[0])) for x in operand)
):
raise ValueError(
f"Expected where operand value to be a non-empty list, and all values to obe of the same type "
f"got {operand}"
)
return where
def validate_where_document(where_document: WhereDocument) -> WhereDocument:
"""
Validates where_document to ensure it is a dictionary of WhereDocumentOperator to strings, or in the case of $and and $or,
a list of where_document expressions
"""
if not isinstance(where_document, dict):
raise ValueError(
f"Expected where document to be a dictionary, got {where_document}"
)
if len(where_document) != 1:
raise ValueError(
f"Expected where document to have exactly one operator, got {where_document}"
)
for operator, operand in where_document.items():
if operator not in ["$contains", "$not_contains", "$and", "$or"]:
raise ValueError(
f"Expected where document operator to be one of $contains, $and, $or, got {operator}"
)
if operator == "$and" or operator == "$or":
if not isinstance(operand, list):
raise ValueError(
f"Expected document value for $and or $or to be a list of where document expressions, got {operand}"
)
if len(operand) <= 1:
raise ValueError(
f"Expected document value for $and or $or to be a list with at least two where document expressions, got {operand}"
)
for where_document_expression in operand:
validate_where_document(where_document_expression)
# Value is a $contains operator
elif not isinstance(operand, str):
raise ValueError(
f"Expected where document operand value for operator $contains to be a str, got {operand}"
)
elif len(operand) == 0:
raise ValueError(
"Expected where document operand value for operator $contains to be a non-empty str"
)
return where_document
def validate_include(include: Include, allow_distances: bool) -> Include:
"""Validates include to ensure it is a list of strings. Since get does not allow distances, allow_distances is used
to control if distances is allowed"""
if not isinstance(include, list):
raise ValueError(f"Expected include to be a list, got {include}")
for item in include:
if not isinstance(item, str):
raise ValueError(f"Expected include item to be a str, got {item}")
allowed_values = ["embeddings", "documents", "metadatas", "uris", "data"]
if allow_distances:
allowed_values.append("distances")
if item not in allowed_values:
raise ValueError(
f"Expected include item to be one of {', '.join(allowed_values)}, got {item}"
)
return include
def validate_n_results(n_results: int) -> int:
"""Validates n_results to ensure it is a positive Integer. Since hnswlib does not allow n_results to be negative."""
# Check Number of requested results
if not isinstance(n_results, int):
raise ValueError(
f"Expected requested number of results to be a int, got {n_results}"
)
if n_results <= 0:
raise TypeError(
f"Number of requested results {n_results}, cannot be negative, or zero."
)
return n_results
def validate_embeddings(embeddings: Embeddings) -> Embeddings:
"""Validates embeddings to ensure it is a list of list of ints, or floats"""
if not isinstance(embeddings, list):
raise ValueError(f"Expected embeddings to be a list, got {embeddings}")
if len(embeddings) == 0:
raise ValueError(
f"Expected embeddings to be a list with at least one item, got {embeddings}"
)
if not all([isinstance(e, list) for e in embeddings]):
raise ValueError(
f"Expected each embedding in the embeddings to be a list, got {embeddings}"
)
for i,embedding in enumerate(embeddings):
if len(embedding) == 0:
raise ValueError(
f"Expected each embedding in the embeddings to be a non-empty list, got empty embedding at pos {i}"
)
if not all(
[
isinstance(value, (int, float)) and not isinstance(value, bool)
for value in embedding
]
):
raise ValueError(
f"Expected each value in the embedding to be a int or float, got {embeddings}"
)
return embeddings
def validate_batch(
batch: Tuple[
IDs,
Optional[Embeddings],
Optional[Metadatas],
Optional[Documents],
Optional[URIs],
],
limits: Dict[str, Any],
) -> None:
if len(batch[0]) > limits["max_batch_size"]:
raise ValueError(
f"Batch size {len(batch[0])} exceeds maximum batch size {limits['max_batch_size']}"
)