Spaces:
Sleeping
Sleeping
import logging | |
import multiprocessing | |
from multiprocessing.connection import Connection | |
from typing import Generator, Callable | |
from hypothesis import given | |
import hypothesis.strategies as st | |
import pytest | |
import chromadb | |
from chromadb.api import ClientAPI, ServerAPI | |
from chromadb.config import Settings, System | |
import chromadb.test.property.strategies as strategies | |
import chromadb.test.property.invariants as invariants | |
from chromadb.test.property.test_embeddings import ( | |
EmbeddingStateMachine, | |
EmbeddingStateMachineStates, | |
collection_st as embedding_collection_st, | |
trace, | |
) | |
from hypothesis.stateful import ( | |
run_state_machine_as_test, | |
rule, | |
precondition, | |
initialize, | |
) | |
import os | |
import shutil | |
import tempfile | |
CreatePersistAPI = Callable[[], ServerAPI] | |
configurations = [ | |
Settings( | |
chroma_api_impl="chromadb.api.segment.SegmentAPI", | |
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", | |
allow_reset=True, | |
is_persistent=True, | |
persist_directory=tempfile.mkdtemp(), | |
), | |
] | |
def settings(request: pytest.FixtureRequest) -> Generator[Settings, None, None]: | |
configuration = request.param | |
save_path = configuration.persist_directory | |
# Create if it doesn't exist | |
if not os.path.exists(save_path): | |
os.makedirs(save_path, exist_ok=True) | |
yield configuration | |
# Remove if it exists | |
if os.path.exists(save_path): | |
shutil.rmtree(save_path, ignore_errors=True) | |
collection_st = st.shared( | |
strategies.collections(with_hnsw_params=True, with_persistent_hnsw_params=True), | |
key="coll", | |
) | |
def test_persist( | |
settings: Settings, | |
collection_strategy: strategies.Collection, | |
embeddings_strategy: strategies.RecordSet, | |
) -> None: | |
system_1 = System(settings) | |
api_1 = system_1.instance(ServerAPI) | |
system_1.start() | |
api_1.reset() | |
coll = api_1.create_collection( | |
name=collection_strategy.name, | |
metadata=collection_strategy.metadata, | |
embedding_function=collection_strategy.embedding_function, | |
) | |
if not invariants.is_metadata_valid(invariants.wrap_all(embeddings_strategy)): | |
with pytest.raises(Exception): | |
coll.add(**embeddings_strategy) | |
return | |
coll.add(**embeddings_strategy) | |
invariants.count(coll, embeddings_strategy) | |
invariants.metadatas_match(coll, embeddings_strategy) | |
invariants.documents_match(coll, embeddings_strategy) | |
invariants.ids_match(coll, embeddings_strategy) | |
invariants.ann_accuracy( | |
coll, | |
embeddings_strategy, | |
embedding_function=collection_strategy.embedding_function, | |
) | |
system_1.stop() | |
del api_1 | |
del system_1 | |
system_2 = System(settings) | |
api_2 = system_2.instance(ServerAPI) | |
system_2.start() | |
coll = api_2.get_collection( | |
name=collection_strategy.name, | |
embedding_function=collection_strategy.embedding_function, | |
) | |
invariants.count(coll, embeddings_strategy) | |
invariants.metadatas_match(coll, embeddings_strategy) | |
invariants.documents_match(coll, embeddings_strategy) | |
invariants.ids_match(coll, embeddings_strategy) | |
invariants.ann_accuracy( | |
coll, | |
embeddings_strategy, | |
embedding_function=collection_strategy.embedding_function, | |
) | |
system_2.stop() | |
del api_2 | |
del system_2 | |
def load_and_check( | |
settings: Settings, | |
collection_name: str, | |
record_set: strategies.RecordSet, | |
conn: Connection, | |
) -> None: | |
try: | |
system = System(settings) | |
api = system.instance(ServerAPI) | |
system.start() | |
coll = api.get_collection( | |
name=collection_name, | |
embedding_function=strategies.not_implemented_embedding_function(), | |
) | |
invariants.count(coll, record_set) | |
invariants.metadatas_match(coll, record_set) | |
invariants.documents_match(coll, record_set) | |
invariants.ids_match(coll, record_set) | |
invariants.ann_accuracy(coll, record_set) | |
system.stop() | |
except Exception as e: | |
conn.send(e) | |
raise e | |
class PersistEmbeddingsStateMachineStates(EmbeddingStateMachineStates): | |
persist = "persist" | |
class PersistEmbeddingsStateMachine(EmbeddingStateMachine): | |
def __init__(self, api: ClientAPI, settings: Settings): | |
self.api = api | |
self.settings = settings | |
self.last_persist_delay = 10 | |
self.api.reset() | |
super().__init__(self.api) | |
# type: ignore | |
def initialize( | |
self, collection: strategies.Collection, batch_size: int, sync_threshold: int | |
): | |
self.api.reset() | |
self.collection = self.api.create_collection( | |
name=collection.name, | |
metadata=collection.metadata, | |
embedding_function=collection.embedding_function, | |
) | |
self.embedding_function = collection.embedding_function | |
trace("init") | |
self.on_state_change(EmbeddingStateMachineStates.initialize) | |
self.record_set_state = strategies.StateMachineRecordSet( | |
ids=[], metadatas=[], documents=[], embeddings=[] | |
) | |
def persist(self) -> None: | |
self.on_state_change(PersistEmbeddingsStateMachineStates.persist) | |
collection_name = self.collection.name | |
# Create a new process and then inside the process run the invariants | |
# TODO: Once we switch off of duckdb and onto sqlite we can remove this | |
ctx = multiprocessing.get_context("spawn") | |
conn1, conn2 = multiprocessing.Pipe() | |
p = ctx.Process( | |
target=load_and_check, | |
args=(self.settings, collection_name, self.record_set_state, conn2), | |
) | |
p.start() | |
p.join() | |
if conn1.poll(): | |
e = conn1.recv() | |
raise e | |
p.close() | |
def on_state_change(self, new_state: str) -> None: | |
if new_state == PersistEmbeddingsStateMachineStates.persist: | |
self.last_persist_delay = 10 | |
else: | |
self.last_persist_delay -= 1 | |
def teardown(self) -> None: | |
self.api.reset() | |
def test_persist_embeddings_state( | |
caplog: pytest.LogCaptureFixture, settings: Settings | |
) -> None: | |
caplog.set_level(logging.ERROR) | |
api = chromadb.Client(settings) | |
run_state_machine_as_test( | |
lambda: PersistEmbeddingsStateMachine(settings=settings, api=api) | |
) # type: ignore | |