Spaces:
Running
Running
refactor
Browse files- app.py +288 -0
- inference.py +303 -0
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-310.pyc +0 -0
- src/data/__init__.py +0 -0
- src/data/__pycache__/__init__.cpython-310.pyc +0 -0
- src/data/__pycache__/dataset.cpython-310.pyc +0 -0
- src/data/__pycache__/utils.cpython-310.pyc +0 -0
- src/data/dataset.py +317 -0
- src/data/utils.py +143 -0
- src/model/__init__.py +0 -0
- src/model/__pycache__/__init__.cpython-310.pyc +0 -0
- src/model/__pycache__/layers.cpython-310.pyc +0 -0
- src/model/__pycache__/loss.cpython-310.pyc +0 -0
- src/model/__pycache__/models.cpython-310.pyc +0 -0
- src/model/layers.py +234 -0
- src/model/loss.py +85 -0
- src/model/models.py +269 -0
- src/util/__init__.py +0 -0
- src/util/__pycache__/__init__.cpython-310.pyc +0 -0
- src/util/__pycache__/smiles_cor.cpython-310.pyc +0 -0
- src/util/__pycache__/utils.cpython-310.pyc +0 -0
- src/util/smiles_cor.py +1284 -0
- src/util/utils.py +930 -0
- train.py +462 -0
app.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from inference import Inference
|
| 3 |
+
import PIL
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import random
|
| 7 |
+
from rdkit import Chem
|
| 8 |
+
from rdkit.Chem import Draw
|
| 9 |
+
from rdkit.Chem.Draw import IPythonConsole
|
| 10 |
+
import shutil
|
| 11 |
+
import os
|
| 12 |
+
import time
|
| 13 |
+
|
| 14 |
+
class DrugGENConfig:
|
| 15 |
+
# Inference configuration
|
| 16 |
+
submodel='DrugGEN'
|
| 17 |
+
inference_model="experiments/models/DrugGEN/"
|
| 18 |
+
sample_num=100
|
| 19 |
+
disable_correction=False # corresponds to correct=True in old config
|
| 20 |
+
|
| 21 |
+
# Data configuration
|
| 22 |
+
inf_smiles='data/chembl_test.smi' # corresponds to inf_raw_file in old config
|
| 23 |
+
train_smiles='data/chembl_train.smi'
|
| 24 |
+
train_drug_smiles='data/akt1_train.smi'
|
| 25 |
+
inf_batch_size=1
|
| 26 |
+
mol_data_dir='data'
|
| 27 |
+
features=False
|
| 28 |
+
|
| 29 |
+
# Model configuration
|
| 30 |
+
act='relu'
|
| 31 |
+
max_atom=45
|
| 32 |
+
dim=128
|
| 33 |
+
depth=1
|
| 34 |
+
heads=8
|
| 35 |
+
mlp_ratio=3
|
| 36 |
+
dropout=0.
|
| 37 |
+
|
| 38 |
+
# Seed configuration
|
| 39 |
+
set_seed=True
|
| 40 |
+
seed=10
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class DrugGENAKT1Config(DrugGENConfig):
|
| 44 |
+
submodel='DrugGEN'
|
| 45 |
+
inference_model="experiments/models/DrugGEN-AKT1/"
|
| 46 |
+
train_drug_smiles='data/akt1_train.smi'
|
| 47 |
+
max_atom=45
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class DrugGENCDK2Config(DrugGENConfig):
|
| 51 |
+
submodel='DrugGEN'
|
| 52 |
+
inference_model="experiments/models/DrugGEN-CDK2/"
|
| 53 |
+
train_drug_smiles='data/cdk2_train.smi'
|
| 54 |
+
max_atom=38
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class NoTargetConfig(DrugGENConfig):
|
| 58 |
+
submodel="NoTarget"
|
| 59 |
+
inference_model="experiments/models/NoTarget/"
|
| 60 |
+
train_drug_smiles='data/chembl_train.smi' # No specific target, use general ChEMBL data
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
model_configs = {
|
| 64 |
+
"DrugGEN-AKT1": DrugGENAKT1Config(),
|
| 65 |
+
"DrugGEN-CDK2": DrugGENCDK2Config(),
|
| 66 |
+
"DrugGEN-NoTarget": NoTargetConfig(),
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def function(model_name: str, num_molecules: int, seed_num: int) -> tuple[PIL.Image, pd.DataFrame, str]:
|
| 72 |
+
'''
|
| 73 |
+
Returns:
|
| 74 |
+
image, score_df, file path
|
| 75 |
+
'''
|
| 76 |
+
if model_name == "DrugGEN-NoTarget":
|
| 77 |
+
model_name = "NoTarget"
|
| 78 |
+
|
| 79 |
+
config = model_configs[model_name]
|
| 80 |
+
config.sample_num = num_molecules
|
| 81 |
+
|
| 82 |
+
if config.sample_num > 250:
|
| 83 |
+
raise gr.Error("You have requested to generate more than the allowed limit of 250 molecules. Please reduce your request to 250 or fewer.")
|
| 84 |
+
|
| 85 |
+
if seed_num is None or seed_num.strip() == "":
|
| 86 |
+
config.seed = random.randint(0, 10000)
|
| 87 |
+
else:
|
| 88 |
+
try:
|
| 89 |
+
config.seed = int(seed_num)
|
| 90 |
+
except ValueError:
|
| 91 |
+
raise gr.Error("The seed must be an integer value!")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
inferer = Inference(config)
|
| 95 |
+
start_time = time.time()
|
| 96 |
+
scores = inferer.inference() # create scores_df out of this
|
| 97 |
+
et = time.time() - start_time
|
| 98 |
+
|
| 99 |
+
score_df = pd.DataFrame({
|
| 100 |
+
"Runtime (seconds)": [et],
|
| 101 |
+
"Validity": [scores["validity"].iloc[0]],
|
| 102 |
+
"Uniqueness": [scores["uniqueness"].iloc[0]],
|
| 103 |
+
"Novelty (Train)": [scores["novelty"].iloc[0]],
|
| 104 |
+
"Novelty (Test)": [scores["novelty_test"].iloc[0]],
|
| 105 |
+
"Drug Novelty": [scores["drug_novelty"].iloc[0]],
|
| 106 |
+
"Max Length": [scores["max_len"].iloc[0]],
|
| 107 |
+
"Mean Atom Type": [scores["mean_atom_type"].iloc[0]],
|
| 108 |
+
"SNN ChEMBL": [scores["snn_chembl"].iloc[0]],
|
| 109 |
+
"SNN Drug": [scores["snn_drug"].iloc[0]],
|
| 110 |
+
"Internal Diversity": [scores["IntDiv"].iloc[0]],
|
| 111 |
+
"QED": [scores["qed"].iloc[0]],
|
| 112 |
+
"SA Score": [scores["sa"].iloc[0]]
|
| 113 |
+
})
|
| 114 |
+
|
| 115 |
+
output_file_path = f'experiments/inference/{model_name}/inference_drugs.txt'
|
| 116 |
+
|
| 117 |
+
new_path = f'{model_name}_denovo_mols.smi'
|
| 118 |
+
os.rename(output_file_path, new_path)
|
| 119 |
+
|
| 120 |
+
with open(new_path) as f:
|
| 121 |
+
inference_drugs = f.read()
|
| 122 |
+
|
| 123 |
+
generated_molecule_list = inference_drugs.split("\n")[:-1]
|
| 124 |
+
|
| 125 |
+
rng = random.Random(config.seed)
|
| 126 |
+
if num_molecules > 12:
|
| 127 |
+
selected_molecules = rng.choices(generated_molecule_list, k=12)
|
| 128 |
+
else:
|
| 129 |
+
selected_molecules = generated_molecule_list
|
| 130 |
+
|
| 131 |
+
selected_molecules = [Chem.MolFromSmiles(mol) for mol in selected_molecules if Chem.MolFromSmiles(mol) is not None]
|
| 132 |
+
|
| 133 |
+
drawOptions = Draw.rdMolDraw2D.MolDrawOptions()
|
| 134 |
+
drawOptions.prepareMolsBeforeDrawing = False
|
| 135 |
+
drawOptions.bondLineWidth = 0.5
|
| 136 |
+
|
| 137 |
+
molecule_image = Draw.MolsToGridImage(
|
| 138 |
+
selected_molecules,
|
| 139 |
+
molsPerRow=3,
|
| 140 |
+
subImgSize=(400, 400),
|
| 141 |
+
maxMols=len(selected_molecules),
|
| 142 |
+
# legends=None,
|
| 143 |
+
returnPNG=False,
|
| 144 |
+
drawOptions=drawOptions,
|
| 145 |
+
highlightAtomLists=None,
|
| 146 |
+
highlightBondLists=None,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
return molecule_image, score_df, new_path
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
| 154 |
+
with gr.Row():
|
| 155 |
+
with gr.Column(scale=1):
|
| 156 |
+
gr.Markdown("# DrugGEN: Target Centric De Novo Design of Drug Candidate Molecules with Graph Generative Deep Adversarial Networks")
|
| 157 |
+
with gr.Row():
|
| 158 |
+
gr.Markdown("[](https://arxiv.org/abs/2302.07868)")
|
| 159 |
+
gr.Markdown("[](https://github.com/HUBioDataLab/DrugGEN)")
|
| 160 |
+
|
| 161 |
+
with gr.Accordion("About DrugGEN Models", open=False):
|
| 162 |
+
gr.Markdown("""
|
| 163 |
+
## Model Variations
|
| 164 |
+
|
| 165 |
+
### DrugGEN-AKT1
|
| 166 |
+
This model is designed to generate molecules targeting the human AKT1 protein (UniProt ID: P31749), a serine/threonine-protein kinase that plays a key role in regulating cell survival, metabolism, and growth. AKT1 is a significant target in cancer therapy, particularly for breast, colorectal, and ovarian cancers.
|
| 167 |
+
|
| 168 |
+
The model learns from:
|
| 169 |
+
- General drug-like molecules from ChEMBL database
|
| 170 |
+
- Known AKT1 inhibitors
|
| 171 |
+
- Maximum atom count: 45
|
| 172 |
+
|
| 173 |
+
### DrugGEN-CDK2
|
| 174 |
+
This model targets the human CDK2 protein (UniProt ID: P24941), a cyclin-dependent kinase involved in cell cycle regulation. CDK2 inhibitors are being investigated for treating various cancers, particularly those with dysregulated cell cycle control.
|
| 175 |
+
|
| 176 |
+
The model learns from:
|
| 177 |
+
- General drug-like molecules from ChEMBL database
|
| 178 |
+
- Known CDK2 inhibitors
|
| 179 |
+
- Maximum atom count: 38
|
| 180 |
+
|
| 181 |
+
### DrugGEN-NoTarget
|
| 182 |
+
This is a general-purpose model that generates diverse drug-like molecules without targeting a specific protein. It's useful for:
|
| 183 |
+
- Exploring chemical space
|
| 184 |
+
- Generating diverse scaffolds
|
| 185 |
+
- Creating molecules with drug-like properties
|
| 186 |
+
|
| 187 |
+
## How It Works
|
| 188 |
+
DrugGEN uses a graph-based generative adversarial network (GAN) architecture where:
|
| 189 |
+
1. The generator creates molecular graphs
|
| 190 |
+
2. The discriminator evaluates them against real molecules
|
| 191 |
+
3. The model learns to generate increasingly realistic and target-specific molecules
|
| 192 |
+
|
| 193 |
+
For more details, see our [paper on arXiv](https://arxiv.org/abs/2302.07868).
|
| 194 |
+
""")
|
| 195 |
+
|
| 196 |
+
with gr.Accordion("Understanding the Metrics", open=False):
|
| 197 |
+
gr.Markdown("""
|
| 198 |
+
## Evaluation Metrics
|
| 199 |
+
|
| 200 |
+
### Basic Metrics
|
| 201 |
+
- **Validity**: Percentage of generated molecules that are chemically valid
|
| 202 |
+
- **Uniqueness**: Percentage of unique molecules among valid ones
|
| 203 |
+
- **Runtime**: Time taken to generate the requested molecules
|
| 204 |
+
|
| 205 |
+
### Novelty Metrics
|
| 206 |
+
- **Novelty (Train)**: Percentage of molecules not found in the training set
|
| 207 |
+
- **Novelty (Test)**: Percentage of molecules not found in the test set
|
| 208 |
+
- **Drug Novelty**: Percentage of molecules not found in known drugs
|
| 209 |
+
|
| 210 |
+
### Structural Metrics
|
| 211 |
+
- **Max Length**: Maximum component length in the generated molecules
|
| 212 |
+
- **Mean Atom Type**: Average distribution of atom types
|
| 213 |
+
- **Internal Diversity**: Diversity within the generated set (higher is more diverse)
|
| 214 |
+
|
| 215 |
+
### Drug-likeness Metrics
|
| 216 |
+
- **QED (Quantitative Estimate of Drug-likeness)**: Score from 0-1 measuring how drug-like a molecule is (higher is better)
|
| 217 |
+
- **SA Score (Synthetic Accessibility)**: Score from 1-10 indicating ease of synthesis (lower is easier)
|
| 218 |
+
|
| 219 |
+
### Similarity Metrics
|
| 220 |
+
- **SNN ChEMBL**: Similarity to ChEMBL molecules (higher means more similar to known drug-like compounds)
|
| 221 |
+
- **SNN Drug**: Similarity to known drugs (higher means more similar to approved drugs)
|
| 222 |
+
""")
|
| 223 |
+
|
| 224 |
+
model_name = gr.Radio(
|
| 225 |
+
choices=("DrugGEN-AKT1", "DrugGEN-CDK2", "DrugGEN-NoTarget"),
|
| 226 |
+
value="DrugGEN-AKT1",
|
| 227 |
+
label="Select Target Model",
|
| 228 |
+
info="Choose which protein target or general model to use for molecule generation"
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
num_molecules = gr.Slider(
|
| 232 |
+
minimum=10,
|
| 233 |
+
maximum=250,
|
| 234 |
+
value=100,
|
| 235 |
+
step=10,
|
| 236 |
+
label="Number of Molecules to Generate",
|
| 237 |
+
info="This space runs on a CPU, which may result in slower performance. Generating 200 molecules takes approximately 6 minutes. Therefore, We set a 250-molecule cap. On a GPU, the model can generate 10,000 molecules in the same amount of time. Please check our GitHub repo for running our models on GPU.""
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
seed_num = gr.Textbox(
|
| 241 |
+
label="Random Seed (Optional)",
|
| 242 |
+
value="",
|
| 243 |
+
info="Set a specific seed for reproducible results, or leave empty for random generation"
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
submit_button = gr.Button(
|
| 247 |
+
value="Generate Molecules",
|
| 248 |
+
variant="primary",
|
| 249 |
+
size="lg"
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
with gr.Column(scale=2):
|
| 253 |
+
with gr.Tabs():
|
| 254 |
+
with gr.TabItem("Generated Molecules"):
|
| 255 |
+
image_output = gr.Image(
|
| 256 |
+
label="Sample of Generated Molecules",
|
| 257 |
+
elem_id="molecule_display"
|
| 258 |
+
)
|
| 259 |
+
file_download = gr.File(
|
| 260 |
+
label="Download All Generated Molecules (SMILES format)",
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
with gr.TabItem("Performance Metrics"):
|
| 264 |
+
scores_df = gr.Dataframe(
|
| 265 |
+
label="Model Performance Metrics",
|
| 266 |
+
headers=["Runtime (seconds)", "Validity", "Uniqueness", "Novelty (Train)", "Novelty (Test)",
|
| 267 |
+
"Drug Novelty", "Max Length", "Mean Atom Type", "SNN ChEMBL", "SNN Drug",
|
| 268 |
+
"Internal Diversity", "QED", "SA Score"]
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
with gr.Accordion("Generation Settings", open=False):
|
| 272 |
+
gr.Markdown("""
|
| 273 |
+
## Technical Details
|
| 274 |
+
|
| 275 |
+
- This demo runs on CPU which limits generation speed
|
| 276 |
+
- Generating 200 molecules takes approximately 6 minutes
|
| 277 |
+
- For faster generation or larger batches, run the model on GPU using our GitHub repository
|
| 278 |
+
- The model uses a graph-based representation of molecules
|
| 279 |
+
- Maximum atom count varies by model (AKT1: 45, CDK2: 38)
|
| 280 |
+
""")
|
| 281 |
+
|
| 282 |
+
gr.Markdown("### Created by the HU BioDataLab | [GitHub](https://github.com/HUBioDataLab/DrugGEN) | [Paper](https://arxiv.org/abs/2302.07868)")
|
| 283 |
+
|
| 284 |
+
submit_button.click(function, inputs=[model_name, num_molecules, seed_num], outputs=[image_output, scores_df, file_download], api_name="inference")
|
| 285 |
+
#demo.queue(concurrency_count=1)
|
| 286 |
+
demo.queue()
|
| 287 |
+
demo.launch()
|
| 288 |
+
|
inference.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import time
|
| 4 |
+
import random
|
| 5 |
+
import pickle
|
| 6 |
+
import argparse
|
| 7 |
+
import os.path as osp
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.utils.data
|
| 11 |
+
from torch_geometric.loader import DataLoader
|
| 12 |
+
|
| 13 |
+
import pandas as pd
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
from rdkit import RDLogger, Chem
|
| 17 |
+
from rdkit.Chem import QED, RDConfig
|
| 18 |
+
|
| 19 |
+
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
|
| 20 |
+
import sascorer
|
| 21 |
+
|
| 22 |
+
from src.util.utils import *
|
| 23 |
+
from src.model.models import Generator
|
| 24 |
+
from src.data.dataset import DruggenDataset
|
| 25 |
+
from src.data.utils import get_encoders_decoders, load_molecules
|
| 26 |
+
from src.model.loss import generator_loss
|
| 27 |
+
from src.util.smiles_cor import smi_correct
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Inference(object):
|
| 31 |
+
"""Inference class for DrugGEN."""
|
| 32 |
+
|
| 33 |
+
def __init__(self, config):
|
| 34 |
+
if config.set_seed:
|
| 35 |
+
np.random.seed(config.seed)
|
| 36 |
+
random.seed(config.seed)
|
| 37 |
+
torch.manual_seed(config.seed)
|
| 38 |
+
torch.cuda.manual_seed_all(config.seed)
|
| 39 |
+
|
| 40 |
+
torch.backends.cudnn.deterministic = True
|
| 41 |
+
torch.backends.cudnn.benchmark = False
|
| 42 |
+
|
| 43 |
+
os.environ["PYTHONHASHSEED"] = str(config.seed)
|
| 44 |
+
|
| 45 |
+
print(f'Using seed {config.seed}')
|
| 46 |
+
|
| 47 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
|
| 48 |
+
|
| 49 |
+
# Initialize configurations
|
| 50 |
+
self.submodel = config.submodel
|
| 51 |
+
self.inference_model = config.inference_model
|
| 52 |
+
self.sample_num = config.sample_num
|
| 53 |
+
self.disable_correction = config.disable_correction
|
| 54 |
+
|
| 55 |
+
# Data loader.
|
| 56 |
+
self.inf_smiles = config.inf_smiles # SMILES containing text file for first dataset.
|
| 57 |
+
# Write the full path to file.
|
| 58 |
+
|
| 59 |
+
inf_smiles_basename = osp.basename(self.inf_smiles)
|
| 60 |
+
|
| 61 |
+
# Get the base name without extension and add max_atom to it
|
| 62 |
+
self.max_atom = config.max_atom # Model is based on one-shot generation.
|
| 63 |
+
inf_smiles_base = os.path.splitext(inf_smiles_basename)[0]
|
| 64 |
+
|
| 65 |
+
# Change extension from .smi to .pt and add max_atom to the filename
|
| 66 |
+
self.inf_dataset_file = f"{inf_smiles_base}{self.max_atom}.pt"
|
| 67 |
+
|
| 68 |
+
self.inf_batch_size = config.inf_batch_size
|
| 69 |
+
self.train_smiles = config.train_smiles
|
| 70 |
+
self.train_drug_smiles = config.train_drug_smiles
|
| 71 |
+
self.mol_data_dir = config.mol_data_dir # Directory where the dataset files are stored.
|
| 72 |
+
self.dataset_name = self.inf_dataset_file.split(".")[0]
|
| 73 |
+
self.features = config.features # Small model uses atom types as node features. (Boolean, False uses atom types only.)
|
| 74 |
+
# Additional node features can be added. Please check new_dataloarder.py Line 102.
|
| 75 |
+
|
| 76 |
+
# Get atom and bond encoders/decoders
|
| 77 |
+
self.atom_encoder, self.atom_decoder, self.bond_encoder, self.bond_decoder = get_encoders_decoders(
|
| 78 |
+
self.train_smiles,
|
| 79 |
+
self.train_drug_smiles,
|
| 80 |
+
self.max_atom
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
self.inf_dataset = DruggenDataset(self.mol_data_dir,
|
| 84 |
+
self.inf_dataset_file,
|
| 85 |
+
self.inf_smiles,
|
| 86 |
+
self.max_atom,
|
| 87 |
+
self.features,
|
| 88 |
+
atom_encoder=self.atom_encoder,
|
| 89 |
+
atom_decoder=self.atom_decoder,
|
| 90 |
+
bond_encoder=self.bond_encoder,
|
| 91 |
+
bond_decoder=self.bond_decoder)
|
| 92 |
+
|
| 93 |
+
self.inf_loader = DataLoader(self.inf_dataset,
|
| 94 |
+
shuffle=True,
|
| 95 |
+
batch_size=self.inf_batch_size,
|
| 96 |
+
drop_last=True) # PyG dataloader for the first GAN.
|
| 97 |
+
|
| 98 |
+
self.m_dim = len(self.atom_decoder) if not self.features else int(self.inf_loader.dataset[0].x.shape[1]) # Atom type dimension.
|
| 99 |
+
self.b_dim = len(self.bond_decoder) # Bond type dimension.
|
| 100 |
+
self.vertexes = int(self.inf_loader.dataset[0].x.shape[0]) # Number of nodes in the graph.
|
| 101 |
+
|
| 102 |
+
# Model configurations.
|
| 103 |
+
self.act = config.act
|
| 104 |
+
self.dim = config.dim
|
| 105 |
+
self.depth = config.depth
|
| 106 |
+
self.heads = config.heads
|
| 107 |
+
self.mlp_ratio = config.mlp_ratio
|
| 108 |
+
self.dropout = config.dropout
|
| 109 |
+
|
| 110 |
+
self.build_model()
|
| 111 |
+
|
| 112 |
+
def build_model(self):
|
| 113 |
+
"""Create generators and discriminators."""
|
| 114 |
+
self.G = Generator(self.act,
|
| 115 |
+
self.vertexes,
|
| 116 |
+
self.b_dim,
|
| 117 |
+
self.m_dim,
|
| 118 |
+
self.dropout,
|
| 119 |
+
dim=self.dim,
|
| 120 |
+
depth=self.depth,
|
| 121 |
+
heads=self.heads,
|
| 122 |
+
mlp_ratio=self.mlp_ratio)
|
| 123 |
+
self.G.to(self.device)
|
| 124 |
+
self.print_network(self.G, 'G')
|
| 125 |
+
|
| 126 |
+
def print_network(self, model, name):
|
| 127 |
+
"""Print out the network information."""
|
| 128 |
+
num_params = 0
|
| 129 |
+
for p in model.parameters():
|
| 130 |
+
num_params += p.numel()
|
| 131 |
+
print(model)
|
| 132 |
+
print(name)
|
| 133 |
+
print("The number of parameters: {}".format(num_params))
|
| 134 |
+
|
| 135 |
+
def restore_model(self, submodel, model_directory):
|
| 136 |
+
"""Restore the trained generator and discriminator."""
|
| 137 |
+
print('Loading the model...')
|
| 138 |
+
G_path = os.path.join(model_directory, '{}-G.ckpt'.format(submodel))
|
| 139 |
+
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
|
| 140 |
+
|
| 141 |
+
def inference(self):
|
| 142 |
+
# Load the trained generator.
|
| 143 |
+
self.restore_model(self.submodel, self.inference_model)
|
| 144 |
+
|
| 145 |
+
# smiles data for metrics calculation.
|
| 146 |
+
chembl_smiles = [line for line in open(self.train_smiles, 'r').read().splitlines()]
|
| 147 |
+
chembl_test = [line for line in open(self.inf_smiles, 'r').read().splitlines()]
|
| 148 |
+
drug_smiles = [line for line in open(self.train_drug_smiles, 'r').read().splitlines()]
|
| 149 |
+
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
| 150 |
+
drug_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in drug_mols if x is not None]
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# Make directories if not exist.
|
| 154 |
+
if not os.path.exists("experiments/inference/{}".format(self.submodel)):
|
| 155 |
+
os.makedirs("experiments/inference/{}".format(self.submodel))
|
| 156 |
+
|
| 157 |
+
if not self.disable_correction:
|
| 158 |
+
correct = smi_correct(self.submodel, "experiments/inference/{}".format(self.submodel))
|
| 159 |
+
|
| 160 |
+
search_res = pd.DataFrame(columns=["submodel", "validity",
|
| 161 |
+
"uniqueness", "novelty",
|
| 162 |
+
"novelty_test", "drug_novelty",
|
| 163 |
+
"max_len", "mean_atom_type",
|
| 164 |
+
"snn_chembl", "snn_drug", "IntDiv", "qed", "sa"])
|
| 165 |
+
|
| 166 |
+
self.G.eval()
|
| 167 |
+
|
| 168 |
+
start_time = time.time()
|
| 169 |
+
metric_calc_dr = []
|
| 170 |
+
uniqueness_calc = []
|
| 171 |
+
real_smiles_snn = []
|
| 172 |
+
nodes_sample = torch.Tensor(size=[1, self.vertexes, 1]).to(self.device)
|
| 173 |
+
generated_smiles = []
|
| 174 |
+
val_counter = 0
|
| 175 |
+
none_counter = 0
|
| 176 |
+
|
| 177 |
+
# Inference mode
|
| 178 |
+
with torch.inference_mode():
|
| 179 |
+
pbar = tqdm(range(self.sample_num))
|
| 180 |
+
pbar.set_description('Inference mode for {} model started'.format(self.submodel))
|
| 181 |
+
for i, data in enumerate(self.inf_loader):
|
| 182 |
+
val_counter += 1
|
| 183 |
+
# Preprocess dataset
|
| 184 |
+
_, a_tensor, x_tensor = load_molecules(
|
| 185 |
+
data=data,
|
| 186 |
+
batch_size=self.inf_batch_size,
|
| 187 |
+
device=self.device,
|
| 188 |
+
b_dim=self.b_dim,
|
| 189 |
+
m_dim=self.m_dim,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
_, _, node_sample, edge_sample = self.G(a_tensor, x_tensor)
|
| 193 |
+
|
| 194 |
+
g_edges_hat_sample = torch.max(edge_sample, -1)[1]
|
| 195 |
+
g_nodes_hat_sample = torch.max(node_sample, -1)[1]
|
| 196 |
+
|
| 197 |
+
fake_mol_g = [self.inf_dataset.matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=False, file_name=self.dataset_name)
|
| 198 |
+
for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
|
| 199 |
+
|
| 200 |
+
a_tensor_sample = torch.max(a_tensor, -1)[1]
|
| 201 |
+
x_tensor_sample = torch.max(x_tensor, -1)[1]
|
| 202 |
+
real_mols = [self.inf_dataset.matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=self.dataset_name)
|
| 203 |
+
for e_, n_ in zip(a_tensor_sample, x_tensor_sample)]
|
| 204 |
+
|
| 205 |
+
inference_drugs = [None if line is None else Chem.MolToSmiles(line) for line in fake_mol_g]
|
| 206 |
+
inference_drugs = [None if x is None else max(x.split('.'), key=len) for x in inference_drugs]
|
| 207 |
+
|
| 208 |
+
for molecules in inference_drugs:
|
| 209 |
+
if molecules is None:
|
| 210 |
+
none_counter += 1
|
| 211 |
+
|
| 212 |
+
for molecules in inference_drugs:
|
| 213 |
+
if molecules is not None:
|
| 214 |
+
molecules = molecules.replace("*", "C")
|
| 215 |
+
generated_smiles.append(molecules)
|
| 216 |
+
uniqueness_calc.append(molecules)
|
| 217 |
+
nodes_sample = torch.cat((nodes_sample, g_nodes_hat_sample.view(1, self.vertexes, 1)), 0)
|
| 218 |
+
pbar.update(1)
|
| 219 |
+
metric_calc_dr.append(molecules)
|
| 220 |
+
|
| 221 |
+
real_smiles_snn.append(real_mols[0])
|
| 222 |
+
generation_number = len([x for x in metric_calc_dr if x is not None])
|
| 223 |
+
if generation_number == self.sample_num or none_counter == self.sample_num:
|
| 224 |
+
break
|
| 225 |
+
|
| 226 |
+
if not self.disable_correction:
|
| 227 |
+
correct = smi_correct(self.submodel, "experiments/inference/{}".format(self.submodel))
|
| 228 |
+
gen_smi = correct.correct_smiles_list(generated_smiles)
|
| 229 |
+
else:
|
| 230 |
+
gen_smi = generated_smiles
|
| 231 |
+
|
| 232 |
+
et = time.time() - start_time
|
| 233 |
+
|
| 234 |
+
gen_vecs = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(x), 2, nBits=1024) for x in uniqueness_calc if Chem.MolFromSmiles(x) is not None]
|
| 235 |
+
real_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in real_smiles_snn if x is not None]
|
| 236 |
+
|
| 237 |
+
if not self.disable_correction:
|
| 238 |
+
val = round(len(gen_smi)/self.sample_num, 3)
|
| 239 |
+
else:
|
| 240 |
+
val = round(fraction_valid(gen_smi), 3)
|
| 241 |
+
|
| 242 |
+
uniq = round(fraction_unique(gen_smi), 3)
|
| 243 |
+
nov = round(novelty(gen_smi, chembl_smiles), 3)
|
| 244 |
+
nov_test = round(novelty(gen_smi, chembl_test), 3)
|
| 245 |
+
drug_nov = round(novelty(gen_smi, drug_smiles), 3)
|
| 246 |
+
max_len = round(Metrics.max_component(gen_smi, self.vertexes), 3)
|
| 247 |
+
mean_atom = round(Metrics.mean_atom_type(nodes_sample), 3)
|
| 248 |
+
snn_chembl = round(average_agg_tanimoto(np.array(real_vecs), np.array(gen_vecs)), 3)
|
| 249 |
+
snn_drug = round(average_agg_tanimoto(np.array(drug_vecs), np.array(gen_vecs)), 3)
|
| 250 |
+
int_div = round((internal_diversity(np.array(gen_vecs)))[0], 3)
|
| 251 |
+
qed = round(np.mean([QED.qed(Chem.MolFromSmiles(x)) for x in gen_smi if Chem.MolFromSmiles(x) is not None]), 3)
|
| 252 |
+
sa = round(np.mean([sascorer.calculateScore(Chem.MolFromSmiles(x)) for x in gen_smi if Chem.MolFromSmiles(x) is not None]), 3)
|
| 253 |
+
|
| 254 |
+
model_res = pd.DataFrame({"submodel": [self.submodel], "validity": [val],
|
| 255 |
+
"uniqueness": [uniq], "novelty": [nov],
|
| 256 |
+
"novelty_test": [nov_test], "drug_novelty": [drug_nov],
|
| 257 |
+
"max_len": [max_len], "mean_atom_type": [mean_atom],
|
| 258 |
+
"snn_chembl": [snn_chembl], "snn_drug": [snn_drug],
|
| 259 |
+
"IntDiv": [int_div], "qed": [qed], "sa": [sa]})
|
| 260 |
+
|
| 261 |
+
# Write generated SMILES to a temporary file for app.py to use
|
| 262 |
+
temp_file = f'{self.submodel}_denovo_mols.smi'
|
| 263 |
+
with open(temp_file, 'w') as f:
|
| 264 |
+
f.write("SMILES\n")
|
| 265 |
+
for smiles in gen_smi:
|
| 266 |
+
f.write(f"{smiles}\n")
|
| 267 |
+
|
| 268 |
+
return model_res
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
if __name__=="__main__":
|
| 272 |
+
parser = argparse.ArgumentParser()
|
| 273 |
+
|
| 274 |
+
# Inference configuration.
|
| 275 |
+
parser.add_argument('--submodel', type=str, default="DrugGEN", help="Chose model subtype: DrugGEN, NoTarget", choices=['DrugGEN', 'NoTarget'])
|
| 276 |
+
parser.add_argument('--inference_model', type=str, help="Path to the model for inference")
|
| 277 |
+
parser.add_argument('--sample_num', type=int, default=100, help='inference samples')
|
| 278 |
+
parser.add_argument('--disable_correction', action='store_true', help='Disable SMILES correction')
|
| 279 |
+
|
| 280 |
+
# Data configuration.
|
| 281 |
+
parser.add_argument('--inf_smiles', type=str, required=True)
|
| 282 |
+
parser.add_argument('--train_smiles', type=str, required=True)
|
| 283 |
+
parser.add_argument('--train_drug_smiles', type=str, required=True)
|
| 284 |
+
parser.add_argument('--inf_batch_size', type=int, default=1, help='Batch size for inference')
|
| 285 |
+
parser.add_argument('--mol_data_dir', type=str, default='data')
|
| 286 |
+
parser.add_argument('--features', action='store_true', help='features dimension for nodes')
|
| 287 |
+
|
| 288 |
+
# Model configuration.
|
| 289 |
+
parser.add_argument('--act', type=str, default="relu", help="Activation function for the model.", choices=['relu', 'tanh', 'leaky', 'sigmoid'])
|
| 290 |
+
parser.add_argument('--max_atom', type=int, default=45, help='Max atom number for molecules must be specified.')
|
| 291 |
+
parser.add_argument('--dim', type=int, default=128, help='Dimension of the Transformer Encoder model for the GAN.')
|
| 292 |
+
parser.add_argument('--depth', type=int, default=1, help='Depth of the Transformer model from the GAN.')
|
| 293 |
+
parser.add_argument('--heads', type=int, default=8, help='Number of heads for the MultiHeadAttention module from the GAN.')
|
| 294 |
+
parser.add_argument('--mlp_ratio', type=int, default=3, help='MLP ratio for the Transformer.')
|
| 295 |
+
parser.add_argument('--dropout', type=float, default=0., help='dropout rate')
|
| 296 |
+
|
| 297 |
+
# Seed configuration.
|
| 298 |
+
parser.add_argument('--set_seed', action='store_true', help='set seed for reproducibility')
|
| 299 |
+
parser.add_argument('--seed', type=int, default=1, help='seed for reproducibility')
|
| 300 |
+
|
| 301 |
+
config = parser.parse_args()
|
| 302 |
+
inference = Inference(config)
|
| 303 |
+
inference.inference()
|
src/__init__.py
ADDED
|
File without changes
|
src/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (150 Bytes). View file
|
|
|
src/data/__init__.py
ADDED
|
File without changes
|
src/data/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (155 Bytes). View file
|
|
|
src/data/__pycache__/dataset.cpython-310.pyc
ADDED
|
Binary file (12.9 kB). View file
|
|
|
src/data/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (4.75 kB). View file
|
|
|
src/data/dataset.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os.path as osp
|
| 3 |
+
import re
|
| 4 |
+
import pickle
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch_geometric.data import Data, InMemoryDataset
|
| 12 |
+
|
| 13 |
+
from rdkit import Chem, RDLogger
|
| 14 |
+
|
| 15 |
+
from src.data.utils import label2onehot
|
| 16 |
+
|
| 17 |
+
RDLogger.DisableLog('rdApp.*')
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class DruggenDataset(InMemoryDataset):
|
| 21 |
+
def __init__(self, root, dataset_file, raw_files, max_atom, features,
|
| 22 |
+
atom_encoder, atom_decoder, bond_encoder, bond_decoder,
|
| 23 |
+
transform=None, pre_transform=None, pre_filter=None):
|
| 24 |
+
"""
|
| 25 |
+
Initialize the DruggenDataset with pre-loaded encoder/decoder dictionaries.
|
| 26 |
+
|
| 27 |
+
Parameters:
|
| 28 |
+
root (str): Root directory.
|
| 29 |
+
dataset_file (str): Name of the processed dataset file.
|
| 30 |
+
raw_files (str): Path to the raw SMILES file.
|
| 31 |
+
max_atom (int): Maximum number of atoms allowed in a molecule.
|
| 32 |
+
features (bool): Whether to include additional node features.
|
| 33 |
+
atom_encoder (dict): Pre-loaded atom encoder dictionary.
|
| 34 |
+
atom_decoder (dict): Pre-loaded atom decoder dictionary.
|
| 35 |
+
bond_encoder (dict): Pre-loaded bond encoder dictionary.
|
| 36 |
+
bond_decoder (dict): Pre-loaded bond decoder dictionary.
|
| 37 |
+
transform, pre_transform, pre_filter: See PyG InMemoryDataset.
|
| 38 |
+
"""
|
| 39 |
+
self.dataset_name = dataset_file.split(".")[0]
|
| 40 |
+
self.dataset_file = dataset_file
|
| 41 |
+
self.raw_files = raw_files
|
| 42 |
+
self.max_atom = max_atom
|
| 43 |
+
self.features = features
|
| 44 |
+
|
| 45 |
+
# Use the provided encoder/decoder mappings.
|
| 46 |
+
self.atom_encoder_m = atom_encoder
|
| 47 |
+
self.atom_decoder_m = atom_decoder
|
| 48 |
+
self.bond_encoder_m = bond_encoder
|
| 49 |
+
self.bond_decoder_m = bond_decoder
|
| 50 |
+
|
| 51 |
+
self.atom_num_types = len(atom_encoder)
|
| 52 |
+
self.bond_num_types = len(bond_encoder)
|
| 53 |
+
|
| 54 |
+
super().__init__(root, transform, pre_transform, pre_filter)
|
| 55 |
+
path = osp.join(self.processed_dir, dataset_file)
|
| 56 |
+
self.data, self.slices = torch.load(path)
|
| 57 |
+
self.root = root
|
| 58 |
+
|
| 59 |
+
@property
|
| 60 |
+
def processed_dir(self):
|
| 61 |
+
"""
|
| 62 |
+
Returns the directory where processed dataset files are stored.
|
| 63 |
+
"""
|
| 64 |
+
return self.root
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def raw_file_names(self):
|
| 68 |
+
"""
|
| 69 |
+
Returns the raw SMILES file name.
|
| 70 |
+
"""
|
| 71 |
+
return self.raw_files
|
| 72 |
+
|
| 73 |
+
@property
|
| 74 |
+
def processed_file_names(self):
|
| 75 |
+
"""
|
| 76 |
+
Returns the name of the processed dataset file.
|
| 77 |
+
"""
|
| 78 |
+
return self.dataset_file
|
| 79 |
+
|
| 80 |
+
def _filter_smiles(self, smiles_list):
|
| 81 |
+
"""
|
| 82 |
+
Filters the input list of SMILES strings to keep only valid molecules that:
|
| 83 |
+
- Can be successfully parsed,
|
| 84 |
+
- Have a number of atoms less than or equal to the maximum allowed (max_atom),
|
| 85 |
+
- Contain only atoms present in the atom_encoder,
|
| 86 |
+
- Contain only bonds present in the bond_encoder.
|
| 87 |
+
|
| 88 |
+
Parameters:
|
| 89 |
+
smiles_list (list): List of SMILES strings.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
max_length (int): Maximum number of atoms found in the filtered molecules.
|
| 93 |
+
filtered_smiles (list): List of valid SMILES strings.
|
| 94 |
+
"""
|
| 95 |
+
max_length = 0
|
| 96 |
+
filtered_smiles = []
|
| 97 |
+
for smiles in tqdm(smiles_list, desc="Filtering SMILES"):
|
| 98 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 99 |
+
if mol is None:
|
| 100 |
+
continue
|
| 101 |
+
|
| 102 |
+
# Check molecule size
|
| 103 |
+
molecule_size = mol.GetNumAtoms()
|
| 104 |
+
if molecule_size > self.max_atom:
|
| 105 |
+
continue
|
| 106 |
+
|
| 107 |
+
# Filter out molecules with atoms not in the atom_encoder
|
| 108 |
+
if not all(atom.GetAtomicNum() in self.atom_encoder_m for atom in mol.GetAtoms()):
|
| 109 |
+
continue
|
| 110 |
+
|
| 111 |
+
# Filter out molecules with bonds not in the bond_encoder
|
| 112 |
+
if not all(bond.GetBondType() in self.bond_encoder_m for bond in mol.GetBonds()):
|
| 113 |
+
continue
|
| 114 |
+
|
| 115 |
+
filtered_smiles.append(smiles)
|
| 116 |
+
max_length = max(max_length, molecule_size)
|
| 117 |
+
return max_length, filtered_smiles
|
| 118 |
+
|
| 119 |
+
def _genA(self, mol, connected=True, max_length=None):
|
| 120 |
+
"""
|
| 121 |
+
Generates the adjacency matrix for a molecule based on its bond structure.
|
| 122 |
+
|
| 123 |
+
Parameters:
|
| 124 |
+
mol (rdkit.Chem.Mol): The molecule.
|
| 125 |
+
connected (bool): If True, ensures all atoms are connected.
|
| 126 |
+
max_length (int, optional): The size of the matrix; if None, uses number of atoms in mol.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
np.array: Adjacency matrix with bond types as entries, or None if disconnected.
|
| 130 |
+
"""
|
| 131 |
+
max_length = max_length if max_length is not None else mol.GetNumAtoms()
|
| 132 |
+
A = np.zeros((max_length, max_length))
|
| 133 |
+
begin = [b.GetBeginAtomIdx() for b in mol.GetBonds()]
|
| 134 |
+
end = [b.GetEndAtomIdx() for b in mol.GetBonds()]
|
| 135 |
+
bond_type = [self.bond_encoder_m[b.GetBondType()] for b in mol.GetBonds()]
|
| 136 |
+
A[begin, end] = bond_type
|
| 137 |
+
A[end, begin] = bond_type
|
| 138 |
+
degree = np.sum(A[:mol.GetNumAtoms(), :mol.GetNumAtoms()], axis=-1)
|
| 139 |
+
return A if connected and (degree > 0).all() else None
|
| 140 |
+
|
| 141 |
+
def _genX(self, mol, max_length=None):
|
| 142 |
+
"""
|
| 143 |
+
Generates the feature vector for each atom in a molecule by encoding their atomic numbers.
|
| 144 |
+
|
| 145 |
+
Parameters:
|
| 146 |
+
mol (rdkit.Chem.Mol): The molecule.
|
| 147 |
+
max_length (int, optional): Length of the feature vector; if None, uses number of atoms in mol.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
np.array: Array of atom feature indices, padded with zeros if necessary, or None on error.
|
| 151 |
+
"""
|
| 152 |
+
max_length = max_length if max_length is not None else mol.GetNumAtoms()
|
| 153 |
+
try:
|
| 154 |
+
return np.array([self.atom_encoder_m[atom.GetAtomicNum()] for atom in mol.GetAtoms()] +
|
| 155 |
+
[0] * (max_length - mol.GetNumAtoms()))
|
| 156 |
+
except KeyError as e:
|
| 157 |
+
print(f"Skipping molecule with unsupported atom: {e}")
|
| 158 |
+
print(f"Skipped SMILES: {Chem.MolToSmiles(mol)}")
|
| 159 |
+
return None
|
| 160 |
+
|
| 161 |
+
def _genF(self, mol, max_length=None):
|
| 162 |
+
"""
|
| 163 |
+
Generates additional node features for a molecule using various atomic properties.
|
| 164 |
+
|
| 165 |
+
Parameters:
|
| 166 |
+
mol (rdkit.Chem.Mol): The molecule.
|
| 167 |
+
max_length (int, optional): Number of rows in the features matrix; if None, uses number of atoms.
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
np.array: Array of additional features for each atom, padded with zeros if necessary.
|
| 171 |
+
"""
|
| 172 |
+
max_length = max_length if max_length is not None else mol.GetNumAtoms()
|
| 173 |
+
features = np.array([[*[a.GetDegree() == i for i in range(5)],
|
| 174 |
+
*[a.GetExplicitValence() == i for i in range(9)],
|
| 175 |
+
*[int(a.GetHybridization()) == i for i in range(1, 7)],
|
| 176 |
+
*[a.GetImplicitValence() == i for i in range(9)],
|
| 177 |
+
a.GetIsAromatic(),
|
| 178 |
+
a.GetNoImplicit(),
|
| 179 |
+
*[a.GetNumExplicitHs() == i for i in range(5)],
|
| 180 |
+
*[a.GetNumImplicitHs() == i for i in range(5)],
|
| 181 |
+
*[a.GetNumRadicalElectrons() == i for i in range(5)],
|
| 182 |
+
a.IsInRing(),
|
| 183 |
+
*[a.IsInRingSize(i) for i in range(2, 9)]]
|
| 184 |
+
for a in mol.GetAtoms()], dtype=np.int32)
|
| 185 |
+
return np.vstack((features, np.zeros((max_length - features.shape[0], features.shape[1]))))
|
| 186 |
+
|
| 187 |
+
def decoder_load(self, dictionary_name, file):
|
| 188 |
+
"""
|
| 189 |
+
Returns the pre-loaded decoder dictionary based on the dictionary name.
|
| 190 |
+
|
| 191 |
+
Parameters:
|
| 192 |
+
dictionary_name (str): Name of the dictionary ("atom" or "bond").
|
| 193 |
+
file: Placeholder parameter for compatibility.
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
dict: The corresponding decoder dictionary.
|
| 197 |
+
"""
|
| 198 |
+
if dictionary_name == "atom":
|
| 199 |
+
return self.atom_decoder_m
|
| 200 |
+
elif dictionary_name == "bond":
|
| 201 |
+
return self.bond_decoder_m
|
| 202 |
+
else:
|
| 203 |
+
raise ValueError("Unknown dictionary name.")
|
| 204 |
+
|
| 205 |
+
def matrices2mol(self, node_labels, edge_labels, strict=True, file_name=None):
|
| 206 |
+
"""
|
| 207 |
+
Converts graph representations (node labels and edge labels) back to an RDKit molecule.
|
| 208 |
+
|
| 209 |
+
Parameters:
|
| 210 |
+
node_labels (iterable): Encoded atom labels.
|
| 211 |
+
edge_labels (np.array): Adjacency matrix with encoded bond types.
|
| 212 |
+
strict (bool): If True, sanitizes the molecule and returns None on failure.
|
| 213 |
+
file_name: Placeholder parameter for compatibility.
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
rdkit.Chem.Mol: The resulting molecule, or None if sanitization fails.
|
| 217 |
+
"""
|
| 218 |
+
mol = Chem.RWMol()
|
| 219 |
+
for node_label in node_labels:
|
| 220 |
+
mol.AddAtom(Chem.Atom(self.atom_decoder_m[node_label]))
|
| 221 |
+
for start, end in zip(*np.nonzero(edge_labels)):
|
| 222 |
+
if start > end:
|
| 223 |
+
mol.AddBond(int(start), int(end), self.bond_decoder_m[edge_labels[start, end]])
|
| 224 |
+
if strict:
|
| 225 |
+
try:
|
| 226 |
+
Chem.SanitizeMol(mol)
|
| 227 |
+
except Exception:
|
| 228 |
+
mol = None
|
| 229 |
+
return mol
|
| 230 |
+
|
| 231 |
+
def check_valency(self, mol):
|
| 232 |
+
"""
|
| 233 |
+
Checks that no atom in the molecule has exceeded its allowed valency.
|
| 234 |
+
|
| 235 |
+
Parameters:
|
| 236 |
+
mol (rdkit.Chem.Mol): The molecule.
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
tuple: (True, None) if valid; (False, atomid_valence) if there is a valency issue.
|
| 240 |
+
"""
|
| 241 |
+
try:
|
| 242 |
+
Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
|
| 243 |
+
return True, None
|
| 244 |
+
except ValueError as e:
|
| 245 |
+
e = str(e)
|
| 246 |
+
p = e.find('#')
|
| 247 |
+
e_sub = e[p:]
|
| 248 |
+
atomid_valence = list(map(int, re.findall(r'\d+', e_sub)))
|
| 249 |
+
return False, atomid_valence
|
| 250 |
+
|
| 251 |
+
def correct_mol(self, mol):
|
| 252 |
+
"""
|
| 253 |
+
Corrects a molecule by removing bonds until all atoms satisfy their valency limits.
|
| 254 |
+
|
| 255 |
+
Parameters:
|
| 256 |
+
mol (rdkit.Chem.Mol): The molecule.
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
rdkit.Chem.Mol: The corrected molecule.
|
| 260 |
+
"""
|
| 261 |
+
while True:
|
| 262 |
+
flag, atomid_valence = self.check_valency(mol)
|
| 263 |
+
if flag:
|
| 264 |
+
break
|
| 265 |
+
else:
|
| 266 |
+
# Expecting two numbers: atom index and its valence.
|
| 267 |
+
assert len(atomid_valence) == 2
|
| 268 |
+
idx = atomid_valence[0]
|
| 269 |
+
queue = []
|
| 270 |
+
for b in mol.GetAtomWithIdx(idx).GetBonds():
|
| 271 |
+
queue.append((b.GetIdx(), int(b.GetBondType()), b.GetBeginAtomIdx(), b.GetEndAtomIdx()))
|
| 272 |
+
queue.sort(key=lambda tup: tup[1], reverse=True)
|
| 273 |
+
if queue:
|
| 274 |
+
start = queue[0][2]
|
| 275 |
+
end = queue[0][3]
|
| 276 |
+
mol.RemoveBond(start, end)
|
| 277 |
+
return mol
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def process(self, size=None):
|
| 281 |
+
"""
|
| 282 |
+
Processes the raw SMILES file by filtering and converting each valid SMILES into a PyTorch Geometric Data object.
|
| 283 |
+
The resulting dataset is saved to disk.
|
| 284 |
+
|
| 285 |
+
Parameters:
|
| 286 |
+
size (optional): Placeholder parameter for compatibility.
|
| 287 |
+
|
| 288 |
+
Side Effects:
|
| 289 |
+
Saves the processed dataset as a file in the processed directory.
|
| 290 |
+
"""
|
| 291 |
+
# Read raw SMILES from file (assuming CSV with no header)
|
| 292 |
+
smiles_list = pd.read_csv(self.raw_files, header=None)[0].tolist()
|
| 293 |
+
max_length, filtered_smiles = self._filter_smiles(smiles_list)
|
| 294 |
+
data_list = []
|
| 295 |
+
self.m_dim = len(self.atom_decoder_m)
|
| 296 |
+
for smiles in tqdm(filtered_smiles, desc='Processing dataset', total=len(filtered_smiles)):
|
| 297 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 298 |
+
A = self._genA(mol, connected=True, max_length=max_length)
|
| 299 |
+
if A is not None:
|
| 300 |
+
x_array = self._genX(mol, max_length=max_length)
|
| 301 |
+
if x_array is None:
|
| 302 |
+
continue
|
| 303 |
+
x = torch.from_numpy(x_array).to(torch.long).view(1, -1)
|
| 304 |
+
x = label2onehot(x, self.m_dim).squeeze()
|
| 305 |
+
if self.features:
|
| 306 |
+
f = torch.from_numpy(self._genF(mol, max_length=max_length)).to(torch.long).view(x.shape[0], -1)
|
| 307 |
+
x = torch.concat((x, f), dim=-1)
|
| 308 |
+
adjacency = torch.from_numpy(A)
|
| 309 |
+
edge_index = adjacency.nonzero(as_tuple=False).t().contiguous()
|
| 310 |
+
edge_attr = adjacency[edge_index[0], edge_index[1]].to(torch.long)
|
| 311 |
+
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=smiles)
|
| 312 |
+
if self.pre_filter is not None and not self.pre_filter(data):
|
| 313 |
+
continue
|
| 314 |
+
if self.pre_transform is not None:
|
| 315 |
+
data = self.pre_transform(data)
|
| 316 |
+
data_list.append(data)
|
| 317 |
+
torch.save(self.collate(data_list), osp.join(self.processed_dir, self.dataset_file))
|
src/data/utils.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch_geometric.data import Data, InMemoryDataset
|
| 9 |
+
import torch_geometric.utils as geoutils
|
| 10 |
+
|
| 11 |
+
from rdkit import Chem, RDLogger
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def label2onehot(labels, dim, device=None):
|
| 16 |
+
"""Convert label indices to one-hot vectors."""
|
| 17 |
+
out = torch.zeros(list(labels.size())+[dim])
|
| 18 |
+
if device:
|
| 19 |
+
out = out.to(device)
|
| 20 |
+
|
| 21 |
+
out.scatter_(len(out.size())-1,labels.unsqueeze(-1),1.)
|
| 22 |
+
|
| 23 |
+
return out.float()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_encoders_decoders(raw_file1, raw_file2, max_atom):
|
| 27 |
+
"""
|
| 28 |
+
Given two raw SMILES files, either load the atom and bond encoders/decoders
|
| 29 |
+
if they exist (naming them based on the file names) or create and save them.
|
| 30 |
+
|
| 31 |
+
Parameters:
|
| 32 |
+
raw_file1 (str): Path to the first SMILES file.
|
| 33 |
+
raw_file2 (str): Path to the second SMILES file.
|
| 34 |
+
max_atom (int): Maximum allowed number of atoms in a molecule.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
atom_encoder (dict): Mapping from atomic numbers to indices.
|
| 38 |
+
atom_decoder (dict): Mapping from indices to atomic numbers.
|
| 39 |
+
bond_encoder (dict): Mapping from bond types to indices.
|
| 40 |
+
bond_decoder (dict): Mapping from indices to bond types.
|
| 41 |
+
"""
|
| 42 |
+
# Determine unique suffix based on the two file names (alphabetically sorted for consistency)
|
| 43 |
+
name1 = os.path.splitext(os.path.basename(raw_file1))[0]
|
| 44 |
+
name2 = os.path.splitext(os.path.basename(raw_file2))[0]
|
| 45 |
+
sorted_names = sorted([name1, name2])
|
| 46 |
+
suffix = f"{sorted_names[0]}_{sorted_names[1]}"
|
| 47 |
+
|
| 48 |
+
# Define encoder/decoder directories and file paths
|
| 49 |
+
enc_dir = os.path.join("data", "encoders")
|
| 50 |
+
dec_dir = os.path.join("data", "decoders")
|
| 51 |
+
atom_encoder_path = os.path.join(enc_dir, f"atom_{suffix}.pkl")
|
| 52 |
+
atom_decoder_path = os.path.join(dec_dir, f"atom_{suffix}.pkl")
|
| 53 |
+
bond_encoder_path = os.path.join(enc_dir, f"bond_{suffix}.pkl")
|
| 54 |
+
bond_decoder_path = os.path.join(dec_dir, f"bond_{suffix}.pkl")
|
| 55 |
+
|
| 56 |
+
# If all files exist, load and return them
|
| 57 |
+
if (os.path.exists(atom_encoder_path) and os.path.exists(atom_decoder_path) and
|
| 58 |
+
os.path.exists(bond_encoder_path) and os.path.exists(bond_decoder_path)):
|
| 59 |
+
with open(atom_encoder_path, "rb") as f:
|
| 60 |
+
atom_encoder = pickle.load(f)
|
| 61 |
+
with open(atom_decoder_path, "rb") as f:
|
| 62 |
+
atom_decoder = pickle.load(f)
|
| 63 |
+
with open(bond_encoder_path, "rb") as f:
|
| 64 |
+
bond_encoder = pickle.load(f)
|
| 65 |
+
with open(bond_decoder_path, "rb") as f:
|
| 66 |
+
bond_decoder = pickle.load(f)
|
| 67 |
+
print("Loaded existing encoders/decoders!")
|
| 68 |
+
return atom_encoder, atom_decoder, bond_encoder, bond_decoder
|
| 69 |
+
|
| 70 |
+
# Otherwise, create the encoders/decoders
|
| 71 |
+
print("Creating new encoders/decoders...")
|
| 72 |
+
# Read SMILES from both files (assuming one SMILES per row, no header)
|
| 73 |
+
smiles1 = pd.read_csv(raw_file1, header=None)[0].tolist()
|
| 74 |
+
smiles2 = pd.read_csv(raw_file2, header=None)[0].tolist()
|
| 75 |
+
smiles_combined = smiles1 + smiles2
|
| 76 |
+
|
| 77 |
+
atom_labels = set()
|
| 78 |
+
bond_labels = set()
|
| 79 |
+
max_length = 0
|
| 80 |
+
filtered_smiles = []
|
| 81 |
+
|
| 82 |
+
# Process each SMILES: keep only valid molecules with <= max_atom atoms
|
| 83 |
+
for smiles in tqdm(smiles_combined, desc="Processing SMILES"):
|
| 84 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 85 |
+
if mol is None:
|
| 86 |
+
continue
|
| 87 |
+
molecule_size = mol.GetNumAtoms()
|
| 88 |
+
if molecule_size > max_atom:
|
| 89 |
+
continue
|
| 90 |
+
filtered_smiles.append(smiles)
|
| 91 |
+
# Collect atomic numbers
|
| 92 |
+
atom_labels.update([atom.GetAtomicNum() for atom in mol.GetAtoms()])
|
| 93 |
+
max_length = max(max_length, molecule_size)
|
| 94 |
+
# Collect bond types
|
| 95 |
+
bond_labels.update([bond.GetBondType() for bond in mol.GetBonds()])
|
| 96 |
+
|
| 97 |
+
# Add a PAD symbol (here using 0 for atoms)
|
| 98 |
+
atom_labels.add(0)
|
| 99 |
+
atom_labels = sorted(atom_labels)
|
| 100 |
+
|
| 101 |
+
# For bonds, prepend the PAD bond type (using rdkit's BondType.ZERO)
|
| 102 |
+
bond_labels = sorted(bond_labels)
|
| 103 |
+
bond_labels = [Chem.rdchem.BondType.ZERO] + bond_labels
|
| 104 |
+
|
| 105 |
+
# Create encoder and decoder dictionaries
|
| 106 |
+
atom_encoder = {l: i for i, l in enumerate(atom_labels)}
|
| 107 |
+
atom_decoder = {i: l for i, l in enumerate(atom_labels)}
|
| 108 |
+
bond_encoder = {l: i for i, l in enumerate(bond_labels)}
|
| 109 |
+
bond_decoder = {i: l for i, l in enumerate(bond_labels)}
|
| 110 |
+
|
| 111 |
+
# Ensure directories exist
|
| 112 |
+
os.makedirs(enc_dir, exist_ok=True)
|
| 113 |
+
os.makedirs(dec_dir, exist_ok=True)
|
| 114 |
+
|
| 115 |
+
# Save the encoders/decoders to disk
|
| 116 |
+
with open(atom_encoder_path, "wb") as f:
|
| 117 |
+
pickle.dump(atom_encoder, f)
|
| 118 |
+
with open(atom_decoder_path, "wb") as f:
|
| 119 |
+
pickle.dump(atom_decoder, f)
|
| 120 |
+
with open(bond_encoder_path, "wb") as f:
|
| 121 |
+
pickle.dump(bond_encoder, f)
|
| 122 |
+
with open(bond_decoder_path, "wb") as f:
|
| 123 |
+
pickle.dump(bond_decoder, f)
|
| 124 |
+
|
| 125 |
+
print("Encoders/decoders created and saved.")
|
| 126 |
+
return atom_encoder, atom_decoder, bond_encoder, bond_decoder
|
| 127 |
+
|
| 128 |
+
def load_molecules(data=None, b_dim=32, m_dim=32, device=None, batch_size=32):
|
| 129 |
+
data = data.to(device)
|
| 130 |
+
a = geoutils.to_dense_adj(
|
| 131 |
+
edge_index = data.edge_index,
|
| 132 |
+
batch=data.batch,
|
| 133 |
+
edge_attr=data.edge_attr,
|
| 134 |
+
max_num_nodes=int(data.batch.shape[0]/batch_size)
|
| 135 |
+
)
|
| 136 |
+
x_tensor = data.x.view(batch_size,int(data.batch.shape[0]/batch_size),-1)
|
| 137 |
+
a_tensor = label2onehot(a, b_dim, device)
|
| 138 |
+
|
| 139 |
+
a_tensor_vec = a_tensor.reshape(batch_size,-1)
|
| 140 |
+
x_tensor_vec = x_tensor.reshape(batch_size,-1)
|
| 141 |
+
real_graphs = torch.concat((x_tensor_vec,a_tensor_vec),dim=-1)
|
| 142 |
+
|
| 143 |
+
return real_graphs, a_tensor, x_tensor
|
src/model/__init__.py
ADDED
|
File without changes
|
src/model/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (156 Bytes). View file
|
|
|
src/model/__pycache__/layers.cpython-310.pyc
ADDED
|
Binary file (8.31 kB). View file
|
|
|
src/model/__pycache__/loss.cpython-310.pyc
ADDED
|
Binary file (2.04 kB). View file
|
|
|
src/model/__pycache__/models.cpython-310.pyc
ADDED
|
Binary file (7.35 kB). View file
|
|
|
src/model/layers.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
class MLP(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
A simple Multi-Layer Perceptron (MLP) module consisting of two linear layers with a ReLU activation in between,
|
| 10 |
+
followed by a dropout on the output.
|
| 11 |
+
|
| 12 |
+
Attributes:
|
| 13 |
+
fc1 (nn.Linear): The first fully-connected layer.
|
| 14 |
+
act (nn.ReLU): ReLU activation function.
|
| 15 |
+
fc2 (nn.Linear): The second fully-connected layer.
|
| 16 |
+
droprateout (nn.Dropout): Dropout layer applied to the output.
|
| 17 |
+
"""
|
| 18 |
+
def __init__(self, in_feat, hid_feat=None, out_feat=None, dropout=0.):
|
| 19 |
+
"""
|
| 20 |
+
Initializes the MLP module.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
in_feat (int): Number of input features.
|
| 24 |
+
hid_feat (int, optional): Number of hidden features. Defaults to in_feat if not provided.
|
| 25 |
+
out_feat (int, optional): Number of output features. Defaults to in_feat if not provided.
|
| 26 |
+
dropout (float, optional): Dropout rate. Defaults to 0.
|
| 27 |
+
"""
|
| 28 |
+
super().__init__()
|
| 29 |
+
|
| 30 |
+
# Set hidden and output dimensions to input dimension if not specified
|
| 31 |
+
if not hid_feat:
|
| 32 |
+
hid_feat = in_feat
|
| 33 |
+
if not out_feat:
|
| 34 |
+
out_feat = in_feat
|
| 35 |
+
|
| 36 |
+
self.fc1 = nn.Linear(in_feat, hid_feat)
|
| 37 |
+
self.act = nn.ReLU()
|
| 38 |
+
self.fc2 = nn.Linear(hid_feat, out_feat)
|
| 39 |
+
self.droprateout = nn.Dropout(dropout)
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
"""
|
| 43 |
+
Forward pass for the MLP.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
x (torch.Tensor): Input tensor.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
torch.Tensor: Output tensor after applying the linear layers, activation, and dropout.
|
| 50 |
+
"""
|
| 51 |
+
x = self.fc1(x)
|
| 52 |
+
x = self.act(x)
|
| 53 |
+
x = self.fc2(x)
|
| 54 |
+
return self.droprateout(x)
|
| 55 |
+
|
| 56 |
+
class MHA(nn.Module):
|
| 57 |
+
"""
|
| 58 |
+
Multi-Head Attention (MHA) module of the graph transformer with edge features incorporated into the attention computation.
|
| 59 |
+
|
| 60 |
+
Attributes:
|
| 61 |
+
heads (int): Number of attention heads.
|
| 62 |
+
scale (float): Scaling factor for the attention scores.
|
| 63 |
+
q, k, v (nn.Linear): Linear layers to project the node features into query, key, and value embeddings.
|
| 64 |
+
e (nn.Linear): Linear layer to project the edge features.
|
| 65 |
+
d_k (int): Dimension of each attention head.
|
| 66 |
+
out_e (nn.Linear): Linear layer applied to the computed edge features.
|
| 67 |
+
out_n (nn.Linear): Linear layer applied to the aggregated node features.
|
| 68 |
+
"""
|
| 69 |
+
def __init__(self, dim, heads, attention_dropout=0.):
|
| 70 |
+
"""
|
| 71 |
+
Initializes the Multi-Head Attention module.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
dim (int): Dimensionality of the input features.
|
| 75 |
+
heads (int): Number of attention heads.
|
| 76 |
+
attention_dropout (float, optional): Dropout rate for attention (not used explicitly in this implementation).
|
| 77 |
+
"""
|
| 78 |
+
super().__init__()
|
| 79 |
+
|
| 80 |
+
# Ensure that dimension is divisible by the number of heads
|
| 81 |
+
assert dim % heads == 0
|
| 82 |
+
|
| 83 |
+
self.heads = heads
|
| 84 |
+
self.scale = 1. / math.sqrt(dim) # Scaling factor for attention
|
| 85 |
+
# Linear layers for projecting node features
|
| 86 |
+
self.q = nn.Linear(dim, dim)
|
| 87 |
+
self.k = nn.Linear(dim, dim)
|
| 88 |
+
self.v = nn.Linear(dim, dim)
|
| 89 |
+
# Linear layer for projecting edge features
|
| 90 |
+
self.e = nn.Linear(dim, dim)
|
| 91 |
+
self.d_k = dim // heads # Dimension per head
|
| 92 |
+
|
| 93 |
+
# Linear layers for output transformations
|
| 94 |
+
self.out_e = nn.Linear(dim, dim)
|
| 95 |
+
self.out_n = nn.Linear(dim, dim)
|
| 96 |
+
|
| 97 |
+
def forward(self, node, edge):
|
| 98 |
+
"""
|
| 99 |
+
Forward pass for the Multi-Head Attention.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
node (torch.Tensor): Node feature tensor of shape (batch, num_nodes, dim).
|
| 103 |
+
edge (torch.Tensor): Edge feature tensor of shape (batch, num_nodes, num_nodes, dim).
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
tuple: (updated node features, updated edge features)
|
| 107 |
+
"""
|
| 108 |
+
b, n, c = node.shape
|
| 109 |
+
|
| 110 |
+
# Compute query, key, and value embeddings and reshape for multi-head attention
|
| 111 |
+
q_embed = self.q(node).view(b, n, self.heads, c // self.heads)
|
| 112 |
+
k_embed = self.k(node).view(b, n, self.heads, c // self.heads)
|
| 113 |
+
v_embed = self.v(node).view(b, n, self.heads, c // self.heads)
|
| 114 |
+
|
| 115 |
+
# Compute edge embeddings
|
| 116 |
+
e_embed = self.e(edge).view(b, n, n, self.heads, c // self.heads)
|
| 117 |
+
|
| 118 |
+
# Adjust dimensions for broadcasting: add singleton dimensions to queries and keys
|
| 119 |
+
q_embed = q_embed.unsqueeze(2) # Shape: (b, n, 1, heads, c//heads)
|
| 120 |
+
k_embed = k_embed.unsqueeze(1) # Shape: (b, 1, n, heads, c//heads)
|
| 121 |
+
|
| 122 |
+
# Compute attention scores
|
| 123 |
+
attn = q_embed * k_embed
|
| 124 |
+
attn = attn / math.sqrt(self.d_k)
|
| 125 |
+
attn = attn * (e_embed + 1) * e_embed # Modulated attention incorporating edge features
|
| 126 |
+
|
| 127 |
+
edge_out = self.out_e(attn.flatten(3)) # Flatten last dimension for linear layer
|
| 128 |
+
|
| 129 |
+
# Apply softmax over the node dimension to obtain normalized attention weights
|
| 130 |
+
attn = F.softmax(attn, dim=2)
|
| 131 |
+
|
| 132 |
+
v_embed = v_embed.unsqueeze(1) # Adjust dimensions to broadcast: (b, 1, n, heads, c//heads)
|
| 133 |
+
v_embed = attn * v_embed
|
| 134 |
+
v_embed = v_embed.sum(dim=2).flatten(2)
|
| 135 |
+
node_out = self.out_n(v_embed)
|
| 136 |
+
|
| 137 |
+
return node_out, edge_out
|
| 138 |
+
|
| 139 |
+
class Encoder_Block(nn.Module):
|
| 140 |
+
"""
|
| 141 |
+
Transformer encoder block that integrates node and edge features.
|
| 142 |
+
|
| 143 |
+
Consists of:
|
| 144 |
+
- A multi-head attention layer with edge modulation.
|
| 145 |
+
- Two MLP layers, each with residual connections and layer normalization.
|
| 146 |
+
|
| 147 |
+
Attributes:
|
| 148 |
+
ln1, ln3, ln4, ln5, ln6 (nn.LayerNorm): Layer normalization modules.
|
| 149 |
+
attn (MHA): Multi-head attention module.
|
| 150 |
+
mlp, mlp2 (MLP): MLP modules for further transformation of node and edge features.
|
| 151 |
+
"""
|
| 152 |
+
def __init__(self, dim, heads, act, mlp_ratio=4, drop_rate=0.):
|
| 153 |
+
"""
|
| 154 |
+
Initializes the encoder block.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
dim (int): Dimensionality of the input features.
|
| 158 |
+
heads (int): Number of attention heads.
|
| 159 |
+
act (callable): Activation function (not explicitly used in this block, but provided for potential extensions).
|
| 160 |
+
mlp_ratio (int, optional): Ratio to determine the hidden layer size in the MLP. Defaults to 4.
|
| 161 |
+
drop_rate (float, optional): Dropout rate applied in the MLPs. Defaults to 0.
|
| 162 |
+
"""
|
| 163 |
+
super().__init__()
|
| 164 |
+
|
| 165 |
+
self.ln1 = nn.LayerNorm(dim)
|
| 166 |
+
self.attn = MHA(dim, heads, drop_rate)
|
| 167 |
+
self.ln3 = nn.LayerNorm(dim)
|
| 168 |
+
self.ln4 = nn.LayerNorm(dim)
|
| 169 |
+
self.mlp = MLP(dim, dim * mlp_ratio, dim, dropout=drop_rate)
|
| 170 |
+
self.mlp2 = MLP(dim, dim * mlp_ratio, dim, dropout=drop_rate)
|
| 171 |
+
self.ln5 = nn.LayerNorm(dim)
|
| 172 |
+
self.ln6 = nn.LayerNorm(dim)
|
| 173 |
+
|
| 174 |
+
def forward(self, x, y):
|
| 175 |
+
"""
|
| 176 |
+
Forward pass of the encoder block.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
x (torch.Tensor): Node feature tensor.
|
| 180 |
+
y (torch.Tensor): Edge feature tensor.
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
tuple: (updated node features, updated edge features)
|
| 184 |
+
"""
|
| 185 |
+
x1 = self.ln1(x)
|
| 186 |
+
x2, y1 = self.attn(x1, y)
|
| 187 |
+
x2 = x1 + x2
|
| 188 |
+
y2 = y + y1
|
| 189 |
+
x2 = self.ln3(x2)
|
| 190 |
+
y2 = self.ln4(y2)
|
| 191 |
+
x = self.ln5(x2 + self.mlp(x2))
|
| 192 |
+
y = self.ln6(y2 + self.mlp2(y2))
|
| 193 |
+
return x, y
|
| 194 |
+
|
| 195 |
+
class TransformerEncoder(nn.Module):
|
| 196 |
+
"""
|
| 197 |
+
Transformer Encoder composed of a sequence of encoder blocks.
|
| 198 |
+
|
| 199 |
+
Attributes:
|
| 200 |
+
Encoder_Blocks (nn.ModuleList): A list of Encoder_Block modules stacked sequentially.
|
| 201 |
+
"""
|
| 202 |
+
def __init__(self, dim, depth, heads, act, mlp_ratio=4, drop_rate=0.1):
|
| 203 |
+
"""
|
| 204 |
+
Initializes the Transformer Encoder.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
dim (int): Dimensionality of the input features.
|
| 208 |
+
depth (int): Number of encoder blocks to stack.
|
| 209 |
+
heads (int): Number of attention heads in each block.
|
| 210 |
+
act (callable): Activation function (passed to encoder blocks for potential use).
|
| 211 |
+
mlp_ratio (int, optional): Ratio for determining the hidden layer size in MLP modules. Defaults to 4.
|
| 212 |
+
drop_rate (float, optional): Dropout rate for the MLPs within each block. Defaults to 0.1.
|
| 213 |
+
"""
|
| 214 |
+
super().__init__()
|
| 215 |
+
|
| 216 |
+
self.Encoder_Blocks = nn.ModuleList([
|
| 217 |
+
Encoder_Block(dim, heads, act, mlp_ratio, drop_rate)
|
| 218 |
+
for _ in range(depth)
|
| 219 |
+
])
|
| 220 |
+
|
| 221 |
+
def forward(self, x, y):
|
| 222 |
+
"""
|
| 223 |
+
Forward pass of the Transformer Encoder.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
x (torch.Tensor): Node feature tensor.
|
| 227 |
+
y (torch.Tensor): Edge feature tensor.
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
tuple: (final node features, final edge features) after processing through all encoder blocks.
|
| 231 |
+
"""
|
| 232 |
+
for block in self.Encoder_Blocks:
|
| 233 |
+
x, y = block(x, y)
|
| 234 |
+
return x, y
|
src/model/loss.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def gradient_penalty(discriminator, real_node, real_edge, fake_node, fake_edge, batch_size, device):
|
| 5 |
+
"""
|
| 6 |
+
Calculate gradient penalty for WGAN-GP.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
discriminator: The discriminator model
|
| 10 |
+
real_node: Real node features
|
| 11 |
+
real_edge: Real edge features
|
| 12 |
+
fake_node: Generated node features
|
| 13 |
+
fake_edge: Generated edge features
|
| 14 |
+
batch_size: Batch size
|
| 15 |
+
device: Device to compute on
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
Gradient penalty term
|
| 19 |
+
"""
|
| 20 |
+
# Generate random interpolation factors
|
| 21 |
+
eps_edge = torch.rand(batch_size, 1, 1, 1, device=device)
|
| 22 |
+
eps_node = torch.rand(batch_size, 1, 1, device=device)
|
| 23 |
+
|
| 24 |
+
# Create interpolated samples
|
| 25 |
+
int_node = (eps_node * real_node + (1 - eps_node) * fake_node).requires_grad_(True)
|
| 26 |
+
int_edge = (eps_edge * real_edge + (1 - eps_edge) * fake_edge).requires_grad_(True)
|
| 27 |
+
|
| 28 |
+
logits_interpolated = discriminator(int_edge, int_node)
|
| 29 |
+
|
| 30 |
+
# Calculate gradients for both node and edge inputs
|
| 31 |
+
weight = torch.ones(logits_interpolated.size(), requires_grad=False).to(device)
|
| 32 |
+
gradients = torch.autograd.grad(
|
| 33 |
+
outputs=logits_interpolated,
|
| 34 |
+
inputs=[int_node, int_edge],
|
| 35 |
+
grad_outputs=weight,
|
| 36 |
+
create_graph=True,
|
| 37 |
+
retain_graph=True,
|
| 38 |
+
only_inputs=True
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Combine gradients from both inputs
|
| 42 |
+
gradients_node = gradients[0].view(batch_size, -1)
|
| 43 |
+
gradients_edge = gradients[1].view(batch_size, -1)
|
| 44 |
+
gradients = torch.cat([gradients_node, gradients_edge], dim=1)
|
| 45 |
+
|
| 46 |
+
# Calculate gradient penalty
|
| 47 |
+
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
|
| 48 |
+
|
| 49 |
+
return gradient_penalty
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def discriminator_loss(generator, discriminator, drug_adj, drug_annot, mol_adj, mol_annot, batch_size, device, lambda_gp):
|
| 53 |
+
# Compute loss for drugs
|
| 54 |
+
logits_real_disc = discriminator(drug_adj, drug_annot)
|
| 55 |
+
|
| 56 |
+
# Use mean reduction for more stable training
|
| 57 |
+
prediction_real = -torch.mean(logits_real_disc)
|
| 58 |
+
|
| 59 |
+
# Compute loss for generated molecules
|
| 60 |
+
node, edge, node_sample, edge_sample = generator(mol_adj, mol_annot)
|
| 61 |
+
|
| 62 |
+
logits_fake_disc = discriminator(edge_sample.detach(), node_sample.detach())
|
| 63 |
+
|
| 64 |
+
prediction_fake = torch.mean(logits_fake_disc)
|
| 65 |
+
|
| 66 |
+
# Compute gradient penalty using the new function
|
| 67 |
+
gp = gradient_penalty(discriminator, drug_annot, drug_adj, node_sample.detach(), edge_sample.detach(), batch_size, device)
|
| 68 |
+
|
| 69 |
+
# Calculate total discriminator loss
|
| 70 |
+
d_loss = prediction_fake + prediction_real + lambda_gp * gp
|
| 71 |
+
|
| 72 |
+
return node, edge, d_loss
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def generator_loss(generator, discriminator, mol_adj, mol_annot, batch_size):
|
| 76 |
+
# Generate fake molecules
|
| 77 |
+
node, edge, node_sample, edge_sample = generator(mol_adj, mol_annot)
|
| 78 |
+
|
| 79 |
+
# Compute logits for fake molecules
|
| 80 |
+
logits_fake_disc = discriminator(edge_sample, node_sample)
|
| 81 |
+
|
| 82 |
+
prediction_fake = -torch.mean(logits_fake_disc)
|
| 83 |
+
g_loss = prediction_fake
|
| 84 |
+
|
| 85 |
+
return g_loss, node, edge, node_sample, edge_sample
|
src/model/models.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from src.model.layers import TransformerEncoder
|
| 4 |
+
|
| 5 |
+
class Generator(nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
Generator network that uses a Transformer Encoder to process node and edge features.
|
| 8 |
+
|
| 9 |
+
The network first processes input node and edge features with separate linear layers,
|
| 10 |
+
then applies a Transformer Encoder to model interactions, and finally outputs both transformed
|
| 11 |
+
features and readout samples.
|
| 12 |
+
"""
|
| 13 |
+
def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio):
|
| 14 |
+
"""
|
| 15 |
+
Initializes the Generator.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
act (str): Type of activation function to use ("relu", "leaky", "sigmoid", or "tanh").
|
| 19 |
+
vertexes (int): Number of vertexes in the graph.
|
| 20 |
+
edges (int): Number of edge features.
|
| 21 |
+
nodes (int): Number of node features.
|
| 22 |
+
dropout (float): Dropout rate.
|
| 23 |
+
dim (int): Dimensionality used for intermediate features.
|
| 24 |
+
depth (int): Number of Transformer encoder blocks.
|
| 25 |
+
heads (int): Number of attention heads in the Transformer.
|
| 26 |
+
mlp_ratio (int): Ratio for determining hidden layer size in MLP modules.
|
| 27 |
+
"""
|
| 28 |
+
super(Generator, self).__init__()
|
| 29 |
+
self.vertexes = vertexes
|
| 30 |
+
self.edges = edges
|
| 31 |
+
self.nodes = nodes
|
| 32 |
+
self.depth = depth
|
| 33 |
+
self.dim = dim
|
| 34 |
+
self.heads = heads
|
| 35 |
+
self.mlp_ratio = mlp_ratio
|
| 36 |
+
self.dropout = dropout
|
| 37 |
+
|
| 38 |
+
# Set the activation function based on the provided string
|
| 39 |
+
if act == "relu":
|
| 40 |
+
act = nn.ReLU()
|
| 41 |
+
elif act == "leaky":
|
| 42 |
+
act = nn.LeakyReLU()
|
| 43 |
+
elif act == "sigmoid":
|
| 44 |
+
act = nn.Sigmoid()
|
| 45 |
+
elif act == "tanh":
|
| 46 |
+
act = nn.Tanh()
|
| 47 |
+
|
| 48 |
+
# Calculate the total number of features and dimensions for transformer
|
| 49 |
+
self.features = vertexes * vertexes * edges + vertexes * nodes
|
| 50 |
+
self.transformer_dim = vertexes * vertexes * dim + vertexes * dim
|
| 51 |
+
|
| 52 |
+
self.node_layers = nn.Sequential(
|
| 53 |
+
nn.Linear(nodes, 64), act,
|
| 54 |
+
nn.Linear(64, dim), act,
|
| 55 |
+
nn.Dropout(self.dropout)
|
| 56 |
+
)
|
| 57 |
+
self.edge_layers = nn.Sequential(
|
| 58 |
+
nn.Linear(edges, 64), act,
|
| 59 |
+
nn.Linear(64, dim), act,
|
| 60 |
+
nn.Dropout(self.dropout)
|
| 61 |
+
)
|
| 62 |
+
self.TransformerEncoder = TransformerEncoder(
|
| 63 |
+
dim=self.dim, depth=self.depth, heads=self.heads, act=act,
|
| 64 |
+
mlp_ratio=self.mlp_ratio, drop_rate=self.dropout
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
self.readout_e = nn.Linear(self.dim, edges)
|
| 68 |
+
self.readout_n = nn.Linear(self.dim, nodes)
|
| 69 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 70 |
+
|
| 71 |
+
def forward(self, z_e, z_n):
|
| 72 |
+
"""
|
| 73 |
+
Forward pass of the Generator.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
z_e (torch.Tensor): Edge features tensor of shape (batch, vertexes, vertexes, edges).
|
| 77 |
+
z_n (torch.Tensor): Node features tensor of shape (batch, vertexes, nodes).
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
tuple: A tuple containing:
|
| 81 |
+
- node: Updated node features after the transformer.
|
| 82 |
+
- edge: Updated edge features after the transformer.
|
| 83 |
+
- node_sample: Readout sample from node features.
|
| 84 |
+
- edge_sample: Readout sample from edge features.
|
| 85 |
+
"""
|
| 86 |
+
b, n, c = z_n.shape
|
| 87 |
+
# The fourth dimension of edge features
|
| 88 |
+
_, _, _, d = z_e.shape
|
| 89 |
+
|
| 90 |
+
# Process node and edge features through their respective layers
|
| 91 |
+
node = self.node_layers(z_n)
|
| 92 |
+
edge = self.edge_layers(z_e)
|
| 93 |
+
# Symmetrize the edge features by averaging with its transpose along vertex dimensions
|
| 94 |
+
edge = (edge + edge.permute(0, 2, 1, 3)) / 2
|
| 95 |
+
|
| 96 |
+
# Pass the features through the Transformer Encoder
|
| 97 |
+
node, edge = self.TransformerEncoder(node, edge)
|
| 98 |
+
|
| 99 |
+
# Readout layers to generate final outputs
|
| 100 |
+
node_sample = self.readout_n(node)
|
| 101 |
+
edge_sample = self.readout_e(edge)
|
| 102 |
+
|
| 103 |
+
return node, edge, node_sample, edge_sample
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class Discriminator(nn.Module):
|
| 107 |
+
"""
|
| 108 |
+
Discriminator network that evaluates node and edge features.
|
| 109 |
+
|
| 110 |
+
It processes features with linear layers, applies a Transformer Encoder to capture dependencies,
|
| 111 |
+
and finally predicts a scalar value using an MLP on aggregated node features.
|
| 112 |
+
|
| 113 |
+
This class is used in DrugGEN model.
|
| 114 |
+
"""
|
| 115 |
+
def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio):
|
| 116 |
+
"""
|
| 117 |
+
Initializes the Discriminator.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
act (str): Activation function type ("relu", "leaky", "sigmoid", or "tanh").
|
| 121 |
+
vertexes (int): Number of vertexes.
|
| 122 |
+
edges (int): Number of edge features.
|
| 123 |
+
nodes (int): Number of node features.
|
| 124 |
+
dropout (float): Dropout rate.
|
| 125 |
+
dim (int): Dimensionality for intermediate representations.
|
| 126 |
+
depth (int): Number of Transformer encoder blocks.
|
| 127 |
+
heads (int): Number of attention heads.
|
| 128 |
+
mlp_ratio (int): MLP ratio for hidden layer dimensions.
|
| 129 |
+
"""
|
| 130 |
+
super(Discriminator, self).__init__()
|
| 131 |
+
self.vertexes = vertexes
|
| 132 |
+
self.edges = edges
|
| 133 |
+
self.nodes = nodes
|
| 134 |
+
self.depth = depth
|
| 135 |
+
self.dim = dim
|
| 136 |
+
self.heads = heads
|
| 137 |
+
self.mlp_ratio = mlp_ratio
|
| 138 |
+
self.dropout = dropout
|
| 139 |
+
|
| 140 |
+
# Set the activation function
|
| 141 |
+
if act == "relu":
|
| 142 |
+
act = nn.ReLU()
|
| 143 |
+
elif act == "leaky":
|
| 144 |
+
act = nn.LeakyReLU()
|
| 145 |
+
elif act == "sigmoid":
|
| 146 |
+
act = nn.Sigmoid()
|
| 147 |
+
elif act == "tanh":
|
| 148 |
+
act = nn.Tanh()
|
| 149 |
+
|
| 150 |
+
self.features = vertexes * vertexes * edges + vertexes * nodes
|
| 151 |
+
self.transformer_dim = vertexes * vertexes * dim + vertexes * dim
|
| 152 |
+
|
| 153 |
+
# Define layers for processing node and edge features
|
| 154 |
+
self.node_layers = nn.Sequential(
|
| 155 |
+
nn.Linear(nodes, 64), act,
|
| 156 |
+
nn.Linear(64, dim), act,
|
| 157 |
+
nn.Dropout(self.dropout)
|
| 158 |
+
)
|
| 159 |
+
self.edge_layers = nn.Sequential(
|
| 160 |
+
nn.Linear(edges, 64), act,
|
| 161 |
+
nn.Linear(64, dim), act,
|
| 162 |
+
nn.Dropout(self.dropout)
|
| 163 |
+
)
|
| 164 |
+
# Transformer Encoder for modeling node and edge interactions
|
| 165 |
+
self.TransformerEncoder = TransformerEncoder(
|
| 166 |
+
dim=self.dim, depth=self.depth, heads=self.heads, act=act,
|
| 167 |
+
mlp_ratio=self.mlp_ratio, drop_rate=self.dropout
|
| 168 |
+
)
|
| 169 |
+
# Calculate dimensions for node features aggregation
|
| 170 |
+
self.node_features = vertexes * dim
|
| 171 |
+
self.edge_features = vertexes * vertexes * dim
|
| 172 |
+
# MLP to predict a scalar value from aggregated node features
|
| 173 |
+
self.node_mlp = nn.Sequential(
|
| 174 |
+
nn.Linear(self.node_features, 64), act,
|
| 175 |
+
nn.Linear(64, 32), act,
|
| 176 |
+
nn.Linear(32, 16), act,
|
| 177 |
+
nn.Linear(16, 1)
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
def forward(self, z_e, z_n):
|
| 181 |
+
"""
|
| 182 |
+
Forward pass of the Discriminator.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
z_e (torch.Tensor): Edge features tensor of shape (batch, vertexes, vertexes, edges).
|
| 186 |
+
z_n (torch.Tensor): Node features tensor of shape (batch, vertexes, nodes).
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
torch.Tensor: Prediction scores (typically a scalar per sample).
|
| 190 |
+
"""
|
| 191 |
+
b, n, c = z_n.shape
|
| 192 |
+
# Unpack the shape of edge features (not used further directly)
|
| 193 |
+
_, _, _, d = z_e.shape
|
| 194 |
+
|
| 195 |
+
# Process node and edge features separately
|
| 196 |
+
node = self.node_layers(z_n)
|
| 197 |
+
edge = self.edge_layers(z_e)
|
| 198 |
+
# Symmetrize edge features by averaging with its transpose
|
| 199 |
+
edge = (edge + edge.permute(0, 2, 1, 3)) / 2
|
| 200 |
+
|
| 201 |
+
# Process features through the Transformer Encoder
|
| 202 |
+
node, edge = self.TransformerEncoder(node, edge)
|
| 203 |
+
|
| 204 |
+
# Flatten node features for MLP
|
| 205 |
+
node = node.view(b, -1)
|
| 206 |
+
# Predict a scalar score using the node MLP
|
| 207 |
+
prediction = self.node_mlp(node)
|
| 208 |
+
|
| 209 |
+
return prediction
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class simple_disc(nn.Module):
|
| 213 |
+
"""
|
| 214 |
+
A simplified discriminator that processes flattened features through an MLP
|
| 215 |
+
to predict a scalar score.
|
| 216 |
+
|
| 217 |
+
This class is used in NoTarget model.
|
| 218 |
+
"""
|
| 219 |
+
def __init__(self, act, m_dim, vertexes, b_dim):
|
| 220 |
+
"""
|
| 221 |
+
Initializes the simple discriminator.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
act (str): Activation function type ("relu", "leaky", "sigmoid", or "tanh").
|
| 225 |
+
m_dim (int): Dimensionality for atom type features.
|
| 226 |
+
vertexes (int): Number of vertexes.
|
| 227 |
+
b_dim (int): Dimensionality for bond type features.
|
| 228 |
+
"""
|
| 229 |
+
super().__init__()
|
| 230 |
+
|
| 231 |
+
# Set the activation function and check if it's supported
|
| 232 |
+
if act == "relu":
|
| 233 |
+
act = nn.ReLU()
|
| 234 |
+
elif act == "leaky":
|
| 235 |
+
act = nn.LeakyReLU()
|
| 236 |
+
elif act == "sigmoid":
|
| 237 |
+
act = nn.Sigmoid()
|
| 238 |
+
elif act == "tanh":
|
| 239 |
+
act = nn.Tanh()
|
| 240 |
+
else:
|
| 241 |
+
raise ValueError("Unsupported activation function: {}".format(act))
|
| 242 |
+
|
| 243 |
+
# Compute total number of features combining both dimensions
|
| 244 |
+
features = vertexes * m_dim + vertexes * vertexes * b_dim
|
| 245 |
+
print(vertexes)
|
| 246 |
+
print(m_dim)
|
| 247 |
+
print(b_dim)
|
| 248 |
+
print(features)
|
| 249 |
+
self.predictor = nn.Sequential(
|
| 250 |
+
nn.Linear(features, 256), act,
|
| 251 |
+
nn.Linear(256, 128), act,
|
| 252 |
+
nn.Linear(128, 64), act,
|
| 253 |
+
nn.Linear(64, 32), act,
|
| 254 |
+
nn.Linear(32, 16), act,
|
| 255 |
+
nn.Linear(16, 1)
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
def forward(self, x):
|
| 259 |
+
"""
|
| 260 |
+
Forward pass of the simple discriminator.
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
x (torch.Tensor): Input features tensor.
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
torch.Tensor: Prediction scores.
|
| 267 |
+
"""
|
| 268 |
+
prediction = self.predictor(x)
|
| 269 |
+
return prediction
|
src/util/__init__.py
ADDED
|
File without changes
|
src/util/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (155 Bytes). View file
|
|
|
src/util/__pycache__/smiles_cor.cpython-310.pyc
ADDED
|
Binary file (30.2 kB). View file
|
|
|
src/util/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (30 kB). View file
|
|
|
src/util/smiles_cor.py
ADDED
|
@@ -0,0 +1,1284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import random
|
| 4 |
+
import re
|
| 5 |
+
import itertools
|
| 6 |
+
import statistics
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.optim as optim
|
| 13 |
+
from torch.utils.data import DataLoader
|
| 14 |
+
from torchtext.data import TabularDataset, Field, BucketIterator, Iterator
|
| 15 |
+
|
| 16 |
+
from rdkit import Chem, rdBase, RDLogger
|
| 17 |
+
from rdkit.Chem import (
|
| 18 |
+
MolStandardize,
|
| 19 |
+
GraphDescriptors,
|
| 20 |
+
Lipinski,
|
| 21 |
+
AllChem,
|
| 22 |
+
)
|
| 23 |
+
from rdkit.Chem.rdSLNParse import MolFromSLN
|
| 24 |
+
from rdkit.Chem.rdmolfiles import MolFromSmiles
|
| 25 |
+
from chembl_structure_pipeline import standardizer
|
| 26 |
+
|
| 27 |
+
RDLogger.DisableLog('rdApp.*')
|
| 28 |
+
|
| 29 |
+
SEED = 42
|
| 30 |
+
random.seed(SEED)
|
| 31 |
+
torch.manual_seed(SEED)
|
| 32 |
+
torch.backends.cudnn.deterministic = True
|
| 33 |
+
|
| 34 |
+
##################################################################################################
|
| 35 |
+
##################################################################################################
|
| 36 |
+
# #
|
| 37 |
+
# THIS SCRIPT IS DIRECTLY ADAPTED FROM https://github.com/LindeSchoenmaker/SMILES-corrector #
|
| 38 |
+
# #
|
| 39 |
+
##################################################################################################
|
| 40 |
+
##################################################################################################
|
| 41 |
+
def is_smiles(array,
|
| 42 |
+
TRG,
|
| 43 |
+
reverse: bool,
|
| 44 |
+
return_output=False,
|
| 45 |
+
src=None,
|
| 46 |
+
src_field=None):
|
| 47 |
+
"""Turns predicted tokens within batch into smiles and evaluates their validity
|
| 48 |
+
Arguments:
|
| 49 |
+
array: Tensor with most probable token for each location for each sequence in batch
|
| 50 |
+
[trg len, batch size]
|
| 51 |
+
TRG: target field for getting tokens from vocab
|
| 52 |
+
reverse (bool): True if the target sequence is reversed
|
| 53 |
+
return_output (bool): True if output sequences and their validity should be saved
|
| 54 |
+
Returns:
|
| 55 |
+
df: dataframe with correct and incorrect sequences
|
| 56 |
+
valids: list with booleans that show if prediction was a valid SMILES (True) or invalid one (False)
|
| 57 |
+
smiless: list of the predicted smiles
|
| 58 |
+
"""
|
| 59 |
+
trg_field = TRG
|
| 60 |
+
valids = []
|
| 61 |
+
smiless = []
|
| 62 |
+
if return_output:
|
| 63 |
+
df = pd.DataFrame()
|
| 64 |
+
else:
|
| 65 |
+
df = None
|
| 66 |
+
batch_size = array.size(1)
|
| 67 |
+
# check if the first token should be removed, first token is zero because
|
| 68 |
+
# outputs initaliazed to all be zeros
|
| 69 |
+
if int((array[0, 0]).tolist()) == 0:
|
| 70 |
+
start = 1
|
| 71 |
+
else:
|
| 72 |
+
start = 0
|
| 73 |
+
# for each sequence in the batch
|
| 74 |
+
for i in range(0, batch_size):
|
| 75 |
+
# turns sequence from tensor to list skipps first row as this is not
|
| 76 |
+
# filled in in forward
|
| 77 |
+
sequence = (array[start:, i]).tolist()
|
| 78 |
+
# goes from embedded to tokens
|
| 79 |
+
trg_tokens = [trg_field.vocab.itos[int(t)] for t in sequence]
|
| 80 |
+
# print(trg_tokens)
|
| 81 |
+
# takes all tokens untill eos token, model would be faster if did this
|
| 82 |
+
# one step earlier, but then changes in vocab order would disrupt.
|
| 83 |
+
rev_tokens = list(
|
| 84 |
+
itertools.takewhile(lambda x: x != "<eos>", trg_tokens))
|
| 85 |
+
if reverse:
|
| 86 |
+
rev_tokens = rev_tokens[::-1]
|
| 87 |
+
smiles = "".join(rev_tokens)
|
| 88 |
+
# determine how many valid smiles are made
|
| 89 |
+
valid = True if MolFromSmiles(smiles) else False
|
| 90 |
+
valids.append(valid)
|
| 91 |
+
smiless.append(smiles)
|
| 92 |
+
if return_output:
|
| 93 |
+
if valid:
|
| 94 |
+
df.loc[i, "CORRECT"] = smiles
|
| 95 |
+
else:
|
| 96 |
+
df.loc[i, "INCORRECT"] = smiles
|
| 97 |
+
|
| 98 |
+
# add the original drugex outputs to the _de dataframe
|
| 99 |
+
if return_output and src is not None:
|
| 100 |
+
for i in range(0, batch_size):
|
| 101 |
+
# turns sequence from tensor to list skipps first row as this is
|
| 102 |
+
# <sos> for src
|
| 103 |
+
sequence = (src[1:, i]).tolist()
|
| 104 |
+
# goes from embedded to tokens
|
| 105 |
+
src_tokens = [src_field.vocab.itos[int(t)] for t in sequence]
|
| 106 |
+
# takes all tokens untill eos token, model would be faster if did
|
| 107 |
+
# this one step earlier, but then changes in vocab order would
|
| 108 |
+
# disrupt.
|
| 109 |
+
rev_tokens = list(
|
| 110 |
+
itertools.takewhile(lambda x: x != "<eos>", src_tokens))
|
| 111 |
+
smiles = "".join(rev_tokens)
|
| 112 |
+
df.loc[i, "ORIGINAL"] = smiles
|
| 113 |
+
|
| 114 |
+
return df, valids, smiless
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def is_unchanged(array,
|
| 118 |
+
TRG,
|
| 119 |
+
reverse: bool,
|
| 120 |
+
return_output=False,
|
| 121 |
+
src=None,
|
| 122 |
+
src_field=None):
|
| 123 |
+
"""Checks is output is different from input
|
| 124 |
+
Arguments:
|
| 125 |
+
array: Tensor with most probable token for each location for each sequence in batch
|
| 126 |
+
[trg len, batch size]
|
| 127 |
+
TRG: target field for getting tokens from vocab
|
| 128 |
+
reverse (bool): True if the target sequence is reversed
|
| 129 |
+
return_output (bool): True if output sequences and their validity should be saved
|
| 130 |
+
Returns:
|
| 131 |
+
df: dataframe with correct and incorrect sequences
|
| 132 |
+
valids: list with booleans that show if prediction was a valid SMILES (True) or invalid one (False)
|
| 133 |
+
smiless: list of the predicted smiles
|
| 134 |
+
"""
|
| 135 |
+
trg_field = TRG
|
| 136 |
+
sources = []
|
| 137 |
+
batch_size = array.size(1)
|
| 138 |
+
unchanged = 0
|
| 139 |
+
|
| 140 |
+
# check if the first token should be removed, first token is zero because
|
| 141 |
+
# outputs initaliazed to all be zeros
|
| 142 |
+
if int((array[0, 0]).tolist()) == 0:
|
| 143 |
+
start = 1
|
| 144 |
+
else:
|
| 145 |
+
start = 0
|
| 146 |
+
|
| 147 |
+
for i in range(0, batch_size):
|
| 148 |
+
# turns sequence from tensor to list skipps first row as this is <sos>
|
| 149 |
+
# for src
|
| 150 |
+
sequence = (src[1:, i]).tolist()
|
| 151 |
+
# goes from embedded to tokens
|
| 152 |
+
src_tokens = [src_field.vocab.itos[int(t)] for t in sequence]
|
| 153 |
+
# takes all tokens untill eos token, model would be faster if did this
|
| 154 |
+
# one step earlier, but then changes in vocab order would disrupt.
|
| 155 |
+
rev_tokens = list(
|
| 156 |
+
itertools.takewhile(lambda x: x != "<eos>", src_tokens))
|
| 157 |
+
smiles = "".join(rev_tokens)
|
| 158 |
+
sources.append(smiles)
|
| 159 |
+
|
| 160 |
+
# for each sequence in the batch
|
| 161 |
+
for i in range(0, batch_size):
|
| 162 |
+
# turns sequence from tensor to list skipps first row as this is not
|
| 163 |
+
# filled in in forward
|
| 164 |
+
sequence = (array[start:, i]).tolist()
|
| 165 |
+
# goes from embedded to tokens
|
| 166 |
+
trg_tokens = [trg_field.vocab.itos[int(t)] for t in sequence]
|
| 167 |
+
# print(trg_tokens)
|
| 168 |
+
# takes all tokens untill eos token, model would be faster if did this
|
| 169 |
+
# one step earlier, but then changes in vocab order would disrupt.
|
| 170 |
+
rev_tokens = list(
|
| 171 |
+
itertools.takewhile(lambda x: x != "<eos>", trg_tokens))
|
| 172 |
+
if reverse:
|
| 173 |
+
rev_tokens = rev_tokens[::-1]
|
| 174 |
+
smiles = "".join(rev_tokens)
|
| 175 |
+
# determine how many valid smiles are made
|
| 176 |
+
valid = True if MolFromSmiles(smiles) else False
|
| 177 |
+
if not valid:
|
| 178 |
+
if smiles == sources[i]:
|
| 179 |
+
unchanged += 1
|
| 180 |
+
|
| 181 |
+
return unchanged
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def molecule_reconstruction(array, TRG, reverse: bool, outputs):
|
| 185 |
+
"""Turns target tokens within batch into smiles and compares them to predicted output smiles
|
| 186 |
+
Arguments:
|
| 187 |
+
array: Tensor with target's token for each location for each sequence in batch
|
| 188 |
+
[trg len, batch size]
|
| 189 |
+
TRG: target field for getting tokens from vocab
|
| 190 |
+
reverse (bool): True if the target sequence is reversed
|
| 191 |
+
outputs: list of predicted SMILES sequences
|
| 192 |
+
Returns:
|
| 193 |
+
matches(int): number of total right molecules
|
| 194 |
+
"""
|
| 195 |
+
trg_field = TRG
|
| 196 |
+
matches = 0
|
| 197 |
+
targets = []
|
| 198 |
+
batch_size = array.size(1)
|
| 199 |
+
# for each sequence in the batch
|
| 200 |
+
for i in range(0, batch_size):
|
| 201 |
+
# turns sequence from tensor to list skipps first row as this is not
|
| 202 |
+
# filled in in forward
|
| 203 |
+
sequence = (array[1:, i]).tolist()
|
| 204 |
+
# goes from embedded to tokens
|
| 205 |
+
trg_tokens = [trg_field.vocab.itos[int(t)] for t in sequence]
|
| 206 |
+
# takes all tokens untill eos token, model would be faster if did this
|
| 207 |
+
# one step earlier, but then changes in vocab order would disrupt.
|
| 208 |
+
rev_tokens = list(
|
| 209 |
+
itertools.takewhile(lambda x: x != "<eos>", trg_tokens))
|
| 210 |
+
if reverse:
|
| 211 |
+
rev_tokens = rev_tokens[::-1]
|
| 212 |
+
smiles = "".join(rev_tokens)
|
| 213 |
+
targets.append(smiles)
|
| 214 |
+
for i in range(0, batch_size):
|
| 215 |
+
m = MolFromSmiles(targets[i])
|
| 216 |
+
p = MolFromSmiles(outputs[i])
|
| 217 |
+
if p is not None:
|
| 218 |
+
if m.HasSubstructMatch(p) and p.HasSubstructMatch(m):
|
| 219 |
+
matches += 1
|
| 220 |
+
return matches
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def complexity_whitlock(mol: Chem.Mol, includeAllDescs=False):
|
| 224 |
+
"""
|
| 225 |
+
Complexity as defined in DOI:10.1021/jo9814546
|
| 226 |
+
S: complexity = 4*#rings + 2*#unsat + #hetatm + 2*#chiral
|
| 227 |
+
Other descriptors:
|
| 228 |
+
H: size = #bonds (Hydrogen atoms included)
|
| 229 |
+
G: S + H
|
| 230 |
+
Ratio: S / H
|
| 231 |
+
"""
|
| 232 |
+
mol_ = Chem.Mol(mol)
|
| 233 |
+
nrings = Lipinski.RingCount(mol_) - Lipinski.NumAromaticRings(mol_)
|
| 234 |
+
Chem.rdmolops.SetAromaticity(mol_)
|
| 235 |
+
unsat = sum(1 for bond in mol_.GetBonds()
|
| 236 |
+
if bond.GetBondTypeAsDouble() == 2)
|
| 237 |
+
hetatm = len(mol_.GetSubstructMatches(Chem.MolFromSmarts("[!#6]")))
|
| 238 |
+
AllChem.EmbedMolecule(mol_)
|
| 239 |
+
Chem.AssignAtomChiralTagsFromStructure(mol_)
|
| 240 |
+
chiral = len(Chem.FindMolChiralCenters(mol_))
|
| 241 |
+
S = 4 * nrings + 2 * unsat + hetatm + 2 * chiral
|
| 242 |
+
if not includeAllDescs:
|
| 243 |
+
return S
|
| 244 |
+
Chem.rdmolops.Kekulize(mol_)
|
| 245 |
+
mol_ = Chem.AddHs(mol_)
|
| 246 |
+
H = sum(bond.GetBondTypeAsDouble() for bond in mol_.GetBonds())
|
| 247 |
+
G = S + H
|
| 248 |
+
R = S / H
|
| 249 |
+
return {"WhitlockS": S, "WhitlockH": H, "WhitlockG": G, "WhitlockRatio": R}
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def complexity_baronechanon(mol: Chem.Mol):
|
| 253 |
+
"""
|
| 254 |
+
Complexity as defined in DOI:10.1021/ci000145p
|
| 255 |
+
"""
|
| 256 |
+
mol_ = Chem.Mol(mol)
|
| 257 |
+
Chem.Kekulize(mol_)
|
| 258 |
+
Chem.RemoveStereochemistry(mol_)
|
| 259 |
+
mol_ = Chem.RemoveHs(mol_, updateExplicitCount=True)
|
| 260 |
+
degree, counts = 0, 0
|
| 261 |
+
for atom in mol_.GetAtoms():
|
| 262 |
+
degree += 3 * 2**(atom.GetExplicitValence() - atom.GetNumExplicitHs() -
|
| 263 |
+
1)
|
| 264 |
+
counts += 3 if atom.GetSymbol() == "C" else 6
|
| 265 |
+
ringterm = sum(map(lambda x: 6 * len(x), mol_.GetRingInfo().AtomRings()))
|
| 266 |
+
return degree + counts + ringterm
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def calc_complexity(array,
|
| 270 |
+
TRG,
|
| 271 |
+
reverse,
|
| 272 |
+
valids,
|
| 273 |
+
complexity_function=GraphDescriptors.BertzCT):
|
| 274 |
+
"""Calculates the complexity of inputs that are not correct.
|
| 275 |
+
Arguments:
|
| 276 |
+
array: Tensor with target's token for each location for each sequence in batch
|
| 277 |
+
[trg len, batch size]
|
| 278 |
+
TRG: target field for getting tokens from vocab
|
| 279 |
+
reverse (bool): True if the target sequence is reversed
|
| 280 |
+
valids: list with booleans that show if prediction was a valid SMILES (True) or invalid one (False)
|
| 281 |
+
complexity_function: the type of complexity measure that will be used
|
| 282 |
+
GraphDescriptors.BertzCT
|
| 283 |
+
complexity_whitlock
|
| 284 |
+
complexity_baronechanon
|
| 285 |
+
Returns:
|
| 286 |
+
matches(int): mean of complexity values
|
| 287 |
+
"""
|
| 288 |
+
trg_field = TRG
|
| 289 |
+
sources = []
|
| 290 |
+
complexities = []
|
| 291 |
+
loc = torch.BoolTensor(valids)
|
| 292 |
+
# only keeps rows in batch size dimension where valid is false
|
| 293 |
+
array = array[:, loc == False]
|
| 294 |
+
# should check if this still works
|
| 295 |
+
# array = torch.transpose(array, 0, 1)
|
| 296 |
+
array_size = array.size(1)
|
| 297 |
+
for i in range(0, array_size):
|
| 298 |
+
# turns sequence from tensor to list skipps first row as this is not
|
| 299 |
+
# filled in in forward
|
| 300 |
+
sequence = (array[1:, i]).tolist()
|
| 301 |
+
# goes from embedded to tokens
|
| 302 |
+
trg_tokens = [trg_field.vocab.itos[int(t)] for t in sequence]
|
| 303 |
+
# takes all tokens untill eos token, model would be faster if did this
|
| 304 |
+
# one step earlier, but then changes in vocab order would disrupt.
|
| 305 |
+
rev_tokens = list(
|
| 306 |
+
itertools.takewhile(lambda x: x != "<eos>", trg_tokens))
|
| 307 |
+
if reverse:
|
| 308 |
+
rev_tokens = rev_tokens[::-1]
|
| 309 |
+
smiles = "".join(rev_tokens)
|
| 310 |
+
sources.append(smiles)
|
| 311 |
+
for source in sources:
|
| 312 |
+
try:
|
| 313 |
+
m = MolFromSmiles(source)
|
| 314 |
+
except BaseException:
|
| 315 |
+
m = MolFromSLN(source)
|
| 316 |
+
complexities.append(complexity_function(m))
|
| 317 |
+
if len(complexities) > 0:
|
| 318 |
+
mean = statistics.mean(complexities)
|
| 319 |
+
else:
|
| 320 |
+
mean = 0
|
| 321 |
+
return mean
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def epoch_time(start_time, end_time):
|
| 325 |
+
elapsed_time = end_time - start_time
|
| 326 |
+
elapsed_mins = int(elapsed_time / 60)
|
| 327 |
+
elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
|
| 328 |
+
return elapsed_mins, elapsed_secs
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
class Convo:
|
| 332 |
+
"""Class for training and evaluating transformer and convolutional neural network
|
| 333 |
+
|
| 334 |
+
Methods
|
| 335 |
+
-------
|
| 336 |
+
train_model()
|
| 337 |
+
train model for initialized number of epochs
|
| 338 |
+
evaluate(return_output)
|
| 339 |
+
use model with validation loader (& optionally drugex loader) to get test loss & other metrics
|
| 340 |
+
translate(loader)
|
| 341 |
+
translate inputs from loader (different from evaluate in that no target sequence is used)
|
| 342 |
+
"""
|
| 343 |
+
|
| 344 |
+
def train_model(self):
|
| 345 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr)
|
| 346 |
+
log = open(f"{self.out}.log", "a")
|
| 347 |
+
best_error = np.inf
|
| 348 |
+
for epoch in range(self.epochs):
|
| 349 |
+
self.train()
|
| 350 |
+
start_time = time.time()
|
| 351 |
+
loss_train = 0
|
| 352 |
+
for i, batch in enumerate(self.loader_train):
|
| 353 |
+
optimizer.zero_grad()
|
| 354 |
+
# changed src,trg call to match with bentrevett
|
| 355 |
+
# src, trg = batch['src'], batch['trg']
|
| 356 |
+
trg = batch.trg
|
| 357 |
+
src = batch.src
|
| 358 |
+
output, attention = self(src, trg[:, :-1])
|
| 359 |
+
# feed the source and target into def forward to get the output
|
| 360 |
+
# Xuhan uses forward for this, with istrain = true
|
| 361 |
+
output_dim = output.shape[-1]
|
| 362 |
+
# changed
|
| 363 |
+
output = output.contiguous().view(-1, output_dim)
|
| 364 |
+
trg = trg[:, 1:].contiguous().view(-1)
|
| 365 |
+
# output = output[:,:,0]#.view(-1)
|
| 366 |
+
# output = output[1:].view(-1, output.shape[-1])
|
| 367 |
+
# trg = trg[1:].view(-1)
|
| 368 |
+
loss = nn.CrossEntropyLoss(
|
| 369 |
+
ignore_index=self.TRG.vocab.stoi[self.TRG.pad_token])
|
| 370 |
+
a, b = output.view(-1), trg.to(self.device).view(-1)
|
| 371 |
+
# changed
|
| 372 |
+
# loss = loss(output.view(0), trg.view(0).to(device))
|
| 373 |
+
loss = loss(output, trg)
|
| 374 |
+
loss.backward()
|
| 375 |
+
torch.nn.utils.clip_grad_norm_(self.parameters(), self.clip)
|
| 376 |
+
optimizer.step()
|
| 377 |
+
loss_train += loss.item()
|
| 378 |
+
# turned off for now, as not using voc so won't work, output is a tensor
|
| 379 |
+
# output = [(trg len - 1) * batch size, output dim]
|
| 380 |
+
# smiles, valid = is_valid_smiles(output, reversed)
|
| 381 |
+
# if valid:
|
| 382 |
+
# valids += 1
|
| 383 |
+
# smiless.append(smiles)
|
| 384 |
+
# added .dataset becaue len(iterator) gives len(self.dataset) /
|
| 385 |
+
# self.batch_size)
|
| 386 |
+
loss_train /= len(self.loader_train)
|
| 387 |
+
info = f"Epoch: {epoch+1:02} step: {i} loss_train: {loss_train:.4g}"
|
| 388 |
+
# model is used to generate trg based on src from the validation set to assess performance
|
| 389 |
+
# similar to Xuhan, although he doesn't use the if loop
|
| 390 |
+
if self.loader_valid is not None:
|
| 391 |
+
return_output = False
|
| 392 |
+
if epoch + 1 == self.epochs:
|
| 393 |
+
return_output = True
|
| 394 |
+
(
|
| 395 |
+
valids,
|
| 396 |
+
loss_valid,
|
| 397 |
+
valids_de,
|
| 398 |
+
df_output,
|
| 399 |
+
df_output_de,
|
| 400 |
+
right_molecules,
|
| 401 |
+
complexity,
|
| 402 |
+
unchanged,
|
| 403 |
+
unchanged_de,
|
| 404 |
+
) = self.evaluate(return_output)
|
| 405 |
+
reconstruction_error = 1 - right_molecules / len(
|
| 406 |
+
self.loader_valid.dataset)
|
| 407 |
+
error = 1 - valids / len(self.loader_valid.dataset)
|
| 408 |
+
complexity = complexity / len(self.loader_valid)
|
| 409 |
+
unchan = unchanged / (len(self.loader_valid.dataset) - valids)
|
| 410 |
+
info += f" loss_valid: {loss_valid:.4g} error_rate: {error:.4g} molecule_reconstruction_error_rate: {reconstruction_error:.4g} unchanged: {unchan:.4g} invalid_target_complexity: {complexity:.4g}"
|
| 411 |
+
if self.loader_drugex is not None:
|
| 412 |
+
error_de = 1 - valids_de / len(self.loader_drugex.dataset)
|
| 413 |
+
unchan_de = unchanged_de / (
|
| 414 |
+
len(self.loader_drugex.dataset) - valids_de)
|
| 415 |
+
info += f" error_rate_drugex: {error_de:.4g} unchanged_drugex: {unchan_de:.4g}"
|
| 416 |
+
|
| 417 |
+
if reconstruction_error < best_error:
|
| 418 |
+
torch.save(self.state_dict(), f"{self.out}.pkg")
|
| 419 |
+
best_error = reconstruction_error
|
| 420 |
+
last_save = epoch
|
| 421 |
+
else:
|
| 422 |
+
if epoch - last_save >= 10 and best_error != 1:
|
| 423 |
+
torch.save(self.state_dict(), f"{self.out}_last.pkg")
|
| 424 |
+
(
|
| 425 |
+
valids,
|
| 426 |
+
loss_valid,
|
| 427 |
+
valids_de,
|
| 428 |
+
df_output,
|
| 429 |
+
df_output_de,
|
| 430 |
+
right_molecules,
|
| 431 |
+
complexity,
|
| 432 |
+
unchanged,
|
| 433 |
+
unchanged_de,
|
| 434 |
+
) = self.evaluate(True)
|
| 435 |
+
end_time = time.time()
|
| 436 |
+
epoch_mins, epoch_secs = epoch_time(
|
| 437 |
+
start_time, end_time)
|
| 438 |
+
info += f" Time: {epoch_mins}m {epoch_secs}s"
|
| 439 |
+
|
| 440 |
+
break
|
| 441 |
+
elif error < best_error:
|
| 442 |
+
torch.save(self.state_dict(), f"{self.out}.pkg")
|
| 443 |
+
best_error = error
|
| 444 |
+
end_time = time.time()
|
| 445 |
+
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
|
| 446 |
+
info += f" Time: {epoch_mins}m {epoch_secs}s"
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
torch.save(self.state_dict(), f"{self.out}_last.pkg")
|
| 450 |
+
log.close()
|
| 451 |
+
self.load_state_dict(torch.load(f"{self.out}.pkg"))
|
| 452 |
+
df_output.to_csv(f"{self.out}.csv", index=False)
|
| 453 |
+
df_output_de.to_csv(f"{self.out}_de.csv", index=False)
|
| 454 |
+
|
| 455 |
+
def evaluate(self, return_output):
|
| 456 |
+
self.eval()
|
| 457 |
+
test_loss = 0
|
| 458 |
+
df_output = pd.DataFrame()
|
| 459 |
+
df_output_de = pd.DataFrame()
|
| 460 |
+
valids = 0
|
| 461 |
+
valids_de = 0
|
| 462 |
+
unchanged = 0
|
| 463 |
+
unchanged_de = 0
|
| 464 |
+
right_molecules = 0
|
| 465 |
+
complexity = 0
|
| 466 |
+
with torch.no_grad():
|
| 467 |
+
for _, batch in enumerate(self.loader_valid):
|
| 468 |
+
trg = batch.trg
|
| 469 |
+
src = batch.src
|
| 470 |
+
output, attention = self.forward(src, trg[:, :-1])
|
| 471 |
+
pred_token = output.argmax(2)
|
| 472 |
+
array = torch.transpose(pred_token, 0, 1)
|
| 473 |
+
trg_trans = torch.transpose(trg, 0, 1)
|
| 474 |
+
output_dim = output.shape[-1]
|
| 475 |
+
output = output.contiguous().view(-1, output_dim)
|
| 476 |
+
trg = trg[:, 1:].contiguous().view(-1)
|
| 477 |
+
src_trans = torch.transpose(src, 0, 1)
|
| 478 |
+
df_batch, valid, smiless = is_smiles(
|
| 479 |
+
array, self.TRG, reverse=True, return_output=return_output)
|
| 480 |
+
unchanged += is_unchanged(
|
| 481 |
+
array,
|
| 482 |
+
self.TRG,
|
| 483 |
+
reverse=True,
|
| 484 |
+
return_output=return_output,
|
| 485 |
+
src=src_trans,
|
| 486 |
+
src_field=self.SRC,
|
| 487 |
+
)
|
| 488 |
+
matches = molecule_reconstruction(trg_trans,
|
| 489 |
+
self.TRG,
|
| 490 |
+
reverse=True,
|
| 491 |
+
outputs=smiless)
|
| 492 |
+
complexity += calc_complexity(trg_trans,
|
| 493 |
+
self.TRG,
|
| 494 |
+
reverse=True,
|
| 495 |
+
valids=valid)
|
| 496 |
+
if df_batch is not None:
|
| 497 |
+
df_output = pd.concat([df_output, df_batch],
|
| 498 |
+
ignore_index=True)
|
| 499 |
+
right_molecules += matches
|
| 500 |
+
valids += sum(valid)
|
| 501 |
+
# trg = trg[1:].view(-1)
|
| 502 |
+
# output, trg = output[1:].view(-1, output.shape[-1]), trg[1:].view(-1)
|
| 503 |
+
loss = nn.CrossEntropyLoss(
|
| 504 |
+
ignore_index=self.TRG.vocab.stoi[self.TRG.pad_token])
|
| 505 |
+
loss = loss(output, trg)
|
| 506 |
+
test_loss += loss.item()
|
| 507 |
+
if self.loader_drugex is not None:
|
| 508 |
+
for _, batch in enumerate(self.loader_drugex):
|
| 509 |
+
src = batch.src
|
| 510 |
+
output = self.translate_sentence(src, self.TRG,
|
| 511 |
+
self.device)
|
| 512 |
+
# checks the number of valid smiles
|
| 513 |
+
pred_token = output.argmax(2)
|
| 514 |
+
array = torch.transpose(pred_token, 0, 1)
|
| 515 |
+
src_trans = torch.transpose(src, 0, 1)
|
| 516 |
+
df_batch, valid, smiless = is_smiles(
|
| 517 |
+
array,
|
| 518 |
+
self.TRG,
|
| 519 |
+
reverse=True,
|
| 520 |
+
return_output=return_output,
|
| 521 |
+
src=src_trans,
|
| 522 |
+
src_field=self.SRC,
|
| 523 |
+
)
|
| 524 |
+
unchanged_de += is_unchanged(
|
| 525 |
+
array,
|
| 526 |
+
self.TRG,
|
| 527 |
+
reverse=True,
|
| 528 |
+
return_output=return_output,
|
| 529 |
+
src=src_trans,
|
| 530 |
+
src_field=self.SRC,
|
| 531 |
+
)
|
| 532 |
+
if df_batch is not None:
|
| 533 |
+
df_output_de = pd.concat([df_output_de, df_batch],
|
| 534 |
+
ignore_index=True)
|
| 535 |
+
valids_de += sum(valid)
|
| 536 |
+
return (
|
| 537 |
+
valids,
|
| 538 |
+
test_loss / len(self.loader_valid),
|
| 539 |
+
valids_de,
|
| 540 |
+
df_output,
|
| 541 |
+
df_output_de,
|
| 542 |
+
right_molecules,
|
| 543 |
+
complexity,
|
| 544 |
+
unchanged,
|
| 545 |
+
unchanged_de,
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
def translate(self, loader):
|
| 549 |
+
self.eval()
|
| 550 |
+
df_output_de = pd.DataFrame()
|
| 551 |
+
valids_de = 0
|
| 552 |
+
with torch.no_grad():
|
| 553 |
+
for _, batch in enumerate(loader):
|
| 554 |
+
src = batch.src
|
| 555 |
+
output = self.translate_sentence(src, self.TRG, self.device)
|
| 556 |
+
# checks the number of valid smiles
|
| 557 |
+
pred_token = output.argmax(2)
|
| 558 |
+
array = torch.transpose(pred_token, 0, 1)
|
| 559 |
+
src_trans = torch.transpose(src, 0, 1)
|
| 560 |
+
df_batch, valid, smiless = is_smiles(
|
| 561 |
+
array,
|
| 562 |
+
self.TRG,
|
| 563 |
+
reverse=True,
|
| 564 |
+
return_output=True,
|
| 565 |
+
src=src_trans,
|
| 566 |
+
src_field=self.SRC,
|
| 567 |
+
)
|
| 568 |
+
if df_batch is not None:
|
| 569 |
+
df_output_de = pd.concat([df_output_de, df_batch],
|
| 570 |
+
ignore_index=True)
|
| 571 |
+
valids_de += sum(valid)
|
| 572 |
+
return valids_de, df_output_de
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
class Encoder(nn.Module):
|
| 576 |
+
|
| 577 |
+
def __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, dropout,
|
| 578 |
+
max_length, device):
|
| 579 |
+
super().__init__()
|
| 580 |
+
self.device = device
|
| 581 |
+
self.tok_embedding = nn.Embedding(input_dim, hid_dim)
|
| 582 |
+
self.pos_embedding = nn.Embedding(max_length, hid_dim)
|
| 583 |
+
self.layers = nn.ModuleList([
|
| 584 |
+
EncoderLayer(hid_dim, n_heads, pf_dim, dropout, device)
|
| 585 |
+
for _ in range(n_layers)
|
| 586 |
+
])
|
| 587 |
+
|
| 588 |
+
self.dropout = nn.Dropout(dropout)
|
| 589 |
+
self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
|
| 590 |
+
|
| 591 |
+
def forward(self, src, src_mask):
|
| 592 |
+
# src = [batch size, src len]
|
| 593 |
+
# src_mask = [batch size, src len]
|
| 594 |
+
batch_size = src.shape[0]
|
| 595 |
+
src_len = src.shape[1]
|
| 596 |
+
pos = (torch.arange(0, src_len).unsqueeze(0).repeat(batch_size,
|
| 597 |
+
1).to(self.device))
|
| 598 |
+
# pos = [batch size, src len]
|
| 599 |
+
src = self.dropout((self.tok_embedding(src) * self.scale) +
|
| 600 |
+
self.pos_embedding(pos))
|
| 601 |
+
# src = [batch size, src len, hid dim]
|
| 602 |
+
for layer in self.layers:
|
| 603 |
+
src = layer(src, src_mask)
|
| 604 |
+
# src = [batch size, src len, hid dim]
|
| 605 |
+
return src
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
class EncoderLayer(nn.Module):
|
| 609 |
+
|
| 610 |
+
def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
|
| 611 |
+
super().__init__()
|
| 612 |
+
|
| 613 |
+
self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
|
| 614 |
+
self.ff_layer_norm = nn.LayerNorm(hid_dim)
|
| 615 |
+
self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads,
|
| 616 |
+
dropout, device)
|
| 617 |
+
self.positionwise_feedforward = PositionwiseFeedforwardLayer(
|
| 618 |
+
hid_dim, pf_dim, dropout)
|
| 619 |
+
self.dropout = nn.Dropout(dropout)
|
| 620 |
+
|
| 621 |
+
def forward(self, src, src_mask):
|
| 622 |
+
# src = [batch size, src len, hid dim]
|
| 623 |
+
# src_mask = [batch size, src len]
|
| 624 |
+
# self attention
|
| 625 |
+
_src, _ = self.self_attention(src, src, src, src_mask)
|
| 626 |
+
# dropout, residual connection and layer norm
|
| 627 |
+
src = self.self_attn_layer_norm(src + self.dropout(_src))
|
| 628 |
+
# src = [batch size, src len, hid dim]
|
| 629 |
+
# positionwise feedforward
|
| 630 |
+
_src = self.positionwise_feedforward(src)
|
| 631 |
+
# dropout, residual and layer norm
|
| 632 |
+
src = self.ff_layer_norm(src + self.dropout(_src))
|
| 633 |
+
# src = [batch size, src len, hid dim]
|
| 634 |
+
|
| 635 |
+
return src
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
class MultiHeadAttentionLayer(nn.Module):
|
| 639 |
+
|
| 640 |
+
def __init__(self, hid_dim, n_heads, dropout, device):
|
| 641 |
+
super().__init__()
|
| 642 |
+
assert hid_dim % n_heads == 0
|
| 643 |
+
self.hid_dim = hid_dim
|
| 644 |
+
self.n_heads = n_heads
|
| 645 |
+
self.head_dim = hid_dim // n_heads
|
| 646 |
+
self.fc_q = nn.Linear(hid_dim, hid_dim)
|
| 647 |
+
self.fc_k = nn.Linear(hid_dim, hid_dim)
|
| 648 |
+
self.fc_v = nn.Linear(hid_dim, hid_dim)
|
| 649 |
+
self.fc_o = nn.Linear(hid_dim, hid_dim)
|
| 650 |
+
self.dropout = nn.Dropout(dropout)
|
| 651 |
+
self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
|
| 652 |
+
|
| 653 |
+
def forward(self, query, key, value, mask=None):
|
| 654 |
+
batch_size = query.shape[0]
|
| 655 |
+
# query = [batch size, query len, hid dim]
|
| 656 |
+
# key = [batch size, key len, hid dim]
|
| 657 |
+
# value = [batch size, value len, hid dim]
|
| 658 |
+
Q = self.fc_q(query)
|
| 659 |
+
K = self.fc_k(key)
|
| 660 |
+
V = self.fc_v(value)
|
| 661 |
+
# Q = [batch size, query len, hid dim]
|
| 662 |
+
# K = [batch size, key len, hid dim]
|
| 663 |
+
# V = [batch size, value len, hid dim]
|
| 664 |
+
Q = Q.view(batch_size, -1, self.n_heads,
|
| 665 |
+
self.head_dim).permute(0, 2, 1, 3)
|
| 666 |
+
K = K.view(batch_size, -1, self.n_heads,
|
| 667 |
+
self.head_dim).permute(0, 2, 1, 3)
|
| 668 |
+
V = V.view(batch_size, -1, self.n_heads,
|
| 669 |
+
self.head_dim).permute(0, 2, 1, 3)
|
| 670 |
+
# Q = [batch size, n heads, query len, head dim]
|
| 671 |
+
# K = [batch size, n heads, key len, head dim]
|
| 672 |
+
# V = [batch size, n heads, value len, head dim]
|
| 673 |
+
energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
|
| 674 |
+
# energy = [batch size, n heads, query len, key len]
|
| 675 |
+
if mask is not None:
|
| 676 |
+
energy = energy.masked_fill(mask == 0, -1e10)
|
| 677 |
+
attention = torch.softmax(energy, dim=-1)
|
| 678 |
+
# attention = [batch size, n heads, query len, key len]
|
| 679 |
+
x = torch.matmul(self.dropout(attention), V)
|
| 680 |
+
# x = [batch size, n heads, query len, head dim]
|
| 681 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
| 682 |
+
# x = [batch size, query len, n heads, head dim]
|
| 683 |
+
x = x.view(batch_size, -1, self.hid_dim)
|
| 684 |
+
# x = [batch size, query len, hid dim]
|
| 685 |
+
x = self.fc_o(x)
|
| 686 |
+
# x = [batch size, query len, hid dim]
|
| 687 |
+
return x, attention
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
class PositionwiseFeedforwardLayer(nn.Module):
|
| 691 |
+
|
| 692 |
+
def __init__(self, hid_dim, pf_dim, dropout):
|
| 693 |
+
super().__init__()
|
| 694 |
+
self.fc_1 = nn.Linear(hid_dim, pf_dim)
|
| 695 |
+
self.fc_2 = nn.Linear(pf_dim, hid_dim)
|
| 696 |
+
self.dropout = nn.Dropout(dropout)
|
| 697 |
+
|
| 698 |
+
def forward(self, x):
|
| 699 |
+
# x = [batch size, seq len, hid dim]
|
| 700 |
+
x = self.dropout(torch.relu(self.fc_1(x)))
|
| 701 |
+
# x = [batch size, seq len, pf dim]
|
| 702 |
+
x = self.fc_2(x)
|
| 703 |
+
# x = [batch size, seq len, hid dim]
|
| 704 |
+
|
| 705 |
+
return x
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
class Decoder(nn.Module):
|
| 709 |
+
|
| 710 |
+
def __init__(
|
| 711 |
+
self,
|
| 712 |
+
output_dim,
|
| 713 |
+
hid_dim,
|
| 714 |
+
n_layers,
|
| 715 |
+
n_heads,
|
| 716 |
+
pf_dim,
|
| 717 |
+
dropout,
|
| 718 |
+
max_length,
|
| 719 |
+
device,
|
| 720 |
+
):
|
| 721 |
+
super().__init__()
|
| 722 |
+
self.device = device
|
| 723 |
+
self.tok_embedding = nn.Embedding(output_dim, hid_dim)
|
| 724 |
+
self.pos_embedding = nn.Embedding(max_length, hid_dim)
|
| 725 |
+
self.layers = nn.ModuleList([
|
| 726 |
+
DecoderLayer(hid_dim, n_heads, pf_dim, dropout, device)
|
| 727 |
+
for _ in range(n_layers)
|
| 728 |
+
])
|
| 729 |
+
self.fc_out = nn.Linear(hid_dim, output_dim)
|
| 730 |
+
self.dropout = nn.Dropout(dropout)
|
| 731 |
+
self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
|
| 732 |
+
|
| 733 |
+
def forward(self, trg, enc_src, trg_mask, src_mask):
|
| 734 |
+
# trg = [batch size, trg len]
|
| 735 |
+
# enc_src = [batch size, src len, hid dim]
|
| 736 |
+
# trg_mask = [batch size, trg len]
|
| 737 |
+
# src_mask = [batch size, src len]
|
| 738 |
+
batch_size = trg.shape[0]
|
| 739 |
+
trg_len = trg.shape[1]
|
| 740 |
+
pos = (torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size,
|
| 741 |
+
1).to(self.device))
|
| 742 |
+
# pos = [batch size, trg len]
|
| 743 |
+
trg = self.dropout((self.tok_embedding(trg) * self.scale) +
|
| 744 |
+
self.pos_embedding(pos))
|
| 745 |
+
# trg = [batch size, trg len, hid dim]
|
| 746 |
+
for layer in self.layers:
|
| 747 |
+
trg, attention = layer(trg, enc_src, trg_mask, src_mask)
|
| 748 |
+
# trg = [batch size, trg len, hid dim]
|
| 749 |
+
# attention = [batch size, n heads, trg len, src len]
|
| 750 |
+
output = self.fc_out(trg)
|
| 751 |
+
# output = [batch size, trg len, output dim]
|
| 752 |
+
return output, attention
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
class DecoderLayer(nn.Module):
|
| 756 |
+
|
| 757 |
+
def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
|
| 758 |
+
super().__init__()
|
| 759 |
+
self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
|
| 760 |
+
self.enc_attn_layer_norm = nn.LayerNorm(hid_dim)
|
| 761 |
+
self.ff_layer_norm = nn.LayerNorm(hid_dim)
|
| 762 |
+
self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads,
|
| 763 |
+
dropout, device)
|
| 764 |
+
self.encoder_attention = MultiHeadAttentionLayer(
|
| 765 |
+
hid_dim, n_heads, dropout, device)
|
| 766 |
+
self.positionwise_feedforward = PositionwiseFeedforwardLayer(
|
| 767 |
+
hid_dim, pf_dim, dropout)
|
| 768 |
+
self.dropout = nn.Dropout(dropout)
|
| 769 |
+
|
| 770 |
+
def forward(self, trg, enc_src, trg_mask, src_mask):
|
| 771 |
+
# trg = [batch size, trg len, hid dim]
|
| 772 |
+
# enc_src = [batch size, src len, hid dim]
|
| 773 |
+
# trg_mask = [batch size, trg len]
|
| 774 |
+
# src_mask = [batch size, src len]
|
| 775 |
+
# self attention
|
| 776 |
+
_trg, _ = self.self_attention(trg, trg, trg, trg_mask)
|
| 777 |
+
# dropout, residual connection and layer norm
|
| 778 |
+
trg = self.self_attn_layer_norm(trg + self.dropout(_trg))
|
| 779 |
+
# trg = [batch size, trg len, hid dim]
|
| 780 |
+
# encoder attention
|
| 781 |
+
_trg, attention = self.encoder_attention(trg, enc_src, enc_src,
|
| 782 |
+
src_mask)
|
| 783 |
+
# dropout, residual connection and layer norm
|
| 784 |
+
trg = self.enc_attn_layer_norm(trg + self.dropout(_trg))
|
| 785 |
+
# trg = [batch size, trg len, hid dim]
|
| 786 |
+
# positionwise feedforward
|
| 787 |
+
_trg = self.positionwise_feedforward(trg)
|
| 788 |
+
# dropout, residual and layer norm
|
| 789 |
+
trg = self.ff_layer_norm(trg + self.dropout(_trg))
|
| 790 |
+
# trg = [batch size, trg len, hid dim]
|
| 791 |
+
# attention = [batch size, n heads, trg len, src len]
|
| 792 |
+
return trg, attention
|
| 793 |
+
|
| 794 |
+
|
| 795 |
+
class Seq2Seq(nn.Module, Convo):
|
| 796 |
+
|
| 797 |
+
def __init__(
|
| 798 |
+
self,
|
| 799 |
+
encoder,
|
| 800 |
+
decoder,
|
| 801 |
+
src_pad_idx,
|
| 802 |
+
trg_pad_idx,
|
| 803 |
+
device,
|
| 804 |
+
loader_train: DataLoader,
|
| 805 |
+
out: str,
|
| 806 |
+
loader_valid=None,
|
| 807 |
+
loader_drugex=None,
|
| 808 |
+
epochs=100,
|
| 809 |
+
lr=0.0005,
|
| 810 |
+
clip=0.1,
|
| 811 |
+
reverse=True,
|
| 812 |
+
TRG=None,
|
| 813 |
+
SRC=None,
|
| 814 |
+
):
|
| 815 |
+
super().__init__()
|
| 816 |
+
self.encoder = encoder
|
| 817 |
+
self.decoder = decoder
|
| 818 |
+
self.src_pad_idx = src_pad_idx
|
| 819 |
+
self.trg_pad_idx = trg_pad_idx
|
| 820 |
+
self.device = device
|
| 821 |
+
self.loader_train = loader_train
|
| 822 |
+
self.out = out
|
| 823 |
+
self.loader_valid = loader_valid
|
| 824 |
+
self.loader_drugex = loader_drugex
|
| 825 |
+
self.epochs = epochs
|
| 826 |
+
self.lr = lr
|
| 827 |
+
self.clip = clip
|
| 828 |
+
self.reverse = reverse
|
| 829 |
+
self.TRG = TRG
|
| 830 |
+
self.SRC = SRC
|
| 831 |
+
|
| 832 |
+
def make_src_mask(self, src):
|
| 833 |
+
# src = [batch size, src len]
|
| 834 |
+
src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
|
| 835 |
+
# src_mask = [batch size, 1, 1, src len]
|
| 836 |
+
return src_mask
|
| 837 |
+
|
| 838 |
+
def make_trg_mask(self, trg):
|
| 839 |
+
# trg = [batch size, trg len]
|
| 840 |
+
trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
|
| 841 |
+
# trg_pad_mask = [batch size, 1, 1, trg len]
|
| 842 |
+
trg_len = trg.shape[1]
|
| 843 |
+
trg_sub_mask = torch.tril(
|
| 844 |
+
torch.ones((trg_len, trg_len), device=self.device)).bool()
|
| 845 |
+
# trg_sub_mask = [trg len, trg len]
|
| 846 |
+
trg_mask = trg_pad_mask & trg_sub_mask
|
| 847 |
+
# trg_mask = [batch size, 1, trg len, trg len]
|
| 848 |
+
return trg_mask
|
| 849 |
+
|
| 850 |
+
def forward(self, src, trg):
|
| 851 |
+
# src = [batch size, src len]
|
| 852 |
+
# trg = [batch size, trg len]
|
| 853 |
+
src_mask = self.make_src_mask(src)
|
| 854 |
+
trg_mask = self.make_trg_mask(trg)
|
| 855 |
+
# src_mask = [batch size, 1, 1, src len]
|
| 856 |
+
# trg_mask = [batch size, 1, trg len, trg len]
|
| 857 |
+
enc_src = self.encoder(src, src_mask)
|
| 858 |
+
# enc_src = [batch size, src len, hid dim]
|
| 859 |
+
output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)
|
| 860 |
+
# output = [batch size, trg len, output dim]
|
| 861 |
+
# attention = [batch size, n heads, trg len, src len]
|
| 862 |
+
return output, attention
|
| 863 |
+
|
| 864 |
+
def translate_sentence(self, src, trg_field, device, max_len=202):
|
| 865 |
+
self.eval()
|
| 866 |
+
src_mask = self.make_src_mask(src)
|
| 867 |
+
with torch.no_grad():
|
| 868 |
+
enc_src = self.encoder(src, src_mask)
|
| 869 |
+
trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]
|
| 870 |
+
batch_size = src.shape[0]
|
| 871 |
+
trg = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)
|
| 872 |
+
trg = trg.repeat(batch_size, 1)
|
| 873 |
+
for i in range(max_len):
|
| 874 |
+
# turned model into self.
|
| 875 |
+
trg_mask = self.make_trg_mask(trg)
|
| 876 |
+
with torch.no_grad():
|
| 877 |
+
output, attention = self.decoder(trg, enc_src, trg_mask,
|
| 878 |
+
src_mask)
|
| 879 |
+
pred_tokens = output.argmax(2)[:, -1].unsqueeze(1)
|
| 880 |
+
trg = torch.cat((trg, pred_tokens), 1)
|
| 881 |
+
|
| 882 |
+
return output
|
| 883 |
+
|
| 884 |
+
|
| 885 |
+
def remove_floats(df: pd.DataFrame, subset: str):
|
| 886 |
+
"""Preprocessing step to remove any entries that are not strings"""
|
| 887 |
+
df_subset = df[subset]
|
| 888 |
+
df[subset] = df[subset].astype(str)
|
| 889 |
+
# only keep entries that stayed the same after applying astype str
|
| 890 |
+
df = df[df[subset] == df_subset].copy()
|
| 891 |
+
|
| 892 |
+
return df
|
| 893 |
+
|
| 894 |
+
|
| 895 |
+
def smi_tokenizer(smi: str, reverse=False) -> list:
|
| 896 |
+
"""
|
| 897 |
+
Tokenize a SMILES molecule
|
| 898 |
+
"""
|
| 899 |
+
pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
|
| 900 |
+
regex = re.compile(pattern)
|
| 901 |
+
# tokens = ['<sos>'] + [token for token in regex.findall(smi)] + ['<eos>']
|
| 902 |
+
tokens = [token for token in regex.findall(smi)]
|
| 903 |
+
# assert smi == ''.join(tokens[1:-1])
|
| 904 |
+
assert smi == "".join(tokens[:])
|
| 905 |
+
# try:
|
| 906 |
+
# assert smi == "".join(tokens[:])
|
| 907 |
+
# except:
|
| 908 |
+
# print(smi)
|
| 909 |
+
# print("".join(tokens[:]))
|
| 910 |
+
if reverse:
|
| 911 |
+
return tokens[::-1]
|
| 912 |
+
return tokens
|
| 913 |
+
|
| 914 |
+
|
| 915 |
+
def init_weights(m: nn.Module):
|
| 916 |
+
if hasattr(m, "weight") and m.weight.dim() > 1:
|
| 917 |
+
nn.init.xavier_uniform_(m.weight.data)
|
| 918 |
+
|
| 919 |
+
|
| 920 |
+
def count_parameters(model: nn.Module):
|
| 921 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 922 |
+
|
| 923 |
+
|
| 924 |
+
def epoch_time(start_time, end_time):
|
| 925 |
+
elapsed_time = end_time - start_time
|
| 926 |
+
elapsed_mins = int(elapsed_time / 60)
|
| 927 |
+
elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
|
| 928 |
+
return elapsed_mins, elapsed_secs
|
| 929 |
+
|
| 930 |
+
|
| 931 |
+
def initialize_model(folder_out: str,
|
| 932 |
+
data_source: str,
|
| 933 |
+
error_source: str,
|
| 934 |
+
device: torch.device,
|
| 935 |
+
threshold: int,
|
| 936 |
+
epochs: int,
|
| 937 |
+
layers: int = 3,
|
| 938 |
+
batch_size: int = 16,
|
| 939 |
+
invalid_type: str = "all",
|
| 940 |
+
num_errors: int = 1,
|
| 941 |
+
validation_step=False):
|
| 942 |
+
"""Create encoder decoder models for specified model (currently only translator) & type of invalid SMILES
|
| 943 |
+
|
| 944 |
+
param data: collection of invalid, valid SMILES pairs
|
| 945 |
+
param invalid_smiles_path: path to previously generated invalid SMILES
|
| 946 |
+
param invalid_type: type of errors introduced into invalid SMILES
|
| 947 |
+
|
| 948 |
+
return:
|
| 949 |
+
|
| 950 |
+
"""
|
| 951 |
+
|
| 952 |
+
# set fields
|
| 953 |
+
SRC = Field(
|
| 954 |
+
tokenize=lambda x: smi_tokenizer(x),
|
| 955 |
+
init_token="<sos>",
|
| 956 |
+
eos_token="<eos>",
|
| 957 |
+
batch_first=True,
|
| 958 |
+
)
|
| 959 |
+
TRG = Field(
|
| 960 |
+
tokenize=lambda x: smi_tokenizer(x, reverse=True),
|
| 961 |
+
init_token="<sos>",
|
| 962 |
+
eos_token="<eos>",
|
| 963 |
+
batch_first=True,
|
| 964 |
+
)
|
| 965 |
+
|
| 966 |
+
if validation_step:
|
| 967 |
+
train, val = TabularDataset.splits(
|
| 968 |
+
path=f'{folder_out}errors/split/',
|
| 969 |
+
train=f"{data_source}_{invalid_type}_{num_errors}_errors_train.csv",
|
| 970 |
+
validation=
|
| 971 |
+
f"{data_source}_{invalid_type}_{num_errors}_errors_dev.csv",
|
| 972 |
+
format="CSV",
|
| 973 |
+
skip_header=False,
|
| 974 |
+
fields={
|
| 975 |
+
"ERROR": ("src", SRC),
|
| 976 |
+
"STD_SMILES": ("trg", TRG)
|
| 977 |
+
},
|
| 978 |
+
)
|
| 979 |
+
SRC.build_vocab(train, val, max_size=1000)
|
| 980 |
+
TRG.build_vocab(train, val, max_size=1000)
|
| 981 |
+
else:
|
| 982 |
+
train = TabularDataset(
|
| 983 |
+
path=
|
| 984 |
+
f'{folder_out}{data_source}_{invalid_type}_{num_errors}_errors.csv',
|
| 985 |
+
format="CSV",
|
| 986 |
+
skip_header=False,
|
| 987 |
+
fields={
|
| 988 |
+
"ERROR": ("src", SRC),
|
| 989 |
+
"STD_SMILES": ("trg", TRG)
|
| 990 |
+
},
|
| 991 |
+
)
|
| 992 |
+
SRC.build_vocab(train, max_size=1000)
|
| 993 |
+
TRG.build_vocab(train, max_size=1000)
|
| 994 |
+
|
| 995 |
+
drugex = TabularDataset(
|
| 996 |
+
path=error_source,
|
| 997 |
+
format="csv",
|
| 998 |
+
skip_header=False,
|
| 999 |
+
fields={
|
| 1000 |
+
"SMILES": ("src", SRC),
|
| 1001 |
+
"SMILES_TARGET": ("trg", TRG)
|
| 1002 |
+
},
|
| 1003 |
+
)
|
| 1004 |
+
|
| 1005 |
+
|
| 1006 |
+
#SRC.vocab = torch.load('vocab_src.pth')
|
| 1007 |
+
#TRG.vocab = torch.load('vocab_trg.pth')
|
| 1008 |
+
|
| 1009 |
+
# model parameters
|
| 1010 |
+
EPOCHS = epochs
|
| 1011 |
+
BATCH_SIZE = batch_size
|
| 1012 |
+
INPUT_DIM = len(SRC.vocab)
|
| 1013 |
+
OUTPUT_DIM = len(TRG.vocab)
|
| 1014 |
+
HID_DIM = 256
|
| 1015 |
+
ENC_LAYERS = layers
|
| 1016 |
+
DEC_LAYERS = layers
|
| 1017 |
+
ENC_HEADS = 8
|
| 1018 |
+
DEC_HEADS = 8
|
| 1019 |
+
ENC_PF_DIM = 512
|
| 1020 |
+
DEC_PF_DIM = 512
|
| 1021 |
+
ENC_DROPOUT = 0.1
|
| 1022 |
+
DEC_DROPOUT = 0.1
|
| 1023 |
+
SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
|
| 1024 |
+
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]
|
| 1025 |
+
# add 2 to length for start and stop tokens
|
| 1026 |
+
MAX_LENGTH = threshold + 2
|
| 1027 |
+
|
| 1028 |
+
# model name
|
| 1029 |
+
MODEL_OUT_FOLDER = f"{folder_out}"
|
| 1030 |
+
|
| 1031 |
+
MODEL_NAME = "transformer_%s_%s_%s_%s_%s" % (
|
| 1032 |
+
invalid_type, num_errors, data_source, BATCH_SIZE, layers)
|
| 1033 |
+
if not os.path.exists(MODEL_OUT_FOLDER):
|
| 1034 |
+
os.mkdir(MODEL_OUT_FOLDER)
|
| 1035 |
+
|
| 1036 |
+
out = os.path.join(MODEL_OUT_FOLDER, MODEL_NAME)
|
| 1037 |
+
|
| 1038 |
+
torch.save(SRC.vocab, f'{out}_vocab_src.pth')
|
| 1039 |
+
torch.save(TRG.vocab, f'{out}_vocab_trg.pth')
|
| 1040 |
+
|
| 1041 |
+
# iterator is a dataloader
|
| 1042 |
+
# iterator to pass to the same length and create batches in which the
|
| 1043 |
+
# amount of padding is minimized
|
| 1044 |
+
if validation_step:
|
| 1045 |
+
train_iter, val_iter = BucketIterator.splits(
|
| 1046 |
+
(train, val),
|
| 1047 |
+
batch_sizes=(BATCH_SIZE, 256),
|
| 1048 |
+
sort_within_batch=True,
|
| 1049 |
+
shuffle=True,
|
| 1050 |
+
# the BucketIterator needs to be told what function it should use to
|
| 1051 |
+
# group the data.
|
| 1052 |
+
sort_key=lambda x: len(x.src),
|
| 1053 |
+
device=device,
|
| 1054 |
+
)
|
| 1055 |
+
else:
|
| 1056 |
+
train_iter = BucketIterator(
|
| 1057 |
+
train,
|
| 1058 |
+
batch_size=BATCH_SIZE,
|
| 1059 |
+
sort_within_batch=True,
|
| 1060 |
+
shuffle=True,
|
| 1061 |
+
# the BucketIterator needs to be told what function it should use to
|
| 1062 |
+
# group the data.
|
| 1063 |
+
sort_key=lambda x: len(x.src),
|
| 1064 |
+
device=device,
|
| 1065 |
+
)
|
| 1066 |
+
val_iter = None
|
| 1067 |
+
|
| 1068 |
+
drugex_iter = Iterator(
|
| 1069 |
+
drugex,
|
| 1070 |
+
batch_size=64,
|
| 1071 |
+
device=device,
|
| 1072 |
+
sort=False,
|
| 1073 |
+
sort_within_batch=True,
|
| 1074 |
+
sort_key=lambda x: len(x.src),
|
| 1075 |
+
repeat=False,
|
| 1076 |
+
)
|
| 1077 |
+
|
| 1078 |
+
|
| 1079 |
+
# model initialization
|
| 1080 |
+
|
| 1081 |
+
enc = Encoder(
|
| 1082 |
+
INPUT_DIM,
|
| 1083 |
+
HID_DIM,
|
| 1084 |
+
ENC_LAYERS,
|
| 1085 |
+
ENC_HEADS,
|
| 1086 |
+
ENC_PF_DIM,
|
| 1087 |
+
ENC_DROPOUT,
|
| 1088 |
+
MAX_LENGTH,
|
| 1089 |
+
device,
|
| 1090 |
+
)
|
| 1091 |
+
dec = Decoder(
|
| 1092 |
+
OUTPUT_DIM,
|
| 1093 |
+
HID_DIM,
|
| 1094 |
+
DEC_LAYERS,
|
| 1095 |
+
DEC_HEADS,
|
| 1096 |
+
DEC_PF_DIM,
|
| 1097 |
+
DEC_DROPOUT,
|
| 1098 |
+
MAX_LENGTH,
|
| 1099 |
+
device,
|
| 1100 |
+
)
|
| 1101 |
+
|
| 1102 |
+
model = Seq2Seq(
|
| 1103 |
+
enc,
|
| 1104 |
+
dec,
|
| 1105 |
+
SRC_PAD_IDX,
|
| 1106 |
+
TRG_PAD_IDX,
|
| 1107 |
+
device,
|
| 1108 |
+
train_iter,
|
| 1109 |
+
out=out,
|
| 1110 |
+
loader_valid=val_iter,
|
| 1111 |
+
loader_drugex=drugex_iter,
|
| 1112 |
+
epochs=EPOCHS,
|
| 1113 |
+
TRG=TRG,
|
| 1114 |
+
SRC=SRC,
|
| 1115 |
+
).to(device)
|
| 1116 |
+
|
| 1117 |
+
|
| 1118 |
+
|
| 1119 |
+
|
| 1120 |
+
return model, out, SRC
|
| 1121 |
+
|
| 1122 |
+
|
| 1123 |
+
def train_model(model, out, assess):
|
| 1124 |
+
"""Apply given weights (& assess performance or train further) or start training new model
|
| 1125 |
+
|
| 1126 |
+
Args:
|
| 1127 |
+
model: initialized model
|
| 1128 |
+
out: .pkg file with model parameters
|
| 1129 |
+
asses: bool
|
| 1130 |
+
|
| 1131 |
+
Returns:
|
| 1132 |
+
model with (new) weights
|
| 1133 |
+
"""
|
| 1134 |
+
|
| 1135 |
+
if os.path.exists(f"{out}.pkg") and assess:
|
| 1136 |
+
|
| 1137 |
+
|
| 1138 |
+
model.load_state_dict(torch.load(f=out + ".pkg"))
|
| 1139 |
+
(
|
| 1140 |
+
valids,
|
| 1141 |
+
loss_valid,
|
| 1142 |
+
valids_de,
|
| 1143 |
+
df_output,
|
| 1144 |
+
df_output_de,
|
| 1145 |
+
right_molecules,
|
| 1146 |
+
complexity,
|
| 1147 |
+
unchanged,
|
| 1148 |
+
unchanged_de,
|
| 1149 |
+
) = model.evaluate(True)
|
| 1150 |
+
|
| 1151 |
+
|
| 1152 |
+
# log = open('unchanged.log', 'a')
|
| 1153 |
+
# info = f'type: comb unchanged: {unchan:.4g} unchanged_drugex: {unchan_de:.4g}'
|
| 1154 |
+
# print(info, file=log, flush = True)
|
| 1155 |
+
# print(valids_de)
|
| 1156 |
+
# print(unchanged_de)
|
| 1157 |
+
|
| 1158 |
+
# print(unchan)
|
| 1159 |
+
# print(unchan_de)
|
| 1160 |
+
# df_output_de.to_csv(f'{out}_de_new.csv', index = False)
|
| 1161 |
+
|
| 1162 |
+
# error_de = 1 - valids_de / len(drugex_iter.dataset)
|
| 1163 |
+
# print(error_de)
|
| 1164 |
+
# df_output.to_csv(f'{out}_par.csv', index = False)
|
| 1165 |
+
|
| 1166 |
+
elif os.path.exists(f"{out}.pkg"):
|
| 1167 |
+
|
| 1168 |
+
# starts from the model after the last epoch, not the best epoch
|
| 1169 |
+
model.load_state_dict(torch.load(f=out + "_last.pkg"))
|
| 1170 |
+
# need to change how log file names epochs
|
| 1171 |
+
model.train_model()
|
| 1172 |
+
else:
|
| 1173 |
+
|
| 1174 |
+
model = model.apply(init_weights)
|
| 1175 |
+
model.train_model()
|
| 1176 |
+
|
| 1177 |
+
return model
|
| 1178 |
+
|
| 1179 |
+
|
| 1180 |
+
def correct_SMILES(model, out, error_source, device, SRC):
|
| 1181 |
+
"""Model that is given corrects SMILES and return number of correct ouputs and dataframe containing all outputs
|
| 1182 |
+
Args:
|
| 1183 |
+
model: initialized model
|
| 1184 |
+
out: .pkg file with model parameters
|
| 1185 |
+
asses: bool
|
| 1186 |
+
|
| 1187 |
+
Returns:
|
| 1188 |
+
valids: number of fixed outputs
|
| 1189 |
+
df_output: dataframe containing output (either correct or incorrect) & original input
|
| 1190 |
+
"""
|
| 1191 |
+
## account for tokens that are not yet in SRC without changing existing SRC token embeddings
|
| 1192 |
+
errors = TabularDataset(
|
| 1193 |
+
path=error_source,
|
| 1194 |
+
format="csv",
|
| 1195 |
+
skip_header=False,
|
| 1196 |
+
fields={"SMILES": ("src", SRC)},
|
| 1197 |
+
)
|
| 1198 |
+
|
| 1199 |
+
errors_loader = Iterator(
|
| 1200 |
+
errors,
|
| 1201 |
+
batch_size=64,
|
| 1202 |
+
device=device,
|
| 1203 |
+
sort=False,
|
| 1204 |
+
sort_within_batch=True,
|
| 1205 |
+
sort_key=lambda x: len(x.src),
|
| 1206 |
+
repeat=False,
|
| 1207 |
+
)
|
| 1208 |
+
model.load_state_dict(torch.load(f=out + ".pkg",map_location=torch.device('cpu')))
|
| 1209 |
+
# add option to use different iterator maybe?
|
| 1210 |
+
|
| 1211 |
+
valids, df_output = model.translate(errors_loader)
|
| 1212 |
+
#df_output.to_csv(f"{error_source}_fixed.csv", index=False)
|
| 1213 |
+
|
| 1214 |
+
|
| 1215 |
+
return valids, df_output
|
| 1216 |
+
|
| 1217 |
+
|
| 1218 |
+
|
| 1219 |
+
class smi_correct(object):
|
| 1220 |
+
def __init__(self, model_name, trans_file_path):
|
| 1221 |
+
# set random seed, used for error generation & initiation transformer
|
| 1222 |
+
|
| 1223 |
+
self.SEED = 42
|
| 1224 |
+
random.seed(self.SEED)
|
| 1225 |
+
self.model_name = model_name
|
| 1226 |
+
self.folder_out = "data/"
|
| 1227 |
+
|
| 1228 |
+
self.trans_file_path = trans_file_path
|
| 1229 |
+
|
| 1230 |
+
if not os.path.exists(self.folder_out):
|
| 1231 |
+
os.makedirs(self.folder_out)
|
| 1232 |
+
|
| 1233 |
+
self.invalid_type = 'multiple'
|
| 1234 |
+
self.num_errors = 12
|
| 1235 |
+
self.threshold = 200
|
| 1236 |
+
self.data_source = f"PAPYRUS_{self.threshold}"
|
| 1237 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
| 1238 |
+
self.initialize_source = 'data/papyrus_rnn_S.csv' # change this path
|
| 1239 |
+
|
| 1240 |
+
def standardization_pipeline(self, smile):
|
| 1241 |
+
desalter = MolStandardize.rdMolStandardize.LargestFragmentChooser()
|
| 1242 |
+
std_smile = None
|
| 1243 |
+
if not isinstance(smile, str): return None
|
| 1244 |
+
m = Chem.MolFromSmiles(smile)
|
| 1245 |
+
# skips smiles for which no mol file could be generated
|
| 1246 |
+
if m is not None:
|
| 1247 |
+
# standardizes
|
| 1248 |
+
std_m = standardizer.standardize_mol(m)
|
| 1249 |
+
# strips salts
|
| 1250 |
+
std_m_p, exclude = standardizer.get_parent_mol(std_m)
|
| 1251 |
+
if not exclude:
|
| 1252 |
+
# choose largest fragment for rare cases where chembl structure
|
| 1253 |
+
# pipeline leaves 2 fragments
|
| 1254 |
+
std_m_p_d = desalter.choose(std_m_p)
|
| 1255 |
+
std_smile = Chem.MolToSmiles(std_m_p_d)
|
| 1256 |
+
return std_smile
|
| 1257 |
+
|
| 1258 |
+
def remove_smiles_duplicates(self, dataframe: pd.DataFrame,
|
| 1259 |
+
subset: str) -> pd.DataFrame:
|
| 1260 |
+
return dataframe.drop_duplicates(subset=subset)
|
| 1261 |
+
|
| 1262 |
+
def correct(self, smi):
|
| 1263 |
+
|
| 1264 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 1265 |
+
|
| 1266 |
+
model, out, SRC = initialize_model(self.folder_out,
|
| 1267 |
+
self.data_source,
|
| 1268 |
+
error_source=self.initialize_source,
|
| 1269 |
+
device=device,
|
| 1270 |
+
threshold=self.threshold,
|
| 1271 |
+
epochs=30,
|
| 1272 |
+
layers=3,
|
| 1273 |
+
batch_size=16,
|
| 1274 |
+
invalid_type=self.invalid_type,
|
| 1275 |
+
num_errors=self.num_errors)
|
| 1276 |
+
|
| 1277 |
+
valids, df_output = correct_SMILES(model, out, smi, device,
|
| 1278 |
+
SRC)
|
| 1279 |
+
|
| 1280 |
+
df_output["SMILES"] = df_output.apply(lambda row: self.standardization_pipeline(row["CORRECT"]), axis=1)
|
| 1281 |
+
|
| 1282 |
+
df_output = self.remove_smiles_duplicates(df_output, subset="SMILES").drop(columns=["CORRECT", "INCORRECT", "ORIGINAL"]).dropna()
|
| 1283 |
+
|
| 1284 |
+
return df_output
|
src/util/utils.py
ADDED
|
@@ -0,0 +1,930 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import math
|
| 4 |
+
import datetime
|
| 5 |
+
import warnings
|
| 6 |
+
import itertools
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
from functools import partial
|
| 9 |
+
from collections import Counter
|
| 10 |
+
from multiprocessing import Pool
|
| 11 |
+
from statistics import mean
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
from matplotlib.lines import Line2D
|
| 16 |
+
from scipy.spatial.distance import cosine as cos_distance
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import wandb
|
| 20 |
+
|
| 21 |
+
from rdkit import Chem, DataStructs, RDLogger
|
| 22 |
+
from rdkit.Chem import (
|
| 23 |
+
AllChem,
|
| 24 |
+
Draw,
|
| 25 |
+
Descriptors,
|
| 26 |
+
Lipinski,
|
| 27 |
+
Crippen,
|
| 28 |
+
rdMolDescriptors,
|
| 29 |
+
FilterCatalog,
|
| 30 |
+
)
|
| 31 |
+
from rdkit.Chem.Scaffolds import MurckoScaffold
|
| 32 |
+
|
| 33 |
+
# Disable RDKit warnings
|
| 34 |
+
RDLogger.DisableLog("rdApp.*")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Metrics(object):
|
| 38 |
+
"""
|
| 39 |
+
Collection of static methods to compute various metrics for molecules.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
@staticmethod
|
| 43 |
+
def valid(x):
|
| 44 |
+
"""
|
| 45 |
+
Checks whether the molecule is valid.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
x: RDKit molecule object.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
bool: True if molecule is valid and has a non-empty SMILES representation.
|
| 52 |
+
"""
|
| 53 |
+
return x is not None and Chem.MolToSmiles(x) != ''
|
| 54 |
+
|
| 55 |
+
@staticmethod
|
| 56 |
+
def tanimoto_sim_1v2(data1, data2):
|
| 57 |
+
"""
|
| 58 |
+
Computes the average Tanimoto similarity for paired fingerprints.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
data1: Fingerprint data for first set.
|
| 62 |
+
data2: Fingerprint data for second set.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
float: The average Tanimoto similarity between corresponding fingerprints.
|
| 66 |
+
"""
|
| 67 |
+
# Determine the minimum size between two arrays for pairing
|
| 68 |
+
min_len = data1.size if data1.size > data2.size else data2
|
| 69 |
+
sims = []
|
| 70 |
+
for i in range(min_len):
|
| 71 |
+
sim = DataStructs.FingerprintSimilarity(data1[i], data2[i])
|
| 72 |
+
sims.append(sim)
|
| 73 |
+
# Use 'mean' from statistics; note that variable 'sim' was used, corrected to use sims list.
|
| 74 |
+
mean_sim = mean(sims)
|
| 75 |
+
return mean_sim
|
| 76 |
+
|
| 77 |
+
@staticmethod
|
| 78 |
+
def mol_length(x):
|
| 79 |
+
"""
|
| 80 |
+
Computes the length of the largest fragment (by character count) in a SMILES string.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
x (str): SMILES string.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
int: Number of alphabetic characters in the longest fragment of the SMILES.
|
| 87 |
+
"""
|
| 88 |
+
if x is not None:
|
| 89 |
+
# Split at dots (.) and take the fragment with maximum length, then count alphabetic characters.
|
| 90 |
+
return len([char for char in max(x.split(sep="."), key=len).upper() if char.isalpha()])
|
| 91 |
+
else:
|
| 92 |
+
return 0
|
| 93 |
+
|
| 94 |
+
@staticmethod
|
| 95 |
+
def max_component(data, max_len):
|
| 96 |
+
"""
|
| 97 |
+
Returns the average normalized length of molecules in the dataset.
|
| 98 |
+
|
| 99 |
+
Each molecule's length is computed and divided by max_len, then averaged.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
data (iterable): Collection of SMILES strings.
|
| 103 |
+
max_len (int): Maximum possible length for normalization.
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
float: Normalized average length.
|
| 107 |
+
"""
|
| 108 |
+
lengths = np.array(list(map(Metrics.mol_length, data)), dtype=np.float32)
|
| 109 |
+
return (lengths / max_len).mean()
|
| 110 |
+
|
| 111 |
+
@staticmethod
|
| 112 |
+
def mean_atom_type(data):
|
| 113 |
+
"""
|
| 114 |
+
Computes the average number of unique atom types in the provided node data.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
data (iterable): Iterable containing node data with unique atom types.
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
float: The average count of unique atom types, subtracting one.
|
| 121 |
+
"""
|
| 122 |
+
atom_types_used = []
|
| 123 |
+
for i in data:
|
| 124 |
+
# Assuming each element i has a .unique() method that returns unique atom types.
|
| 125 |
+
atom_types_used.append(len(i.unique().tolist()))
|
| 126 |
+
av_type = np.mean(atom_types_used) - 1
|
| 127 |
+
return av_type
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def mols2grid_image(mols, path):
|
| 131 |
+
"""
|
| 132 |
+
Saves grid images for a list of molecules.
|
| 133 |
+
|
| 134 |
+
For each molecule in the list, computes 2D coordinates and saves an image file.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
mols (list): List of RDKit molecule objects.
|
| 138 |
+
path (str): Directory where images will be saved.
|
| 139 |
+
"""
|
| 140 |
+
# Replace None molecules with an empty molecule
|
| 141 |
+
mols = [e if e is not None else Chem.RWMol() for e in mols]
|
| 142 |
+
|
| 143 |
+
for i in range(len(mols)):
|
| 144 |
+
if Metrics.valid(mols[i]):
|
| 145 |
+
AllChem.Compute2DCoords(mols[i])
|
| 146 |
+
file_path = os.path.join(path, "{}.png".format(i + 1))
|
| 147 |
+
Draw.MolToFile(mols[i], file_path, size=(1200, 1200))
|
| 148 |
+
# wandb.save(file_path) # Optionally save to Weights & Biases
|
| 149 |
+
else:
|
| 150 |
+
continue
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def save_smiles_matrices(mols, edges_hard, nodes_hard, path, data_source=None):
|
| 154 |
+
"""
|
| 155 |
+
Saves the edge and node matrices along with SMILES strings to text files.
|
| 156 |
+
|
| 157 |
+
Each file contains the edge matrix, node matrix, and SMILES representation for a molecule.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
mols (list): List of RDKit molecule objects.
|
| 161 |
+
edges_hard (torch.Tensor): Tensor of edge features.
|
| 162 |
+
nodes_hard (torch.Tensor): Tensor of node features.
|
| 163 |
+
path (str): Directory where files will be saved.
|
| 164 |
+
data_source: Optional data source information (not used in function).
|
| 165 |
+
"""
|
| 166 |
+
mols = [e if e is not None else Chem.RWMol() for e in mols]
|
| 167 |
+
|
| 168 |
+
for i in range(len(mols)):
|
| 169 |
+
if Metrics.valid(mols[i]):
|
| 170 |
+
save_path = os.path.join(path, "{}.txt".format(i + 1))
|
| 171 |
+
with open(save_path, "a") as f:
|
| 172 |
+
np.savetxt(f, edges_hard[i].cpu().numpy(), header="edge matrix:\n", fmt='%1.2f')
|
| 173 |
+
f.write("\n")
|
| 174 |
+
np.savetxt(f, nodes_hard[i].cpu().numpy(), header="node matrix:\n", footer="\nsmiles:", fmt='%1.2f')
|
| 175 |
+
f.write("\n")
|
| 176 |
+
# Append the SMILES representation to the file
|
| 177 |
+
with open(save_path, "a") as f:
|
| 178 |
+
print(Chem.MolToSmiles(mols[i]), file=f)
|
| 179 |
+
# wandb.save(save_path) # Optionally save to Weights & Biases
|
| 180 |
+
else:
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
def dense_to_sparse_with_attr(adj):
|
| 184 |
+
"""
|
| 185 |
+
Converts a dense adjacency matrix to a sparse representation.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
adj (torch.Tensor): Adjacency matrix tensor (2D or 3D) with square last two dimensions.
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
tuple: A tuple containing indices and corresponding edge attributes.
|
| 192 |
+
"""
|
| 193 |
+
assert adj.dim() >= 2 and adj.dim() <= 3
|
| 194 |
+
assert adj.size(-1) == adj.size(-2)
|
| 195 |
+
|
| 196 |
+
index = adj.nonzero(as_tuple=True)
|
| 197 |
+
edge_attr = adj[index]
|
| 198 |
+
|
| 199 |
+
if len(index) == 3:
|
| 200 |
+
batch = index[0] * adj.size(-1)
|
| 201 |
+
index = (batch + index[1], batch + index[2])
|
| 202 |
+
return index, edge_attr
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def mol_sample(sample_directory, edges, nodes, idx, i, matrices2mol, dataset_name):
|
| 206 |
+
"""
|
| 207 |
+
Samples molecules from edge and node predictions, then saves grid images and text files.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
sample_directory (str): Directory to save the samples.
|
| 211 |
+
edges (torch.Tensor): Edge predictions tensor.
|
| 212 |
+
nodes (torch.Tensor): Node predictions tensor.
|
| 213 |
+
idx (int): Current index for naming the sample.
|
| 214 |
+
i (int): Epoch/iteration index.
|
| 215 |
+
matrices2mol (callable): Function to convert matrices to RDKit molecule.
|
| 216 |
+
dataset_name (str): Name of the dataset for file naming.
|
| 217 |
+
"""
|
| 218 |
+
sample_path = os.path.join(sample_directory, "{}_{}-epoch_iteration".format(idx + 1, i + 1))
|
| 219 |
+
# Get the index of the maximum predicted feature along the last dimension
|
| 220 |
+
g_edges_hat_sample = torch.max(edges, -1)[1]
|
| 221 |
+
g_nodes_hat_sample = torch.max(nodes, -1)[1]
|
| 222 |
+
# Convert matrices to molecule objects
|
| 223 |
+
mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(),
|
| 224 |
+
strict=True, file_name=dataset_name)
|
| 225 |
+
for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
|
| 226 |
+
|
| 227 |
+
if not os.path.exists(sample_path):
|
| 228 |
+
os.makedirs(sample_path)
|
| 229 |
+
|
| 230 |
+
mols2grid_image(mol, sample_path)
|
| 231 |
+
save_smiles_matrices(mol, g_edges_hat_sample.detach(), g_nodes_hat_sample.detach(), sample_path)
|
| 232 |
+
|
| 233 |
+
# Remove the directory if no files were saved
|
| 234 |
+
if len(os.listdir(sample_path)) == 0:
|
| 235 |
+
os.rmdir(sample_path)
|
| 236 |
+
|
| 237 |
+
print("Valid molecules are saved.")
|
| 238 |
+
print("Valid matrices and smiles are saved")
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def logging(log_path, start_time, i, idx, loss, save_path, drug_smiles, edge, node,
|
| 242 |
+
matrices2mol, dataset_name, real_adj, real_annot, drug_vecs):
|
| 243 |
+
"""
|
| 244 |
+
Logs training statistics and evaluation metrics.
|
| 245 |
+
|
| 246 |
+
The function generates molecules from predictions, computes various metrics such as
|
| 247 |
+
validity, uniqueness, novelty, and similarity scores, and logs them using wandb and a file.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
log_path (str): Path to save the log file.
|
| 251 |
+
start_time (float): Start time to compute elapsed time.
|
| 252 |
+
i (int): Current iteration index.
|
| 253 |
+
idx (int): Current epoch index.
|
| 254 |
+
loss (dict): Dictionary to update with loss and metric values.
|
| 255 |
+
save_path (str): Directory path to save sample outputs.
|
| 256 |
+
drug_smiles (list): List of reference drug SMILES.
|
| 257 |
+
edge (torch.Tensor): Edge prediction tensor.
|
| 258 |
+
node (torch.Tensor): Node prediction tensor.
|
| 259 |
+
matrices2mol (callable): Function to convert matrices to molecules.
|
| 260 |
+
dataset_name (str): Dataset name.
|
| 261 |
+
real_adj (torch.Tensor): Ground truth adjacency matrix tensor.
|
| 262 |
+
real_annot (torch.Tensor): Ground truth annotation tensor.
|
| 263 |
+
drug_vecs (list): List of drug vectors for similarity calculation.
|
| 264 |
+
"""
|
| 265 |
+
g_edges_hat_sample = torch.max(edge, -1)[1]
|
| 266 |
+
g_nodes_hat_sample = torch.max(node, -1)[1]
|
| 267 |
+
|
| 268 |
+
a_tensor_sample = torch.max(real_adj, -1)[1].float()
|
| 269 |
+
x_tensor_sample = torch.max(real_annot, -1)[1].float()
|
| 270 |
+
|
| 271 |
+
# Generate molecules from predictions and real data
|
| 272 |
+
mols = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(),
|
| 273 |
+
strict=True, file_name=dataset_name)
|
| 274 |
+
for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
|
| 275 |
+
real_mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(),
|
| 276 |
+
strict=True, file_name=dataset_name)
|
| 277 |
+
for e_, n_ in zip(a_tensor_sample, x_tensor_sample)]
|
| 278 |
+
|
| 279 |
+
# Compute average number of atom types
|
| 280 |
+
atom_types_average = Metrics.mean_atom_type(g_nodes_hat_sample)
|
| 281 |
+
real_smiles = [Chem.MolToSmiles(x) for x in real_mol if x is not None]
|
| 282 |
+
gen_smiles = []
|
| 283 |
+
uniq_smiles = []
|
| 284 |
+
for line in mols:
|
| 285 |
+
if line is not None:
|
| 286 |
+
gen_smiles.append(Chem.MolToSmiles(line))
|
| 287 |
+
uniq_smiles.append(Chem.MolToSmiles(line))
|
| 288 |
+
elif line is None:
|
| 289 |
+
gen_smiles.append(None)
|
| 290 |
+
|
| 291 |
+
# Process SMILES to take the longest fragment if multiple are present
|
| 292 |
+
gen_smiles_saves = [None if x is None else max(x.split('.'), key=len) for x in gen_smiles]
|
| 293 |
+
uniq_smiles_saves = [None if x is None else max(x.split('.'), key=len) for x in uniq_smiles]
|
| 294 |
+
|
| 295 |
+
# Save the generated SMILES to a text file
|
| 296 |
+
sample_save_dir = os.path.join(save_path, "samples.txt")
|
| 297 |
+
with open(sample_save_dir, "a") as f:
|
| 298 |
+
for s in gen_smiles_saves:
|
| 299 |
+
if s is not None:
|
| 300 |
+
f.write(s + "\n")
|
| 301 |
+
|
| 302 |
+
k = len(set(uniq_smiles_saves) - {None})
|
| 303 |
+
et = time.time() - start_time
|
| 304 |
+
et = str(datetime.timedelta(seconds=et))[:-7]
|
| 305 |
+
log_str = "Elapsed [{}], Epoch/Iteration [{}/{}]".format(et, idx, i + 1)
|
| 306 |
+
|
| 307 |
+
# Generate molecular fingerprints for similarity computations
|
| 308 |
+
gen_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in mols if x is not None]
|
| 309 |
+
chembl_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in real_mol if x is not None]
|
| 310 |
+
|
| 311 |
+
# Compute evaluation metrics: validity, uniqueness, novelty, similarity scores, and average maximum molecule length.
|
| 312 |
+
valid = fraction_valid(gen_smiles_saves)
|
| 313 |
+
unique = fraction_unique(uniq_smiles_saves, k)
|
| 314 |
+
novel_starting_mol = novelty(gen_smiles_saves, real_smiles)
|
| 315 |
+
novel_akt = novelty(gen_smiles_saves, drug_smiles)
|
| 316 |
+
if len(uniq_smiles_saves) == 0:
|
| 317 |
+
snn_chembl = 0
|
| 318 |
+
snn_akt = 0
|
| 319 |
+
maxlen = 0
|
| 320 |
+
else:
|
| 321 |
+
snn_chembl = average_agg_tanimoto(np.array(chembl_vecs), np.array(gen_vecs))
|
| 322 |
+
snn_akt = average_agg_tanimoto(np.array(drug_vecs), np.array(gen_vecs))
|
| 323 |
+
maxlen = Metrics.max_component(uniq_smiles_saves, 45)
|
| 324 |
+
|
| 325 |
+
# Update loss dictionary with computed metrics
|
| 326 |
+
loss.update({
|
| 327 |
+
'Validity': valid,
|
| 328 |
+
'Uniqueness': unique,
|
| 329 |
+
'Novelty': novel_starting_mol,
|
| 330 |
+
'Novelty_akt': novel_akt,
|
| 331 |
+
'SNN_chembl': snn_chembl,
|
| 332 |
+
'SNN_akt': snn_akt,
|
| 333 |
+
'MaxLen': maxlen,
|
| 334 |
+
'Atom_types': atom_types_average
|
| 335 |
+
})
|
| 336 |
+
|
| 337 |
+
# Log metrics using wandb
|
| 338 |
+
wandb.log({
|
| 339 |
+
"Validity": valid,
|
| 340 |
+
"Uniqueness": unique,
|
| 341 |
+
"Novelty": novel_starting_mol,
|
| 342 |
+
"Novelty_akt": novel_akt,
|
| 343 |
+
"SNN_chembl": snn_chembl,
|
| 344 |
+
"SNN_akt": snn_akt,
|
| 345 |
+
"MaxLen": maxlen,
|
| 346 |
+
"Atom_types": atom_types_average
|
| 347 |
+
})
|
| 348 |
+
|
| 349 |
+
# Append each metric to the log string and write to the log file
|
| 350 |
+
for tag, value in loss.items():
|
| 351 |
+
log_str += ", {}: {:.4f}".format(tag, value)
|
| 352 |
+
with open(log_path, "a") as f:
|
| 353 |
+
f.write(log_str + "\n")
|
| 354 |
+
print(log_str)
|
| 355 |
+
print("\n")
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def plot_grad_flow(named_parameters, model, itera, epoch, grad_flow_directory):
|
| 359 |
+
"""
|
| 360 |
+
Plots the gradients flowing through different layers during training.
|
| 361 |
+
|
| 362 |
+
This is useful to check for possible gradient vanishing or exploding problems.
|
| 363 |
+
|
| 364 |
+
Args:
|
| 365 |
+
named_parameters (iterable): Iterable of (name, parameter) tuples from the model.
|
| 366 |
+
model (str): Name of the model (used for saving the plot).
|
| 367 |
+
itera (int): Iteration index.
|
| 368 |
+
epoch (int): Current epoch.
|
| 369 |
+
grad_flow_directory (str): Directory to save the gradient flow plot.
|
| 370 |
+
"""
|
| 371 |
+
ave_grads = []
|
| 372 |
+
max_grads = []
|
| 373 |
+
layers = []
|
| 374 |
+
for n, p in named_parameters:
|
| 375 |
+
if p.requires_grad and ("bias" not in n):
|
| 376 |
+
layers.append(n)
|
| 377 |
+
ave_grads.append(p.grad.abs().mean().cpu())
|
| 378 |
+
max_grads.append(p.grad.abs().max().cpu())
|
| 379 |
+
# Plot maximum gradients and average gradients for each layer
|
| 380 |
+
plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
|
| 381 |
+
plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
|
| 382 |
+
plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k")
|
| 383 |
+
plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
|
| 384 |
+
plt.xlim(left=0, right=len(ave_grads))
|
| 385 |
+
plt.ylim(bottom=-0.001, top=1) # Zoom in on lower gradient regions
|
| 386 |
+
plt.xlabel("Layers")
|
| 387 |
+
plt.ylabel("Average Gradient")
|
| 388 |
+
plt.title("Gradient Flow")
|
| 389 |
+
plt.grid(True)
|
| 390 |
+
plt.legend([
|
| 391 |
+
Line2D([0], [0], color="c", lw=4),
|
| 392 |
+
Line2D([0], [0], color="b", lw=4),
|
| 393 |
+
Line2D([0], [0], color="k", lw=4)
|
| 394 |
+
], ['max-gradient', 'mean-gradient', 'zero-gradient'])
|
| 395 |
+
# Save the plot to the specified directory
|
| 396 |
+
plt.savefig(os.path.join(grad_flow_directory, "weights_" + model + "_" + str(itera) + "_" + str(epoch) + ".png"), dpi=500, bbox_inches='tight')
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def get_mol(smiles_or_mol):
|
| 400 |
+
"""
|
| 401 |
+
Loads a SMILES string or molecule into an RDKit molecule object.
|
| 402 |
+
|
| 403 |
+
Args:
|
| 404 |
+
smiles_or_mol (str or RDKit Mol): SMILES string or RDKit molecule.
|
| 405 |
+
|
| 406 |
+
Returns:
|
| 407 |
+
RDKit Mol or None: Sanitized molecule object, or None if invalid.
|
| 408 |
+
"""
|
| 409 |
+
if isinstance(smiles_or_mol, str):
|
| 410 |
+
if len(smiles_or_mol) == 0:
|
| 411 |
+
return None
|
| 412 |
+
mol = Chem.MolFromSmiles(smiles_or_mol)
|
| 413 |
+
if mol is None:
|
| 414 |
+
return None
|
| 415 |
+
try:
|
| 416 |
+
Chem.SanitizeMol(mol)
|
| 417 |
+
except ValueError:
|
| 418 |
+
return None
|
| 419 |
+
return mol
|
| 420 |
+
return smiles_or_mol
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def mapper(n_jobs):
|
| 424 |
+
"""
|
| 425 |
+
Returns a mapping function for parallel or serial processing.
|
| 426 |
+
|
| 427 |
+
If n_jobs == 1, returns the built-in map function.
|
| 428 |
+
If n_jobs > 1, returns a function that uses a multiprocessing pool.
|
| 429 |
+
|
| 430 |
+
Args:
|
| 431 |
+
n_jobs (int or pool object): Number of jobs or a Pool instance.
|
| 432 |
+
|
| 433 |
+
Returns:
|
| 434 |
+
callable: A function that acts like map.
|
| 435 |
+
"""
|
| 436 |
+
if n_jobs == 1:
|
| 437 |
+
def _mapper(*args, **kwargs):
|
| 438 |
+
return list(map(*args, **kwargs))
|
| 439 |
+
return _mapper
|
| 440 |
+
if isinstance(n_jobs, int):
|
| 441 |
+
pool = Pool(n_jobs)
|
| 442 |
+
def _mapper(*args, **kwargs):
|
| 443 |
+
try:
|
| 444 |
+
result = pool.map(*args, **kwargs)
|
| 445 |
+
finally:
|
| 446 |
+
pool.terminate()
|
| 447 |
+
return result
|
| 448 |
+
return _mapper
|
| 449 |
+
return n_jobs.map
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def remove_invalid(gen, canonize=True, n_jobs=1):
|
| 453 |
+
"""
|
| 454 |
+
Removes invalid molecules from the provided dataset.
|
| 455 |
+
|
| 456 |
+
Optionally canonizes the SMILES strings.
|
| 457 |
+
|
| 458 |
+
Args:
|
| 459 |
+
gen (list): List of SMILES strings.
|
| 460 |
+
canonize (bool): Whether to convert to canonical SMILES.
|
| 461 |
+
n_jobs (int): Number of parallel jobs.
|
| 462 |
+
|
| 463 |
+
Returns:
|
| 464 |
+
list: Filtered list of valid molecules.
|
| 465 |
+
"""
|
| 466 |
+
if not canonize:
|
| 467 |
+
mols = mapper(n_jobs)(get_mol, gen)
|
| 468 |
+
return [gen_ for gen_, mol in zip(gen, mols) if mol is not None]
|
| 469 |
+
return [x for x in mapper(n_jobs)(canonic_smiles, gen) if x is not None]
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def fraction_valid(gen, n_jobs=1):
|
| 473 |
+
"""
|
| 474 |
+
Computes the fraction of valid molecules in the dataset.
|
| 475 |
+
|
| 476 |
+
Args:
|
| 477 |
+
gen (list): List of SMILES strings.
|
| 478 |
+
n_jobs (int): Number of parallel jobs.
|
| 479 |
+
|
| 480 |
+
Returns:
|
| 481 |
+
float: Fraction of molecules that are valid.
|
| 482 |
+
"""
|
| 483 |
+
gen = mapper(n_jobs)(get_mol, gen)
|
| 484 |
+
return 1 - gen.count(None) / len(gen)
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def canonic_smiles(smiles_or_mol):
|
| 488 |
+
"""
|
| 489 |
+
Converts a SMILES string or molecule to its canonical SMILES.
|
| 490 |
+
|
| 491 |
+
Args:
|
| 492 |
+
smiles_or_mol (str or RDKit Mol): Input molecule.
|
| 493 |
+
|
| 494 |
+
Returns:
|
| 495 |
+
str or None: Canonical SMILES string or None if invalid.
|
| 496 |
+
"""
|
| 497 |
+
mol = get_mol(smiles_or_mol)
|
| 498 |
+
if mol is None:
|
| 499 |
+
return None
|
| 500 |
+
return Chem.MolToSmiles(mol)
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
|
| 504 |
+
"""
|
| 505 |
+
Computes the fraction of unique molecules.
|
| 506 |
+
|
| 507 |
+
Optionally computes unique@k, where only the first k molecules are considered.
|
| 508 |
+
|
| 509 |
+
Args:
|
| 510 |
+
gen (list): List of SMILES strings.
|
| 511 |
+
k (int): Optional cutoff for unique@k computation.
|
| 512 |
+
n_jobs (int): Number of parallel jobs.
|
| 513 |
+
check_validity (bool): Whether to check for validity of molecules.
|
| 514 |
+
|
| 515 |
+
Returns:
|
| 516 |
+
float: Fraction of unique molecules.
|
| 517 |
+
"""
|
| 518 |
+
if k is not None:
|
| 519 |
+
if len(gen) < k:
|
| 520 |
+
warnings.warn("Can't compute unique@{}.".format(k) +
|
| 521 |
+
" gen contains only {} molecules".format(len(gen)))
|
| 522 |
+
gen = gen[:k]
|
| 523 |
+
if check_validity:
|
| 524 |
+
canonic = list(mapper(n_jobs)(canonic_smiles, gen))
|
| 525 |
+
canonic = [i for i in canonic if i is not None]
|
| 526 |
+
set_cannonic = set(canonic)
|
| 527 |
+
return 0 if len(canonic) == 0 else len(set_cannonic) / len(canonic)
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def novelty(gen, train, n_jobs=1):
|
| 531 |
+
"""
|
| 532 |
+
Computes the novelty score of generated molecules.
|
| 533 |
+
|
| 534 |
+
Novelty is defined as the fraction of generated molecules that do not appear in the training set.
|
| 535 |
+
|
| 536 |
+
Args:
|
| 537 |
+
gen (list): List of generated SMILES strings.
|
| 538 |
+
train (list): List of training SMILES strings.
|
| 539 |
+
n_jobs (int): Number of parallel jobs.
|
| 540 |
+
|
| 541 |
+
Returns:
|
| 542 |
+
float: Novelty score.
|
| 543 |
+
"""
|
| 544 |
+
gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
|
| 545 |
+
gen_smiles_set = set(gen_smiles) - {None}
|
| 546 |
+
train_set = set(train)
|
| 547 |
+
return 0 if len(gen_smiles_set) == 0 else len(gen_smiles_set - train_set) / len(gen_smiles_set)
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
def internal_diversity(gen):
|
| 551 |
+
"""
|
| 552 |
+
Computes the internal diversity of a set of molecules.
|
| 553 |
+
|
| 554 |
+
Internal diversity is defined as one minus the average Tanimoto similarity between all pairs.
|
| 555 |
+
|
| 556 |
+
Args:
|
| 557 |
+
gen: Array-like representation of molecules.
|
| 558 |
+
|
| 559 |
+
Returns:
|
| 560 |
+
tuple: Mean and standard deviation of internal diversity.
|
| 561 |
+
"""
|
| 562 |
+
diversity = [1 - x for x in average_agg_tanimoto(gen, gen, agg="mean", intdiv=True)]
|
| 563 |
+
return np.mean(diversity), np.std(diversity)
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def average_agg_tanimoto(stock_vecs, gen_vecs, batch_size=5000, agg='max', device='cpu', p=1, intdiv=False):
|
| 567 |
+
"""
|
| 568 |
+
Computes the average aggregated Tanimoto similarity between two sets of molecular fingerprints.
|
| 569 |
+
|
| 570 |
+
For each fingerprint in gen_vecs, finds the closest (max or mean) similarity with fingerprints in stock_vecs.
|
| 571 |
+
|
| 572 |
+
Args:
|
| 573 |
+
stock_vecs (numpy.ndarray): Array of fingerprint vectors from the reference set.
|
| 574 |
+
gen_vecs (numpy.ndarray): Array of fingerprint vectors from the generated set.
|
| 575 |
+
batch_size (int): Batch size for processing fingerprints.
|
| 576 |
+
agg (str): Aggregation method, either 'max' or 'mean'.
|
| 577 |
+
device (str): Device to perform computations on.
|
| 578 |
+
p (int): Power for averaging.
|
| 579 |
+
intdiv (bool): Whether to return individual similarities or the average.
|
| 580 |
+
|
| 581 |
+
Returns:
|
| 582 |
+
float or numpy.ndarray: Average aggregated Tanimoto similarity or array of individual scores.
|
| 583 |
+
"""
|
| 584 |
+
assert agg in ['max', 'mean'], "Can aggregate only max or mean"
|
| 585 |
+
agg_tanimoto = np.zeros(len(gen_vecs))
|
| 586 |
+
total = np.zeros(len(gen_vecs))
|
| 587 |
+
for j in range(0, stock_vecs.shape[0], batch_size):
|
| 588 |
+
x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float()
|
| 589 |
+
for i in range(0, gen_vecs.shape[0], batch_size):
|
| 590 |
+
y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float()
|
| 591 |
+
y_gen = y_gen.transpose(0, 1)
|
| 592 |
+
tp = torch.mm(x_stock, y_gen)
|
| 593 |
+
# Compute Jaccard/Tanimoto similarity
|
| 594 |
+
jac = (tp / (x_stock.sum(1, keepdim=True) + y_gen.sum(0, keepdim=True) - tp)).cpu().numpy()
|
| 595 |
+
jac[np.isnan(jac)] = 1
|
| 596 |
+
if p != 1:
|
| 597 |
+
jac = jac ** p
|
| 598 |
+
if agg == 'max':
|
| 599 |
+
agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum(
|
| 600 |
+
agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0))
|
| 601 |
+
elif agg == 'mean':
|
| 602 |
+
agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0)
|
| 603 |
+
total[i:i + y_gen.shape[1]] += jac.shape[0]
|
| 604 |
+
if agg == 'mean':
|
| 605 |
+
agg_tanimoto /= total
|
| 606 |
+
if p != 1:
|
| 607 |
+
agg_tanimoto = (agg_tanimoto) ** (1 / p)
|
| 608 |
+
if intdiv:
|
| 609 |
+
return agg_tanimoto
|
| 610 |
+
else:
|
| 611 |
+
return np.mean(agg_tanimoto)
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
def str2bool(v):
|
| 615 |
+
"""
|
| 616 |
+
Converts a string to a boolean.
|
| 617 |
+
|
| 618 |
+
Args:
|
| 619 |
+
v (str): Input string.
|
| 620 |
+
|
| 621 |
+
Returns:
|
| 622 |
+
bool: True if the string is 'true' (case insensitive), else False.
|
| 623 |
+
"""
|
| 624 |
+
return v.lower() in ('true')
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
def obey_lipinski(mol):
|
| 628 |
+
"""
|
| 629 |
+
Checks if a molecule obeys Lipinski's Rule of Five.
|
| 630 |
+
|
| 631 |
+
The function evaluates weight, hydrogen bond donors and acceptors, logP, and rotatable bonds.
|
| 632 |
+
|
| 633 |
+
Args:
|
| 634 |
+
mol (RDKit Mol): Molecule object.
|
| 635 |
+
|
| 636 |
+
Returns:
|
| 637 |
+
int: Number of Lipinski rules satisfied.
|
| 638 |
+
"""
|
| 639 |
+
mol = deepcopy(mol)
|
| 640 |
+
Chem.SanitizeMol(mol)
|
| 641 |
+
rule_1 = Descriptors.ExactMolWt(mol) < 500
|
| 642 |
+
rule_2 = Lipinski.NumHDonors(mol) <= 5
|
| 643 |
+
rule_3 = Lipinski.NumHAcceptors(mol) <= 10
|
| 644 |
+
rule_4 = (logp := Crippen.MolLogP(mol) >= -2) & (logp <= 5)
|
| 645 |
+
rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(mol) <= 10
|
| 646 |
+
return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]])
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
def obey_veber(mol):
|
| 650 |
+
"""
|
| 651 |
+
Checks if a molecule obeys Veber's rules.
|
| 652 |
+
|
| 653 |
+
Veber's rules focus on the number of rotatable bonds and topological polar surface area.
|
| 654 |
+
|
| 655 |
+
Args:
|
| 656 |
+
mol (RDKit Mol): Molecule object.
|
| 657 |
+
|
| 658 |
+
Returns:
|
| 659 |
+
int: Number of Veber's rules satisfied.
|
| 660 |
+
"""
|
| 661 |
+
mol = deepcopy(mol)
|
| 662 |
+
Chem.SanitizeMol(mol)
|
| 663 |
+
rule_1 = rdMolDescriptors.CalcNumRotatableBonds(mol) <= 10
|
| 664 |
+
rule_2 = rdMolDescriptors.CalcTPSA(mol) <= 140
|
| 665 |
+
return np.sum([int(a) for a in [rule_1, rule_2]])
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
def load_pains_filters():
|
| 669 |
+
"""
|
| 670 |
+
Loads the PAINS (Pan-Assay INterference compoundS) filters A, B, and C.
|
| 671 |
+
|
| 672 |
+
Returns:
|
| 673 |
+
FilterCatalog: An RDKit FilterCatalog object containing PAINS filters.
|
| 674 |
+
"""
|
| 675 |
+
params = FilterCatalog.FilterCatalogParams()
|
| 676 |
+
params.AddCatalog(FilterCatalog.FilterCatalogParams.FilterCatalogs.PAINS_A)
|
| 677 |
+
params.AddCatalog(FilterCatalog.FilterCatalogParams.FilterCatalogs.PAINS_B)
|
| 678 |
+
params.AddCatalog(FilterCatalog.FilterCatalogParams.FilterCatalogs.PAINS_C)
|
| 679 |
+
catalog = FilterCatalog.FilterCatalog(params)
|
| 680 |
+
return catalog
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
def is_pains(mol, catalog):
|
| 684 |
+
"""
|
| 685 |
+
Checks if the given molecule is a PAINS compound.
|
| 686 |
+
|
| 687 |
+
Args:
|
| 688 |
+
mol (RDKit Mol): Molecule object.
|
| 689 |
+
catalog (FilterCatalog): A catalog of PAINS filters.
|
| 690 |
+
|
| 691 |
+
Returns:
|
| 692 |
+
bool: True if the molecule matches a PAINS filter, else False.
|
| 693 |
+
"""
|
| 694 |
+
entry = catalog.GetFirstMatch(mol)
|
| 695 |
+
return entry is not None
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
def mapper(n_jobs):
|
| 699 |
+
"""
|
| 700 |
+
Returns a mapping function for parallel or serial processing.
|
| 701 |
+
|
| 702 |
+
If n_jobs == 1, returns the built-in map function.
|
| 703 |
+
If n_jobs > 1, returns a function that uses a multiprocessing pool.
|
| 704 |
+
|
| 705 |
+
Args:
|
| 706 |
+
n_jobs (int or pool object): Number of jobs or a Pool instance.
|
| 707 |
+
|
| 708 |
+
Returns:
|
| 709 |
+
callable: A function that acts like map.
|
| 710 |
+
"""
|
| 711 |
+
if n_jobs == 1:
|
| 712 |
+
def _mapper(*args, **kwargs):
|
| 713 |
+
return list(map(*args, **kwargs))
|
| 714 |
+
return _mapper
|
| 715 |
+
if isinstance(n_jobs, int):
|
| 716 |
+
pool = Pool(n_jobs)
|
| 717 |
+
def _mapper(*args, **kwargs):
|
| 718 |
+
try:
|
| 719 |
+
result = pool.map(*args, **kwargs)
|
| 720 |
+
finally:
|
| 721 |
+
pool.terminate()
|
| 722 |
+
return result
|
| 723 |
+
return _mapper
|
| 724 |
+
return n_jobs.map
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
def fragmenter(mol):
|
| 728 |
+
"""
|
| 729 |
+
Fragments a molecule using BRICS and returns a list of fragment SMILES.
|
| 730 |
+
|
| 731 |
+
Args:
|
| 732 |
+
mol (str or RDKit Mol): Input molecule.
|
| 733 |
+
|
| 734 |
+
Returns:
|
| 735 |
+
list: List of fragment SMILES strings.
|
| 736 |
+
"""
|
| 737 |
+
fgs = AllChem.FragmentOnBRICSBonds(get_mol(mol))
|
| 738 |
+
fgs_smi = Chem.MolToSmiles(fgs).split(".")
|
| 739 |
+
return fgs_smi
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
def get_mol(smiles_or_mol):
|
| 743 |
+
"""
|
| 744 |
+
Loads a SMILES string or molecule into an RDKit molecule object.
|
| 745 |
+
|
| 746 |
+
Args:
|
| 747 |
+
smiles_or_mol (str or RDKit Mol): SMILES string or molecule.
|
| 748 |
+
|
| 749 |
+
Returns:
|
| 750 |
+
RDKit Mol or None: Sanitized molecule object or None if invalid.
|
| 751 |
+
"""
|
| 752 |
+
if isinstance(smiles_or_mol, str):
|
| 753 |
+
if len(smiles_or_mol) == 0:
|
| 754 |
+
return None
|
| 755 |
+
mol = Chem.MolFromSmiles(smiles_or_mol)
|
| 756 |
+
if mol is None:
|
| 757 |
+
return None
|
| 758 |
+
try:
|
| 759 |
+
Chem.SanitizeMol(mol)
|
| 760 |
+
except ValueError:
|
| 761 |
+
return None
|
| 762 |
+
return mol
|
| 763 |
+
return smiles_or_mol
|
| 764 |
+
|
| 765 |
+
|
| 766 |
+
def compute_fragments(mol_list, n_jobs=1):
|
| 767 |
+
"""
|
| 768 |
+
Fragments a list of molecules using BRICS and returns a counter of fragment occurrences.
|
| 769 |
+
|
| 770 |
+
Args:
|
| 771 |
+
mol_list (list): List of molecules (SMILES or RDKit Mol).
|
| 772 |
+
n_jobs (int): Number of parallel jobs.
|
| 773 |
+
|
| 774 |
+
Returns:
|
| 775 |
+
Counter: A Counter dictionary mapping fragment SMILES to counts.
|
| 776 |
+
"""
|
| 777 |
+
fragments = Counter()
|
| 778 |
+
for mol_frag in mapper(n_jobs)(fragmenter, mol_list):
|
| 779 |
+
fragments.update(mol_frag)
|
| 780 |
+
return fragments
|
| 781 |
+
|
| 782 |
+
|
| 783 |
+
def compute_scaffolds(mol_list, n_jobs=1, min_rings=2):
|
| 784 |
+
"""
|
| 785 |
+
Extracts scaffolds from a list of molecules as canonical SMILES.
|
| 786 |
+
|
| 787 |
+
Only scaffolds with at least min_rings rings are considered.
|
| 788 |
+
|
| 789 |
+
Args:
|
| 790 |
+
mol_list (list): List of molecules.
|
| 791 |
+
n_jobs (int): Number of parallel jobs.
|
| 792 |
+
min_rings (int): Minimum number of rings required in a scaffold.
|
| 793 |
+
|
| 794 |
+
Returns:
|
| 795 |
+
Counter: A Counter mapping scaffold SMILES to counts.
|
| 796 |
+
"""
|
| 797 |
+
scaffolds = Counter()
|
| 798 |
+
map_ = mapper(n_jobs)
|
| 799 |
+
scaffolds = Counter(map_(partial(compute_scaffold, min_rings=min_rings), mol_list))
|
| 800 |
+
if None in scaffolds:
|
| 801 |
+
scaffolds.pop(None)
|
| 802 |
+
return scaffolds
|
| 803 |
+
|
| 804 |
+
|
| 805 |
+
def get_n_rings(mol):
|
| 806 |
+
"""
|
| 807 |
+
Computes the number of rings in a molecule.
|
| 808 |
+
|
| 809 |
+
Args:
|
| 810 |
+
mol (RDKit Mol): Molecule object.
|
| 811 |
+
|
| 812 |
+
Returns:
|
| 813 |
+
int: Number of rings.
|
| 814 |
+
"""
|
| 815 |
+
return mol.GetRingInfo().NumRings()
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
def compute_scaffold(mol, min_rings=2):
|
| 819 |
+
"""
|
| 820 |
+
Computes the Murcko scaffold of a molecule and returns its canonical SMILES if it has enough rings.
|
| 821 |
+
|
| 822 |
+
Args:
|
| 823 |
+
mol (str or RDKit Mol): Input molecule.
|
| 824 |
+
min_rings (int): Minimum number of rings required.
|
| 825 |
+
|
| 826 |
+
Returns:
|
| 827 |
+
str or None: Canonical SMILES of the scaffold if valid, else None.
|
| 828 |
+
"""
|
| 829 |
+
mol = get_mol(mol)
|
| 830 |
+
try:
|
| 831 |
+
scaffold = MurckoScaffold.GetScaffoldForMol(mol)
|
| 832 |
+
except (ValueError, RuntimeError):
|
| 833 |
+
return None
|
| 834 |
+
n_rings = get_n_rings(scaffold)
|
| 835 |
+
scaffold_smiles = Chem.MolToSmiles(scaffold)
|
| 836 |
+
if scaffold_smiles == '' or n_rings < min_rings:
|
| 837 |
+
return None
|
| 838 |
+
return scaffold_smiles
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
class Metric:
|
| 842 |
+
"""
|
| 843 |
+
Abstract base class for chemical metrics.
|
| 844 |
+
|
| 845 |
+
Derived classes should implement the precalc and metric methods.
|
| 846 |
+
"""
|
| 847 |
+
def __init__(self, n_jobs=1, device='cpu', batch_size=512, **kwargs):
|
| 848 |
+
self.n_jobs = n_jobs
|
| 849 |
+
self.device = device
|
| 850 |
+
self.batch_size = batch_size
|
| 851 |
+
for k, v in kwargs.items():
|
| 852 |
+
setattr(self, k, v)
|
| 853 |
+
|
| 854 |
+
def __call__(self, ref=None, gen=None, pref=None, pgen=None):
|
| 855 |
+
"""
|
| 856 |
+
Computes the metric between reference and generated molecules.
|
| 857 |
+
|
| 858 |
+
Exactly one of ref or pref, and gen or pgen should be provided.
|
| 859 |
+
|
| 860 |
+
Args:
|
| 861 |
+
ref: Reference molecule list.
|
| 862 |
+
gen: Generated molecule list.
|
| 863 |
+
pref: Precalculated reference metric.
|
| 864 |
+
pgen: Precalculated generated metric.
|
| 865 |
+
|
| 866 |
+
Returns:
|
| 867 |
+
Metric value computed by the metric method.
|
| 868 |
+
"""
|
| 869 |
+
assert (ref is None) != (pref is None), "specify ref xor pref"
|
| 870 |
+
assert (gen is None) != (pgen is None), "specify gen xor pgen"
|
| 871 |
+
if pref is None:
|
| 872 |
+
pref = self.precalc(ref)
|
| 873 |
+
if pgen is None:
|
| 874 |
+
pgen = self.precalc(gen)
|
| 875 |
+
return self.metric(pref, pgen)
|
| 876 |
+
|
| 877 |
+
def precalc(self, molecules):
|
| 878 |
+
"""
|
| 879 |
+
Pre-calculates necessary representations from a list of molecules.
|
| 880 |
+
Should be implemented by derived classes.
|
| 881 |
+
"""
|
| 882 |
+
raise NotImplementedError
|
| 883 |
+
|
| 884 |
+
def metric(self, pref, pgen):
|
| 885 |
+
"""
|
| 886 |
+
Computes the metric given precalculated representations.
|
| 887 |
+
Should be implemented by derived classes.
|
| 888 |
+
"""
|
| 889 |
+
raise NotImplementedError
|
| 890 |
+
|
| 891 |
+
|
| 892 |
+
class FragMetric(Metric):
|
| 893 |
+
"""
|
| 894 |
+
Metrics based on molecular fragments.
|
| 895 |
+
"""
|
| 896 |
+
def precalc(self, mols):
|
| 897 |
+
return {'frag': compute_fragments(mols, n_jobs=self.n_jobs)}
|
| 898 |
+
|
| 899 |
+
def metric(self, pref, pgen):
|
| 900 |
+
return cos_similarity(pref['frag'], pgen['frag'])
|
| 901 |
+
|
| 902 |
+
|
| 903 |
+
class ScafMetric(Metric):
|
| 904 |
+
"""
|
| 905 |
+
Metrics based on molecular scaffolds.
|
| 906 |
+
"""
|
| 907 |
+
def precalc(self, mols):
|
| 908 |
+
return {'scaf': compute_scaffolds(mols, n_jobs=self.n_jobs)}
|
| 909 |
+
|
| 910 |
+
def metric(self, pref, pgen):
|
| 911 |
+
return cos_similarity(pref['scaf'], pgen['scaf'])
|
| 912 |
+
|
| 913 |
+
|
| 914 |
+
def cos_similarity(ref_counts, gen_counts):
|
| 915 |
+
"""
|
| 916 |
+
Computes cosine similarity between two molecular vectors.
|
| 917 |
+
|
| 918 |
+
Args:
|
| 919 |
+
ref_counts (dict): Reference molecular vectors.
|
| 920 |
+
gen_counts (dict): Generated molecular vectors.
|
| 921 |
+
|
| 922 |
+
Returns:
|
| 923 |
+
float: Cosine similarity between the two molecular vectors.
|
| 924 |
+
"""
|
| 925 |
+
if len(ref_counts) == 0 or len(gen_counts) == 0:
|
| 926 |
+
return np.nan
|
| 927 |
+
keys = np.unique(list(ref_counts.keys()) + list(gen_counts.keys()))
|
| 928 |
+
ref_vec = np.array([ref_counts.get(k, 0) for k in keys])
|
| 929 |
+
gen_vec = np.array([gen_counts.get(k, 0) for k in keys])
|
| 930 |
+
return 1 - cos_distance(ref_vec, gen_vec)
|
train.py
ADDED
|
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import random
|
| 4 |
+
import pickle
|
| 5 |
+
import argparse
|
| 6 |
+
import os.path as osp
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.utils.data
|
| 10 |
+
from torch import nn
|
| 11 |
+
from torch_geometric.loader import DataLoader
|
| 12 |
+
|
| 13 |
+
import wandb
|
| 14 |
+
from rdkit import RDLogger
|
| 15 |
+
|
| 16 |
+
torch.set_num_threads(5)
|
| 17 |
+
RDLogger.DisableLog('rdApp.*')
|
| 18 |
+
|
| 19 |
+
from src.util.utils import *
|
| 20 |
+
from src.model.models import Generator, Discriminator, simple_disc
|
| 21 |
+
from src.data.dataset import DruggenDataset
|
| 22 |
+
from src.data.utils import get_encoders_decoders, load_molecules
|
| 23 |
+
from src.model.loss import discriminator_loss, generator_loss
|
| 24 |
+
|
| 25 |
+
class Train(object):
|
| 26 |
+
"""Trainer for DrugGEN."""
|
| 27 |
+
|
| 28 |
+
def __init__(self, config):
|
| 29 |
+
if config.set_seed:
|
| 30 |
+
np.random.seed(config.seed)
|
| 31 |
+
random.seed(config.seed)
|
| 32 |
+
torch.manual_seed(config.seed)
|
| 33 |
+
torch.cuda.manual_seed_all(config.seed)
|
| 34 |
+
|
| 35 |
+
torch.backends.cudnn.deterministic = True
|
| 36 |
+
torch.backends.cudnn.benchmark = False
|
| 37 |
+
|
| 38 |
+
os.environ["PYTHONHASHSEED"] = str(config.seed)
|
| 39 |
+
|
| 40 |
+
print(f'Using seed {config.seed}')
|
| 41 |
+
|
| 42 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
|
| 43 |
+
|
| 44 |
+
# Initialize configurations
|
| 45 |
+
self.submodel = config.submodel
|
| 46 |
+
|
| 47 |
+
# Data loader.
|
| 48 |
+
self.raw_file = config.raw_file # SMILES containing text file for dataset.
|
| 49 |
+
# Write the full path to file.
|
| 50 |
+
self.drug_raw_file = config.drug_raw_file # SMILES containing text file for second dataset.
|
| 51 |
+
# Write the full path to file.
|
| 52 |
+
|
| 53 |
+
# Automatically infer dataset file names from raw file names
|
| 54 |
+
raw_file_basename = osp.basename(self.raw_file)
|
| 55 |
+
drug_raw_file_basename = osp.basename(self.drug_raw_file)
|
| 56 |
+
|
| 57 |
+
# Get the base name without extension and add max_atom to it
|
| 58 |
+
self.max_atom = config.max_atom # Model is based on one-shot generation.
|
| 59 |
+
raw_file_base = os.path.splitext(raw_file_basename)[0]
|
| 60 |
+
drug_raw_file_base = os.path.splitext(drug_raw_file_basename)[0]
|
| 61 |
+
|
| 62 |
+
# Change extension from .smi to .pt and add max_atom to the filename
|
| 63 |
+
self.dataset_file = f"{raw_file_base}{self.max_atom}.pt"
|
| 64 |
+
self.drugs_dataset_file = f"{drug_raw_file_base}{self.max_atom}.pt"
|
| 65 |
+
|
| 66 |
+
self.mol_data_dir = config.mol_data_dir # Directory where the dataset files are stored.
|
| 67 |
+
self.drug_data_dir = config.drug_data_dir # Directory where the drug dataset files are stored.
|
| 68 |
+
self.dataset_name = self.dataset_file.split(".")[0]
|
| 69 |
+
self.drugs_dataset_name = self.drugs_dataset_file.split(".")[0]
|
| 70 |
+
self.features = config.features # Small model uses atom types as node features. (Boolean, False uses atom types only.)
|
| 71 |
+
# Additional node features can be added. Please check new_dataloarder.py Line 102.
|
| 72 |
+
self.batch_size = config.batch_size # Batch size for training.
|
| 73 |
+
|
| 74 |
+
self.parallel = config.parallel
|
| 75 |
+
|
| 76 |
+
# Get atom and bond encoders/decoders
|
| 77 |
+
atom_encoder, atom_decoder, bond_encoder, bond_decoder = get_encoders_decoders(
|
| 78 |
+
self.raw_file,
|
| 79 |
+
self.drug_raw_file,
|
| 80 |
+
self.max_atom
|
| 81 |
+
)
|
| 82 |
+
self.atom_encoder = atom_encoder
|
| 83 |
+
self.atom_decoder = atom_decoder
|
| 84 |
+
self.bond_encoder = bond_encoder
|
| 85 |
+
self.bond_decoder = bond_decoder
|
| 86 |
+
|
| 87 |
+
self.dataset = DruggenDataset(self.mol_data_dir,
|
| 88 |
+
self.dataset_file,
|
| 89 |
+
self.raw_file,
|
| 90 |
+
self.max_atom,
|
| 91 |
+
self.features,
|
| 92 |
+
atom_encoder=atom_encoder,
|
| 93 |
+
atom_decoder=atom_decoder,
|
| 94 |
+
bond_encoder=bond_encoder,
|
| 95 |
+
bond_decoder=bond_decoder)
|
| 96 |
+
|
| 97 |
+
self.loader = DataLoader(self.dataset,
|
| 98 |
+
shuffle=True,
|
| 99 |
+
batch_size=self.batch_size,
|
| 100 |
+
drop_last=True) # PyG dataloader for the GAN.
|
| 101 |
+
|
| 102 |
+
self.drugs = DruggenDataset(self.drug_data_dir,
|
| 103 |
+
self.drugs_dataset_file,
|
| 104 |
+
self.drug_raw_file,
|
| 105 |
+
self.max_atom,
|
| 106 |
+
self.features,
|
| 107 |
+
atom_encoder=atom_encoder,
|
| 108 |
+
atom_decoder=atom_decoder,
|
| 109 |
+
bond_encoder=bond_encoder,
|
| 110 |
+
bond_decoder=bond_decoder)
|
| 111 |
+
|
| 112 |
+
self.drugs_loader = DataLoader(self.drugs,
|
| 113 |
+
shuffle=True,
|
| 114 |
+
batch_size=self.batch_size,
|
| 115 |
+
drop_last=True) # PyG dataloader for the second GAN.
|
| 116 |
+
|
| 117 |
+
self.m_dim = len(self.atom_decoder) if not self.features else int(self.loader.dataset[0].x.shape[1]) # Atom type dimension.
|
| 118 |
+
self.b_dim = len(self.bond_decoder) # Bond type dimension.
|
| 119 |
+
self.vertexes = int(self.loader.dataset[0].x.shape[0]) # Number of nodes in the graph.
|
| 120 |
+
|
| 121 |
+
# Model configurations.
|
| 122 |
+
self.act = config.act
|
| 123 |
+
self.lambda_gp = config.lambda_gp
|
| 124 |
+
self.dim = config.dim
|
| 125 |
+
self.depth = config.depth
|
| 126 |
+
self.heads = config.heads
|
| 127 |
+
self.mlp_ratio = config.mlp_ratio
|
| 128 |
+
self.ddepth = config.ddepth
|
| 129 |
+
self.ddropout = config.ddropout
|
| 130 |
+
|
| 131 |
+
# Training configurations.
|
| 132 |
+
self.epoch = config.epoch
|
| 133 |
+
self.g_lr = config.g_lr
|
| 134 |
+
self.d_lr = config.d_lr
|
| 135 |
+
self.dropout = config.dropout
|
| 136 |
+
self.beta1 = config.beta1
|
| 137 |
+
self.beta2 = config.beta2
|
| 138 |
+
|
| 139 |
+
# Directories.
|
| 140 |
+
self.log_dir = config.log_dir
|
| 141 |
+
self.sample_dir = config.sample_dir
|
| 142 |
+
self.model_save_dir = config.model_save_dir
|
| 143 |
+
|
| 144 |
+
# Step size.
|
| 145 |
+
self.log_step = config.log_sample_step
|
| 146 |
+
|
| 147 |
+
# resume training
|
| 148 |
+
self.resume = config.resume
|
| 149 |
+
self.resume_epoch = config.resume_epoch
|
| 150 |
+
self.resume_iter = config.resume_iter
|
| 151 |
+
self.resume_directory = config.resume_directory
|
| 152 |
+
|
| 153 |
+
# wandb configuration
|
| 154 |
+
self.use_wandb = config.use_wandb
|
| 155 |
+
self.online = config.online
|
| 156 |
+
self.exp_name = config.exp_name
|
| 157 |
+
|
| 158 |
+
# Arguments for the model.
|
| 159 |
+
self.arguments = "{}_{}_glr{}_dlr{}_dim{}_depth{}_heads{}_batch{}_epoch{}_dataset{}_dropout{}".format(self.exp_name, self.submodel, self.g_lr, self.d_lr, self.dim, self.depth, self.heads, self.batch_size, self.epoch, self.dataset_name, self.dropout)
|
| 160 |
+
|
| 161 |
+
self.build_model(self.model_save_dir, self.arguments)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def build_model(self, model_save_dir, arguments):
|
| 165 |
+
"""Create generators and discriminators."""
|
| 166 |
+
|
| 167 |
+
''' Generator is based on Transformer Encoder:
|
| 168 |
+
|
| 169 |
+
@ g_conv_dim: Dimensions for MLP layers before Transformer Encoder
|
| 170 |
+
@ vertexes: maximum length of generated molecules (atom length)
|
| 171 |
+
@ b_dim: number of bond types
|
| 172 |
+
@ m_dim: number of atom types (or number of features used)
|
| 173 |
+
@ dropout: dropout possibility
|
| 174 |
+
@ dim: Hidden dimension of Transformer Encoder
|
| 175 |
+
@ depth: Transformer layer number
|
| 176 |
+
@ heads: Number of multihead-attention heads
|
| 177 |
+
@ mlp_ratio: Read-out layer dimension of Transformer
|
| 178 |
+
@ drop_rate: depricated
|
| 179 |
+
@ tra_conv: Whether module creates output for TransformerConv discriminator
|
| 180 |
+
'''
|
| 181 |
+
self.G = Generator(self.act,
|
| 182 |
+
self.vertexes,
|
| 183 |
+
self.b_dim,
|
| 184 |
+
self.m_dim,
|
| 185 |
+
self.dropout,
|
| 186 |
+
dim=self.dim,
|
| 187 |
+
depth=self.depth,
|
| 188 |
+
heads=self.heads,
|
| 189 |
+
mlp_ratio=self.mlp_ratio)
|
| 190 |
+
|
| 191 |
+
''' Discriminator implementation with Transformer Encoder:
|
| 192 |
+
|
| 193 |
+
@ act: Activation function for MLP
|
| 194 |
+
@ vertexes: maximum length of generated molecules (molecule length)
|
| 195 |
+
@ b_dim: number of bond types
|
| 196 |
+
@ m_dim: number of atom types (or number of features used)
|
| 197 |
+
@ dropout: dropout possibility
|
| 198 |
+
@ dim: Hidden dimension of Transformer Encoder
|
| 199 |
+
@ depth: Transformer layer number
|
| 200 |
+
@ heads: Number of multihead-attention heads
|
| 201 |
+
@ mlp_ratio: Read-out layer dimension of Transformer'''
|
| 202 |
+
|
| 203 |
+
self.D = Discriminator(self.act,
|
| 204 |
+
self.vertexes,
|
| 205 |
+
self.b_dim,
|
| 206 |
+
self.m_dim,
|
| 207 |
+
self.ddropout,
|
| 208 |
+
dim=self.dim,
|
| 209 |
+
depth=self.ddepth,
|
| 210 |
+
heads=self.heads,
|
| 211 |
+
mlp_ratio=self.mlp_ratio)
|
| 212 |
+
|
| 213 |
+
self.g_optimizer = torch.optim.AdamW(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
|
| 214 |
+
self.d_optimizer = torch.optim.AdamW(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
|
| 215 |
+
|
| 216 |
+
network_path = os.path.join(model_save_dir, arguments)
|
| 217 |
+
self.print_network(self.G, 'G', network_path)
|
| 218 |
+
self.print_network(self.D, 'D', network_path)
|
| 219 |
+
|
| 220 |
+
if self.parallel and torch.cuda.device_count() > 1:
|
| 221 |
+
print(f"Using {torch.cuda.device_count()} GPUs!")
|
| 222 |
+
self.G = nn.DataParallel(self.G)
|
| 223 |
+
self.D = nn.DataParallel(self.D)
|
| 224 |
+
|
| 225 |
+
self.G.to(self.device)
|
| 226 |
+
self.D.to(self.device)
|
| 227 |
+
|
| 228 |
+
def print_network(self, model, name, save_dir):
|
| 229 |
+
"""Print out the network information."""
|
| 230 |
+
num_params = 0
|
| 231 |
+
for p in model.parameters():
|
| 232 |
+
num_params += p.numel()
|
| 233 |
+
|
| 234 |
+
if not os.path.exists(save_dir):
|
| 235 |
+
os.makedirs(save_dir)
|
| 236 |
+
|
| 237 |
+
network_path = os.path.join(save_dir, "{}_modules.txt".format(name))
|
| 238 |
+
with open(network_path, "w+") as file:
|
| 239 |
+
for module in model.modules():
|
| 240 |
+
file.write(f"{module.__class__.__name__}:\n")
|
| 241 |
+
print(module.__class__.__name__)
|
| 242 |
+
for n, param in module.named_parameters():
|
| 243 |
+
if param is not None:
|
| 244 |
+
file.write(f" - {n}: {param.size()}\n")
|
| 245 |
+
print(f" - {n}: {param.size()}")
|
| 246 |
+
break
|
| 247 |
+
file.write(f"Total number of parameters: {num_params}\n")
|
| 248 |
+
print(f"Total number of parameters: {num_params}\n\n")
|
| 249 |
+
|
| 250 |
+
def restore_model(self, epoch, iteration, model_directory):
|
| 251 |
+
"""Restore the trained generator and discriminator."""
|
| 252 |
+
print('Loading the trained models from epoch / iteration {}-{}...'.format(epoch, iteration))
|
| 253 |
+
|
| 254 |
+
G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(epoch, iteration))
|
| 255 |
+
D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(epoch, iteration))
|
| 256 |
+
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
|
| 257 |
+
self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
|
| 258 |
+
|
| 259 |
+
def save_model(self, model_directory, idx,i):
|
| 260 |
+
G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(idx+1,i+1))
|
| 261 |
+
D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(idx+1,i+1))
|
| 262 |
+
torch.save(self.G.state_dict(), G_path)
|
| 263 |
+
torch.save(self.D.state_dict(), D_path)
|
| 264 |
+
|
| 265 |
+
def reset_grad(self):
|
| 266 |
+
"""Reset the gradient buffers."""
|
| 267 |
+
self.g_optimizer.zero_grad()
|
| 268 |
+
self.d_optimizer.zero_grad()
|
| 269 |
+
|
| 270 |
+
def train(self, config):
|
| 271 |
+
''' Training Script starts from here'''
|
| 272 |
+
if self.use_wandb:
|
| 273 |
+
mode = 'online' if self.online else 'offline'
|
| 274 |
+
else:
|
| 275 |
+
mode = 'disabled'
|
| 276 |
+
kwargs = {'name': self.exp_name, 'project': 'druggen', 'config': config,
|
| 277 |
+
'settings': wandb.Settings(_disable_stats=True), 'reinit': True, 'mode': mode, 'save_code': True}
|
| 278 |
+
wandb.init(**kwargs)
|
| 279 |
+
|
| 280 |
+
wandb.save(os.path.join(self.model_save_dir, self.arguments, "G_modules.txt"))
|
| 281 |
+
wandb.save(os.path.join(self.model_save_dir, self.arguments, "D_modules.txt"))
|
| 282 |
+
|
| 283 |
+
self.model_directory = os.path.join(self.model_save_dir, self.arguments)
|
| 284 |
+
self.sample_directory = os.path.join(self.sample_dir, self.arguments)
|
| 285 |
+
self.log_path = os.path.join(self.log_dir, "{}.txt".format(self.arguments))
|
| 286 |
+
if not os.path.exists(self.model_directory):
|
| 287 |
+
os.makedirs(self.model_directory)
|
| 288 |
+
if not os.path.exists(self.sample_directory):
|
| 289 |
+
os.makedirs(self.sample_directory)
|
| 290 |
+
|
| 291 |
+
# smiles data for metrics calculation.
|
| 292 |
+
drug_smiles = [line for line in open(self.drug_raw_file, 'r').read().splitlines()]
|
| 293 |
+
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
| 294 |
+
drug_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in drug_mols if x is not None]
|
| 295 |
+
|
| 296 |
+
if self.resume:
|
| 297 |
+
self.restore_model(self.resume_epoch, self.resume_iter, self.resume_directory)
|
| 298 |
+
|
| 299 |
+
# Start training.
|
| 300 |
+
print('Start training...')
|
| 301 |
+
self.start_time = time.time()
|
| 302 |
+
for idx in range(self.epoch):
|
| 303 |
+
# =================================================================================== #
|
| 304 |
+
# 1. Preprocess input data #
|
| 305 |
+
# =================================================================================== #
|
| 306 |
+
# Load the data
|
| 307 |
+
dataloader_iterator = iter(self.drugs_loader)
|
| 308 |
+
|
| 309 |
+
wandb.log({"epoch": idx})
|
| 310 |
+
|
| 311 |
+
for i, data in enumerate(self.loader):
|
| 312 |
+
try:
|
| 313 |
+
drugs = next(dataloader_iterator)
|
| 314 |
+
except StopIteration:
|
| 315 |
+
dataloader_iterator = iter(self.drugs_loader)
|
| 316 |
+
drugs = next(dataloader_iterator)
|
| 317 |
+
|
| 318 |
+
wandb.log({"iter": i})
|
| 319 |
+
|
| 320 |
+
# Preprocess both dataset
|
| 321 |
+
real_graphs, a_tensor, x_tensor = load_molecules(
|
| 322 |
+
data=data,
|
| 323 |
+
batch_size=self.batch_size,
|
| 324 |
+
device=self.device,
|
| 325 |
+
b_dim=self.b_dim,
|
| 326 |
+
m_dim=self.m_dim,
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
drug_graphs, drugs_a_tensor, drugs_x_tensor = load_molecules(
|
| 330 |
+
data=drugs,
|
| 331 |
+
batch_size=self.batch_size,
|
| 332 |
+
device=self.device,
|
| 333 |
+
b_dim=self.b_dim,
|
| 334 |
+
m_dim=self.m_dim,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
# Training configuration.
|
| 338 |
+
GEN_node = x_tensor # Generator input node features (annotation matrix of real molecules)
|
| 339 |
+
GEN_edge = a_tensor # Generator input edge features (adjacency matrix of real molecules)
|
| 340 |
+
if self.submodel == "DrugGEN":
|
| 341 |
+
DISC_node = drugs_x_tensor # Discriminator input node features (annotation matrix of drug molecules)
|
| 342 |
+
DISC_edge = drugs_a_tensor # Discriminator input edge features (adjacency matrix of drug molecules)
|
| 343 |
+
elif self.submodel == "NoTarget":
|
| 344 |
+
DISC_node = x_tensor # Discriminator input node features (annotation matrix of real molecules)
|
| 345 |
+
DISC_edge = a_tensor # Discriminator input edge features (adjacency matrix of real molecules)
|
| 346 |
+
|
| 347 |
+
# =================================================================================== #
|
| 348 |
+
# 2. Train the GAN #
|
| 349 |
+
# =================================================================================== #
|
| 350 |
+
|
| 351 |
+
loss = {}
|
| 352 |
+
self.reset_grad()
|
| 353 |
+
# Compute discriminator loss.
|
| 354 |
+
node, edge, d_loss = discriminator_loss(self.G,
|
| 355 |
+
self.D,
|
| 356 |
+
DISC_edge,
|
| 357 |
+
DISC_node,
|
| 358 |
+
GEN_edge,
|
| 359 |
+
GEN_node,
|
| 360 |
+
self.batch_size,
|
| 361 |
+
self.device,
|
| 362 |
+
self.lambda_gp)
|
| 363 |
+
d_total = d_loss
|
| 364 |
+
wandb.log({"d_loss": d_total.item()})
|
| 365 |
+
|
| 366 |
+
loss["d_total"] = d_total.item()
|
| 367 |
+
d_total.backward()
|
| 368 |
+
self.d_optimizer.step()
|
| 369 |
+
|
| 370 |
+
self.reset_grad()
|
| 371 |
+
|
| 372 |
+
# Compute generator loss.
|
| 373 |
+
generator_output = generator_loss(self.G,
|
| 374 |
+
self.D,
|
| 375 |
+
GEN_edge,
|
| 376 |
+
GEN_node,
|
| 377 |
+
self.batch_size)
|
| 378 |
+
g_loss, node, edge, node_sample, edge_sample = generator_output
|
| 379 |
+
g_total = g_loss
|
| 380 |
+
wandb.log({"g_loss": g_total.item()})
|
| 381 |
+
|
| 382 |
+
loss["g_total"] = g_total.item()
|
| 383 |
+
g_total.backward()
|
| 384 |
+
self.g_optimizer.step()
|
| 385 |
+
|
| 386 |
+
# Logging.
|
| 387 |
+
if (i+1) % self.log_step == 0:
|
| 388 |
+
logging(self.log_path, self.start_time, i, idx, loss, self.sample_directory,
|
| 389 |
+
drug_smiles,edge_sample, node_sample, self.dataset.matrices2mol,
|
| 390 |
+
self.dataset_name, a_tensor, x_tensor, drug_vecs)
|
| 391 |
+
|
| 392 |
+
mol_sample(self.sample_directory, edge_sample.detach(), node_sample.detach(),
|
| 393 |
+
idx, i, self.dataset.matrices2mol, self.dataset_name)
|
| 394 |
+
print("samples saved at epoch {} and iteration {}".format(idx,i))
|
| 395 |
+
|
| 396 |
+
self.save_model(self.model_directory, idx, i)
|
| 397 |
+
print("model saved at epoch {} and iteration {}".format(idx,i))
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
if __name__ == '__main__':
|
| 401 |
+
parser = argparse.ArgumentParser()
|
| 402 |
+
|
| 403 |
+
# Data configuration.
|
| 404 |
+
parser.add_argument('--raw_file', type=str, required=True)
|
| 405 |
+
parser.add_argument('--drug_raw_file', type=str, required=False, help='Required for DrugGEN model, optional for NoTarget')
|
| 406 |
+
parser.add_argument('--drug_data_dir', type=str, default='data')
|
| 407 |
+
parser.add_argument('--mol_data_dir', type=str, default='data')
|
| 408 |
+
parser.add_argument('--features', action='store_true', help='features dimension for nodes')
|
| 409 |
+
|
| 410 |
+
# Model configuration.
|
| 411 |
+
parser.add_argument('--submodel', type=str, default="DrugGEN", help="Chose model subtype: DrugGEN, NoTarget", choices=['DrugGEN', 'NoTarget'])
|
| 412 |
+
parser.add_argument('--act', type=str, default="relu", help="Activation function for the model.", choices=['relu', 'tanh', 'leaky', 'sigmoid'])
|
| 413 |
+
parser.add_argument('--max_atom', type=int, default=45, help='Max atom number for molecules must be specified.')
|
| 414 |
+
parser.add_argument('--dim', type=int, default=128, help='Dimension of the Transformer Encoder model for the GAN.')
|
| 415 |
+
parser.add_argument('--depth', type=int, default=1, help='Depth of the Transformer model from the GAN.')
|
| 416 |
+
parser.add_argument('--ddepth', type=int, default=1, help='Depth of the Transformer model from the discriminator.')
|
| 417 |
+
parser.add_argument('--heads', type=int, default=8, help='Number of heads for the MultiHeadAttention module from the GAN.')
|
| 418 |
+
parser.add_argument('--mlp_ratio', type=int, default=3, help='MLP ratio for the Transformer.')
|
| 419 |
+
parser.add_argument('--dropout', type=float, default=0., help='dropout rate')
|
| 420 |
+
parser.add_argument('--ddropout', type=float, default=0., help='dropout rate for the discriminator')
|
| 421 |
+
parser.add_argument('--lambda_gp', type=float, default=10, help='Gradient penalty lambda multiplier for the GAN.')
|
| 422 |
+
|
| 423 |
+
# Training configuration.
|
| 424 |
+
parser.add_argument('--batch_size', type=int, default=128, help='Batch size for the training.')
|
| 425 |
+
parser.add_argument('--epoch', type=int, default=10, help='Epoch number for Training.')
|
| 426 |
+
parser.add_argument('--g_lr', type=float, default=0.00001, help='learning rate for G')
|
| 427 |
+
parser.add_argument('--d_lr', type=float, default=0.00001, help='learning rate for D')
|
| 428 |
+
parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for Adam optimizer')
|
| 429 |
+
parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer')
|
| 430 |
+
parser.add_argument('--log_dir', type=str, default='experiments/logs')
|
| 431 |
+
parser.add_argument('--sample_dir', type=str, default='experiments/samples')
|
| 432 |
+
parser.add_argument('--model_save_dir', type=str, default='experiments/models')
|
| 433 |
+
parser.add_argument('--log_sample_step', type=int, default=1000, help='step size for sampling during training')
|
| 434 |
+
|
| 435 |
+
# Resume training.
|
| 436 |
+
parser.add_argument('--resume', type=bool, default=False, help='resume training')
|
| 437 |
+
parser.add_argument('--resume_epoch', type=int, default=None, help='resume training from this epoch')
|
| 438 |
+
parser.add_argument('--resume_iter', type=int, default=None, help='resume training from this step')
|
| 439 |
+
parser.add_argument('--resume_directory', type=str, default=None, help='load pretrained weights from this directory')
|
| 440 |
+
|
| 441 |
+
# Seed configuration.
|
| 442 |
+
parser.add_argument('--set_seed', action='store_true', help='set seed for reproducibility')
|
| 443 |
+
parser.add_argument('--seed', type=int, default=1, help='seed for reproducibility')
|
| 444 |
+
|
| 445 |
+
# wandb configuration.
|
| 446 |
+
parser.add_argument('--use_wandb', action='store_true', help='use wandb for logging')
|
| 447 |
+
parser.add_argument('--online', action='store_true', help='use wandb online')
|
| 448 |
+
parser.add_argument('--exp_name', type=str, default='druggen', help='experiment name')
|
| 449 |
+
parser.add_argument('--parallel', action='store_true', help='Parallelize training')
|
| 450 |
+
|
| 451 |
+
config = parser.parse_args()
|
| 452 |
+
|
| 453 |
+
# Check if drug_raw_file is provided when using DrugGEN model
|
| 454 |
+
if config.submodel == "DrugGEN" and not config.drug_raw_file:
|
| 455 |
+
parser.error("--drug_raw_file is required when using DrugGEN model")
|
| 456 |
+
|
| 457 |
+
# If using NoTarget model and drug_raw_file is not provided, use a dummy file
|
| 458 |
+
if config.submodel == "NoTarget" and not config.drug_raw_file:
|
| 459 |
+
config.drug_raw_file = "data/akt_train.smi" # Use a reference file for NoTarget model (AKT) (not used for training for ease of use and encoder/decoder's)
|
| 460 |
+
|
| 461 |
+
trainer = Train(config)
|
| 462 |
+
trainer.train(config)
|