Spaces:
Running
on
Zero
Running
on
Zero
import re | |
from collections.abc import Collection | |
from enum import StrEnum, auto | |
from typing import Any | |
from datasets import DatasetDict | |
from pydantic import BaseModel, Field, model_validator | |
from ether0.utils import TDataset | |
REWARD_REASON_KEY = "reward_reason" # Sentinel key | |
class RewardReason(StrEnum): | |
FORMAT_FAILED = auto() | |
INVALID_MOL = auto() | |
# Catch-all for invalid values that aren't a molecule or a reaction | |
INVALID_VALUE = auto() | |
# Oracle regression values | |
WRONG_NUMERICAL_ANSWER = auto() | |
# Reaction/retro-synthesis failures | |
INVALID_RXN = auto() | |
WRONG_PRODUCT = auto() | |
PRODUCT_IS_REACTANT = auto() | |
NOT_PURCHASABLE = auto() | |
# Molecule formula/functional group failures | |
WRONG_FORMULA = auto() | |
FAILED_CONSTRAINT = auto() | |
# Unreasonable molecules | |
FAILED_REOS_CHECK = auto() | |
FAILED_RING_CHECK = auto() | |
FAILED_COUNTERION_CHECK = auto() | |
# Really this is a bug, but we don't want to blow up training if a | |
# few bad examples slip through. | |
INVALID_GROUND_TRUTH = auto() | |
# Failover reason if we have an exception during a reward function. | |
# NOTE: not using "failed" or "error" since an unhandled exception | |
# may be something else | |
REWARD_FUNCTION_EXCEPTION = auto() | |
# These are automatically added if no other reason is given | |
WRONG_ANSWER = auto() | |
RIGHT_ANSWER = auto() | |
def set_reason(self, metadata: dict | None) -> None: | |
if metadata is not None: | |
metadata[REWARD_REASON_KEY] = self.value | |
def set_default_reason(cls, reward: float, metadata: dict | None) -> None: | |
if metadata is not None and REWARD_REASON_KEY not in metadata: | |
(cls.RIGHT_ANSWER if reward >= 1.0 else cls.WRONG_ANSWER).set_reason( | |
metadata | |
) | |
SOLUTION_DELIMITER = "!:!" | |
class RewardFunctionInfo(BaseModel): | |
"""Metadata used by a reward function to evaluate a solution.""" | |
fxn_name: str = Field(description="Name of the reward function to use.") | |
answer_info: str = Field( | |
description="Serialized metadata used by the reward function." | |
) | |
problem_type: str = Field(description="Problem type, for reference.") | |
def check_card_number_not_present(cls, data: Any) -> Any: | |
if isinstance(data, str): | |
# Deserialize from a string 3-tuple | |
fn, ainfo, pt = data.split(SOLUTION_DELIMITER, maxsplit=2) | |
return {"fxn_name": fn, "answer_info": ainfo, "problem_type": pt} | |
return data | |
class QAExample(BaseModel): | |
"""Question-answer example with reward function info.""" | |
id: str = Field(description="Unique identifier for this example.") | |
problem: str = Field(description="Problem to solve.") | |
problem_type: str = Field(description="Problem type, for reference or filtering.") | |
solution: RewardFunctionInfo = Field( | |
description="Metadata for the reward function." | |
) | |
ideal: str | None = Field( | |
description=( | |
"An optional ideal answer. This could be a candidate SMILES, a log10 of" | |
" water solubility, or None if having an ideal does not make sense." | |
) | |
) | |
unformatted: str | None = Field( | |
description=( | |
"Optional raw data used to generate the problem, used for traceability." | |
) | |
) | |
def filter_problem_types( | |
dataset: TDataset, problem_types: str | Collection[str] | None | |
) -> TDataset: | |
"""Filter a dataset by problem types. | |
Args: | |
dataset: The dataset to filter. Can be a single Dataset or a DatasetDict. | |
problem_types: A string or collection of strings specifying the problem | |
types to filter by. | |
- If None, the original dataset is returned. | |
- If a string or a collection of strings: | |
- Strings starting with "re:" are treated as regex patterns. | |
If a regex filter is provided, then it must be the only filter. | |
- Strings starting with "!" are treated as problem types to exclude. | |
- Other strings are treated as exact problem types to include. | |
- Mixing inclusion and exclusion rules (e.g. ["type_a", "!type_b"]) | |
is not allowed. | |
Returns: | |
The filtered dataset. | |
""" | |
if problem_types is None: | |
return dataset | |
if isinstance(problem_types, str): # Assume single problem type as a string | |
problem_types = [problem_types] | |
problem_types = {pt.strip() for pt in problem_types} | |
columns = ( | |
next(iter(dataset.values())) if isinstance(dataset, DatasetDict) else dataset | |
).column_names | |
# ether0-benchmark uses 'problem_type'; some variants may use 'type' | |
type_col = "problem_type" if "problem_type" in columns else "type" | |
if any(pt.startswith("re:") for pt in problem_types): | |
# A regex was passed in | |
if len(problem_types) != 1: | |
raise ValueError( | |
"If filtering by regex, only one filter is supported," | |
f" passed {problem_types}." | |
) | |
regex = re.compile(next(iter(problem_types)).removeprefix("re:")) | |
def filter_func(x): | |
return regex.match(x[type_col]) is not None | |
else: | |
# Treat as exact string match | |
valid_problem_types = {pt for pt in problem_types if not pt.startswith("!")} | |
invalid_problem_types = { | |
pt.removeprefix("!") for pt in problem_types if pt.startswith("!") | |
} | |
if valid_problem_types: | |
if invalid_problem_types: | |
raise ValueError( | |
"Cannot specify both problem types to keep and to exclude," | |
f" passed {problem_types}." | |
) | |
def filter_func(x): | |
return x[type_col] in valid_problem_types | |
else: | |
def filter_func(x): | |
return x[type_col] not in invalid_problem_types | |
return dataset.filter(filter_func, desc="Filtering problem types") | |