File size: 3,290 Bytes
287a0bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import multiprocessing
from typing import Any, Dict, Generator, Optional, Tuple
import pytest
from chromadb import CloudClient
from chromadb.api import ServerAPI
from chromadb.auth.token import TokenTransportHeader
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
from chromadb.errors import AuthorizationError

from chromadb.test.conftest import _await_server, _run_server, find_free_port

TOKEN_TRANSPORT_HEADER = TokenTransportHeader.X_CHROMA_TOKEN.name
TEST_CLOUD_HOST = "localhost"


@pytest.fixture(scope="module")
def valid_token() -> str:
    return "valid_token"


@pytest.fixture(scope="module")
def mock_cloud_server(valid_token: str) -> Generator[System, None, None]:
    chroma_server_auth_provider: str = "chromadb.auth.token.TokenAuthServerProvider"
    chroma_server_auth_credentials_provider: str = (
        "chromadb.auth.token.TokenConfigServerAuthCredentialsProvider"
    )
    chroma_server_auth_credentials: str = valid_token
    chroma_server_auth_token_transport_header: str = TOKEN_TRANSPORT_HEADER

    port = find_free_port()

    args: Tuple[
        int,
        bool,
        Optional[str],
        Optional[str],
        Optional[str],
        Optional[str],
        Optional[str],
        Optional[str],
        Optional[str],
        Optional[str],
        Optional[Dict[str, Any]],
    ] = (
        port,
        False,
        None,
        chroma_server_auth_provider,
        chroma_server_auth_credentials_provider,
        None,
        chroma_server_auth_credentials,
        chroma_server_auth_token_transport_header,
        None,
        None,
        None,
    )
    ctx = multiprocessing.get_context("spawn")
    proc = ctx.Process(target=_run_server, args=args, daemon=True)
    proc.start()

    settings = Settings(
        chroma_api_impl="chromadb.api.fastapi.FastAPI",
        chroma_server_host=TEST_CLOUD_HOST,
        chroma_server_http_port=str(port),
        chroma_client_auth_provider="chromadb.auth.token.TokenAuthClientProvider",
        chroma_client_auth_credentials=valid_token,
        chroma_client_auth_token_transport_header=TOKEN_TRANSPORT_HEADER,
    )

    system = System(settings)
    api = system.instance(ServerAPI)
    system.start()
    _await_server(api)
    yield system
    system.stop()
    proc.kill()


def test_valid_key(mock_cloud_server: System, valid_token: str) -> None:
    valid_client = CloudClient(
        tenant=DEFAULT_TENANT,
        database=DEFAULT_DATABASE,
        api_key=valid_token,
        cloud_host=TEST_CLOUD_HOST,
        cloud_port=mock_cloud_server.settings.chroma_server_http_port,  # type: ignore
        enable_ssl=False,
    )

    assert valid_client.heartbeat()


def test_invalid_key(mock_cloud_server: System, valid_token: str) -> None:
    # Try to connect to the default tenant and database with an invalid token
    invalid_token = valid_token + "_invalid"
    with pytest.raises(AuthorizationError):
        client = CloudClient(
            tenant=DEFAULT_TENANT,
            database=DEFAULT_DATABASE,
            api_key=invalid_token,
            cloud_host=TEST_CLOUD_HOST,
            cloud_port=mock_cloud_server.settings.chroma_server_http_port,  # type: ignore
            enable_ssl=False,
        )
        client.heartbeat()