Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,950 Bytes
4c346eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import pytest
from datasets import Dataset
from ether0.models import QAExample, RewardFunctionInfo, filter_problem_types
class TestModels:
def test_load(self, ether0_benchmark_test: Dataset) -> None:
ether0_parsed = [QAExample(**r) for r in ether0_benchmark_test]
ex_0 = ether0_parsed[0]
assert isinstance(ex_0, QAExample)
assert ex_0.id == "00c8bc2d-0bb3-53c2-8bdf-cd19616d4536"
assert (
ex_0.problem
== "Generate a SMILES representation for a molecule containing groups:"
" charged and nitro. It should also have formula C13H12N6O5."
)
assert ex_0.problem_type == "functional-group"
assert ex_0.ideal == "Cc1ncc([N+](=O)[O-])n1CC(=O)N/N=C/c1ccc([N+](=O)[O-])cc1"
assert ex_0.unformatted == "C13H12N6O5,['charged', 'nitro']"
assert isinstance(ex_0.solution, RewardFunctionInfo)
ex0_sol = ex_0.solution
assert (
(ex0_sol.fxn_name, ex0_sol.answer_info, ex0_sol.problem_type)
== tuple(ex0_sol.model_dump().values())
== (
"functional_group_eval",
"('C13H12N6O5', ['charged', 'nitro'])",
"functional-group",
)
)
# NOTE: the num_expected_types numbers may have to be adjusted if we add
# more problem types to the dataset.
@pytest.mark.parametrize(
("filters", "should_remove_rows", "num_expected_types", "should_raise"),
[
pytest.param([], False, 70, False, id="no-filter-1"),
pytest.param(None, False, 70, False, id="no-filter-2"),
pytest.param(["reaction-prediction"], True, 1, False, id="include-1"),
pytest.param(
["reaction-prediction", "retro-synthesis"],
True,
2,
False,
id="include-2",
),
pytest.param(["!reaction-prediction"], True, 69, False, id="exclude-1"),
pytest.param(
["!reaction-prediction", "molecule-name"],
# Note that in this case, should_remove_rows and num_expected are just
# dummy values. Filtering should fail before we get there.
True,
999,
True,
id="exclude-include",
),
],
)
def test_filter_problem_types(
ether0_benchmark_test: Dataset,
filters: list[str] | None,
should_remove_rows: bool,
num_expected_types: int,
should_raise: bool,
) -> None:
if should_raise:
with pytest.raises(
ValueError,
match="Cannot specify both problem types to keep and to exclude",
):
filter_problem_types(ether0_benchmark_test, filters)
return
filtered = filter_problem_types(ether0_benchmark_test, filters)
problem_types = set(filtered["problem_type"])
assert len(problem_types) == num_expected_types
assert (len(filtered) < len(ether0_benchmark_test)) == should_remove_rows
|