Spaces:
Running
Running
File size: 2,720 Bytes
287a0bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import shutil
import os
from typing import List, Hashable
import hypothesis.strategies as st
import onnxruntime
import pytest
from hypothesis import given, settings
from chromadb.utils.embedding_functions import ONNXMiniLM_L6_V2, _verify_sha256
def unique_by(x: Hashable) -> Hashable:
return x
@settings(deadline=None)
@given(
providers=st.lists(
st.sampled_from(onnxruntime.get_all_providers()).filter(
lambda x: x not in onnxruntime.get_available_providers()
),
unique_by=unique_by,
min_size=1,
)
)
def test_unavailable_provider_multiple(providers: List[str]) -> None:
with pytest.raises(ValueError) as e:
ef = ONNXMiniLM_L6_V2(preferred_providers=providers)
ef(["test"])
assert "Preferred providers must be subset of available providers" in str(e.value)
@given(
providers=st.lists(
st.sampled_from(onnxruntime.get_all_providers()).filter(
lambda x: x in onnxruntime.get_available_providers()
),
min_size=1,
unique_by=unique_by,
)
)
def test_available_provider(providers: List[str]) -> None:
ef = ONNXMiniLM_L6_V2(preferred_providers=providers)
ef(["test"])
def test_warning_no_providers_supplied() -> None:
ef = ONNXMiniLM_L6_V2()
ef(["test"])
@given(
providers=st.lists(
st.sampled_from(onnxruntime.get_all_providers()).filter(
lambda x: x in onnxruntime.get_available_providers()
),
min_size=1,
).filter(lambda x: len(x) > len(set(x)))
)
def test_provider_repeating(providers: List[str]) -> None:
with pytest.raises(ValueError) as e:
ef = ONNXMiniLM_L6_V2(preferred_providers=providers)
ef(["test"])
assert "Preferred providers must be unique" in str(e.value)
def test_invalid_sha256() -> None:
ef = ONNXMiniLM_L6_V2()
shutil.rmtree(ef.DOWNLOAD_PATH) # clean up any existing models
with pytest.raises(ValueError) as e:
ef._MODEL_SHA256 = "invalid"
ef(["test"])
assert "does not match expected SHA256 hash" in str(e.value)
def test_partial_download() -> None:
ef = ONNXMiniLM_L6_V2()
shutil.rmtree(ef.DOWNLOAD_PATH, ignore_errors=True) # clean up any existing models
os.makedirs(ef.DOWNLOAD_PATH, exist_ok=True)
path = os.path.join(ef.DOWNLOAD_PATH, ef.ARCHIVE_FILENAME)
with open(path, "wb") as f: # create invalid file to simulate partial download
f.write(b"invalid")
ef._download_model_if_not_exists() # re-download model
assert os.path.exists(path)
assert _verify_sha256(
str(os.path.join(ef.DOWNLOAD_PATH, ef.ARCHIVE_FILENAME)),
ef._MODEL_SHA256,
)
assert len(ef(["test"])) == 1
|