chroma / chromadb /auth /registry.py
badalsahani's picture
feat: chroma initial deploy
287a0bc
import importlib
import logging
import pkgutil
from typing import Union, Dict, Type, Callable # noqa: F401
from chromadb.auth import (
ClientAuthConfigurationProvider,
ClientAuthCredentialsProvider,
ClientAuthProtocolAdapter,
ServerAuthProvider,
ServerAuthConfigurationProvider,
ServerAuthCredentialsProvider,
ClientAuthProvider,
ServerAuthorizationConfigurationProvider,
ServerAuthorizationProvider,
)
from chromadb.utils import get_class
logger = logging.getLogger(__name__)
ProviderTypes = Union[
"ClientAuthProvider",
"ClientAuthConfigurationProvider",
"ClientAuthCredentialsProvider",
"ServerAuthProvider",
"ServerAuthConfigurationProvider",
"ServerAuthCredentialsProvider",
"ClientAuthProtocolAdapter",
"ServerAuthorizationProvider",
"ServerAuthorizationConfigurationProvider",
]
_provider_registry = {
"client_auth_providers": {},
"client_auth_config_providers": {},
"client_auth_credentials_providers": {},
"client_auth_protocol_adapters": {},
"server_auth_providers": {},
"server_auth_config_providers": {},
"server_auth_credentials_providers": {},
"server_authz_providers": {},
"server_authz_config_providers": {},
} # type: Dict[str, Dict[str, Type[ProviderTypes]]]
def register_classes_from_package(package_name: str) -> None:
package = importlib.import_module(package_name)
for _, module_name, _ in pkgutil.iter_modules(package.__path__):
full_module_name = f"{package_name}.{module_name}"
_ = importlib.import_module(full_module_name)
def register_provider(
short_hand: str,
) -> Callable[[Type[ProviderTypes]], Type[ProviderTypes]]:
def decorator(cls: Type[ProviderTypes]) -> Type[ProviderTypes]:
logger.debug("Registering provider: %s", short_hand)
global _provider_registry
if issubclass(cls, ClientAuthProvider):
_provider_registry["client_auth_providers"][short_hand] = cls
elif issubclass(cls, ClientAuthConfigurationProvider):
_provider_registry["client_auth_config_providers"][short_hand] = cls
elif issubclass(cls, ClientAuthCredentialsProvider):
_provider_registry["client_auth_credentials_providers"][short_hand] = cls
elif issubclass(cls, ClientAuthProtocolAdapter):
_provider_registry["client_auth_protocol_adapters"][short_hand] = cls
elif issubclass(cls, ServerAuthProvider):
_provider_registry["server_auth_providers"][short_hand] = cls
elif issubclass(cls, ServerAuthConfigurationProvider):
_provider_registry["server_auth_config_providers"][short_hand] = cls
elif issubclass(cls, ServerAuthCredentialsProvider):
_provider_registry["server_auth_credentials_providers"][short_hand] = cls
elif issubclass(cls, ServerAuthorizationProvider):
_provider_registry["server_authz_providers"][short_hand] = cls
elif issubclass(cls, ServerAuthorizationConfigurationProvider):
_provider_registry["server_authz_config_providers"][short_hand] = cls
else:
raise ValueError(
"Only ClientAuthProvider, ClientAuthConfigurationProvider, "
"ClientAuthCredentialsProvider, ServerAuthProvider, "
"ServerAuthConfigurationProvider, and ServerAuthCredentialsProvider, "
"ClientAuthProtocolAdapter, ServerAuthorizationProvider, "
"ServerAuthorizationConfigurationProvider can be registered."
)
return cls
return decorator
def resolve_provider(
class_or_name: str, cls: Type[ProviderTypes]
) -> Type[ProviderTypes]:
register_classes_from_package("chromadb.auth")
global _provider_registry
if issubclass(cls, ClientAuthProvider):
_key = "client_auth_providers"
elif issubclass(cls, ClientAuthConfigurationProvider):
_key = "client_auth_config_providers"
elif issubclass(cls, ClientAuthCredentialsProvider):
_key = "client_auth_credentials_providers"
elif issubclass(cls, ClientAuthProtocolAdapter):
_key = "client_auth_protocol_adapters"
elif issubclass(cls, ServerAuthProvider):
_key = "server_auth_providers"
elif issubclass(cls, ServerAuthConfigurationProvider):
_key = "server_auth_config_providers"
elif issubclass(cls, ServerAuthCredentialsProvider):
_key = "server_auth_credentials_providers"
elif issubclass(cls, ServerAuthorizationProvider):
_key = "server_authz_providers"
elif issubclass(cls, ServerAuthorizationConfigurationProvider):
_key = "server_authz_config_providers"
else:
raise ValueError(
"Only ClientAuthProvider, ClientAuthConfigurationProvider, "
"ClientAuthCredentialsProvider, ServerAuthProvider, "
"ServerAuthConfigurationProvider, and ServerAuthCredentialsProvider, "
"ClientAuthProtocolAdapter, ServerAuthorizationProvider,"
"ServerAuthorizationConfigurationProvider, can be registered."
)
if class_or_name in _provider_registry[_key]:
return _provider_registry[_key][class_or_name]
else:
return get_class(class_or_name, cls) # type: ignore