import os
import pytest
from unittest.mock import patch, MagicMock

import chromadb
from chromadb.db.impl.sqlite import SqliteDB
from chromadb.config import System, Settings


@pytest.mark.parametrize("migrations_hash_algorithm", [None, "md5", "sha256"])
@patch("chromadb.api.fastapi.FastAPI")
@patch.dict(os.environ, {}, clear=True)
def test_settings_valid_hash_algorithm(
    api_mock: MagicMock, migrations_hash_algorithm: str
) -> None:
    """
    Ensure that when no hash algorithm or a valid one is provided, the client is set up
    with that value
    """
    if migrations_hash_algorithm:
        settings = chromadb.config.Settings(
            chroma_api_impl="chromadb.api.fastapi.FastAPI",
            is_persistent=True,
            persist_directory="./foo",
            migrations_hash_algorithm=migrations_hash_algorithm,
        )
    else:
        settings = chromadb.config.Settings(
            chroma_api_impl="chromadb.api.fastapi.FastAPI",
            is_persistent=True,
            persist_directory="./foo",
        )

    client = chromadb.Client(settings)

    # Check that the mock was called
    assert api_mock.called

    # Retrieve the arguments with which the mock was called
    # `call_args` returns a tuple, where the first element is a tuple of positional arguments
    # and the second element is a dictionary of keyword arguments. We assume here that
    # the settings object is passed as a positional argument.
    args, kwargs = api_mock.call_args
    passed_settings = args[0] if args else None

    # Check if the default hash algorith was set
    expected_migrations_hash_algorithm = migrations_hash_algorithm or "md5"
    assert passed_settings
    assert (
        getattr(passed_settings.settings, "migrations_hash_algorithm", None)
        == expected_migrations_hash_algorithm
    )
    client.clear_system_cache()


@patch("chromadb.api.fastapi.FastAPI")
@patch.dict(os.environ, {}, clear=True)
def test_settings_invalid_hash_algorithm(mock: MagicMock) -> None:
    """
    Ensure that providing an invalid hash results in a raised exception and the client
    is not called
    """
    with pytest.raises(Exception):
        settings = chromadb.config.Settings(
            chroma_api_impl="chromadb.api.fastapi.FastAPI",
            migrations_hash_algorithm="invalid_hash_alg",
            persist_directory="./foo",
        )

        chromadb.Client(settings)

    assert not mock.called


@pytest.mark.parametrize("migrations_hash_algorithm", ["md5", "sha256"])
@patch("chromadb.db.migrations.verify_migration_sequence")
@patch("chromadb.db.migrations.hashlib")
@patch.dict(os.environ, {}, clear=True)
def test_hashlib_alg(hashlib_mock: MagicMock, verify_migration_sequence_mock: MagicMock, migrations_hash_algorithm: str) -> None:
    """
    Test that only the appropriate hashlib functions are called
    """
    db = SqliteDB(
        System(
            Settings(
                migrations="apply",
                allow_reset=True,
                migrations_hash_algorithm=migrations_hash_algorithm,
            )
        )
    )

    # replace the real migration application call with a mock we can check
    db.apply_migration = MagicMock()  # type: ignore [method-assign]

    # we don't want `verify_migration_sequence` to actually run since a) we're not testing that functionality and
    # b) db may be cached between tests, and we're changing the algorithm, so it may fail.
    # Instead, return a fake unapplied migration (expect `apply_migration` to be called after)
    verify_migration_sequence_mock.return_value = ["unapplied_migration"]

    db.start()

    assert db.apply_migration.called

    # Check if the default hash algorith was set
    expected_migrations_hash_algorithm = migrations_hash_algorithm or "md5"

    # check that the right algorithm was used
    if expected_migrations_hash_algorithm == "md5":
        assert hashlib_mock.md5.called
        assert not hashlib_mock.sha256.called
    elif expected_migrations_hash_algorithm == "sha256":
        assert not hashlib_mock.md5.called
        assert hashlib_mock.sha256.called
    else:
        # we only support the algorithms above
        assert False