Spaces:
Sleeping
Sleeping
import json | |
import logging | |
import string | |
from enum import Enum | |
from typing import List, Optional, Tuple, Any, TypedDict, cast, Dict, TypeVar | |
from overrides import override | |
from pydantic import SecretStr | |
import yaml | |
from chromadb.auth import ( | |
ServerAuthProvider, | |
ClientAuthProvider, | |
ServerAuthenticationRequest, | |
ServerAuthCredentialsProvider, | |
AuthInfoType, | |
ClientAuthCredentialsProvider, | |
ClientAuthResponse, | |
SecretStrAbstractCredentials, | |
AbstractCredentials, | |
SimpleServerAuthenticationResponse, | |
SimpleUserIdentity, | |
) | |
from chromadb.auth.registry import register_provider, resolve_provider | |
from chromadb.config import System | |
from chromadb.telemetry.opentelemetry import ( | |
OpenTelemetryGranularity, | |
trace_method, | |
) | |
from chromadb.utils import get_class | |
T = TypeVar("T") | |
logger = logging.getLogger(__name__) | |
__all__ = ["TokenAuthServerProvider", "TokenAuthClientProvider"] | |
_token_transport_headers = ["Authorization", "X-Chroma-Token"] | |
class TokenTransportHeader(Enum): | |
AUTHORIZATION = "Authorization" | |
X_CHROMA_TOKEN = "X-Chroma-Token" | |
class TokenAuthClientAuthResponse(ClientAuthResponse): | |
_token_transport_header: TokenTransportHeader | |
def __init__( | |
self, | |
credentials: SecretStr, | |
token_transport_header: TokenTransportHeader = TokenTransportHeader.AUTHORIZATION, | |
) -> None: | |
self._credentials = credentials | |
self._token_transport_header = token_transport_header | |
def get_auth_info_type(self) -> AuthInfoType: | |
return AuthInfoType.HEADER | |
def get_auth_info(self) -> Tuple[str, SecretStr]: | |
if self._token_transport_header == TokenTransportHeader.AUTHORIZATION: | |
return "Authorization", SecretStr( | |
f"Bearer {self._credentials.get_secret_value()}" | |
) | |
elif self._token_transport_header == TokenTransportHeader.X_CHROMA_TOKEN: | |
return "X-Chroma-Token", SecretStr( | |
f"{self._credentials.get_secret_value()}" | |
) | |
else: | |
raise ValueError( | |
f"Invalid token transport header: {self._token_transport_header}" | |
) | |
def check_token(token: str) -> None: | |
token_str = str(token) | |
if not all( | |
c in string.digits + string.ascii_letters + string.punctuation | |
for c in token_str | |
): | |
raise ValueError("Invalid token. Must contain only ASCII letters and digits.") | |
class TokenConfigServerAuthCredentialsProvider(ServerAuthCredentialsProvider): | |
_token: SecretStr | |
def __init__(self, system: System) -> None: | |
super().__init__(system) | |
system.settings.require("chroma_server_auth_credentials") | |
token_str = str(system.settings.chroma_server_auth_credentials) | |
check_token(token_str) | |
self._token = SecretStr(token_str) | |
def validate_credentials(self, credentials: AbstractCredentials[T]) -> bool: | |
_creds = cast(Dict[str, SecretStr], credentials.get_credentials()) | |
if "token" not in _creds: | |
logger.error("Returned credentials do not contain token") | |
return False | |
return _creds["token"].get_secret_value() == self._token.get_secret_value() | |
def get_user_identity( | |
self, credentials: AbstractCredentials[T] | |
) -> Optional[SimpleUserIdentity]: | |
return None | |
class Token(TypedDict): | |
token: str | |
secret: str | |
class User(TypedDict): | |
id: str | |
role: str | |
tenant: Optional[str] | |
databases: Optional[List[str]] | |
tokens: List[Token] | |
class UserTokenConfigServerAuthCredentialsProvider(ServerAuthCredentialsProvider): | |
_users: List[User] | |
_token_user_mapping: Dict[str, str] # reverse mapping of token to user | |
def __init__(self, system: System) -> None: | |
super().__init__(system) | |
if system.settings.chroma_server_auth_credentials_file: | |
system.settings.require("chroma_server_auth_credentials_file") | |
user_file = str(system.settings.chroma_server_auth_credentials_file) | |
with open(user_file) as f: | |
self._users = cast(List[User], yaml.safe_load(f)["users"]) | |
elif system.settings.chroma_server_auth_credentials: | |
self._users = cast( | |
List[User], json.loads(system.settings.chroma_server_auth_credentials) | |
) | |
self._token_user_mapping = {} | |
for user in self._users: | |
for t in user["tokens"]: | |
token_str = t["token"] | |
check_token(token_str) | |
if token_str in self._token_user_mapping: | |
raise ValueError("Token already exists for another user") | |
self._token_user_mapping[token_str] = user["id"] | |
def find_user_by_id(self, _user_id: str) -> Optional[User]: | |
for user in self._users: | |
if user["id"] == _user_id: | |
return user | |
return None | |
def validate_credentials(self, credentials: AbstractCredentials[T]) -> bool: | |
_creds = cast(Dict[str, SecretStr], credentials.get_credentials()) | |
if "token" not in _creds: | |
logger.error("Returned credentials do not contain token") | |
return False | |
return _creds["token"].get_secret_value() in self._token_user_mapping.keys() | |
def get_user_identity( | |
self, credentials: AbstractCredentials[T] | |
) -> Optional[SimpleUserIdentity]: | |
_creds = cast(Dict[str, SecretStr], credentials.get_credentials()) | |
if "token" not in _creds: | |
logger.error("Returned credentials do not contain token") | |
return None | |
# below is just simple identity mapping and may need future work for more | |
# complex use cases | |
_user_id = self._token_user_mapping[_creds["token"].get_secret_value()] | |
_user = self.find_user_by_id(_user_id) | |
return SimpleUserIdentity( | |
user_id=_user_id, | |
tenant=_user["tenant"] if _user and "tenant" in _user else "*", | |
databases=_user["databases"] if _user and "databases" in _user else ["*"], | |
) | |
class TokenAuthCredentials(SecretStrAbstractCredentials): | |
_token: SecretStr | |
def __init__(self, token: SecretStr) -> None: | |
self._token = token | |
def get_credentials(self) -> Dict[str, SecretStr]: | |
return {"token": self._token} | |
def from_header( | |
header: str, | |
token_transport_header: TokenTransportHeader = TokenTransportHeader.AUTHORIZATION, | |
) -> "TokenAuthCredentials": | |
""" | |
Extracts token from header and returns a TokenAuthCredentials object. | |
""" | |
if token_transport_header == TokenTransportHeader.AUTHORIZATION: | |
header = header.replace("Bearer ", "") | |
header = header.strip() | |
token = header | |
elif token_transport_header == TokenTransportHeader.X_CHROMA_TOKEN: | |
header = header.strip() | |
token = header | |
else: | |
raise ValueError( | |
f"Invalid token transport header: {token_transport_header}" | |
) | |
return TokenAuthCredentials(SecretStr(token)) | |
class TokenAuthServerProvider(ServerAuthProvider): | |
_credentials_provider: ServerAuthCredentialsProvider | |
_token_transport_header: TokenTransportHeader = TokenTransportHeader.AUTHORIZATION | |
def __init__(self, system: System) -> None: | |
super().__init__(system) | |
self._settings = system.settings | |
system.settings.require("chroma_server_auth_credentials_provider") | |
self._credentials_provider = cast( | |
ServerAuthCredentialsProvider, | |
system.require( | |
resolve_provider( | |
str(system.settings.chroma_server_auth_credentials_provider), | |
ServerAuthCredentialsProvider, | |
) | |
), | |
) | |
if system.settings.chroma_server_auth_token_transport_header: | |
self._token_transport_header = TokenTransportHeader[ | |
str(system.settings.chroma_server_auth_token_transport_header) | |
] | |
def authenticate( | |
self, request: ServerAuthenticationRequest[Any] | |
) -> SimpleServerAuthenticationResponse: | |
try: | |
_auth_header = request.get_auth_info( | |
AuthInfoType.HEADER, self._token_transport_header.value | |
) | |
_token_creds = TokenAuthCredentials.from_header( | |
_auth_header, self._token_transport_header | |
) | |
return SimpleServerAuthenticationResponse( | |
self._credentials_provider.validate_credentials(_token_creds), | |
self._credentials_provider.get_user_identity(_token_creds), | |
) | |
except Exception as e: | |
logger.error(f"TokenAuthServerProvider.authenticate failed: {repr(e)}") | |
return SimpleServerAuthenticationResponse(False, None) | |
class TokenAuthClientProvider(ClientAuthProvider): | |
_credentials_provider: ClientAuthCredentialsProvider[Any] | |
_token_transport_header: TokenTransportHeader = TokenTransportHeader.AUTHORIZATION | |
def __init__(self, system: System) -> None: | |
super().__init__(system) | |
self._settings = system.settings | |
system.settings.require("chroma_client_auth_credentials_provider") | |
self._credentials_provider = system.require( | |
get_class( | |
str(system.settings.chroma_client_auth_credentials_provider), | |
ClientAuthCredentialsProvider, | |
) | |
) | |
_token = self._credentials_provider.get_credentials() | |
check_token(_token.get_secret_value()) | |
if system.settings.chroma_client_auth_token_transport_header: | |
self._token_transport_header = TokenTransportHeader[ | |
str(system.settings.chroma_client_auth_token_transport_header) | |
] | |
def authenticate(self) -> ClientAuthResponse: | |
_token = self._credentials_provider.get_credentials() | |
return TokenAuthClientAuthResponse(_token, self._token_transport_header) | |