Spaces:
Sleeping
Sleeping
import importlib | |
import logging | |
import pkgutil | |
from typing import Union, Dict, Type, Callable # noqa: F401 | |
from chromadb.auth import ( | |
ClientAuthConfigurationProvider, | |
ClientAuthCredentialsProvider, | |
ClientAuthProtocolAdapter, | |
ServerAuthProvider, | |
ServerAuthConfigurationProvider, | |
ServerAuthCredentialsProvider, | |
ClientAuthProvider, | |
ServerAuthorizationConfigurationProvider, | |
ServerAuthorizationProvider, | |
) | |
from chromadb.utils import get_class | |
logger = logging.getLogger(__name__) | |
ProviderTypes = Union[ | |
"ClientAuthProvider", | |
"ClientAuthConfigurationProvider", | |
"ClientAuthCredentialsProvider", | |
"ServerAuthProvider", | |
"ServerAuthConfigurationProvider", | |
"ServerAuthCredentialsProvider", | |
"ClientAuthProtocolAdapter", | |
"ServerAuthorizationProvider", | |
"ServerAuthorizationConfigurationProvider", | |
] | |
_provider_registry = { | |
"client_auth_providers": {}, | |
"client_auth_config_providers": {}, | |
"client_auth_credentials_providers": {}, | |
"client_auth_protocol_adapters": {}, | |
"server_auth_providers": {}, | |
"server_auth_config_providers": {}, | |
"server_auth_credentials_providers": {}, | |
"server_authz_providers": {}, | |
"server_authz_config_providers": {}, | |
} # type: Dict[str, Dict[str, Type[ProviderTypes]]] | |
def register_classes_from_package(package_name: str) -> None: | |
package = importlib.import_module(package_name) | |
for _, module_name, _ in pkgutil.iter_modules(package.__path__): | |
full_module_name = f"{package_name}.{module_name}" | |
_ = importlib.import_module(full_module_name) | |
def register_provider( | |
short_hand: str, | |
) -> Callable[[Type[ProviderTypes]], Type[ProviderTypes]]: | |
def decorator(cls: Type[ProviderTypes]) -> Type[ProviderTypes]: | |
logger.debug("Registering provider: %s", short_hand) | |
global _provider_registry | |
if issubclass(cls, ClientAuthProvider): | |
_provider_registry["client_auth_providers"][short_hand] = cls | |
elif issubclass(cls, ClientAuthConfigurationProvider): | |
_provider_registry["client_auth_config_providers"][short_hand] = cls | |
elif issubclass(cls, ClientAuthCredentialsProvider): | |
_provider_registry["client_auth_credentials_providers"][short_hand] = cls | |
elif issubclass(cls, ClientAuthProtocolAdapter): | |
_provider_registry["client_auth_protocol_adapters"][short_hand] = cls | |
elif issubclass(cls, ServerAuthProvider): | |
_provider_registry["server_auth_providers"][short_hand] = cls | |
elif issubclass(cls, ServerAuthConfigurationProvider): | |
_provider_registry["server_auth_config_providers"][short_hand] = cls | |
elif issubclass(cls, ServerAuthCredentialsProvider): | |
_provider_registry["server_auth_credentials_providers"][short_hand] = cls | |
elif issubclass(cls, ServerAuthorizationProvider): | |
_provider_registry["server_authz_providers"][short_hand] = cls | |
elif issubclass(cls, ServerAuthorizationConfigurationProvider): | |
_provider_registry["server_authz_config_providers"][short_hand] = cls | |
else: | |
raise ValueError( | |
"Only ClientAuthProvider, ClientAuthConfigurationProvider, " | |
"ClientAuthCredentialsProvider, ServerAuthProvider, " | |
"ServerAuthConfigurationProvider, and ServerAuthCredentialsProvider, " | |
"ClientAuthProtocolAdapter, ServerAuthorizationProvider, " | |
"ServerAuthorizationConfigurationProvider can be registered." | |
) | |
return cls | |
return decorator | |
def resolve_provider( | |
class_or_name: str, cls: Type[ProviderTypes] | |
) -> Type[ProviderTypes]: | |
register_classes_from_package("chromadb.auth") | |
global _provider_registry | |
if issubclass(cls, ClientAuthProvider): | |
_key = "client_auth_providers" | |
elif issubclass(cls, ClientAuthConfigurationProvider): | |
_key = "client_auth_config_providers" | |
elif issubclass(cls, ClientAuthCredentialsProvider): | |
_key = "client_auth_credentials_providers" | |
elif issubclass(cls, ClientAuthProtocolAdapter): | |
_key = "client_auth_protocol_adapters" | |
elif issubclass(cls, ServerAuthProvider): | |
_key = "server_auth_providers" | |
elif issubclass(cls, ServerAuthConfigurationProvider): | |
_key = "server_auth_config_providers" | |
elif issubclass(cls, ServerAuthCredentialsProvider): | |
_key = "server_auth_credentials_providers" | |
elif issubclass(cls, ServerAuthorizationProvider): | |
_key = "server_authz_providers" | |
elif issubclass(cls, ServerAuthorizationConfigurationProvider): | |
_key = "server_authz_config_providers" | |
else: | |
raise ValueError( | |
"Only ClientAuthProvider, ClientAuthConfigurationProvider, " | |
"ClientAuthCredentialsProvider, ServerAuthProvider, " | |
"ServerAuthConfigurationProvider, and ServerAuthCredentialsProvider, " | |
"ClientAuthProtocolAdapter, ServerAuthorizationProvider," | |
"ServerAuthorizationConfigurationProvider, can be registered." | |
) | |
if class_or_name in _provider_registry[_key]: | |
return _provider_registry[_key][class_or_name] | |
else: | |
return get_class(class_or_name, cls) # type: ignore | |