chroma / chromadb /test /ingest /test_producer_consumer.py
badalsahani's picture
feat: chroma initial deploy
287a0bc
raw
history blame contribute delete
13.6 kB
import asyncio
import os
import shutil
import tempfile
import pytest
from itertools import count
from typing import (
Generator,
List,
Callable,
Optional,
Dict,
Union,
Iterator,
Sequence,
Tuple,
)
from chromadb.ingest import Producer, Consumer
from chromadb.db.impl.sqlite import SqliteDB
from chromadb.ingest.impl.utils import create_topic_name
from chromadb.test.conftest import ProducerFn
from chromadb.types import (
SubmitEmbeddingRecord,
Operation,
EmbeddingRecord,
ScalarEncoding,
)
from chromadb.config import System, Settings
from pytest import FixtureRequest, approx
from asyncio import Event, wait_for, TimeoutError
import uuid
def sqlite() -> Generator[Tuple[Producer, Consumer], None, None]:
"""Fixture generator for sqlite Producer + Consumer"""
system = System(Settings(allow_reset=True))
db = system.require(SqliteDB)
system.start()
yield db, db
system.stop()
def sqlite_persistent() -> Generator[Tuple[Producer, Consumer], None, None]:
"""Fixture generator for sqlite_persistent Producer + Consumer"""
save_path = tempfile.mkdtemp()
system = System(
Settings(allow_reset=True, is_persistent=True, persist_directory=save_path)
)
db = system.require(SqliteDB)
system.start()
yield db, db
system.stop()
if os.path.exists(save_path):
shutil.rmtree(save_path)
def pulsar() -> Generator[Tuple[Producer, Consumer], None, None]:
"""Fixture generator for pulsar Producer + Consumer. This fixture requires a running
pulsar cluster. You can use bin/cluster-test.sh to start a standalone pulsar and run this test.
Assumes pulsar_broker_url etc is set from the environment variables like PULSAR_BROKER_URL.
"""
system = System(
Settings(
allow_reset=True,
chroma_producer_impl="chromadb.ingest.impl.pulsar.PulsarProducer",
chroma_consumer_impl="chromadb.ingest.impl.pulsar.PulsarConsumer",
)
)
producer = system.require(Producer)
consumer = system.require(Consumer)
system.start()
yield producer, consumer
system.stop()
def fixtures() -> List[Callable[[], Generator[Tuple[Producer, Consumer], None, None]]]:
fixtures = [sqlite, sqlite_persistent]
if "CHROMA_CLUSTER_TEST_ONLY" in os.environ:
fixtures = [pulsar]
return fixtures
@pytest.fixture(scope="module", params=fixtures())
def producer_consumer(
request: FixtureRequest,
) -> Generator[Tuple[Producer, Consumer], None, None]:
yield next(request.param())
@pytest.fixture(scope="module")
def sample_embeddings() -> Iterator[SubmitEmbeddingRecord]:
def create_record(i: int) -> SubmitEmbeddingRecord:
vector = [i + i * 0.1, i + 1 + i * 0.1]
metadata: Optional[Dict[str, Union[str, int, float]]]
if i % 2 == 0:
metadata = None
else:
metadata = {"str_key": f"value_{i}", "int_key": i, "float_key": i + i * 0.1}
record = SubmitEmbeddingRecord(
id=f"embedding_{i}",
embedding=vector,
encoding=ScalarEncoding.FLOAT32,
metadata=metadata,
operation=Operation.ADD,
collection_id=uuid.uuid4(),
)
return record
return (create_record(i) for i in count())
class CapturingConsumeFn:
embeddings: List[EmbeddingRecord]
waiters: List[Tuple[int, Event]]
def __init__(self) -> None:
"""A function that captures embeddings and allows you to wait for a certain
number of embeddings to be available. It must be constructed in the thread with
the main event loop
"""
self.embeddings = []
self.waiters = []
self._loop = asyncio.get_event_loop()
def __call__(self, embeddings: Sequence[EmbeddingRecord]) -> None:
self.embeddings.extend(embeddings)
for n, event in self.waiters:
if len(self.embeddings) >= n:
# event.set() is not thread safe, so we need to call it in the main event loop
self._loop.call_soon_threadsafe(event.set)
async def get(self, n: int, timeout_secs: int = 10) -> Sequence[EmbeddingRecord]:
"Wait until at least N embeddings are available, then return all embeddings"
if len(self.embeddings) >= n:
return self.embeddings[:n]
else:
event = Event()
self.waiters.append((n, event))
# timeout so we don't hang forever on failure
await wait_for(event.wait(), timeout_secs)
return self.embeddings[:n]
def assert_approx_equal(a: Sequence[float], b: Sequence[float]) -> None:
for i, j in zip(a, b):
assert approx(i) == approx(j)
def assert_records_match(
inserted_records: Sequence[SubmitEmbeddingRecord],
consumed_records: Sequence[EmbeddingRecord],
) -> None:
"""Given a list of inserted and consumed records, make sure they match"""
assert len(consumed_records) == len(inserted_records)
for inserted, consumed in zip(inserted_records, consumed_records):
assert inserted["id"] == consumed["id"]
assert inserted["operation"] == consumed["operation"]
assert inserted["encoding"] == consumed["encoding"]
assert inserted["metadata"] == consumed["metadata"]
if inserted["embedding"] is not None:
assert consumed["embedding"] is not None
assert_approx_equal(inserted["embedding"], consumed["embedding"])
def full_topic_name(topic_name: str) -> str:
return create_topic_name("default", "default", topic_name)
@pytest.mark.asyncio
async def test_backfill(
producer_consumer: Tuple[Producer, Consumer],
sample_embeddings: Iterator[SubmitEmbeddingRecord],
produce_fns: ProducerFn,
) -> None:
producer, consumer = producer_consumer
producer.reset_state()
consumer.reset_state()
topic_name = full_topic_name("test_topic")
producer.create_topic(topic_name)
embeddings = produce_fns(producer, topic_name, sample_embeddings, 3)[0]
consume_fn = CapturingConsumeFn()
consumer.subscribe(topic_name, consume_fn, start=consumer.min_seqid())
recieved = await consume_fn.get(3)
assert_records_match(embeddings, recieved)
@pytest.mark.asyncio
async def test_notifications(
producer_consumer: Tuple[Producer, Consumer],
sample_embeddings: Iterator[SubmitEmbeddingRecord],
) -> None:
producer, consumer = producer_consumer
producer.reset_state()
consumer.reset_state()
topic_name = full_topic_name("test_topic")
producer.create_topic(topic_name)
embeddings: List[SubmitEmbeddingRecord] = []
consume_fn = CapturingConsumeFn()
consumer.subscribe(topic_name, consume_fn, start=consumer.min_seqid())
for i in range(10):
e = next(sample_embeddings)
embeddings.append(e)
producer.submit_embedding(topic_name, e)
received = await consume_fn.get(i + 1)
assert_records_match(embeddings, received)
@pytest.mark.asyncio
async def test_multiple_topics(
producer_consumer: Tuple[Producer, Consumer],
sample_embeddings: Iterator[SubmitEmbeddingRecord],
) -> None:
producer, consumer = producer_consumer
producer.reset_state()
consumer.reset_state()
topic_name_1 = full_topic_name("test_topic_1")
topic_name_2 = full_topic_name("test_topic_2")
producer.create_topic(topic_name_1)
producer.create_topic(topic_name_2)
embeddings_1: List[SubmitEmbeddingRecord] = []
embeddings_2: List[SubmitEmbeddingRecord] = []
consume_fn_1 = CapturingConsumeFn()
consume_fn_2 = CapturingConsumeFn()
consumer.subscribe(topic_name_1, consume_fn_1, start=consumer.min_seqid())
consumer.subscribe(topic_name_2, consume_fn_2, start=consumer.min_seqid())
for i in range(10):
e_1 = next(sample_embeddings)
embeddings_1.append(e_1)
producer.submit_embedding(topic_name_1, e_1)
results_2 = await consume_fn_1.get(i + 1)
assert_records_match(embeddings_1, results_2)
e_2 = next(sample_embeddings)
embeddings_2.append(e_2)
producer.submit_embedding(topic_name_2, e_2)
results_2 = await consume_fn_2.get(i + 1)
assert_records_match(embeddings_2, results_2)
@pytest.mark.asyncio
async def test_start_seq_id(
producer_consumer: Tuple[Producer, Consumer],
sample_embeddings: Iterator[SubmitEmbeddingRecord],
produce_fns: ProducerFn,
) -> None:
producer, consumer = producer_consumer
producer.reset_state()
consumer.reset_state()
topic_name = full_topic_name("test_topic")
producer.create_topic(topic_name)
consume_fn_1 = CapturingConsumeFn()
consume_fn_2 = CapturingConsumeFn()
consumer.subscribe(topic_name, consume_fn_1, start=consumer.min_seqid())
embeddings = produce_fns(producer, topic_name, sample_embeddings, 5)[0]
results_1 = await consume_fn_1.get(5)
assert_records_match(embeddings, results_1)
start = consume_fn_1.embeddings[-1]["seq_id"]
consumer.subscribe(topic_name, consume_fn_2, start=start)
second_embeddings = produce_fns(producer, topic_name, sample_embeddings, 5)[0]
assert isinstance(embeddings, list)
embeddings.extend(second_embeddings)
results_2 = await consume_fn_2.get(5)
assert_records_match(embeddings[-5:], results_2)
@pytest.mark.asyncio
async def test_end_seq_id(
producer_consumer: Tuple[Producer, Consumer],
sample_embeddings: Iterator[SubmitEmbeddingRecord],
produce_fns: ProducerFn,
) -> None:
producer, consumer = producer_consumer
producer.reset_state()
consumer.reset_state()
topic_name = full_topic_name("test_topic")
producer.create_topic(topic_name)
consume_fn_1 = CapturingConsumeFn()
consume_fn_2 = CapturingConsumeFn()
consumer.subscribe(topic_name, consume_fn_1, start=consumer.min_seqid())
embeddings = produce_fns(producer, topic_name, sample_embeddings, 10)[0]
results_1 = await consume_fn_1.get(10)
assert_records_match(embeddings, results_1)
end = consume_fn_1.embeddings[-5]["seq_id"]
consumer.subscribe(topic_name, consume_fn_2, start=consumer.min_seqid(), end=end)
results_2 = await consume_fn_2.get(6)
assert_records_match(embeddings[:6], results_2)
# Should never produce a 7th
with pytest.raises(TimeoutError):
_ = await wait_for(consume_fn_2.get(7), timeout=1)
@pytest.mark.asyncio
async def test_submit_batch(
producer_consumer: Tuple[Producer, Consumer],
sample_embeddings: Iterator[SubmitEmbeddingRecord],
) -> None:
producer, consumer = producer_consumer
producer.reset_state()
consumer.reset_state()
topic_name = full_topic_name("test_topic")
embeddings = [next(sample_embeddings) for _ in range(100)]
producer.create_topic(topic_name)
producer.submit_embeddings(topic_name, embeddings=embeddings)
consume_fn = CapturingConsumeFn()
consumer.subscribe(topic_name, consume_fn, start=consumer.min_seqid())
recieved = await consume_fn.get(100)
assert_records_match(embeddings, recieved)
@pytest.mark.asyncio
async def test_multiple_topics_batch(
producer_consumer: Tuple[Producer, Consumer],
sample_embeddings: Iterator[SubmitEmbeddingRecord],
produce_fns: ProducerFn,
) -> None:
producer, consumer = producer_consumer
producer.reset_state()
consumer.reset_state()
N_TOPICS = 2
consume_fns = [CapturingConsumeFn() for _ in range(N_TOPICS)]
for i in range(N_TOPICS):
producer.create_topic(full_topic_name(f"test_topic_{i}"))
consumer.subscribe(
full_topic_name(f"test_topic_{i}"),
consume_fns[i],
start=consumer.min_seqid(),
)
embeddings_n: List[List[SubmitEmbeddingRecord]] = [[] for _ in range(N_TOPICS)]
PRODUCE_BATCH_SIZE = 10
N_TO_PRODUCE = 100
total_produced = 0
for i in range(N_TO_PRODUCE // PRODUCE_BATCH_SIZE):
for n in range(N_TOPICS):
embeddings_n[n].extend(
produce_fns(
producer,
full_topic_name(f"test_topic_{n}"),
sample_embeddings,
PRODUCE_BATCH_SIZE,
)[0]
)
recieved = await consume_fns[n].get(total_produced + PRODUCE_BATCH_SIZE)
assert_records_match(embeddings_n[n], recieved)
total_produced += PRODUCE_BATCH_SIZE
@pytest.mark.asyncio
async def test_max_batch_size(
producer_consumer: Tuple[Producer, Consumer],
sample_embeddings: Iterator[SubmitEmbeddingRecord],
) -> None:
producer, consumer = producer_consumer
producer.reset_state()
consumer.reset_state()
topic_name = full_topic_name("test_topic")
max_batch_size = producer.max_batch_size
assert max_batch_size > 0
# Make sure that we can produce a batch of size max_batch_size
embeddings = [next(sample_embeddings) for _ in range(max_batch_size)]
consume_fn = CapturingConsumeFn()
consumer.subscribe(topic_name, consume_fn, start=consumer.min_seqid())
producer.submit_embeddings(topic_name, embeddings=embeddings)
received = await consume_fn.get(max_batch_size, timeout_secs=120)
assert_records_match(embeddings, received)
embeddings = [next(sample_embeddings) for _ in range(max_batch_size + 1)]
# Make sure that we can't produce a batch of size > max_batch_size
with pytest.raises(ValueError) as e:
producer.submit_embeddings(topic_name, embeddings=embeddings)
assert "Cannot submit more than" in str(e.value)