chroma / chromadb /test /property /test_collections.py
badalsahani's picture
feat: chroma initial deploy
287a0bc
import pytest
import logging
import hypothesis.strategies as st
import chromadb.test.property.strategies as strategies
from chromadb.api import ClientAPI
import chromadb.api.types as types
from hypothesis.stateful import (
Bundle,
RuleBasedStateMachine,
rule,
initialize,
multiple,
consumes,
run_state_machine_as_test,
MultipleResults,
)
from typing import Dict, Optional
class CollectionStateMachine(RuleBasedStateMachine):
collections: Bundle[strategies.Collection]
_model: Dict[str, Optional[types.CollectionMetadata]]
collections = Bundle("collections")
def __init__(self, api: ClientAPI):
super().__init__()
self._model = {}
self.api = api
@initialize()
def initialize(self) -> None:
self.api.reset()
self._model = {}
@rule(target=collections, coll=strategies.collections())
def create_coll(
self, coll: strategies.Collection
) -> MultipleResults[strategies.Collection]:
# Metadata can either be None or a non-empty dict
if coll.name in self.model or (
coll.metadata is not None and len(coll.metadata) == 0
):
with pytest.raises(Exception):
c = self.api.create_collection(
name=coll.name,
metadata=coll.metadata,
embedding_function=coll.embedding_function,
)
return multiple()
c = self.api.create_collection(
name=coll.name,
metadata=coll.metadata,
embedding_function=coll.embedding_function,
)
self.set_model(coll.name, coll.metadata)
assert c.name == coll.name
assert c.metadata == self.model[coll.name]
return multiple(coll)
@rule(coll=collections)
def get_coll(self, coll: strategies.Collection) -> None:
if coll.name in self.model:
c = self.api.get_collection(name=coll.name)
assert c.name == coll.name
assert c.metadata == self.model[coll.name]
else:
with pytest.raises(Exception):
self.api.get_collection(name=coll.name)
@rule(coll=consumes(collections))
def delete_coll(self, coll: strategies.Collection) -> None:
if coll.name in self.model:
self.api.delete_collection(name=coll.name)
self.delete_from_model(coll.name)
else:
with pytest.raises(Exception):
self.api.delete_collection(name=coll.name)
with pytest.raises(Exception):
self.api.get_collection(name=coll.name)
@rule()
def list_collections(self) -> None:
colls = self.api.list_collections()
assert len(colls) == len(self.model)
for c in colls:
assert c.name in self.model
# @rule for list_collections with limit and offset
@rule(
limit=st.integers(min_value=1, max_value=5),
offset=st.integers(min_value=0, max_value=5),
)
def list_collections_with_limit_offset(self, limit: int, offset: int) -> None:
colls = self.api.list_collections(limit=limit, offset=offset)
total_collections = self.api.count_collections()
# get all collections
all_colls = self.api.list_collections()
# manually slice the collections based on the given limit and offset
man_colls = all_colls[offset : offset + limit]
# given limit and offset, make various assertions regarding the total number of collections
if limit + offset > total_collections:
assert len(colls) == max(total_collections - offset, 0)
# assert that our manually sliced collections are the same as the ones returned by the API
assert colls == man_colls
else:
assert len(colls) == limit
@rule(
target=collections,
new_metadata=st.one_of(st.none(), strategies.collection_metadata),
coll=st.one_of(consumes(collections), strategies.collections()),
)
def get_or_create_coll(
self,
coll: strategies.Collection,
new_metadata: Optional[types.Metadata],
) -> MultipleResults[strategies.Collection]:
# Cases for get_or_create
# Case 0
# new_metadata is none, coll is an existing collection
# get_or_create should return the existing collection with existing metadata
# Essentially - an update with none is a no-op
# Case 1
# new_metadata is none, coll is a new collection
# get_or_create should create a new collection with the metadata of None
# Case 2
# new_metadata is not none, coll is an existing collection
# get_or_create should return the existing collection with updated metadata
# Case 3
# new_metadata is not none, coll is a new collection
# get_or_create should create a new collection with the new metadata, ignoring
# the metdata of in the input coll.
# The fact that we ignore the metadata of the generated collections is a
# bit weird, but it is the easiest way to excercise all cases
if new_metadata is not None and len(new_metadata) == 0:
with pytest.raises(Exception):
c = self.api.get_or_create_collection(
name=coll.name,
metadata=new_metadata,
embedding_function=coll.embedding_function,
)
return multiple()
# Update model
if coll.name not in self.model:
# Handles case 1 and 3
coll.metadata = new_metadata
else:
# Handles case 0 and 2
coll.metadata = (
self.model[coll.name] if new_metadata is None else new_metadata
)
self.set_model(coll.name, coll.metadata)
# Update API
c = self.api.get_or_create_collection(
name=coll.name,
metadata=new_metadata,
embedding_function=coll.embedding_function,
)
# Check that model and API are in sync
assert c.name == coll.name
assert c.metadata == self.model[coll.name]
return multiple(coll)
@rule(
target=collections,
coll=consumes(collections),
new_metadata=strategies.collection_metadata,
new_name=st.one_of(st.none(), strategies.collection_name()),
)
def modify_coll(
self,
coll: strategies.Collection,
new_metadata: types.Metadata,
new_name: Optional[str],
) -> MultipleResults[strategies.Collection]:
if coll.name not in self.model:
with pytest.raises(Exception):
c = self.api.get_collection(name=coll.name)
return multiple()
c = self.api.get_collection(name=coll.name)
if new_metadata is not None:
if len(new_metadata) == 0:
with pytest.raises(Exception):
c = self.api.get_or_create_collection(
name=coll.name,
metadata=new_metadata,
embedding_function=coll.embedding_function,
)
return multiple()
coll.metadata = new_metadata
self.set_model(coll.name, coll.metadata)
if new_name is not None:
if new_name in self.model and new_name != coll.name:
with pytest.raises(Exception):
c.modify(metadata=new_metadata, name=new_name)
return multiple()
prev_metadata = self.model[coll.name]
self.delete_from_model(coll.name)
self.set_model(new_name, prev_metadata)
coll.name = new_name
c.modify(metadata=new_metadata, name=new_name)
c = self.api.get_collection(name=coll.name)
assert c.name == coll.name
assert c.metadata == self.model[coll.name]
return multiple(coll)
def set_model(
self, name: str, metadata: Optional[types.CollectionMetadata]
) -> None:
model = self.model
model[name] = metadata
def delete_from_model(self, name: str) -> None:
model = self.model
del model[name]
@property
def model(self) -> Dict[str, Optional[types.CollectionMetadata]]:
return self._model
def test_collections(caplog: pytest.LogCaptureFixture, api: ClientAPI) -> None:
caplog.set_level(logging.ERROR)
run_state_machine_as_test(lambda: CollectionStateMachine(api)) # type: ignore