chroma / chromadb /auth /basic /__init__.py
badalsahani's picture
feat: chroma initial deploy
287a0bc
raw
history blame contribute delete
3.68 kB
import base64
import logging
from typing import Tuple, Any, cast
from overrides import override
from pydantic import SecretStr
from chromadb.auth import (
ServerAuthProvider,
ClientAuthProvider,
ServerAuthenticationRequest,
ServerAuthCredentialsProvider,
AuthInfoType,
BasicAuthCredentials,
ClientAuthCredentialsProvider,
ClientAuthResponse,
SimpleServerAuthenticationResponse,
)
from chromadb.auth.registry import register_provider, resolve_provider
from chromadb.config import System
from chromadb.telemetry.opentelemetry import (
OpenTelemetryGranularity,
trace_method,
)
from chromadb.utils import get_class
logger = logging.getLogger(__name__)
__all__ = ["BasicAuthServerProvider", "BasicAuthClientProvider"]
class BasicAuthClientAuthResponse(ClientAuthResponse):
def __init__(self, credentials: SecretStr) -> None:
self._credentials = credentials
@override
def get_auth_info_type(self) -> AuthInfoType:
return AuthInfoType.HEADER
@override
def get_auth_info(self) -> Tuple[str, SecretStr]:
return "Authorization", SecretStr(
f"Basic {self._credentials.get_secret_value()}"
)
@register_provider("basic")
class BasicAuthClientProvider(ClientAuthProvider):
_credentials_provider: ClientAuthCredentialsProvider[Any]
def __init__(self, system: System) -> None:
super().__init__(system)
self._settings = system.settings
system.settings.require("chroma_client_auth_credentials_provider")
self._credentials_provider = system.require(
get_class(
str(system.settings.chroma_client_auth_credentials_provider),
ClientAuthCredentialsProvider,
)
)
@override
def authenticate(self) -> ClientAuthResponse:
_creds = self._credentials_provider.get_credentials()
return BasicAuthClientAuthResponse(
SecretStr(
base64.b64encode(f"{_creds.get_secret_value()}".encode("utf-8")).decode(
"utf-8"
)
)
)
@register_provider("basic")
class BasicAuthServerProvider(ServerAuthProvider):
_credentials_provider: ServerAuthCredentialsProvider
def __init__(self, system: System) -> None:
super().__init__(system)
self._settings = system.settings
system.settings.require("chroma_server_auth_credentials_provider")
self._credentials_provider = cast(
ServerAuthCredentialsProvider,
system.require(
resolve_provider(
str(system.settings.chroma_server_auth_credentials_provider),
ServerAuthCredentialsProvider,
)
),
)
@trace_method("BasicAuthServerProvider.authenticate", OpenTelemetryGranularity.ALL)
@override
def authenticate(
self, request: ServerAuthenticationRequest[Any]
) -> SimpleServerAuthenticationResponse:
try:
_auth_header = request.get_auth_info(AuthInfoType.HEADER, "Authorization")
_validation = self._credentials_provider.validate_credentials(
BasicAuthCredentials.from_header(_auth_header)
)
return SimpleServerAuthenticationResponse(
_validation,
self._credentials_provider.get_user_identity(
BasicAuthCredentials.from_header(_auth_header)
),
)
except Exception as e:
logger.error(f"BasicAuthServerProvider.authenticate failed: {repr(e)}")
return SimpleServerAuthenticationResponse(False, None)