ether0-inference / tests /test_models.py
jonahkall's picture
Upload 51 files
4c346eb verified
raw
history blame
2.95 kB
import pytest
from datasets import Dataset
from ether0.models import QAExample, RewardFunctionInfo, filter_problem_types
class TestModels:
def test_load(self, ether0_benchmark_test: Dataset) -> None:
ether0_parsed = [QAExample(**r) for r in ether0_benchmark_test]
ex_0 = ether0_parsed[0]
assert isinstance(ex_0, QAExample)
assert ex_0.id == "00c8bc2d-0bb3-53c2-8bdf-cd19616d4536"
assert (
ex_0.problem
== "Generate a SMILES representation for a molecule containing groups:"
" charged and nitro. It should also have formula C13H12N6O5."
)
assert ex_0.problem_type == "functional-group"
assert ex_0.ideal == "Cc1ncc([N+](=O)[O-])n1CC(=O)N/N=C/c1ccc([N+](=O)[O-])cc1"
assert ex_0.unformatted == "C13H12N6O5,['charged', 'nitro']"
assert isinstance(ex_0.solution, RewardFunctionInfo)
ex0_sol = ex_0.solution
assert (
(ex0_sol.fxn_name, ex0_sol.answer_info, ex0_sol.problem_type)
== tuple(ex0_sol.model_dump().values())
== (
"functional_group_eval",
"('C13H12N6O5', ['charged', 'nitro'])",
"functional-group",
)
)
# NOTE: the num_expected_types numbers may have to be adjusted if we add
# more problem types to the dataset.
@pytest.mark.parametrize(
("filters", "should_remove_rows", "num_expected_types", "should_raise"),
[
pytest.param([], False, 70, False, id="no-filter-1"),
pytest.param(None, False, 70, False, id="no-filter-2"),
pytest.param(["reaction-prediction"], True, 1, False, id="include-1"),
pytest.param(
["reaction-prediction", "retro-synthesis"],
True,
2,
False,
id="include-2",
),
pytest.param(["!reaction-prediction"], True, 69, False, id="exclude-1"),
pytest.param(
["!reaction-prediction", "molecule-name"],
# Note that in this case, should_remove_rows and num_expected are just
# dummy values. Filtering should fail before we get there.
True,
999,
True,
id="exclude-include",
),
],
)
def test_filter_problem_types(
ether0_benchmark_test: Dataset,
filters: list[str] | None,
should_remove_rows: bool,
num_expected_types: int,
should_raise: bool,
) -> None:
if should_raise:
with pytest.raises(
ValueError,
match="Cannot specify both problem types to keep and to exclude",
):
filter_problem_types(ether0_benchmark_test, filters)
return
filtered = filter_problem_types(ether0_benchmark_test, filters)
problem_types = set(filtered["problem_type"])
assert len(problem_types) == num_expected_types
assert (len(filtered) < len(ether0_benchmark_test)) == should_remove_rows