chroma / chromadb /api /types.py
badalsahani's picture
feat: chroma initial deploy
287a0bc
raw
history blame contribute delete
18.2 kB
from typing import Optional, Union, TypeVar, List, Dict, Any, Tuple, cast
from numpy.typing import NDArray
import numpy as np
from typing_extensions import Literal, TypedDict, Protocol
import chromadb.errors as errors
from chromadb.types import (
Metadata,
UpdateMetadata,
Vector,
LiteralValue,
LogicalOperator,
WhereOperator,
OperatorExpression,
Where,
WhereDocumentOperator,
WhereDocument,
)
from inspect import signature
from tenacity import retry
# Re-export types from chromadb.types
__all__ = ["Metadata", "Where", "WhereDocument", "UpdateCollectionMetadata"]
T = TypeVar("T")
OneOrMany = Union[T, List[T]]
# URIs
URI = str
URIs = List[URI]
def maybe_cast_one_to_many_uri(target: OneOrMany[URI]) -> URIs:
if isinstance(target, str):
# One URI
return cast(URIs, [target])
# Already a sequence
return cast(URIs, target)
# IDs
ID = str
IDs = List[ID]
def maybe_cast_one_to_many_ids(target: OneOrMany[ID]) -> IDs:
if isinstance(target, str):
# One ID
return cast(IDs, [target])
# Already a sequence
return cast(IDs, target)
# Embeddings
Embedding = Vector
Embeddings = List[Embedding]
def maybe_cast_one_to_many_embedding(target: OneOrMany[Embedding]) -> Embeddings:
if isinstance(target, List):
# One Embedding
if isinstance(target[0], (int, float)):
return cast(Embeddings, [target])
# Already a sequence
return cast(Embeddings, target)
# Metadatas
Metadatas = List[Metadata]
def maybe_cast_one_to_many_metadata(target: OneOrMany[Metadata]) -> Metadatas:
# One Metadata dict
if isinstance(target, dict):
return cast(Metadatas, [target])
# Already a sequence
return cast(Metadatas, target)
CollectionMetadata = Dict[str, Any]
UpdateCollectionMetadata = UpdateMetadata
# Documents
Document = str
Documents = List[Document]
def is_document(target: Any) -> bool:
if not isinstance(target, str):
return False
return True
def maybe_cast_one_to_many_document(target: OneOrMany[Document]) -> Documents:
# One Document
if is_document(target):
return cast(Documents, [target])
# Already a sequence
return cast(Documents, target)
# Images
ImageDType = Union[np.uint, np.int_, np.float_]
Image = NDArray[ImageDType]
Images = List[Image]
def is_image(target: Any) -> bool:
if not isinstance(target, np.ndarray):
return False
if len(target.shape) < 2:
return False
return True
def maybe_cast_one_to_many_image(target: OneOrMany[Image]) -> Images:
if is_image(target):
return cast(Images, [target])
# Already a sequence
return cast(Images, target)
Parameter = TypeVar("Parameter", Document, Image, Embedding, Metadata, ID)
# This should ust be List[Literal["documents", "embeddings", "metadatas", "distances"]]
# However, this provokes an incompatibility with the Overrides library and Python 3.7
Include = List[
Union[
Literal["documents"],
Literal["embeddings"],
Literal["metadatas"],
Literal["distances"],
Literal["uris"],
Literal["data"],
]
]
# Re-export types from chromadb.types
LiteralValue = LiteralValue
LogicalOperator = LogicalOperator
WhereOperator = WhereOperator
OperatorExpression = OperatorExpression
Where = Where
WhereDocumentOperator = WhereDocumentOperator
Embeddable = Union[Documents, Images]
D = TypeVar("D", bound=Embeddable, contravariant=True)
Loadable = List[Optional[Image]]
L = TypeVar("L", covariant=True, bound=Loadable)
class GetResult(TypedDict):
ids: List[ID]
embeddings: Optional[List[Embedding]]
documents: Optional[List[Document]]
uris: Optional[URIs]
data: Optional[Loadable]
metadatas: Optional[List[Metadata]]
class QueryResult(TypedDict):
ids: List[IDs]
embeddings: Optional[List[List[Embedding]]]
documents: Optional[List[List[Document]]]
uris: Optional[List[List[URI]]]
data: Optional[List[Loadable]]
metadatas: Optional[List[List[Metadata]]]
distances: Optional[List[List[float]]]
class IndexMetadata(TypedDict):
dimensionality: int
# The current number of elements in the index (total = additions - deletes)
curr_elements: int
# The auto-incrementing ID of the last inserted element, never decreases so
# can be used as a count of total historical size. Should increase by 1 every add.
# Assume cannot overflow
total_elements_added: int
time_created: float
class EmbeddingFunction(Protocol[D]):
def __call__(self, input: D) -> Embeddings:
...
def __init_subclass__(cls) -> None:
super().__init_subclass__()
# Raise an exception if __call__ is not defined since it is expected to be defined
call = getattr(cls, "__call__")
def __call__(self: EmbeddingFunction[D], input: D) -> Embeddings:
result = call(self, input)
return validate_embeddings(maybe_cast_one_to_many_embedding(result))
setattr(cls, "__call__", __call__)
def embed_with_retries(self, input: D, **retry_kwargs: Dict) -> Embeddings:
return retry(**retry_kwargs)(self.__call__)(input)
def validate_embedding_function(
embedding_function: EmbeddingFunction[Embeddable],
) -> None:
function_signature = signature(
embedding_function.__class__.__call__
).parameters.keys()
protocol_signature = signature(EmbeddingFunction.__call__).parameters.keys()
if not function_signature == protocol_signature:
raise ValueError(
f"Expected EmbeddingFunction.__call__ to have the following signature: {protocol_signature}, got {function_signature}\n"
"Please see https://docs.trychroma.com/embeddings for details of the EmbeddingFunction interface.\n"
"Please note the recent change to the EmbeddingFunction interface: https://docs.trychroma.com/migration#migration-to-0416---november-7-2023 \n"
)
class DataLoader(Protocol[L]):
def __call__(self, uris: URIs) -> L:
...
def validate_ids(ids: IDs) -> IDs:
"""Validates ids to ensure it is a list of strings"""
if not isinstance(ids, list):
raise ValueError(f"Expected IDs to be a list, got {ids}")
if len(ids) == 0:
raise ValueError(f"Expected IDs to be a non-empty list, got {ids}")
seen = set()
dups = set()
for id_ in ids:
if not isinstance(id_, str):
raise ValueError(f"Expected ID to be a str, got {id_}")
if id_ in seen:
dups.add(id_)
else:
seen.add(id_)
if dups:
n_dups = len(dups)
if n_dups < 10:
example_string = ", ".join(dups)
message = (
f"Expected IDs to be unique, found duplicates of: {example_string}"
)
else:
examples = []
for idx, dup in enumerate(dups):
examples.append(dup)
if idx == 10:
break
example_string = (
f"{', '.join(examples[:5])}, ..., {', '.join(examples[-5:])}"
)
message = f"Expected IDs to be unique, found {n_dups} duplicated IDs: {example_string}"
raise errors.DuplicateIDError(message)
return ids
def validate_metadata(metadata: Metadata) -> Metadata:
"""Validates metadata to ensure it is a dictionary of strings to strings, ints, floats or bools"""
if not isinstance(metadata, dict) and metadata is not None:
raise ValueError(f"Expected metadata to be a dict or None, got {metadata}")
if metadata is None:
return metadata
if len(metadata) == 0:
raise ValueError(f"Expected metadata to be a non-empty dict, got {metadata}")
for key, value in metadata.items():
if not isinstance(key, str):
raise TypeError(
f"Expected metadata key to be a str, got {key} which is a {type(key)}"
)
# isinstance(True, int) evaluates to True, so we need to check for bools separately
if not isinstance(value, bool) and not isinstance(value, (str, int, float)):
raise ValueError(
f"Expected metadata value to be a str, int, float or bool, got {value} which is a {type(value)}"
)
return metadata
def validate_update_metadata(metadata: UpdateMetadata) -> UpdateMetadata:
"""Validates metadata to ensure it is a dictionary of strings to strings, ints, floats or bools"""
if not isinstance(metadata, dict) and metadata is not None:
raise ValueError(f"Expected metadata to be a dict or None, got {metadata}")
if metadata is None:
return metadata
if len(metadata) == 0:
raise ValueError(f"Expected metadata to be a non-empty dict, got {metadata}")
for key, value in metadata.items():
if not isinstance(key, str):
raise ValueError(f"Expected metadata key to be a str, got {key}")
# isinstance(True, int) evaluates to True, so we need to check for bools separately
if not isinstance(value, bool) and not isinstance(
value, (str, int, float, type(None))
):
raise ValueError(
f"Expected metadata value to be a str, int, or float, got {value}"
)
return metadata
def validate_metadatas(metadatas: Metadatas) -> Metadatas:
"""Validates metadatas to ensure it is a list of dictionaries of strings to strings, ints, floats or bools"""
if not isinstance(metadatas, list):
raise ValueError(f"Expected metadatas to be a list, got {metadatas}")
for metadata in metadatas:
validate_metadata(metadata)
return metadatas
def validate_where(where: Where) -> Where:
"""
Validates where to ensure it is a dictionary of strings to strings, ints, floats or operator expressions,
or in the case of $and and $or, a list of where expressions
"""
if not isinstance(where, dict):
raise ValueError(f"Expected where to be a dict, got {where}")
if len(where) != 1:
raise ValueError(f"Expected where to have exactly one operator, got {where}")
for key, value in where.items():
if not isinstance(key, str):
raise ValueError(f"Expected where key to be a str, got {key}")
if (
key != "$and"
and key != "$or"
and key != "$in"
and key != "$nin"
and not isinstance(value, (str, int, float, dict))
):
raise ValueError(
f"Expected where value to be a str, int, float, or operator expression, got {value}"
)
if key == "$and" or key == "$or":
if not isinstance(value, list):
raise ValueError(
f"Expected where value for $and or $or to be a list of where expressions, got {value}"
)
if len(value) <= 1:
raise ValueError(
f"Expected where value for $and or $or to be a list with at least two where expressions, got {value}"
)
for where_expression in value:
validate_where(where_expression)
# Value is a operator expression
if isinstance(value, dict):
# Ensure there is only one operator
if len(value) != 1:
raise ValueError(
f"Expected operator expression to have exactly one operator, got {value}"
)
for operator, operand in value.items():
# Only numbers can be compared with gt, gte, lt, lte
if operator in ["$gt", "$gte", "$lt", "$lte"]:
if not isinstance(operand, (int, float)):
raise ValueError(
f"Expected operand value to be an int or a float for operator {operator}, got {operand}"
)
if operator in ["$in", "$nin"]:
if not isinstance(operand, list):
raise ValueError(
f"Expected operand value to be an list for operator {operator}, got {operand}"
)
if operator not in [
"$gt",
"$gte",
"$lt",
"$lte",
"$ne",
"$eq",
"$in",
"$nin",
]:
raise ValueError(
f"Expected where operator to be one of $gt, $gte, $lt, $lte, $ne, $eq, $in, $nin, "
f"got {operator}"
)
if not isinstance(operand, (str, int, float, list)):
raise ValueError(
f"Expected where operand value to be a str, int, float, or list of those type, got {operand}"
)
if isinstance(operand, list) and (
len(operand) == 0
or not all(isinstance(x, type(operand[0])) for x in operand)
):
raise ValueError(
f"Expected where operand value to be a non-empty list, and all values to obe of the same type "
f"got {operand}"
)
return where
def validate_where_document(where_document: WhereDocument) -> WhereDocument:
"""
Validates where_document to ensure it is a dictionary of WhereDocumentOperator to strings, or in the case of $and and $or,
a list of where_document expressions
"""
if not isinstance(where_document, dict):
raise ValueError(
f"Expected where document to be a dictionary, got {where_document}"
)
if len(where_document) != 1:
raise ValueError(
f"Expected where document to have exactly one operator, got {where_document}"
)
for operator, operand in where_document.items():
if operator not in ["$contains", "$not_contains", "$and", "$or"]:
raise ValueError(
f"Expected where document operator to be one of $contains, $and, $or, got {operator}"
)
if operator == "$and" or operator == "$or":
if not isinstance(operand, list):
raise ValueError(
f"Expected document value for $and or $or to be a list of where document expressions, got {operand}"
)
if len(operand) <= 1:
raise ValueError(
f"Expected document value for $and or $or to be a list with at least two where document expressions, got {operand}"
)
for where_document_expression in operand:
validate_where_document(where_document_expression)
# Value is a $contains operator
elif not isinstance(operand, str):
raise ValueError(
f"Expected where document operand value for operator $contains to be a str, got {operand}"
)
elif len(operand) == 0:
raise ValueError(
"Expected where document operand value for operator $contains to be a non-empty str"
)
return where_document
def validate_include(include: Include, allow_distances: bool) -> Include:
"""Validates include to ensure it is a list of strings. Since get does not allow distances, allow_distances is used
to control if distances is allowed"""
if not isinstance(include, list):
raise ValueError(f"Expected include to be a list, got {include}")
for item in include:
if not isinstance(item, str):
raise ValueError(f"Expected include item to be a str, got {item}")
allowed_values = ["embeddings", "documents", "metadatas", "uris", "data"]
if allow_distances:
allowed_values.append("distances")
if item not in allowed_values:
raise ValueError(
f"Expected include item to be one of {', '.join(allowed_values)}, got {item}"
)
return include
def validate_n_results(n_results: int) -> int:
"""Validates n_results to ensure it is a positive Integer. Since hnswlib does not allow n_results to be negative."""
# Check Number of requested results
if not isinstance(n_results, int):
raise ValueError(
f"Expected requested number of results to be a int, got {n_results}"
)
if n_results <= 0:
raise TypeError(
f"Number of requested results {n_results}, cannot be negative, or zero."
)
return n_results
def validate_embeddings(embeddings: Embeddings) -> Embeddings:
"""Validates embeddings to ensure it is a list of list of ints, or floats"""
if not isinstance(embeddings, list):
raise ValueError(f"Expected embeddings to be a list, got {embeddings}")
if len(embeddings) == 0:
raise ValueError(
f"Expected embeddings to be a list with at least one item, got {embeddings}"
)
if not all([isinstance(e, list) for e in embeddings]):
raise ValueError(
f"Expected each embedding in the embeddings to be a list, got {embeddings}"
)
for i,embedding in enumerate(embeddings):
if len(embedding) == 0:
raise ValueError(
f"Expected each embedding in the embeddings to be a non-empty list, got empty embedding at pos {i}"
)
if not all(
[
isinstance(value, (int, float)) and not isinstance(value, bool)
for value in embedding
]
):
raise ValueError(
f"Expected each value in the embedding to be a int or float, got {embeddings}"
)
return embeddings
def validate_batch(
batch: Tuple[
IDs,
Optional[Embeddings],
Optional[Metadatas],
Optional[Documents],
Optional[URIs],
],
limits: Dict[str, Any],
) -> None:
if len(batch[0]) > limits["max_batch_size"]:
raise ValueError(
f"Batch size {len(batch[0])} exceeds maximum batch size {limits['max_batch_size']}"
)