import re from collections.abc import Collection from enum import StrEnum, auto from typing import Any from datasets import DatasetDict from pydantic import BaseModel, Field, model_validator from ether0.utils import TDataset REWARD_REASON_KEY = "reward_reason" # Sentinel key class RewardReason(StrEnum): FORMAT_FAILED = auto() INVALID_MOL = auto() # Catch-all for invalid values that aren't a molecule or a reaction INVALID_VALUE = auto() # Oracle regression values WRONG_NUMERICAL_ANSWER = auto() # Reaction/retro-synthesis failures INVALID_RXN = auto() WRONG_PRODUCT = auto() PRODUCT_IS_REACTANT = auto() NOT_PURCHASABLE = auto() # Molecule formula/functional group failures WRONG_FORMULA = auto() FAILED_CONSTRAINT = auto() # Unreasonable molecules FAILED_REOS_CHECK = auto() FAILED_RING_CHECK = auto() FAILED_COUNTERION_CHECK = auto() # Really this is a bug, but we don't want to blow up training if a # few bad examples slip through. INVALID_GROUND_TRUTH = auto() # Failover reason if we have an exception during a reward function. # NOTE: not using "failed" or "error" since an unhandled exception # may be something else REWARD_FUNCTION_EXCEPTION = auto() # These are automatically added if no other reason is given WRONG_ANSWER = auto() RIGHT_ANSWER = auto() def set_reason(self, metadata: dict | None) -> None: if metadata is not None: metadata[REWARD_REASON_KEY] = self.value @classmethod def set_default_reason(cls, reward: float, metadata: dict | None) -> None: if metadata is not None and REWARD_REASON_KEY not in metadata: (cls.RIGHT_ANSWER if reward >= 1.0 else cls.WRONG_ANSWER).set_reason( metadata ) SOLUTION_DELIMITER = "!:!" class RewardFunctionInfo(BaseModel): """Metadata used by a reward function to evaluate a solution.""" fxn_name: str = Field(description="Name of the reward function to use.") answer_info: str = Field( description="Serialized metadata used by the reward function." ) problem_type: str = Field(description="Problem type, for reference.") @model_validator(mode="before") @classmethod def check_card_number_not_present(cls, data: Any) -> Any: if isinstance(data, str): # Deserialize from a string 3-tuple fn, ainfo, pt = data.split(SOLUTION_DELIMITER, maxsplit=2) return {"fxn_name": fn, "answer_info": ainfo, "problem_type": pt} return data class QAExample(BaseModel): """Question-answer example with reward function info.""" id: str = Field(description="Unique identifier for this example.") problem: str = Field(description="Problem to solve.") problem_type: str = Field(description="Problem type, for reference or filtering.") solution: RewardFunctionInfo = Field( description="Metadata for the reward function." ) ideal: str | None = Field( description=( "An optional ideal answer. This could be a candidate SMILES, a log10 of" " water solubility, or None if having an ideal does not make sense." ) ) unformatted: str | None = Field( description=( "Optional raw data used to generate the problem, used for traceability." ) ) def filter_problem_types( dataset: TDataset, problem_types: str | Collection[str] | None ) -> TDataset: """Filter a dataset by problem types. Args: dataset: The dataset to filter. Can be a single Dataset or a DatasetDict. problem_types: A string or collection of strings specifying the problem types to filter by. - If None, the original dataset is returned. - If a string or a collection of strings: - Strings starting with "re:" are treated as regex patterns. If a regex filter is provided, then it must be the only filter. - Strings starting with "!" are treated as problem types to exclude. - Other strings are treated as exact problem types to include. - Mixing inclusion and exclusion rules (e.g. ["type_a", "!type_b"]) is not allowed. Returns: The filtered dataset. """ if problem_types is None: return dataset if isinstance(problem_types, str): # Assume single problem type as a string problem_types = [problem_types] problem_types = {pt.strip() for pt in problem_types} columns = ( next(iter(dataset.values())) if isinstance(dataset, DatasetDict) else dataset ).column_names # ether0-benchmark uses 'problem_type'; some variants may use 'type' type_col = "problem_type" if "problem_type" in columns else "type" if any(pt.startswith("re:") for pt in problem_types): # A regex was passed in if len(problem_types) != 1: raise ValueError( "If filtering by regex, only one filter is supported," f" passed {problem_types}." ) regex = re.compile(next(iter(problem_types)).removeprefix("re:")) def filter_func(x): return regex.match(x[type_col]) is not None else: # Treat as exact string match valid_problem_types = {pt for pt in problem_types if not pt.startswith("!")} invalid_problem_types = { pt.removeprefix("!") for pt in problem_types if pt.startswith("!") } if valid_problem_types: if invalid_problem_types: raise ValueError( "Cannot specify both problem types to keep and to exclude," f" passed {problem_types}." ) def filter_func(x): return x[type_col] in valid_problem_types else: def filter_func(x): return x[type_col] not in invalid_problem_types return dataset.filter(filter_func, desc="Filtering problem types")