Spaces:
Sleeping
Sleeping
import string | |
from typing import Dict, Any | |
import hypothesis.strategies as st | |
import pytest | |
from hypothesis import given, settings | |
from chromadb.api import ServerAPI | |
from chromadb.config import System | |
from chromadb.test.conftest import _fastapi_fixture | |
def token_config(draw: st.DrawFn) -> Dict[str, Any]: | |
token_header = draw(st.sampled_from(["AUTHORIZATION", "X_CHROMA_TOKEN", None])) | |
server_provider = draw( | |
st.sampled_from(["token", "chromadb.auth.token.TokenAuthServerProvider"]) | |
) | |
client_provider = draw( | |
st.sampled_from(["token", "chromadb.auth.token.TokenAuthClientProvider"]) | |
) | |
server_credentials_provider = draw( | |
st.sampled_from( | |
["chromadb.auth.token.TokenConfigServerAuthCredentialsProvider"] | |
) | |
) | |
token = draw( | |
st.text( | |
alphabet=string.digits + string.ascii_letters + string.punctuation, | |
min_size=1, | |
max_size=50, | |
) | |
) | |
persistence = draw(st.booleans()) | |
return { | |
"token_transport_header": token_header, | |
"chroma_server_auth_credentials": token, | |
"chroma_client_auth_credentials": token, | |
"chroma_server_auth_provider": server_provider, | |
"chroma_client_auth_provider": client_provider, | |
"chroma_server_auth_credentials_provider": server_credentials_provider, | |
"is_persistent": persistence, | |
} | |
def test_fastapi_server_token_auth(token_config: Dict[str, Any]) -> None: | |
api = _fastapi_fixture( | |
is_persistent=token_config["is_persistent"], | |
chroma_server_auth_provider=token_config["chroma_server_auth_provider"], | |
chroma_server_auth_credentials_provider=token_config[ | |
"chroma_server_auth_credentials_provider" | |
], | |
chroma_server_auth_credentials=token_config["chroma_server_auth_credentials"], | |
chroma_client_auth_provider=token_config["chroma_client_auth_provider"], | |
chroma_client_auth_token_transport_header=token_config[ | |
"token_transport_header" | |
], | |
chroma_server_auth_token_transport_header=token_config[ | |
"token_transport_header" | |
], | |
chroma_client_auth_credentials=token_config["chroma_client_auth_credentials"], | |
) | |
_sys: System = next(api) | |
_sys.reset_state() | |
_api = _sys.instance(ServerAPI) | |
_api.heartbeat() | |
assert _api.list_collections() == [] | |
def random_token(draw: st.DrawFn) -> str: | |
return draw( | |
st.text(alphabet=string.ascii_letters + string.digits, min_size=1, max_size=5) | |
) | |
def invalid_token(draw: st.DrawFn) -> str: | |
opposite_alphabet = set(string.printable) - set( | |
string.digits + string.ascii_letters + string.punctuation | |
) | |
token = draw(st.text(alphabet=list(opposite_alphabet), min_size=1, max_size=50)) | |
return token | |
def test_invalid_token(tconf: Dict[str, Any], inval_tok: str) -> None: | |
api = _fastapi_fixture( | |
is_persistent=tconf["is_persistent"], | |
chroma_server_auth_provider=tconf["chroma_server_auth_provider"], | |
chroma_server_auth_credentials_provider=tconf[ | |
"chroma_server_auth_credentials_provider" | |
], | |
chroma_server_auth_credentials=tconf["chroma_server_auth_credentials"], | |
chroma_server_auth_token_transport_header=tconf["token_transport_header"], | |
chroma_client_auth_provider=tconf["chroma_client_auth_provider"], | |
chroma_client_auth_token_transport_header=tconf["token_transport_header"], | |
chroma_client_auth_credentials=inval_tok, | |
) | |
with pytest.raises(Exception) as e: | |
_sys: System = next(api) | |
_sys.reset_state() | |
_sys.instance(ServerAPI) | |
assert "Invalid token" in str(e) | |
def test_fastapi_server_token_auth_wrong_token( | |
token_config: Dict[str, Any], random_token: str | |
) -> None: | |
api = _fastapi_fixture( | |
is_persistent=token_config["is_persistent"], | |
chroma_server_auth_provider=token_config["chroma_server_auth_provider"], | |
chroma_server_auth_credentials_provider=token_config[ | |
"chroma_server_auth_credentials_provider" | |
], | |
chroma_server_auth_credentials=token_config["chroma_server_auth_credentials"], | |
chroma_server_auth_token_transport_header=token_config[ | |
"token_transport_header" | |
], | |
chroma_client_auth_provider=token_config["chroma_client_auth_provider"], | |
chroma_client_auth_token_transport_header=token_config[ | |
"token_transport_header" | |
], | |
chroma_client_auth_credentials=token_config["chroma_client_auth_credentials"] | |
+ random_token, | |
) | |
_sys: System = next(api) | |
_sys.reset_state() | |
_api = _sys.instance(ServerAPI) | |
_api.heartbeat() | |
with pytest.raises(Exception) as e: | |
_api.list_collections() | |
assert "Unauthorized" in str(e) | |