chroma / chromadb /test /ef /test_default_ef.py
badalsahani's picture
feat: chroma initial deploy
287a0bc
raw
history blame contribute delete
2.72 kB
import shutil
import os
from typing import List, Hashable
import hypothesis.strategies as st
import onnxruntime
import pytest
from hypothesis import given, settings
from chromadb.utils.embedding_functions import ONNXMiniLM_L6_V2, _verify_sha256
def unique_by(x: Hashable) -> Hashable:
return x
@settings(deadline=None)
@given(
providers=st.lists(
st.sampled_from(onnxruntime.get_all_providers()).filter(
lambda x: x not in onnxruntime.get_available_providers()
),
unique_by=unique_by,
min_size=1,
)
)
def test_unavailable_provider_multiple(providers: List[str]) -> None:
with pytest.raises(ValueError) as e:
ef = ONNXMiniLM_L6_V2(preferred_providers=providers)
ef(["test"])
assert "Preferred providers must be subset of available providers" in str(e.value)
@given(
providers=st.lists(
st.sampled_from(onnxruntime.get_all_providers()).filter(
lambda x: x in onnxruntime.get_available_providers()
),
min_size=1,
unique_by=unique_by,
)
)
def test_available_provider(providers: List[str]) -> None:
ef = ONNXMiniLM_L6_V2(preferred_providers=providers)
ef(["test"])
def test_warning_no_providers_supplied() -> None:
ef = ONNXMiniLM_L6_V2()
ef(["test"])
@given(
providers=st.lists(
st.sampled_from(onnxruntime.get_all_providers()).filter(
lambda x: x in onnxruntime.get_available_providers()
),
min_size=1,
).filter(lambda x: len(x) > len(set(x)))
)
def test_provider_repeating(providers: List[str]) -> None:
with pytest.raises(ValueError) as e:
ef = ONNXMiniLM_L6_V2(preferred_providers=providers)
ef(["test"])
assert "Preferred providers must be unique" in str(e.value)
def test_invalid_sha256() -> None:
ef = ONNXMiniLM_L6_V2()
shutil.rmtree(ef.DOWNLOAD_PATH) # clean up any existing models
with pytest.raises(ValueError) as e:
ef._MODEL_SHA256 = "invalid"
ef(["test"])
assert "does not match expected SHA256 hash" in str(e.value)
def test_partial_download() -> None:
ef = ONNXMiniLM_L6_V2()
shutil.rmtree(ef.DOWNLOAD_PATH, ignore_errors=True) # clean up any existing models
os.makedirs(ef.DOWNLOAD_PATH, exist_ok=True)
path = os.path.join(ef.DOWNLOAD_PATH, ef.ARCHIVE_FILENAME)
with open(path, "wb") as f: # create invalid file to simulate partial download
f.write(b"invalid")
ef._download_model_if_not_exists() # re-download model
assert os.path.exists(path)
assert _verify_sha256(
str(os.path.join(ef.DOWNLOAD_PATH, ef.ARCHIVE_FILENAME)),
ef._MODEL_SHA256,
)
assert len(ef(["test"])) == 1