Spaces:
Running
on
Zero
Running
on
Zero
import pytest | |
from datasets import Dataset | |
from ether0.models import QAExample, RewardFunctionInfo, filter_problem_types | |
class TestModels: | |
def test_load(self, ether0_benchmark_test: Dataset) -> None: | |
ether0_parsed = [QAExample(**r) for r in ether0_benchmark_test] | |
ex_0 = ether0_parsed[0] | |
assert isinstance(ex_0, QAExample) | |
assert ex_0.id == "00c8bc2d-0bb3-53c2-8bdf-cd19616d4536" | |
assert ( | |
ex_0.problem | |
== "Generate a SMILES representation for a molecule containing groups:" | |
" charged and nitro. It should also have formula C13H12N6O5." | |
) | |
assert ex_0.problem_type == "functional-group" | |
assert ex_0.ideal == "Cc1ncc([N+](=O)[O-])n1CC(=O)N/N=C/c1ccc([N+](=O)[O-])cc1" | |
assert ex_0.unformatted == "C13H12N6O5,['charged', 'nitro']" | |
assert isinstance(ex_0.solution, RewardFunctionInfo) | |
ex0_sol = ex_0.solution | |
assert ( | |
(ex0_sol.fxn_name, ex0_sol.answer_info, ex0_sol.problem_type) | |
== tuple(ex0_sol.model_dump().values()) | |
== ( | |
"functional_group_eval", | |
"('C13H12N6O5', ['charged', 'nitro'])", | |
"functional-group", | |
) | |
) | |
# NOTE: the num_expected_types numbers may have to be adjusted if we add | |
# more problem types to the dataset. | |
def test_filter_problem_types( | |
ether0_benchmark_test: Dataset, | |
filters: list[str] | None, | |
should_remove_rows: bool, | |
num_expected_types: int, | |
should_raise: bool, | |
) -> None: | |
if should_raise: | |
with pytest.raises( | |
ValueError, | |
match="Cannot specify both problem types to keep and to exclude", | |
): | |
filter_problem_types(ether0_benchmark_test, filters) | |
return | |
filtered = filter_problem_types(ether0_benchmark_test, filters) | |
problem_types = set(filtered["problem_type"]) | |
assert len(problem_types) == num_expected_types | |
assert (len(filtered) < len(ether0_benchmark_test)) == should_remove_rows | |