Spaces:
Sleeping
Sleeping
import os | |
import pytest | |
from unittest.mock import patch, MagicMock | |
import chromadb | |
from chromadb.db.impl.sqlite import SqliteDB | |
from chromadb.config import System, Settings | |
def test_settings_valid_hash_algorithm( | |
api_mock: MagicMock, migrations_hash_algorithm: str | |
) -> None: | |
""" | |
Ensure that when no hash algorithm or a valid one is provided, the client is set up | |
with that value | |
""" | |
if migrations_hash_algorithm: | |
settings = chromadb.config.Settings( | |
chroma_api_impl="chromadb.api.fastapi.FastAPI", | |
is_persistent=True, | |
persist_directory="./foo", | |
migrations_hash_algorithm=migrations_hash_algorithm, | |
) | |
else: | |
settings = chromadb.config.Settings( | |
chroma_api_impl="chromadb.api.fastapi.FastAPI", | |
is_persistent=True, | |
persist_directory="./foo", | |
) | |
client = chromadb.Client(settings) | |
# Check that the mock was called | |
assert api_mock.called | |
# Retrieve the arguments with which the mock was called | |
# `call_args` returns a tuple, where the first element is a tuple of positional arguments | |
# and the second element is a dictionary of keyword arguments. We assume here that | |
# the settings object is passed as a positional argument. | |
args, kwargs = api_mock.call_args | |
passed_settings = args[0] if args else None | |
# Check if the default hash algorith was set | |
expected_migrations_hash_algorithm = migrations_hash_algorithm or "md5" | |
assert passed_settings | |
assert ( | |
getattr(passed_settings.settings, "migrations_hash_algorithm", None) | |
== expected_migrations_hash_algorithm | |
) | |
client.clear_system_cache() | |
def test_settings_invalid_hash_algorithm(mock: MagicMock) -> None: | |
""" | |
Ensure that providing an invalid hash results in a raised exception and the client | |
is not called | |
""" | |
with pytest.raises(Exception): | |
settings = chromadb.config.Settings( | |
chroma_api_impl="chromadb.api.fastapi.FastAPI", | |
migrations_hash_algorithm="invalid_hash_alg", | |
persist_directory="./foo", | |
) | |
chromadb.Client(settings) | |
assert not mock.called | |
def test_hashlib_alg(hashlib_mock: MagicMock, verify_migration_sequence_mock: MagicMock, migrations_hash_algorithm: str) -> None: | |
""" | |
Test that only the appropriate hashlib functions are called | |
""" | |
db = SqliteDB( | |
System( | |
Settings( | |
migrations="apply", | |
allow_reset=True, | |
migrations_hash_algorithm=migrations_hash_algorithm, | |
) | |
) | |
) | |
# replace the real migration application call with a mock we can check | |
db.apply_migration = MagicMock() # type: ignore [method-assign] | |
# we don't want `verify_migration_sequence` to actually run since a) we're not testing that functionality and | |
# b) db may be cached between tests, and we're changing the algorithm, so it may fail. | |
# Instead, return a fake unapplied migration (expect `apply_migration` to be called after) | |
verify_migration_sequence_mock.return_value = ["unapplied_migration"] | |
db.start() | |
assert db.apply_migration.called | |
# Check if the default hash algorith was set | |
expected_migrations_hash_algorithm = migrations_hash_algorithm or "md5" | |
# check that the right algorithm was used | |
if expected_migrations_hash_algorithm == "md5": | |
assert hashlib_mock.md5.called | |
assert not hashlib_mock.sha256.called | |
elif expected_migrations_hash_algorithm == "sha256": | |
assert not hashlib_mock.md5.called | |
assert hashlib_mock.sha256.called | |
else: | |
# we only support the algorithms above | |
assert False | |