chroma / chromadb /auth /authz /__init__.py
badalsahani's picture
feat: chroma initial deploy
287a0bc
import logging
from typing import Any, Dict, Set, cast
from overrides import override
import yaml
from chromadb.auth import (
AuthorizationContext,
ServerAuthorizationConfigurationProvider,
ServerAuthorizationProvider,
)
from chromadb.auth.registry import register_provider, resolve_provider
from chromadb.config import DEFAULT_TENANT, System
from chromadb.telemetry.opentelemetry import (
OpenTelemetryGranularity,
trace_method,
)
logger = logging.getLogger(__name__)
@register_provider("local_authz_config")
class LocalUserConfigAuthorizationConfigurationProvider(
ServerAuthorizationConfigurationProvider[Dict[str, Any]]
):
_config_file: str
_config: Dict[str, Any]
def __init__(self, system: System) -> None:
super().__init__(system)
self._settings = system.settings
if self._settings.chroma_server_authz_config_file:
self._config_file = str(system.settings.chroma_server_authz_config_file)
with open(self._config_file, "r") as f:
self._config = yaml.safe_load(f)
elif self._settings.chroma_server_authz_config:
self._config = self._settings.chroma_server_authz_config
else:
raise ValueError(
"No configuration (CHROMA_SERVER_AUTHZ_CONFIG_FILE) file or "
"configuration (CHROMA_SERVER_AUTHZ_CONFIG) provided for "
"LocalUserConfigAuthorizationConfigurationProvider"
)
@override
def get_configuration(self) -> Dict[str, Any]:
return self._config
@register_provider("simple_rbac")
class SimpleRBACAuthorizationProvider(ServerAuthorizationProvider):
_authz_config_provider: ServerAuthorizationConfigurationProvider[Dict[str, Any]]
def __init__(self, system: System) -> None:
super().__init__(system)
self._settings = system.settings
system.settings.require("chroma_server_authz_config_provider")
if self._settings.chroma_server_authz_config_provider:
_cls = resolve_provider(
self._settings.chroma_server_authz_config_provider,
ServerAuthorizationConfigurationProvider,
)
self._authz_config_provider = cast(
ServerAuthorizationConfigurationProvider[Dict[str, Any]],
self.require(_cls),
)
_config = self._authz_config_provider.get_configuration()
self._authz_tuples_map: Dict[str, Set[Any]] = {}
for u in _config["users"]:
_actions = _config["roles_mapping"][u["role"]]["actions"]
for a in _actions:
tenant = u["tenant"] if "tenant" in u else DEFAULT_TENANT
if u["id"] not in self._authz_tuples_map.keys():
self._authz_tuples_map[u["id"]] = set()
self._authz_tuples_map[u["id"]].add(
(u["id"], tenant, *a.split(":"))
)
logger.debug(
f"Loaded {len(self._authz_tuples_map)} permissions for "
f"({len(_config['users'])}) users"
)
logger.info(
"Authorization Provider SimpleRBACAuthorizationProvider initialized"
)
@trace_method(
"SimpleRBACAuthorizationProvider.authorize",
OpenTelemetryGranularity.ALL,
)
@override
def authorize(self, context: AuthorizationContext) -> bool:
_authz_tuple = (
context.user.id,
context.user.tenant,
context.resource.type,
context.action.id,
)
policy_decision = False
if (
context.user.id in self._authz_tuples_map.keys()
and _authz_tuple in self._authz_tuples_map[context.user.id]
):
policy_decision = True
logger.debug(
f"Authorization decision: Access "
f"{'granted' if policy_decision else 'denied'} for "
f"user [{context.user.id}] attempting to [{context.action.id}]"
f" on [{context.resource}]"
)
return policy_decision