Spaces:
Sleeping
Sleeping
import pytest | |
import logging | |
import hypothesis.strategies as st | |
import chromadb.test.property.strategies as strategies | |
from chromadb.api import ClientAPI | |
import chromadb.api.types as types | |
from hypothesis.stateful import ( | |
Bundle, | |
RuleBasedStateMachine, | |
rule, | |
initialize, | |
multiple, | |
consumes, | |
run_state_machine_as_test, | |
MultipleResults, | |
) | |
from typing import Dict, Optional | |
class CollectionStateMachine(RuleBasedStateMachine): | |
collections: Bundle[strategies.Collection] | |
_model: Dict[str, Optional[types.CollectionMetadata]] | |
collections = Bundle("collections") | |
def __init__(self, api: ClientAPI): | |
super().__init__() | |
self._model = {} | |
self.api = api | |
def initialize(self) -> None: | |
self.api.reset() | |
self._model = {} | |
def create_coll( | |
self, coll: strategies.Collection | |
) -> MultipleResults[strategies.Collection]: | |
# Metadata can either be None or a non-empty dict | |
if coll.name in self.model or ( | |
coll.metadata is not None and len(coll.metadata) == 0 | |
): | |
with pytest.raises(Exception): | |
c = self.api.create_collection( | |
name=coll.name, | |
metadata=coll.metadata, | |
embedding_function=coll.embedding_function, | |
) | |
return multiple() | |
c = self.api.create_collection( | |
name=coll.name, | |
metadata=coll.metadata, | |
embedding_function=coll.embedding_function, | |
) | |
self.set_model(coll.name, coll.metadata) | |
assert c.name == coll.name | |
assert c.metadata == self.model[coll.name] | |
return multiple(coll) | |
def get_coll(self, coll: strategies.Collection) -> None: | |
if coll.name in self.model: | |
c = self.api.get_collection(name=coll.name) | |
assert c.name == coll.name | |
assert c.metadata == self.model[coll.name] | |
else: | |
with pytest.raises(Exception): | |
self.api.get_collection(name=coll.name) | |
def delete_coll(self, coll: strategies.Collection) -> None: | |
if coll.name in self.model: | |
self.api.delete_collection(name=coll.name) | |
self.delete_from_model(coll.name) | |
else: | |
with pytest.raises(Exception): | |
self.api.delete_collection(name=coll.name) | |
with pytest.raises(Exception): | |
self.api.get_collection(name=coll.name) | |
def list_collections(self) -> None: | |
colls = self.api.list_collections() | |
assert len(colls) == len(self.model) | |
for c in colls: | |
assert c.name in self.model | |
# @rule for list_collections with limit and offset | |
def list_collections_with_limit_offset(self, limit: int, offset: int) -> None: | |
colls = self.api.list_collections(limit=limit, offset=offset) | |
total_collections = self.api.count_collections() | |
# get all collections | |
all_colls = self.api.list_collections() | |
# manually slice the collections based on the given limit and offset | |
man_colls = all_colls[offset : offset + limit] | |
# given limit and offset, make various assertions regarding the total number of collections | |
if limit + offset > total_collections: | |
assert len(colls) == max(total_collections - offset, 0) | |
# assert that our manually sliced collections are the same as the ones returned by the API | |
assert colls == man_colls | |
else: | |
assert len(colls) == limit | |
def get_or_create_coll( | |
self, | |
coll: strategies.Collection, | |
new_metadata: Optional[types.Metadata], | |
) -> MultipleResults[strategies.Collection]: | |
# Cases for get_or_create | |
# Case 0 | |
# new_metadata is none, coll is an existing collection | |
# get_or_create should return the existing collection with existing metadata | |
# Essentially - an update with none is a no-op | |
# Case 1 | |
# new_metadata is none, coll is a new collection | |
# get_or_create should create a new collection with the metadata of None | |
# Case 2 | |
# new_metadata is not none, coll is an existing collection | |
# get_or_create should return the existing collection with updated metadata | |
# Case 3 | |
# new_metadata is not none, coll is a new collection | |
# get_or_create should create a new collection with the new metadata, ignoring | |
# the metdata of in the input coll. | |
# The fact that we ignore the metadata of the generated collections is a | |
# bit weird, but it is the easiest way to excercise all cases | |
if new_metadata is not None and len(new_metadata) == 0: | |
with pytest.raises(Exception): | |
c = self.api.get_or_create_collection( | |
name=coll.name, | |
metadata=new_metadata, | |
embedding_function=coll.embedding_function, | |
) | |
return multiple() | |
# Update model | |
if coll.name not in self.model: | |
# Handles case 1 and 3 | |
coll.metadata = new_metadata | |
else: | |
# Handles case 0 and 2 | |
coll.metadata = ( | |
self.model[coll.name] if new_metadata is None else new_metadata | |
) | |
self.set_model(coll.name, coll.metadata) | |
# Update API | |
c = self.api.get_or_create_collection( | |
name=coll.name, | |
metadata=new_metadata, | |
embedding_function=coll.embedding_function, | |
) | |
# Check that model and API are in sync | |
assert c.name == coll.name | |
assert c.metadata == self.model[coll.name] | |
return multiple(coll) | |
def modify_coll( | |
self, | |
coll: strategies.Collection, | |
new_metadata: types.Metadata, | |
new_name: Optional[str], | |
) -> MultipleResults[strategies.Collection]: | |
if coll.name not in self.model: | |
with pytest.raises(Exception): | |
c = self.api.get_collection(name=coll.name) | |
return multiple() | |
c = self.api.get_collection(name=coll.name) | |
if new_metadata is not None: | |
if len(new_metadata) == 0: | |
with pytest.raises(Exception): | |
c = self.api.get_or_create_collection( | |
name=coll.name, | |
metadata=new_metadata, | |
embedding_function=coll.embedding_function, | |
) | |
return multiple() | |
coll.metadata = new_metadata | |
self.set_model(coll.name, coll.metadata) | |
if new_name is not None: | |
if new_name in self.model and new_name != coll.name: | |
with pytest.raises(Exception): | |
c.modify(metadata=new_metadata, name=new_name) | |
return multiple() | |
prev_metadata = self.model[coll.name] | |
self.delete_from_model(coll.name) | |
self.set_model(new_name, prev_metadata) | |
coll.name = new_name | |
c.modify(metadata=new_metadata, name=new_name) | |
c = self.api.get_collection(name=coll.name) | |
assert c.name == coll.name | |
assert c.metadata == self.model[coll.name] | |
return multiple(coll) | |
def set_model( | |
self, name: str, metadata: Optional[types.CollectionMetadata] | |
) -> None: | |
model = self.model | |
model[name] = metadata | |
def delete_from_model(self, name: str) -> None: | |
model = self.model | |
del model[name] | |
def model(self) -> Dict[str, Optional[types.CollectionMetadata]]: | |
return self._model | |
def test_collections(caplog: pytest.LogCaptureFixture, api: ClientAPI) -> None: | |
caplog.set_level(logging.ERROR) | |
run_state_machine_as_test(lambda: CollectionStateMachine(api)) # type: ignore | |