chroma / chromadb /auth /token /__init__.py
badalsahani's picture
feat: chroma initial deploy
287a0bc
raw
history blame contribute delete
10.6 kB
import json
import logging
import string
from enum import Enum
from typing import List, Optional, Tuple, Any, TypedDict, cast, Dict, TypeVar
from overrides import override
from pydantic import SecretStr
import yaml
from chromadb.auth import (
ServerAuthProvider,
ClientAuthProvider,
ServerAuthenticationRequest,
ServerAuthCredentialsProvider,
AuthInfoType,
ClientAuthCredentialsProvider,
ClientAuthResponse,
SecretStrAbstractCredentials,
AbstractCredentials,
SimpleServerAuthenticationResponse,
SimpleUserIdentity,
)
from chromadb.auth.registry import register_provider, resolve_provider
from chromadb.config import System
from chromadb.telemetry.opentelemetry import (
OpenTelemetryGranularity,
trace_method,
)
from chromadb.utils import get_class
T = TypeVar("T")
logger = logging.getLogger(__name__)
__all__ = ["TokenAuthServerProvider", "TokenAuthClientProvider"]
_token_transport_headers = ["Authorization", "X-Chroma-Token"]
class TokenTransportHeader(Enum):
AUTHORIZATION = "Authorization"
X_CHROMA_TOKEN = "X-Chroma-Token"
class TokenAuthClientAuthResponse(ClientAuthResponse):
_token_transport_header: TokenTransportHeader
def __init__(
self,
credentials: SecretStr,
token_transport_header: TokenTransportHeader = TokenTransportHeader.AUTHORIZATION,
) -> None:
self._credentials = credentials
self._token_transport_header = token_transport_header
@override
def get_auth_info_type(self) -> AuthInfoType:
return AuthInfoType.HEADER
@override
def get_auth_info(self) -> Tuple[str, SecretStr]:
if self._token_transport_header == TokenTransportHeader.AUTHORIZATION:
return "Authorization", SecretStr(
f"Bearer {self._credentials.get_secret_value()}"
)
elif self._token_transport_header == TokenTransportHeader.X_CHROMA_TOKEN:
return "X-Chroma-Token", SecretStr(
f"{self._credentials.get_secret_value()}"
)
else:
raise ValueError(
f"Invalid token transport header: {self._token_transport_header}"
)
def check_token(token: str) -> None:
token_str = str(token)
if not all(
c in string.digits + string.ascii_letters + string.punctuation
for c in token_str
):
raise ValueError("Invalid token. Must contain only ASCII letters and digits.")
@register_provider("token_config")
class TokenConfigServerAuthCredentialsProvider(ServerAuthCredentialsProvider):
_token: SecretStr
def __init__(self, system: System) -> None:
super().__init__(system)
system.settings.require("chroma_server_auth_credentials")
token_str = str(system.settings.chroma_server_auth_credentials)
check_token(token_str)
self._token = SecretStr(token_str)
@trace_method(
"TokenConfigServerAuthCredentialsProvider.validate_credentials",
OpenTelemetryGranularity.ALL,
)
@override
def validate_credentials(self, credentials: AbstractCredentials[T]) -> bool:
_creds = cast(Dict[str, SecretStr], credentials.get_credentials())
if "token" not in _creds:
logger.error("Returned credentials do not contain token")
return False
return _creds["token"].get_secret_value() == self._token.get_secret_value()
@override
def get_user_identity(
self, credentials: AbstractCredentials[T]
) -> Optional[SimpleUserIdentity]:
return None
class Token(TypedDict):
token: str
secret: str
class User(TypedDict):
id: str
role: str
tenant: Optional[str]
databases: Optional[List[str]]
tokens: List[Token]
@register_provider("user_token_config")
class UserTokenConfigServerAuthCredentialsProvider(ServerAuthCredentialsProvider):
_users: List[User]
_token_user_mapping: Dict[str, str] # reverse mapping of token to user
def __init__(self, system: System) -> None:
super().__init__(system)
if system.settings.chroma_server_auth_credentials_file:
system.settings.require("chroma_server_auth_credentials_file")
user_file = str(system.settings.chroma_server_auth_credentials_file)
with open(user_file) as f:
self._users = cast(List[User], yaml.safe_load(f)["users"])
elif system.settings.chroma_server_auth_credentials:
self._users = cast(
List[User], json.loads(system.settings.chroma_server_auth_credentials)
)
self._token_user_mapping = {}
for user in self._users:
for t in user["tokens"]:
token_str = t["token"]
check_token(token_str)
if token_str in self._token_user_mapping:
raise ValueError("Token already exists for another user")
self._token_user_mapping[token_str] = user["id"]
def find_user_by_id(self, _user_id: str) -> Optional[User]:
for user in self._users:
if user["id"] == _user_id:
return user
return None
@override
def validate_credentials(self, credentials: AbstractCredentials[T]) -> bool:
_creds = cast(Dict[str, SecretStr], credentials.get_credentials())
if "token" not in _creds:
logger.error("Returned credentials do not contain token")
return False
return _creds["token"].get_secret_value() in self._token_user_mapping.keys()
@override
def get_user_identity(
self, credentials: AbstractCredentials[T]
) -> Optional[SimpleUserIdentity]:
_creds = cast(Dict[str, SecretStr], credentials.get_credentials())
if "token" not in _creds:
logger.error("Returned credentials do not contain token")
return None
# below is just simple identity mapping and may need future work for more
# complex use cases
_user_id = self._token_user_mapping[_creds["token"].get_secret_value()]
_user = self.find_user_by_id(_user_id)
return SimpleUserIdentity(
user_id=_user_id,
tenant=_user["tenant"] if _user and "tenant" in _user else "*",
databases=_user["databases"] if _user and "databases" in _user else ["*"],
)
class TokenAuthCredentials(SecretStrAbstractCredentials):
_token: SecretStr
def __init__(self, token: SecretStr) -> None:
self._token = token
@override
def get_credentials(self) -> Dict[str, SecretStr]:
return {"token": self._token}
@staticmethod
def from_header(
header: str,
token_transport_header: TokenTransportHeader = TokenTransportHeader.AUTHORIZATION,
) -> "TokenAuthCredentials":
"""
Extracts token from header and returns a TokenAuthCredentials object.
"""
if token_transport_header == TokenTransportHeader.AUTHORIZATION:
header = header.replace("Bearer ", "")
header = header.strip()
token = header
elif token_transport_header == TokenTransportHeader.X_CHROMA_TOKEN:
header = header.strip()
token = header
else:
raise ValueError(
f"Invalid token transport header: {token_transport_header}"
)
return TokenAuthCredentials(SecretStr(token))
@register_provider("token")
class TokenAuthServerProvider(ServerAuthProvider):
_credentials_provider: ServerAuthCredentialsProvider
_token_transport_header: TokenTransportHeader = TokenTransportHeader.AUTHORIZATION
def __init__(self, system: System) -> None:
super().__init__(system)
self._settings = system.settings
system.settings.require("chroma_server_auth_credentials_provider")
self._credentials_provider = cast(
ServerAuthCredentialsProvider,
system.require(
resolve_provider(
str(system.settings.chroma_server_auth_credentials_provider),
ServerAuthCredentialsProvider,
)
),
)
if system.settings.chroma_server_auth_token_transport_header:
self._token_transport_header = TokenTransportHeader[
str(system.settings.chroma_server_auth_token_transport_header)
]
@trace_method("TokenAuthServerProvider.authenticate", OpenTelemetryGranularity.ALL)
@override
def authenticate(
self, request: ServerAuthenticationRequest[Any]
) -> SimpleServerAuthenticationResponse:
try:
_auth_header = request.get_auth_info(
AuthInfoType.HEADER, self._token_transport_header.value
)
_token_creds = TokenAuthCredentials.from_header(
_auth_header, self._token_transport_header
)
return SimpleServerAuthenticationResponse(
self._credentials_provider.validate_credentials(_token_creds),
self._credentials_provider.get_user_identity(_token_creds),
)
except Exception as e:
logger.error(f"TokenAuthServerProvider.authenticate failed: {repr(e)}")
return SimpleServerAuthenticationResponse(False, None)
@register_provider("token")
class TokenAuthClientProvider(ClientAuthProvider):
_credentials_provider: ClientAuthCredentialsProvider[Any]
_token_transport_header: TokenTransportHeader = TokenTransportHeader.AUTHORIZATION
def __init__(self, system: System) -> None:
super().__init__(system)
self._settings = system.settings
system.settings.require("chroma_client_auth_credentials_provider")
self._credentials_provider = system.require(
get_class(
str(system.settings.chroma_client_auth_credentials_provider),
ClientAuthCredentialsProvider,
)
)
_token = self._credentials_provider.get_credentials()
check_token(_token.get_secret_value())
if system.settings.chroma_client_auth_token_transport_header:
self._token_transport_header = TokenTransportHeader[
str(system.settings.chroma_client_auth_token_transport_header)
]
@trace_method("TokenAuthClientProvider.authenticate", OpenTelemetryGranularity.ALL)
@override
def authenticate(self) -> ClientAuthResponse:
_token = self._credentials_provider.get_credentials()
return TokenAuthClientAuthResponse(_token, self._token_transport_header)