Spaces:
Sleeping
Sleeping
File size: 18,242 Bytes
287a0bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 |
from typing import Optional, Union, TypeVar, List, Dict, Any, Tuple, cast
from numpy.typing import NDArray
import numpy as np
from typing_extensions import Literal, TypedDict, Protocol
import chromadb.errors as errors
from chromadb.types import (
Metadata,
UpdateMetadata,
Vector,
LiteralValue,
LogicalOperator,
WhereOperator,
OperatorExpression,
Where,
WhereDocumentOperator,
WhereDocument,
)
from inspect import signature
from tenacity import retry
# Re-export types from chromadb.types
__all__ = ["Metadata", "Where", "WhereDocument", "UpdateCollectionMetadata"]
T = TypeVar("T")
OneOrMany = Union[T, List[T]]
# URIs
URI = str
URIs = List[URI]
def maybe_cast_one_to_many_uri(target: OneOrMany[URI]) -> URIs:
if isinstance(target, str):
# One URI
return cast(URIs, [target])
# Already a sequence
return cast(URIs, target)
# IDs
ID = str
IDs = List[ID]
def maybe_cast_one_to_many_ids(target: OneOrMany[ID]) -> IDs:
if isinstance(target, str):
# One ID
return cast(IDs, [target])
# Already a sequence
return cast(IDs, target)
# Embeddings
Embedding = Vector
Embeddings = List[Embedding]
def maybe_cast_one_to_many_embedding(target: OneOrMany[Embedding]) -> Embeddings:
if isinstance(target, List):
# One Embedding
if isinstance(target[0], (int, float)):
return cast(Embeddings, [target])
# Already a sequence
return cast(Embeddings, target)
# Metadatas
Metadatas = List[Metadata]
def maybe_cast_one_to_many_metadata(target: OneOrMany[Metadata]) -> Metadatas:
# One Metadata dict
if isinstance(target, dict):
return cast(Metadatas, [target])
# Already a sequence
return cast(Metadatas, target)
CollectionMetadata = Dict[str, Any]
UpdateCollectionMetadata = UpdateMetadata
# Documents
Document = str
Documents = List[Document]
def is_document(target: Any) -> bool:
if not isinstance(target, str):
return False
return True
def maybe_cast_one_to_many_document(target: OneOrMany[Document]) -> Documents:
# One Document
if is_document(target):
return cast(Documents, [target])
# Already a sequence
return cast(Documents, target)
# Images
ImageDType = Union[np.uint, np.int_, np.float_]
Image = NDArray[ImageDType]
Images = List[Image]
def is_image(target: Any) -> bool:
if not isinstance(target, np.ndarray):
return False
if len(target.shape) < 2:
return False
return True
def maybe_cast_one_to_many_image(target: OneOrMany[Image]) -> Images:
if is_image(target):
return cast(Images, [target])
# Already a sequence
return cast(Images, target)
Parameter = TypeVar("Parameter", Document, Image, Embedding, Metadata, ID)
# This should ust be List[Literal["documents", "embeddings", "metadatas", "distances"]]
# However, this provokes an incompatibility with the Overrides library and Python 3.7
Include = List[
Union[
Literal["documents"],
Literal["embeddings"],
Literal["metadatas"],
Literal["distances"],
Literal["uris"],
Literal["data"],
]
]
# Re-export types from chromadb.types
LiteralValue = LiteralValue
LogicalOperator = LogicalOperator
WhereOperator = WhereOperator
OperatorExpression = OperatorExpression
Where = Where
WhereDocumentOperator = WhereDocumentOperator
Embeddable = Union[Documents, Images]
D = TypeVar("D", bound=Embeddable, contravariant=True)
Loadable = List[Optional[Image]]
L = TypeVar("L", covariant=True, bound=Loadable)
class GetResult(TypedDict):
ids: List[ID]
embeddings: Optional[List[Embedding]]
documents: Optional[List[Document]]
uris: Optional[URIs]
data: Optional[Loadable]
metadatas: Optional[List[Metadata]]
class QueryResult(TypedDict):
ids: List[IDs]
embeddings: Optional[List[List[Embedding]]]
documents: Optional[List[List[Document]]]
uris: Optional[List[List[URI]]]
data: Optional[List[Loadable]]
metadatas: Optional[List[List[Metadata]]]
distances: Optional[List[List[float]]]
class IndexMetadata(TypedDict):
dimensionality: int
# The current number of elements in the index (total = additions - deletes)
curr_elements: int
# The auto-incrementing ID of the last inserted element, never decreases so
# can be used as a count of total historical size. Should increase by 1 every add.
# Assume cannot overflow
total_elements_added: int
time_created: float
class EmbeddingFunction(Protocol[D]):
def __call__(self, input: D) -> Embeddings:
...
def __init_subclass__(cls) -> None:
super().__init_subclass__()
# Raise an exception if __call__ is not defined since it is expected to be defined
call = getattr(cls, "__call__")
def __call__(self: EmbeddingFunction[D], input: D) -> Embeddings:
result = call(self, input)
return validate_embeddings(maybe_cast_one_to_many_embedding(result))
setattr(cls, "__call__", __call__)
def embed_with_retries(self, input: D, **retry_kwargs: Dict) -> Embeddings:
return retry(**retry_kwargs)(self.__call__)(input)
def validate_embedding_function(
embedding_function: EmbeddingFunction[Embeddable],
) -> None:
function_signature = signature(
embedding_function.__class__.__call__
).parameters.keys()
protocol_signature = signature(EmbeddingFunction.__call__).parameters.keys()
if not function_signature == protocol_signature:
raise ValueError(
f"Expected EmbeddingFunction.__call__ to have the following signature: {protocol_signature}, got {function_signature}\n"
"Please see https://docs.trychroma.com/embeddings for details of the EmbeddingFunction interface.\n"
"Please note the recent change to the EmbeddingFunction interface: https://docs.trychroma.com/migration#migration-to-0416---november-7-2023 \n"
)
class DataLoader(Protocol[L]):
def __call__(self, uris: URIs) -> L:
...
def validate_ids(ids: IDs) -> IDs:
"""Validates ids to ensure it is a list of strings"""
if not isinstance(ids, list):
raise ValueError(f"Expected IDs to be a list, got {ids}")
if len(ids) == 0:
raise ValueError(f"Expected IDs to be a non-empty list, got {ids}")
seen = set()
dups = set()
for id_ in ids:
if not isinstance(id_, str):
raise ValueError(f"Expected ID to be a str, got {id_}")
if id_ in seen:
dups.add(id_)
else:
seen.add(id_)
if dups:
n_dups = len(dups)
if n_dups < 10:
example_string = ", ".join(dups)
message = (
f"Expected IDs to be unique, found duplicates of: {example_string}"
)
else:
examples = []
for idx, dup in enumerate(dups):
examples.append(dup)
if idx == 10:
break
example_string = (
f"{', '.join(examples[:5])}, ..., {', '.join(examples[-5:])}"
)
message = f"Expected IDs to be unique, found {n_dups} duplicated IDs: {example_string}"
raise errors.DuplicateIDError(message)
return ids
def validate_metadata(metadata: Metadata) -> Metadata:
"""Validates metadata to ensure it is a dictionary of strings to strings, ints, floats or bools"""
if not isinstance(metadata, dict) and metadata is not None:
raise ValueError(f"Expected metadata to be a dict or None, got {metadata}")
if metadata is None:
return metadata
if len(metadata) == 0:
raise ValueError(f"Expected metadata to be a non-empty dict, got {metadata}")
for key, value in metadata.items():
if not isinstance(key, str):
raise TypeError(
f"Expected metadata key to be a str, got {key} which is a {type(key)}"
)
# isinstance(True, int) evaluates to True, so we need to check for bools separately
if not isinstance(value, bool) and not isinstance(value, (str, int, float)):
raise ValueError(
f"Expected metadata value to be a str, int, float or bool, got {value} which is a {type(value)}"
)
return metadata
def validate_update_metadata(metadata: UpdateMetadata) -> UpdateMetadata:
"""Validates metadata to ensure it is a dictionary of strings to strings, ints, floats or bools"""
if not isinstance(metadata, dict) and metadata is not None:
raise ValueError(f"Expected metadata to be a dict or None, got {metadata}")
if metadata is None:
return metadata
if len(metadata) == 0:
raise ValueError(f"Expected metadata to be a non-empty dict, got {metadata}")
for key, value in metadata.items():
if not isinstance(key, str):
raise ValueError(f"Expected metadata key to be a str, got {key}")
# isinstance(True, int) evaluates to True, so we need to check for bools separately
if not isinstance(value, bool) and not isinstance(
value, (str, int, float, type(None))
):
raise ValueError(
f"Expected metadata value to be a str, int, or float, got {value}"
)
return metadata
def validate_metadatas(metadatas: Metadatas) -> Metadatas:
"""Validates metadatas to ensure it is a list of dictionaries of strings to strings, ints, floats or bools"""
if not isinstance(metadatas, list):
raise ValueError(f"Expected metadatas to be a list, got {metadatas}")
for metadata in metadatas:
validate_metadata(metadata)
return metadatas
def validate_where(where: Where) -> Where:
"""
Validates where to ensure it is a dictionary of strings to strings, ints, floats or operator expressions,
or in the case of $and and $or, a list of where expressions
"""
if not isinstance(where, dict):
raise ValueError(f"Expected where to be a dict, got {where}")
if len(where) != 1:
raise ValueError(f"Expected where to have exactly one operator, got {where}")
for key, value in where.items():
if not isinstance(key, str):
raise ValueError(f"Expected where key to be a str, got {key}")
if (
key != "$and"
and key != "$or"
and key != "$in"
and key != "$nin"
and not isinstance(value, (str, int, float, dict))
):
raise ValueError(
f"Expected where value to be a str, int, float, or operator expression, got {value}"
)
if key == "$and" or key == "$or":
if not isinstance(value, list):
raise ValueError(
f"Expected where value for $and or $or to be a list of where expressions, got {value}"
)
if len(value) <= 1:
raise ValueError(
f"Expected where value for $and or $or to be a list with at least two where expressions, got {value}"
)
for where_expression in value:
validate_where(where_expression)
# Value is a operator expression
if isinstance(value, dict):
# Ensure there is only one operator
if len(value) != 1:
raise ValueError(
f"Expected operator expression to have exactly one operator, got {value}"
)
for operator, operand in value.items():
# Only numbers can be compared with gt, gte, lt, lte
if operator in ["$gt", "$gte", "$lt", "$lte"]:
if not isinstance(operand, (int, float)):
raise ValueError(
f"Expected operand value to be an int or a float for operator {operator}, got {operand}"
)
if operator in ["$in", "$nin"]:
if not isinstance(operand, list):
raise ValueError(
f"Expected operand value to be an list for operator {operator}, got {operand}"
)
if operator not in [
"$gt",
"$gte",
"$lt",
"$lte",
"$ne",
"$eq",
"$in",
"$nin",
]:
raise ValueError(
f"Expected where operator to be one of $gt, $gte, $lt, $lte, $ne, $eq, $in, $nin, "
f"got {operator}"
)
if not isinstance(operand, (str, int, float, list)):
raise ValueError(
f"Expected where operand value to be a str, int, float, or list of those type, got {operand}"
)
if isinstance(operand, list) and (
len(operand) == 0
or not all(isinstance(x, type(operand[0])) for x in operand)
):
raise ValueError(
f"Expected where operand value to be a non-empty list, and all values to obe of the same type "
f"got {operand}"
)
return where
def validate_where_document(where_document: WhereDocument) -> WhereDocument:
"""
Validates where_document to ensure it is a dictionary of WhereDocumentOperator to strings, or in the case of $and and $or,
a list of where_document expressions
"""
if not isinstance(where_document, dict):
raise ValueError(
f"Expected where document to be a dictionary, got {where_document}"
)
if len(where_document) != 1:
raise ValueError(
f"Expected where document to have exactly one operator, got {where_document}"
)
for operator, operand in where_document.items():
if operator not in ["$contains", "$not_contains", "$and", "$or"]:
raise ValueError(
f"Expected where document operator to be one of $contains, $and, $or, got {operator}"
)
if operator == "$and" or operator == "$or":
if not isinstance(operand, list):
raise ValueError(
f"Expected document value for $and or $or to be a list of where document expressions, got {operand}"
)
if len(operand) <= 1:
raise ValueError(
f"Expected document value for $and or $or to be a list with at least two where document expressions, got {operand}"
)
for where_document_expression in operand:
validate_where_document(where_document_expression)
# Value is a $contains operator
elif not isinstance(operand, str):
raise ValueError(
f"Expected where document operand value for operator $contains to be a str, got {operand}"
)
elif len(operand) == 0:
raise ValueError(
"Expected where document operand value for operator $contains to be a non-empty str"
)
return where_document
def validate_include(include: Include, allow_distances: bool) -> Include:
"""Validates include to ensure it is a list of strings. Since get does not allow distances, allow_distances is used
to control if distances is allowed"""
if not isinstance(include, list):
raise ValueError(f"Expected include to be a list, got {include}")
for item in include:
if not isinstance(item, str):
raise ValueError(f"Expected include item to be a str, got {item}")
allowed_values = ["embeddings", "documents", "metadatas", "uris", "data"]
if allow_distances:
allowed_values.append("distances")
if item not in allowed_values:
raise ValueError(
f"Expected include item to be one of {', '.join(allowed_values)}, got {item}"
)
return include
def validate_n_results(n_results: int) -> int:
"""Validates n_results to ensure it is a positive Integer. Since hnswlib does not allow n_results to be negative."""
# Check Number of requested results
if not isinstance(n_results, int):
raise ValueError(
f"Expected requested number of results to be a int, got {n_results}"
)
if n_results <= 0:
raise TypeError(
f"Number of requested results {n_results}, cannot be negative, or zero."
)
return n_results
def validate_embeddings(embeddings: Embeddings) -> Embeddings:
"""Validates embeddings to ensure it is a list of list of ints, or floats"""
if not isinstance(embeddings, list):
raise ValueError(f"Expected embeddings to be a list, got {embeddings}")
if len(embeddings) == 0:
raise ValueError(
f"Expected embeddings to be a list with at least one item, got {embeddings}"
)
if not all([isinstance(e, list) for e in embeddings]):
raise ValueError(
f"Expected each embedding in the embeddings to be a list, got {embeddings}"
)
for i,embedding in enumerate(embeddings):
if len(embedding) == 0:
raise ValueError(
f"Expected each embedding in the embeddings to be a non-empty list, got empty embedding at pos {i}"
)
if not all(
[
isinstance(value, (int, float)) and not isinstance(value, bool)
for value in embedding
]
):
raise ValueError(
f"Expected each value in the embedding to be a int or float, got {embeddings}"
)
return embeddings
def validate_batch(
batch: Tuple[
IDs,
Optional[Embeddings],
Optional[Metadatas],
Optional[Documents],
Optional[URIs],
],
limits: Dict[str, Any],
) -> None:
if len(batch[0]) > limits["max_batch_size"]:
raise ValueError(
f"Batch size {len(batch[0])} exceeds maximum batch size {limits['max_batch_size']}"
)
|