File size: 8,041 Bytes
4c346eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import logging
import re
from collections.abc import Collection, Mapping
from pathlib import Path

from datasets import Dataset
from molbloom import BloomFilter, canon
from rdkit import Chem
from rdkit.Chem.Draw import MolDraw2D, MolDraw2DSVG  # pylint: disable=no-name-in-module
from rdkit.Chem.Draw.rdMolDraw2D import MolDraw2DCairo
from rdkit.Chem.rdChemReactions import (  # pylint: disable=no-name-in-module
    ReactionFromSmarts,
)
from rdkit.Chem.rdDepictor import (  # pylint: disable=no-name-in-module
    Compute2DCoords,
    StraightenDepiction,
)
from rdkit.Chem.rdMolDescriptors import (  # pylint: disable=no-name-in-module
    GetMorganFingerprint,
)
from rdkit.Chem.rdmolfiles import MolFromSmiles  # pylint: disable=no-name-in-module

logger = logging.getLogger(__name__)


PROBLEM_CATEGORY_TO_NICKNAME: Mapping[str, str] = {
    "functional-group": "functional group",
    "molecule-caption": "molecule caption",
    "molecule-completion": "SMILES completion",
    "molecule-formula": "elucidation",
    "molecule-name": "IUPAC name",
    "oracle-solubility": "solubility edit",
    "property": "multiple choice",
    "property-cat-brain": "BBB permeability",
    "property-cat-eve": "Human receptor binding",
    "property-cat-safety": "safety",
    "property-cat-smell": "scent",
    "property-regression-pka": "pKa",
    "property-regression-ld50": "LD50",
    "property-regression-adme": "ADME",
    "reaction-prediction": "reaction prediction",
    "retro-synthesis": "retrosynthesis",
    "simple-formula": "molecular formula",
    "property-regression-adme/log_hlm_clint": "log of HLM CL$_{\\text{int}}$",
    "property-regression-adme/log_mdr1-mdck_er": "log of MDR1-MDCK ER",
    "property-regression-adme/log_rlm_clint": "log of RLM CL$_{\\text{int}}$",
    "property-regression-adme/log_solubility": "log of aqueous solubility",
}


def get_problem_type(row: Mapping[str, str]) -> str:
    return row.get("problem_type") or row["type"]


def get_problem_category(problem_type: str | None) -> str:
    return (problem_type or "").split("/", maxsplit=1)[0]


def get_problem_categories_from_datasets(*datasets: Dataset) -> Collection[str]:
    return {
        get_problem_category(pt)
        for dataset in datasets
        for pt in (dataset.hf_dataset if hasattr(dataset, "hf_dataset") else dataset)[
            "problem_type"
        ]
    }


# Use this regex with findall to extract SMILES strings from text.
# Note this function currently fails on counterions e.g.
# Cc1ccc(-c2ccc3c(c2)c2ccccc2c[n+]3C)cc1.[Cl-]
SMILES_PATTERN = re.compile(
    r"(?<!\w)(?:(?:Cl|Br|[BCNOPSFIC]|[cnops]|\[[^\]]+?\]|[0-9@+\-=#\\/()%])){4,}(?!\w)"
)


def make_sized_d2d(w: int = 400, h: int = 300) -> MolDraw2DCairo:
    return MolDraw2DCairo(w, h)


def draw_molecule(
    smiles: str, bg_opacity: float = 1.0, d2d: MolDraw2D | None = None
) -> str:
    """Draw a SMILES molecule and return the drawing string."""
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError(f"Failed to convert {smiles=} to a molecule.")
    Compute2DCoords(mol)
    StraightenDepiction(mol)
    if d2d is None:
        d2d = MolDraw2DSVG(-1, -1)
    dopts = d2d.drawOptions()
    dopts.useBWAtomPalette()
    dopts.setBackgroundColour((*dopts.getBackgroundColour(), bg_opacity))
    d2d.DrawMolecule(mol)
    d2d.FinishDrawing()
    return d2d.GetDrawingText()


def draw_reaction(
    rxn_smiles: str, bg_opacity: float = 1.0, d2d: MolDraw2D | None = None
) -> str:
    rxn = ReactionFromSmarts(rxn_smiles, useSmiles=True)
    if d2d is None:
        d2d = MolDraw2DSVG(-1, -1)
    dopts = d2d.drawOptions()
    dopts.useBWAtomPalette()
    dopts.setBackgroundColour((*dopts.getBackgroundColour(), bg_opacity))
    d2d.DrawReaction(rxn)
    d2d.FinishDrawing()
    return d2d.GetDrawingText()


# Precompiled SMARTS patterns for protected bonds and ring atoms
_ring_db_pat = Chem.MolFromSmarts("[#6R,#16R]=[OR0,SR0,CR0,NR0]")
_ring_atom_pat = Chem.MolFromSmarts("[R]")


bloom_filters: dict[str, BloomFilter] = {}


def _get_bits(mol: Chem.Mol) -> set[str]:
    """Get the fingerprint bits from a molecule."""
    # the keys are the actual bits
    bi: dict[int, tuple[tuple[int, int], ...]] = {}
    GetMorganFingerprint(mol, 2, bitInfo=bi)  # type: ignore[arg-type]
    return {str(k) for k in bi}


ETHER0_DIR = Path(__file__).parent


def _get_bloom_filter(name: str) -> BloomFilter:
    if name in bloom_filters:
        return bloom_filters[name]
    bloom_filters[name] = BloomFilter(str(ETHER0_DIR / f"{name}.bloom"))
    return bloom_filters[name]


def get_ring_system(mol: Chem.Mol) -> list[str]:
    """
    Extracts ring systems from an RDKit molecule and returns a list of SMILES.
    Bonds not in rings and not protected (e.g., ring carbonyls) are cleaved.

    Source: https://github.com/PatWalters/useful_rdkit_utils/blob/edb126e3fd71870ae2d1c9440b904106e3ef97a2/useful_rdkit_utils/ring_systems.py#L13
    Which has a MIT license, copyright 2021-2025 PatWalters.
    """  # noqa: D205
    # Copy to avoid mutating original
    mol = Chem.Mol(mol)

    # Tag protected bonds
    for bond in mol.GetBonds():
        bond.SetBoolProp("protected", False)  # noqa: FBT003
    for a1, a2 in mol.GetSubstructMatches(_ring_db_pat):
        b = mol.GetBondBetweenAtoms(a1, a2)
        b.SetBoolProp("protected", True)  # noqa: FBT003

    # Cleave linker bonds
    cleave_idxs = [
        b.GetIdx()
        for b in mol.GetBonds()
        if not b.IsInRing()
        and not b.GetBoolProp("protected")
        and b.GetBondType() == Chem.BondType.SINGLE
    ]
    if cleave_idxs:
        frag_mol = Chem.FragmentOnBonds(mol, cleave_idxs)
        Chem.SanitizeMol(frag_mol)
    else:
        frag_mol = mol

    # Split into fragments and clean up
    ring_smiles: list[str] = []
    for frag in Chem.GetMolFrags(frag_mol, asMols=True):
        if frag.HasSubstructMatch(_ring_atom_pat):
            for atom in frag.GetAtoms():
                if atom.GetAtomicNum() == 0:
                    atom.SetAtomicNum(1)
                    atom.SetIsotope(0)
            frag = Chem.RemoveAllHs(frag)  # noqa: PLW2901
            # Fix stereo on terminal double bonds
            for bd in frag.GetBonds():
                if bd.GetBondType() == Chem.BondType.DOUBLE and (
                    1 in {bd.GetBeginAtom().GetDegree(), bd.GetEndAtom().GetDegree()}
                ):
                    bd.SetStereo(Chem.BondStereo.STEREONONE)
            ring_smiles.append(Chem.MolToSmiles(frag))

    return ring_smiles


def is_reasonable_ring_system(mol: Chem.Mol, ref_mol: Chem.Mol | None = None) -> bool:
    """
    Check if a molecule has a reasonable ring system.

    Either no rings or the ring system is found in known rings.
    If reference is provided, thsos are assumed valid.
    """
    bloom_filter = _get_bloom_filter("rings")
    ring_systems = [canon(r) for r in get_ring_system(mol)]
    # remove from consideration all rings in ref_mol, since we'll always assume they're correct
    if ref_mol:
        ref_ring_systems = [canon(r) for r in get_ring_system(ref_mol)]
        ring_systems = [ring for ring in ring_systems if ring not in ref_ring_systems]
    return all((r in bloom_filter) for r in ring_systems)


def is_reasonable_fp(mol: Chem.Mol, ref_mol: Chem.Mol | None = None) -> bool:
    """
    Check if a molecule has a reasonable fingerprint.

    If reference is provided, those fingerprints are assumed valid.
    """
    bloom_filter = _get_bloom_filter("fingerprints")
    bits: Collection[str] = _get_bits(mol)
    # remove from consideration all rings in ref_mol, since we'll always assume they're correct
    if ref_mol:
        ref_bits = _get_bits(ref_mol)
        bits = [bit for bit in bits if bit not in ref_bits]
    return all((b in bloom_filter) for b in bits)


def mol_from_smiles(smiles: str, *args, **kwargs) -> Chem.Mol | None:
    """MolFromSmiles is type-hinted to always return Mol, but can return None."""
    return MolFromSmiles(smiles, *args, **kwargs)