chroma / chromadb /test /db /test_hash.py
badalsahani's picture
feat: chroma initial deploy
287a0bc
raw
history blame contribute delete
4.2 kB
import os
import pytest
from unittest.mock import patch, MagicMock
import chromadb
from chromadb.db.impl.sqlite import SqliteDB
from chromadb.config import System, Settings
@pytest.mark.parametrize("migrations_hash_algorithm", [None, "md5", "sha256"])
@patch("chromadb.api.fastapi.FastAPI")
@patch.dict(os.environ, {}, clear=True)
def test_settings_valid_hash_algorithm(
api_mock: MagicMock, migrations_hash_algorithm: str
) -> None:
"""
Ensure that when no hash algorithm or a valid one is provided, the client is set up
with that value
"""
if migrations_hash_algorithm:
settings = chromadb.config.Settings(
chroma_api_impl="chromadb.api.fastapi.FastAPI",
is_persistent=True,
persist_directory="./foo",
migrations_hash_algorithm=migrations_hash_algorithm,
)
else:
settings = chromadb.config.Settings(
chroma_api_impl="chromadb.api.fastapi.FastAPI",
is_persistent=True,
persist_directory="./foo",
)
client = chromadb.Client(settings)
# Check that the mock was called
assert api_mock.called
# Retrieve the arguments with which the mock was called
# `call_args` returns a tuple, where the first element is a tuple of positional arguments
# and the second element is a dictionary of keyword arguments. We assume here that
# the settings object is passed as a positional argument.
args, kwargs = api_mock.call_args
passed_settings = args[0] if args else None
# Check if the default hash algorith was set
expected_migrations_hash_algorithm = migrations_hash_algorithm or "md5"
assert passed_settings
assert (
getattr(passed_settings.settings, "migrations_hash_algorithm", None)
== expected_migrations_hash_algorithm
)
client.clear_system_cache()
@patch("chromadb.api.fastapi.FastAPI")
@patch.dict(os.environ, {}, clear=True)
def test_settings_invalid_hash_algorithm(mock: MagicMock) -> None:
"""
Ensure that providing an invalid hash results in a raised exception and the client
is not called
"""
with pytest.raises(Exception):
settings = chromadb.config.Settings(
chroma_api_impl="chromadb.api.fastapi.FastAPI",
migrations_hash_algorithm="invalid_hash_alg",
persist_directory="./foo",
)
chromadb.Client(settings)
assert not mock.called
@pytest.mark.parametrize("migrations_hash_algorithm", ["md5", "sha256"])
@patch("chromadb.db.migrations.verify_migration_sequence")
@patch("chromadb.db.migrations.hashlib")
@patch.dict(os.environ, {}, clear=True)
def test_hashlib_alg(hashlib_mock: MagicMock, verify_migration_sequence_mock: MagicMock, migrations_hash_algorithm: str) -> None:
"""
Test that only the appropriate hashlib functions are called
"""
db = SqliteDB(
System(
Settings(
migrations="apply",
allow_reset=True,
migrations_hash_algorithm=migrations_hash_algorithm,
)
)
)
# replace the real migration application call with a mock we can check
db.apply_migration = MagicMock() # type: ignore [method-assign]
# we don't want `verify_migration_sequence` to actually run since a) we're not testing that functionality and
# b) db may be cached between tests, and we're changing the algorithm, so it may fail.
# Instead, return a fake unapplied migration (expect `apply_migration` to be called after)
verify_migration_sequence_mock.return_value = ["unapplied_migration"]
db.start()
assert db.apply_migration.called
# Check if the default hash algorith was set
expected_migrations_hash_algorithm = migrations_hash_algorithm or "md5"
# check that the right algorithm was used
if expected_migrations_hash_algorithm == "md5":
assert hashlib_mock.md5.called
assert not hashlib_mock.sha256.called
elif expected_migrations_hash_algorithm == "sha256":
assert not hashlib_mock.md5.called
assert hashlib_mock.sha256.called
else:
# we only support the algorithms above
assert False