Spaces:
Sleeping
Sleeping
import logging | |
import multiprocessing | |
import os | |
import shutil | |
import socket | |
import subprocess | |
import tempfile | |
import time | |
from typing import ( | |
Any, | |
Dict, | |
Generator, | |
Iterator, | |
List, | |
Optional, | |
Sequence, | |
Tuple, | |
Callable, | |
) | |
import hypothesis | |
import pytest | |
import uvicorn | |
from requests.exceptions import ConnectionError | |
from typing_extensions import Protocol | |
import chromadb.server.fastapi | |
from chromadb.api import ClientAPI, ServerAPI | |
from chromadb.config import Settings, System | |
from chromadb.db.mixins import embeddings_queue | |
from chromadb.ingest import Producer | |
from chromadb.types import SeqId, SubmitEmbeddingRecord | |
from chromadb.api.client import Client as ClientCreator | |
root_logger = logging.getLogger() | |
root_logger.setLevel(logging.DEBUG) # This will only run when testing | |
logger = logging.getLogger(__name__) | |
hypothesis.settings.register_profile( | |
"dev", | |
deadline=45000, | |
suppress_health_check=[ | |
hypothesis.HealthCheck.data_too_large, | |
hypothesis.HealthCheck.large_base_example, | |
hypothesis.HealthCheck.function_scoped_fixture, | |
], | |
) | |
hypothesis.settings.load_profile(os.getenv("HYPOTHESIS_PROFILE", "dev")) | |
NOT_CLUSTER_ONLY = os.getenv("CHROMA_CLUSTER_TEST_ONLY") != "1" | |
def skip_if_not_cluster() -> pytest.MarkDecorator: | |
return pytest.mark.skipif( | |
NOT_CLUSTER_ONLY, | |
reason="Requires Kubernetes to be running with a valid config", | |
) | |
def generate_self_signed_certificate() -> None: | |
config_path = os.path.join( | |
os.path.dirname(os.path.abspath(__file__)), "openssl.cnf" | |
) | |
print(f"Config path: {config_path}") # Debug print to verify path | |
if not os.path.exists(config_path): | |
raise FileNotFoundError(f"Config file not found at {config_path}") | |
subprocess.run( | |
[ | |
"openssl", | |
"req", | |
"-x509", | |
"-newkey", | |
"rsa:4096", | |
"-keyout", | |
"serverkey.pem", | |
"-out", | |
"servercert.pem", | |
"-days", | |
"365", | |
"-nodes", | |
"-subj", | |
"/CN=localhost", | |
"-config", | |
config_path, | |
] | |
) | |
def find_free_port() -> int: | |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
s.bind(("", 0)) | |
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
return s.getsockname()[1] # type: ignore | |
def _run_server( | |
port: int, | |
is_persistent: bool = False, | |
persist_directory: Optional[str] = None, | |
chroma_server_auth_provider: Optional[str] = None, | |
chroma_server_auth_credentials_provider: Optional[str] = None, | |
chroma_server_auth_credentials_file: Optional[str] = None, | |
chroma_server_auth_credentials: Optional[str] = None, | |
chroma_server_auth_token_transport_header: Optional[str] = None, | |
chroma_server_authz_provider: Optional[str] = None, | |
chroma_server_authz_config_file: Optional[str] = None, | |
chroma_server_authz_config: Optional[Dict[str, Any]] = None, | |
chroma_server_ssl_certfile: Optional[str] = None, | |
chroma_server_ssl_keyfile: Optional[str] = None, | |
) -> None: | |
"""Run a Chroma server locally""" | |
if is_persistent and persist_directory: | |
settings = Settings( | |
chroma_api_impl="chromadb.api.segment.SegmentAPI", | |
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", | |
is_persistent=is_persistent, | |
persist_directory=persist_directory, | |
allow_reset=True, | |
chroma_server_auth_provider=chroma_server_auth_provider, | |
chroma_server_auth_credentials_provider=chroma_server_auth_credentials_provider, | |
chroma_server_auth_credentials_file=chroma_server_auth_credentials_file, | |
chroma_server_auth_credentials=chroma_server_auth_credentials, | |
chroma_server_auth_token_transport_header=chroma_server_auth_token_transport_header, | |
chroma_server_authz_provider=chroma_server_authz_provider, | |
chroma_server_authz_config_file=chroma_server_authz_config_file, | |
chroma_server_authz_config=chroma_server_authz_config, | |
) | |
else: | |
settings = Settings( | |
chroma_api_impl="chromadb.api.segment.SegmentAPI", | |
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", | |
is_persistent=False, | |
allow_reset=True, | |
chroma_server_auth_provider=chroma_server_auth_provider, | |
chroma_server_auth_credentials_provider=chroma_server_auth_credentials_provider, | |
chroma_server_auth_credentials_file=chroma_server_auth_credentials_file, | |
chroma_server_auth_credentials=chroma_server_auth_credentials, | |
chroma_server_auth_token_transport_header=chroma_server_auth_token_transport_header, | |
chroma_server_authz_provider=chroma_server_authz_provider, | |
chroma_server_authz_config_file=chroma_server_authz_config_file, | |
chroma_server_authz_config=chroma_server_authz_config, | |
) | |
server = chromadb.server.fastapi.FastAPI(settings) | |
uvicorn.run( | |
server.app(), | |
host="0.0.0.0", | |
port=port, | |
log_level="error", | |
timeout_keep_alive=30, | |
ssl_keyfile=chroma_server_ssl_keyfile, | |
ssl_certfile=chroma_server_ssl_certfile, | |
) | |
def _await_server(api: ServerAPI, attempts: int = 0) -> None: | |
try: | |
api.heartbeat() | |
except ConnectionError as e: | |
if attempts > 15: | |
logger.error("Test server failed to start after 15 attempts") | |
raise e | |
else: | |
logger.info("Waiting for server to start...") | |
time.sleep(4) | |
_await_server(api, attempts + 1) | |
def _fastapi_fixture( | |
is_persistent: bool = False, | |
chroma_server_auth_provider: Optional[str] = None, | |
chroma_server_auth_credentials_provider: Optional[str] = None, | |
chroma_client_auth_provider: Optional[str] = None, | |
chroma_server_auth_credentials_file: Optional[str] = None, | |
chroma_client_auth_credentials: Optional[str] = None, | |
chroma_server_auth_credentials: Optional[str] = None, | |
chroma_client_auth_token_transport_header: Optional[str] = None, | |
chroma_server_auth_token_transport_header: Optional[str] = None, | |
chroma_server_authz_provider: Optional[str] = None, | |
chroma_server_authz_config_file: Optional[str] = None, | |
chroma_server_authz_config: Optional[Dict[str, Any]] = None, | |
chroma_server_ssl_certfile: Optional[str] = None, | |
chroma_server_ssl_keyfile: Optional[str] = None, | |
) -> Generator[System, None, None]: | |
"""Fixture generator that launches a server in a separate process, and yields a | |
fastapi client connect to it""" | |
port = find_free_port() | |
logger.info(f"Running test FastAPI server on port {port}") | |
ctx = multiprocessing.get_context("spawn") | |
args: Tuple[ | |
int, | |
bool, | |
Optional[str], | |
Optional[str], | |
Optional[str], | |
Optional[str], | |
Optional[str], | |
Optional[str], | |
Optional[str], | |
Optional[str], | |
Optional[Dict[str, Any]], | |
Optional[str], | |
Optional[str], | |
] = ( | |
port, | |
False, | |
None, | |
chroma_server_auth_provider, | |
chroma_server_auth_credentials_provider, | |
chroma_server_auth_credentials_file, | |
chroma_server_auth_credentials, | |
chroma_server_auth_token_transport_header, | |
chroma_server_authz_provider, | |
chroma_server_authz_config_file, | |
chroma_server_authz_config, | |
chroma_server_ssl_certfile, | |
chroma_server_ssl_keyfile, | |
) | |
persist_directory = None | |
if is_persistent: | |
persist_directory = tempfile.mkdtemp() | |
args = ( | |
port, | |
is_persistent, | |
persist_directory, | |
chroma_server_auth_provider, | |
chroma_server_auth_credentials_provider, | |
chroma_server_auth_credentials_file, | |
chroma_server_auth_credentials, | |
chroma_server_auth_token_transport_header, | |
chroma_server_authz_provider, | |
chroma_server_authz_config_file, | |
chroma_server_authz_config, | |
chroma_server_ssl_certfile, | |
chroma_server_ssl_keyfile, | |
) | |
proc = ctx.Process(target=_run_server, args=args, daemon=True) | |
proc.start() | |
settings = Settings( | |
chroma_api_impl="chromadb.api.fastapi.FastAPI", | |
chroma_server_host="localhost", | |
chroma_server_http_port=str(port), | |
allow_reset=True, | |
chroma_client_auth_provider=chroma_client_auth_provider, | |
chroma_client_auth_credentials=chroma_client_auth_credentials, | |
chroma_client_auth_token_transport_header=chroma_client_auth_token_transport_header, | |
chroma_server_ssl_verify=chroma_server_ssl_certfile, | |
chroma_server_ssl_enabled=True if chroma_server_ssl_certfile else False, | |
) | |
system = System(settings) | |
api = system.instance(ServerAPI) | |
system.start() | |
_await_server(api) | |
yield system | |
system.stop() | |
proc.kill() | |
if is_persistent and persist_directory is not None: | |
if os.path.exists(persist_directory): | |
shutil.rmtree(persist_directory) | |
def fastapi() -> Generator[System, None, None]: | |
return _fastapi_fixture(is_persistent=False) | |
def fastapi_persistent() -> Generator[System, None, None]: | |
return _fastapi_fixture(is_persistent=True) | |
def fastapi_ssl() -> Generator[System, None, None]: | |
generate_self_signed_certificate() | |
return _fastapi_fixture( | |
is_persistent=False, | |
chroma_server_ssl_certfile="./servercert.pem", | |
chroma_server_ssl_keyfile="./serverkey.pem", | |
) | |
def basic_http_client() -> Generator[System, None, None]: | |
settings = Settings( | |
chroma_api_impl="chromadb.api.fastapi.FastAPI", | |
chroma_server_http_port="8000", | |
allow_reset=True, | |
) | |
system = System(settings) | |
api = system.instance(ServerAPI) | |
_await_server(api) | |
system.start() | |
yield system | |
system.stop() | |
def fastapi_server_basic_auth() -> Generator[System, None, None]: | |
server_auth_file = os.path.abspath(os.path.join(".", "server.htpasswd")) | |
with open(server_auth_file, "w") as f: | |
f.write("admin:$2y$05$e5sRb6NCcSH3YfbIxe1AGu2h5K7OOd982OXKmd8WyQ3DRQ4MvpnZS\n") | |
for item in _fastapi_fixture( | |
is_persistent=False, | |
chroma_server_auth_provider="chromadb.auth.basic.BasicAuthServerProvider", | |
chroma_server_auth_credentials_provider="chromadb.auth.providers.HtpasswdFileServerAuthCredentialsProvider", | |
chroma_server_auth_credentials_file="./server.htpasswd", | |
chroma_client_auth_provider="chromadb.auth.basic.BasicAuthClientProvider", | |
chroma_client_auth_credentials="admin:admin", | |
): | |
yield item | |
os.remove(server_auth_file) | |
def fastapi_server_basic_auth_param() -> Generator[System, None, None]: | |
server_auth_file = os.path.abspath(os.path.join(".", "server.htpasswd")) | |
with open(server_auth_file, "w") as f: | |
f.write("admin:$2y$05$e5sRb6NCcSH3YfbIxe1AGu2h5K7OOd982OXKmd8WyQ3DRQ4MvpnZS\n") | |
for item in _fastapi_fixture( | |
is_persistent=False, | |
chroma_server_auth_provider="chromadb.auth.basic.BasicAuthServerProvider", | |
chroma_server_auth_credentials_provider="chromadb.auth.providers.HtpasswdFileServerAuthCredentialsProvider", | |
chroma_server_auth_credentials_file="./server.htpasswd", | |
chroma_client_auth_provider="chromadb.auth.basic.BasicAuthClientProvider", | |
chroma_client_auth_credentials="admin:admin", | |
): | |
yield item | |
os.remove(server_auth_file) | |
# TODO we need a generator for auth providers | |
def fastapi_server_basic_auth_file() -> Generator[System, None, None]: | |
server_auth_file = os.path.abspath(os.path.join(".", "server.htpasswd")) | |
with open(server_auth_file, "w") as f: | |
f.write("admin:$2y$05$e5sRb6NCcSH3YfbIxe1AGu2h5K7OOd982OXKmd8WyQ3DRQ4MvpnZS\n") | |
for item in _fastapi_fixture( | |
is_persistent=False, | |
chroma_server_auth_provider="chromadb.auth.basic.BasicAuthServerProvider", | |
chroma_server_auth_credentials_provider="chromadb.auth.providers.HtpasswdFileServerAuthCredentialsProvider", | |
chroma_server_auth_credentials_file="./server.htpasswd", | |
chroma_client_auth_provider="chromadb.auth.basic.BasicAuthClientProvider", | |
chroma_client_auth_credentials="admin:admin", | |
): | |
yield item | |
os.remove(server_auth_file) | |
def fastapi_server_basic_auth_shorthand() -> Generator[System, None, None]: | |
server_auth_file = os.path.abspath(os.path.join(".", "server.htpasswd")) | |
with open(server_auth_file, "w") as f: | |
f.write("admin:$2y$05$e5sRb6NCcSH3YfbIxe1AGu2h5K7OOd982OXKmd8WyQ3DRQ4MvpnZS\n") | |
for item in _fastapi_fixture( | |
is_persistent=False, | |
chroma_server_auth_provider="basic", | |
chroma_server_auth_credentials_provider="htpasswd_file", | |
chroma_server_auth_credentials_file="./server.htpasswd", | |
chroma_client_auth_provider="basic", | |
chroma_client_auth_credentials="admin:admin", | |
): | |
yield item | |
os.remove(server_auth_file) | |
def fastapi_server_basic_auth_invalid_cred() -> Generator[System, None, None]: | |
server_auth_file = os.path.abspath(os.path.join(".", "server.htpasswd")) | |
with open(server_auth_file, "w") as f: | |
f.write("admin:$2y$05$e5sRb6NCcSH3YfbIxe1AGu2h5K7OOd982OXKmd8WyQ3DRQ4MvpnZS\n") | |
for item in _fastapi_fixture( | |
is_persistent=False, | |
chroma_server_auth_provider="chromadb.auth.basic.BasicAuthServerProvider", | |
chroma_server_auth_credentials_provider="chromadb.auth.providers.HtpasswdFileServerAuthCredentialsProvider", | |
chroma_server_auth_credentials_file="./server.htpasswd", | |
chroma_client_auth_provider="chromadb.auth.basic.BasicAuthClientProvider", | |
chroma_client_auth_credentials="admin:admin1", | |
): | |
yield item | |
os.remove(server_auth_file) | |
def integration() -> Generator[System, None, None]: | |
"""Fixture generator for returning a client configured via environmenet | |
variables, intended for externally configured integration tests | |
""" | |
settings = Settings(allow_reset=True) | |
system = System(settings) | |
system.start() | |
yield system | |
system.stop() | |
def sqlite() -> Generator[System, None, None]: | |
"""Fixture generator for segment-based API using in-memory Sqlite""" | |
settings = Settings( | |
chroma_api_impl="chromadb.api.segment.SegmentAPI", | |
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", | |
is_persistent=False, | |
allow_reset=True, | |
) | |
system = System(settings) | |
system.start() | |
yield system | |
system.stop() | |
def sqlite_persistent() -> Generator[System, None, None]: | |
"""Fixture generator for segment-based API using persistent Sqlite""" | |
save_path = tempfile.mkdtemp() | |
settings = Settings( | |
chroma_api_impl="chromadb.api.segment.SegmentAPI", | |
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", | |
allow_reset=True, | |
is_persistent=True, | |
persist_directory=save_path, | |
) | |
system = System(settings) | |
system.start() | |
yield system | |
system.stop() | |
if os.path.exists(save_path): | |
shutil.rmtree(save_path) | |
def system_fixtures() -> List[Callable[[], Generator[System, None, None]]]: | |
fixtures = [fastapi, fastapi_persistent, sqlite, sqlite_persistent] | |
if "CHROMA_INTEGRATION_TEST" in os.environ: | |
fixtures.append(integration) | |
if "CHROMA_INTEGRATION_TEST_ONLY" in os.environ: | |
fixtures = [integration] | |
if "CHROMA_CLUSTER_TEST_ONLY" in os.environ: | |
fixtures = [basic_http_client] | |
return fixtures | |
def system_fixtures_auth() -> List[Callable[[], Generator[System, None, None]]]: | |
fixtures = [ | |
fastapi_server_basic_auth_param, | |
fastapi_server_basic_auth_file, | |
fastapi_server_basic_auth_shorthand, | |
] | |
return fixtures | |
def system_fixtures_wrong_auth() -> List[Callable[[], Generator[System, None, None]]]: | |
fixtures = [fastapi_server_basic_auth_invalid_cred] | |
return fixtures | |
def system_fixtures_ssl() -> List[Callable[[], Generator[System, None, None]]]: | |
fixtures = [fastapi_ssl] | |
return fixtures | |
def system_wrong_auth( | |
request: pytest.FixtureRequest, | |
) -> Generator[ServerAPI, None, None]: | |
yield next(request.param()) | |
def system(request: pytest.FixtureRequest) -> Generator[ServerAPI, None, None]: | |
yield next(request.param()) | |
def system_ssl(request: pytest.FixtureRequest) -> Generator[ServerAPI, None, None]: | |
yield next(request.param()) | |
def system_auth(request: pytest.FixtureRequest) -> Generator[ServerAPI, None, None]: | |
yield next(request.param()) | |
def api(system: System) -> Generator[ServerAPI, None, None]: | |
system.reset_state() | |
api = system.instance(ServerAPI) | |
yield api | |
def client(system: System) -> Generator[ClientAPI, None, None]: | |
system.reset_state() | |
client = ClientCreator.from_system(system) | |
yield client | |
client.clear_system_cache() | |
def client_ssl(system_ssl: System) -> Generator[ClientAPI, None, None]: | |
system_ssl.reset_state() | |
client = ClientCreator.from_system(system_ssl) | |
yield client | |
client.clear_system_cache() | |
def api_wrong_cred( | |
system_wrong_auth: System, | |
) -> Generator[ServerAPI, None, None]: | |
system_wrong_auth.reset_state() | |
api = system_wrong_auth.instance(ServerAPI) | |
yield api | |
def api_with_server_auth(system_auth: System) -> Generator[ServerAPI, None, None]: | |
_sys = system_auth | |
_sys.reset_state() | |
api = _sys.instance(ServerAPI) | |
yield api | |
# Producer / Consumer fixtures # | |
class ProducerFn(Protocol): | |
def __call__( | |
self, | |
producer: Producer, | |
topic: str, | |
embeddings: Iterator[SubmitEmbeddingRecord], | |
n: int, | |
) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]: | |
... | |
def produce_n_single( | |
producer: Producer, | |
topic: str, | |
embeddings: Iterator[SubmitEmbeddingRecord], | |
n: int, | |
) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]: | |
submitted_embeddings = [] | |
seq_ids = [] | |
for _ in range(n): | |
e = next(embeddings) | |
seq_id = producer.submit_embedding(topic, e) | |
submitted_embeddings.append(e) | |
seq_ids.append(seq_id) | |
return submitted_embeddings, seq_ids | |
def produce_n_batch( | |
producer: Producer, | |
topic: str, | |
embeddings: Iterator[SubmitEmbeddingRecord], | |
n: int, | |
) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]: | |
submitted_embeddings = [] | |
seq_ids: Sequence[SeqId] = [] | |
for _ in range(n): | |
e = next(embeddings) | |
submitted_embeddings.append(e) | |
seq_ids = producer.submit_embeddings(topic, submitted_embeddings) | |
return submitted_embeddings, seq_ids | |
def produce_fn_fixtures() -> List[ProducerFn]: | |
return [produce_n_single, produce_n_batch] | |
def produce_fns( | |
request: pytest.FixtureRequest, | |
) -> Generator[ProducerFn, None, None]: | |
yield request.param | |
def pytest_configure(config): # type: ignore | |
embeddings_queue._called_from_test = True | |