import base64
import logging
from typing import Tuple, Any, cast

from overrides import override
from pydantic import SecretStr

from chromadb.auth import (
    ServerAuthProvider,
    ClientAuthProvider,
    ServerAuthenticationRequest,
    ServerAuthCredentialsProvider,
    AuthInfoType,
    BasicAuthCredentials,
    ClientAuthCredentialsProvider,
    ClientAuthResponse,
    SimpleServerAuthenticationResponse,
)
from chromadb.auth.registry import register_provider, resolve_provider
from chromadb.config import System
from chromadb.telemetry.opentelemetry import (
    OpenTelemetryGranularity,
    trace_method,
)
from chromadb.utils import get_class

logger = logging.getLogger(__name__)

__all__ = ["BasicAuthServerProvider", "BasicAuthClientProvider"]


class BasicAuthClientAuthResponse(ClientAuthResponse):
    def __init__(self, credentials: SecretStr) -> None:
        self._credentials = credentials

    @override
    def get_auth_info_type(self) -> AuthInfoType:
        return AuthInfoType.HEADER

    @override
    def get_auth_info(self) -> Tuple[str, SecretStr]:
        return "Authorization", SecretStr(
            f"Basic {self._credentials.get_secret_value()}"
        )


@register_provider("basic")
class BasicAuthClientProvider(ClientAuthProvider):
    _credentials_provider: ClientAuthCredentialsProvider[Any]

    def __init__(self, system: System) -> None:
        super().__init__(system)
        self._settings = system.settings
        system.settings.require("chroma_client_auth_credentials_provider")
        self._credentials_provider = system.require(
            get_class(
                str(system.settings.chroma_client_auth_credentials_provider),
                ClientAuthCredentialsProvider,
            )
        )

    @override
    def authenticate(self) -> ClientAuthResponse:
        _creds = self._credentials_provider.get_credentials()
        return BasicAuthClientAuthResponse(
            SecretStr(
                base64.b64encode(f"{_creds.get_secret_value()}".encode("utf-8")).decode(
                    "utf-8"
                )
            )
        )


@register_provider("basic")
class BasicAuthServerProvider(ServerAuthProvider):
    _credentials_provider: ServerAuthCredentialsProvider

    def __init__(self, system: System) -> None:
        super().__init__(system)
        self._settings = system.settings
        system.settings.require("chroma_server_auth_credentials_provider")
        self._credentials_provider = cast(
            ServerAuthCredentialsProvider,
            system.require(
                resolve_provider(
                    str(system.settings.chroma_server_auth_credentials_provider),
                    ServerAuthCredentialsProvider,
                )
            ),
        )

    @trace_method("BasicAuthServerProvider.authenticate", OpenTelemetryGranularity.ALL)
    @override
    def authenticate(
        self, request: ServerAuthenticationRequest[Any]
    ) -> SimpleServerAuthenticationResponse:
        try:
            _auth_header = request.get_auth_info(AuthInfoType.HEADER, "Authorization")
            _validation = self._credentials_provider.validate_credentials(
                BasicAuthCredentials.from_header(_auth_header)
            )
            return SimpleServerAuthenticationResponse(
                _validation,
                self._credentials_provider.get_user_identity(
                    BasicAuthCredentials.from_header(_auth_header)
                ),
            )
        except Exception as e:
            logger.error(f"BasicAuthServerProvider.authenticate failed: {repr(e)}")
            return SimpleServerAuthenticationResponse(False, None)