chroma / chromadb /test /client /test_cloud_client.py
badalsahani's picture
feat: chroma initial deploy
287a0bc
raw
history blame contribute delete
3.29 kB
import multiprocessing
from typing import Any, Dict, Generator, Optional, Tuple
import pytest
from chromadb import CloudClient
from chromadb.api import ServerAPI
from chromadb.auth.token import TokenTransportHeader
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
from chromadb.errors import AuthorizationError
from chromadb.test.conftest import _await_server, _run_server, find_free_port
TOKEN_TRANSPORT_HEADER = TokenTransportHeader.X_CHROMA_TOKEN.name
TEST_CLOUD_HOST = "localhost"
@pytest.fixture(scope="module")
def valid_token() -> str:
return "valid_token"
@pytest.fixture(scope="module")
def mock_cloud_server(valid_token: str) -> Generator[System, None, None]:
chroma_server_auth_provider: str = "chromadb.auth.token.TokenAuthServerProvider"
chroma_server_auth_credentials_provider: str = (
"chromadb.auth.token.TokenConfigServerAuthCredentialsProvider"
)
chroma_server_auth_credentials: str = valid_token
chroma_server_auth_token_transport_header: str = TOKEN_TRANSPORT_HEADER
port = find_free_port()
args: Tuple[
int,
bool,
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[Dict[str, Any]],
] = (
port,
False,
None,
chroma_server_auth_provider,
chroma_server_auth_credentials_provider,
None,
chroma_server_auth_credentials,
chroma_server_auth_token_transport_header,
None,
None,
None,
)
ctx = multiprocessing.get_context("spawn")
proc = ctx.Process(target=_run_server, args=args, daemon=True)
proc.start()
settings = Settings(
chroma_api_impl="chromadb.api.fastapi.FastAPI",
chroma_server_host=TEST_CLOUD_HOST,
chroma_server_http_port=str(port),
chroma_client_auth_provider="chromadb.auth.token.TokenAuthClientProvider",
chroma_client_auth_credentials=valid_token,
chroma_client_auth_token_transport_header=TOKEN_TRANSPORT_HEADER,
)
system = System(settings)
api = system.instance(ServerAPI)
system.start()
_await_server(api)
yield system
system.stop()
proc.kill()
def test_valid_key(mock_cloud_server: System, valid_token: str) -> None:
valid_client = CloudClient(
tenant=DEFAULT_TENANT,
database=DEFAULT_DATABASE,
api_key=valid_token,
cloud_host=TEST_CLOUD_HOST,
cloud_port=mock_cloud_server.settings.chroma_server_http_port, # type: ignore
enable_ssl=False,
)
assert valid_client.heartbeat()
def test_invalid_key(mock_cloud_server: System, valid_token: str) -> None:
# Try to connect to the default tenant and database with an invalid token
invalid_token = valid_token + "_invalid"
with pytest.raises(AuthorizationError):
client = CloudClient(
tenant=DEFAULT_TENANT,
database=DEFAULT_DATABASE,
api_key=invalid_token,
cloud_host=TEST_CLOUD_HOST,
cloud_port=mock_cloud_server.settings.chroma_server_http_port, # type: ignore
enable_ssl=False,
)
client.heartbeat()