import pytest from typing import Generator, List, Callable, Iterator, Type, cast from chromadb.config import System, Settings from chromadb.test.conftest import ProducerFn from chromadb.types import ( SubmitEmbeddingRecord, VectorQuery, Operation, ScalarEncoding, Segment, SegmentScope, SeqId, Vector, ) from chromadb.ingest import Producer from chromadb.segment import VectorReader import uuid import time from chromadb.segment.impl.vector.local_hnsw import ( LocalHnswSegment, ) from chromadb.segment.impl.vector.local_persistent_hnsw import ( PersistentLocalHnswSegment, ) from chromadb.test.property.strategies import test_hnsw_config from pytest import FixtureRequest from itertools import count import tempfile import os import shutil def sqlite() -> Generator[System, None, None]: """Fixture generator for sqlite DB""" save_path = tempfile.mkdtemp() settings = Settings( allow_reset=True, is_persistent=False, persist_directory=save_path, ) system = System(settings) system.start() yield system system.stop() if os.path.exists(save_path): shutil.rmtree(save_path) def sqlite_persistent() -> Generator[System, None, None]: """Fixture generator for sqlite DB""" save_path = tempfile.mkdtemp() settings = Settings( allow_reset=True, is_persistent=True, persist_directory=save_path, ) system = System(settings) system.start() yield system system.stop() if os.path.exists(save_path): shutil.rmtree(save_path) # We will excercise in memory, persistent sqlite with both ephemeral and persistent hnsw. # We technically never expose persitent sqlite with memory hnsw to users, but it's a valid # configuration, so we test it here. def system_fixtures() -> List[Callable[[], Generator[System, None, None]]]: return [sqlite, sqlite_persistent] @pytest.fixture(scope="module", params=system_fixtures()) def system(request: FixtureRequest) -> Generator[System, None, None]: yield next(request.param()) @pytest.fixture(scope="function") def sample_embeddings() -> Iterator[SubmitEmbeddingRecord]: """Generate a sequence of embeddings with the property that for each embedding (other than the first and last), it's nearest neighbor is the previous in the sequence, and it's second nearest neighbor is the subsequent""" def create_record(i: int) -> SubmitEmbeddingRecord: vector = [i**1.1, i**1.1] record = SubmitEmbeddingRecord( id=f"embedding_{i}", embedding=vector, encoding=ScalarEncoding.FLOAT32, metadata=None, operation=Operation.ADD, collection_id=uuid.UUID(int=0), ) return record return (create_record(i) for i in count()) def vector_readers() -> List[Type[VectorReader]]: return [LocalHnswSegment, PersistentLocalHnswSegment] @pytest.fixture(scope="module", params=vector_readers()) def vector_reader(request: FixtureRequest) -> Generator[Type[VectorReader], None, None]: yield request.param def create_random_segment_definition() -> Segment: return Segment( id=uuid.uuid4(), type="test_type", scope=SegmentScope.VECTOR, topic="persistent://test/test/test_topic_1", collection=None, metadata=test_hnsw_config, ) def sync(segment: VectorReader, seq_id: SeqId) -> None: # Try for up to 5 seconds, then throw a TimeoutError start = time.time() while time.time() - start < 5: if segment.max_seqid() >= seq_id: return time.sleep(0.25) raise TimeoutError(f"Timed out waiting for seq_id {seq_id}") def test_insert_and_count( system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord], vector_reader: Type[VectorReader], produce_fns: ProducerFn, ) -> None: producer = system.instance(Producer) system.reset_state() segment_definition = create_random_segment_definition() topic = str(segment_definition["topic"]) max_id = produce_fns( producer=producer, topic=topic, n=3, embeddings=sample_embeddings )[1][-1] segment = vector_reader(system, segment_definition) segment.start() sync(segment, max_id) assert segment.count() == 3 max_id = produce_fns( producer=producer, topic=topic, n=3, embeddings=sample_embeddings )[1][-1] sync(segment, max_id) assert segment.count() == 6 def approx_equal(a: float, b: float, epsilon: float = 0.0001) -> bool: return abs(a - b) < epsilon def approx_equal_vector(a: Vector, b: Vector, epsilon: float = 0.0001) -> bool: return all(approx_equal(x, y, epsilon) for x, y in zip(a, b)) def test_get_vectors( system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord], vector_reader: Type[VectorReader], produce_fns: ProducerFn, ) -> None: producer = system.instance(Producer) system.reset_state() segment_definition = create_random_segment_definition() topic = str(segment_definition["topic"]) segment = vector_reader(system, segment_definition) segment.start() embeddings, seq_ids = produce_fns( producer=producer, topic=topic, embeddings=sample_embeddings, n=10 ) sync(segment, seq_ids[-1]) # Get all items vectors = segment.get_vectors() assert len(vectors) == len(embeddings) vectors = sorted(vectors, key=lambda v: v["id"]) for actual, expected, seq_id in zip(vectors, embeddings, seq_ids): assert actual["id"] == expected["id"] assert approx_equal_vector( actual["embedding"], cast(Vector, expected["embedding"]) ) assert actual["seq_id"] == seq_id # Get selected IDs ids = [e["id"] for e in embeddings[5:]] vectors = segment.get_vectors(ids=ids) assert len(vectors) == 5 vectors = sorted(vectors, key=lambda v: v["id"]) for actual, expected, seq_id in zip(vectors, embeddings[5:], seq_ids[5:]): assert actual["id"] == expected["id"] assert approx_equal_vector( actual["embedding"], cast(Vector, expected["embedding"]) ) assert actual["seq_id"] == seq_id def test_ann_query( system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord], vector_reader: Type[VectorReader], produce_fns: ProducerFn, ) -> None: producer = system.instance(Producer) system.reset_state() segment_definition = create_random_segment_definition() topic = str(segment_definition["topic"]) segment = vector_reader(system, segment_definition) segment.start() embeddings, seq_ids = produce_fns( producer=producer, topic=topic, embeddings=sample_embeddings, n=100 ) sync(segment, seq_ids[-1]) # Each item is its own nearest neighbor (one at a time) for e in embeddings: vector = cast(Vector, e["embedding"]) query = VectorQuery( vectors=[vector], k=1, allowed_ids=None, options=None, include_embeddings=True, ) results = segment.query_vectors(query) assert len(results) == 1 assert len(results[0]) == 1 assert results[0][0]["id"] == e["id"] assert results[0][0]["embedding"] is not None assert approx_equal_vector(results[0][0]["embedding"], vector) # Each item is its own nearest neighbor (all at once) vectors = [cast(Vector, e["embedding"]) for e in embeddings] query = VectorQuery( vectors=vectors, k=1, allowed_ids=None, options=None, include_embeddings=False ) results = segment.query_vectors(query) assert len(results) == len(embeddings) for r, e in zip(results, embeddings): assert len(r) == 1 assert r[0]["id"] == e["id"] # Each item's 3 nearest neighbors are itself and the item before and after test_embeddings = embeddings[1:-1] vectors = [cast(Vector, e["embedding"]) for e in test_embeddings] query = VectorQuery( vectors=vectors, k=3, allowed_ids=None, options=None, include_embeddings=False ) results = segment.query_vectors(query) assert len(results) == len(test_embeddings) for r, e, i in zip(results, test_embeddings, range(1, len(test_embeddings))): assert len(r) == 3 assert r[0]["id"] == embeddings[i]["id"] assert r[1]["id"] == embeddings[i - 1]["id"] assert r[2]["id"] == embeddings[i + 1]["id"] def test_delete( system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord], vector_reader: Type[VectorReader], produce_fns: ProducerFn, ) -> None: producer = system.instance(Producer) system.reset_state() segment_definition = create_random_segment_definition() topic = str(segment_definition["topic"]) segment = vector_reader(system, segment_definition) segment.start() embeddings, seq_ids = produce_fns( producer=producer, topic=topic, embeddings=sample_embeddings, n=5 ) sync(segment, seq_ids[-1]) assert segment.count() == 5 delete_record = SubmitEmbeddingRecord( id=embeddings[0]["id"], embedding=None, encoding=None, metadata=None, operation=Operation.DELETE, collection_id=uuid.UUID(int=0), ) assert isinstance(seq_ids, List) seq_ids.append( produce_fns( producer=producer, topic=topic, n=1, embeddings=(delete_record for _ in range(1)), )[1][0] ) sync(segment, seq_ids[-1]) # Assert that the record is gone using `count` assert segment.count() == 4 # Assert that the record is gone using `get` assert segment.get_vectors(ids=[embeddings[0]["id"]]) == [] results = segment.get_vectors() assert len(results) == 4 # get_vectors returns results in arbitrary order results = sorted(results, key=lambda v: v["id"]) for actual, expected in zip(results, embeddings[1:]): assert actual["id"] == expected["id"] assert approx_equal_vector( actual["embedding"], cast(Vector, expected["embedding"]) ) # Assert that the record is gone from KNN search vector = cast(Vector, embeddings[0]["embedding"]) query = VectorQuery( vectors=[vector], k=10, allowed_ids=None, options=None, include_embeddings=False ) knn_results = segment.query_vectors(query) assert len(results) == 4 assert set(r["id"] for r in knn_results[0]) == set(e["id"] for e in embeddings[1:]) # Delete is idempotent seq_ids.append( produce_fns( producer=producer, topic=topic, n=1, embeddings=(delete_record for _ in range(1)), )[1][0] ) sync(segment, seq_ids[-1]) assert segment.count() == 4 def _test_update( producer: Producer, topic: str, segment: VectorReader, sample_embeddings: Iterator[SubmitEmbeddingRecord], operation: Operation, ) -> None: """Tests the common code paths between update & upsert""" embeddings = [next(sample_embeddings) for i in range(3)] seq_ids: List[SeqId] = [] for e in embeddings: seq_ids.append(producer.submit_embedding(topic, e)) sync(segment, seq_ids[-1]) assert segment.count() == 3 seq_ids.append( producer.submit_embedding( topic, SubmitEmbeddingRecord( id=embeddings[0]["id"], embedding=[10.0, 10.0], encoding=ScalarEncoding.FLOAT32, metadata=None, operation=operation, collection_id=uuid.UUID(int=0), ), ) ) sync(segment, seq_ids[-1]) # Test new data from get_vectors assert segment.count() == 3 results = segment.get_vectors() assert len(results) == 3 results = segment.get_vectors(ids=[embeddings[0]["id"]]) assert results[0]["embedding"] == [10.0, 10.0] # Test querying at the old location vector = cast(Vector, embeddings[0]["embedding"]) query = VectorQuery( vectors=[vector], k=3, allowed_ids=None, options=None, include_embeddings=False ) knn_results = segment.query_vectors(query)[0] assert knn_results[0]["id"] == embeddings[1]["id"] assert knn_results[1]["id"] == embeddings[2]["id"] assert knn_results[2]["id"] == embeddings[0]["id"] # Test querying at the new location vector = [10.0, 10.0] query = VectorQuery( vectors=[vector], k=3, allowed_ids=None, options=None, include_embeddings=False ) knn_results = segment.query_vectors(query)[0] assert knn_results[0]["id"] == embeddings[0]["id"] assert knn_results[1]["id"] == embeddings[2]["id"] assert knn_results[2]["id"] == embeddings[1]["id"] def test_update( system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord], vector_reader: Type[VectorReader], produce_fns: ProducerFn, ) -> None: producer = system.instance(Producer) system.reset_state() segment_definition = create_random_segment_definition() topic = str(segment_definition["topic"]) segment = vector_reader(system, segment_definition) segment.start() _test_update(producer, topic, segment, sample_embeddings, Operation.UPDATE) # test updating a nonexistent record update_record = SubmitEmbeddingRecord( id="no_such_record", embedding=[10.0, 10.0], encoding=ScalarEncoding.FLOAT32, metadata=None, operation=Operation.UPDATE, collection_id=uuid.UUID(int=0), ) seq_id = produce_fns( producer=producer, topic=topic, n=1, embeddings=(update_record for _ in range(1)), )[1][0] sync(segment, seq_id) assert segment.count() == 3 assert segment.get_vectors(ids=["no_such_record"]) == [] def test_upsert( system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord], vector_reader: Type[VectorReader], produce_fns: ProducerFn, ) -> None: producer = system.instance(Producer) system.reset_state() segment_definition = create_random_segment_definition() topic = str(segment_definition["topic"]) segment = vector_reader(system, segment_definition) segment.start() _test_update(producer, topic, segment, sample_embeddings, Operation.UPSERT) # test updating a nonexistent record upsert_record = SubmitEmbeddingRecord( id="no_such_record", embedding=[42, 42], encoding=ScalarEncoding.FLOAT32, metadata=None, operation=Operation.UPSERT, collection_id=uuid.UUID(int=0), ) seq_id = produce_fns( producer=producer, topic=topic, n=1, embeddings=(upsert_record for _ in range(1)), )[1][0] sync(segment, seq_id) assert segment.count() == 4 result = segment.get_vectors(ids=["no_such_record"]) assert len(result) == 1 assert approx_equal_vector(result[0]["embedding"], [42, 42]) def test_delete_without_add( system: System, vector_reader: Type[VectorReader], ) -> None: producer = system.instance(Producer) system.reset_state() segment_definition = create_random_segment_definition() topic = str(segment_definition["topic"]) segment = vector_reader(system, segment_definition) segment.start() assert segment.count() == 0 delete_record = SubmitEmbeddingRecord( id="not_in_db", embedding=None, encoding=None, metadata=None, operation=Operation.DELETE, collection_id=uuid.UUID(int=0), ) try: producer.submit_embedding(topic, delete_record) except BaseException: pytest.fail("Unexpected error. Deleting on an empty segment should not raise.") def test_delete_with_local_segment_storage( system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord], vector_reader: Type[VectorReader], produce_fns: ProducerFn, ) -> None: producer = system.instance(Producer) system.reset_state() segment_definition = create_random_segment_definition() topic = str(segment_definition["topic"]) segment = vector_reader(system, segment_definition) segment.start() embeddings, seq_ids = produce_fns( producer=producer, topic=topic, embeddings=sample_embeddings, n=5 ) sync(segment, seq_ids[-1]) assert segment.count() == 5 delete_record = SubmitEmbeddingRecord( id=embeddings[0]["id"], embedding=None, encoding=None, metadata=None, operation=Operation.DELETE, collection_id=uuid.UUID(int=0), ) assert isinstance(seq_ids, List) seq_ids.append( produce_fns( producer=producer, topic=topic, n=1, embeddings=(delete_record for _ in range(1)), )[1][0] ) sync(segment, seq_ids[-1]) # Assert that the record is gone using `count` assert segment.count() == 4 # Assert that the record is gone using `get` assert segment.get_vectors(ids=[embeddings[0]["id"]]) == [] results = segment.get_vectors() assert len(results) == 4 # get_vectors returns results in arbitrary order results = sorted(results, key=lambda v: v["id"]) for actual, expected in zip(results, embeddings[1:]): assert actual["id"] == expected["id"] assert approx_equal_vector( actual["embedding"], cast(Vector, expected["embedding"]) ) # Assert that the record is gone from KNN search vector = cast(Vector, embeddings[0]["embedding"]) query = VectorQuery( vectors=[vector], k=10, allowed_ids=None, options=None, include_embeddings=False ) knn_results = segment.query_vectors(query) assert len(results) == 4 assert set(r["id"] for r in knn_results[0]) == set(e["id"] for e in embeddings[1:]) # Delete is idempotent if isinstance(segment, PersistentLocalHnswSegment): assert os.path.exists(segment._get_storage_folder()) segment.delete() assert not os.path.exists(segment._get_storage_folder()) segment.delete() # should not raise elif isinstance(segment, LocalHnswSegment): with pytest.raises(NotImplementedError): segment.delete() def test_reset_state_ignored_for_allow_reset_false( system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord], vector_reader: Type[VectorReader], produce_fns: ProducerFn, ) -> None: producer = system.instance(Producer) system.reset_state() segment_definition = create_random_segment_definition() topic = str(segment_definition["topic"]) segment = vector_reader(system, segment_definition) segment.start() embeddings, seq_ids = produce_fns( producer=producer, topic=topic, embeddings=sample_embeddings, n=5 ) sync(segment, seq_ids[-1]) assert segment.count() == 5 delete_record = SubmitEmbeddingRecord( id=embeddings[0]["id"], embedding=None, encoding=None, metadata=None, operation=Operation.DELETE, collection_id=uuid.UUID(int=0), ) assert isinstance(seq_ids, List) seq_ids.append( produce_fns( producer=producer, topic=topic, n=1, embeddings=(delete_record for _ in range(1)), )[1][0] ) sync(segment, seq_ids[-1]) # Assert that the record is gone using `count` assert segment.count() == 4 # Assert that the record is gone using `get` assert segment.get_vectors(ids=[embeddings[0]["id"]]) == [] results = segment.get_vectors() assert len(results) == 4 # get_vectors returns results in arbitrary order results = sorted(results, key=lambda v: v["id"]) for actual, expected in zip(results, embeddings[1:]): assert actual["id"] == expected["id"] assert approx_equal_vector( actual["embedding"], cast(Vector, expected["embedding"]) ) # Assert that the record is gone from KNN search vector = cast(Vector, embeddings[0]["embedding"]) query = VectorQuery( vectors=[vector], k=10, allowed_ids=None, options=None, include_embeddings=False ) knn_results = segment.query_vectors(query) assert len(results) == 4 assert set(r["id"] for r in knn_results[0]) == set(e["id"] for e in embeddings[1:]) if isinstance(segment, PersistentLocalHnswSegment): if segment._allow_reset: assert os.path.exists(segment._get_storage_folder()) segment.reset_state() assert not os.path.exists(segment._get_storage_folder()) else: assert os.path.exists(segment._get_storage_folder()) segment.reset_state() assert os.path.exists(segment._get_storage_folder())