chroma / chromadb /test /conftest.py
badalsahani's picture
feat: chroma initial deploy
287a0bc
import logging
import multiprocessing
import os
import shutil
import socket
import subprocess
import tempfile
import time
from typing import (
Any,
Dict,
Generator,
Iterator,
List,
Optional,
Sequence,
Tuple,
Callable,
)
import hypothesis
import pytest
import uvicorn
from requests.exceptions import ConnectionError
from typing_extensions import Protocol
import chromadb.server.fastapi
from chromadb.api import ClientAPI, ServerAPI
from chromadb.config import Settings, System
from chromadb.db.mixins import embeddings_queue
from chromadb.ingest import Producer
from chromadb.types import SeqId, SubmitEmbeddingRecord
from chromadb.api.client import Client as ClientCreator
root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG) # This will only run when testing
logger = logging.getLogger(__name__)
hypothesis.settings.register_profile(
"dev",
deadline=45000,
suppress_health_check=[
hypothesis.HealthCheck.data_too_large,
hypothesis.HealthCheck.large_base_example,
hypothesis.HealthCheck.function_scoped_fixture,
],
)
hypothesis.settings.load_profile(os.getenv("HYPOTHESIS_PROFILE", "dev"))
NOT_CLUSTER_ONLY = os.getenv("CHROMA_CLUSTER_TEST_ONLY") != "1"
def skip_if_not_cluster() -> pytest.MarkDecorator:
return pytest.mark.skipif(
NOT_CLUSTER_ONLY,
reason="Requires Kubernetes to be running with a valid config",
)
def generate_self_signed_certificate() -> None:
config_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "openssl.cnf"
)
print(f"Config path: {config_path}") # Debug print to verify path
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config file not found at {config_path}")
subprocess.run(
[
"openssl",
"req",
"-x509",
"-newkey",
"rsa:4096",
"-keyout",
"serverkey.pem",
"-out",
"servercert.pem",
"-days",
"365",
"-nodes",
"-subj",
"/CN=localhost",
"-config",
config_path,
]
)
def find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1] # type: ignore
def _run_server(
port: int,
is_persistent: bool = False,
persist_directory: Optional[str] = None,
chroma_server_auth_provider: Optional[str] = None,
chroma_server_auth_credentials_provider: Optional[str] = None,
chroma_server_auth_credentials_file: Optional[str] = None,
chroma_server_auth_credentials: Optional[str] = None,
chroma_server_auth_token_transport_header: Optional[str] = None,
chroma_server_authz_provider: Optional[str] = None,
chroma_server_authz_config_file: Optional[str] = None,
chroma_server_authz_config: Optional[Dict[str, Any]] = None,
chroma_server_ssl_certfile: Optional[str] = None,
chroma_server_ssl_keyfile: Optional[str] = None,
) -> None:
"""Run a Chroma server locally"""
if is_persistent and persist_directory:
settings = Settings(
chroma_api_impl="chromadb.api.segment.SegmentAPI",
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager",
is_persistent=is_persistent,
persist_directory=persist_directory,
allow_reset=True,
chroma_server_auth_provider=chroma_server_auth_provider,
chroma_server_auth_credentials_provider=chroma_server_auth_credentials_provider,
chroma_server_auth_credentials_file=chroma_server_auth_credentials_file,
chroma_server_auth_credentials=chroma_server_auth_credentials,
chroma_server_auth_token_transport_header=chroma_server_auth_token_transport_header,
chroma_server_authz_provider=chroma_server_authz_provider,
chroma_server_authz_config_file=chroma_server_authz_config_file,
chroma_server_authz_config=chroma_server_authz_config,
)
else:
settings = Settings(
chroma_api_impl="chromadb.api.segment.SegmentAPI",
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager",
is_persistent=False,
allow_reset=True,
chroma_server_auth_provider=chroma_server_auth_provider,
chroma_server_auth_credentials_provider=chroma_server_auth_credentials_provider,
chroma_server_auth_credentials_file=chroma_server_auth_credentials_file,
chroma_server_auth_credentials=chroma_server_auth_credentials,
chroma_server_auth_token_transport_header=chroma_server_auth_token_transport_header,
chroma_server_authz_provider=chroma_server_authz_provider,
chroma_server_authz_config_file=chroma_server_authz_config_file,
chroma_server_authz_config=chroma_server_authz_config,
)
server = chromadb.server.fastapi.FastAPI(settings)
uvicorn.run(
server.app(),
host="0.0.0.0",
port=port,
log_level="error",
timeout_keep_alive=30,
ssl_keyfile=chroma_server_ssl_keyfile,
ssl_certfile=chroma_server_ssl_certfile,
)
def _await_server(api: ServerAPI, attempts: int = 0) -> None:
try:
api.heartbeat()
except ConnectionError as e:
if attempts > 15:
logger.error("Test server failed to start after 15 attempts")
raise e
else:
logger.info("Waiting for server to start...")
time.sleep(4)
_await_server(api, attempts + 1)
def _fastapi_fixture(
is_persistent: bool = False,
chroma_server_auth_provider: Optional[str] = None,
chroma_server_auth_credentials_provider: Optional[str] = None,
chroma_client_auth_provider: Optional[str] = None,
chroma_server_auth_credentials_file: Optional[str] = None,
chroma_client_auth_credentials: Optional[str] = None,
chroma_server_auth_credentials: Optional[str] = None,
chroma_client_auth_token_transport_header: Optional[str] = None,
chroma_server_auth_token_transport_header: Optional[str] = None,
chroma_server_authz_provider: Optional[str] = None,
chroma_server_authz_config_file: Optional[str] = None,
chroma_server_authz_config: Optional[Dict[str, Any]] = None,
chroma_server_ssl_certfile: Optional[str] = None,
chroma_server_ssl_keyfile: Optional[str] = None,
) -> Generator[System, None, None]:
"""Fixture generator that launches a server in a separate process, and yields a
fastapi client connect to it"""
port = find_free_port()
logger.info(f"Running test FastAPI server on port {port}")
ctx = multiprocessing.get_context("spawn")
args: Tuple[
int,
bool,
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[Dict[str, Any]],
Optional[str],
Optional[str],
] = (
port,
False,
None,
chroma_server_auth_provider,
chroma_server_auth_credentials_provider,
chroma_server_auth_credentials_file,
chroma_server_auth_credentials,
chroma_server_auth_token_transport_header,
chroma_server_authz_provider,
chroma_server_authz_config_file,
chroma_server_authz_config,
chroma_server_ssl_certfile,
chroma_server_ssl_keyfile,
)
persist_directory = None
if is_persistent:
persist_directory = tempfile.mkdtemp()
args = (
port,
is_persistent,
persist_directory,
chroma_server_auth_provider,
chroma_server_auth_credentials_provider,
chroma_server_auth_credentials_file,
chroma_server_auth_credentials,
chroma_server_auth_token_transport_header,
chroma_server_authz_provider,
chroma_server_authz_config_file,
chroma_server_authz_config,
chroma_server_ssl_certfile,
chroma_server_ssl_keyfile,
)
proc = ctx.Process(target=_run_server, args=args, daemon=True)
proc.start()
settings = Settings(
chroma_api_impl="chromadb.api.fastapi.FastAPI",
chroma_server_host="localhost",
chroma_server_http_port=str(port),
allow_reset=True,
chroma_client_auth_provider=chroma_client_auth_provider,
chroma_client_auth_credentials=chroma_client_auth_credentials,
chroma_client_auth_token_transport_header=chroma_client_auth_token_transport_header,
chroma_server_ssl_verify=chroma_server_ssl_certfile,
chroma_server_ssl_enabled=True if chroma_server_ssl_certfile else False,
)
system = System(settings)
api = system.instance(ServerAPI)
system.start()
_await_server(api)
yield system
system.stop()
proc.kill()
if is_persistent and persist_directory is not None:
if os.path.exists(persist_directory):
shutil.rmtree(persist_directory)
def fastapi() -> Generator[System, None, None]:
return _fastapi_fixture(is_persistent=False)
def fastapi_persistent() -> Generator[System, None, None]:
return _fastapi_fixture(is_persistent=True)
def fastapi_ssl() -> Generator[System, None, None]:
generate_self_signed_certificate()
return _fastapi_fixture(
is_persistent=False,
chroma_server_ssl_certfile="./servercert.pem",
chroma_server_ssl_keyfile="./serverkey.pem",
)
def basic_http_client() -> Generator[System, None, None]:
settings = Settings(
chroma_api_impl="chromadb.api.fastapi.FastAPI",
chroma_server_http_port="8000",
allow_reset=True,
)
system = System(settings)
api = system.instance(ServerAPI)
_await_server(api)
system.start()
yield system
system.stop()
def fastapi_server_basic_auth() -> Generator[System, None, None]:
server_auth_file = os.path.abspath(os.path.join(".", "server.htpasswd"))
with open(server_auth_file, "w") as f:
f.write("admin:$2y$05$e5sRb6NCcSH3YfbIxe1AGu2h5K7OOd982OXKmd8WyQ3DRQ4MvpnZS\n")
for item in _fastapi_fixture(
is_persistent=False,
chroma_server_auth_provider="chromadb.auth.basic.BasicAuthServerProvider",
chroma_server_auth_credentials_provider="chromadb.auth.providers.HtpasswdFileServerAuthCredentialsProvider",
chroma_server_auth_credentials_file="./server.htpasswd",
chroma_client_auth_provider="chromadb.auth.basic.BasicAuthClientProvider",
chroma_client_auth_credentials="admin:admin",
):
yield item
os.remove(server_auth_file)
def fastapi_server_basic_auth_param() -> Generator[System, None, None]:
server_auth_file = os.path.abspath(os.path.join(".", "server.htpasswd"))
with open(server_auth_file, "w") as f:
f.write("admin:$2y$05$e5sRb6NCcSH3YfbIxe1AGu2h5K7OOd982OXKmd8WyQ3DRQ4MvpnZS\n")
for item in _fastapi_fixture(
is_persistent=False,
chroma_server_auth_provider="chromadb.auth.basic.BasicAuthServerProvider",
chroma_server_auth_credentials_provider="chromadb.auth.providers.HtpasswdFileServerAuthCredentialsProvider",
chroma_server_auth_credentials_file="./server.htpasswd",
chroma_client_auth_provider="chromadb.auth.basic.BasicAuthClientProvider",
chroma_client_auth_credentials="admin:admin",
):
yield item
os.remove(server_auth_file)
# TODO we need a generator for auth providers
def fastapi_server_basic_auth_file() -> Generator[System, None, None]:
server_auth_file = os.path.abspath(os.path.join(".", "server.htpasswd"))
with open(server_auth_file, "w") as f:
f.write("admin:$2y$05$e5sRb6NCcSH3YfbIxe1AGu2h5K7OOd982OXKmd8WyQ3DRQ4MvpnZS\n")
for item in _fastapi_fixture(
is_persistent=False,
chroma_server_auth_provider="chromadb.auth.basic.BasicAuthServerProvider",
chroma_server_auth_credentials_provider="chromadb.auth.providers.HtpasswdFileServerAuthCredentialsProvider",
chroma_server_auth_credentials_file="./server.htpasswd",
chroma_client_auth_provider="chromadb.auth.basic.BasicAuthClientProvider",
chroma_client_auth_credentials="admin:admin",
):
yield item
os.remove(server_auth_file)
def fastapi_server_basic_auth_shorthand() -> Generator[System, None, None]:
server_auth_file = os.path.abspath(os.path.join(".", "server.htpasswd"))
with open(server_auth_file, "w") as f:
f.write("admin:$2y$05$e5sRb6NCcSH3YfbIxe1AGu2h5K7OOd982OXKmd8WyQ3DRQ4MvpnZS\n")
for item in _fastapi_fixture(
is_persistent=False,
chroma_server_auth_provider="basic",
chroma_server_auth_credentials_provider="htpasswd_file",
chroma_server_auth_credentials_file="./server.htpasswd",
chroma_client_auth_provider="basic",
chroma_client_auth_credentials="admin:admin",
):
yield item
os.remove(server_auth_file)
def fastapi_server_basic_auth_invalid_cred() -> Generator[System, None, None]:
server_auth_file = os.path.abspath(os.path.join(".", "server.htpasswd"))
with open(server_auth_file, "w") as f:
f.write("admin:$2y$05$e5sRb6NCcSH3YfbIxe1AGu2h5K7OOd982OXKmd8WyQ3DRQ4MvpnZS\n")
for item in _fastapi_fixture(
is_persistent=False,
chroma_server_auth_provider="chromadb.auth.basic.BasicAuthServerProvider",
chroma_server_auth_credentials_provider="chromadb.auth.providers.HtpasswdFileServerAuthCredentialsProvider",
chroma_server_auth_credentials_file="./server.htpasswd",
chroma_client_auth_provider="chromadb.auth.basic.BasicAuthClientProvider",
chroma_client_auth_credentials="admin:admin1",
):
yield item
os.remove(server_auth_file)
def integration() -> Generator[System, None, None]:
"""Fixture generator for returning a client configured via environmenet
variables, intended for externally configured integration tests
"""
settings = Settings(allow_reset=True)
system = System(settings)
system.start()
yield system
system.stop()
def sqlite() -> Generator[System, None, None]:
"""Fixture generator for segment-based API using in-memory Sqlite"""
settings = Settings(
chroma_api_impl="chromadb.api.segment.SegmentAPI",
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager",
is_persistent=False,
allow_reset=True,
)
system = System(settings)
system.start()
yield system
system.stop()
def sqlite_persistent() -> Generator[System, None, None]:
"""Fixture generator for segment-based API using persistent Sqlite"""
save_path = tempfile.mkdtemp()
settings = Settings(
chroma_api_impl="chromadb.api.segment.SegmentAPI",
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager",
allow_reset=True,
is_persistent=True,
persist_directory=save_path,
)
system = System(settings)
system.start()
yield system
system.stop()
if os.path.exists(save_path):
shutil.rmtree(save_path)
def system_fixtures() -> List[Callable[[], Generator[System, None, None]]]:
fixtures = [fastapi, fastapi_persistent, sqlite, sqlite_persistent]
if "CHROMA_INTEGRATION_TEST" in os.environ:
fixtures.append(integration)
if "CHROMA_INTEGRATION_TEST_ONLY" in os.environ:
fixtures = [integration]
if "CHROMA_CLUSTER_TEST_ONLY" in os.environ:
fixtures = [basic_http_client]
return fixtures
def system_fixtures_auth() -> List[Callable[[], Generator[System, None, None]]]:
fixtures = [
fastapi_server_basic_auth_param,
fastapi_server_basic_auth_file,
fastapi_server_basic_auth_shorthand,
]
return fixtures
def system_fixtures_wrong_auth() -> List[Callable[[], Generator[System, None, None]]]:
fixtures = [fastapi_server_basic_auth_invalid_cred]
return fixtures
def system_fixtures_ssl() -> List[Callable[[], Generator[System, None, None]]]:
fixtures = [fastapi_ssl]
return fixtures
@pytest.fixture(scope="module", params=system_fixtures_wrong_auth())
def system_wrong_auth(
request: pytest.FixtureRequest,
) -> Generator[ServerAPI, None, None]:
yield next(request.param())
@pytest.fixture(scope="module", params=system_fixtures())
def system(request: pytest.FixtureRequest) -> Generator[ServerAPI, None, None]:
yield next(request.param())
@pytest.fixture(scope="module", params=system_fixtures_ssl())
def system_ssl(request: pytest.FixtureRequest) -> Generator[ServerAPI, None, None]:
yield next(request.param())
@pytest.fixture(scope="module", params=system_fixtures_auth())
def system_auth(request: pytest.FixtureRequest) -> Generator[ServerAPI, None, None]:
yield next(request.param())
@pytest.fixture(scope="function")
def api(system: System) -> Generator[ServerAPI, None, None]:
system.reset_state()
api = system.instance(ServerAPI)
yield api
@pytest.fixture(scope="function")
def client(system: System) -> Generator[ClientAPI, None, None]:
system.reset_state()
client = ClientCreator.from_system(system)
yield client
client.clear_system_cache()
@pytest.fixture(scope="function")
def client_ssl(system_ssl: System) -> Generator[ClientAPI, None, None]:
system_ssl.reset_state()
client = ClientCreator.from_system(system_ssl)
yield client
client.clear_system_cache()
@pytest.fixture(scope="function")
def api_wrong_cred(
system_wrong_auth: System,
) -> Generator[ServerAPI, None, None]:
system_wrong_auth.reset_state()
api = system_wrong_auth.instance(ServerAPI)
yield api
@pytest.fixture(scope="function")
def api_with_server_auth(system_auth: System) -> Generator[ServerAPI, None, None]:
_sys = system_auth
_sys.reset_state()
api = _sys.instance(ServerAPI)
yield api
# Producer / Consumer fixtures #
class ProducerFn(Protocol):
def __call__(
self,
producer: Producer,
topic: str,
embeddings: Iterator[SubmitEmbeddingRecord],
n: int,
) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]:
...
def produce_n_single(
producer: Producer,
topic: str,
embeddings: Iterator[SubmitEmbeddingRecord],
n: int,
) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]:
submitted_embeddings = []
seq_ids = []
for _ in range(n):
e = next(embeddings)
seq_id = producer.submit_embedding(topic, e)
submitted_embeddings.append(e)
seq_ids.append(seq_id)
return submitted_embeddings, seq_ids
def produce_n_batch(
producer: Producer,
topic: str,
embeddings: Iterator[SubmitEmbeddingRecord],
n: int,
) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]:
submitted_embeddings = []
seq_ids: Sequence[SeqId] = []
for _ in range(n):
e = next(embeddings)
submitted_embeddings.append(e)
seq_ids = producer.submit_embeddings(topic, submitted_embeddings)
return submitted_embeddings, seq_ids
def produce_fn_fixtures() -> List[ProducerFn]:
return [produce_n_single, produce_n_batch]
@pytest.fixture(scope="module", params=produce_fn_fixtures())
def produce_fns(
request: pytest.FixtureRequest,
) -> Generator[ProducerFn, None, None]:
yield request.param
def pytest_configure(config): # type: ignore
embeddings_queue._called_from_test = True