Spaces:
Sleeping
Sleeping
import pytest | |
from typing import Generator, List, Callable, Iterator, Type, cast | |
from chromadb.config import System, Settings | |
from chromadb.test.conftest import ProducerFn | |
from chromadb.types import ( | |
SubmitEmbeddingRecord, | |
VectorQuery, | |
Operation, | |
ScalarEncoding, | |
Segment, | |
SegmentScope, | |
SeqId, | |
Vector, | |
) | |
from chromadb.ingest import Producer | |
from chromadb.segment import VectorReader | |
import uuid | |
import time | |
from chromadb.segment.impl.vector.local_hnsw import ( | |
LocalHnswSegment, | |
) | |
from chromadb.segment.impl.vector.local_persistent_hnsw import ( | |
PersistentLocalHnswSegment, | |
) | |
from chromadb.test.property.strategies import test_hnsw_config | |
from pytest import FixtureRequest | |
from itertools import count | |
import tempfile | |
import os | |
import shutil | |
def sqlite() -> Generator[System, None, None]: | |
"""Fixture generator for sqlite DB""" | |
save_path = tempfile.mkdtemp() | |
settings = Settings( | |
allow_reset=True, | |
is_persistent=False, | |
persist_directory=save_path, | |
) | |
system = System(settings) | |
system.start() | |
yield system | |
system.stop() | |
if os.path.exists(save_path): | |
shutil.rmtree(save_path) | |
def sqlite_persistent() -> Generator[System, None, None]: | |
"""Fixture generator for sqlite DB""" | |
save_path = tempfile.mkdtemp() | |
settings = Settings( | |
allow_reset=True, | |
is_persistent=True, | |
persist_directory=save_path, | |
) | |
system = System(settings) | |
system.start() | |
yield system | |
system.stop() | |
if os.path.exists(save_path): | |
shutil.rmtree(save_path) | |
# We will excercise in memory, persistent sqlite with both ephemeral and persistent hnsw. | |
# We technically never expose persitent sqlite with memory hnsw to users, but it's a valid | |
# configuration, so we test it here. | |
def system_fixtures() -> List[Callable[[], Generator[System, None, None]]]: | |
return [sqlite, sqlite_persistent] | |
def system(request: FixtureRequest) -> Generator[System, None, None]: | |
yield next(request.param()) | |
def sample_embeddings() -> Iterator[SubmitEmbeddingRecord]: | |
"""Generate a sequence of embeddings with the property that for each embedding | |
(other than the first and last), it's nearest neighbor is the previous in the | |
sequence, and it's second nearest neighbor is the subsequent""" | |
def create_record(i: int) -> SubmitEmbeddingRecord: | |
vector = [i**1.1, i**1.1] | |
record = SubmitEmbeddingRecord( | |
id=f"embedding_{i}", | |
embedding=vector, | |
encoding=ScalarEncoding.FLOAT32, | |
metadata=None, | |
operation=Operation.ADD, | |
collection_id=uuid.UUID(int=0), | |
) | |
return record | |
return (create_record(i) for i in count()) | |
def vector_readers() -> List[Type[VectorReader]]: | |
return [LocalHnswSegment, PersistentLocalHnswSegment] | |
def vector_reader(request: FixtureRequest) -> Generator[Type[VectorReader], None, None]: | |
yield request.param | |
def create_random_segment_definition() -> Segment: | |
return Segment( | |
id=uuid.uuid4(), | |
type="test_type", | |
scope=SegmentScope.VECTOR, | |
topic="persistent://test/test/test_topic_1", | |
collection=None, | |
metadata=test_hnsw_config, | |
) | |
def sync(segment: VectorReader, seq_id: SeqId) -> None: | |
# Try for up to 5 seconds, then throw a TimeoutError | |
start = time.time() | |
while time.time() - start < 5: | |
if segment.max_seqid() >= seq_id: | |
return | |
time.sleep(0.25) | |
raise TimeoutError(f"Timed out waiting for seq_id {seq_id}") | |
def test_insert_and_count( | |
system: System, | |
sample_embeddings: Iterator[SubmitEmbeddingRecord], | |
vector_reader: Type[VectorReader], | |
produce_fns: ProducerFn, | |
) -> None: | |
producer = system.instance(Producer) | |
system.reset_state() | |
segment_definition = create_random_segment_definition() | |
topic = str(segment_definition["topic"]) | |
max_id = produce_fns( | |
producer=producer, topic=topic, n=3, embeddings=sample_embeddings | |
)[1][-1] | |
segment = vector_reader(system, segment_definition) | |
segment.start() | |
sync(segment, max_id) | |
assert segment.count() == 3 | |
max_id = produce_fns( | |
producer=producer, topic=topic, n=3, embeddings=sample_embeddings | |
)[1][-1] | |
sync(segment, max_id) | |
assert segment.count() == 6 | |
def approx_equal(a: float, b: float, epsilon: float = 0.0001) -> bool: | |
return abs(a - b) < epsilon | |
def approx_equal_vector(a: Vector, b: Vector, epsilon: float = 0.0001) -> bool: | |
return all(approx_equal(x, y, epsilon) for x, y in zip(a, b)) | |
def test_get_vectors( | |
system: System, | |
sample_embeddings: Iterator[SubmitEmbeddingRecord], | |
vector_reader: Type[VectorReader], | |
produce_fns: ProducerFn, | |
) -> None: | |
producer = system.instance(Producer) | |
system.reset_state() | |
segment_definition = create_random_segment_definition() | |
topic = str(segment_definition["topic"]) | |
segment = vector_reader(system, segment_definition) | |
segment.start() | |
embeddings, seq_ids = produce_fns( | |
producer=producer, topic=topic, embeddings=sample_embeddings, n=10 | |
) | |
sync(segment, seq_ids[-1]) | |
# Get all items | |
vectors = segment.get_vectors() | |
assert len(vectors) == len(embeddings) | |
vectors = sorted(vectors, key=lambda v: v["id"]) | |
for actual, expected, seq_id in zip(vectors, embeddings, seq_ids): | |
assert actual["id"] == expected["id"] | |
assert approx_equal_vector( | |
actual["embedding"], cast(Vector, expected["embedding"]) | |
) | |
assert actual["seq_id"] == seq_id | |
# Get selected IDs | |
ids = [e["id"] for e in embeddings[5:]] | |
vectors = segment.get_vectors(ids=ids) | |
assert len(vectors) == 5 | |
vectors = sorted(vectors, key=lambda v: v["id"]) | |
for actual, expected, seq_id in zip(vectors, embeddings[5:], seq_ids[5:]): | |
assert actual["id"] == expected["id"] | |
assert approx_equal_vector( | |
actual["embedding"], cast(Vector, expected["embedding"]) | |
) | |
assert actual["seq_id"] == seq_id | |
def test_ann_query( | |
system: System, | |
sample_embeddings: Iterator[SubmitEmbeddingRecord], | |
vector_reader: Type[VectorReader], | |
produce_fns: ProducerFn, | |
) -> None: | |
producer = system.instance(Producer) | |
system.reset_state() | |
segment_definition = create_random_segment_definition() | |
topic = str(segment_definition["topic"]) | |
segment = vector_reader(system, segment_definition) | |
segment.start() | |
embeddings, seq_ids = produce_fns( | |
producer=producer, topic=topic, embeddings=sample_embeddings, n=100 | |
) | |
sync(segment, seq_ids[-1]) | |
# Each item is its own nearest neighbor (one at a time) | |
for e in embeddings: | |
vector = cast(Vector, e["embedding"]) | |
query = VectorQuery( | |
vectors=[vector], | |
k=1, | |
allowed_ids=None, | |
options=None, | |
include_embeddings=True, | |
) | |
results = segment.query_vectors(query) | |
assert len(results) == 1 | |
assert len(results[0]) == 1 | |
assert results[0][0]["id"] == e["id"] | |
assert results[0][0]["embedding"] is not None | |
assert approx_equal_vector(results[0][0]["embedding"], vector) | |
# Each item is its own nearest neighbor (all at once) | |
vectors = [cast(Vector, e["embedding"]) for e in embeddings] | |
query = VectorQuery( | |
vectors=vectors, k=1, allowed_ids=None, options=None, include_embeddings=False | |
) | |
results = segment.query_vectors(query) | |
assert len(results) == len(embeddings) | |
for r, e in zip(results, embeddings): | |
assert len(r) == 1 | |
assert r[0]["id"] == e["id"] | |
# Each item's 3 nearest neighbors are itself and the item before and after | |
test_embeddings = embeddings[1:-1] | |
vectors = [cast(Vector, e["embedding"]) for e in test_embeddings] | |
query = VectorQuery( | |
vectors=vectors, k=3, allowed_ids=None, options=None, include_embeddings=False | |
) | |
results = segment.query_vectors(query) | |
assert len(results) == len(test_embeddings) | |
for r, e, i in zip(results, test_embeddings, range(1, len(test_embeddings))): | |
assert len(r) == 3 | |
assert r[0]["id"] == embeddings[i]["id"] | |
assert r[1]["id"] == embeddings[i - 1]["id"] | |
assert r[2]["id"] == embeddings[i + 1]["id"] | |
def test_delete( | |
system: System, | |
sample_embeddings: Iterator[SubmitEmbeddingRecord], | |
vector_reader: Type[VectorReader], | |
produce_fns: ProducerFn, | |
) -> None: | |
producer = system.instance(Producer) | |
system.reset_state() | |
segment_definition = create_random_segment_definition() | |
topic = str(segment_definition["topic"]) | |
segment = vector_reader(system, segment_definition) | |
segment.start() | |
embeddings, seq_ids = produce_fns( | |
producer=producer, topic=topic, embeddings=sample_embeddings, n=5 | |
) | |
sync(segment, seq_ids[-1]) | |
assert segment.count() == 5 | |
delete_record = SubmitEmbeddingRecord( | |
id=embeddings[0]["id"], | |
embedding=None, | |
encoding=None, | |
metadata=None, | |
operation=Operation.DELETE, | |
collection_id=uuid.UUID(int=0), | |
) | |
assert isinstance(seq_ids, List) | |
seq_ids.append( | |
produce_fns( | |
producer=producer, | |
topic=topic, | |
n=1, | |
embeddings=(delete_record for _ in range(1)), | |
)[1][0] | |
) | |
sync(segment, seq_ids[-1]) | |
# Assert that the record is gone using `count` | |
assert segment.count() == 4 | |
# Assert that the record is gone using `get` | |
assert segment.get_vectors(ids=[embeddings[0]["id"]]) == [] | |
results = segment.get_vectors() | |
assert len(results) == 4 | |
# get_vectors returns results in arbitrary order | |
results = sorted(results, key=lambda v: v["id"]) | |
for actual, expected in zip(results, embeddings[1:]): | |
assert actual["id"] == expected["id"] | |
assert approx_equal_vector( | |
actual["embedding"], cast(Vector, expected["embedding"]) | |
) | |
# Assert that the record is gone from KNN search | |
vector = cast(Vector, embeddings[0]["embedding"]) | |
query = VectorQuery( | |
vectors=[vector], k=10, allowed_ids=None, options=None, include_embeddings=False | |
) | |
knn_results = segment.query_vectors(query) | |
assert len(results) == 4 | |
assert set(r["id"] for r in knn_results[0]) == set(e["id"] for e in embeddings[1:]) | |
# Delete is idempotent | |
seq_ids.append( | |
produce_fns( | |
producer=producer, | |
topic=topic, | |
n=1, | |
embeddings=(delete_record for _ in range(1)), | |
)[1][0] | |
) | |
sync(segment, seq_ids[-1]) | |
assert segment.count() == 4 | |
def _test_update( | |
producer: Producer, | |
topic: str, | |
segment: VectorReader, | |
sample_embeddings: Iterator[SubmitEmbeddingRecord], | |
operation: Operation, | |
) -> None: | |
"""Tests the common code paths between update & upsert""" | |
embeddings = [next(sample_embeddings) for i in range(3)] | |
seq_ids: List[SeqId] = [] | |
for e in embeddings: | |
seq_ids.append(producer.submit_embedding(topic, e)) | |
sync(segment, seq_ids[-1]) | |
assert segment.count() == 3 | |
seq_ids.append( | |
producer.submit_embedding( | |
topic, | |
SubmitEmbeddingRecord( | |
id=embeddings[0]["id"], | |
embedding=[10.0, 10.0], | |
encoding=ScalarEncoding.FLOAT32, | |
metadata=None, | |
operation=operation, | |
collection_id=uuid.UUID(int=0), | |
), | |
) | |
) | |
sync(segment, seq_ids[-1]) | |
# Test new data from get_vectors | |
assert segment.count() == 3 | |
results = segment.get_vectors() | |
assert len(results) == 3 | |
results = segment.get_vectors(ids=[embeddings[0]["id"]]) | |
assert results[0]["embedding"] == [10.0, 10.0] | |
# Test querying at the old location | |
vector = cast(Vector, embeddings[0]["embedding"]) | |
query = VectorQuery( | |
vectors=[vector], k=3, allowed_ids=None, options=None, include_embeddings=False | |
) | |
knn_results = segment.query_vectors(query)[0] | |
assert knn_results[0]["id"] == embeddings[1]["id"] | |
assert knn_results[1]["id"] == embeddings[2]["id"] | |
assert knn_results[2]["id"] == embeddings[0]["id"] | |
# Test querying at the new location | |
vector = [10.0, 10.0] | |
query = VectorQuery( | |
vectors=[vector], k=3, allowed_ids=None, options=None, include_embeddings=False | |
) | |
knn_results = segment.query_vectors(query)[0] | |
assert knn_results[0]["id"] == embeddings[0]["id"] | |
assert knn_results[1]["id"] == embeddings[2]["id"] | |
assert knn_results[2]["id"] == embeddings[1]["id"] | |
def test_update( | |
system: System, | |
sample_embeddings: Iterator[SubmitEmbeddingRecord], | |
vector_reader: Type[VectorReader], | |
produce_fns: ProducerFn, | |
) -> None: | |
producer = system.instance(Producer) | |
system.reset_state() | |
segment_definition = create_random_segment_definition() | |
topic = str(segment_definition["topic"]) | |
segment = vector_reader(system, segment_definition) | |
segment.start() | |
_test_update(producer, topic, segment, sample_embeddings, Operation.UPDATE) | |
# test updating a nonexistent record | |
update_record = SubmitEmbeddingRecord( | |
id="no_such_record", | |
embedding=[10.0, 10.0], | |
encoding=ScalarEncoding.FLOAT32, | |
metadata=None, | |
operation=Operation.UPDATE, | |
collection_id=uuid.UUID(int=0), | |
) | |
seq_id = produce_fns( | |
producer=producer, | |
topic=topic, | |
n=1, | |
embeddings=(update_record for _ in range(1)), | |
)[1][0] | |
sync(segment, seq_id) | |
assert segment.count() == 3 | |
assert segment.get_vectors(ids=["no_such_record"]) == [] | |
def test_upsert( | |
system: System, | |
sample_embeddings: Iterator[SubmitEmbeddingRecord], | |
vector_reader: Type[VectorReader], | |
produce_fns: ProducerFn, | |
) -> None: | |
producer = system.instance(Producer) | |
system.reset_state() | |
segment_definition = create_random_segment_definition() | |
topic = str(segment_definition["topic"]) | |
segment = vector_reader(system, segment_definition) | |
segment.start() | |
_test_update(producer, topic, segment, sample_embeddings, Operation.UPSERT) | |
# test updating a nonexistent record | |
upsert_record = SubmitEmbeddingRecord( | |
id="no_such_record", | |
embedding=[42, 42], | |
encoding=ScalarEncoding.FLOAT32, | |
metadata=None, | |
operation=Operation.UPSERT, | |
collection_id=uuid.UUID(int=0), | |
) | |
seq_id = produce_fns( | |
producer=producer, | |
topic=topic, | |
n=1, | |
embeddings=(upsert_record for _ in range(1)), | |
)[1][0] | |
sync(segment, seq_id) | |
assert segment.count() == 4 | |
result = segment.get_vectors(ids=["no_such_record"]) | |
assert len(result) == 1 | |
assert approx_equal_vector(result[0]["embedding"], [42, 42]) | |
def test_delete_without_add( | |
system: System, | |
vector_reader: Type[VectorReader], | |
) -> None: | |
producer = system.instance(Producer) | |
system.reset_state() | |
segment_definition = create_random_segment_definition() | |
topic = str(segment_definition["topic"]) | |
segment = vector_reader(system, segment_definition) | |
segment.start() | |
assert segment.count() == 0 | |
delete_record = SubmitEmbeddingRecord( | |
id="not_in_db", | |
embedding=None, | |
encoding=None, | |
metadata=None, | |
operation=Operation.DELETE, | |
collection_id=uuid.UUID(int=0), | |
) | |
try: | |
producer.submit_embedding(topic, delete_record) | |
except BaseException: | |
pytest.fail("Unexpected error. Deleting on an empty segment should not raise.") | |
def test_delete_with_local_segment_storage( | |
system: System, | |
sample_embeddings: Iterator[SubmitEmbeddingRecord], | |
vector_reader: Type[VectorReader], | |
produce_fns: ProducerFn, | |
) -> None: | |
producer = system.instance(Producer) | |
system.reset_state() | |
segment_definition = create_random_segment_definition() | |
topic = str(segment_definition["topic"]) | |
segment = vector_reader(system, segment_definition) | |
segment.start() | |
embeddings, seq_ids = produce_fns( | |
producer=producer, topic=topic, embeddings=sample_embeddings, n=5 | |
) | |
sync(segment, seq_ids[-1]) | |
assert segment.count() == 5 | |
delete_record = SubmitEmbeddingRecord( | |
id=embeddings[0]["id"], | |
embedding=None, | |
encoding=None, | |
metadata=None, | |
operation=Operation.DELETE, | |
collection_id=uuid.UUID(int=0), | |
) | |
assert isinstance(seq_ids, List) | |
seq_ids.append( | |
produce_fns( | |
producer=producer, | |
topic=topic, | |
n=1, | |
embeddings=(delete_record for _ in range(1)), | |
)[1][0] | |
) | |
sync(segment, seq_ids[-1]) | |
# Assert that the record is gone using `count` | |
assert segment.count() == 4 | |
# Assert that the record is gone using `get` | |
assert segment.get_vectors(ids=[embeddings[0]["id"]]) == [] | |
results = segment.get_vectors() | |
assert len(results) == 4 | |
# get_vectors returns results in arbitrary order | |
results = sorted(results, key=lambda v: v["id"]) | |
for actual, expected in zip(results, embeddings[1:]): | |
assert actual["id"] == expected["id"] | |
assert approx_equal_vector( | |
actual["embedding"], cast(Vector, expected["embedding"]) | |
) | |
# Assert that the record is gone from KNN search | |
vector = cast(Vector, embeddings[0]["embedding"]) | |
query = VectorQuery( | |
vectors=[vector], k=10, allowed_ids=None, options=None, include_embeddings=False | |
) | |
knn_results = segment.query_vectors(query) | |
assert len(results) == 4 | |
assert set(r["id"] for r in knn_results[0]) == set(e["id"] for e in embeddings[1:]) | |
# Delete is idempotent | |
if isinstance(segment, PersistentLocalHnswSegment): | |
assert os.path.exists(segment._get_storage_folder()) | |
segment.delete() | |
assert not os.path.exists(segment._get_storage_folder()) | |
segment.delete() # should not raise | |
elif isinstance(segment, LocalHnswSegment): | |
with pytest.raises(NotImplementedError): | |
segment.delete() | |
def test_reset_state_ignored_for_allow_reset_false( | |
system: System, | |
sample_embeddings: Iterator[SubmitEmbeddingRecord], | |
vector_reader: Type[VectorReader], | |
produce_fns: ProducerFn, | |
) -> None: | |
producer = system.instance(Producer) | |
system.reset_state() | |
segment_definition = create_random_segment_definition() | |
topic = str(segment_definition["topic"]) | |
segment = vector_reader(system, segment_definition) | |
segment.start() | |
embeddings, seq_ids = produce_fns( | |
producer=producer, topic=topic, embeddings=sample_embeddings, n=5 | |
) | |
sync(segment, seq_ids[-1]) | |
assert segment.count() == 5 | |
delete_record = SubmitEmbeddingRecord( | |
id=embeddings[0]["id"], | |
embedding=None, | |
encoding=None, | |
metadata=None, | |
operation=Operation.DELETE, | |
collection_id=uuid.UUID(int=0), | |
) | |
assert isinstance(seq_ids, List) | |
seq_ids.append( | |
produce_fns( | |
producer=producer, | |
topic=topic, | |
n=1, | |
embeddings=(delete_record for _ in range(1)), | |
)[1][0] | |
) | |
sync(segment, seq_ids[-1]) | |
# Assert that the record is gone using `count` | |
assert segment.count() == 4 | |
# Assert that the record is gone using `get` | |
assert segment.get_vectors(ids=[embeddings[0]["id"]]) == [] | |
results = segment.get_vectors() | |
assert len(results) == 4 | |
# get_vectors returns results in arbitrary order | |
results = sorted(results, key=lambda v: v["id"]) | |
for actual, expected in zip(results, embeddings[1:]): | |
assert actual["id"] == expected["id"] | |
assert approx_equal_vector( | |
actual["embedding"], cast(Vector, expected["embedding"]) | |
) | |
# Assert that the record is gone from KNN search | |
vector = cast(Vector, embeddings[0]["embedding"]) | |
query = VectorQuery( | |
vectors=[vector], k=10, allowed_ids=None, options=None, include_embeddings=False | |
) | |
knn_results = segment.query_vectors(query) | |
assert len(results) == 4 | |
assert set(r["id"] for r in knn_results[0]) == set(e["id"] for e in embeddings[1:]) | |
if isinstance(segment, PersistentLocalHnswSegment): | |
if segment._allow_reset: | |
assert os.path.exists(segment._get_storage_folder()) | |
segment.reset_state() | |
assert not os.path.exists(segment._get_storage_folder()) | |
else: | |
assert os.path.exists(segment._get_storage_folder()) | |
segment.reset_state() | |
assert os.path.exists(segment._get_storage_folder()) | |