Spaces:
Running
on
Zero
Running
on
Zero
import ast | |
import logging | |
import operator | |
import re | |
import unicodedata | |
from collections.abc import Iterable, Mapping, Sequence | |
from typing import Protocol, cast | |
import exmol | |
from pydantic import JsonValue | |
from rdkit import Chem, DataStructs | |
from rdkit.Chem import GetMolFrags, SanitizeMol # pylint: disable=no-name-in-module | |
from rdkit.Chem.rdMolDescriptors import ( # pylint: disable=no-name-in-module | |
CalcMolFormula, | |
GetMorganFingerprintAsBitVect, | |
) | |
from rdkit.Chem.rdmolfiles import MolToSmiles # pylint: disable=no-name-in-module | |
from rdkit.rdBase import BlockLogs | |
from ether0.clients import fetch_forward_rxn, fetch_purchasable, fetch_solubility | |
from ether0.data import is_reasonable_fp, is_reasonable_ring_system, mol_from_smiles | |
from ether0.model_prompts import extract_answer_loose, extract_thought_answer_strict | |
from ether0.models import RewardFunctionInfo, RewardReason | |
block = BlockLogs() | |
logger = logging.getLogger(__name__) | |
class RewardEvalFn(Protocol): | |
def __call__( | |
self, | |
yhat: str, | |
y: str, | |
soft: bool = False, | |
test: bool = False, | |
metadata: dict[str, JsonValue] | None = None, | |
) -> float: ... | |
def formula_diff(formula1: str, formula2: str) -> float: | |
"""Calculate l2 norm between two molecular formulas.""" | |
# important = elements we care about in organic chem | |
important_elements = {"C", "H", "O", "N", "F", "Cl", "Br", "P", "S"} | |
pattern = re.compile(r"([A-Z][a-z]?)(\d*)") | |
counts1 = dict.fromkeys(important_elements, 0) | |
counts2 = dict.fromkeys(important_elements, 0) | |
for m in pattern.finditer(formula1): | |
element = m.group(1) | |
count = int(m.group(2)) if m.group(2) else 1 | |
if element in important_elements: | |
counts1[element] += count | |
for m in pattern.finditer(formula2): | |
element = m.group(1) | |
count = int(m.group(2)) if m.group(2) else 1 | |
if element in important_elements: | |
counts2[element] += count | |
d2 = sum((counts1[k] - counts2[k]) ** 2 for k in important_elements) | |
return d2**0.5 | |
def format_reward( | |
completions, | |
reasoning: bool, | |
reward: float = 1.0, | |
**kwargs, # noqa: ARG001 | |
) -> list[float]: | |
"""Reward function that checks if the completion has a specific format.""" | |
if isinstance(completions[0], list): | |
completion_contents = [completion[0]["content"] for completion in completions] | |
else: | |
completion_contents = completions | |
# Note we check `answer is not None` since empty answer still counts as valid | |
# formatting. | |
return [ | |
reward if answer is not None else 0.0 | |
for answer in ( | |
extract_thought_answer_strict(c, reasoning)[1] for c in completion_contents | |
) | |
] | |
SUPERSCRIPT_PATTERN = re.compile(r"\^{([\d,]+)}") | |
ITALICS_PATTERN = re.compile(r"{([a-zA-Z])}") | |
# parentheses that aren't nested or contain hyphens | |
# https://regex101.com/r/6c8smX/1 | |
USELESS_PARENTHESES = re.compile(r"([-\d])[\(\[{]([A-Za-z0-9]+)[\]\)}]-") | |
def normalize_iupac(s: str) -> str: | |
"""Normalize an IUPAC name by removing special formatting and characters. | |
Args: | |
s: Original IUPAC name. | |
Returns: | |
A normalized IUPAC name without special characters. | |
""" | |
s = s.strip().casefold() | |
# replace ^{n} with ^(n) | |
s = SUPERSCRIPT_PATTERN.sub(r"^(\1)", s) | |
# remove italicized pattern - but don't match ^{1,5} (by avoiding matching commas) | |
s = ITALICS_PATTERN.sub(r"\1", s) | |
# remove garbage | |
s = s.replace("$", "").replace("~", "") # noqa: FURB184 | |
# remove parentheses that aren't nested or contain hyphens | |
s = USELESS_PARENTHESES.sub(r"\1\2-", s) | |
# ok to ignore carrots and hpyhens for comparison | |
return s.replace("^", "").replace(" ", "-") # noqa: FURB184 | |
def normalize_unicodes(s: str) -> str: | |
"""Normalize all Unicode dashes/hyphens to regular hyphen. | |
Args: | |
s: Input string with potential Unicode characters. | |
Returns: | |
Unicode-normalized string. | |
""" | |
s = unicodedata.normalize("NFKC", s) | |
s = "".join("-" if unicodedata.category(c) in {"Pd", "Po"} else c for c in s) | |
return s.replace("-", "") # minus sign # noqa: FURB184 | |
def is_reasonable_molecule( | |
mol: Chem.Mol, | |
metadata: dict[str, JsonValue] | None, | |
test: bool, # noqa: ARG001 | |
ref_mol: Chem.Mol | None = None, | |
) -> bool: | |
"""Returns True if the molecule passes heuristics for being a reasonable molecule.""" | |
# always check valence | |
try: | |
SanitizeMol(mol) | |
except Exception: | |
RewardReason.INVALID_MOL.set_reason(metadata) | |
return False | |
# We have decided that the convention will be to check the | |
# same at test time and train time. | |
# determine if we have counter-ions (which is fine), but we want to | |
# evaluate the largest molecule only. We only consider single molecules | |
# or single molecules + a counterion as valid responses | |
sorted_frags = sorted( # sort by size | |
GetMolFrags(mol, asMols=True), key=lambda m: m.GetNumAtoms(), reverse=True | |
) | |
if len(sorted_frags) > 2: # noqa: PLR2004 | |
# not a counter-ion | |
RewardReason.FAILED_COUNTERION_CHECK.set_reason(metadata) | |
return False | |
if len(sorted_frags) == 2: # noqa: PLR2004 | |
# If 2, assume first is counter-ion, and double check it's small | |
cmol = sorted_frags[1] | |
if cmol.GetNumHeavyAtoms() > 5: # noqa: PLR2004 | |
RewardReason.FAILED_COUNTERION_CHECK.set_reason(metadata) | |
return False | |
mol = sorted_frags[0] | |
ring_status = is_reasonable_ring_system(mol, ref_mol) | |
if not ring_status: | |
RewardReason.FAILED_RING_CHECK.set_reason(metadata) | |
return False | |
failure = is_reasonable_fp(mol, ref_mol) | |
if not failure: | |
RewardReason.FAILED_REOS_CHECK.set_reason(metadata) | |
return False | |
return True | |
FULL_SMILES_KEY = "full_smiles" | |
def set_full_smiles(smiles: str, metadata: dict[str, JsonValue] | None) -> None: | |
if metadata is not None: | |
metadata[FULL_SMILES_KEY] = smiles | |
BAD_SMARTS_PATTERNS = [ | |
"[#16]-[#16]-[#16]", # More than a thiol bond | |
"[#8]~[#8]", # Peroxides | |
"[#7]-[NH2]", # Hydrazines | |
"[#7]-[NH3]", # weird charged amine | |
"[#7]~[#7]~[#7]", # 3 or more amines | |
"[NX2](=[OX1])[O;$([X2]),$([X1-])]", # Nitrite | |
"[SX2][NX2]=[OX1]", # Thionitrite | |
"[$([NX3](=[OX1])(=[OX1])[O;$([X2]),$([X1-])]),$([NX3+]([OX1-])(=[OX1])[O;$([X2]),$([X1-])])]", # Nitrate # noqa: E501 | |
"[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]", # Nitro | |
"[NX2](=[OX1])[!#7;!#8]", # Nitroso | |
"[CX4]" + ("-[CX4]" * 6), # Long chain of carbons (7 or more) | |
] | |
def contains_bad_substruct(mol: Chem.Mol) -> bool: | |
return any( | |
mol.HasSubstructMatch(Chem.MolFromSmarts(pat)) for pat in BAD_SMARTS_PATTERNS | |
) | |
def rxn_eval( | |
yhat: str, | |
y: str, | |
soft: bool = False, # noqa: ARG001 | |
test: bool = False, # noqa: ARG001 | |
metadata: dict[str, JsonValue] | None = None, # noqa: ARG001 | |
) -> float: | |
"""Returns 1.0 if strings match (case-insensitive), otherwise 0.0.""" | |
# some normalization for IUPAC names - shouldn't affect others | |
if normalize_iupac(yhat) == normalize_iupac(y): | |
return 1.0 | |
# If that fails (would return 0), try normalizing further | |
return ( | |
1.0 | |
if normalize_unicodes(normalize_iupac(yhat)) | |
== normalize_unicodes(normalize_iupac(y)) | |
else 0.0 | |
) | |
def str_eval( | |
yhat: str, | |
y: str, | |
soft: bool = False, # noqa: ARG001 | |
test: bool = False, # noqa: ARG001 | |
metadata: dict[str, JsonValue] | None = None, | |
) -> float: | |
"""Returns 1.0 if strings match (case-insensitive), otherwise 0.0.""" | |
set_full_smiles(yhat, metadata) | |
return 1.0 if normalize_iupac(yhat) == normalize_iupac(y) else 0.0 | |
def valid_mol_eval( | |
yhat: str, | |
y: str, | |
soft: bool = False, # noqa: ARG001 | |
test: bool = False, | |
metadata: dict[str, JsonValue] | None = None, | |
) -> float: | |
"""Validate if yhat is a valid SMILES string, when appended to y. | |
Args: | |
yhat: Model-predicted SMILES string or partial completion. | |
y: Base SMILES string (e.g. "O=C1CCC2=CC=C(O)C(OC)=C2C#CCC2=CC3=C4") to append | |
yhat and check validity. | |
test: unused | |
soft: unused | |
metadata: optional metadata dictionary | |
Returns: | |
1.0 if `y + yhat` is a valid SMILES string, 0.0 otherwise. | |
""" | |
if not yhat: | |
RewardReason.INVALID_MOL.set_reason(metadata) | |
return 0.0 | |
# First attempt yhat alone (assuming full SMILES), then try y+yhat (assuming | |
# partial) if that fails | |
for smiles in (yhat, y + yhat): | |
if not smiles.startswith(y): | |
# only accept a solution containing the answer | |
continue | |
try: | |
mol = mol_from_smiles(smiles) | |
except Exception: | |
logger.exception( | |
f"Failed to construct molecule from SMILES string {yhat!r}." | |
) | |
continue | |
if mol is not None: | |
set_full_smiles(smiles, metadata) | |
if not is_reasonable_molecule(mol, metadata, test): | |
return 0.0 | |
return 1.0 | |
# Nothing worked - mark as invalid | |
RewardReason.INVALID_MOL.set_reason(metadata) | |
return 0.0 | |
SMOOTH_THRESHOLD_TANIMOTO_SIMILARITY = 0.7 # close enough | |
def tanimoto_similarity( | |
m1: Chem.Mol | None, m2: Chem.Mol | None, atom_threshold: float = 10.0 | |
) -> float: | |
"""Calculate Tanimoto similarity between two molecules. | |
The `atom_threshold` parameter is a relative fraction (e.g., `0.2` for 20%) | |
that sets a threshold for degenerate cases when the fingerprints are similar, | |
but there are many more atoms in one molecule. | |
Default is 10.0, which corresponds to a 1000% difference and has no practical effect. | |
""" | |
if m1 is None or m2 is None: | |
return 0.0 | |
fp1 = GetMorganFingerprintAsBitVect(m1, 2) | |
fp2 = GetMorganFingerprintAsBitVect(m2, 2) | |
# heavy atom threshold | |
atoms1 = m1.GetNumHeavyAtoms() | |
atoms2 = m2.GetNumHeavyAtoms() | |
if (denom := max(atoms1, atoms2)) > 0: | |
# Do not apply the atom diff check if there are no heavy atoms. | |
# This is always safe, since the only way to avoid | |
# this block is if m1=m2=H2, which would pass anyway. | |
atom_diff = abs(atoms1 - atoms2) / denom | |
if atom_diff > atom_threshold: | |
return 0.0 | |
return DataStructs.TanimotoSimilarity(fp1, fp2) | |
def exact_mol_match(m1: Chem.Mol, m2: Chem.Mol) -> float: | |
s1 = MolToSmiles(m1, canonical=True, isomericSmiles=True) # noqa: FURB120 | |
s2 = MolToSmiles(m2, canonical=True, isomericSmiles=True) # noqa: FURB120 | |
return 1.0 if s1 == s2 else 0.0 | |
def get_largest_mol(smiles: str) -> Chem.Mol | None: | |
parts = smiles.split(".") | |
# Filter out small fragments (removes counter-ions) and invalid SMILES | |
mols = [ | |
mol_from_smiles(p) | |
for p in parts | |
if (len(p) > 3 and mol_from_smiles(p) is not None) # noqa: PLR2004 | |
] | |
if not mols: | |
return None | |
mols_atoms = [] | |
for mol in mols: | |
n_atoms = None if mol is None else mol.GetNumAtoms() | |
if n_atoms is None: | |
raise NotImplementedError(f"Didn't handle {mol=} having None atoms.") | |
mols_atoms.append((mol, n_atoms)) | |
return max(mols_atoms, key=operator.itemgetter(1))[0] | |
def product_eval( | |
yhat: str, | |
y: str, | |
soft: bool = False, | |
test: bool = False, # noqa: ARG001 | |
metadata: dict[str, JsonValue] | None = None, | |
) -> float: | |
"""Computes the Tanimoto similarity of the largest fragments from two SMILES (if soft) or exact match (if not soft). | |
Returns: | |
Reward in [0, 1]. | |
""" # noqa: E501,W505 | |
m1 = get_largest_mol(yhat) | |
m2 = get_largest_mol(y) | |
if m1 is None: | |
RewardReason.INVALID_MOL.set_reason(metadata) | |
return 0.0 | |
if m2 is None: | |
RewardReason.INVALID_GROUND_TRUTH.set_reason(metadata) | |
logger.warning(f"Invalid ground truth molecule {y!r}.") | |
return 0.0 | |
# don't use yhat directly since it may have multiple molecules | |
set_full_smiles(MolToSmiles(m1), metadata) | |
if soft: | |
return tanimoto_similarity(m1, m2) | |
return exact_mol_match(m1, m2) # exact match for non-soft mode | |
def caption_eval( | |
yhat: str, | |
y: str, | |
soft: bool = False, | |
test: bool = False, | |
metadata: dict[str, JsonValue] | None = None, | |
) -> float: | |
"""Currently forwards to product_eval, but also stores Tanimoto in metadata.""" | |
if metadata is not None: | |
m1 = get_largest_mol(yhat) | |
m2 = get_largest_mol(y) | |
metadata["tanimoto"] = ( | |
tanimoto_similarity(m1, m2) if (m1 is not None and m2 is not None) else 0.0 | |
) | |
return product_eval(yhat, y, soft, test, metadata) | |
def formula_eval( | |
yhat: str, | |
y: str, | |
soft: bool = False, | |
test: bool = False, | |
metadata: dict[str, JsonValue] | None = None, | |
) -> float: | |
"""Check correct formula and Tanimoto similarity, giving a reward in [0, 1].""" | |
set_full_smiles(yhat, metadata) | |
mhat = mol_from_smiles(yhat) | |
m = mol_from_smiles(y) | |
if mhat is None: | |
RewardReason.INVALID_MOL.set_reason(metadata) | |
return 0.0 | |
if m is None: | |
RewardReason.INVALID_GROUND_TRUTH.set_reason(metadata) | |
logger.warning(f"Invalid ground truth molecule {y!r}.") | |
return 0.0 | |
fhat = CalcMolFormula(mhat) | |
f = CalcMolFormula(m) | |
if fhat != f: | |
RewardReason.WRONG_FORMULA.set_reason(metadata) | |
return 0.0 | |
if not is_reasonable_molecule(mhat, metadata, test, ref_mol=m): | |
return 0.0 | |
return ( | |
1.0 | |
if tanimoto_similarity(mhat, m) >= SMOOTH_THRESHOLD_TANIMOTO_SIMILARITY | |
# Give partial credit if soft=True and we got the right formula | |
else (0.5 if soft else 0.0) | |
) | |
def functional_group_eval( | |
yhat: str, | |
y: str, | |
soft: bool = False, | |
test: bool = False, | |
metadata: dict[str, JsonValue] | None = None, | |
) -> float: | |
"""Match functional group and formula, giving a reward in [0, 1].""" | |
set_full_smiles(yhat, metadata) | |
mhat = mol_from_smiles(yhat) | |
if mhat is None: | |
RewardReason.INVALID_MOL.set_reason(metadata) | |
return 0.0 | |
y_args: tuple[str, list[str]] = ast.literal_eval(y) | |
formula = y_args[0] | |
groups = {g.lower() for g in y_args[1]} | |
fhat = CalcMolFormula(mhat) | |
if fhat != formula: | |
RewardReason.WRONG_FORMULA.set_reason(metadata) | |
return 0.0 | |
groupshat: set[str] = { | |
f.lower() for f in exmol.get_functional_groups(mhat, return_all=True) | |
} | |
if not is_reasonable_molecule(mhat, metadata, test): | |
return 0.0 | |
return ( | |
1.0 | |
if groups <= groupshat | |
# Give partial credit if soft=True and we got the right formula | |
else (0.5 if soft else 0.0) | |
) | |
def oracle_solubility_eval( | |
yhat: str, | |
y: str, | |
soft: bool = False, # noqa: ARG001 | |
test: bool = False, | |
metadata: dict[str, JsonValue] | None = None, | |
) -> float: | |
"""Evaluate solubility prediction using remote, giving a reward in [0, 1].""" | |
set_full_smiles(yhat, metadata) | |
# we only want single molecules | |
if "." in yhat: | |
return 0.0 | |
mhat = mol_from_smiles(yhat) | |
if mhat is None: | |
RewardReason.INVALID_MOL.set_reason(metadata) | |
return 0.0 | |
y_args: tuple[str, str | list[str], float | str, str] = ast.literal_eval(y) | |
constraint_type, constraint_data = y_args[:2] | |
target = float(y_args[2]) | |
# Unused: direction = y_args[3] # noqa: ERA001 | |
ref_mol: Chem.Mol | None = None | |
# first check constraint | |
if constraint_type == "scaffold": | |
ref_mol = mol_from_smiles(cast(str, constraint_data)) | |
if ref_mol is None: | |
raise NotImplementedError( | |
f"Didn't handle when {constraint_data=} is invalid." | |
) | |
if not mhat.HasSubstructMatch(ref_mol): | |
RewardReason.FAILED_CONSTRAINT.set_reason(metadata) | |
return 0.0 | |
elif constraint_type == "groups": | |
groups = [g.lower() for g in exmol.get_functional_groups(mhat, return_all=True)] | |
if not any(group.lower() in groups for group in constraint_data): | |
RewardReason.FAILED_CONSTRAINT.set_reason(metadata) | |
return 0.0 | |
elif constraint_type == "tanimoto": | |
ref_mol = mol_from_smiles(cast(str, constraint_data)) | |
if ( | |
tanimoto_similarity(mhat, ref_mol, atom_threshold=0.2) | |
< SMOOTH_THRESHOLD_TANIMOTO_SIMILARITY | |
): | |
RewardReason.FAILED_CONSTRAINT.set_reason(metadata) | |
return 0.0 | |
else: | |
raise ValueError(f"Unknown constraint type: {constraint_type}") | |
if not is_reasonable_molecule(mhat, metadata, test, ref_mol=ref_mol): | |
return 0.0 | |
# make sure we hit the target | |
result = fetch_solubility(yhat) | |
if "solubility" in result: | |
solubility = result["solubility"] | |
delta = solubility - target | |
# hard coded to typical solubility accuracies | |
# we subtract 0.01 because some questions ask for | |
# 0.5 change and we don't want restatements to | |
# be matches | |
if abs(delta) > (0.5 - 0.01): | |
RewardReason.WRONG_NUMERICAL_ANSWER.set_reason(metadata) | |
return 0.0 | |
return 1.0 | |
RewardReason.INVALID_MOL.set_reason(metadata) | |
return 0.0 | |
def oracle_rxn_eval( | |
yhat: str, | |
y: str, | |
soft: bool = False, | |
test: bool = False, # noqa: ARG001 | |
metadata: dict[str, JsonValue] | None = None, | |
) -> float: | |
"""Evaluate forward reaction prediction using remote, giving a reward in [0, 1].""" | |
if ">" not in yhat or "." not in yhat: | |
RewardReason.INVALID_RXN.set_reason(metadata) | |
return 0.0 | |
# make sure there are not more than two angle brackets | |
if yhat.count(">") > 2: # noqa: PLR2004 | |
RewardReason.INVALID_RXN.set_reason(metadata) | |
return 0.0 | |
# ok now do real check on regex after heuristic checks | |
# adapted partly from https://gist.github.com/lsauer/1312860/264ae813c2bd2c27a769d261c8c6b38da34e22fb | |
# https://regex101.com/r/9bdE6H/1 | |
# basically SMILES_THINGS>SMILES_THINGS | empty> | |
if not re.match( | |
r"^[^J][a-z0-9@+\-\[\]\(\)\\\/%=#$\.]{6,}>[a-z0-9@+\-\[\]\(\)\\\/%=#$\.]{0,}>", | |
yhat, | |
re.IGNORECASE, # lower = aromatic, which we're fine matching | |
): | |
RewardReason.INVALID_RXN.set_reason(metadata) | |
return 0.0 | |
ymol = mol_from_smiles(y) | |
if ymol is None: | |
RewardReason.INVALID_GROUND_TRUTH.set_reason(metadata) | |
logger.warning(f"Invalid ground truth molecule {y!r}.") | |
return 0.0 | |
reactant_smi = yhat.split(">")[0].split(".") | |
reactants = [mol_from_smiles(r) for r in reactant_smi] | |
if not all(x is not None for x in reactants): | |
RewardReason.INVALID_MOL.set_reason(metadata) | |
return 0.0 | |
reagents = [mol_from_smiles(r) for r in yhat.split(">")[1].split(".") if r.strip()] | |
if not all(x is not None for x in reagents): | |
RewardReason.INVALID_MOL.set_reason(metadata) | |
return 0.0 | |
# check products, if present, contain the desired product | |
products = [mol_from_smiles(r) for r in yhat.split(">")[2].split(".") if r.strip()] | |
# notice we pass if there are no products | |
if products: | |
if not all(x is not None for x in products): | |
RewardReason.INVALID_MOL.set_reason(metadata) | |
return 0.0 | |
if not any(exact_mol_match(m, ymol) == 1.0 for m in products): # type: ignore[arg-type] | |
RewardReason.INVALID_RXN.set_reason(metadata) | |
return 0.0 | |
# Disallow products in the reactants or reagents | |
if any(exact_mol_match(m, ymol) == 1.0 for m in (reactants + reagents)): # type: ignore[arg-type] | |
RewardReason.PRODUCT_IS_REACTANT.set_reason(metadata) | |
return 0.0 | |
# check that the reactants are purchasable | |
def is_small_so_probably_purchasable(smi: str) -> bool: | |
mol = mol_from_smiles(smi) | |
# Molecules with <= 4 heavy atoms are likely purchasable, | |
# since they include solvents and counterions | |
return mol is not None and mol.GetNumHeavyAtoms() <= 4 # noqa: PLR2004 | |
purchasable_results = fetch_purchasable(reactant_smi) | |
if not all( | |
purchasable_results.get(r, False) or is_small_so_probably_purchasable(r) | |
for r in reactant_smi | |
): | |
RewardReason.NOT_PURCHASABLE.set_reason(metadata) | |
return 0.0 | |
result = fetch_forward_rxn(yhat) | |
if "product" in result: | |
product = result["product"] | |
pmol = mol_from_smiles(product) | |
if pmol is None: | |
RewardReason.INVALID_MOL.set_reason(metadata) | |
return 0.0 | |
if soft: | |
return tanimoto_similarity(pmol, ymol) | |
if exact_mol_match(pmol, ymol) == 1.0: | |
return 1.0 | |
RewardReason.WRONG_PRODUCT.set_reason(metadata) | |
return 0.0 | |
RewardReason.INVALID_RXN.set_reason(metadata) | |
return 0.0 | |
def valid_molecule_eval( | |
yhat: str, | |
y: str, # noqa: ARG001 | |
soft: bool = False, # noqa: ARG001 | |
test: bool = False, # noqa: ARG001 | |
metadata: dict[str, JsonValue] | None = None, # noqa: ARG001 | |
) -> float: | |
"""Evaluate if yhat is valid molecule.""" | |
if not yhat: | |
return 0.0 | |
mol = mol_from_smiles(yhat, sanitize=True) | |
return float(mol is not None) | |
EVAL_FUNCTIONS: Mapping[str, RewardEvalFn] = { | |
"str_eval": str_eval, | |
"valid_mol_eval": valid_mol_eval, | |
"caption_eval": caption_eval, | |
"product_eval": product_eval, | |
"rxn_eval": rxn_eval, | |
"formula_eval": formula_eval, | |
"functional_group_eval": functional_group_eval, | |
"sol_eval": oracle_solubility_eval, | |
"rxn_forward": oracle_rxn_eval, | |
"should_not_answer_eval": str_eval, | |
"should_answer_eval": valid_molecule_eval, | |
} | |
# These correspond to open-ended problems that do not have a | |
# unique molecule as answer. | |
APPLY_GOOD_MOLECULE_CHECK: set[str] = { | |
"valid_mol_eval", | |
"formula_eval", | |
"functional_group_eval", | |
"sol_eval", | |
} | |
def accuracy_reward( | |
completions: Sequence[list[Mapping[str, str]]] | Sequence[str], | |
solution: Iterable[str], | |
reasoning: bool, | |
metadata: list[dict[str, JsonValue]] | None = None, | |
soft: bool = False, | |
test: bool = False, | |
good_molecule_bonus: float = 0.0, | |
**kwargs, # noqa: ARG001 | |
) -> list[float]: | |
"""Reward function that checks if the completion is the same as the ground truth.""" | |
if isinstance(completions[0], list): | |
messages = cast(Sequence[list[Mapping[str, str]]], completions) | |
contents: Sequence[str] = [m[0]["content"] for m in messages] | |
else: | |
contents = cast(Sequence[str], completions) | |
if soft and test: | |
raise ValueError("Soft mode is not supported for test time accuracy reward.") | |
rewards = [] | |
problem_types: list[str | None] = [] | |
if metadata is None: | |
# Create empty metadata that we can use internal to this function | |
metadata = [{} for _ in contents] | |
else: | |
if metadata: | |
raise NotImplementedError(f"Received non-empty metadata {metadata}.") | |
metadata.extend([{} for _ in contents]) | |
for content, info, meta in zip(contents, solution, metadata, strict=True): | |
reward = 0.0 | |
reward_info = RewardFunctionInfo.model_validate(info) | |
fxn_name, answer_info, problem_type = tuple(reward_info.model_dump().values()) | |
try: | |
answer: str | None = ( | |
extract_answer_loose(content) | |
if test | |
else extract_thought_answer_strict(content, reasoning=reasoning)[1] | |
) | |
if answer is not None: | |
# During test time, see if full SMILES string was given as input | |
if problem_type == "valid_mol_eval" and test: | |
# If we're testing, we only allow full SMILES strings | |
reward = EVAL_FUNCTIONS[fxn_name]( | |
answer, answer_info, test=test, metadata=meta | |
) | |
else: | |
reward = EVAL_FUNCTIONS[fxn_name]( | |
answer, answer_info, soft=soft, metadata=meta | |
) | |
RewardReason.set_default_reason(reward, meta) | |
if reward == 1.0 and fxn_name in APPLY_GOOD_MOLECULE_CHECK: | |
if FULL_SMILES_KEY not in meta: | |
raise ValueError( # noqa: TRY301 | |
f"Missing full SMILES key in metadata {meta}" | |
f" with reward function {fxn_name}." | |
) | |
full_smiles = cast(str, meta[FULL_SMILES_KEY]) | |
mol = mol_from_smiles(full_smiles) | |
if mol is None: | |
raise ValueError( # noqa: TRY301 | |
f"Invalid full SMILES {full_smiles}" | |
f" with reward function {fxn_name}." | |
) | |
meta["is_good_molecule"] = not contains_bad_substruct(mol) | |
if meta["is_good_molecule"]: | |
reward += good_molecule_bonus | |
else: | |
RewardReason.FORMAT_FAILED.set_reason(meta) | |
rewards.append(reward) | |
problem_types.append(problem_type) | |
except Exception: | |
logger.exception( | |
f"Unhandled exception in {fxn_name=} for {problem_type=}" | |
f" with inputs {content=}, {answer_info=} {soft=}, and {test=}." | |
) | |
RewardReason.REWARD_FUNCTION_EXCEPTION.set_reason(meta) | |
rewards.append(reward) | |
problem_types.append(None) | |
return rewards | |