Spaces:
Sleeping
Sleeping
import random | |
import uuid | |
from random import randint | |
from typing import cast, List, Any, Dict | |
import pytest | |
import hypothesis.strategies as st | |
from hypothesis import given, settings | |
from chromadb.api import ServerAPI | |
from chromadb.api.types import Embeddings, Metadatas | |
import chromadb.test.property.strategies as strategies | |
import chromadb.test.property.invariants as invariants | |
from chromadb.utils.batch_utils import create_batches | |
collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="coll") | |
def test_add( | |
api: ServerAPI, | |
collection: strategies.Collection, | |
record_set: strategies.RecordSet, | |
) -> None: | |
api.reset() | |
# TODO: Generative embedding functions | |
coll = api.create_collection( | |
name=collection.name, | |
metadata=collection.metadata, # type: ignore | |
embedding_function=collection.embedding_function, | |
) | |
normalized_record_set = invariants.wrap_all(record_set) | |
if not invariants.is_metadata_valid(normalized_record_set): | |
with pytest.raises(Exception): | |
coll.add(**normalized_record_set) | |
return | |
coll.add(**record_set) | |
invariants.count(coll, cast(strategies.RecordSet, normalized_record_set)) | |
n_results = max(1, (len(normalized_record_set["ids"]) // 10)) | |
invariants.ann_accuracy( | |
coll, | |
cast(strategies.RecordSet, normalized_record_set), | |
n_results=n_results, | |
embedding_function=collection.embedding_function, | |
) | |
def create_large_recordset( | |
min_size: int = 45000, | |
max_size: int = 50000, | |
) -> strategies.RecordSet: | |
size = randint(min_size, max_size) | |
ids = [str(uuid.uuid4()) for _ in range(size)] | |
metadatas = [{"some_key": f"{i}"} for i in range(size)] | |
documents = [f"Document {i}" for i in range(size)] | |
embeddings = [[1, 2, 3] for _ in range(size)] | |
record_set: Dict[str, List[Any]] = { | |
"ids": ids, | |
"embeddings": cast(Embeddings, embeddings), | |
"metadatas": metadatas, | |
"documents": documents, | |
} | |
return cast(strategies.RecordSet, record_set) | |
def test_add_large(api: ServerAPI, collection: strategies.Collection) -> None: | |
api.reset() | |
record_set = create_large_recordset( | |
min_size=api.max_batch_size, | |
max_size=api.max_batch_size + int(api.max_batch_size * random.random()), | |
) | |
coll = api.create_collection( | |
name=collection.name, | |
metadata=collection.metadata, # type: ignore | |
embedding_function=collection.embedding_function, | |
) | |
normalized_record_set = invariants.wrap_all(record_set) | |
if not invariants.is_metadata_valid(normalized_record_set): | |
with pytest.raises(Exception): | |
coll.add(**normalized_record_set) | |
return | |
for batch in create_batches( | |
api=api, | |
ids=cast(List[str], record_set["ids"]), | |
embeddings=cast(Embeddings, record_set["embeddings"]), | |
metadatas=cast(Metadatas, record_set["metadatas"]), | |
documents=cast(List[str], record_set["documents"]), | |
): | |
coll.add(*batch) | |
invariants.count(coll, cast(strategies.RecordSet, normalized_record_set)) | |
def test_add_large_exceeding(api: ServerAPI, collection: strategies.Collection) -> None: | |
api.reset() | |
record_set = create_large_recordset( | |
min_size=api.max_batch_size, | |
max_size=api.max_batch_size + int(api.max_batch_size * random.random()), | |
) | |
coll = api.create_collection( | |
name=collection.name, | |
metadata=collection.metadata, # type: ignore | |
embedding_function=collection.embedding_function, | |
) | |
normalized_record_set = invariants.wrap_all(record_set) | |
if not invariants.is_metadata_valid(normalized_record_set): | |
with pytest.raises(Exception): | |
coll.add(**normalized_record_set) | |
return | |
with pytest.raises(Exception) as e: | |
coll.add(**record_set) | |
assert "exceeds maximum batch size" in str(e.value) | |
# TODO: This test fails right now because the ids are not sorted by the input order | |
def test_out_of_order_ids(api: ServerAPI) -> None: | |
api.reset() | |
ooo_ids = [ | |
"40", | |
"05", | |
"8", | |
"6", | |
"10", | |
"01", | |
"00", | |
"3", | |
"04", | |
"20", | |
"02", | |
"9", | |
"30", | |
"11", | |
"13", | |
"2", | |
"0", | |
"7", | |
"06", | |
"5", | |
"50", | |
"12", | |
"03", | |
"4", | |
"1", | |
] | |
coll = api.create_collection( | |
"test", embedding_function=lambda input: [[1, 2, 3] for _ in input] # type: ignore | |
) | |
embeddings: Embeddings = [[1, 2, 3] for _ in ooo_ids] | |
coll.add(ids=ooo_ids, embeddings=embeddings) | |
get_ids = coll.get(ids=ooo_ids)["ids"] | |
assert get_ids == ooo_ids | |
def test_add_partial(api: ServerAPI) -> None: | |
"""Tests adding a record set with some of the fields set to None.""" | |
api.reset() | |
coll = api.create_collection("test") | |
# TODO: We need to clean up the api types to support this typing | |
coll.add( | |
ids=["1", "2", "3"], | |
embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore | |
metadatas=[{"a": 1}, None, {"a": 3}], # type: ignore | |
documents=["a", "b", None], # type: ignore | |
) | |
results = coll.get() | |
assert results["ids"] == ["1", "2", "3"] | |
assert results["metadatas"] == [{"a": 1}, None, {"a": 3}] | |
assert results["documents"] == ["a", "b", None] | |