Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,041 Bytes
4c346eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
import logging
import re
from collections.abc import Collection, Mapping
from pathlib import Path
from datasets import Dataset
from molbloom import BloomFilter, canon
from rdkit import Chem
from rdkit.Chem.Draw import MolDraw2D, MolDraw2DSVG # pylint: disable=no-name-in-module
from rdkit.Chem.Draw.rdMolDraw2D import MolDraw2DCairo
from rdkit.Chem.rdChemReactions import ( # pylint: disable=no-name-in-module
ReactionFromSmarts,
)
from rdkit.Chem.rdDepictor import ( # pylint: disable=no-name-in-module
Compute2DCoords,
StraightenDepiction,
)
from rdkit.Chem.rdMolDescriptors import ( # pylint: disable=no-name-in-module
GetMorganFingerprint,
)
from rdkit.Chem.rdmolfiles import MolFromSmiles # pylint: disable=no-name-in-module
logger = logging.getLogger(__name__)
PROBLEM_CATEGORY_TO_NICKNAME: Mapping[str, str] = {
"functional-group": "functional group",
"molecule-caption": "molecule caption",
"molecule-completion": "SMILES completion",
"molecule-formula": "elucidation",
"molecule-name": "IUPAC name",
"oracle-solubility": "solubility edit",
"property": "multiple choice",
"property-cat-brain": "BBB permeability",
"property-cat-eve": "Human receptor binding",
"property-cat-safety": "safety",
"property-cat-smell": "scent",
"property-regression-pka": "pKa",
"property-regression-ld50": "LD50",
"property-regression-adme": "ADME",
"reaction-prediction": "reaction prediction",
"retro-synthesis": "retrosynthesis",
"simple-formula": "molecular formula",
"property-regression-adme/log_hlm_clint": "log of HLM CL$_{\\text{int}}$",
"property-regression-adme/log_mdr1-mdck_er": "log of MDR1-MDCK ER",
"property-regression-adme/log_rlm_clint": "log of RLM CL$_{\\text{int}}$",
"property-regression-adme/log_solubility": "log of aqueous solubility",
}
def get_problem_type(row: Mapping[str, str]) -> str:
return row.get("problem_type") or row["type"]
def get_problem_category(problem_type: str | None) -> str:
return (problem_type or "").split("/", maxsplit=1)[0]
def get_problem_categories_from_datasets(*datasets: Dataset) -> Collection[str]:
return {
get_problem_category(pt)
for dataset in datasets
for pt in (dataset.hf_dataset if hasattr(dataset, "hf_dataset") else dataset)[
"problem_type"
]
}
# Use this regex with findall to extract SMILES strings from text.
# Note this function currently fails on counterions e.g.
# Cc1ccc(-c2ccc3c(c2)c2ccccc2c[n+]3C)cc1.[Cl-]
SMILES_PATTERN = re.compile(
r"(?<!\w)(?:(?:Cl|Br|[BCNOPSFIC]|[cnops]|\[[^\]]+?\]|[0-9@+\-=#\\/()%])){4,}(?!\w)"
)
def make_sized_d2d(w: int = 400, h: int = 300) -> MolDraw2DCairo:
return MolDraw2DCairo(w, h)
def draw_molecule(
smiles: str, bg_opacity: float = 1.0, d2d: MolDraw2D | None = None
) -> str:
"""Draw a SMILES molecule and return the drawing string."""
mol = Chem.MolFromSmiles(smiles)
if mol is None:
raise ValueError(f"Failed to convert {smiles=} to a molecule.")
Compute2DCoords(mol)
StraightenDepiction(mol)
if d2d is None:
d2d = MolDraw2DSVG(-1, -1)
dopts = d2d.drawOptions()
dopts.useBWAtomPalette()
dopts.setBackgroundColour((*dopts.getBackgroundColour(), bg_opacity))
d2d.DrawMolecule(mol)
d2d.FinishDrawing()
return d2d.GetDrawingText()
def draw_reaction(
rxn_smiles: str, bg_opacity: float = 1.0, d2d: MolDraw2D | None = None
) -> str:
rxn = ReactionFromSmarts(rxn_smiles, useSmiles=True)
if d2d is None:
d2d = MolDraw2DSVG(-1, -1)
dopts = d2d.drawOptions()
dopts.useBWAtomPalette()
dopts.setBackgroundColour((*dopts.getBackgroundColour(), bg_opacity))
d2d.DrawReaction(rxn)
d2d.FinishDrawing()
return d2d.GetDrawingText()
# Precompiled SMARTS patterns for protected bonds and ring atoms
_ring_db_pat = Chem.MolFromSmarts("[#6R,#16R]=[OR0,SR0,CR0,NR0]")
_ring_atom_pat = Chem.MolFromSmarts("[R]")
bloom_filters: dict[str, BloomFilter] = {}
def _get_bits(mol: Chem.Mol) -> set[str]:
"""Get the fingerprint bits from a molecule."""
# the keys are the actual bits
bi: dict[int, tuple[tuple[int, int], ...]] = {}
GetMorganFingerprint(mol, 2, bitInfo=bi) # type: ignore[arg-type]
return {str(k) for k in bi}
ETHER0_DIR = Path(__file__).parent
def _get_bloom_filter(name: str) -> BloomFilter:
if name in bloom_filters:
return bloom_filters[name]
bloom_filters[name] = BloomFilter(str(ETHER0_DIR / f"{name}.bloom"))
return bloom_filters[name]
def get_ring_system(mol: Chem.Mol) -> list[str]:
"""
Extracts ring systems from an RDKit molecule and returns a list of SMILES.
Bonds not in rings and not protected (e.g., ring carbonyls) are cleaved.
Source: https://github.com/PatWalters/useful_rdkit_utils/blob/edb126e3fd71870ae2d1c9440b904106e3ef97a2/useful_rdkit_utils/ring_systems.py#L13
Which has a MIT license, copyright 2021-2025 PatWalters.
""" # noqa: D205
# Copy to avoid mutating original
mol = Chem.Mol(mol)
# Tag protected bonds
for bond in mol.GetBonds():
bond.SetBoolProp("protected", False) # noqa: FBT003
for a1, a2 in mol.GetSubstructMatches(_ring_db_pat):
b = mol.GetBondBetweenAtoms(a1, a2)
b.SetBoolProp("protected", True) # noqa: FBT003
# Cleave linker bonds
cleave_idxs = [
b.GetIdx()
for b in mol.GetBonds()
if not b.IsInRing()
and not b.GetBoolProp("protected")
and b.GetBondType() == Chem.BondType.SINGLE
]
if cleave_idxs:
frag_mol = Chem.FragmentOnBonds(mol, cleave_idxs)
Chem.SanitizeMol(frag_mol)
else:
frag_mol = mol
# Split into fragments and clean up
ring_smiles: list[str] = []
for frag in Chem.GetMolFrags(frag_mol, asMols=True):
if frag.HasSubstructMatch(_ring_atom_pat):
for atom in frag.GetAtoms():
if atom.GetAtomicNum() == 0:
atom.SetAtomicNum(1)
atom.SetIsotope(0)
frag = Chem.RemoveAllHs(frag) # noqa: PLW2901
# Fix stereo on terminal double bonds
for bd in frag.GetBonds():
if bd.GetBondType() == Chem.BondType.DOUBLE and (
1 in {bd.GetBeginAtom().GetDegree(), bd.GetEndAtom().GetDegree()}
):
bd.SetStereo(Chem.BondStereo.STEREONONE)
ring_smiles.append(Chem.MolToSmiles(frag))
return ring_smiles
def is_reasonable_ring_system(mol: Chem.Mol, ref_mol: Chem.Mol | None = None) -> bool:
"""
Check if a molecule has a reasonable ring system.
Either no rings or the ring system is found in known rings.
If reference is provided, thsos are assumed valid.
"""
bloom_filter = _get_bloom_filter("rings")
ring_systems = [canon(r) for r in get_ring_system(mol)]
# remove from consideration all rings in ref_mol, since we'll always assume they're correct
if ref_mol:
ref_ring_systems = [canon(r) for r in get_ring_system(ref_mol)]
ring_systems = [ring for ring in ring_systems if ring not in ref_ring_systems]
return all((r in bloom_filter) for r in ring_systems)
def is_reasonable_fp(mol: Chem.Mol, ref_mol: Chem.Mol | None = None) -> bool:
"""
Check if a molecule has a reasonable fingerprint.
If reference is provided, those fingerprints are assumed valid.
"""
bloom_filter = _get_bloom_filter("fingerprints")
bits: Collection[str] = _get_bits(mol)
# remove from consideration all rings in ref_mol, since we'll always assume they're correct
if ref_mol:
ref_bits = _get_bits(ref_mol)
bits = [bit for bit in bits if bit not in ref_bits]
return all((b in bloom_filter) for b in bits)
def mol_from_smiles(smiles: str, *args, **kwargs) -> Chem.Mol | None:
"""MolFromSmiles is type-hinted to always return Mol, but can return None."""
return MolFromSmiles(smiles, *args, **kwargs)
|