Spaces:
Running
Running
File size: 3,290 Bytes
287a0bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import multiprocessing
from typing import Any, Dict, Generator, Optional, Tuple
import pytest
from chromadb import CloudClient
from chromadb.api import ServerAPI
from chromadb.auth.token import TokenTransportHeader
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
from chromadb.errors import AuthorizationError
from chromadb.test.conftest import _await_server, _run_server, find_free_port
TOKEN_TRANSPORT_HEADER = TokenTransportHeader.X_CHROMA_TOKEN.name
TEST_CLOUD_HOST = "localhost"
@pytest.fixture(scope="module")
def valid_token() -> str:
return "valid_token"
@pytest.fixture(scope="module")
def mock_cloud_server(valid_token: str) -> Generator[System, None, None]:
chroma_server_auth_provider: str = "chromadb.auth.token.TokenAuthServerProvider"
chroma_server_auth_credentials_provider: str = (
"chromadb.auth.token.TokenConfigServerAuthCredentialsProvider"
)
chroma_server_auth_credentials: str = valid_token
chroma_server_auth_token_transport_header: str = TOKEN_TRANSPORT_HEADER
port = find_free_port()
args: Tuple[
int,
bool,
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[Dict[str, Any]],
] = (
port,
False,
None,
chroma_server_auth_provider,
chroma_server_auth_credentials_provider,
None,
chroma_server_auth_credentials,
chroma_server_auth_token_transport_header,
None,
None,
None,
)
ctx = multiprocessing.get_context("spawn")
proc = ctx.Process(target=_run_server, args=args, daemon=True)
proc.start()
settings = Settings(
chroma_api_impl="chromadb.api.fastapi.FastAPI",
chroma_server_host=TEST_CLOUD_HOST,
chroma_server_http_port=str(port),
chroma_client_auth_provider="chromadb.auth.token.TokenAuthClientProvider",
chroma_client_auth_credentials=valid_token,
chroma_client_auth_token_transport_header=TOKEN_TRANSPORT_HEADER,
)
system = System(settings)
api = system.instance(ServerAPI)
system.start()
_await_server(api)
yield system
system.stop()
proc.kill()
def test_valid_key(mock_cloud_server: System, valid_token: str) -> None:
valid_client = CloudClient(
tenant=DEFAULT_TENANT,
database=DEFAULT_DATABASE,
api_key=valid_token,
cloud_host=TEST_CLOUD_HOST,
cloud_port=mock_cloud_server.settings.chroma_server_http_port, # type: ignore
enable_ssl=False,
)
assert valid_client.heartbeat()
def test_invalid_key(mock_cloud_server: System, valid_token: str) -> None:
# Try to connect to the default tenant and database with an invalid token
invalid_token = valid_token + "_invalid"
with pytest.raises(AuthorizationError):
client = CloudClient(
tenant=DEFAULT_TENANT,
database=DEFAULT_DATABASE,
api_key=invalid_token,
cloud_host=TEST_CLOUD_HOST,
cloud_port=mock_cloud_server.settings.chroma_server_http_port, # type: ignore
enable_ssl=False,
)
client.heartbeat()
|