File size: 5,078 Bytes
287a0bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import string
from typing import Dict, Any

import hypothesis.strategies as st
import pytest
from hypothesis import given, settings

from chromadb.api import ServerAPI
from chromadb.config import System
from chromadb.test.conftest import _fastapi_fixture


@st.composite
def token_config(draw: st.DrawFn) -> Dict[str, Any]:
    token_header = draw(st.sampled_from(["AUTHORIZATION", "X_CHROMA_TOKEN", None]))
    server_provider = draw(
        st.sampled_from(["token", "chromadb.auth.token.TokenAuthServerProvider"])
    )
    client_provider = draw(
        st.sampled_from(["token", "chromadb.auth.token.TokenAuthClientProvider"])
    )
    server_credentials_provider = draw(
        st.sampled_from(
            ["chromadb.auth.token.TokenConfigServerAuthCredentialsProvider"]
        )
    )
    token = draw(
        st.text(
            alphabet=string.digits + string.ascii_letters + string.punctuation,
            min_size=1,
            max_size=50,
        )
    )
    persistence = draw(st.booleans())
    return {
        "token_transport_header": token_header,
        "chroma_server_auth_credentials": token,
        "chroma_client_auth_credentials": token,
        "chroma_server_auth_provider": server_provider,
        "chroma_client_auth_provider": client_provider,
        "chroma_server_auth_credentials_provider": server_credentials_provider,
        "is_persistent": persistence,
    }


@settings(max_examples=10)
@given(token_config())
def test_fastapi_server_token_auth(token_config: Dict[str, Any]) -> None:
    api = _fastapi_fixture(
        is_persistent=token_config["is_persistent"],
        chroma_server_auth_provider=token_config["chroma_server_auth_provider"],
        chroma_server_auth_credentials_provider=token_config[
            "chroma_server_auth_credentials_provider"
        ],
        chroma_server_auth_credentials=token_config["chroma_server_auth_credentials"],
        chroma_client_auth_provider=token_config["chroma_client_auth_provider"],
        chroma_client_auth_token_transport_header=token_config[
            "token_transport_header"
        ],
        chroma_server_auth_token_transport_header=token_config[
            "token_transport_header"
        ],
        chroma_client_auth_credentials=token_config["chroma_client_auth_credentials"],
    )
    _sys: System = next(api)
    _sys.reset_state()
    _api = _sys.instance(ServerAPI)
    _api.heartbeat()
    assert _api.list_collections() == []


@st.composite
def random_token(draw: st.DrawFn) -> str:
    return draw(
        st.text(alphabet=string.ascii_letters + string.digits, min_size=1, max_size=5)
    )


@st.composite
def invalid_token(draw: st.DrawFn) -> str:
    opposite_alphabet = set(string.printable) - set(
        string.digits + string.ascii_letters + string.punctuation
    )
    token = draw(st.text(alphabet=list(opposite_alphabet), min_size=1, max_size=50))
    return token


@settings(max_examples=10)
@given(tconf=token_config(), inval_tok=invalid_token())
def test_invalid_token(tconf: Dict[str, Any], inval_tok: str) -> None:
    api = _fastapi_fixture(
        is_persistent=tconf["is_persistent"],
        chroma_server_auth_provider=tconf["chroma_server_auth_provider"],
        chroma_server_auth_credentials_provider=tconf[
            "chroma_server_auth_credentials_provider"
        ],
        chroma_server_auth_credentials=tconf["chroma_server_auth_credentials"],
        chroma_server_auth_token_transport_header=tconf["token_transport_header"],
        chroma_client_auth_provider=tconf["chroma_client_auth_provider"],
        chroma_client_auth_token_transport_header=tconf["token_transport_header"],
        chroma_client_auth_credentials=inval_tok,
    )
    with pytest.raises(Exception) as e:
        _sys: System = next(api)
        _sys.reset_state()
        _sys.instance(ServerAPI)
    assert "Invalid token" in str(e)


@settings(max_examples=10)
@given(token_config(), random_token())
def test_fastapi_server_token_auth_wrong_token(
    token_config: Dict[str, Any], random_token: str
) -> None:
    api = _fastapi_fixture(
        is_persistent=token_config["is_persistent"],
        chroma_server_auth_provider=token_config["chroma_server_auth_provider"],
        chroma_server_auth_credentials_provider=token_config[
            "chroma_server_auth_credentials_provider"
        ],
        chroma_server_auth_credentials=token_config["chroma_server_auth_credentials"],
        chroma_server_auth_token_transport_header=token_config[
            "token_transport_header"
        ],
        chroma_client_auth_provider=token_config["chroma_client_auth_provider"],
        chroma_client_auth_token_transport_header=token_config[
            "token_transport_header"
        ],
        chroma_client_auth_credentials=token_config["chroma_client_auth_credentials"]
        + random_token,
    )
    _sys: System = next(api)
    _sys.reset_state()
    _api = _sys.instance(ServerAPI)
    _api.heartbeat()
    with pytest.raises(Exception) as e:
        _api.list_collections()
    assert "Unauthorized" in str(e)