chroma / chromadb /test /ef /test_default_ef.py
badalsahani's picture
feat: chroma initial deploy
287a0bc
import shutil
import os
from typing import List, Hashable
import hypothesis.strategies as st
import onnxruntime
import pytest
from hypothesis import given, settings
from chromadb.utils.embedding_functions import ONNXMiniLM_L6_V2, _verify_sha256
def unique_by(x: Hashable) -> Hashable:
return x
@settings(deadline=None)
@given(
providers=st.lists(
st.sampled_from(onnxruntime.get_all_providers()).filter(
lambda x: x not in onnxruntime.get_available_providers()
),
unique_by=unique_by,
min_size=1,
)
)
def test_unavailable_provider_multiple(providers: List[str]) -> None:
with pytest.raises(ValueError) as e:
ef = ONNXMiniLM_L6_V2(preferred_providers=providers)
ef(["test"])
assert "Preferred providers must be subset of available providers" in str(e.value)
@given(
providers=st.lists(
st.sampled_from(onnxruntime.get_all_providers()).filter(
lambda x: x in onnxruntime.get_available_providers()
),
min_size=1,
unique_by=unique_by,
)
)
def test_available_provider(providers: List[str]) -> None:
ef = ONNXMiniLM_L6_V2(preferred_providers=providers)
ef(["test"])
def test_warning_no_providers_supplied() -> None:
ef = ONNXMiniLM_L6_V2()
ef(["test"])
@given(
providers=st.lists(
st.sampled_from(onnxruntime.get_all_providers()).filter(
lambda x: x in onnxruntime.get_available_providers()
),
min_size=1,
).filter(lambda x: len(x) > len(set(x)))
)
def test_provider_repeating(providers: List[str]) -> None:
with pytest.raises(ValueError) as e:
ef = ONNXMiniLM_L6_V2(preferred_providers=providers)
ef(["test"])
assert "Preferred providers must be unique" in str(e.value)
def test_invalid_sha256() -> None:
ef = ONNXMiniLM_L6_V2()
shutil.rmtree(ef.DOWNLOAD_PATH) # clean up any existing models
with pytest.raises(ValueError) as e:
ef._MODEL_SHA256 = "invalid"
ef(["test"])
assert "does not match expected SHA256 hash" in str(e.value)
def test_partial_download() -> None:
ef = ONNXMiniLM_L6_V2()
shutil.rmtree(ef.DOWNLOAD_PATH, ignore_errors=True) # clean up any existing models
os.makedirs(ef.DOWNLOAD_PATH, exist_ok=True)
path = os.path.join(ef.DOWNLOAD_PATH, ef.ARCHIVE_FILENAME)
with open(path, "wb") as f: # create invalid file to simulate partial download
f.write(b"invalid")
ef._download_model_if_not_exists() # re-download model
assert os.path.exists(path)
assert _verify_sha256(
str(os.path.join(ef.DOWNLOAD_PATH, ef.ARCHIVE_FILENAME)),
ef._MODEL_SHA256,
)
assert len(ef(["test"])) == 1