Spaces:
Sleeping
Sleeping
import json | |
import random | |
import string | |
from typing import Dict, Any, Tuple | |
import uuid | |
import hypothesis.strategies as st | |
import pytest | |
from hypothesis import given, settings | |
from chromadb import AdminClient | |
from chromadb.api import AdminAPI, ServerAPI | |
from chromadb.api.models.Collection import Collection | |
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System | |
from chromadb.test.conftest import _fastapi_fixture | |
valid_action_space = [ | |
"tenant:create_tenant", | |
"tenant:get_tenant", | |
"db:create_database", | |
"db:get_database", | |
"db:reset", | |
"db:list_collections", | |
"collection:get_collection", | |
"db:create_collection", | |
"collection:delete_collection", | |
"collection:update_collection", | |
"collection:add", | |
"collection:delete", | |
"collection:get", | |
"collection:query", | |
"collection:peek", | |
"collection:update", | |
"collection:upsert", | |
"collection:count", | |
] | |
role_name = st.text(alphabet=string.ascii_letters, min_size=1, max_size=20) | |
user_name = st.text(alphabet=string.ascii_letters, min_size=1, max_size=20) | |
actions = st.lists( | |
st.sampled_from(valid_action_space), min_size=1, max_size=len(valid_action_space) | |
) | |
def master_user(draw: st.DrawFn) -> Tuple[Dict[str, Any], Dict[str, Any]]: | |
return { | |
"role": "__master_role__", | |
"id": "__master__", | |
"tenant": DEFAULT_TENANT, | |
"tokens": [ | |
{ | |
"token": f"{random.randint(1,1000000)}_" | |
+ draw( | |
st.text( | |
alphabet=string.ascii_letters + string.digits, | |
min_size=1, | |
max_size=25, | |
) | |
) | |
} | |
for _ in range(2) | |
], | |
}, { | |
"__master_role__": { | |
"actions": valid_action_space, | |
"unauthorized_actions": [], | |
} | |
} | |
def user_role_config(draw: st.DrawFn) -> Tuple[Dict[str, Any], Dict[str, Any]]: | |
role = draw(role_name) | |
user = draw(user_name) | |
actions_list = draw(actions) | |
if any( | |
action in actions_list | |
for action in [ | |
"collection:add", | |
"collection:delete", | |
"collection:get", | |
"collection:query", | |
"collection:peek", | |
"collection:update", | |
"collection:upsert", | |
"collection:count", | |
] | |
): | |
actions_list.append("collection:get_collection") | |
if any( | |
action in actions_list | |
for action in [ | |
"collection:peek", | |
] | |
): | |
actions_list.append("collection:get") | |
actions_list.extend( | |
[ | |
"tenant:get_tenant", | |
"db:get_database", | |
] | |
) | |
unauthorized_actions = set(valid_action_space) - set(actions_list) | |
_role_config = { | |
f"{role}": { | |
"actions": actions_list, | |
"unauthorized_actions": list(unauthorized_actions), | |
} | |
} | |
return { | |
"role": role, | |
"id": user, | |
"tenant": DEFAULT_TENANT, | |
"tokens": [ | |
{ | |
"token": f"{random.randint(1,1000000)}_" | |
+ draw( | |
st.text( | |
alphabet=string.ascii_letters + string.digits, | |
min_size=1, | |
max_size=25, | |
) | |
) | |
} | |
for _ in range(2) | |
], | |
}, _role_config | |
def rbac_config(draw: st.DrawFn) -> Dict[str, Any]: | |
user_roles = draw( | |
st.lists(user_role_config().filter(lambda t: t[0]), min_size=1, max_size=10) | |
) | |
muser_role = draw(st.lists(master_user(), min_size=1, max_size=1)) | |
users = [] | |
roles = [] | |
for user, role in user_roles: | |
users.append(user) | |
roles.append(role) | |
for muser, mrole in muser_role: | |
users.append(muser) | |
roles.append(mrole) | |
roles_mapping = {} | |
for role in roles: | |
roles_mapping.update(role) | |
_rbac_config = { | |
"roles_mapping": roles_mapping, | |
"users": users, | |
} | |
return _rbac_config | |
def token_config(draw: st.DrawFn) -> Dict[str, Any]: | |
token_header = draw(st.sampled_from(["AUTHORIZATION", "X_CHROMA_TOKEN", None])) | |
server_provider = draw( | |
st.sampled_from(["token", "chromadb.auth.token.TokenAuthServerProvider"]) | |
) | |
client_provider = draw( | |
st.sampled_from(["token", "chromadb.auth.token.TokenAuthClientProvider"]) | |
) | |
server_authz_provider = draw( | |
st.sampled_from(["chromadb.auth.authz.SimpleRBACAuthorizationProvider"]) | |
) | |
server_credentials_provider = draw(st.sampled_from(["user_token_config"])) | |
# _rbac_config = draw(rbac_config()) | |
persistence = draw(st.booleans()) | |
return { | |
"token_transport_header": token_header, | |
"chroma_server_auth_credentials_file": None, | |
"chroma_server_auth_provider": server_provider, | |
"chroma_client_auth_provider": client_provider, | |
"chroma_server_authz_config_file": None, | |
"chroma_server_auth_credentials_provider": server_credentials_provider, | |
"chroma_server_authz_provider": server_authz_provider, | |
"is_persistent": persistence, | |
} | |
api_executors = { | |
"db:create_database": lambda api, mapi, aapi: ( | |
aapi.create_database(f"test-{uuid.uuid4()}") | |
), | |
"db:get_database": lambda api, mapi, aapi: (aapi.get_database(DEFAULT_DATABASE),), | |
"tenant:create_tenant": lambda api, mapi, aapi: ( | |
aapi.create_tenant(f"test-{uuid.uuid4()}") | |
), | |
"tenant:get_tenant": lambda api, mapi, aapi: (aapi.get_tenant(DEFAULT_TENANT),), | |
"db:reset": lambda api, mapi, _: api.reset(), | |
"db:list_collections": lambda api, mapi, _: api.list_collections(), | |
"collection:get_collection": lambda api, mapi, _: ( | |
# pre-condition | |
mcol := mapi.create_collection(f"test-get-{uuid.uuid4()}"), | |
api.get_collection(f"{mcol.name}"), | |
), | |
"db:create_collection": lambda api, mapi, _: ( | |
api.create_collection(f"test-create-{uuid.uuid4()}"), | |
), | |
"db:get_or_create_collection": lambda api, mapi, _: ( | |
api.get_or_create_collection(f"test-get-or-create-{uuid.uuid4()}") | |
), | |
"collection:delete_collection": lambda api, mapi, _: ( | |
# pre-condition | |
mcol := mapi.create_collection(f"test-delete-col-{uuid.uuid4()}"), | |
api.delete_collection(f"{mcol.name}"), | |
), | |
"collection:update_collection": lambda api, mapi, _: ( | |
# pre-condition | |
mcol := mapi.create_collection(f"test-modify-col-{uuid.uuid4()}"), | |
col := Collection(api, f"{mcol.name}", mcol.id), | |
col.modify(metadata={"test": "test"}), | |
), | |
"collection:add": lambda api, mapi, _: ( | |
mcol := mapi.create_collection(f"test-add-doc-{uuid.uuid4()}"), | |
col := Collection(api, f"{mcol.name}", mcol.id), | |
col.add(documents=["test"], ids=["1"]), | |
), | |
"collection:delete": lambda api, mapi, _: ( | |
mcol := mapi.create_collection(f"test-delete-doc-{uuid.uuid4()}"), | |
mcol.add(documents=["test"], ids=["1"]), | |
col := Collection(client=api, name=f"{mcol.name}", id=mcol.id), | |
col.delete(ids=["1"]), | |
), | |
"collection:get": lambda api, mapi, _: ( | |
mcol := mapi.create_collection(f"test-get-doc-{uuid.uuid4()}"), | |
mcol.add(documents=["test"], ids=["1"]), | |
col := Collection(api, f"{mcol.name}", mcol.id), | |
col.get(ids=["1"]), | |
), | |
"collection:query": lambda api, mapi, _: ( | |
mcol := mapi.create_collection(f"test-query-doc-{uuid.uuid4()}"), | |
mcol.add(documents=["test"], ids=["1"]), | |
col := Collection(api, f"{mcol.name}", mcol.id), | |
col.query(query_texts=["test"]), | |
), | |
"collection:peek": lambda api, mapi, _: ( | |
mcol := mapi.create_collection(f"test-peek-{uuid.uuid4()}"), | |
mcol.add(documents=["test"], ids=["1"]), | |
col := Collection(api, f"{mcol.name}", mcol.id), | |
col.peek(), | |
), | |
"collection:update": lambda api, mapi, _: ( | |
mcol := mapi.create_collection(f"test-update-{uuid.uuid4()}"), | |
mcol.add(documents=["test"], ids=["1"]), | |
col := Collection(api, f"{mcol.name}", mcol.id), | |
col.update(ids=["1"], documents=["test1"]), | |
), | |
"collection:upsert": lambda api, mapi, _: ( | |
mcol := mapi.create_collection(f"test-upsert-{uuid.uuid4()}"), | |
mcol.add(documents=["test"], ids=["1"]), | |
col := Collection(api, f"{mcol.name}", mcol.id), | |
col.upsert(ids=["1"], documents=["test1"]), | |
), | |
"collection:count": lambda api, mapi, _: ( | |
mcol := mapi.create_collection(f"test-count-{uuid.uuid4()}"), | |
mcol.add(documents=["test"], ids=["1"]), | |
col := Collection(api, f"{mcol.name}", mcol.id), | |
col.count(), | |
), | |
} | |
def master_api(_settings: Settings) -> Tuple[ServerAPI, AdminAPI]: | |
system = System(_settings) | |
api = system.instance(ServerAPI) | |
admin_api = AdminClient(api.get_settings()) | |
system.start() | |
return api, admin_api | |
def test_authz(token_config: Dict[str, Any], rbac_config: Dict[str, Any]) -> None: | |
authz_config = rbac_config | |
token_config["chroma_server_authz_config"] = rbac_config | |
token_config["chroma_server_auth_credentials"] = json.dumps(rbac_config["users"]) | |
random_user = random.choice( | |
[user for user in authz_config["users"] if user["id"] != "__master__"] | |
) | |
_master_user = [ | |
user for user in authz_config["users"] if user["id"] == "__master__" | |
][0] | |
random_token = random.choice(random_user["tokens"])["token"] | |
api = _fastapi_fixture( | |
is_persistent=token_config["is_persistent"], | |
chroma_server_auth_provider=token_config["chroma_server_auth_provider"], | |
chroma_server_auth_credentials_provider=token_config[ | |
"chroma_server_auth_credentials_provider" | |
], | |
chroma_server_auth_credentials=token_config["chroma_server_auth_credentials"], | |
chroma_client_auth_provider=token_config["chroma_client_auth_provider"], | |
chroma_client_auth_token_transport_header=token_config[ | |
"token_transport_header" | |
], | |
chroma_server_auth_token_transport_header=token_config[ | |
"token_transport_header" | |
], | |
chroma_server_authz_provider=token_config["chroma_server_authz_provider"], | |
chroma_server_authz_config=token_config["chroma_server_authz_config"], | |
chroma_client_auth_credentials=random_token, | |
) | |
_sys: System = next(api) | |
_sys.reset_state() | |
_master_settings = Settings(**dict(_sys.settings)) | |
_master_settings.chroma_client_auth_credentials = _master_user["tokens"][0]["token"] | |
_master_api, admin_api = master_api(_master_settings) | |
_api = _sys.instance(ServerAPI) | |
_api.heartbeat() | |
for action in authz_config["roles_mapping"][random_user["role"]]["actions"]: | |
print(action) | |
api_executors[action](_api, _master_api, admin_api) # type: ignore | |
for unauthorized_action in authz_config["roles_mapping"][random_user["role"]][ | |
"unauthorized_actions" | |
]: | |
with pytest.raises(Exception) as ex: | |
api_executors[unauthorized_action]( | |
_api, _master_api, admin_api | |
) # type: ignore | |
assert "Unauthorized" in str(ex) or "Forbidden" in str(ex) | |