import multiprocessing
from typing import Any, Dict, Generator, Optional, Tuple
import pytest
from chromadb import CloudClient
from chromadb.api import ServerAPI
from chromadb.auth.token import TokenTransportHeader
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
from chromadb.errors import AuthorizationError

from chromadb.test.conftest import _await_server, _run_server, find_free_port

TOKEN_TRANSPORT_HEADER = TokenTransportHeader.X_CHROMA_TOKEN.name
TEST_CLOUD_HOST = "localhost"


@pytest.fixture(scope="module")
def valid_token() -> str:
    return "valid_token"


@pytest.fixture(scope="module")
def mock_cloud_server(valid_token: str) -> Generator[System, None, None]:
    chroma_server_auth_provider: str = "chromadb.auth.token.TokenAuthServerProvider"
    chroma_server_auth_credentials_provider: str = (
        "chromadb.auth.token.TokenConfigServerAuthCredentialsProvider"
    )
    chroma_server_auth_credentials: str = valid_token
    chroma_server_auth_token_transport_header: str = TOKEN_TRANSPORT_HEADER

    port = find_free_port()

    args: Tuple[
        int,
        bool,
        Optional[str],
        Optional[str],
        Optional[str],
        Optional[str],
        Optional[str],
        Optional[str],
        Optional[str],
        Optional[str],
        Optional[Dict[str, Any]],
    ] = (
        port,
        False,
        None,
        chroma_server_auth_provider,
        chroma_server_auth_credentials_provider,
        None,
        chroma_server_auth_credentials,
        chroma_server_auth_token_transport_header,
        None,
        None,
        None,
    )
    ctx = multiprocessing.get_context("spawn")
    proc = ctx.Process(target=_run_server, args=args, daemon=True)
    proc.start()

    settings = Settings(
        chroma_api_impl="chromadb.api.fastapi.FastAPI",
        chroma_server_host=TEST_CLOUD_HOST,
        chroma_server_http_port=str(port),
        chroma_client_auth_provider="chromadb.auth.token.TokenAuthClientProvider",
        chroma_client_auth_credentials=valid_token,
        chroma_client_auth_token_transport_header=TOKEN_TRANSPORT_HEADER,
    )

    system = System(settings)
    api = system.instance(ServerAPI)
    system.start()
    _await_server(api)
    yield system
    system.stop()
    proc.kill()


def test_valid_key(mock_cloud_server: System, valid_token: str) -> None:
    valid_client = CloudClient(
        tenant=DEFAULT_TENANT,
        database=DEFAULT_DATABASE,
        api_key=valid_token,
        cloud_host=TEST_CLOUD_HOST,
        cloud_port=mock_cloud_server.settings.chroma_server_http_port,  # type: ignore
        enable_ssl=False,
    )

    assert valid_client.heartbeat()


def test_invalid_key(mock_cloud_server: System, valid_token: str) -> None:
    # Try to connect to the default tenant and database with an invalid token
    invalid_token = valid_token + "_invalid"
    with pytest.raises(AuthorizationError):
        client = CloudClient(
            tenant=DEFAULT_TENANT,
            database=DEFAULT_DATABASE,
            api_key=invalid_token,
            cloud_host=TEST_CLOUD_HOST,
            cloud_port=mock_cloud_server.settings.chroma_server_http_port,  # type: ignore
            enable_ssl=False,
        )
        client.heartbeat()