Spaces:
Sleeping
Sleeping
import asyncio | |
import os | |
import shutil | |
import tempfile | |
import pytest | |
from itertools import count | |
from typing import ( | |
Generator, | |
List, | |
Callable, | |
Optional, | |
Dict, | |
Union, | |
Iterator, | |
Sequence, | |
Tuple, | |
) | |
from chromadb.ingest import Producer, Consumer | |
from chromadb.db.impl.sqlite import SqliteDB | |
from chromadb.ingest.impl.utils import create_topic_name | |
from chromadb.test.conftest import ProducerFn | |
from chromadb.types import ( | |
SubmitEmbeddingRecord, | |
Operation, | |
EmbeddingRecord, | |
ScalarEncoding, | |
) | |
from chromadb.config import System, Settings | |
from pytest import FixtureRequest, approx | |
from asyncio import Event, wait_for, TimeoutError | |
import uuid | |
def sqlite() -> Generator[Tuple[Producer, Consumer], None, None]: | |
"""Fixture generator for sqlite Producer + Consumer""" | |
system = System(Settings(allow_reset=True)) | |
db = system.require(SqliteDB) | |
system.start() | |
yield db, db | |
system.stop() | |
def sqlite_persistent() -> Generator[Tuple[Producer, Consumer], None, None]: | |
"""Fixture generator for sqlite_persistent Producer + Consumer""" | |
save_path = tempfile.mkdtemp() | |
system = System( | |
Settings(allow_reset=True, is_persistent=True, persist_directory=save_path) | |
) | |
db = system.require(SqliteDB) | |
system.start() | |
yield db, db | |
system.stop() | |
if os.path.exists(save_path): | |
shutil.rmtree(save_path) | |
def pulsar() -> Generator[Tuple[Producer, Consumer], None, None]: | |
"""Fixture generator for pulsar Producer + Consumer. This fixture requires a running | |
pulsar cluster. You can use bin/cluster-test.sh to start a standalone pulsar and run this test. | |
Assumes pulsar_broker_url etc is set from the environment variables like PULSAR_BROKER_URL. | |
""" | |
system = System( | |
Settings( | |
allow_reset=True, | |
chroma_producer_impl="chromadb.ingest.impl.pulsar.PulsarProducer", | |
chroma_consumer_impl="chromadb.ingest.impl.pulsar.PulsarConsumer", | |
) | |
) | |
producer = system.require(Producer) | |
consumer = system.require(Consumer) | |
system.start() | |
yield producer, consumer | |
system.stop() | |
def fixtures() -> List[Callable[[], Generator[Tuple[Producer, Consumer], None, None]]]: | |
fixtures = [sqlite, sqlite_persistent] | |
if "CHROMA_CLUSTER_TEST_ONLY" in os.environ: | |
fixtures = [pulsar] | |
return fixtures | |
def producer_consumer( | |
request: FixtureRequest, | |
) -> Generator[Tuple[Producer, Consumer], None, None]: | |
yield next(request.param()) | |
def sample_embeddings() -> Iterator[SubmitEmbeddingRecord]: | |
def create_record(i: int) -> SubmitEmbeddingRecord: | |
vector = [i + i * 0.1, i + 1 + i * 0.1] | |
metadata: Optional[Dict[str, Union[str, int, float]]] | |
if i % 2 == 0: | |
metadata = None | |
else: | |
metadata = {"str_key": f"value_{i}", "int_key": i, "float_key": i + i * 0.1} | |
record = SubmitEmbeddingRecord( | |
id=f"embedding_{i}", | |
embedding=vector, | |
encoding=ScalarEncoding.FLOAT32, | |
metadata=metadata, | |
operation=Operation.ADD, | |
collection_id=uuid.uuid4(), | |
) | |
return record | |
return (create_record(i) for i in count()) | |
class CapturingConsumeFn: | |
embeddings: List[EmbeddingRecord] | |
waiters: List[Tuple[int, Event]] | |
def __init__(self) -> None: | |
"""A function that captures embeddings and allows you to wait for a certain | |
number of embeddings to be available. It must be constructed in the thread with | |
the main event loop | |
""" | |
self.embeddings = [] | |
self.waiters = [] | |
self._loop = asyncio.get_event_loop() | |
def __call__(self, embeddings: Sequence[EmbeddingRecord]) -> None: | |
self.embeddings.extend(embeddings) | |
for n, event in self.waiters: | |
if len(self.embeddings) >= n: | |
# event.set() is not thread safe, so we need to call it in the main event loop | |
self._loop.call_soon_threadsafe(event.set) | |
async def get(self, n: int, timeout_secs: int = 10) -> Sequence[EmbeddingRecord]: | |
"Wait until at least N embeddings are available, then return all embeddings" | |
if len(self.embeddings) >= n: | |
return self.embeddings[:n] | |
else: | |
event = Event() | |
self.waiters.append((n, event)) | |
# timeout so we don't hang forever on failure | |
await wait_for(event.wait(), timeout_secs) | |
return self.embeddings[:n] | |
def assert_approx_equal(a: Sequence[float], b: Sequence[float]) -> None: | |
for i, j in zip(a, b): | |
assert approx(i) == approx(j) | |
def assert_records_match( | |
inserted_records: Sequence[SubmitEmbeddingRecord], | |
consumed_records: Sequence[EmbeddingRecord], | |
) -> None: | |
"""Given a list of inserted and consumed records, make sure they match""" | |
assert len(consumed_records) == len(inserted_records) | |
for inserted, consumed in zip(inserted_records, consumed_records): | |
assert inserted["id"] == consumed["id"] | |
assert inserted["operation"] == consumed["operation"] | |
assert inserted["encoding"] == consumed["encoding"] | |
assert inserted["metadata"] == consumed["metadata"] | |
if inserted["embedding"] is not None: | |
assert consumed["embedding"] is not None | |
assert_approx_equal(inserted["embedding"], consumed["embedding"]) | |
def full_topic_name(topic_name: str) -> str: | |
return create_topic_name("default", "default", topic_name) | |
async def test_backfill( | |
producer_consumer: Tuple[Producer, Consumer], | |
sample_embeddings: Iterator[SubmitEmbeddingRecord], | |
produce_fns: ProducerFn, | |
) -> None: | |
producer, consumer = producer_consumer | |
producer.reset_state() | |
consumer.reset_state() | |
topic_name = full_topic_name("test_topic") | |
producer.create_topic(topic_name) | |
embeddings = produce_fns(producer, topic_name, sample_embeddings, 3)[0] | |
consume_fn = CapturingConsumeFn() | |
consumer.subscribe(topic_name, consume_fn, start=consumer.min_seqid()) | |
recieved = await consume_fn.get(3) | |
assert_records_match(embeddings, recieved) | |
async def test_notifications( | |
producer_consumer: Tuple[Producer, Consumer], | |
sample_embeddings: Iterator[SubmitEmbeddingRecord], | |
) -> None: | |
producer, consumer = producer_consumer | |
producer.reset_state() | |
consumer.reset_state() | |
topic_name = full_topic_name("test_topic") | |
producer.create_topic(topic_name) | |
embeddings: List[SubmitEmbeddingRecord] = [] | |
consume_fn = CapturingConsumeFn() | |
consumer.subscribe(topic_name, consume_fn, start=consumer.min_seqid()) | |
for i in range(10): | |
e = next(sample_embeddings) | |
embeddings.append(e) | |
producer.submit_embedding(topic_name, e) | |
received = await consume_fn.get(i + 1) | |
assert_records_match(embeddings, received) | |
async def test_multiple_topics( | |
producer_consumer: Tuple[Producer, Consumer], | |
sample_embeddings: Iterator[SubmitEmbeddingRecord], | |
) -> None: | |
producer, consumer = producer_consumer | |
producer.reset_state() | |
consumer.reset_state() | |
topic_name_1 = full_topic_name("test_topic_1") | |
topic_name_2 = full_topic_name("test_topic_2") | |
producer.create_topic(topic_name_1) | |
producer.create_topic(topic_name_2) | |
embeddings_1: List[SubmitEmbeddingRecord] = [] | |
embeddings_2: List[SubmitEmbeddingRecord] = [] | |
consume_fn_1 = CapturingConsumeFn() | |
consume_fn_2 = CapturingConsumeFn() | |
consumer.subscribe(topic_name_1, consume_fn_1, start=consumer.min_seqid()) | |
consumer.subscribe(topic_name_2, consume_fn_2, start=consumer.min_seqid()) | |
for i in range(10): | |
e_1 = next(sample_embeddings) | |
embeddings_1.append(e_1) | |
producer.submit_embedding(topic_name_1, e_1) | |
results_2 = await consume_fn_1.get(i + 1) | |
assert_records_match(embeddings_1, results_2) | |
e_2 = next(sample_embeddings) | |
embeddings_2.append(e_2) | |
producer.submit_embedding(topic_name_2, e_2) | |
results_2 = await consume_fn_2.get(i + 1) | |
assert_records_match(embeddings_2, results_2) | |
async def test_start_seq_id( | |
producer_consumer: Tuple[Producer, Consumer], | |
sample_embeddings: Iterator[SubmitEmbeddingRecord], | |
produce_fns: ProducerFn, | |
) -> None: | |
producer, consumer = producer_consumer | |
producer.reset_state() | |
consumer.reset_state() | |
topic_name = full_topic_name("test_topic") | |
producer.create_topic(topic_name) | |
consume_fn_1 = CapturingConsumeFn() | |
consume_fn_2 = CapturingConsumeFn() | |
consumer.subscribe(topic_name, consume_fn_1, start=consumer.min_seqid()) | |
embeddings = produce_fns(producer, topic_name, sample_embeddings, 5)[0] | |
results_1 = await consume_fn_1.get(5) | |
assert_records_match(embeddings, results_1) | |
start = consume_fn_1.embeddings[-1]["seq_id"] | |
consumer.subscribe(topic_name, consume_fn_2, start=start) | |
second_embeddings = produce_fns(producer, topic_name, sample_embeddings, 5)[0] | |
assert isinstance(embeddings, list) | |
embeddings.extend(second_embeddings) | |
results_2 = await consume_fn_2.get(5) | |
assert_records_match(embeddings[-5:], results_2) | |
async def test_end_seq_id( | |
producer_consumer: Tuple[Producer, Consumer], | |
sample_embeddings: Iterator[SubmitEmbeddingRecord], | |
produce_fns: ProducerFn, | |
) -> None: | |
producer, consumer = producer_consumer | |
producer.reset_state() | |
consumer.reset_state() | |
topic_name = full_topic_name("test_topic") | |
producer.create_topic(topic_name) | |
consume_fn_1 = CapturingConsumeFn() | |
consume_fn_2 = CapturingConsumeFn() | |
consumer.subscribe(topic_name, consume_fn_1, start=consumer.min_seqid()) | |
embeddings = produce_fns(producer, topic_name, sample_embeddings, 10)[0] | |
results_1 = await consume_fn_1.get(10) | |
assert_records_match(embeddings, results_1) | |
end = consume_fn_1.embeddings[-5]["seq_id"] | |
consumer.subscribe(topic_name, consume_fn_2, start=consumer.min_seqid(), end=end) | |
results_2 = await consume_fn_2.get(6) | |
assert_records_match(embeddings[:6], results_2) | |
# Should never produce a 7th | |
with pytest.raises(TimeoutError): | |
_ = await wait_for(consume_fn_2.get(7), timeout=1) | |
async def test_submit_batch( | |
producer_consumer: Tuple[Producer, Consumer], | |
sample_embeddings: Iterator[SubmitEmbeddingRecord], | |
) -> None: | |
producer, consumer = producer_consumer | |
producer.reset_state() | |
consumer.reset_state() | |
topic_name = full_topic_name("test_topic") | |
embeddings = [next(sample_embeddings) for _ in range(100)] | |
producer.create_topic(topic_name) | |
producer.submit_embeddings(topic_name, embeddings=embeddings) | |
consume_fn = CapturingConsumeFn() | |
consumer.subscribe(topic_name, consume_fn, start=consumer.min_seqid()) | |
recieved = await consume_fn.get(100) | |
assert_records_match(embeddings, recieved) | |
async def test_multiple_topics_batch( | |
producer_consumer: Tuple[Producer, Consumer], | |
sample_embeddings: Iterator[SubmitEmbeddingRecord], | |
produce_fns: ProducerFn, | |
) -> None: | |
producer, consumer = producer_consumer | |
producer.reset_state() | |
consumer.reset_state() | |
N_TOPICS = 2 | |
consume_fns = [CapturingConsumeFn() for _ in range(N_TOPICS)] | |
for i in range(N_TOPICS): | |
producer.create_topic(full_topic_name(f"test_topic_{i}")) | |
consumer.subscribe( | |
full_topic_name(f"test_topic_{i}"), | |
consume_fns[i], | |
start=consumer.min_seqid(), | |
) | |
embeddings_n: List[List[SubmitEmbeddingRecord]] = [[] for _ in range(N_TOPICS)] | |
PRODUCE_BATCH_SIZE = 10 | |
N_TO_PRODUCE = 100 | |
total_produced = 0 | |
for i in range(N_TO_PRODUCE // PRODUCE_BATCH_SIZE): | |
for n in range(N_TOPICS): | |
embeddings_n[n].extend( | |
produce_fns( | |
producer, | |
full_topic_name(f"test_topic_{n}"), | |
sample_embeddings, | |
PRODUCE_BATCH_SIZE, | |
)[0] | |
) | |
recieved = await consume_fns[n].get(total_produced + PRODUCE_BATCH_SIZE) | |
assert_records_match(embeddings_n[n], recieved) | |
total_produced += PRODUCE_BATCH_SIZE | |
async def test_max_batch_size( | |
producer_consumer: Tuple[Producer, Consumer], | |
sample_embeddings: Iterator[SubmitEmbeddingRecord], | |
) -> None: | |
producer, consumer = producer_consumer | |
producer.reset_state() | |
consumer.reset_state() | |
topic_name = full_topic_name("test_topic") | |
max_batch_size = producer.max_batch_size | |
assert max_batch_size > 0 | |
# Make sure that we can produce a batch of size max_batch_size | |
embeddings = [next(sample_embeddings) for _ in range(max_batch_size)] | |
consume_fn = CapturingConsumeFn() | |
consumer.subscribe(topic_name, consume_fn, start=consumer.min_seqid()) | |
producer.submit_embeddings(topic_name, embeddings=embeddings) | |
received = await consume_fn.get(max_batch_size, timeout_secs=120) | |
assert_records_match(embeddings, received) | |
embeddings = [next(sample_embeddings) for _ in range(max_batch_size + 1)] | |
# Make sure that we can't produce a batch of size > max_batch_size | |
with pytest.raises(ValueError) as e: | |
producer.submit_embeddings(topic_name, embeddings=embeddings) | |
assert "Cannot submit more than" in str(e.value) | |