chroma / chromadb /test /auth /test_simple_rbac_authz.py
badalsahani's picture
feat: chroma initial deploy
287a0bc
import json
import random
import string
from typing import Dict, Any, Tuple
import uuid
import hypothesis.strategies as st
import pytest
from hypothesis import given, settings
from chromadb import AdminClient
from chromadb.api import AdminAPI, ServerAPI
from chromadb.api.models.Collection import Collection
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
from chromadb.test.conftest import _fastapi_fixture
valid_action_space = [
"tenant:create_tenant",
"tenant:get_tenant",
"db:create_database",
"db:get_database",
"db:reset",
"db:list_collections",
"collection:get_collection",
"db:create_collection",
"collection:delete_collection",
"collection:update_collection",
"collection:add",
"collection:delete",
"collection:get",
"collection:query",
"collection:peek",
"collection:update",
"collection:upsert",
"collection:count",
]
role_name = st.text(alphabet=string.ascii_letters, min_size=1, max_size=20)
user_name = st.text(alphabet=string.ascii_letters, min_size=1, max_size=20)
actions = st.lists(
st.sampled_from(valid_action_space), min_size=1, max_size=len(valid_action_space)
)
@st.composite
def master_user(draw: st.DrawFn) -> Tuple[Dict[str, Any], Dict[str, Any]]:
return {
"role": "__master_role__",
"id": "__master__",
"tenant": DEFAULT_TENANT,
"tokens": [
{
"token": f"{random.randint(1,1000000)}_"
+ draw(
st.text(
alphabet=string.ascii_letters + string.digits,
min_size=1,
max_size=25,
)
)
}
for _ in range(2)
],
}, {
"__master_role__": {
"actions": valid_action_space,
"unauthorized_actions": [],
}
}
@st.composite
def user_role_config(draw: st.DrawFn) -> Tuple[Dict[str, Any], Dict[str, Any]]:
role = draw(role_name)
user = draw(user_name)
actions_list = draw(actions)
if any(
action in actions_list
for action in [
"collection:add",
"collection:delete",
"collection:get",
"collection:query",
"collection:peek",
"collection:update",
"collection:upsert",
"collection:count",
]
):
actions_list.append("collection:get_collection")
if any(
action in actions_list
for action in [
"collection:peek",
]
):
actions_list.append("collection:get")
actions_list.extend(
[
"tenant:get_tenant",
"db:get_database",
]
)
unauthorized_actions = set(valid_action_space) - set(actions_list)
_role_config = {
f"{role}": {
"actions": actions_list,
"unauthorized_actions": list(unauthorized_actions),
}
}
return {
"role": role,
"id": user,
"tenant": DEFAULT_TENANT,
"tokens": [
{
"token": f"{random.randint(1,1000000)}_"
+ draw(
st.text(
alphabet=string.ascii_letters + string.digits,
min_size=1,
max_size=25,
)
)
}
for _ in range(2)
],
}, _role_config
@st.composite
def rbac_config(draw: st.DrawFn) -> Dict[str, Any]:
user_roles = draw(
st.lists(user_role_config().filter(lambda t: t[0]), min_size=1, max_size=10)
)
muser_role = draw(st.lists(master_user(), min_size=1, max_size=1))
users = []
roles = []
for user, role in user_roles:
users.append(user)
roles.append(role)
for muser, mrole in muser_role:
users.append(muser)
roles.append(mrole)
roles_mapping = {}
for role in roles:
roles_mapping.update(role)
_rbac_config = {
"roles_mapping": roles_mapping,
"users": users,
}
return _rbac_config
@st.composite
def token_config(draw: st.DrawFn) -> Dict[str, Any]:
token_header = draw(st.sampled_from(["AUTHORIZATION", "X_CHROMA_TOKEN", None]))
server_provider = draw(
st.sampled_from(["token", "chromadb.auth.token.TokenAuthServerProvider"])
)
client_provider = draw(
st.sampled_from(["token", "chromadb.auth.token.TokenAuthClientProvider"])
)
server_authz_provider = draw(
st.sampled_from(["chromadb.auth.authz.SimpleRBACAuthorizationProvider"])
)
server_credentials_provider = draw(st.sampled_from(["user_token_config"]))
# _rbac_config = draw(rbac_config())
persistence = draw(st.booleans())
return {
"token_transport_header": token_header,
"chroma_server_auth_credentials_file": None,
"chroma_server_auth_provider": server_provider,
"chroma_client_auth_provider": client_provider,
"chroma_server_authz_config_file": None,
"chroma_server_auth_credentials_provider": server_credentials_provider,
"chroma_server_authz_provider": server_authz_provider,
"is_persistent": persistence,
}
api_executors = {
"db:create_database": lambda api, mapi, aapi: (
aapi.create_database(f"test-{uuid.uuid4()}")
),
"db:get_database": lambda api, mapi, aapi: (aapi.get_database(DEFAULT_DATABASE),),
"tenant:create_tenant": lambda api, mapi, aapi: (
aapi.create_tenant(f"test-{uuid.uuid4()}")
),
"tenant:get_tenant": lambda api, mapi, aapi: (aapi.get_tenant(DEFAULT_TENANT),),
"db:reset": lambda api, mapi, _: api.reset(),
"db:list_collections": lambda api, mapi, _: api.list_collections(),
"collection:get_collection": lambda api, mapi, _: (
# pre-condition
mcol := mapi.create_collection(f"test-get-{uuid.uuid4()}"),
api.get_collection(f"{mcol.name}"),
),
"db:create_collection": lambda api, mapi, _: (
api.create_collection(f"test-create-{uuid.uuid4()}"),
),
"db:get_or_create_collection": lambda api, mapi, _: (
api.get_or_create_collection(f"test-get-or-create-{uuid.uuid4()}")
),
"collection:delete_collection": lambda api, mapi, _: (
# pre-condition
mcol := mapi.create_collection(f"test-delete-col-{uuid.uuid4()}"),
api.delete_collection(f"{mcol.name}"),
),
"collection:update_collection": lambda api, mapi, _: (
# pre-condition
mcol := mapi.create_collection(f"test-modify-col-{uuid.uuid4()}"),
col := Collection(api, f"{mcol.name}", mcol.id),
col.modify(metadata={"test": "test"}),
),
"collection:add": lambda api, mapi, _: (
mcol := mapi.create_collection(f"test-add-doc-{uuid.uuid4()}"),
col := Collection(api, f"{mcol.name}", mcol.id),
col.add(documents=["test"], ids=["1"]),
),
"collection:delete": lambda api, mapi, _: (
mcol := mapi.create_collection(f"test-delete-doc-{uuid.uuid4()}"),
mcol.add(documents=["test"], ids=["1"]),
col := Collection(client=api, name=f"{mcol.name}", id=mcol.id),
col.delete(ids=["1"]),
),
"collection:get": lambda api, mapi, _: (
mcol := mapi.create_collection(f"test-get-doc-{uuid.uuid4()}"),
mcol.add(documents=["test"], ids=["1"]),
col := Collection(api, f"{mcol.name}", mcol.id),
col.get(ids=["1"]),
),
"collection:query": lambda api, mapi, _: (
mcol := mapi.create_collection(f"test-query-doc-{uuid.uuid4()}"),
mcol.add(documents=["test"], ids=["1"]),
col := Collection(api, f"{mcol.name}", mcol.id),
col.query(query_texts=["test"]),
),
"collection:peek": lambda api, mapi, _: (
mcol := mapi.create_collection(f"test-peek-{uuid.uuid4()}"),
mcol.add(documents=["test"], ids=["1"]),
col := Collection(api, f"{mcol.name}", mcol.id),
col.peek(),
),
"collection:update": lambda api, mapi, _: (
mcol := mapi.create_collection(f"test-update-{uuid.uuid4()}"),
mcol.add(documents=["test"], ids=["1"]),
col := Collection(api, f"{mcol.name}", mcol.id),
col.update(ids=["1"], documents=["test1"]),
),
"collection:upsert": lambda api, mapi, _: (
mcol := mapi.create_collection(f"test-upsert-{uuid.uuid4()}"),
mcol.add(documents=["test"], ids=["1"]),
col := Collection(api, f"{mcol.name}", mcol.id),
col.upsert(ids=["1"], documents=["test1"]),
),
"collection:count": lambda api, mapi, _: (
mcol := mapi.create_collection(f"test-count-{uuid.uuid4()}"),
mcol.add(documents=["test"], ids=["1"]),
col := Collection(api, f"{mcol.name}", mcol.id),
col.count(),
),
}
def master_api(_settings: Settings) -> Tuple[ServerAPI, AdminAPI]:
system = System(_settings)
api = system.instance(ServerAPI)
admin_api = AdminClient(api.get_settings())
system.start()
return api, admin_api
@settings(max_examples=10)
@given(token_config=token_config(), rbac_config=rbac_config())
def test_authz(token_config: Dict[str, Any], rbac_config: Dict[str, Any]) -> None:
authz_config = rbac_config
token_config["chroma_server_authz_config"] = rbac_config
token_config["chroma_server_auth_credentials"] = json.dumps(rbac_config["users"])
random_user = random.choice(
[user for user in authz_config["users"] if user["id"] != "__master__"]
)
_master_user = [
user for user in authz_config["users"] if user["id"] == "__master__"
][0]
random_token = random.choice(random_user["tokens"])["token"]
api = _fastapi_fixture(
is_persistent=token_config["is_persistent"],
chroma_server_auth_provider=token_config["chroma_server_auth_provider"],
chroma_server_auth_credentials_provider=token_config[
"chroma_server_auth_credentials_provider"
],
chroma_server_auth_credentials=token_config["chroma_server_auth_credentials"],
chroma_client_auth_provider=token_config["chroma_client_auth_provider"],
chroma_client_auth_token_transport_header=token_config[
"token_transport_header"
],
chroma_server_auth_token_transport_header=token_config[
"token_transport_header"
],
chroma_server_authz_provider=token_config["chroma_server_authz_provider"],
chroma_server_authz_config=token_config["chroma_server_authz_config"],
chroma_client_auth_credentials=random_token,
)
_sys: System = next(api)
_sys.reset_state()
_master_settings = Settings(**dict(_sys.settings))
_master_settings.chroma_client_auth_credentials = _master_user["tokens"][0]["token"]
_master_api, admin_api = master_api(_master_settings)
_api = _sys.instance(ServerAPI)
_api.heartbeat()
for action in authz_config["roles_mapping"][random_user["role"]]["actions"]:
print(action)
api_executors[action](_api, _master_api, admin_api) # type: ignore
for unauthorized_action in authz_config["roles_mapping"][random_user["role"]][
"unauthorized_actions"
]:
with pytest.raises(Exception) as ex:
api_executors[unauthorized_action](
_api, _master_api, admin_api
) # type: ignore
assert "Unauthorized" in str(ex) or "Forbidden" in str(ex)