chroma / chromadb /test /auth /test_token_auth.py
badalsahani's picture
feat: chroma initial deploy
287a0bc
import string
from typing import Dict, Any
import hypothesis.strategies as st
import pytest
from hypothesis import given, settings
from chromadb.api import ServerAPI
from chromadb.config import System
from chromadb.test.conftest import _fastapi_fixture
@st.composite
def token_config(draw: st.DrawFn) -> Dict[str, Any]:
token_header = draw(st.sampled_from(["AUTHORIZATION", "X_CHROMA_TOKEN", None]))
server_provider = draw(
st.sampled_from(["token", "chromadb.auth.token.TokenAuthServerProvider"])
)
client_provider = draw(
st.sampled_from(["token", "chromadb.auth.token.TokenAuthClientProvider"])
)
server_credentials_provider = draw(
st.sampled_from(
["chromadb.auth.token.TokenConfigServerAuthCredentialsProvider"]
)
)
token = draw(
st.text(
alphabet=string.digits + string.ascii_letters + string.punctuation,
min_size=1,
max_size=50,
)
)
persistence = draw(st.booleans())
return {
"token_transport_header": token_header,
"chroma_server_auth_credentials": token,
"chroma_client_auth_credentials": token,
"chroma_server_auth_provider": server_provider,
"chroma_client_auth_provider": client_provider,
"chroma_server_auth_credentials_provider": server_credentials_provider,
"is_persistent": persistence,
}
@settings(max_examples=10)
@given(token_config())
def test_fastapi_server_token_auth(token_config: Dict[str, Any]) -> None:
api = _fastapi_fixture(
is_persistent=token_config["is_persistent"],
chroma_server_auth_provider=token_config["chroma_server_auth_provider"],
chroma_server_auth_credentials_provider=token_config[
"chroma_server_auth_credentials_provider"
],
chroma_server_auth_credentials=token_config["chroma_server_auth_credentials"],
chroma_client_auth_provider=token_config["chroma_client_auth_provider"],
chroma_client_auth_token_transport_header=token_config[
"token_transport_header"
],
chroma_server_auth_token_transport_header=token_config[
"token_transport_header"
],
chroma_client_auth_credentials=token_config["chroma_client_auth_credentials"],
)
_sys: System = next(api)
_sys.reset_state()
_api = _sys.instance(ServerAPI)
_api.heartbeat()
assert _api.list_collections() == []
@st.composite
def random_token(draw: st.DrawFn) -> str:
return draw(
st.text(alphabet=string.ascii_letters + string.digits, min_size=1, max_size=5)
)
@st.composite
def invalid_token(draw: st.DrawFn) -> str:
opposite_alphabet = set(string.printable) - set(
string.digits + string.ascii_letters + string.punctuation
)
token = draw(st.text(alphabet=list(opposite_alphabet), min_size=1, max_size=50))
return token
@settings(max_examples=10)
@given(tconf=token_config(), inval_tok=invalid_token())
def test_invalid_token(tconf: Dict[str, Any], inval_tok: str) -> None:
api = _fastapi_fixture(
is_persistent=tconf["is_persistent"],
chroma_server_auth_provider=tconf["chroma_server_auth_provider"],
chroma_server_auth_credentials_provider=tconf[
"chroma_server_auth_credentials_provider"
],
chroma_server_auth_credentials=tconf["chroma_server_auth_credentials"],
chroma_server_auth_token_transport_header=tconf["token_transport_header"],
chroma_client_auth_provider=tconf["chroma_client_auth_provider"],
chroma_client_auth_token_transport_header=tconf["token_transport_header"],
chroma_client_auth_credentials=inval_tok,
)
with pytest.raises(Exception) as e:
_sys: System = next(api)
_sys.reset_state()
_sys.instance(ServerAPI)
assert "Invalid token" in str(e)
@settings(max_examples=10)
@given(token_config(), random_token())
def test_fastapi_server_token_auth_wrong_token(
token_config: Dict[str, Any], random_token: str
) -> None:
api = _fastapi_fixture(
is_persistent=token_config["is_persistent"],
chroma_server_auth_provider=token_config["chroma_server_auth_provider"],
chroma_server_auth_credentials_provider=token_config[
"chroma_server_auth_credentials_provider"
],
chroma_server_auth_credentials=token_config["chroma_server_auth_credentials"],
chroma_server_auth_token_transport_header=token_config[
"token_transport_header"
],
chroma_client_auth_provider=token_config["chroma_client_auth_provider"],
chroma_client_auth_token_transport_header=token_config[
"token_transport_header"
],
chroma_client_auth_credentials=token_config["chroma_client_auth_credentials"]
+ random_token,
)
_sys: System = next(api)
_sys.reset_state()
_api = _sys.instance(ServerAPI)
_api.heartbeat()
with pytest.raises(Exception) as e:
_api.list_collections()
assert "Unauthorized" in str(e)