Spaces:
Sleeping
Sleeping
import logging | |
from typing import Any, Dict, Set, cast | |
from overrides import override | |
import yaml | |
from chromadb.auth import ( | |
AuthorizationContext, | |
ServerAuthorizationConfigurationProvider, | |
ServerAuthorizationProvider, | |
) | |
from chromadb.auth.registry import register_provider, resolve_provider | |
from chromadb.config import DEFAULT_TENANT, System | |
from chromadb.telemetry.opentelemetry import ( | |
OpenTelemetryGranularity, | |
trace_method, | |
) | |
logger = logging.getLogger(__name__) | |
class LocalUserConfigAuthorizationConfigurationProvider( | |
ServerAuthorizationConfigurationProvider[Dict[str, Any]] | |
): | |
_config_file: str | |
_config: Dict[str, Any] | |
def __init__(self, system: System) -> None: | |
super().__init__(system) | |
self._settings = system.settings | |
if self._settings.chroma_server_authz_config_file: | |
self._config_file = str(system.settings.chroma_server_authz_config_file) | |
with open(self._config_file, "r") as f: | |
self._config = yaml.safe_load(f) | |
elif self._settings.chroma_server_authz_config: | |
self._config = self._settings.chroma_server_authz_config | |
else: | |
raise ValueError( | |
"No configuration (CHROMA_SERVER_AUTHZ_CONFIG_FILE) file or " | |
"configuration (CHROMA_SERVER_AUTHZ_CONFIG) provided for " | |
"LocalUserConfigAuthorizationConfigurationProvider" | |
) | |
def get_configuration(self) -> Dict[str, Any]: | |
return self._config | |
class SimpleRBACAuthorizationProvider(ServerAuthorizationProvider): | |
_authz_config_provider: ServerAuthorizationConfigurationProvider[Dict[str, Any]] | |
def __init__(self, system: System) -> None: | |
super().__init__(system) | |
self._settings = system.settings | |
system.settings.require("chroma_server_authz_config_provider") | |
if self._settings.chroma_server_authz_config_provider: | |
_cls = resolve_provider( | |
self._settings.chroma_server_authz_config_provider, | |
ServerAuthorizationConfigurationProvider, | |
) | |
self._authz_config_provider = cast( | |
ServerAuthorizationConfigurationProvider[Dict[str, Any]], | |
self.require(_cls), | |
) | |
_config = self._authz_config_provider.get_configuration() | |
self._authz_tuples_map: Dict[str, Set[Any]] = {} | |
for u in _config["users"]: | |
_actions = _config["roles_mapping"][u["role"]]["actions"] | |
for a in _actions: | |
tenant = u["tenant"] if "tenant" in u else DEFAULT_TENANT | |
if u["id"] not in self._authz_tuples_map.keys(): | |
self._authz_tuples_map[u["id"]] = set() | |
self._authz_tuples_map[u["id"]].add( | |
(u["id"], tenant, *a.split(":")) | |
) | |
logger.debug( | |
f"Loaded {len(self._authz_tuples_map)} permissions for " | |
f"({len(_config['users'])}) users" | |
) | |
logger.info( | |
"Authorization Provider SimpleRBACAuthorizationProvider initialized" | |
) | |
def authorize(self, context: AuthorizationContext) -> bool: | |
_authz_tuple = ( | |
context.user.id, | |
context.user.tenant, | |
context.resource.type, | |
context.action.id, | |
) | |
policy_decision = False | |
if ( | |
context.user.id in self._authz_tuples_map.keys() | |
and _authz_tuple in self._authz_tuples_map[context.user.id] | |
): | |
policy_decision = True | |
logger.debug( | |
f"Authorization decision: Access " | |
f"{'granted' if policy_decision else 'denied'} for " | |
f"user [{context.user.id}] attempting to [{context.action.id}]" | |
f" on [{context.resource}]" | |
) | |
return policy_decision | |