import logging import re from collections.abc import Collection, Mapping from pathlib import Path from datasets import Dataset from molbloom import BloomFilter, canon from rdkit import Chem from rdkit.Chem.Draw import MolDraw2D, MolDraw2DSVG # pylint: disable=no-name-in-module from rdkit.Chem.Draw.rdMolDraw2D import MolDraw2DCairo from rdkit.Chem.rdChemReactions import ( # pylint: disable=no-name-in-module ReactionFromSmarts, ) from rdkit.Chem.rdDepictor import ( # pylint: disable=no-name-in-module Compute2DCoords, StraightenDepiction, ) from rdkit.Chem.rdMolDescriptors import ( # pylint: disable=no-name-in-module GetMorganFingerprint, ) from rdkit.Chem.rdmolfiles import MolFromSmiles # pylint: disable=no-name-in-module logger = logging.getLogger(__name__) PROBLEM_CATEGORY_TO_NICKNAME: Mapping[str, str] = { "functional-group": "functional group", "molecule-caption": "molecule caption", "molecule-completion": "SMILES completion", "molecule-formula": "elucidation", "molecule-name": "IUPAC name", "oracle-solubility": "solubility edit", "property": "multiple choice", "property-cat-brain": "BBB permeability", "property-cat-eve": "Human receptor binding", "property-cat-safety": "safety", "property-cat-smell": "scent", "property-regression-pka": "pKa", "property-regression-ld50": "LD50", "property-regression-adme": "ADME", "reaction-prediction": "reaction prediction", "retro-synthesis": "retrosynthesis", "simple-formula": "molecular formula", "property-regression-adme/log_hlm_clint": "log of HLM CL$_{\\text{int}}$", "property-regression-adme/log_mdr1-mdck_er": "log of MDR1-MDCK ER", "property-regression-adme/log_rlm_clint": "log of RLM CL$_{\\text{int}}$", "property-regression-adme/log_solubility": "log of aqueous solubility", } def get_problem_type(row: Mapping[str, str]) -> str: return row.get("problem_type") or row["type"] def get_problem_category(problem_type: str | None) -> str: return (problem_type or "").split("/", maxsplit=1)[0] def get_problem_categories_from_datasets(*datasets: Dataset) -> Collection[str]: return { get_problem_category(pt) for dataset in datasets for pt in (dataset.hf_dataset if hasattr(dataset, "hf_dataset") else dataset)[ "problem_type" ] } # Use this regex with findall to extract SMILES strings from text. # Note this function currently fails on counterions e.g. # Cc1ccc(-c2ccc3c(c2)c2ccccc2c[n+]3C)cc1.[Cl-] SMILES_PATTERN = re.compile( r"(? MolDraw2DCairo: return MolDraw2DCairo(w, h) def draw_molecule( smiles: str, bg_opacity: float = 1.0, d2d: MolDraw2D | None = None ) -> str: """Draw a SMILES molecule and return the drawing string.""" mol = Chem.MolFromSmiles(smiles) if mol is None: raise ValueError(f"Failed to convert {smiles=} to a molecule.") Compute2DCoords(mol) StraightenDepiction(mol) if d2d is None: d2d = MolDraw2DSVG(-1, -1) dopts = d2d.drawOptions() dopts.useBWAtomPalette() dopts.setBackgroundColour((*dopts.getBackgroundColour(), bg_opacity)) d2d.DrawMolecule(mol) d2d.FinishDrawing() return d2d.GetDrawingText() def draw_reaction( rxn_smiles: str, bg_opacity: float = 1.0, d2d: MolDraw2D | None = None ) -> str: rxn = ReactionFromSmarts(rxn_smiles, useSmiles=True) if d2d is None: d2d = MolDraw2DSVG(-1, -1) dopts = d2d.drawOptions() dopts.useBWAtomPalette() dopts.setBackgroundColour((*dopts.getBackgroundColour(), bg_opacity)) d2d.DrawReaction(rxn) d2d.FinishDrawing() return d2d.GetDrawingText() # Precompiled SMARTS patterns for protected bonds and ring atoms _ring_db_pat = Chem.MolFromSmarts("[#6R,#16R]=[OR0,SR0,CR0,NR0]") _ring_atom_pat = Chem.MolFromSmarts("[R]") bloom_filters: dict[str, BloomFilter] = {} def _get_bits(mol: Chem.Mol) -> set[str]: """Get the fingerprint bits from a molecule.""" # the keys are the actual bits bi: dict[int, tuple[tuple[int, int], ...]] = {} GetMorganFingerprint(mol, 2, bitInfo=bi) # type: ignore[arg-type] return {str(k) for k in bi} ETHER0_DIR = Path(__file__).parent def _get_bloom_filter(name: str) -> BloomFilter: if name in bloom_filters: return bloom_filters[name] bloom_filters[name] = BloomFilter(str(ETHER0_DIR / f"{name}.bloom")) return bloom_filters[name] def get_ring_system(mol: Chem.Mol) -> list[str]: """ Extracts ring systems from an RDKit molecule and returns a list of SMILES. Bonds not in rings and not protected (e.g., ring carbonyls) are cleaved. Source: https://github.com/PatWalters/useful_rdkit_utils/blob/edb126e3fd71870ae2d1c9440b904106e3ef97a2/useful_rdkit_utils/ring_systems.py#L13 Which has a MIT license, copyright 2021-2025 PatWalters. """ # noqa: D205 # Copy to avoid mutating original mol = Chem.Mol(mol) # Tag protected bonds for bond in mol.GetBonds(): bond.SetBoolProp("protected", False) # noqa: FBT003 for a1, a2 in mol.GetSubstructMatches(_ring_db_pat): b = mol.GetBondBetweenAtoms(a1, a2) b.SetBoolProp("protected", True) # noqa: FBT003 # Cleave linker bonds cleave_idxs = [ b.GetIdx() for b in mol.GetBonds() if not b.IsInRing() and not b.GetBoolProp("protected") and b.GetBondType() == Chem.BondType.SINGLE ] if cleave_idxs: frag_mol = Chem.FragmentOnBonds(mol, cleave_idxs) Chem.SanitizeMol(frag_mol) else: frag_mol = mol # Split into fragments and clean up ring_smiles: list[str] = [] for frag in Chem.GetMolFrags(frag_mol, asMols=True): if frag.HasSubstructMatch(_ring_atom_pat): for atom in frag.GetAtoms(): if atom.GetAtomicNum() == 0: atom.SetAtomicNum(1) atom.SetIsotope(0) frag = Chem.RemoveAllHs(frag) # noqa: PLW2901 # Fix stereo on terminal double bonds for bd in frag.GetBonds(): if bd.GetBondType() == Chem.BondType.DOUBLE and ( 1 in {bd.GetBeginAtom().GetDegree(), bd.GetEndAtom().GetDegree()} ): bd.SetStereo(Chem.BondStereo.STEREONONE) ring_smiles.append(Chem.MolToSmiles(frag)) return ring_smiles def is_reasonable_ring_system(mol: Chem.Mol, ref_mol: Chem.Mol | None = None) -> bool: """ Check if a molecule has a reasonable ring system. Either no rings or the ring system is found in known rings. If reference is provided, thsos are assumed valid. """ bloom_filter = _get_bloom_filter("rings") ring_systems = [canon(r) for r in get_ring_system(mol)] # remove from consideration all rings in ref_mol, since we'll always assume they're correct if ref_mol: ref_ring_systems = [canon(r) for r in get_ring_system(ref_mol)] ring_systems = [ring for ring in ring_systems if ring not in ref_ring_systems] return all((r in bloom_filter) for r in ring_systems) def is_reasonable_fp(mol: Chem.Mol, ref_mol: Chem.Mol | None = None) -> bool: """ Check if a molecule has a reasonable fingerprint. If reference is provided, those fingerprints are assumed valid. """ bloom_filter = _get_bloom_filter("fingerprints") bits: Collection[str] = _get_bits(mol) # remove from consideration all rings in ref_mol, since we'll always assume they're correct if ref_mol: ref_bits = _get_bits(ref_mol) bits = [bit for bit in bits if bit not in ref_bits] return all((b in bloom_filter) for b in bits) def mol_from_smiles(smiles: str, *args, **kwargs) -> Chem.Mol | None: """MolFromSmiles is type-hinted to always return Mol, but can return None.""" return MolFromSmiles(smiles, *args, **kwargs)